brief-extractor/backend/venv/lib/python3.10/site-packages/workflows/utils.py
2026-03-06 18:42:46 +00:00

263 lines
7.8 KiB
Python
Executable file

# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.
from __future__ import annotations
import inspect
from typing import (
Annotated,
Any,
Callable,
Optional,
get_args,
get_origin,
get_type_hints,
)
try:
from typing import Union
except ImportError: # pragma: no cover
from typing_extensions import Union
# handle python version compatibility
try:
from types import UnionType # type: ignore[attr-defined]
except ImportError: # pragma: no cover
from typing import Union as UnionType # type: ignore[assignment]
from pydantic import BaseModel
from .errors import WorkflowValidationError
from .events import Event, EventType
from .resource import ResourceDefinition
BUSY_WAIT_DELAY = 0.01
class StepSignatureSpec(BaseModel):
"""A Pydantic model representing the signature of a step function or method."""
accepted_events: dict[str, list[EventType]]
return_types: list[Any]
context_parameter: str | None
context_state_type: Any | None
resources: list[Any]
def inspect_signature(fn: Callable) -> StepSignatureSpec:
"""
Given a function, ensure the signature is compatible with a workflow step.
Args:
fn (Callable): The function to inspect.
Returns:
StepSignatureSpec: A specification object containing:
- accepted_events: Dictionary mapping parameter names to their event types
- return_types: List of return type annotations
- context_parameter: Name of the context parameter if present
Raises:
TypeError: If fn is not a callable object
"""
if not callable(fn):
raise TypeError(f"Expected a callable object, got {type(fn).__name__}")
sig = inspect.signature(fn)
type_hints = get_type_hints(fn, include_extras=True)
accepted_events: dict[str, list[EventType]] = {}
context_parameter = None
context_state_type = None
resources = []
# Inspect function parameters
for name, t in sig.parameters.items():
# Ignore self and cls
if name in ("self", "cls"):
continue
annotation = type_hints.get(name, t.annotation)
# Handle Context[StateType] annotations
if get_origin(annotation) is not None:
origin = get_origin(annotation)
args = get_args(annotation)
# Check if this is Context[StateType]
if hasattr(origin, "__name__") and origin.__name__ == "Context":
context_parameter = name
# Extract state type from generic parameter
if args:
context_state_type = args[0]
continue
# Handle Annotated types for resources
if get_origin(annotation) is Annotated:
_, resource = get_args(annotation)
resources.append(ResourceDefinition(name=name, resource=resource))
continue
# Get name and type of the Context param (without state type)
if hasattr(annotation, "__name__") and annotation.__name__ == "Context":
context_parameter = name
continue
# Collect name and types of the event param
param_types = _get_param_types(t, type_hints)
if all(
param_t == Event
or (inspect.isclass(param_t) and issubclass(param_t, Event))
for param_t in param_types
):
accepted_events[name] = param_types
continue
return StepSignatureSpec(
accepted_events=accepted_events,
return_types=_get_return_types(fn),
context_parameter=context_parameter,
context_state_type=context_state_type,
resources=resources,
)
def validate_step_signature(spec: StepSignatureSpec) -> None:
"""
Validate that a step signature specification meets workflow requirements.
Args:
spec (StepSignatureSpec): The signature specification to validate.
Raises:
WorkflowValidationError: If the signature is invalid for a workflow step.
"""
num_of_events = len(spec.accepted_events)
if num_of_events == 0:
msg = "Step signature must have at least one parameter annotated as type Event"
raise WorkflowValidationError(msg)
elif num_of_events > 1:
msg = f"Step signature must contain exactly one parameter of type Event but found {num_of_events}."
raise WorkflowValidationError(msg)
if not spec.return_types:
msg = "Return types of workflows step functions must be annotated with their type."
raise WorkflowValidationError(msg)
def get_steps_from_class(_class: object) -> dict[str, Callable]:
"""
Given a class, return the list of its methods that were defined as steps.
Args:
_class (object): The class to inspect for step methods.
Returns:
dict[str, Callable]: A dictionary mapping step names to their corresponding methods.
"""
step_methods: dict[str, Callable] = {}
all_methods = inspect.getmembers(_class, predicate=inspect.isfunction)
for name, method in all_methods:
if hasattr(method, "__step_config"):
step_methods[name] = method
return step_methods
def get_steps_from_instance(workflow: object) -> dict[str, Callable]:
"""
Given a workflow instance, return the list of its methods that were defined as steps.
Args:
workflow (object): The workflow instance to inspect.
Returns:
dict[str, Callable]: A dictionary mapping step names to their corresponding methods.
"""
step_methods: dict[str, Callable] = {}
all_methods = inspect.getmembers(workflow, predicate=inspect.ismethod)
for name, method in all_methods:
if hasattr(method, "__step_config"):
step_methods[name] = method
return step_methods
def _get_param_types(param: inspect.Parameter, type_hints: dict) -> list[Any]:
"""
Extract and process the types of a parameter.
This helper function handles Union and Optional types, returning a list of the actual types.
For Union[A, None] (Optional[A]), it returns [A].
Args:
param (inspect.Parameter): The parameter to analyze.
type_hints (dict): The resolved type hints for the function.
Returns:
list[Any]: A list of extracted types, excluding None from Unions/Optionals.
"""
typ = type_hints.get(param.name, param.annotation)
if typ is inspect.Parameter.empty:
return [Any]
if get_origin(typ) in (Union, Optional, UnionType):
return [t for t in get_args(typ) if t is not type(None)]
return [typ]
def _get_return_types(func: Callable) -> list[Any]:
"""
Extract the return type hints from a function.
Handles Union, Optional, and List types.
"""
type_hints = get_type_hints(func)
return_hint = type_hints.get("return")
if return_hint is None:
return []
origin = get_origin(return_hint)
if origin in (Union, UnionType):
# Optional is Union[type, None] so it's covered here
return [t for t in get_args(return_hint) if t is not type(None)]
else:
return [return_hint]
def is_free_function(qualname: str) -> bool:
"""
Determines whether a certain qualified name points to a free function.
A free function is either a module-level function or a nested function.
This implementation follows PEP-3155 for handling nested function detection.
Args:
qualname (str): The qualified name to analyze.
Returns:
bool: True if the name represents a free function, False otherwise.
Raises:
ValueError: If the qualified name is empty.
"""
if not qualname:
msg = "The qualified name cannot be empty"
raise ValueError(msg)
toks = qualname.split(".")
if len(toks) == 1:
# e.g. `my_function`
return True
elif "<locals>" not in toks:
# e.g. `MyClass.my_method`
return False
else:
return toks[-2] == "<locals>"