ferrero-opentext/Python-Version/venv/lib/python3.12/site-packages/workflows/server/server.py

1678 lines
61 KiB
Python

# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass
import json
import logging
from importlib.metadata import version
from pathlib import Path
from typing import Any, AsyncGenerator, Callable, Awaitable
from datetime import datetime, timezone
from llama_index_instrumentation.dispatcher import instrument_tags
import uvicorn
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, StreamingResponse
from starlette.routing import Route
from starlette.schemas import SchemaGenerator
from starlette.staticfiles import StaticFiles
from workflows import Context, Workflow
from workflows.events import (
Event,
InternalDispatchEvent,
StartEvent,
StepState,
StepStateChanged,
StopEvent,
)
from workflows.handler import WorkflowHandler
from workflows.protocol import (
CancelHandlerResponse,
HandlerData,
HandlersListResponse,
HealthResponse,
SendEventResponse,
WorkflowEventsListResponse,
WorkflowGraphResponse,
WorkflowSchemaResponse,
is_status_completed,
)
from workflows.server.abstract_workflow_store import (
AbstractWorkflowStore,
HandlerQuery,
PersistentHandler,
Status,
)
from workflows.server.memory_workflow_store import MemoryWorkflowStore
from workflows.types import RunResultT
# Protocol models are used on the client side; server responds with plain dicts
from workflows.utils import _nanoid as nanoid
from .representation_utils import _extract_workflow_structure
from workflows.protocol.serializable_events import (
EventValidationError,
EventEnvelopeWithMetadata,
EventEnvelope,
)
logger = logging.getLogger()
class WorkflowServer:
def __init__(
self,
*,
middleware: list[Middleware] | None = None,
workflow_store: AbstractWorkflowStore | None = None,
# retry/backoff seconds for persisting the handler state in the store after failures. Configurable mainly for testing.
persistence_backoff: list[float] = [0.5, 3],
):
self._workflows: dict[str, Workflow] = {}
self._additional_events: dict[str, list[type[Event]] | None] = {}
self._contexts: dict[str, Context] = {}
self._handlers: dict[str, _WorkflowHandler] = {}
self._results: dict[str, RunResultT] = {}
self._workflow_store = (
workflow_store if workflow_store is not None else MemoryWorkflowStore()
)
self._assets_path = Path(__file__).parent / "static"
self._persistence_backoff = list(persistence_backoff)
self._middleware = middleware or [
Middleware(
CORSMiddleware,
# regex echoes the origin header back, which some browsers require (rather than "*") when credentials are required
allow_origin_regex=".*",
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
]
self._routes = [
Route(
"/workflows",
self._list_workflows,
methods=["GET"],
),
Route(
"/workflows/{name}/run",
self._run_workflow,
methods=["POST"],
),
Route(
"/workflows/{name}/run-nowait",
self._run_workflow_nowait,
methods=["POST"],
),
Route(
"/workflows/{name}/schema",
self._get_events_schema,
methods=["GET"],
),
Route(
"/results/{handler_id}",
self._get_workflow_result,
methods=["GET"],
),
Route(
"/events/{handler_id}",
self._stream_events,
methods=["GET"],
),
Route(
"/events/{handler_id}",
self._post_event,
methods=["POST"],
),
Route(
"/health",
self._health_check,
methods=["GET"],
),
Route(
"/handlers",
self._get_handlers,
methods=["GET"],
),
Route(
"/handlers/{handler_id}",
self._get_workflow_handler,
methods=["GET"],
),
Route(
"/handlers/{handler_id}/cancel",
self._cancel_handler,
methods=["POST"],
),
Route(
"/workflows/{name}/representation",
self._get_workflow_representation,
methods=["GET"],
),
Route(
"/workflows/{name}/events",
self._list_workflow_events,
methods=["GET"],
),
]
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
async with self.contextmanager():
yield
self.app = Starlette(
routes=self._routes,
middleware=self._middleware,
lifespan=lifespan,
)
# Serve the UI as static files
self.app.mount(
"/", app=StaticFiles(directory=self._assets_path, html=True), name="ui"
)
def add_workflow(
self,
name: str,
workflow: Workflow,
additional_events: list[type[Event]] | None = None,
) -> None:
self._workflows[name] = workflow
if additional_events is not None:
self._additional_events[name] = additional_events
async def start(self) -> "WorkflowServer":
"""Resumes previously running workflows, if they were not complete at last shutdown"""
handlers = await self._workflow_store.query(
HandlerQuery(
status_in=["running"], workflow_name_in=list(self._workflows.keys())
)
)
for persistent in handlers:
workflow = self._workflows[persistent.workflow_name]
try:
await self._start_workflow(
workflow=_NamedWorkflow(
name=persistent.workflow_name, workflow=workflow
),
handler_id=persistent.handler_id,
context=Context.from_dict(workflow=workflow, data=persistent.ctx),
)
except Exception as e:
logger.error(
f"Failed to resume handler {persistent.handler_id} for workflow {persistent.workflow_name}: {e}"
)
try:
now = datetime.now(timezone.utc)
await self._workflow_store.update(
PersistentHandler(
handler_id=persistent.handler_id,
workflow_name=persistent.workflow_name,
status="failed",
run_id=persistent.run_id,
error=str(e),
result=None,
started_at=persistent.started_at,
updated_at=now,
completed_at=now,
ctx=persistent.ctx,
)
)
except Exception:
pass
continue
return self
@asynccontextmanager
async def contextmanager(self) -> AsyncGenerator["WorkflowServer", None]:
"""Use this server as a context manager to start and stop it"""
await self.start()
try:
yield self
finally:
await self.stop()
async def stop(self) -> None:
logger.info(
f"Shutting down Workflow server. Cancelling {len(self._handlers)} handlers."
)
await asyncio.gather(
*[self._close_handler(handler) for handler in list(self._handlers.values())]
)
self._handlers.clear()
self._results.clear()
async def serve(
self,
host: str = "localhost",
port: int = 80,
uvicorn_config: dict[str, Any] | None = None,
) -> None:
"""Run the server."""
uvicorn_config = uvicorn_config or {}
config = uvicorn.Config(self.app, host=host, port=port, **uvicorn_config)
server = uvicorn.Server(config)
logger.info(
f"Starting Workflow server at http://{host}:{port}{uvicorn_config.get('root_path', '/')}"
)
await server.serve()
def openapi_schema(self) -> dict:
app = self.app
gen = SchemaGenerator(
{
"openapi": "3.0.0",
"info": {
"title": "Workflows API",
"version": version("llama-index-workflows"),
},
"components": {
"schemas": {
"EventEnvelopeWithMetadata": {
"type": "object",
"properties": {
"value": {"type": "object"},
"types": {"type": "array", "items": {"type": "string"}},
"type": {"type": "string"},
"qualified_name": {"type": "string"},
},
"required": ["value", "type"],
},
"Handler": {
"type": "object",
"properties": {
"handler_id": {"type": "string"},
"workflow_name": {"type": "string"},
"run_id": {"type": "string", "nullable": True},
"status": {
"type": "string",
"enum": [
"running",
"completed",
"failed",
"cancelled",
],
},
"started_at": {"type": "string", "format": "date-time"},
"updated_at": {
"type": "string",
"format": "date-time",
"nullable": True,
},
"completed_at": {
"type": "string",
"format": "date-time",
"nullable": True,
},
"error": {"type": "string", "nullable": True},
"result": {
"description": "Workflow result value",
"oneOf": [
{
"$ref": "#/components/schemas/EventEnvelopeWithMetadata"
},
{"type": "null"},
],
},
},
"required": [
"handler_id",
"workflow_name",
"status",
"started_at",
],
},
"HandlersList": {
"type": "object",
"properties": {
"handlers": {
"type": "array",
"items": {"$ref": "#/components/schemas/Handler"},
}
},
"required": ["handlers"],
},
}
},
}
)
return gen.get_schema(app.routes)
#
# HTTP endpoints
#
async def _health_check(self, request: Request) -> JSONResponse:
"""
---
summary: Health check
description: Returns the server health status.
responses:
200:
description: Successful health check
content:
application/json:
schema:
type: object
properties:
status:
type: string
example: healthy
required: [status]
"""
return JSONResponse(HealthResponse(status="healthy").model_dump())
async def _list_workflows(self, request: Request) -> JSONResponse:
"""
---
summary: List workflows
description: Returns the list of registered workflow names.
responses:
200:
description: List of workflows
content:
application/json:
schema:
type: object
properties:
workflows:
type: array
items:
type: string
required: [workflows]
"""
workflow_names = list(self._workflows.keys())
return JSONResponse({"workflows": workflow_names})
async def _list_workflow_events(self, request: Request) -> JSONResponse:
"""
---
summary: List workflow events
description: Returns the list of registered workflow event schemas.
parameters:
- in: path
name: name
required: true
schema:
type: string
description: Registered workflow name.
responses:
200:
description: List of workflow event schemas
content:
application/json:
schema:
type: object
properties:
events:
type: array
description: List of workflow event JSON schemas
items:
type: object
required: [events]
"""
if "name" not in request.path_params:
raise HTTPException(status_code=400, detail="name param is required")
name = request.path_params["name"]
if name not in self._workflows:
raise HTTPException(status_code=404, detail=f"Workflow '{name}' not found")
events = self._workflows[name].events
additional_events = self._additional_events.get(name, [])
if additional_events:
events.extend(additional_events)
event_objs = []
for event in events:
event_objs.append(event.model_json_schema())
return JSONResponse(WorkflowEventsListResponse(events=event_objs).model_dump())
async def _run_workflow(self, request: Request) -> JSONResponse:
"""
---
summary: Run workflow (wait)
description: |
Runs the specified workflow synchronously and returns the final result.
The request body may include an optional serialized start event, an optional
context object, and optional keyword arguments passed to the workflow run.
parameters:
- in: path
name: name
required: true
schema:
type: string
description: Registered workflow name.
requestBody:
required: false
content:
application/json:
schema:
type: object
properties:
start_event:
type: object
description: 'Plain JSON object representing the start event (e.g., {"message": "..."}).'
context:
type: object
description: Serialized workflow Context.
handler_id:
type: string
description: Workflow handler identifier to continue from a previous completed run.
kwargs:
type: object
description: Additional keyword arguments for the workflow.
responses:
200:
description: Workflow completed successfully
content:
application/json:
schema:
$ref: '#/components/schemas/Handler'
400:
description: Invalid start_event payload
404:
description: Workflow or handler identifier not found
500:
description: Error running workflow or invalid request body
"""
workflow = self._extract_workflow(request)
context, start_event, handler_id = await self._extract_run_params(
request, workflow.workflow, workflow.name
)
if start_event is not None:
input_ev = workflow.workflow.start_event_class.model_validate(start_event)
else:
input_ev = None
try:
wrapper = await self._start_workflow(
workflow=_NamedWorkflow(name=workflow.name, workflow=workflow.workflow),
handler_id=handler_id,
context=context,
start_event=input_ev,
)
handler = wrapper.run_handler
try:
await handler
status = 200
except Exception as e:
status = 500
logger.error(f"Error running workflow: {e}", exc_info=True)
if wrapper.task is not None:
try:
await wrapper.task
except Exception:
pass
# explicitly close handlers from this synchronous api so they don't linger with events
# that no-one is listening for
await self._close_handler(wrapper)
return JSONResponse(
wrapper.to_response_model().model_dump(), status_code=status
)
except Exception as e:
status = 500
logger.error(f"Error running workflow: {e}", exc_info=True)
raise HTTPException(
detail=f"Error running workflow: {e}", status_code=status
)
async def _get_events_schema(self, request: Request) -> JSONResponse:
"""
---
summary: Get JSON schema for start event
description: |
Gets the JSON schema of the start and stop events from the specified workflow and returns it under "start" (start event) and "stop" (stop event)
parameters:
- in: path
name: name
required: true
schema:
type: string
description: Registered workflow name.
requestBody:
required: false
responses:
200:
description: JSON schema successfully retrieved for start event
content:
application/json:
schema:
type: object
properties:
start:
description: JSON schema for the start event
stop:
description: JSON schema for the stop event
required: [start, stop]
404:
description: Workflow not found
500:
description: Error while getting the JSON schema for the start or stop event
"""
workflow = self._extract_workflow(request)
try:
start_event_schema = workflow.workflow.start_event_class.model_json_schema()
except Exception as e:
raise HTTPException(
detail=f"Error getting schema of start event for workflow: {e}",
status_code=500,
)
try:
stop_event_schema = workflow.workflow.stop_event_class.model_json_schema()
except Exception as e:
raise HTTPException(
detail=f"Error getting schema of stop event for workflow: {e}",
status_code=500,
)
return JSONResponse(
WorkflowSchemaResponse(
start=start_event_schema, stop=stop_event_schema
).model_dump()
)
async def _get_workflow_representation(self, request: Request) -> JSONResponse:
"""
---
summary: Get the representation of the workflow
description: |
Get the representation of the workflow as a directed graph in JSON format
parameters:
- in: path
name: name
required: true
schema:
type: string
description: Registered workflow name.
requestBody:
required: false
responses:
200:
description: JSON representation successfully retrieved
content:
application/json:
schema:
type: object
properties:
graph:
description: the elements of the JSON representation of the workflow
required: [graph]
404:
description: Workflow not found
500:
description: Error while getting JSON workflow representation
"""
workflow = self._extract_workflow(request)
try:
workflow_graph = _extract_workflow_structure(workflow.workflow)
except Exception as e:
raise HTTPException(
detail=f"Error while getting JSON workflow representation: {e}",
status_code=500,
)
return JSONResponse(
WorkflowGraphResponse(graph=workflow_graph.to_response_model()).model_dump()
)
async def _run_workflow_nowait(self, request: Request) -> JSONResponse:
"""
---
summary: Run workflow (no-wait)
description: |
Starts the specified workflow asynchronously and returns a handler identifier
which can be used to query results or stream events.
parameters:
- in: path
name: name
required: true
schema:
type: string
description: Registered workflow name.
requestBody:
required: false
content:
application/json:
schema:
type: object
properties:
start_event:
type: object
description: 'Plain JSON object representing the start event (e.g., {"message": "..."}).'
context:
type: object
description: Serialized workflow Context.
handler_id:
type: string
description: Workflow handler identifier to continue from a previous completed run.
kwargs:
type: object
description: Additional keyword arguments for the workflow.
responses:
200:
description: Workflow started
content:
application/json:
schema:
$ref: '#/components/schemas/Handler'
400:
description: Invalid start_event payload
404:
description: Workflow or handler identifier not found
"""
workflow = self._extract_workflow(request)
context, start_event, handler_id = await self._extract_run_params(
request, workflow.workflow, workflow.name
)
if start_event is not None:
input_ev = workflow.workflow.start_event_class.model_validate(start_event)
else:
input_ev = None
try:
wrapper = await self._start_workflow(
workflow=_NamedWorkflow(name=workflow.name, workflow=workflow.workflow),
handler_id=handler_id,
context=context,
start_event=input_ev,
)
except Exception as e:
raise HTTPException(
detail=f"Initial persistence failed: {e}", status_code=500
)
return JSONResponse(wrapper.to_response_model().model_dump())
async def _load_handler(self, handler_id: str) -> HandlerData:
wrapper = self._handlers.get(handler_id)
if wrapper is None:
found = await self._workflow_store.query(
HandlerQuery(handler_id_in=[handler_id])
)
if not found:
raise HTTPException(detail="Handler not found", status_code=404)
existing = found[0]
return _WorkflowHandler.handler_data_from_persistent(existing)
else:
if wrapper.run_handler.done() and wrapper.task is not None:
try:
await wrapper.task # make sure its fully done
except Exception:
# failed workflows raise their exception here
pass # failed workflows raise their exception here
return wrapper.to_response_model()
async def _get_workflow_result(self, request: Request) -> JSONResponse:
"""
---
summary: Get workflow result (deprecated)
description: |
Deprecated. Use GET /handlers/{handler_id} instead. Returns the final result of an asynchronously started workflow, if available.
parameters:
- in: path
name: handler_id
required: true
schema:
type: string
description: Workflow run identifier returned from the no-wait run endpoint.
deprecated: true
responses:
200:
description: Result is available
content:
application/json:
schema:
type: object
202:
description: Result not ready yet
content:
application/json:
schema:
type: object
404:
description: Handler not found
500:
description: Error computing result
content:
text/plain:
schema:
type: string
"""
handler_id = request.path_params["handler_id"]
if not handler_id:
raise HTTPException(detail="Handler ID is required", status_code=400)
handler_data = await self._load_handler(handler_id)
status = (
202
if handler_data.status in "running"
else 200
if handler_data.status == "completed"
else 500
)
response_model = handler_data.model_dump()
# compatibility. Use handler api instead
if not handler_data.result:
response_model["result"] = None
else:
type = handler_data.result.qualified_name
response_model["result"] = (
handler_data.result.value.get("result")
if type == "workflows.events.StopEvent"
else handler_data.result.value
)
return JSONResponse(response_model, status_code=status)
async def _get_workflow_handler(self, request: Request) -> JSONResponse:
"""
---
summary: Get workflow handler
description: Returns the final result of an asynchronously started workflow, if available
parameters:
- in: path
name: handler_id
required: true
schema:
type: string
description: Workflow run identifier returned from the no-wait run endpoint.
responses:
200:
description: Result is available
content:
application/json:
schema:
$ref: '#/components/schemas/Handler'
202:
description: Result not ready yet
content:
application/json:
schema:
$ref: '#/components/schemas/Handler'
404:
description: Handler not found
500:
description: Error computing result
content:
text/plain:
schema:
type: string
"""
handler_id = request.path_params["handler_id"]
if not handler_id:
raise HTTPException(detail="Handler ID is required", status_code=400)
handler_data = await self._load_handler(handler_id)
status = (
202
if handler_data.status in "running"
else 200
if handler_data.status == "completed"
else 500
)
return JSONResponse(handler_data.model_dump(), status_code=status)
async def _stream_events(self, request: Request) -> StreamingResponse:
"""
---
summary: Stream workflow events
description: |
Streams events produced by a workflow execution. Events are emitted as
newline-delimited JSON by default, or as Server-Sent Events when `sse=true`.
Event data is returned as an envelope that preserves backward-compatible fields
and adds metadata for type-safety on the client:
{
"value": <pydantic serialized value>,
"types": [<class names from MRO excluding the event class and base Event>],
"type": <class name>,
"qualified_name": <python module path + class name>,
}
Event queue is mutable. Elements are added to the queue by the workflow handler, and removed by any consumer of the queue.
The queue is protected by a lock that is acquired by the consumer, so only one consumer of the queue at a time is allowed.
parameters:
- in: path
name: handler_id
required: true
schema:
type: string
description: Identifier returned from the no-wait run endpoint.
- in: query
name: sse
required: false
schema:
type: boolean
default: true
description: If false, as NDJSON instead of Server-Sent Events.
- in: query
name: include_internal
required: false
schema:
type: boolean
default: false
description: If true, include internal workflow events (e.g., step state changes).
- in: query
name: acquire_timeout
required: false
schema:
type: number
default: 1
description: Timeout for acquiring the lock to iterate over the events.
- in: query
name: include_qualified_name
required: false
schema:
type: boolean
default: true
description: If true, include the qualified name of the event in the response body.
responses:
200:
description: Streaming started
content:
text/event-stream:
schema:
type: object
description: Server-Sent Events stream of event data.
properties:
value:
type: object
description: The event value.
type:
type: string
description: The class name of the event.
types:
type: array
description: Superclass names from MRO (excluding the event class and base Event).
items:
type: string
qualified_name:
type: string
description: The qualified name of the event.
required: [value, type]
404:
description: Handler not found
"""
handler_id = request.path_params["handler_id"]
timeout = request.query_params.get("acquire_timeout", "1").lower()
include_internal = (
request.query_params.get("include_internal", "false").lower() == "true"
)
include_qualified_name = (
request.query_params.get("include_qualified_name", "true").lower() == "true"
)
sse = request.query_params.get("sse", "true").lower() == "true"
try:
timeout = float(timeout)
except ValueError:
raise HTTPException(
detail=f"Invalid acquire_timeout: '{timeout}'", status_code=400
)
handler = self._handlers.get(handler_id)
if handler is None:
persisted = await self._workflow_store.query(
HandlerQuery(handler_id_in=[handler_id])
)
if persisted:
status = persisted[0].status
if status in {"completed", "failed", "cancelled"}:
raise HTTPException(detail="Handler is completed", status_code=204)
raise HTTPException(detail="Handler not found", status_code=404)
if handler.queue.empty() and handler.task is not None and handler.task.done():
# https://html.spec.whatwg.org/multipage/server-sent-events.html
# Clients will reconnect if the connection is closed; a client can
# be told to stop reconnecting using the HTTP 204 No Content response code.
raise HTTPException(detail="Handler is completed", status_code=204)
# Get raw_event query parameter
media_type = "text/event-stream" if sse else "application/x-ndjson"
try:
generator = await handler.acquire_events_stream(timeout=timeout)
except NoLockAvailable as e:
raise HTTPException(
detail=f"No lock available to acquire after {timeout}s timeout",
status_code=409,
) from e
async def event_stream(handler: _WorkflowHandler) -> AsyncGenerator[str, None]:
async for event in generator:
if not include_internal and isinstance(event, InternalDispatchEvent):
continue
envelope = EventEnvelopeWithMetadata.from_event(
event, include_qualified_name=include_qualified_name
)
payload = envelope.model_dump_json()
if sse:
# emit as untyped data. Difficult to subscribe to dynamic event types with SSE.
yield f"data: {payload}\n\n"
else:
yield f"{payload}\n"
await asyncio.sleep(0)
return StreamingResponse(event_stream(handler), media_type=media_type)
async def _get_handlers(self, request: Request) -> JSONResponse:
"""
---
summary: Get handlers
description: Returns workflow handlers, optionally filtered by query parameters.
parameters:
- in: query
name: status
required: false
schema:
type: array
items:
type: string
enum: [running, completed, failed, cancelled]
style: form
explode: true
description: |
Filter by handler status. Can be provided multiple times (e.g., status=running&status=failed)
- in: query
name: workflow_name
required: false
schema:
type: array
items:
type: string
style: form
explode: true
description: |
Filter by workflow name. Can be provided multiple times (e.g., workflow_name=test&workflow_name=other)
responses:
200:
description: List of handlers
content:
application/json:
schema:
$ref: '#/components/schemas/HandlersList'
"""
def _parse_list_param(param_name: str) -> list[str] | None:
# parse repeated params
values = list(request.query_params.getlist(param_name))
if not values:
single = request.query_params.get(param_name) or ""
values = [single]
values = [value.strip() for value in values if value.strip()]
if not values:
return None
return values
# Parse filters
status_values = _parse_list_param("status")
workflow_name_in = _parse_list_param("workflow_name")
# Narrow types for status to match HandlerQuery expectations
allowed_status_values: set[Status] = {
"running",
"completed",
"failed",
"cancelled",
}
status_in = (
list(set(allowed_status_values).intersection(status_values))
if status_values is not None
else None
)
persistent_handlers = await self._workflow_store.query(
HandlerQuery(status_in=status_in, workflow_name_in=workflow_name_in)
)
items = [
HandlerData(
handler_id=h.handler_id,
workflow_name=h.workflow_name,
run_id=h.run_id,
status=h.status,
started_at=h.started_at.isoformat() if h.started_at else "",
updated_at=h.updated_at.isoformat() if h.updated_at else None,
completed_at=h.completed_at.isoformat() if h.completed_at else None,
error=h.error,
result=EventEnvelopeWithMetadata.from_event(h.result)
if h.result
else None,
)
for h in persistent_handlers
]
return JSONResponse(HandlersListResponse(handlers=items).model_dump())
async def _post_event(self, request: Request) -> JSONResponse:
"""
---
summary: Send event to workflow
description: Sends an event to a running workflow's context.
parameters:
- in: path
name: handler_id
required: true
schema:
type: string
description: Workflow handler identifier.
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
event:
description: Serialized event. Accepts object or JSON-encoded string for backward compatibility.
oneOf:
- type: string
description: JSON string of the event envelope or value.
examples:
- '{"type": "ExternalEvent", "value": {"response": "hi"}}'
- type: object
properties:
type:
type: string
description: The class name of the event.
value:
type: object
description: The event value object (preferred over data).
additionalProperties: true
step:
type: string
description: Optional target step name. If not provided, event is sent to all steps.
required: [event]
responses:
200:
description: Event sent successfully
content:
application/json:
schema:
type: object
properties:
status:
type: string
enum: [sent]
required: [status]
400:
description: Invalid event data
404:
description: Handler not found
409:
description: Workflow already completed
"""
handler_id = request.path_params["handler_id"]
# Check if handler exists
wrapper = self._handlers.get(handler_id)
if wrapper is not None and is_status_completed(wrapper.status):
raise HTTPException(detail="Workflow already completed", status_code=409)
if wrapper is None:
handler_data = await self._load_handler(handler_id)
if is_status_completed(handler_data.status):
raise HTTPException(
detail="Workflow already completed", status_code=409
)
else:
# this branch is for cases where handler status is running but somehow not in memory
# Ideally, this should never happen. We probably need to revisit when we add pause/expire functionality.
logger.warning(f"Handler {handler_id} is running but not in memory.")
raise HTTPException(detail="Handler expired", status_code=409)
handler = wrapper.run_handler
# Get the context
ctx = handler.ctx
if ctx is None:
raise HTTPException(detail="Context not available", status_code=500)
# Parse request body
try:
body = await request.json()
event_str = body.get("event")
step = body.get("step")
if not event_str:
raise HTTPException(detail="Event data is required", status_code=400)
# Deserialize the event
try:
event = EventEnvelope.parse(
event_str, self._event_registry(wrapper.workflow_name)
)
except EventValidationError as e:
raise HTTPException(detail=str(e), status_code=400)
except Exception as e:
raise HTTPException(
detail=f"Failed to deserialize event: {e}", status_code=400
)
# Send the event to the context
try:
ctx.send_event(event, step=step)
except Exception as e:
raise HTTPException(
detail=f"Failed to send event: {e}", status_code=400
)
return JSONResponse(SendEventResponse(status="sent").model_dump())
except HTTPException:
raise
except Exception as e:
raise HTTPException(
detail=f"Error processing request: {e}", status_code=500
)
async def _cancel_handler(self, request: Request) -> JSONResponse:
"""
---
summary: Stop and delete handler
description: |
Stops a running workflow handler by cancelling its tasks. Optionally removes the
handler from the persistence store if purge=true.
parameters:
- in: path
name: handler_id
required: true
schema:
type: string
description: Workflow handler identifier.
- in: query
name: purge
required: false
schema:
type: boolean
default: false
description: If true, also deletes the handler from the store, otherwise updates the status to cancelled.
responses:
200:
description: Handler cancelled and deleted or cancelled only
content:
application/json:
schema:
type: object
properties:
status:
type: string
enum: [deleted, cancelled]
required: [status]
404:
description: Handler not found
"""
handler_id = request.path_params["handler_id"]
# Simple boolean parsing aligned with other APIs (e.g., `sse`): only "true" enables
purge = request.query_params.get("purge", "false").lower() == "true"
wrapper = self._handlers.get(handler_id)
if wrapper is None and not purge:
raise HTTPException(detail="Handler not found", status_code=404)
# Close the handler if it exists (this will cancel and trigger auto-checkpoint)
if wrapper is not None:
await self._close_handler(wrapper)
# Handle persistence
if purge:
n_deleted = await self._workflow_store.delete(
HandlerQuery(handler_id_in=[handler_id])
)
if n_deleted == 0:
raise HTTPException(detail="Handler not found", status_code=404)
return JSONResponse(
CancelHandlerResponse(
status="deleted" if purge else "cancelled"
).model_dump()
)
#
# Private methods
#
def _extract_workflow(self, request: Request) -> _NamedWorkflow:
if "name" not in request.path_params:
raise HTTPException(detail="'name' parameter missing", status_code=400)
name = request.path_params["name"]
if name not in self._workflows:
raise HTTPException(detail="Workflow not found", status_code=404)
return _NamedWorkflow(name=name, workflow=self._workflows[name])
async def _extract_run_params(
self, request: Request, workflow: Workflow, workflow_name: str
) -> tuple[Context | None, StartEvent | None, str]:
try:
try:
body = await request.json()
except Exception as e:
raise HTTPException(detail=f"Invalid JSON body: {e}", status_code=400)
context_data = body.get("context")
run_kwargs = body.get("kwargs", {})
start_event_data = body.get("start_event", run_kwargs)
handler_id = body.get("handler_id")
# Extract custom StartEvent if present
start_event = None
if start_event_data is not None:
try:
start_event = EventEnvelope.parse(
start_event_data,
self._event_registry(workflow_name),
explicit_event=workflow.start_event_class,
)
except Exception as e:
raise HTTPException(
detail=f"Validation error for 'start_event': {e}",
status_code=400,
)
if start_event is not None and not isinstance(
start_event, workflow.start_event_class
):
raise HTTPException(
detail=f"Start event must be an instance of {workflow.start_event_class}",
status_code=400,
)
# Extract custom Context if present
context = None
if context_data:
context = Context.from_dict(workflow=workflow, data=context_data)
elif handler_id:
persisted_handlers = await self._workflow_store.query(
HandlerQuery(
handler_id_in=[handler_id],
workflow_name_in=[workflow_name],
status_in=["completed"],
)
)
if len(persisted_handlers) == 0:
raise HTTPException(detail="Handler not found", status_code=404)
context = Context.from_dict(workflow, persisted_handlers[0].ctx)
handler_id = handler_id or nanoid()
return (context, start_event, handler_id)
except HTTPException:
# Re-raise HTTPExceptions as-is (like start_event validation errors)
raise
except Exception as e:
raise HTTPException(
detail=f"Error processing request body: {e}", status_code=500
)
async def _start_workflow(
self,
workflow: _NamedWorkflow,
handler_id: str,
start_event: StartEvent | None = None,
context: Context | None = None,
) -> _WorkflowHandler:
"""Start a workflow and return a wrapper for the handler."""
with instrument_tags({"handler_id": handler_id}):
handler = workflow.workflow.run(
ctx=context,
start_event=start_event,
)
wrapper = await self._run_workflow_handler(
handler_id, workflow.name, handler
)
return wrapper
async def _run_workflow_handler(
self, handler_id: str, workflow_name: str, handler: WorkflowHandler
) -> _WorkflowHandler:
"""
Creates a wrapper for the handler and starts streaming events.
"""
queue: asyncio.Queue[Event] = asyncio.Queue()
started_at = datetime.now(timezone.utc)
wrapper = _WorkflowHandler(
run_handler=handler,
queue=queue,
task=None, # Will be set by start_streaming()
consumer_mutex=asyncio.Lock(),
handler_id=handler_id,
workflow_name=workflow_name,
started_at=started_at,
updated_at=started_at,
completed_at=None,
_workflow_store=self._workflow_store,
_persistence_backoff=self._persistence_backoff,
)
# Initial checkpoint before registration; fail fast if persistence is unavailable
await wrapper.checkpoint()
# Now register and start streaming
self._handlers[handler_id] = wrapper
async def on_finish() -> None:
self._handlers.pop(handler_id, None)
self._results.pop(handler_id, None)
wrapper.start_streaming(on_finish=on_finish)
return wrapper
async def _close_handler(self, handler: _WorkflowHandler) -> None:
"""Close and cleanup a handler."""
# Cancel the run_handler if not done
if not handler.run_handler.done():
try:
handler.run_handler.cancel()
except Exception:
pass
try:
await handler.run_handler.cancel_run()
except Exception:
pass
if handler.task is not None:
await handler.task
self._handlers.pop(handler.handler_id, None)
self._results.pop(handler.handler_id, None)
def _event_registry(self, workflow_name: str) -> dict[str, type[Event]]:
items = {e.__name__: e for e in self._workflows[workflow_name].events}
items.update(
{
e.__name__: e
for e in self._additional_events.get(workflow_name, None) or []
}
)
return items
@dataclass
class _WorkflowHandler:
"""A wrapper around a handler: WorkflowHandler. Necessary to monitor and dispatch events from the handler's stream_events"""
run_handler: WorkflowHandler
queue: asyncio.Queue[Event]
task: asyncio.Task[None] | None
# only one consumer of the queue at a time allowed
consumer_mutex: asyncio.Lock
# metadata
handler_id: str
workflow_name: str
started_at: datetime
updated_at: datetime
completed_at: datetime | None
# Dependencies for persistence
_workflow_store: AbstractWorkflowStore
_persistence_backoff: list[float]
_on_finish: Callable[[], Awaitable[None]] | None = None
def _as_persistent(self) -> PersistentHandler:
"""Persist the current handler state immediately to the workflow store."""
self.updated_at = datetime.now(timezone.utc)
if self.status in ("completed", "failed", "cancelled"):
self.completed_at = self.updated_at
persistent = PersistentHandler(
handler_id=self.handler_id,
workflow_name=self.workflow_name,
status=self.status,
run_id=self.run_handler.run_id,
error=self.error,
result=self.result,
started_at=self.started_at,
updated_at=self.updated_at,
completed_at=self.completed_at,
ctx=self.run_handler.ctx.to_dict() if self.run_handler.ctx else {},
)
return persistent
async def persist(self, persistent: PersistentHandler) -> None:
await self._workflow_store.update(persistent)
async def checkpoint(self) -> None:
"""Persist with retry/backoff; cancel handler when retries exhausted."""
backoffs = list(self._persistence_backoff)
try:
persistent = self._as_persistent()
except Exception as e:
logger.error(
f"Failed to checkpoint handler {self.handler_id} to persistent state. Is there non-serializable state in an event or the state store? {e}",
exc_info=True,
)
raise
while True:
try:
await self.persist(persistent)
return
except Exception as e:
backoff = backoffs.pop(0) if backoffs else None
if backoff is None:
logger.error(
f"Failed to checkpoint handler {self.handler_id} after final attempt. Failing the handler.",
exc_info=True,
)
# Cancel the underlying workflow; do not re-raise here to allow callers to decide behavior
try:
self.run_handler.cancel()
except Exception:
pass
raise
logger.error(
f"Failed to checkpoint handler {self.handler_id}. Retrying in {backoff} seconds: {e}"
)
await asyncio.sleep(backoff)
def to_response_model(self) -> HandlerData:
"""Convert runtime handler to API response model."""
return HandlerData(
handler_id=self.handler_id,
workflow_name=self.workflow_name,
run_id=self.run_handler.run_id,
status=self.status,
started_at=self.started_at.isoformat(),
updated_at=self.updated_at.isoformat(),
completed_at=self.completed_at.isoformat()
if self.completed_at is not None
else None,
error=self.error,
result=EventEnvelopeWithMetadata.from_event(self.result)
if self.result is not None
else None,
)
@staticmethod
def handler_data_from_persistent(persistent: PersistentHandler) -> HandlerData:
return HandlerData(
handler_id=persistent.handler_id,
workflow_name=persistent.workflow_name,
run_id=persistent.run_id,
status=persistent.status,
started_at=persistent.started_at.isoformat()
if persistent.started_at is not None
else datetime.now(timezone.utc).isoformat(),
updated_at=persistent.updated_at.isoformat()
if persistent.updated_at is not None
else None,
completed_at=persistent.completed_at.isoformat()
if persistent.completed_at is not None
else None,
error=persistent.error,
result=EventEnvelopeWithMetadata.from_event(persistent.result)
if persistent.result is not None
else None,
)
@property
def status(self) -> Status:
"""Get the current status by inspecting the handler state."""
if not self.run_handler.done():
return "running"
# done - check if cancelled first
if self.run_handler.cancelled():
return "cancelled"
# then check for exception
exc = self.run_handler.exception()
if exc is not None:
return "failed"
return "completed"
@property
def error(self) -> str | None:
if not self.run_handler.done():
return None
try:
exc = self.run_handler.exception()
except asyncio.CancelledError:
return None
return str(exc) if exc is not None else None
@property
def result(self) -> StopEvent | None:
if not self.run_handler.done():
return None
try:
return self.run_handler.get_stop_event()
except asyncio.CancelledError:
return None
except Exception:
return None
def start_streaming(self, on_finish: Callable[[], Awaitable[None]]) -> None:
"""Start streaming events from the handler and managing state."""
self.task = asyncio.create_task(self._stream_events(on_finish=on_finish))
async def _stream_events(self, on_finish: Callable[[], Awaitable[None]]) -> None:
"""Internal method that streams events, updates status, and persists state."""
with instrument_tags({"handler_id": self.handler_id}):
await self.checkpoint()
self._on_finish = on_finish
async for event in self.run_handler.stream_events(expose_internal=True):
if ( # Watch for a specific internal event that signals the step is complete
isinstance(event, StepStateChanged)
and event.step_state == StepState.NOT_RUNNING
):
state = (
self.run_handler.ctx.to_dict() if self.run_handler.ctx else None
)
if state is None:
logger.warning(
f"Context state is None for handler {self.handler_id}. This is not expected."
)
continue
await self.checkpoint()
self.queue.put_nowait(event)
# done when stream events are complete
try:
await self.run_handler
except asyncio.CancelledError:
# Handler was cancelled - status will be automatically detected via handler.cancelled()
logger.info(f"Workflow run {self.handler_id} was cancelled")
# Don't re-raise, just let the task complete
except Exception as e:
logger.error(
f"Workflow run {self.handler_id} failed! {e}", exc_info=True
)
await self.checkpoint()
async def acquire_events_stream(
self, timeout: float = 1
) -> AsyncGenerator[Event, None]:
"""
Acquires the lock to iterate over the events, and returns generator of events.
"""
try:
await asyncio.wait_for(self.consumer_mutex.acquire(), timeout=timeout)
except asyncio.TimeoutError:
raise NoLockAvailable(
f"No lock available to acquire after {timeout}s timeout"
)
return self._iter_events(timeout=timeout)
async def _iter_events(self, timeout: float = 1) -> AsyncGenerator[Event, None]:
"""
Converts the queue to an async generator while the workflow is still running, and there are still events.
For better or worse, multiple consumers will compete for events
"""
try:
while not self.queue.empty() or (
self.task is not None and not self.task.done()
):
available_events = []
while not self.queue.empty():
available_events.append(self.queue.get_nowait())
for event in available_events:
yield event
queue_get_task: asyncio.Task[Event] = asyncio.create_task(
self.queue.get()
)
task_waitable = self.task
done, pending = await asyncio.wait(
{queue_get_task, task_waitable}
if task_waitable is not None
else {queue_get_task},
return_when=asyncio.FIRST_COMPLETED,
)
if queue_get_task in done:
yield await queue_get_task
else: # otherwise task completed, so nothing else will be published to the queue
queue_get_task.cancel()
break
finally:
if self._on_finish is not None and self.run_handler.done():
# clean up the resources if the stream has been consumed
await self._on_finish()
self.consumer_mutex.release()
class NoLockAvailable(Exception):
"""Raised when no lock is available to acquire after a timeout"""
pass
@dataclass
class _NamedWorkflow:
name: str
workflow: Workflow
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate OpenAPI schema")
parser.add_argument(
"--output", type=str, default="openapi.json", help="Output file path"
)
args = parser.parse_args()
server = WorkflowServer()
dict_schema = server.openapi_schema()
with open(args.output, "w") as f:
json.dump(dict_schema, indent=2, fp=f)
print(f"OpenAPI schema written to {args.output}")