ferrero-opentext/Python-Version/venv/lib/python3.12/site-packages/workflows/runtime/broker.py

344 lines
12 KiB
Python

# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.
from __future__ import annotations
import asyncio
import logging
from collections import Counter, defaultdict
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Generic,
Type,
TypeVar,
cast,
)
from llama_index_instrumentation.dispatcher import (
active_instrument_tags,
instrument_tags,
)
from workflows.utils import _nanoid as nanoid
from workflows.errors import WorkflowRuntimeError
from workflows.events import (
Event,
StartEvent,
)
from workflows.runtime.control_loop import control_loop, rebuild_state_from_ticks
from workflows.runtime.types.internal_state import BrokerState
from workflows.runtime.types.plugin import Plugin, WorkflowRuntime, as_snapshottable
from workflows.runtime.types.results import (
AddCollectedEvent,
AddWaiter,
DeleteCollectedEvent,
DeleteWaiter,
StepWorkerContext,
StepWorkerStateContextVar,
WaitingForEvent,
)
from workflows.runtime.types.step_function import (
StepWorkerFunction,
as_step_worker_function,
)
from workflows.runtime.types.ticks import TickAddEvent, TickCancelRun, WorkflowTick
from workflows.runtime.workflow_registry import workflow_registry
from ..context.state_store import MODEL_T
from workflows.handler import WorkflowHandler
if TYPE_CHECKING:
from workflows import Workflow
from workflows.context.context import Context
T = TypeVar("T", bound=Event)
EventBuffer = dict[str, list[Event]]
logger = logging.getLogger()
# Only warn once about unserializable keys
class UnserializableKeyWarning(Warning):
pass
class WorkflowBroker(Generic[MODEL_T]):
"""
The workflow broker manages starting up and connecting a workflow handler, a runtime, and triggering the
execution of the workflow. From there it manages communication between the workflow and the outside world.
"""
_context: Context[MODEL_T]
_runtime: WorkflowRuntime
_plugin: Plugin
_is_running: bool
_handler: WorkflowHandler | None
_workflow: Workflow
# transient tasks to run async ops in background, exposing sync interfaces
_workers: list[asyncio.Task]
_init_state: BrokerState | None
def __init__(
self,
workflow: Workflow,
context: Context[MODEL_T],
runtime: WorkflowRuntime,
plugin: Plugin,
) -> None:
self._context = context
self._runtime = runtime
self._plugin = plugin
self._is_running = False
self._handler = None
self._workflow = workflow
self._workers = []
self._init_state = None
def _execute_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
task = asyncio.create_task(coro)
self._workers.append(task)
task.add_done_callback(lambda _: self._workers.remove(task))
return task
# context API only
def start(
self,
workflow: Workflow,
previous: BrokerState | None = None,
start_event: StartEvent | None = None,
before_start: Callable[[], Awaitable[None]] | None = None,
after_complete: Callable[[], Awaitable[None]] | None = None,
) -> WorkflowHandler:
"""Start the workflow run. Can only be called once."""
if self._handler is not None:
raise WorkflowRuntimeError(
"this WorkflowBroker already run or running. Cannot start again."
)
self._init_state = previous
async def _run_workflow(run_id: str, tags: dict[str, Any]) -> None:
with instrument_tags({"run_id": run_id, **tags}):
# defer execution to make sure the task can be captured and passed
# to the handler as async exception, protecting against exceptions from before_start
self._is_running = True
await asyncio.sleep(0)
if before_start is not None:
await before_start()
try:
init_state = previous or BrokerState.from_workflow(workflow)
try:
exception_raised = None
step_workers: dict[str, StepWorkerFunction] = {}
for name, step_func in workflow._get_steps().items():
# Avoid capturing a bound method (which retains the instance).
# If it's a bound method, extract the unbound function from the class.
unbound = getattr(step_func, "__func__", step_func)
step_workers[name] = as_step_worker_function(unbound)
registered = workflow_registry.get_registered_workflow(
workflow, self._plugin, control_loop, step_workers
)
# Register run context prior to invoking control loop
workflow_registry.register_run(
run_id=run_id,
workflow=workflow,
plugin=self._runtime,
context=self._context, # type: ignore
steps=registered.steps,
)
try:
workflow_result = await registered.workflow_function(
start_event,
init_state,
run_id,
)
finally:
# ensure run context is cleaned up even on failure
workflow_registry.delete_run(run_id)
result._set_stop_event(workflow_result)
except Exception as e:
exception_raised = e
if exception_raised:
# cancel the stream
if not result.done():
result.set_exception(exception_raised)
finally:
if after_complete is not None:
await after_complete()
self._is_running = False
# Start the machinery in a new Context or use the provided one
run_id = nanoid()
# If a previous context is provided, pass its serialized form
run_task = self._execute_task(
_run_workflow(run_id, tags=active_instrument_tags.get())
)
result = WorkflowHandler(
ctx=self._context, # type: ignore
run_id=run_id,
run_task=run_task,
)
self._handler = result
return result
# outer handler API to cancel the workflow run
def cancel_run(self) -> None:
self._execute_task(self._runtime.send_event(TickCancelRun()))
@property
def is_running(self) -> bool:
return self._is_running
@property
def _state(self) -> BrokerState:
ticks = self._tick_log
state = self._init_state or BrokerState.from_workflow(self._workflow)
new_state = rebuild_state_from_ticks(state, ticks)
return new_state
@property
def _tick_log(self) -> list[WorkflowTick]:
snapshottable = as_snapshottable(self._runtime)
if snapshottable is None:
raise WorkflowRuntimeError("Plugin is not snapshottable")
return snapshottable.replay()
# mostly a debug API. May be removed in the future.
async def running_steps(self) -> list[str]:
return [
step
for step in self._state.workers.keys()
if self._state.workers[step].in_progress
]
# step api only
def collect_events(
self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
) -> list[Event] | None:
step_ctx = self._get_step_ctx(fn="collect_events")
buffer_id = buffer_id or "default"
collected_events = step_ctx.state.collected_events.get(buffer_id, [])
remaining_event_types = Counter(expected) - Counter(
[type(e) for e in collected_events]
)
if remaining_event_types != Counter([type(ev)]):
if type(ev) in remaining_event_types:
step_ctx.returns.return_values.append(
AddCollectedEvent(event_id=buffer_id, event=ev)
)
return None
total = []
by_type = defaultdict(list)
for e in collected_events + [ev]:
by_type[type(e)].append(e)
# order by expected type
for e_type in expected:
total.append(by_type[e_type].pop(0))
# if we got here, it means the collection is fulfilled. Clear the collected events when the step is complete
step_ctx.returns.return_values.append(DeleteCollectedEvent(event_id=buffer_id))
return total
# may be called from both step API and outer handler API
def send_event(self, message: Event, step: str | None = None) -> None:
if step is not None:
if step not in self._workflow._get_steps():
raise WorkflowRuntimeError(f"Step {step} does not exist")
# Validate that the step accepts this event type
step_func = self._workflow._get_steps()[step]
step_config = step_func._step_config
if type(message) not in step_config.accepted_events:
raise WorkflowRuntimeError(
f"Step {step} does not accept event of type {type(message)}"
)
self._execute_task(
self._runtime.send_event(TickAddEvent(event=message, step_name=step))
)
def _get_step_ctx(self, fn: str) -> StepWorkerContext:
try:
return StepWorkerStateContextVar.get()
except LookupError:
raise WorkflowRuntimeError(
f"{fn} may only be called from within a step function"
)
# step api only
async def wait_for_event(
self,
event_type: Type[T],
waiter_event: Event | None = None,
waiter_id: str | None = None,
requirements: dict[str, Any] | None = None,
timeout: float | None = 2000,
) -> T:
step_ctx = self._get_step_ctx(fn="wait_for_event")
collected_waiters = step_ctx.state.collected_waiters
requirements = requirements or {}
# Generate a unique key for the waiter
event_str = self._get_full_path(event_type)
requirements_str = str(requirements)
waiter_id = waiter_id or f"waiter_{event_str}_{requirements_str}"
waiter = next((w for w in collected_waiters if w.waiter_id == waiter_id), None)
if waiter is None or waiter.resolved_event is None:
raise WaitingForEvent(
AddWaiter(
waiter_id=waiter_id,
requirements=requirements,
timeout=timeout,
event_type=event_type,
waiter_event=waiter_event,
)
)
else:
step_ctx.returns.return_values.append(DeleteWaiter(waiter_id=waiter_id))
return cast(T, waiter.resolved_event)
def _get_full_path(self, ev_type: Type[Event]) -> str:
return f"{ev_type.__module__}.{ev_type.__name__}"
def stream_published_events(self) -> AsyncGenerator[Event, None]:
"""The internal queue used for streaming events to callers."""
return self._runtime.stream_published_events()
# step API only
def write_event_to_stream(self, ev: Event | None) -> None:
if ev is not None:
self._execute_task(self._runtime.write_to_event_stream(ev))
async def shutdown(self) -> None:
"""Cancels the running workflow loop
Cancels all outstanding workers, waits for them to finish, and marks the
broker as not running. Queues and state remain available so callers can
inspect or drain leftover events.
"""
await self._runtime.send_event(TickCancelRun())
for worker in self._workers:
worker.cancel()
self._workers.clear()
await self._runtime.close()