ferrero-opentext/Python-Version/venv/lib/python3.12/site-packages/workflows/context/context.py

613 lines
22 KiB
Python

# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.
from __future__ import annotations
import asyncio
import functools
import uuid
import warnings
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generic,
Type,
TypeVar,
cast,
)
from pydantic import BaseModel, ValidationError
from workflows.context.context_types import SerializedContext
from workflows.decorators import StepConfig
from workflows.errors import (
ContextSerdeError,
WorkflowRuntimeError,
)
from workflows.events import (
Event,
StartEvent,
StopEvent,
)
from workflows.runtime.types.internal_state import BrokerState
from workflows.runtime.broker import WorkflowBroker
from workflows.plugins.basic import basic_runtime
from workflows.runtime.types.plugin import Plugin, WorkflowRuntime
from workflows.types import RunResultT
from workflows.handler import WorkflowHandler
from .serializers import BaseSerializer, JsonSerializer
from .state_store import MODEL_T, DictState, InMemoryStateStore
if TYPE_CHECKING: # pragma: no cover
from workflows import Workflow
T = TypeVar("T", bound=Event)
EventBuffer = dict[str, list[Event]]
# Only warn once about unserializable keys
class UnserializableKeyWarning(Warning):
pass
warnings.simplefilter("once", UnserializableKeyWarning)
class Context(Generic[MODEL_T]):
"""
Global, per-run context for a `Workflow`. Provides an interface into the
underlying broker run, for both external (workflow run oberservers) and
internal consumption by workflow steps.
The `Context` coordinates event delivery between steps, tracks in-flight work,
exposes a global state store, and provides utilities for streaming and
synchronization. It is created by a `Workflow` at run time and can be
persisted and restored.
Args:
workflow (Workflow): The owning workflow instance. Used to infer
step configuration and instrumentation.
previous_context: A previous context snapshot to resume from.
serializer: A serializer to use for serializing and deserializing the current and previous context snapshots.
Attributes:
is_running (bool): Whether the workflow is currently running.
store (InMemoryStateStore[MODEL_T]): Type-safe, async state store shared
across steps. See also
[InMemoryStateStore][workflows.context.state_store.InMemoryStateStore].
Examples:
Basic usage inside a step:
```python
from workflows import step
from workflows.events import StartEvent, StopEvent
@step
async def start(self, ctx: Context, ev: StartEvent) -> StopEvent:
await ctx.store.set("query", ev.topic)
ctx.write_event_to_stream(ev) # surface progress to UI
return StopEvent(result="ok")
```
Persisting the state of a workflow across runs:
```python
from workflows import Context
# Create a context and run the workflow with the same context
ctx = Context(my_workflow)
result_1 = await my_workflow.run(..., ctx=ctx)
result_2 = await my_workflow.run(..., ctx=ctx)
# Serialize the context and restore it
ctx_dict = ctx.to_dict()
restored_ctx = Context.from_dict(my_workflow, ctx_dict)
result_3 = await my_workflow.run(..., ctx=restored_ctx)
```
See Also:
- [Workflow][workflows.Workflow]
- [Event][workflows.events.Event]
- [InMemoryStateStore][workflows.context.state_store.InMemoryStateStore]
"""
# These keys are set by pre-built workflows and
# are known to be unserializable in some cases.
known_unserializable_keys = ("memory",)
# Backing state store; serialized as `state`
_state_store: InMemoryStateStore[MODEL_T]
_broker_run: WorkflowBroker[MODEL_T] | None
_plugin: Plugin
_workflow: Workflow
def __init__(
self,
workflow: Workflow,
previous_context: dict[str, Any] | None = None,
serializer: BaseSerializer | None = None,
plugin: Plugin = basic_runtime,
) -> None:
self._serializer = serializer or JsonSerializer()
self._broker_run = None
self._plugin = plugin
self._workflow = workflow
# parse the serialized context
serializer = serializer or JsonSerializer()
if previous_context is not None:
try:
# Auto-detect and convert V0 to V1 if needed
previous_context_parsed = SerializedContext.from_dict_auto(
previous_context
)
# validate it fully parses synchronously to avoid delayed validation errors
BrokerState.from_serialized(
previous_context_parsed, workflow, serializer
)
except ValidationError as e:
raise ContextSerdeError(
f"Context dict specified in an invalid format: {e}"
) from e
else:
previous_context_parsed = SerializedContext()
self._init_snapshot = previous_context_parsed
# initialization of the state store is a bit complex, due to inferring and validating its type from the
# provided workflow context args
state_types: set[Type[BaseModel]] = set()
for _, step_func in workflow._get_steps().items():
step_config: StepConfig = step_func._step_config
if (
step_config.context_state_type is not None
and step_config.context_state_type != DictState
and issubclass(step_config.context_state_type, BaseModel)
):
state_type = step_config.context_state_type
state_types.add(state_type)
if len(state_types) > 1:
raise ValueError(
"Multiple state types are not supported. Make sure that each Context[...] has the same generic state type. Found: "
+ ", ".join([state_type.__name__ for state_type in state_types])
)
state_type = state_types.pop() if state_types else DictState
if previous_context_parsed.state:
# perhaps offer a way to clear on invalid
store_state = InMemoryStateStore.from_dict(
previous_context_parsed.state, serializer
)
if store_state.state_type != state_type:
raise ValueError(
f"State type mismatch. Workflow context expected {state_type.__name__}, got {store_state.state_type.__name__}"
)
self._state_store = cast(InMemoryStateStore[MODEL_T], store_state)
else:
try:
state_instance = cast(MODEL_T, state_type())
self._state_store = InMemoryStateStore(state_instance)
except Exception as e:
raise WorkflowRuntimeError(
f"Failed to initialize state of type {state_type}. Does your state define defaults for all fields? Original error:\n{e}"
) from e
@property
def is_running(self) -> bool:
"""Whether the workflow is currently running."""
if self._broker_run is None:
return self._init_snapshot.is_running
else:
return self._broker_run.is_running
def _init_broker(
self, workflow: Workflow, plugin: WorkflowRuntime | None = None
) -> WorkflowBroker[MODEL_T]:
if self._broker_run is not None:
raise WorkflowRuntimeError("Broker already initialized")
# Initialize a runtime plugin (asyncio-based by default)
runtime: WorkflowRuntime = plugin or self._plugin.new_runtime(str(uuid.uuid4()))
# Initialize the new broker implementation (broker2)
self._broker_run = WorkflowBroker(
workflow=workflow,
context=self,
runtime=runtime,
plugin=self._plugin,
)
return self._broker_run
def _workflow_run(
self,
workflow: Workflow,
start_event: StartEvent | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> WorkflowHandler:
"""
called by package internally from the workflow to run it
"""
prev_broker: WorkflowBroker[MODEL_T] | None = None
if self._broker_run is not None:
prev_broker = self._broker_run
self._broker_run = None
self._broker_run = self._init_broker(workflow)
async def before_start() -> None:
if prev_broker is not None:
try:
await prev_broker.shutdown()
except Exception:
pass
if semaphore is not None:
await semaphore.acquire()
async def after_complete() -> None:
if semaphore is not None:
semaphore.release()
state = BrokerState.from_serialized(
self._init_snapshot, workflow, self._serializer
)
return self._broker_run.start(
workflow=workflow,
previous=state,
start_event=start_event,
before_start=before_start,
after_complete=after_complete,
)
def _workflow_cancel_run(self) -> None:
"""
Called internally from the handler to cancel a context's run
"""
self._running_broker.cancel_run()
@property
def _running_broker(self) -> WorkflowBroker[MODEL_T]:
if self._broker_run is None:
raise WorkflowRuntimeError(
"Workflow run is not yet running. Make sure to only call this method after the context has been passed to a workflow.run call."
)
return self._broker_run
@property
def store(self) -> InMemoryStateStore[MODEL_T]:
"""Typed, process-local state store shared across steps.
If no state was initialized yet, a default
[DictState][workflows.context.state_store.DictState] store is created.
Returns:
InMemoryStateStore[MODEL_T]: The state store instance.
"""
return self._state_store
def to_dict(self, serializer: BaseSerializer | None = None) -> dict[str, Any]:
"""Serialize the context to a JSON-serializable dict.
Persists the global state store, event queues, buffers, accepted events,
broker log, and running flag. This payload can be fed to
[from_dict][workflows.context.context.Context.from_dict] to resume a run
or carry state across runs.
Args:
serializer (BaseSerializer | None): Value serializer used for state
and event payloads. Defaults to
[JsonSerializer][workflows.context.serializers.JsonSerializer].
Returns:
dict[str, Any]: A dict suitable for JSON encoding and later
restoration via `from_dict`.
See Also:
- [InMemoryStateStore.to_dict][workflows.context.state_store.InMemoryStateStore.to_dict]
Examples:
```python
ctx_dict = ctx.to_dict()
my_db.set("key", json.dumps(ctx_dict))
ctx_dict = my_db.get("key")
restored_ctx = Context.from_dict(my_workflow, json.loads(ctx_dict))
result = await my_workflow.run(..., ctx=restored_ctx)
```
"""
serializer = serializer or self._serializer
# Serialize state using the state manager's method
state_data = {}
if self._state_store is not None:
state_data = self._state_store.to_dict(serializer)
# Get the broker state - either from the running broker or from the init snapshot
if self._broker_run is not None:
broker_state = self._broker_run._state
else:
# Deserialize the init snapshot to get a BrokerState, then re-serialize it
# This ensures we always output the current format
broker_state = BrokerState.from_serialized(
self._init_snapshot, self._workflow, serializer
)
context = broker_state.to_serialized(serializer)
context.state = state_data
# mode="python" to support pickling over json if one so chooses. This should perhaps be moved into the serializers
return context.model_dump(mode="python")
@classmethod
def from_dict(
cls,
workflow: "Workflow",
data: dict[str, Any],
serializer: BaseSerializer | None = None,
) -> "Context[MODEL_T]":
"""Reconstruct a `Context` from a serialized payload.
Args:
workflow (Workflow): The workflow instance that will own this
context.
data (dict[str, Any]): Payload produced by
[to_dict][workflows.context.context.Context.to_dict].
serializer (BaseSerializer | None): Serializer used to decode state
and events. Defaults to JSON.
Returns:
Context[MODEL_T]: A context instance initialized with the persisted
state and queues.
Raises:
ContextSerdeError: If the payload is missing required fields or is
in an incompatible format.
Examples:
```python
ctx_dict = ctx.to_dict()
my_db.set("key", json.dumps(ctx_dict))
ctx_dict = my_db.get("key")
restored_ctx = Context.from_dict(my_workflow, json.loads(ctx_dict))
result = await my_workflow.run(..., ctx=restored_ctx)
```
"""
try:
return cls(workflow, previous_context=data, serializer=serializer)
except KeyError as e:
msg = "Error creating a Context instance: the provided payload has a wrong or old format."
raise ContextSerdeError(msg) from e
async def running_steps(self) -> list[str]:
"""Return the list of currently running step names.
Returns:
list[str]: Names of steps that have at least one active worker.
"""
return await self._running_broker.running_steps()
def collect_events(
self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
) -> list[Event] | None:
"""
Buffer events until all expected types are available, then return them.
This utility is helpful when a step can receive multiple event types
and needs to proceed only when it has a full set. The returned list is
ordered according to `expected`.
Args:
ev (Event): The incoming event to add to the buffer.
expected (list[Type[Event]]): Event types to collect, in order.
buffer_id (str | None): Optional stable key to isolate buffers across
steps or workers. Defaults to an internal key derived from the
task name or expected types.
Returns:
list[Event] | None: The events in the requested order when complete,
otherwise `None`.
Examples:
```python
@step
async def synthesize(
self, ctx: Context, ev: QueryEvent | RetrieveEvent
) -> StopEvent | None:
events = ctx.collect_events(ev, [QueryEvent, RetrieveEvent])
if events is None:
return None
query_ev, retrieve_ev = events
# ... proceed with both inputs present ...
```
See Also:
- [Event][workflows.events.Event]
"""
return self._running_broker.collect_events(ev, expected, buffer_id)
def send_event(self, message: Event, step: str | None = None) -> None:
"""Dispatch an event to one or all workflow steps.
If `step` is omitted, the event is broadcast to all step queues and
non-matching steps will ignore it. When `step` is provided, the target
step must accept the event type or a
[WorkflowRuntimeError][workflows.errors.WorkflowRuntimeError] is raised.
Args:
message (Event): The event to enqueue.
step (str | None): Optional step name to target.
Raises:
WorkflowRuntimeError: If the target step does not exist or does not
accept the event type.
Examples:
It's common to use this method to fan-out events:
```python
@step
async def my_step(self, ctx: Context, ev: StartEvent) -> WorkerEvent | GatherEvent:
for i in range(10):
ctx.send_event(WorkerEvent(msg=i))
return GatherEvent()
```
You also see this method used from the caller side to send events into the workflow:
```python
handler = my_workflow.run(...)
async for ev in handler.stream_events():
if isinstance(ev, SomeEvent):
handler.ctx.send_event(SomeOtherEvent(msg="Hello!"))
result = await handler
```
"""
return self._running_broker.send_event(message, step)
async def wait_for_event(
self,
event_type: Type[T],
waiter_event: Event | None = None,
waiter_id: str | None = None,
requirements: dict[str, Any] | None = None,
timeout: float | None = 2000,
) -> T:
"""Wait for the next matching event of type `event_type`.
Optionally emits a `waiter_event` to the event stream once per `waiter_id` to
inform callers that the workflow is waiting for external input.
This helps to prevent duplicate waiter events from being sent to the event stream.
Args:
event_type (type[T]): Concrete event class to wait for.
waiter_event (Event | None): Optional event to write to the stream
once when the wait begins.
waiter_id (str | None): Stable identifier to avoid emitting multiple
waiter events for the same logical wait.
requirements (dict[str, Any] | None): Key/value filters that must be
satisfied by the event via `event.get(key) == value`.
timeout (float | None): Max seconds to wait. `None` means no
timeout. Defaults to 2000 seconds.
Returns:
T: The received event instance of the requested type.
Raises:
asyncio.TimeoutError: If the timeout elapses.
Examples:
```python
@step
async def my_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
response = await ctx.wait_for_event(
HumanResponseEvent,
waiter_event=InputRequiredEvent(msg="What's your name?"),
waiter_id="user_name",
timeout=60,
)
return StopEvent(result=response.response)
```
"""
return await self._running_broker.wait_for_event(
event_type, waiter_event, waiter_id, requirements, timeout
)
def write_event_to_stream(self, ev: Event | None) -> None:
"""Enqueue an event for streaming to [WorkflowHandler]](workflows.handler.WorkflowHandler).
Args:
ev (Event | None): The event to stream. `None` can be used as a
sentinel in some streaming modes.
Examples:
```python
@step
async def my_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
ctx.write_event_to_stream(ev)
return StopEvent(result="ok")
```
"""
self._running_broker.write_event_to_stream(ev)
def get_result(self) -> RunResultT:
"""Return the final result of the workflow run.
Deprecated:
This method is deprecated and will be removed in a future release.
Prefer awaiting the handler returned by `Workflow.run`, e.g.:
`result = await workflow.run(..., ctx=ctx)`.
Examples:
```python
# Preferred
result = await my_workflow.run(..., ctx=ctx)
# Deprecated
result_agent = ctx.get_result()
```
Returns:
RunResultT: The value provided via a `StopEvent`.
"""
_warn_get_result()
if self._running_broker._handler is None:
raise WorkflowRuntimeError("Workflow handler is not set")
return self._running_broker._handler.result()
def stream_events(self) -> AsyncGenerator[Event, None]:
"""The internal queue used for streaming events to callers."""
return self._running_broker.stream_published_events()
@property
def streaming_queue(self) -> asyncio.Queue:
"""Deprecated queue-based event stream.
Returns an asyncio.Queue that is populated by iterating this context's
stream_events(). A deprecation warning is emitted once per process.
"""
_warn_streaming_queue()
q: asyncio.Queue[Event] = asyncio.Queue()
async def _pump() -> None:
async for ev in self.stream_events():
await q.put(ev)
if isinstance(ev, StopEvent):
break
try:
asyncio.create_task(_pump())
except RuntimeError:
loop = asyncio.get_event_loop()
loop.create_task(_pump())
return q
@functools.lru_cache(maxsize=1)
def _warn_get_result() -> None:
warnings.warn(
(
"Context.get_result() is deprecated and will be removed in a future "
"release. Prefer awaiting the WorkflowHandler returned by "
"Workflow.run: `result = await workflow.run(..., ctx=ctx)`."
),
DeprecationWarning,
stacklevel=2,
)
@functools.lru_cache(maxsize=1)
def _warn_streaming_queue() -> None:
warnings.warn(
(
"Context.streaming_queue is deprecated and will be removed in a future "
"release. Prefer iterating Context.stream_events(): "
"`async for ev in ctx.stream_events(): ...`"
),
DeprecationWarning,
stacklevel=2,
)