# SPDX-License-Identifier: MIT # Copyright (c) 2025 LlamaIndex Inc. from __future__ import annotations import asyncio import functools import uuid import warnings from typing import ( TYPE_CHECKING, Any, AsyncGenerator, Generic, Type, TypeVar, cast, ) from pydantic import BaseModel, ValidationError from workflows.context.context_types import SerializedContext from workflows.decorators import StepConfig from workflows.errors import ( ContextSerdeError, WorkflowRuntimeError, ) from workflows.events import ( Event, StartEvent, StopEvent, ) from workflows.runtime.types.internal_state import BrokerState from workflows.runtime.broker import WorkflowBroker from workflows.plugins.basic import basic_runtime from workflows.runtime.types.plugin import Plugin, WorkflowRuntime from workflows.types import RunResultT from workflows.handler import WorkflowHandler from .serializers import BaseSerializer, JsonSerializer from .state_store import MODEL_T, DictState, InMemoryStateStore if TYPE_CHECKING: # pragma: no cover from workflows import Workflow T = TypeVar("T", bound=Event) EventBuffer = dict[str, list[Event]] # Only warn once about unserializable keys class UnserializableKeyWarning(Warning): pass warnings.simplefilter("once", UnserializableKeyWarning) class Context(Generic[MODEL_T]): """ Global, per-run context for a `Workflow`. Provides an interface into the underlying broker run, for both external (workflow run oberservers) and internal consumption by workflow steps. The `Context` coordinates event delivery between steps, tracks in-flight work, exposes a global state store, and provides utilities for streaming and synchronization. It is created by a `Workflow` at run time and can be persisted and restored. Args: workflow (Workflow): The owning workflow instance. Used to infer step configuration and instrumentation. previous_context: A previous context snapshot to resume from. serializer: A serializer to use for serializing and deserializing the current and previous context snapshots. Attributes: is_running (bool): Whether the workflow is currently running. store (InMemoryStateStore[MODEL_T]): Type-safe, async state store shared across steps. See also [InMemoryStateStore][workflows.context.state_store.InMemoryStateStore]. Examples: Basic usage inside a step: ```python from workflows import step from workflows.events import StartEvent, StopEvent @step async def start(self, ctx: Context, ev: StartEvent) -> StopEvent: await ctx.store.set("query", ev.topic) ctx.write_event_to_stream(ev) # surface progress to UI return StopEvent(result="ok") ``` Persisting the state of a workflow across runs: ```python from workflows import Context # Create a context and run the workflow with the same context ctx = Context(my_workflow) result_1 = await my_workflow.run(..., ctx=ctx) result_2 = await my_workflow.run(..., ctx=ctx) # Serialize the context and restore it ctx_dict = ctx.to_dict() restored_ctx = Context.from_dict(my_workflow, ctx_dict) result_3 = await my_workflow.run(..., ctx=restored_ctx) ``` See Also: - [Workflow][workflows.Workflow] - [Event][workflows.events.Event] - [InMemoryStateStore][workflows.context.state_store.InMemoryStateStore] """ # These keys are set by pre-built workflows and # are known to be unserializable in some cases. known_unserializable_keys = ("memory",) # Backing state store; serialized as `state` _state_store: InMemoryStateStore[MODEL_T] _broker_run: WorkflowBroker[MODEL_T] | None _plugin: Plugin _workflow: Workflow def __init__( self, workflow: Workflow, previous_context: dict[str, Any] | None = None, serializer: BaseSerializer | None = None, plugin: Plugin = basic_runtime, ) -> None: self._serializer = serializer or JsonSerializer() self._broker_run = None self._plugin = plugin self._workflow = workflow # parse the serialized context serializer = serializer or JsonSerializer() if previous_context is not None: try: # Auto-detect and convert V0 to V1 if needed previous_context_parsed = SerializedContext.from_dict_auto( previous_context ) # validate it fully parses synchronously to avoid delayed validation errors BrokerState.from_serialized( previous_context_parsed, workflow, serializer ) except ValidationError as e: raise ContextSerdeError( f"Context dict specified in an invalid format: {e}" ) from e else: previous_context_parsed = SerializedContext() self._init_snapshot = previous_context_parsed # initialization of the state store is a bit complex, due to inferring and validating its type from the # provided workflow context args state_types: set[Type[BaseModel]] = set() for _, step_func in workflow._get_steps().items(): step_config: StepConfig = step_func._step_config if ( step_config.context_state_type is not None and step_config.context_state_type != DictState and issubclass(step_config.context_state_type, BaseModel) ): state_type = step_config.context_state_type state_types.add(state_type) if len(state_types) > 1: raise ValueError( "Multiple state types are not supported. Make sure that each Context[...] has the same generic state type. Found: " + ", ".join([state_type.__name__ for state_type in state_types]) ) state_type = state_types.pop() if state_types else DictState if previous_context_parsed.state: # perhaps offer a way to clear on invalid store_state = InMemoryStateStore.from_dict( previous_context_parsed.state, serializer ) if store_state.state_type != state_type: raise ValueError( f"State type mismatch. Workflow context expected {state_type.__name__}, got {store_state.state_type.__name__}" ) self._state_store = cast(InMemoryStateStore[MODEL_T], store_state) else: try: state_instance = cast(MODEL_T, state_type()) self._state_store = InMemoryStateStore(state_instance) except Exception as e: raise WorkflowRuntimeError( f"Failed to initialize state of type {state_type}. Does your state define defaults for all fields? Original error:\n{e}" ) from e @property def is_running(self) -> bool: """Whether the workflow is currently running.""" if self._broker_run is None: return self._init_snapshot.is_running else: return self._broker_run.is_running def _init_broker( self, workflow: Workflow, plugin: WorkflowRuntime | None = None ) -> WorkflowBroker[MODEL_T]: if self._broker_run is not None: raise WorkflowRuntimeError("Broker already initialized") # Initialize a runtime plugin (asyncio-based by default) runtime: WorkflowRuntime = plugin or self._plugin.new_runtime(str(uuid.uuid4())) # Initialize the new broker implementation (broker2) self._broker_run = WorkflowBroker( workflow=workflow, context=self, runtime=runtime, plugin=self._plugin, ) return self._broker_run def _workflow_run( self, workflow: Workflow, start_event: StartEvent | None = None, semaphore: asyncio.Semaphore | None = None, ) -> WorkflowHandler: """ called by package internally from the workflow to run it """ prev_broker: WorkflowBroker[MODEL_T] | None = None if self._broker_run is not None: prev_broker = self._broker_run self._broker_run = None self._broker_run = self._init_broker(workflow) async def before_start() -> None: if prev_broker is not None: try: await prev_broker.shutdown() except Exception: pass if semaphore is not None: await semaphore.acquire() async def after_complete() -> None: if semaphore is not None: semaphore.release() state = BrokerState.from_serialized( self._init_snapshot, workflow, self._serializer ) return self._broker_run.start( workflow=workflow, previous=state, start_event=start_event, before_start=before_start, after_complete=after_complete, ) def _workflow_cancel_run(self) -> None: """ Called internally from the handler to cancel a context's run """ self._running_broker.cancel_run() @property def _running_broker(self) -> WorkflowBroker[MODEL_T]: if self._broker_run is None: raise WorkflowRuntimeError( "Workflow run is not yet running. Make sure to only call this method after the context has been passed to a workflow.run call." ) return self._broker_run @property def store(self) -> InMemoryStateStore[MODEL_T]: """Typed, process-local state store shared across steps. If no state was initialized yet, a default [DictState][workflows.context.state_store.DictState] store is created. Returns: InMemoryStateStore[MODEL_T]: The state store instance. """ return self._state_store def to_dict(self, serializer: BaseSerializer | None = None) -> dict[str, Any]: """Serialize the context to a JSON-serializable dict. Persists the global state store, event queues, buffers, accepted events, broker log, and running flag. This payload can be fed to [from_dict][workflows.context.context.Context.from_dict] to resume a run or carry state across runs. Args: serializer (BaseSerializer | None): Value serializer used for state and event payloads. Defaults to [JsonSerializer][workflows.context.serializers.JsonSerializer]. Returns: dict[str, Any]: A dict suitable for JSON encoding and later restoration via `from_dict`. See Also: - [InMemoryStateStore.to_dict][workflows.context.state_store.InMemoryStateStore.to_dict] Examples: ```python ctx_dict = ctx.to_dict() my_db.set("key", json.dumps(ctx_dict)) ctx_dict = my_db.get("key") restored_ctx = Context.from_dict(my_workflow, json.loads(ctx_dict)) result = await my_workflow.run(..., ctx=restored_ctx) ``` """ serializer = serializer or self._serializer # Serialize state using the state manager's method state_data = {} if self._state_store is not None: state_data = self._state_store.to_dict(serializer) # Get the broker state - either from the running broker or from the init snapshot if self._broker_run is not None: broker_state = self._broker_run._state else: # Deserialize the init snapshot to get a BrokerState, then re-serialize it # This ensures we always output the current format broker_state = BrokerState.from_serialized( self._init_snapshot, self._workflow, serializer ) context = broker_state.to_serialized(serializer) context.state = state_data # mode="python" to support pickling over json if one so chooses. This should perhaps be moved into the serializers return context.model_dump(mode="python") @classmethod def from_dict( cls, workflow: "Workflow", data: dict[str, Any], serializer: BaseSerializer | None = None, ) -> "Context[MODEL_T]": """Reconstruct a `Context` from a serialized payload. Args: workflow (Workflow): The workflow instance that will own this context. data (dict[str, Any]): Payload produced by [to_dict][workflows.context.context.Context.to_dict]. serializer (BaseSerializer | None): Serializer used to decode state and events. Defaults to JSON. Returns: Context[MODEL_T]: A context instance initialized with the persisted state and queues. Raises: ContextSerdeError: If the payload is missing required fields or is in an incompatible format. Examples: ```python ctx_dict = ctx.to_dict() my_db.set("key", json.dumps(ctx_dict)) ctx_dict = my_db.get("key") restored_ctx = Context.from_dict(my_workflow, json.loads(ctx_dict)) result = await my_workflow.run(..., ctx=restored_ctx) ``` """ try: return cls(workflow, previous_context=data, serializer=serializer) except KeyError as e: msg = "Error creating a Context instance: the provided payload has a wrong or old format." raise ContextSerdeError(msg) from e async def running_steps(self) -> list[str]: """Return the list of currently running step names. Returns: list[str]: Names of steps that have at least one active worker. """ return await self._running_broker.running_steps() def collect_events( self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None ) -> list[Event] | None: """ Buffer events until all expected types are available, then return them. This utility is helpful when a step can receive multiple event types and needs to proceed only when it has a full set. The returned list is ordered according to `expected`. Args: ev (Event): The incoming event to add to the buffer. expected (list[Type[Event]]): Event types to collect, in order. buffer_id (str | None): Optional stable key to isolate buffers across steps or workers. Defaults to an internal key derived from the task name or expected types. Returns: list[Event] | None: The events in the requested order when complete, otherwise `None`. Examples: ```python @step async def synthesize( self, ctx: Context, ev: QueryEvent | RetrieveEvent ) -> StopEvent | None: events = ctx.collect_events(ev, [QueryEvent, RetrieveEvent]) if events is None: return None query_ev, retrieve_ev = events # ... proceed with both inputs present ... ``` See Also: - [Event][workflows.events.Event] """ return self._running_broker.collect_events(ev, expected, buffer_id) def send_event(self, message: Event, step: str | None = None) -> None: """Dispatch an event to one or all workflow steps. If `step` is omitted, the event is broadcast to all step queues and non-matching steps will ignore it. When `step` is provided, the target step must accept the event type or a [WorkflowRuntimeError][workflows.errors.WorkflowRuntimeError] is raised. Args: message (Event): The event to enqueue. step (str | None): Optional step name to target. Raises: WorkflowRuntimeError: If the target step does not exist or does not accept the event type. Examples: It's common to use this method to fan-out events: ```python @step async def my_step(self, ctx: Context, ev: StartEvent) -> WorkerEvent | GatherEvent: for i in range(10): ctx.send_event(WorkerEvent(msg=i)) return GatherEvent() ``` You also see this method used from the caller side to send events into the workflow: ```python handler = my_workflow.run(...) async for ev in handler.stream_events(): if isinstance(ev, SomeEvent): handler.ctx.send_event(SomeOtherEvent(msg="Hello!")) result = await handler ``` """ return self._running_broker.send_event(message, step) async def wait_for_event( self, event_type: Type[T], waiter_event: Event | None = None, waiter_id: str | None = None, requirements: dict[str, Any] | None = None, timeout: float | None = 2000, ) -> T: """Wait for the next matching event of type `event_type`. Optionally emits a `waiter_event` to the event stream once per `waiter_id` to inform callers that the workflow is waiting for external input. This helps to prevent duplicate waiter events from being sent to the event stream. Args: event_type (type[T]): Concrete event class to wait for. waiter_event (Event | None): Optional event to write to the stream once when the wait begins. waiter_id (str | None): Stable identifier to avoid emitting multiple waiter events for the same logical wait. requirements (dict[str, Any] | None): Key/value filters that must be satisfied by the event via `event.get(key) == value`. timeout (float | None): Max seconds to wait. `None` means no timeout. Defaults to 2000 seconds. Returns: T: The received event instance of the requested type. Raises: asyncio.TimeoutError: If the timeout elapses. Examples: ```python @step async def my_step(self, ctx: Context, ev: StartEvent) -> StopEvent: response = await ctx.wait_for_event( HumanResponseEvent, waiter_event=InputRequiredEvent(msg="What's your name?"), waiter_id="user_name", timeout=60, ) return StopEvent(result=response.response) ``` """ return await self._running_broker.wait_for_event( event_type, waiter_event, waiter_id, requirements, timeout ) def write_event_to_stream(self, ev: Event | None) -> None: """Enqueue an event for streaming to [WorkflowHandler]](workflows.handler.WorkflowHandler). Args: ev (Event | None): The event to stream. `None` can be used as a sentinel in some streaming modes. Examples: ```python @step async def my_step(self, ctx: Context, ev: StartEvent) -> StopEvent: ctx.write_event_to_stream(ev) return StopEvent(result="ok") ``` """ self._running_broker.write_event_to_stream(ev) def get_result(self) -> RunResultT: """Return the final result of the workflow run. Deprecated: This method is deprecated and will be removed in a future release. Prefer awaiting the handler returned by `Workflow.run`, e.g.: `result = await workflow.run(..., ctx=ctx)`. Examples: ```python # Preferred result = await my_workflow.run(..., ctx=ctx) # Deprecated result_agent = ctx.get_result() ``` Returns: RunResultT: The value provided via a `StopEvent`. """ _warn_get_result() if self._running_broker._handler is None: raise WorkflowRuntimeError("Workflow handler is not set") return self._running_broker._handler.result() def stream_events(self) -> AsyncGenerator[Event, None]: """The internal queue used for streaming events to callers.""" return self._running_broker.stream_published_events() @property def streaming_queue(self) -> asyncio.Queue: """Deprecated queue-based event stream. Returns an asyncio.Queue that is populated by iterating this context's stream_events(). A deprecation warning is emitted once per process. """ _warn_streaming_queue() q: asyncio.Queue[Event] = asyncio.Queue() async def _pump() -> None: async for ev in self.stream_events(): await q.put(ev) if isinstance(ev, StopEvent): break try: asyncio.create_task(_pump()) except RuntimeError: loop = asyncio.get_event_loop() loop.create_task(_pump()) return q @functools.lru_cache(maxsize=1) def _warn_get_result() -> None: warnings.warn( ( "Context.get_result() is deprecated and will be removed in a future " "release. Prefer awaiting the WorkflowHandler returned by " "Workflow.run: `result = await workflow.run(..., ctx=ctx)`." ), DeprecationWarning, stacklevel=2, ) @functools.lru_cache(maxsize=1) def _warn_streaming_queue() -> None: warnings.warn( ( "Context.streaming_queue is deprecated and will be removed in a future " "release. Prefer iterating Context.stream_events(): " "`async for ev in ctx.stream_events(): ...`" ), DeprecationWarning, stacklevel=2, )