344 lines
12 KiB
Python
344 lines
12 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025 LlamaIndex Inc.
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections import Counter, defaultdict
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
AsyncGenerator,
|
|
Awaitable,
|
|
Callable,
|
|
Coroutine,
|
|
Generic,
|
|
Type,
|
|
TypeVar,
|
|
cast,
|
|
)
|
|
|
|
from llama_index_instrumentation.dispatcher import (
|
|
active_instrument_tags,
|
|
instrument_tags,
|
|
)
|
|
|
|
|
|
from workflows.utils import _nanoid as nanoid
|
|
from workflows.errors import WorkflowRuntimeError
|
|
from workflows.events import (
|
|
Event,
|
|
StartEvent,
|
|
)
|
|
from workflows.runtime.control_loop import control_loop, rebuild_state_from_ticks
|
|
from workflows.runtime.types.internal_state import BrokerState
|
|
from workflows.runtime.types.plugin import Plugin, WorkflowRuntime, as_snapshottable
|
|
from workflows.runtime.types.results import (
|
|
AddCollectedEvent,
|
|
AddWaiter,
|
|
DeleteCollectedEvent,
|
|
DeleteWaiter,
|
|
StepWorkerContext,
|
|
StepWorkerStateContextVar,
|
|
WaitingForEvent,
|
|
)
|
|
from workflows.runtime.types.step_function import (
|
|
StepWorkerFunction,
|
|
as_step_worker_function,
|
|
)
|
|
from workflows.runtime.types.ticks import TickAddEvent, TickCancelRun, WorkflowTick
|
|
from workflows.runtime.workflow_registry import workflow_registry
|
|
|
|
from ..context.state_store import MODEL_T
|
|
|
|
from workflows.handler import WorkflowHandler
|
|
|
|
if TYPE_CHECKING:
|
|
from workflows import Workflow
|
|
from workflows.context.context import Context
|
|
|
|
|
|
T = TypeVar("T", bound=Event)
|
|
EventBuffer = dict[str, list[Event]]
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
# Only warn once about unserializable keys
|
|
class UnserializableKeyWarning(Warning):
|
|
pass
|
|
|
|
|
|
class WorkflowBroker(Generic[MODEL_T]):
|
|
"""
|
|
The workflow broker manages starting up and connecting a workflow handler, a runtime, and triggering the
|
|
execution of the workflow. From there it manages communication between the workflow and the outside world.
|
|
"""
|
|
|
|
_context: Context[MODEL_T]
|
|
_runtime: WorkflowRuntime
|
|
_plugin: Plugin
|
|
_is_running: bool
|
|
_handler: WorkflowHandler | None
|
|
_workflow: Workflow
|
|
# transient tasks to run async ops in background, exposing sync interfaces
|
|
_workers: list[asyncio.Task]
|
|
_init_state: BrokerState | None
|
|
|
|
def __init__(
|
|
self,
|
|
workflow: Workflow,
|
|
context: Context[MODEL_T],
|
|
runtime: WorkflowRuntime,
|
|
plugin: Plugin,
|
|
) -> None:
|
|
self._context = context
|
|
self._runtime = runtime
|
|
self._plugin = plugin
|
|
self._is_running = False
|
|
self._handler = None
|
|
self._workflow = workflow
|
|
self._workers = []
|
|
self._init_state = None
|
|
|
|
def _execute_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task[Any]:
|
|
task = asyncio.create_task(coro)
|
|
self._workers.append(task)
|
|
task.add_done_callback(lambda _: self._workers.remove(task))
|
|
return task
|
|
|
|
# context API only
|
|
def start(
|
|
self,
|
|
workflow: Workflow,
|
|
previous: BrokerState | None = None,
|
|
start_event: StartEvent | None = None,
|
|
before_start: Callable[[], Awaitable[None]] | None = None,
|
|
after_complete: Callable[[], Awaitable[None]] | None = None,
|
|
) -> WorkflowHandler:
|
|
"""Start the workflow run. Can only be called once."""
|
|
if self._handler is not None:
|
|
raise WorkflowRuntimeError(
|
|
"this WorkflowBroker already run or running. Cannot start again."
|
|
)
|
|
self._init_state = previous
|
|
|
|
async def _run_workflow(run_id: str, tags: dict[str, Any]) -> None:
|
|
with instrument_tags({"run_id": run_id, **tags}):
|
|
# defer execution to make sure the task can be captured and passed
|
|
# to the handler as async exception, protecting against exceptions from before_start
|
|
self._is_running = True
|
|
await asyncio.sleep(0)
|
|
if before_start is not None:
|
|
await before_start()
|
|
try:
|
|
init_state = previous or BrokerState.from_workflow(workflow)
|
|
|
|
try:
|
|
exception_raised = None
|
|
|
|
step_workers: dict[str, StepWorkerFunction] = {}
|
|
for name, step_func in workflow._get_steps().items():
|
|
# Avoid capturing a bound method (which retains the instance).
|
|
# If it's a bound method, extract the unbound function from the class.
|
|
unbound = getattr(step_func, "__func__", step_func)
|
|
step_workers[name] = as_step_worker_function(unbound)
|
|
|
|
registered = workflow_registry.get_registered_workflow(
|
|
workflow, self._plugin, control_loop, step_workers
|
|
)
|
|
|
|
# Register run context prior to invoking control loop
|
|
workflow_registry.register_run(
|
|
run_id=run_id,
|
|
workflow=workflow,
|
|
plugin=self._runtime,
|
|
context=self._context, # type: ignore
|
|
steps=registered.steps,
|
|
)
|
|
|
|
try:
|
|
workflow_result = await registered.workflow_function(
|
|
start_event,
|
|
init_state,
|
|
run_id,
|
|
)
|
|
finally:
|
|
# ensure run context is cleaned up even on failure
|
|
workflow_registry.delete_run(run_id)
|
|
result._set_stop_event(workflow_result)
|
|
except Exception as e:
|
|
exception_raised = e
|
|
|
|
if exception_raised:
|
|
# cancel the stream
|
|
if not result.done():
|
|
result.set_exception(exception_raised)
|
|
finally:
|
|
if after_complete is not None:
|
|
await after_complete()
|
|
self._is_running = False
|
|
|
|
# Start the machinery in a new Context or use the provided one
|
|
run_id = nanoid()
|
|
|
|
# If a previous context is provided, pass its serialized form
|
|
|
|
run_task = self._execute_task(
|
|
_run_workflow(run_id, tags=active_instrument_tags.get())
|
|
)
|
|
result = WorkflowHandler(
|
|
ctx=self._context, # type: ignore
|
|
run_id=run_id,
|
|
run_task=run_task,
|
|
)
|
|
self._handler = result
|
|
return result
|
|
|
|
# outer handler API to cancel the workflow run
|
|
def cancel_run(self) -> None:
|
|
self._execute_task(self._runtime.send_event(TickCancelRun()))
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._is_running
|
|
|
|
@property
|
|
def _state(self) -> BrokerState:
|
|
ticks = self._tick_log
|
|
state = self._init_state or BrokerState.from_workflow(self._workflow)
|
|
new_state = rebuild_state_from_ticks(state, ticks)
|
|
return new_state
|
|
|
|
@property
|
|
def _tick_log(self) -> list[WorkflowTick]:
|
|
snapshottable = as_snapshottable(self._runtime)
|
|
if snapshottable is None:
|
|
raise WorkflowRuntimeError("Plugin is not snapshottable")
|
|
return snapshottable.replay()
|
|
|
|
# mostly a debug API. May be removed in the future.
|
|
async def running_steps(self) -> list[str]:
|
|
return [
|
|
step
|
|
for step in self._state.workers.keys()
|
|
if self._state.workers[step].in_progress
|
|
]
|
|
|
|
# step api only
|
|
def collect_events(
|
|
self, ev: Event, expected: list[Type[Event]], buffer_id: str | None = None
|
|
) -> list[Event] | None:
|
|
step_ctx = self._get_step_ctx(fn="collect_events")
|
|
|
|
buffer_id = buffer_id or "default"
|
|
|
|
collected_events = step_ctx.state.collected_events.get(buffer_id, [])
|
|
|
|
remaining_event_types = Counter(expected) - Counter(
|
|
[type(e) for e in collected_events]
|
|
)
|
|
|
|
if remaining_event_types != Counter([type(ev)]):
|
|
if type(ev) in remaining_event_types:
|
|
step_ctx.returns.return_values.append(
|
|
AddCollectedEvent(event_id=buffer_id, event=ev)
|
|
)
|
|
return None
|
|
|
|
total = []
|
|
by_type = defaultdict(list)
|
|
for e in collected_events + [ev]:
|
|
by_type[type(e)].append(e)
|
|
# order by expected type
|
|
for e_type in expected:
|
|
total.append(by_type[e_type].pop(0))
|
|
# if we got here, it means the collection is fulfilled. Clear the collected events when the step is complete
|
|
step_ctx.returns.return_values.append(DeleteCollectedEvent(event_id=buffer_id))
|
|
return total
|
|
|
|
# may be called from both step API and outer handler API
|
|
def send_event(self, message: Event, step: str | None = None) -> None:
|
|
if step is not None:
|
|
if step not in self._workflow._get_steps():
|
|
raise WorkflowRuntimeError(f"Step {step} does not exist")
|
|
|
|
# Validate that the step accepts this event type
|
|
step_func = self._workflow._get_steps()[step]
|
|
step_config = step_func._step_config
|
|
if type(message) not in step_config.accepted_events:
|
|
raise WorkflowRuntimeError(
|
|
f"Step {step} does not accept event of type {type(message)}"
|
|
)
|
|
|
|
self._execute_task(
|
|
self._runtime.send_event(TickAddEvent(event=message, step_name=step))
|
|
)
|
|
|
|
def _get_step_ctx(self, fn: str) -> StepWorkerContext:
|
|
try:
|
|
return StepWorkerStateContextVar.get()
|
|
except LookupError:
|
|
raise WorkflowRuntimeError(
|
|
f"{fn} may only be called from within a step function"
|
|
)
|
|
|
|
# step api only
|
|
async def wait_for_event(
|
|
self,
|
|
event_type: Type[T],
|
|
waiter_event: Event | None = None,
|
|
waiter_id: str | None = None,
|
|
requirements: dict[str, Any] | None = None,
|
|
timeout: float | None = 2000,
|
|
) -> T:
|
|
step_ctx = self._get_step_ctx(fn="wait_for_event")
|
|
|
|
collected_waiters = step_ctx.state.collected_waiters
|
|
requirements = requirements or {}
|
|
|
|
# Generate a unique key for the waiter
|
|
event_str = self._get_full_path(event_type)
|
|
requirements_str = str(requirements)
|
|
waiter_id = waiter_id or f"waiter_{event_str}_{requirements_str}"
|
|
|
|
waiter = next((w for w in collected_waiters if w.waiter_id == waiter_id), None)
|
|
if waiter is None or waiter.resolved_event is None:
|
|
raise WaitingForEvent(
|
|
AddWaiter(
|
|
waiter_id=waiter_id,
|
|
requirements=requirements,
|
|
timeout=timeout,
|
|
event_type=event_type,
|
|
waiter_event=waiter_event,
|
|
)
|
|
)
|
|
else:
|
|
step_ctx.returns.return_values.append(DeleteWaiter(waiter_id=waiter_id))
|
|
return cast(T, waiter.resolved_event)
|
|
|
|
def _get_full_path(self, ev_type: Type[Event]) -> str:
|
|
return f"{ev_type.__module__}.{ev_type.__name__}"
|
|
|
|
def stream_published_events(self) -> AsyncGenerator[Event, None]:
|
|
"""The internal queue used for streaming events to callers."""
|
|
return self._runtime.stream_published_events()
|
|
|
|
# step API only
|
|
def write_event_to_stream(self, ev: Event | None) -> None:
|
|
if ev is not None:
|
|
self._execute_task(self._runtime.write_to_event_stream(ev))
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Cancels the running workflow loop
|
|
|
|
Cancels all outstanding workers, waits for them to finish, and marks the
|
|
broker as not running. Queues and state remain available so callers can
|
|
inspect or drain leftover events.
|
|
"""
|
|
await self._runtime.send_event(TickCancelRun())
|
|
for worker in self._workers:
|
|
worker.cancel()
|
|
self._workers.clear()
|
|
await self._runtime.close()
|