462 lines
17 KiB
Python
462 lines
17 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025 LlamaIndex Inc.
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import (
|
|
Any,
|
|
Tuple,
|
|
)
|
|
|
|
from llama_index_instrumentation import get_dispatcher
|
|
from pydantic import ValidationError
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from .context import Context
|
|
from .decorators import StepConfig, StepFunction
|
|
from .errors import (
|
|
WorkflowConfigurationError,
|
|
WorkflowRuntimeError,
|
|
WorkflowValidationError,
|
|
)
|
|
from .events import (
|
|
Event,
|
|
HumanResponseEvent,
|
|
InputRequiredEvent,
|
|
StartEvent,
|
|
StopEvent,
|
|
)
|
|
from .handler import WorkflowHandler
|
|
from .resource import ResourceManager
|
|
from .types import RunResultT
|
|
from .utils import get_steps_from_class, get_steps_from_instance
|
|
|
|
dispatcher = get_dispatcher(__name__)
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class WorkflowMeta(type):
|
|
def __init__(cls, name: str, bases: Tuple[type, ...], dct: dict[str, Any]) -> None:
|
|
super().__init__(name, bases, dct)
|
|
cls._step_functions: dict[str, StepFunction] = {}
|
|
|
|
|
|
class Workflow(metaclass=WorkflowMeta):
|
|
"""
|
|
Event-driven orchestrator to define and run application flows using typed steps.
|
|
|
|
A `Workflow` is composed of `@step`-decorated callables that accept and emit
|
|
typed [Event][workflows.events.Event]s. Steps can be declared as instance
|
|
methods or as free functions registered via the decorator.
|
|
|
|
Key features:
|
|
- Validation of step signatures and event graph before running
|
|
- Typed start/stop events
|
|
- Streaming of intermediate events
|
|
- Optional human-in-the-loop events
|
|
- Retry policies per step
|
|
- Resource injection
|
|
|
|
Examples:
|
|
Basic usage:
|
|
|
|
```python
|
|
from workflows import Workflow, step
|
|
from workflows.events import StartEvent, StopEvent
|
|
|
|
class MyFlow(Workflow):
|
|
@step
|
|
async def start(self, ev: StartEvent) -> StopEvent:
|
|
return StopEvent(result="done")
|
|
|
|
result = await MyFlow(timeout=60).run(topic="Pirates")
|
|
```
|
|
|
|
Custom start/stop events and streaming:
|
|
|
|
```python
|
|
handler = MyFlow().run()
|
|
async for ev in handler.stream_events():
|
|
...
|
|
result = await handler
|
|
```
|
|
|
|
See Also:
|
|
- [step][workflows.decorators.step]
|
|
- [Event][workflows.events.Event]
|
|
- [Context][workflows.context.context.Context]
|
|
- [WorkflowHandler][workflows.handler.WorkflowHandler]
|
|
- [RetryPolicy][workflows.retry_policy.RetryPolicy]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
timeout: float | None = 45.0,
|
|
disable_validation: bool = False,
|
|
verbose: bool = False,
|
|
resource_manager: ResourceManager | None = None,
|
|
num_concurrent_runs: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize a workflow instance.
|
|
|
|
Args:
|
|
timeout (float | None): Max seconds to wait for completion. `None`
|
|
disables the timeout.
|
|
disable_validation (bool): Skip pre-run validation of the event graph
|
|
(not recommended).
|
|
verbose (bool): If True, print step activity.
|
|
resource_manager (ResourceManager | None): Custom resource manager
|
|
for dependency injection.
|
|
num_concurrent_runs (int | None): Limit on concurrent `run()` calls.
|
|
"""
|
|
# Configuration
|
|
self._timeout = timeout
|
|
self._verbose = verbose
|
|
self._disable_validation = disable_validation
|
|
self._num_concurrent_runs = num_concurrent_runs
|
|
# Detect StartEvent issues before StopEvent for clearer guidance
|
|
self._start_event_class = self._ensure_start_event_class()
|
|
self._stop_event_class = self._ensure_stop_event_class()
|
|
self._events = self._ensure_events_collected()
|
|
self._sem = (
|
|
asyncio.Semaphore(num_concurrent_runs) if num_concurrent_runs else None
|
|
)
|
|
# Resource management
|
|
self._resource_manager = resource_manager or ResourceManager()
|
|
# Instrumentation
|
|
self._dispatcher = dispatcher
|
|
|
|
def _ensure_start_event_class(self) -> type[StartEvent]:
|
|
"""
|
|
Returns the StartEvent type used in this workflow.
|
|
|
|
It works by inspecting the events received by the step methods.
|
|
"""
|
|
start_events_found: set[type[StartEvent]] = set()
|
|
for step_func in self._get_steps().values():
|
|
step_config: StepConfig = step_func._step_config
|
|
for event_type in step_config.accepted_events:
|
|
if issubclass(event_type, StartEvent):
|
|
start_events_found.add(event_type)
|
|
|
|
num_found = len(start_events_found)
|
|
if num_found == 0:
|
|
cls_name = self.__class__.__name__
|
|
msg = (
|
|
"At least one Event of type StartEvent must be received by any step. "
|
|
f"(Workflow '{cls_name}' has no @step that accepts StartEvent.)"
|
|
)
|
|
raise WorkflowConfigurationError(msg)
|
|
elif num_found > 1:
|
|
cls_name = self.__class__.__name__
|
|
msg = (
|
|
f"Only one type of StartEvent is allowed per workflow, found {num_found}: {start_events_found} "
|
|
f"in workflow '{cls_name}'."
|
|
)
|
|
raise WorkflowConfigurationError(msg)
|
|
else:
|
|
return start_events_found.pop()
|
|
|
|
@property
|
|
def start_event_class(self) -> type[StartEvent]:
|
|
"""The `StartEvent` subclass accepted by this workflow.
|
|
|
|
Determined by inspecting step input types.
|
|
"""
|
|
return self._start_event_class
|
|
|
|
@property
|
|
def events(self) -> list[type[Event]]:
|
|
"""Returns all known events emitted by this workflow.
|
|
|
|
Determined by inspecting step input/output types.
|
|
"""
|
|
return self._events
|
|
|
|
def _ensure_events_collected(self) -> list[type[Event]]:
|
|
"""Returns all known events emitted by this workflow.
|
|
|
|
Determined by inspecting step input/output types.
|
|
"""
|
|
events_found: set[type[Event]] = set()
|
|
for step_func in self._get_steps().values():
|
|
step_config: StepConfig = step_func._step_config
|
|
|
|
# Do not collect events from the done step
|
|
if step_func.__name__ == "_done":
|
|
continue
|
|
|
|
for event_type in step_config.return_types:
|
|
if issubclass(event_type, Event):
|
|
events_found.add(event_type)
|
|
for event_type in step_config.accepted_events:
|
|
if issubclass(event_type, Event):
|
|
events_found.add(event_type)
|
|
|
|
return list(events_found)
|
|
|
|
def _ensure_stop_event_class(self) -> type[RunResultT]:
|
|
"""
|
|
Returns the StopEvent type used in this workflow.
|
|
|
|
It works by inspecting the events returned.
|
|
"""
|
|
stop_events_found: set[type[StopEvent]] = set()
|
|
for step_func in self._get_steps().values():
|
|
step_config: StepConfig = step_func._step_config
|
|
for event_type in step_config.return_types:
|
|
if issubclass(event_type, StopEvent):
|
|
stop_events_found.add(event_type)
|
|
|
|
num_found = len(stop_events_found)
|
|
if num_found == 0:
|
|
cls_name = self.__class__.__name__
|
|
msg = (
|
|
"At least one Event of type StopEvent must be returned by any step. "
|
|
f"(Workflow '{cls_name}' has no @step that returns StopEvent.)"
|
|
)
|
|
raise WorkflowConfigurationError(msg)
|
|
elif num_found > 1:
|
|
cls_name = self.__class__.__name__
|
|
msg = (
|
|
f"Only one type of StopEvent is allowed per workflow, found {num_found}: {stop_events_found} "
|
|
f"in workflow '{cls_name}'."
|
|
)
|
|
raise WorkflowConfigurationError(msg)
|
|
else:
|
|
return stop_events_found.pop()
|
|
|
|
@property
|
|
def stop_event_class(self) -> type[RunResultT]:
|
|
"""The `StopEvent` subclass produced by this workflow.
|
|
|
|
Determined by inspecting step return annotations.
|
|
"""
|
|
return self._stop_event_class
|
|
|
|
@classmethod
|
|
def add_step(cls, func: StepFunction) -> None:
|
|
"""
|
|
Adds a free function as step for this workflow instance.
|
|
|
|
It raises an exception if a step with the same name was already added to the workflow.
|
|
"""
|
|
step_config: StepConfig | None = getattr(func, "_step_config", None)
|
|
if not step_config:
|
|
msg = f"Step function {func.__name__} is missing the `@step` decorator."
|
|
raise WorkflowValidationError(msg)
|
|
|
|
if func.__name__ in {**get_steps_from_class(cls), **cls._step_functions}:
|
|
msg = f"A step {func.__name__} is already part of this workflow, please choose another name."
|
|
raise WorkflowValidationError(msg)
|
|
|
|
cls._step_functions[func.__name__] = func
|
|
|
|
def _get_steps(self) -> dict[str, StepFunction]:
|
|
"""Returns all the steps, whether defined as methods or free functions."""
|
|
return {**get_steps_from_instance(self), **self.__class__._step_functions}
|
|
|
|
def _get_start_event_instance(
|
|
self, start_event: StartEvent | None, **kwargs: Any
|
|
) -> StartEvent:
|
|
if start_event is not None:
|
|
# start_event was used wrong
|
|
if not isinstance(start_event, StartEvent):
|
|
msg = "The 'start_event' argument must be an instance of 'StartEvent'."
|
|
raise ValueError(msg)
|
|
|
|
# start_event is ok but point out that additional kwargs will be ignored in this case
|
|
if kwargs:
|
|
msg = (
|
|
"Keyword arguments are not supported when 'run()' is invoked with the 'start_event' parameter."
|
|
f" These keyword arguments will be ignored: {kwargs}"
|
|
)
|
|
logger.warning(msg)
|
|
return start_event
|
|
|
|
# Old style start event creation, with kwargs used to create an instance of self._start_event_class
|
|
try:
|
|
return self._start_event_class(**kwargs)
|
|
except ValidationError as e:
|
|
ev_name = self._start_event_class.__name__
|
|
msg = f"Failed creating a start event of type '{ev_name}' with the keyword arguments: {kwargs}"
|
|
logger.debug(e)
|
|
raise WorkflowRuntimeError(msg)
|
|
|
|
@dispatcher.span
|
|
def run(
|
|
self,
|
|
ctx: Context | None = None,
|
|
start_event: StartEvent | None = None,
|
|
**kwargs: Any,
|
|
) -> WorkflowHandler:
|
|
"""Run the workflow and return a handler for results and streaming.
|
|
|
|
This schedules the workflow execution in the background and returns a
|
|
[WorkflowHandler][workflows.handler.WorkflowHandler] that can be awaited
|
|
for the final result or used to stream intermediate events.
|
|
|
|
You may pass either a concrete `start_event` instance or keyword
|
|
arguments that will be used to construct the inferred
|
|
[StartEvent][workflows.events.StartEvent] subclass.
|
|
|
|
Args:
|
|
ctx (Context | None): Optional context to resume or share state
|
|
across runs. If omitted, a fresh context is created.
|
|
start_event (StartEvent | None): Optional explicit start event.
|
|
**kwargs (Any): Keyword args to initialize the start event when
|
|
`start_event` is not provided.
|
|
|
|
Returns:
|
|
WorkflowHandler: A future-like object to await the final result and
|
|
stream events.
|
|
|
|
Raises:
|
|
WorkflowValidationError: If validation fails and validation is
|
|
enabled.
|
|
WorkflowRuntimeError: If the start event cannot be created from kwargs.
|
|
WorkflowTimeoutError: If execution exceeds the configured timeout.
|
|
|
|
Examples:
|
|
```python
|
|
# Create and run with kwargs
|
|
handler = MyFlow().run(topic="Pirates")
|
|
|
|
# Stream events
|
|
async for ev in handler.stream_events():
|
|
...
|
|
|
|
# Await final result
|
|
result = await handler
|
|
```
|
|
|
|
If you subclassed the start event, you can also directly pass it in:
|
|
|
|
```python
|
|
result = await my_workflow.run(start_event=MyStartEvent(topic="Pirates"))
|
|
```
|
|
"""
|
|
from workflows.context import Context
|
|
|
|
# Validate the workflow
|
|
self._validate()
|
|
|
|
# If a previous context is provided, pass its serialized form
|
|
ctx = ctx if ctx is not None else Context(self)
|
|
start_event_instance: StartEvent | None = (
|
|
None
|
|
if ctx.is_running
|
|
else self._get_start_event_instance(start_event, **kwargs)
|
|
)
|
|
return ctx._workflow_run(
|
|
workflow=self, start_event=start_event_instance, semaphore=self._sem
|
|
)
|
|
|
|
def validate(self) -> bool:
|
|
"""
|
|
Validate the workflow to ensure it's well-formed.
|
|
|
|
Returns True if the workflow uses human-in-the-loop, False otherwise.
|
|
"""
|
|
return self._validate()
|
|
|
|
def _validate(self) -> bool:
|
|
if self._disable_validation:
|
|
return False
|
|
|
|
# Ensure at least one step is configured before inspecting events
|
|
if not self._get_steps():
|
|
cls_name = self.__class__.__name__
|
|
msg = (
|
|
f"Workflow '{cls_name}' has no configured steps. "
|
|
"Did you forget to annotate methods with @step or to register "
|
|
"free-function steps via @step(workflow=...)?"
|
|
)
|
|
raise WorkflowConfigurationError(msg)
|
|
|
|
# Recompute StartEvent and StopEvent classes here to support dynamic changes
|
|
# and to surface StartEvent errors before StopEvent during validation.
|
|
self._start_event_class = self._ensure_start_event_class()
|
|
self._stop_event_class = self._ensure_stop_event_class()
|
|
|
|
produced_events: set[type] = {self._start_event_class}
|
|
consumed_events: set[type] = set()
|
|
|
|
# Collect steps that incorrectly accept StopEvent
|
|
steps_accepting_stop_event: list[str] = []
|
|
|
|
for name, step_func in self._get_steps().items():
|
|
step_config: StepConfig = step_func._step_config
|
|
|
|
# Check that no user-defined step accepts StopEvent (only _done step should)
|
|
if name != "_done":
|
|
for event_type in step_config.accepted_events:
|
|
if issubclass(event_type, StopEvent):
|
|
steps_accepting_stop_event.append(name)
|
|
break
|
|
|
|
for event_type in step_config.accepted_events:
|
|
consumed_events.add(event_type)
|
|
|
|
for event_type in step_config.return_types:
|
|
if event_type is type(None):
|
|
# some events may not trigger other events
|
|
continue
|
|
|
|
produced_events.add(event_type)
|
|
|
|
# Raise error if any steps incorrectly accept StopEvent
|
|
if steps_accepting_stop_event:
|
|
step_names = "', '".join(steps_accepting_stop_event)
|
|
plural = "" if len(steps_accepting_stop_event) == 1 else "s"
|
|
msg = f"Step{plural} '{step_names}' cannot accept StopEvent. StopEvent signals the end of the workflow. Use a different Event type instead."
|
|
raise WorkflowValidationError(msg)
|
|
|
|
# Check if no StopEvent is produced
|
|
stop_ok = False
|
|
for ev in produced_events:
|
|
if issubclass(ev, StopEvent):
|
|
stop_ok = True
|
|
break
|
|
if not stop_ok:
|
|
msg = "No event of type StopEvent is produced."
|
|
raise WorkflowValidationError(msg)
|
|
|
|
# Check if all consumed events are produced (except specific built-in events)
|
|
unconsumed_events = consumed_events - produced_events
|
|
unconsumed_events = {
|
|
x
|
|
for x in unconsumed_events
|
|
if not issubclass(x, (InputRequiredEvent, HumanResponseEvent, StopEvent))
|
|
}
|
|
if unconsumed_events:
|
|
names = ", ".join(ev.__name__ for ev in unconsumed_events)
|
|
raise WorkflowValidationError(
|
|
f"The following events are consumed but never produced: {names}"
|
|
)
|
|
|
|
# Check if there are any unused produced events (except specific built-in events)
|
|
unused_events = produced_events - consumed_events
|
|
unused_events = {
|
|
x
|
|
for x in unused_events
|
|
if not issubclass(
|
|
x, (InputRequiredEvent, HumanResponseEvent, self._stop_event_class)
|
|
)
|
|
}
|
|
if unused_events:
|
|
names = ", ".join(ev.__name__ for ev in unused_events)
|
|
raise WorkflowValidationError(
|
|
f"The following events are produced but never consumed: {names}"
|
|
)
|
|
|
|
# Check if the workflow uses human-in-the-loop
|
|
return (
|
|
InputRequiredEvent in produced_events
|
|
or HumanResponseEvent in consumed_events
|
|
)
|