1678 lines
61 KiB
Python
1678 lines
61 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025 LlamaIndex Inc.
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
import json
|
|
import logging
|
|
from importlib.metadata import version
|
|
from pathlib import Path
|
|
from typing import Any, AsyncGenerator, Callable, Awaitable
|
|
from datetime import datetime, timezone
|
|
|
|
from llama_index_instrumentation.dispatcher import instrument_tags
|
|
import uvicorn
|
|
from starlette.applications import Starlette
|
|
from starlette.exceptions import HTTPException
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|
from starlette.routing import Route
|
|
from starlette.schemas import SchemaGenerator
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
from workflows import Context, Workflow
|
|
from workflows.events import (
|
|
Event,
|
|
InternalDispatchEvent,
|
|
StartEvent,
|
|
StepState,
|
|
StepStateChanged,
|
|
StopEvent,
|
|
)
|
|
from workflows.handler import WorkflowHandler
|
|
|
|
|
|
from workflows.protocol import (
|
|
CancelHandlerResponse,
|
|
HandlerData,
|
|
HandlersListResponse,
|
|
HealthResponse,
|
|
SendEventResponse,
|
|
WorkflowEventsListResponse,
|
|
WorkflowGraphResponse,
|
|
WorkflowSchemaResponse,
|
|
is_status_completed,
|
|
)
|
|
from workflows.server.abstract_workflow_store import (
|
|
AbstractWorkflowStore,
|
|
HandlerQuery,
|
|
PersistentHandler,
|
|
Status,
|
|
)
|
|
from workflows.server.memory_workflow_store import MemoryWorkflowStore
|
|
from workflows.types import RunResultT
|
|
|
|
# Protocol models are used on the client side; server responds with plain dicts
|
|
from workflows.utils import _nanoid as nanoid
|
|
from .representation_utils import _extract_workflow_structure
|
|
from workflows.protocol.serializable_events import (
|
|
EventValidationError,
|
|
EventEnvelopeWithMetadata,
|
|
EventEnvelope,
|
|
)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class WorkflowServer:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
middleware: list[Middleware] | None = None,
|
|
workflow_store: AbstractWorkflowStore | None = None,
|
|
# retry/backoff seconds for persisting the handler state in the store after failures. Configurable mainly for testing.
|
|
persistence_backoff: list[float] = [0.5, 3],
|
|
):
|
|
self._workflows: dict[str, Workflow] = {}
|
|
self._additional_events: dict[str, list[type[Event]] | None] = {}
|
|
self._contexts: dict[str, Context] = {}
|
|
self._handlers: dict[str, _WorkflowHandler] = {}
|
|
self._results: dict[str, RunResultT] = {}
|
|
self._workflow_store = (
|
|
workflow_store if workflow_store is not None else MemoryWorkflowStore()
|
|
)
|
|
self._assets_path = Path(__file__).parent / "static"
|
|
self._persistence_backoff = list(persistence_backoff)
|
|
|
|
self._middleware = middleware or [
|
|
Middleware(
|
|
CORSMiddleware,
|
|
# regex echoes the origin header back, which some browsers require (rather than "*") when credentials are required
|
|
allow_origin_regex=".*",
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
allow_credentials=True,
|
|
)
|
|
]
|
|
|
|
self._routes = [
|
|
Route(
|
|
"/workflows",
|
|
self._list_workflows,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/workflows/{name}/run",
|
|
self._run_workflow,
|
|
methods=["POST"],
|
|
),
|
|
Route(
|
|
"/workflows/{name}/run-nowait",
|
|
self._run_workflow_nowait,
|
|
methods=["POST"],
|
|
),
|
|
Route(
|
|
"/workflows/{name}/schema",
|
|
self._get_events_schema,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/results/{handler_id}",
|
|
self._get_workflow_result,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/events/{handler_id}",
|
|
self._stream_events,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/events/{handler_id}",
|
|
self._post_event,
|
|
methods=["POST"],
|
|
),
|
|
Route(
|
|
"/health",
|
|
self._health_check,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/handlers",
|
|
self._get_handlers,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/handlers/{handler_id}",
|
|
self._get_workflow_handler,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/handlers/{handler_id}/cancel",
|
|
self._cancel_handler,
|
|
methods=["POST"],
|
|
),
|
|
Route(
|
|
"/workflows/{name}/representation",
|
|
self._get_workflow_representation,
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/workflows/{name}/events",
|
|
self._list_workflow_events,
|
|
methods=["GET"],
|
|
),
|
|
]
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
|
|
async with self.contextmanager():
|
|
yield
|
|
|
|
self.app = Starlette(
|
|
routes=self._routes,
|
|
middleware=self._middleware,
|
|
lifespan=lifespan,
|
|
)
|
|
# Serve the UI as static files
|
|
self.app.mount(
|
|
"/", app=StaticFiles(directory=self._assets_path, html=True), name="ui"
|
|
)
|
|
|
|
def add_workflow(
|
|
self,
|
|
name: str,
|
|
workflow: Workflow,
|
|
additional_events: list[type[Event]] | None = None,
|
|
) -> None:
|
|
self._workflows[name] = workflow
|
|
if additional_events is not None:
|
|
self._additional_events[name] = additional_events
|
|
|
|
async def start(self) -> "WorkflowServer":
|
|
"""Resumes previously running workflows, if they were not complete at last shutdown"""
|
|
handlers = await self._workflow_store.query(
|
|
HandlerQuery(
|
|
status_in=["running"], workflow_name_in=list(self._workflows.keys())
|
|
)
|
|
)
|
|
for persistent in handlers:
|
|
workflow = self._workflows[persistent.workflow_name]
|
|
try:
|
|
await self._start_workflow(
|
|
workflow=_NamedWorkflow(
|
|
name=persistent.workflow_name, workflow=workflow
|
|
),
|
|
handler_id=persistent.handler_id,
|
|
context=Context.from_dict(workflow=workflow, data=persistent.ctx),
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to resume handler {persistent.handler_id} for workflow {persistent.workflow_name}: {e}"
|
|
)
|
|
|
|
try:
|
|
now = datetime.now(timezone.utc)
|
|
await self._workflow_store.update(
|
|
PersistentHandler(
|
|
handler_id=persistent.handler_id,
|
|
workflow_name=persistent.workflow_name,
|
|
status="failed",
|
|
run_id=persistent.run_id,
|
|
error=str(e),
|
|
result=None,
|
|
started_at=persistent.started_at,
|
|
updated_at=now,
|
|
completed_at=now,
|
|
ctx=persistent.ctx,
|
|
)
|
|
)
|
|
except Exception:
|
|
pass
|
|
continue
|
|
|
|
return self
|
|
|
|
@asynccontextmanager
|
|
async def contextmanager(self) -> AsyncGenerator["WorkflowServer", None]:
|
|
"""Use this server as a context manager to start and stop it"""
|
|
await self.start()
|
|
try:
|
|
yield self
|
|
finally:
|
|
await self.stop()
|
|
|
|
async def stop(self) -> None:
|
|
logger.info(
|
|
f"Shutting down Workflow server. Cancelling {len(self._handlers)} handlers."
|
|
)
|
|
await asyncio.gather(
|
|
*[self._close_handler(handler) for handler in list(self._handlers.values())]
|
|
)
|
|
self._handlers.clear()
|
|
self._results.clear()
|
|
|
|
async def serve(
|
|
self,
|
|
host: str = "localhost",
|
|
port: int = 80,
|
|
uvicorn_config: dict[str, Any] | None = None,
|
|
) -> None:
|
|
"""Run the server."""
|
|
uvicorn_config = uvicorn_config or {}
|
|
|
|
config = uvicorn.Config(self.app, host=host, port=port, **uvicorn_config)
|
|
server = uvicorn.Server(config)
|
|
logger.info(
|
|
f"Starting Workflow server at http://{host}:{port}{uvicorn_config.get('root_path', '/')}"
|
|
)
|
|
|
|
await server.serve()
|
|
|
|
def openapi_schema(self) -> dict:
|
|
app = self.app
|
|
gen = SchemaGenerator(
|
|
{
|
|
"openapi": "3.0.0",
|
|
"info": {
|
|
"title": "Workflows API",
|
|
"version": version("llama-index-workflows"),
|
|
},
|
|
"components": {
|
|
"schemas": {
|
|
"EventEnvelopeWithMetadata": {
|
|
"type": "object",
|
|
"properties": {
|
|
"value": {"type": "object"},
|
|
"types": {"type": "array", "items": {"type": "string"}},
|
|
"type": {"type": "string"},
|
|
"qualified_name": {"type": "string"},
|
|
},
|
|
"required": ["value", "type"],
|
|
},
|
|
"Handler": {
|
|
"type": "object",
|
|
"properties": {
|
|
"handler_id": {"type": "string"},
|
|
"workflow_name": {"type": "string"},
|
|
"run_id": {"type": "string", "nullable": True},
|
|
"status": {
|
|
"type": "string",
|
|
"enum": [
|
|
"running",
|
|
"completed",
|
|
"failed",
|
|
"cancelled",
|
|
],
|
|
},
|
|
"started_at": {"type": "string", "format": "date-time"},
|
|
"updated_at": {
|
|
"type": "string",
|
|
"format": "date-time",
|
|
"nullable": True,
|
|
},
|
|
"completed_at": {
|
|
"type": "string",
|
|
"format": "date-time",
|
|
"nullable": True,
|
|
},
|
|
"error": {"type": "string", "nullable": True},
|
|
"result": {
|
|
"description": "Workflow result value",
|
|
"oneOf": [
|
|
{
|
|
"$ref": "#/components/schemas/EventEnvelopeWithMetadata"
|
|
},
|
|
{"type": "null"},
|
|
],
|
|
},
|
|
},
|
|
"required": [
|
|
"handler_id",
|
|
"workflow_name",
|
|
"status",
|
|
"started_at",
|
|
],
|
|
},
|
|
"HandlersList": {
|
|
"type": "object",
|
|
"properties": {
|
|
"handlers": {
|
|
"type": "array",
|
|
"items": {"$ref": "#/components/schemas/Handler"},
|
|
}
|
|
},
|
|
"required": ["handlers"],
|
|
},
|
|
}
|
|
},
|
|
}
|
|
)
|
|
|
|
return gen.get_schema(app.routes)
|
|
|
|
#
|
|
# HTTP endpoints
|
|
#
|
|
|
|
async def _health_check(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Health check
|
|
description: Returns the server health status.
|
|
responses:
|
|
200:
|
|
description: Successful health check
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
status:
|
|
type: string
|
|
example: healthy
|
|
required: [status]
|
|
"""
|
|
return JSONResponse(HealthResponse(status="healthy").model_dump())
|
|
|
|
async def _list_workflows(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: List workflows
|
|
description: Returns the list of registered workflow names.
|
|
responses:
|
|
200:
|
|
description: List of workflows
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
workflows:
|
|
type: array
|
|
items:
|
|
type: string
|
|
required: [workflows]
|
|
"""
|
|
workflow_names = list(self._workflows.keys())
|
|
return JSONResponse({"workflows": workflow_names})
|
|
|
|
async def _list_workflow_events(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: List workflow events
|
|
description: Returns the list of registered workflow event schemas.
|
|
parameters:
|
|
- in: path
|
|
name: name
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Registered workflow name.
|
|
responses:
|
|
200:
|
|
description: List of workflow event schemas
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
events:
|
|
type: array
|
|
description: List of workflow event JSON schemas
|
|
items:
|
|
type: object
|
|
required: [events]
|
|
"""
|
|
if "name" not in request.path_params:
|
|
raise HTTPException(status_code=400, detail="name param is required")
|
|
|
|
name = request.path_params["name"]
|
|
if name not in self._workflows:
|
|
raise HTTPException(status_code=404, detail=f"Workflow '{name}' not found")
|
|
|
|
events = self._workflows[name].events
|
|
additional_events = self._additional_events.get(name, [])
|
|
if additional_events:
|
|
events.extend(additional_events)
|
|
|
|
event_objs = []
|
|
for event in events:
|
|
event_objs.append(event.model_json_schema())
|
|
|
|
return JSONResponse(WorkflowEventsListResponse(events=event_objs).model_dump())
|
|
|
|
async def _run_workflow(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Run workflow (wait)
|
|
description: |
|
|
Runs the specified workflow synchronously and returns the final result.
|
|
The request body may include an optional serialized start event, an optional
|
|
context object, and optional keyword arguments passed to the workflow run.
|
|
parameters:
|
|
- in: path
|
|
name: name
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Registered workflow name.
|
|
requestBody:
|
|
required: false
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
start_event:
|
|
type: object
|
|
description: 'Plain JSON object representing the start event (e.g., {"message": "..."}).'
|
|
context:
|
|
type: object
|
|
description: Serialized workflow Context.
|
|
handler_id:
|
|
type: string
|
|
description: Workflow handler identifier to continue from a previous completed run.
|
|
kwargs:
|
|
type: object
|
|
description: Additional keyword arguments for the workflow.
|
|
responses:
|
|
200:
|
|
description: Workflow completed successfully
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: '#/components/schemas/Handler'
|
|
400:
|
|
description: Invalid start_event payload
|
|
404:
|
|
description: Workflow or handler identifier not found
|
|
500:
|
|
description: Error running workflow or invalid request body
|
|
"""
|
|
workflow = self._extract_workflow(request)
|
|
context, start_event, handler_id = await self._extract_run_params(
|
|
request, workflow.workflow, workflow.name
|
|
)
|
|
|
|
if start_event is not None:
|
|
input_ev = workflow.workflow.start_event_class.model_validate(start_event)
|
|
else:
|
|
input_ev = None
|
|
|
|
try:
|
|
wrapper = await self._start_workflow(
|
|
workflow=_NamedWorkflow(name=workflow.name, workflow=workflow.workflow),
|
|
handler_id=handler_id,
|
|
context=context,
|
|
start_event=input_ev,
|
|
)
|
|
handler = wrapper.run_handler
|
|
try:
|
|
await handler
|
|
status = 200
|
|
except Exception as e:
|
|
status = 500
|
|
logger.error(f"Error running workflow: {e}", exc_info=True)
|
|
if wrapper.task is not None:
|
|
try:
|
|
await wrapper.task
|
|
except Exception:
|
|
pass
|
|
# explicitly close handlers from this synchronous api so they don't linger with events
|
|
# that no-one is listening for
|
|
await self._close_handler(wrapper)
|
|
|
|
return JSONResponse(
|
|
wrapper.to_response_model().model_dump(), status_code=status
|
|
)
|
|
except Exception as e:
|
|
status = 500
|
|
logger.error(f"Error running workflow: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
detail=f"Error running workflow: {e}", status_code=status
|
|
)
|
|
|
|
async def _get_events_schema(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Get JSON schema for start event
|
|
description: |
|
|
Gets the JSON schema of the start and stop events from the specified workflow and returns it under "start" (start event) and "stop" (stop event)
|
|
parameters:
|
|
- in: path
|
|
name: name
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Registered workflow name.
|
|
requestBody:
|
|
required: false
|
|
responses:
|
|
200:
|
|
description: JSON schema successfully retrieved for start event
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
start:
|
|
description: JSON schema for the start event
|
|
stop:
|
|
description: JSON schema for the stop event
|
|
required: [start, stop]
|
|
404:
|
|
description: Workflow not found
|
|
500:
|
|
description: Error while getting the JSON schema for the start or stop event
|
|
"""
|
|
workflow = self._extract_workflow(request)
|
|
try:
|
|
start_event_schema = workflow.workflow.start_event_class.model_json_schema()
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Error getting schema of start event for workflow: {e}",
|
|
status_code=500,
|
|
)
|
|
try:
|
|
stop_event_schema = workflow.workflow.stop_event_class.model_json_schema()
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Error getting schema of stop event for workflow: {e}",
|
|
status_code=500,
|
|
)
|
|
|
|
return JSONResponse(
|
|
WorkflowSchemaResponse(
|
|
start=start_event_schema, stop=stop_event_schema
|
|
).model_dump()
|
|
)
|
|
|
|
async def _get_workflow_representation(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Get the representation of the workflow
|
|
description: |
|
|
Get the representation of the workflow as a directed graph in JSON format
|
|
parameters:
|
|
- in: path
|
|
name: name
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Registered workflow name.
|
|
requestBody:
|
|
required: false
|
|
responses:
|
|
200:
|
|
description: JSON representation successfully retrieved
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
graph:
|
|
description: the elements of the JSON representation of the workflow
|
|
required: [graph]
|
|
404:
|
|
description: Workflow not found
|
|
500:
|
|
description: Error while getting JSON workflow representation
|
|
"""
|
|
workflow = self._extract_workflow(request)
|
|
try:
|
|
workflow_graph = _extract_workflow_structure(workflow.workflow)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Error while getting JSON workflow representation: {e}",
|
|
status_code=500,
|
|
)
|
|
return JSONResponse(
|
|
WorkflowGraphResponse(graph=workflow_graph.to_response_model()).model_dump()
|
|
)
|
|
|
|
async def _run_workflow_nowait(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Run workflow (no-wait)
|
|
description: |
|
|
Starts the specified workflow asynchronously and returns a handler identifier
|
|
which can be used to query results or stream events.
|
|
parameters:
|
|
- in: path
|
|
name: name
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Registered workflow name.
|
|
requestBody:
|
|
required: false
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
start_event:
|
|
type: object
|
|
description: 'Plain JSON object representing the start event (e.g., {"message": "..."}).'
|
|
context:
|
|
type: object
|
|
description: Serialized workflow Context.
|
|
handler_id:
|
|
type: string
|
|
description: Workflow handler identifier to continue from a previous completed run.
|
|
kwargs:
|
|
type: object
|
|
description: Additional keyword arguments for the workflow.
|
|
responses:
|
|
200:
|
|
description: Workflow started
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: '#/components/schemas/Handler'
|
|
400:
|
|
description: Invalid start_event payload
|
|
404:
|
|
description: Workflow or handler identifier not found
|
|
"""
|
|
workflow = self._extract_workflow(request)
|
|
context, start_event, handler_id = await self._extract_run_params(
|
|
request, workflow.workflow, workflow.name
|
|
)
|
|
|
|
if start_event is not None:
|
|
input_ev = workflow.workflow.start_event_class.model_validate(start_event)
|
|
else:
|
|
input_ev = None
|
|
|
|
try:
|
|
wrapper = await self._start_workflow(
|
|
workflow=_NamedWorkflow(name=workflow.name, workflow=workflow.workflow),
|
|
handler_id=handler_id,
|
|
context=context,
|
|
start_event=input_ev,
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Initial persistence failed: {e}", status_code=500
|
|
)
|
|
return JSONResponse(wrapper.to_response_model().model_dump())
|
|
|
|
async def _load_handler(self, handler_id: str) -> HandlerData:
|
|
wrapper = self._handlers.get(handler_id)
|
|
if wrapper is None:
|
|
found = await self._workflow_store.query(
|
|
HandlerQuery(handler_id_in=[handler_id])
|
|
)
|
|
if not found:
|
|
raise HTTPException(detail="Handler not found", status_code=404)
|
|
existing = found[0]
|
|
return _WorkflowHandler.handler_data_from_persistent(existing)
|
|
else:
|
|
if wrapper.run_handler.done() and wrapper.task is not None:
|
|
try:
|
|
await wrapper.task # make sure its fully done
|
|
except Exception:
|
|
# failed workflows raise their exception here
|
|
pass # failed workflows raise their exception here
|
|
|
|
return wrapper.to_response_model()
|
|
|
|
async def _get_workflow_result(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Get workflow result (deprecated)
|
|
description: |
|
|
Deprecated. Use GET /handlers/{handler_id} instead. Returns the final result of an asynchronously started workflow, if available.
|
|
parameters:
|
|
- in: path
|
|
name: handler_id
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Workflow run identifier returned from the no-wait run endpoint.
|
|
deprecated: true
|
|
responses:
|
|
200:
|
|
description: Result is available
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
202:
|
|
description: Result not ready yet
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
404:
|
|
description: Handler not found
|
|
500:
|
|
description: Error computing result
|
|
content:
|
|
text/plain:
|
|
schema:
|
|
type: string
|
|
"""
|
|
handler_id = request.path_params["handler_id"]
|
|
if not handler_id:
|
|
raise HTTPException(detail="Handler ID is required", status_code=400)
|
|
|
|
handler_data = await self._load_handler(handler_id)
|
|
status = (
|
|
202
|
|
if handler_data.status in "running"
|
|
else 200
|
|
if handler_data.status == "completed"
|
|
else 500
|
|
)
|
|
response_model = handler_data.model_dump()
|
|
|
|
# compatibility. Use handler api instead
|
|
if not handler_data.result:
|
|
response_model["result"] = None
|
|
else:
|
|
type = handler_data.result.qualified_name
|
|
response_model["result"] = (
|
|
handler_data.result.value.get("result")
|
|
if type == "workflows.events.StopEvent"
|
|
else handler_data.result.value
|
|
)
|
|
return JSONResponse(response_model, status_code=status)
|
|
|
|
async def _get_workflow_handler(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Get workflow handler
|
|
description: Returns the final result of an asynchronously started workflow, if available
|
|
parameters:
|
|
- in: path
|
|
name: handler_id
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Workflow run identifier returned from the no-wait run endpoint.
|
|
responses:
|
|
200:
|
|
description: Result is available
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: '#/components/schemas/Handler'
|
|
202:
|
|
description: Result not ready yet
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: '#/components/schemas/Handler'
|
|
404:
|
|
description: Handler not found
|
|
500:
|
|
description: Error computing result
|
|
content:
|
|
text/plain:
|
|
schema:
|
|
type: string
|
|
"""
|
|
handler_id = request.path_params["handler_id"]
|
|
if not handler_id:
|
|
raise HTTPException(detail="Handler ID is required", status_code=400)
|
|
|
|
handler_data = await self._load_handler(handler_id)
|
|
status = (
|
|
202
|
|
if handler_data.status in "running"
|
|
else 200
|
|
if handler_data.status == "completed"
|
|
else 500
|
|
)
|
|
return JSONResponse(handler_data.model_dump(), status_code=status)
|
|
|
|
async def _stream_events(self, request: Request) -> StreamingResponse:
|
|
"""
|
|
---
|
|
summary: Stream workflow events
|
|
description: |
|
|
Streams events produced by a workflow execution. Events are emitted as
|
|
newline-delimited JSON by default, or as Server-Sent Events when `sse=true`.
|
|
Event data is returned as an envelope that preserves backward-compatible fields
|
|
and adds metadata for type-safety on the client:
|
|
{
|
|
"value": <pydantic serialized value>,
|
|
"types": [<class names from MRO excluding the event class and base Event>],
|
|
"type": <class name>,
|
|
"qualified_name": <python module path + class name>,
|
|
}
|
|
|
|
Event queue is mutable. Elements are added to the queue by the workflow handler, and removed by any consumer of the queue.
|
|
The queue is protected by a lock that is acquired by the consumer, so only one consumer of the queue at a time is allowed.
|
|
|
|
parameters:
|
|
- in: path
|
|
name: handler_id
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Identifier returned from the no-wait run endpoint.
|
|
- in: query
|
|
name: sse
|
|
required: false
|
|
schema:
|
|
type: boolean
|
|
default: true
|
|
description: If false, as NDJSON instead of Server-Sent Events.
|
|
- in: query
|
|
name: include_internal
|
|
required: false
|
|
schema:
|
|
type: boolean
|
|
default: false
|
|
description: If true, include internal workflow events (e.g., step state changes).
|
|
- in: query
|
|
name: acquire_timeout
|
|
required: false
|
|
schema:
|
|
type: number
|
|
default: 1
|
|
description: Timeout for acquiring the lock to iterate over the events.
|
|
- in: query
|
|
name: include_qualified_name
|
|
required: false
|
|
schema:
|
|
type: boolean
|
|
default: true
|
|
description: If true, include the qualified name of the event in the response body.
|
|
responses:
|
|
200:
|
|
description: Streaming started
|
|
content:
|
|
text/event-stream:
|
|
schema:
|
|
type: object
|
|
description: Server-Sent Events stream of event data.
|
|
properties:
|
|
value:
|
|
type: object
|
|
description: The event value.
|
|
type:
|
|
type: string
|
|
description: The class name of the event.
|
|
types:
|
|
type: array
|
|
description: Superclass names from MRO (excluding the event class and base Event).
|
|
items:
|
|
type: string
|
|
qualified_name:
|
|
type: string
|
|
description: The qualified name of the event.
|
|
required: [value, type]
|
|
404:
|
|
description: Handler not found
|
|
"""
|
|
handler_id = request.path_params["handler_id"]
|
|
timeout = request.query_params.get("acquire_timeout", "1").lower()
|
|
include_internal = (
|
|
request.query_params.get("include_internal", "false").lower() == "true"
|
|
)
|
|
include_qualified_name = (
|
|
request.query_params.get("include_qualified_name", "true").lower() == "true"
|
|
)
|
|
sse = request.query_params.get("sse", "true").lower() == "true"
|
|
try:
|
|
timeout = float(timeout)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
detail=f"Invalid acquire_timeout: '{timeout}'", status_code=400
|
|
)
|
|
|
|
handler = self._handlers.get(handler_id)
|
|
if handler is None:
|
|
persisted = await self._workflow_store.query(
|
|
HandlerQuery(handler_id_in=[handler_id])
|
|
)
|
|
if persisted:
|
|
status = persisted[0].status
|
|
if status in {"completed", "failed", "cancelled"}:
|
|
raise HTTPException(detail="Handler is completed", status_code=204)
|
|
raise HTTPException(detail="Handler not found", status_code=404)
|
|
if handler.queue.empty() and handler.task is not None and handler.task.done():
|
|
# https://html.spec.whatwg.org/multipage/server-sent-events.html
|
|
# Clients will reconnect if the connection is closed; a client can
|
|
# be told to stop reconnecting using the HTTP 204 No Content response code.
|
|
raise HTTPException(detail="Handler is completed", status_code=204)
|
|
|
|
# Get raw_event query parameter
|
|
media_type = "text/event-stream" if sse else "application/x-ndjson"
|
|
|
|
try:
|
|
generator = await handler.acquire_events_stream(timeout=timeout)
|
|
except NoLockAvailable as e:
|
|
raise HTTPException(
|
|
detail=f"No lock available to acquire after {timeout}s timeout",
|
|
status_code=409,
|
|
) from e
|
|
|
|
async def event_stream(handler: _WorkflowHandler) -> AsyncGenerator[str, None]:
|
|
async for event in generator:
|
|
if not include_internal and isinstance(event, InternalDispatchEvent):
|
|
continue
|
|
envelope = EventEnvelopeWithMetadata.from_event(
|
|
event, include_qualified_name=include_qualified_name
|
|
)
|
|
payload = envelope.model_dump_json()
|
|
if sse:
|
|
# emit as untyped data. Difficult to subscribe to dynamic event types with SSE.
|
|
yield f"data: {payload}\n\n"
|
|
else:
|
|
yield f"{payload}\n"
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
return StreamingResponse(event_stream(handler), media_type=media_type)
|
|
|
|
async def _get_handlers(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Get handlers
|
|
description: Returns workflow handlers, optionally filtered by query parameters.
|
|
parameters:
|
|
- in: query
|
|
name: status
|
|
required: false
|
|
schema:
|
|
type: array
|
|
items:
|
|
type: string
|
|
enum: [running, completed, failed, cancelled]
|
|
style: form
|
|
explode: true
|
|
description: |
|
|
Filter by handler status. Can be provided multiple times (e.g., status=running&status=failed)
|
|
- in: query
|
|
name: workflow_name
|
|
required: false
|
|
schema:
|
|
type: array
|
|
items:
|
|
type: string
|
|
style: form
|
|
explode: true
|
|
description: |
|
|
Filter by workflow name. Can be provided multiple times (e.g., workflow_name=test&workflow_name=other)
|
|
responses:
|
|
200:
|
|
description: List of handlers
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: '#/components/schemas/HandlersList'
|
|
"""
|
|
|
|
def _parse_list_param(param_name: str) -> list[str] | None:
|
|
# parse repeated params
|
|
values = list(request.query_params.getlist(param_name))
|
|
if not values:
|
|
single = request.query_params.get(param_name) or ""
|
|
values = [single]
|
|
values = [value.strip() for value in values if value.strip()]
|
|
if not values:
|
|
return None
|
|
return values
|
|
|
|
# Parse filters
|
|
status_values = _parse_list_param("status")
|
|
workflow_name_in = _parse_list_param("workflow_name")
|
|
|
|
# Narrow types for status to match HandlerQuery expectations
|
|
allowed_status_values: set[Status] = {
|
|
"running",
|
|
"completed",
|
|
"failed",
|
|
"cancelled",
|
|
}
|
|
|
|
status_in = (
|
|
list(set(allowed_status_values).intersection(status_values))
|
|
if status_values is not None
|
|
else None
|
|
)
|
|
|
|
persistent_handlers = await self._workflow_store.query(
|
|
HandlerQuery(status_in=status_in, workflow_name_in=workflow_name_in)
|
|
)
|
|
items = [
|
|
HandlerData(
|
|
handler_id=h.handler_id,
|
|
workflow_name=h.workflow_name,
|
|
run_id=h.run_id,
|
|
status=h.status,
|
|
started_at=h.started_at.isoformat() if h.started_at else "",
|
|
updated_at=h.updated_at.isoformat() if h.updated_at else None,
|
|
completed_at=h.completed_at.isoformat() if h.completed_at else None,
|
|
error=h.error,
|
|
result=EventEnvelopeWithMetadata.from_event(h.result)
|
|
if h.result
|
|
else None,
|
|
)
|
|
for h in persistent_handlers
|
|
]
|
|
return JSONResponse(HandlersListResponse(handlers=items).model_dump())
|
|
|
|
async def _post_event(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Send event to workflow
|
|
description: Sends an event to a running workflow's context.
|
|
parameters:
|
|
- in: path
|
|
name: handler_id
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Workflow handler identifier.
|
|
requestBody:
|
|
required: true
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
event:
|
|
description: Serialized event. Accepts object or JSON-encoded string for backward compatibility.
|
|
oneOf:
|
|
- type: string
|
|
description: JSON string of the event envelope or value.
|
|
examples:
|
|
- '{"type": "ExternalEvent", "value": {"response": "hi"}}'
|
|
- type: object
|
|
properties:
|
|
type:
|
|
type: string
|
|
description: The class name of the event.
|
|
value:
|
|
type: object
|
|
description: The event value object (preferred over data).
|
|
additionalProperties: true
|
|
step:
|
|
type: string
|
|
description: Optional target step name. If not provided, event is sent to all steps.
|
|
required: [event]
|
|
responses:
|
|
200:
|
|
description: Event sent successfully
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
status:
|
|
type: string
|
|
enum: [sent]
|
|
required: [status]
|
|
400:
|
|
description: Invalid event data
|
|
404:
|
|
description: Handler not found
|
|
409:
|
|
description: Workflow already completed
|
|
"""
|
|
handler_id = request.path_params["handler_id"]
|
|
|
|
# Check if handler exists
|
|
wrapper = self._handlers.get(handler_id)
|
|
if wrapper is not None and is_status_completed(wrapper.status):
|
|
raise HTTPException(detail="Workflow already completed", status_code=409)
|
|
if wrapper is None:
|
|
handler_data = await self._load_handler(handler_id)
|
|
if is_status_completed(handler_data.status):
|
|
raise HTTPException(
|
|
detail="Workflow already completed", status_code=409
|
|
)
|
|
else:
|
|
# this branch is for cases where handler status is running but somehow not in memory
|
|
# Ideally, this should never happen. We probably need to revisit when we add pause/expire functionality.
|
|
logger.warning(f"Handler {handler_id} is running but not in memory.")
|
|
raise HTTPException(detail="Handler expired", status_code=409)
|
|
|
|
handler = wrapper.run_handler
|
|
|
|
# Get the context
|
|
ctx = handler.ctx
|
|
if ctx is None:
|
|
raise HTTPException(detail="Context not available", status_code=500)
|
|
|
|
# Parse request body
|
|
try:
|
|
body = await request.json()
|
|
event_str = body.get("event")
|
|
step = body.get("step")
|
|
|
|
if not event_str:
|
|
raise HTTPException(detail="Event data is required", status_code=400)
|
|
|
|
# Deserialize the event
|
|
|
|
try:
|
|
event = EventEnvelope.parse(
|
|
event_str, self._event_registry(wrapper.workflow_name)
|
|
)
|
|
except EventValidationError as e:
|
|
raise HTTPException(detail=str(e), status_code=400)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Failed to deserialize event: {e}", status_code=400
|
|
)
|
|
|
|
# Send the event to the context
|
|
try:
|
|
ctx.send_event(event, step=step)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Failed to send event: {e}", status_code=400
|
|
)
|
|
|
|
return JSONResponse(SendEventResponse(status="sent").model_dump())
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Error processing request: {e}", status_code=500
|
|
)
|
|
|
|
async def _cancel_handler(self, request: Request) -> JSONResponse:
|
|
"""
|
|
---
|
|
summary: Stop and delete handler
|
|
description: |
|
|
Stops a running workflow handler by cancelling its tasks. Optionally removes the
|
|
handler from the persistence store if purge=true.
|
|
parameters:
|
|
- in: path
|
|
name: handler_id
|
|
required: true
|
|
schema:
|
|
type: string
|
|
description: Workflow handler identifier.
|
|
- in: query
|
|
name: purge
|
|
required: false
|
|
schema:
|
|
type: boolean
|
|
default: false
|
|
description: If true, also deletes the handler from the store, otherwise updates the status to cancelled.
|
|
responses:
|
|
200:
|
|
description: Handler cancelled and deleted or cancelled only
|
|
content:
|
|
application/json:
|
|
schema:
|
|
type: object
|
|
properties:
|
|
status:
|
|
type: string
|
|
enum: [deleted, cancelled]
|
|
required: [status]
|
|
404:
|
|
description: Handler not found
|
|
"""
|
|
handler_id = request.path_params["handler_id"]
|
|
# Simple boolean parsing aligned with other APIs (e.g., `sse`): only "true" enables
|
|
purge = request.query_params.get("purge", "false").lower() == "true"
|
|
|
|
wrapper = self._handlers.get(handler_id)
|
|
if wrapper is None and not purge:
|
|
raise HTTPException(detail="Handler not found", status_code=404)
|
|
|
|
# Close the handler if it exists (this will cancel and trigger auto-checkpoint)
|
|
if wrapper is not None:
|
|
await self._close_handler(wrapper)
|
|
|
|
# Handle persistence
|
|
if purge:
|
|
n_deleted = await self._workflow_store.delete(
|
|
HandlerQuery(handler_id_in=[handler_id])
|
|
)
|
|
if n_deleted == 0:
|
|
raise HTTPException(detail="Handler not found", status_code=404)
|
|
|
|
return JSONResponse(
|
|
CancelHandlerResponse(
|
|
status="deleted" if purge else "cancelled"
|
|
).model_dump()
|
|
)
|
|
|
|
#
|
|
# Private methods
|
|
#
|
|
def _extract_workflow(self, request: Request) -> _NamedWorkflow:
|
|
if "name" not in request.path_params:
|
|
raise HTTPException(detail="'name' parameter missing", status_code=400)
|
|
name = request.path_params["name"]
|
|
|
|
if name not in self._workflows:
|
|
raise HTTPException(detail="Workflow not found", status_code=404)
|
|
|
|
return _NamedWorkflow(name=name, workflow=self._workflows[name])
|
|
|
|
async def _extract_run_params(
|
|
self, request: Request, workflow: Workflow, workflow_name: str
|
|
) -> tuple[Context | None, StartEvent | None, str]:
|
|
try:
|
|
try:
|
|
body = await request.json()
|
|
except Exception as e:
|
|
raise HTTPException(detail=f"Invalid JSON body: {e}", status_code=400)
|
|
context_data = body.get("context")
|
|
run_kwargs = body.get("kwargs", {})
|
|
start_event_data = body.get("start_event", run_kwargs)
|
|
handler_id = body.get("handler_id")
|
|
|
|
# Extract custom StartEvent if present
|
|
start_event = None
|
|
if start_event_data is not None:
|
|
try:
|
|
start_event = EventEnvelope.parse(
|
|
start_event_data,
|
|
self._event_registry(workflow_name),
|
|
explicit_event=workflow.start_event_class,
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Validation error for 'start_event': {e}",
|
|
status_code=400,
|
|
)
|
|
if start_event is not None and not isinstance(
|
|
start_event, workflow.start_event_class
|
|
):
|
|
raise HTTPException(
|
|
detail=f"Start event must be an instance of {workflow.start_event_class}",
|
|
status_code=400,
|
|
)
|
|
|
|
# Extract custom Context if present
|
|
context = None
|
|
if context_data:
|
|
context = Context.from_dict(workflow=workflow, data=context_data)
|
|
elif handler_id:
|
|
persisted_handlers = await self._workflow_store.query(
|
|
HandlerQuery(
|
|
handler_id_in=[handler_id],
|
|
workflow_name_in=[workflow_name],
|
|
status_in=["completed"],
|
|
)
|
|
)
|
|
if len(persisted_handlers) == 0:
|
|
raise HTTPException(detail="Handler not found", status_code=404)
|
|
|
|
context = Context.from_dict(workflow, persisted_handlers[0].ctx)
|
|
|
|
handler_id = handler_id or nanoid()
|
|
return (context, start_event, handler_id)
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTPExceptions as-is (like start_event validation errors)
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
detail=f"Error processing request body: {e}", status_code=500
|
|
)
|
|
|
|
async def _start_workflow(
|
|
self,
|
|
workflow: _NamedWorkflow,
|
|
handler_id: str,
|
|
start_event: StartEvent | None = None,
|
|
context: Context | None = None,
|
|
) -> _WorkflowHandler:
|
|
"""Start a workflow and return a wrapper for the handler."""
|
|
with instrument_tags({"handler_id": handler_id}):
|
|
handler = workflow.workflow.run(
|
|
ctx=context,
|
|
start_event=start_event,
|
|
)
|
|
wrapper = await self._run_workflow_handler(
|
|
handler_id, workflow.name, handler
|
|
)
|
|
return wrapper
|
|
|
|
async def _run_workflow_handler(
|
|
self, handler_id: str, workflow_name: str, handler: WorkflowHandler
|
|
) -> _WorkflowHandler:
|
|
"""
|
|
Creates a wrapper for the handler and starts streaming events.
|
|
"""
|
|
queue: asyncio.Queue[Event] = asyncio.Queue()
|
|
started_at = datetime.now(timezone.utc)
|
|
|
|
wrapper = _WorkflowHandler(
|
|
run_handler=handler,
|
|
queue=queue,
|
|
task=None, # Will be set by start_streaming()
|
|
consumer_mutex=asyncio.Lock(),
|
|
handler_id=handler_id,
|
|
workflow_name=workflow_name,
|
|
started_at=started_at,
|
|
updated_at=started_at,
|
|
completed_at=None,
|
|
_workflow_store=self._workflow_store,
|
|
_persistence_backoff=self._persistence_backoff,
|
|
)
|
|
# Initial checkpoint before registration; fail fast if persistence is unavailable
|
|
await wrapper.checkpoint()
|
|
# Now register and start streaming
|
|
self._handlers[handler_id] = wrapper
|
|
|
|
async def on_finish() -> None:
|
|
self._handlers.pop(handler_id, None)
|
|
self._results.pop(handler_id, None)
|
|
|
|
wrapper.start_streaming(on_finish=on_finish)
|
|
|
|
return wrapper
|
|
|
|
async def _close_handler(self, handler: _WorkflowHandler) -> None:
|
|
"""Close and cleanup a handler."""
|
|
# Cancel the run_handler if not done
|
|
if not handler.run_handler.done():
|
|
try:
|
|
handler.run_handler.cancel()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
await handler.run_handler.cancel_run()
|
|
except Exception:
|
|
pass
|
|
|
|
if handler.task is not None:
|
|
await handler.task
|
|
|
|
self._handlers.pop(handler.handler_id, None)
|
|
self._results.pop(handler.handler_id, None)
|
|
|
|
def _event_registry(self, workflow_name: str) -> dict[str, type[Event]]:
|
|
items = {e.__name__: e for e in self._workflows[workflow_name].events}
|
|
items.update(
|
|
{
|
|
e.__name__: e
|
|
for e in self._additional_events.get(workflow_name, None) or []
|
|
}
|
|
)
|
|
return items
|
|
|
|
|
|
@dataclass
|
|
class _WorkflowHandler:
|
|
"""A wrapper around a handler: WorkflowHandler. Necessary to monitor and dispatch events from the handler's stream_events"""
|
|
|
|
run_handler: WorkflowHandler
|
|
queue: asyncio.Queue[Event]
|
|
task: asyncio.Task[None] | None
|
|
# only one consumer of the queue at a time allowed
|
|
consumer_mutex: asyncio.Lock
|
|
|
|
# metadata
|
|
handler_id: str
|
|
workflow_name: str
|
|
started_at: datetime
|
|
updated_at: datetime
|
|
completed_at: datetime | None
|
|
|
|
# Dependencies for persistence
|
|
_workflow_store: AbstractWorkflowStore
|
|
_persistence_backoff: list[float]
|
|
_on_finish: Callable[[], Awaitable[None]] | None = None
|
|
|
|
def _as_persistent(self) -> PersistentHandler:
|
|
"""Persist the current handler state immediately to the workflow store."""
|
|
self.updated_at = datetime.now(timezone.utc)
|
|
if self.status in ("completed", "failed", "cancelled"):
|
|
self.completed_at = self.updated_at
|
|
|
|
persistent = PersistentHandler(
|
|
handler_id=self.handler_id,
|
|
workflow_name=self.workflow_name,
|
|
status=self.status,
|
|
run_id=self.run_handler.run_id,
|
|
error=self.error,
|
|
result=self.result,
|
|
started_at=self.started_at,
|
|
updated_at=self.updated_at,
|
|
completed_at=self.completed_at,
|
|
ctx=self.run_handler.ctx.to_dict() if self.run_handler.ctx else {},
|
|
)
|
|
return persistent
|
|
|
|
async def persist(self, persistent: PersistentHandler) -> None:
|
|
await self._workflow_store.update(persistent)
|
|
|
|
async def checkpoint(self) -> None:
|
|
"""Persist with retry/backoff; cancel handler when retries exhausted."""
|
|
backoffs = list(self._persistence_backoff)
|
|
try:
|
|
persistent = self._as_persistent()
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to checkpoint handler {self.handler_id} to persistent state. Is there non-serializable state in an event or the state store? {e}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
while True:
|
|
try:
|
|
await self.persist(persistent)
|
|
return
|
|
except Exception as e:
|
|
backoff = backoffs.pop(0) if backoffs else None
|
|
if backoff is None:
|
|
logger.error(
|
|
f"Failed to checkpoint handler {self.handler_id} after final attempt. Failing the handler.",
|
|
exc_info=True,
|
|
)
|
|
# Cancel the underlying workflow; do not re-raise here to allow callers to decide behavior
|
|
try:
|
|
self.run_handler.cancel()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
logger.error(
|
|
f"Failed to checkpoint handler {self.handler_id}. Retrying in {backoff} seconds: {e}"
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
|
|
def to_response_model(self) -> HandlerData:
|
|
"""Convert runtime handler to API response model."""
|
|
return HandlerData(
|
|
handler_id=self.handler_id,
|
|
workflow_name=self.workflow_name,
|
|
run_id=self.run_handler.run_id,
|
|
status=self.status,
|
|
started_at=self.started_at.isoformat(),
|
|
updated_at=self.updated_at.isoformat(),
|
|
completed_at=self.completed_at.isoformat()
|
|
if self.completed_at is not None
|
|
else None,
|
|
error=self.error,
|
|
result=EventEnvelopeWithMetadata.from_event(self.result)
|
|
if self.result is not None
|
|
else None,
|
|
)
|
|
|
|
@staticmethod
|
|
def handler_data_from_persistent(persistent: PersistentHandler) -> HandlerData:
|
|
return HandlerData(
|
|
handler_id=persistent.handler_id,
|
|
workflow_name=persistent.workflow_name,
|
|
run_id=persistent.run_id,
|
|
status=persistent.status,
|
|
started_at=persistent.started_at.isoformat()
|
|
if persistent.started_at is not None
|
|
else datetime.now(timezone.utc).isoformat(),
|
|
updated_at=persistent.updated_at.isoformat()
|
|
if persistent.updated_at is not None
|
|
else None,
|
|
completed_at=persistent.completed_at.isoformat()
|
|
if persistent.completed_at is not None
|
|
else None,
|
|
error=persistent.error,
|
|
result=EventEnvelopeWithMetadata.from_event(persistent.result)
|
|
if persistent.result is not None
|
|
else None,
|
|
)
|
|
|
|
@property
|
|
def status(self) -> Status:
|
|
"""Get the current status by inspecting the handler state."""
|
|
if not self.run_handler.done():
|
|
return "running"
|
|
# done - check if cancelled first
|
|
if self.run_handler.cancelled():
|
|
return "cancelled"
|
|
# then check for exception
|
|
exc = self.run_handler.exception()
|
|
if exc is not None:
|
|
return "failed"
|
|
return "completed"
|
|
|
|
@property
|
|
def error(self) -> str | None:
|
|
if not self.run_handler.done():
|
|
return None
|
|
try:
|
|
exc = self.run_handler.exception()
|
|
except asyncio.CancelledError:
|
|
return None
|
|
return str(exc) if exc is not None else None
|
|
|
|
@property
|
|
def result(self) -> StopEvent | None:
|
|
if not self.run_handler.done():
|
|
return None
|
|
try:
|
|
return self.run_handler.get_stop_event()
|
|
except asyncio.CancelledError:
|
|
return None
|
|
except Exception:
|
|
return None
|
|
|
|
def start_streaming(self, on_finish: Callable[[], Awaitable[None]]) -> None:
|
|
"""Start streaming events from the handler and managing state."""
|
|
self.task = asyncio.create_task(self._stream_events(on_finish=on_finish))
|
|
|
|
async def _stream_events(self, on_finish: Callable[[], Awaitable[None]]) -> None:
|
|
"""Internal method that streams events, updates status, and persists state."""
|
|
with instrument_tags({"handler_id": self.handler_id}):
|
|
await self.checkpoint()
|
|
self._on_finish = on_finish
|
|
async for event in self.run_handler.stream_events(expose_internal=True):
|
|
if ( # Watch for a specific internal event that signals the step is complete
|
|
isinstance(event, StepStateChanged)
|
|
and event.step_state == StepState.NOT_RUNNING
|
|
):
|
|
state = (
|
|
self.run_handler.ctx.to_dict() if self.run_handler.ctx else None
|
|
)
|
|
if state is None:
|
|
logger.warning(
|
|
f"Context state is None for handler {self.handler_id}. This is not expected."
|
|
)
|
|
continue
|
|
await self.checkpoint()
|
|
|
|
self.queue.put_nowait(event)
|
|
# done when stream events are complete
|
|
try:
|
|
await self.run_handler
|
|
except asyncio.CancelledError:
|
|
# Handler was cancelled - status will be automatically detected via handler.cancelled()
|
|
logger.info(f"Workflow run {self.handler_id} was cancelled")
|
|
# Don't re-raise, just let the task complete
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Workflow run {self.handler_id} failed! {e}", exc_info=True
|
|
)
|
|
|
|
await self.checkpoint()
|
|
|
|
async def acquire_events_stream(
|
|
self, timeout: float = 1
|
|
) -> AsyncGenerator[Event, None]:
|
|
"""
|
|
Acquires the lock to iterate over the events, and returns generator of events.
|
|
"""
|
|
try:
|
|
await asyncio.wait_for(self.consumer_mutex.acquire(), timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
raise NoLockAvailable(
|
|
f"No lock available to acquire after {timeout}s timeout"
|
|
)
|
|
return self._iter_events(timeout=timeout)
|
|
|
|
async def _iter_events(self, timeout: float = 1) -> AsyncGenerator[Event, None]:
|
|
"""
|
|
Converts the queue to an async generator while the workflow is still running, and there are still events.
|
|
For better or worse, multiple consumers will compete for events
|
|
"""
|
|
|
|
try:
|
|
while not self.queue.empty() or (
|
|
self.task is not None and not self.task.done()
|
|
):
|
|
available_events = []
|
|
while not self.queue.empty():
|
|
available_events.append(self.queue.get_nowait())
|
|
for event in available_events:
|
|
yield event
|
|
queue_get_task: asyncio.Task[Event] = asyncio.create_task(
|
|
self.queue.get()
|
|
)
|
|
task_waitable = self.task
|
|
done, pending = await asyncio.wait(
|
|
{queue_get_task, task_waitable}
|
|
if task_waitable is not None
|
|
else {queue_get_task},
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
if queue_get_task in done:
|
|
yield await queue_get_task
|
|
else: # otherwise task completed, so nothing else will be published to the queue
|
|
queue_get_task.cancel()
|
|
break
|
|
finally:
|
|
if self._on_finish is not None and self.run_handler.done():
|
|
# clean up the resources if the stream has been consumed
|
|
await self._on_finish()
|
|
self.consumer_mutex.release()
|
|
|
|
|
|
class NoLockAvailable(Exception):
|
|
"""Raised when no lock is available to acquire after a timeout"""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class _NamedWorkflow:
|
|
name: str
|
|
workflow: Workflow
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Generate OpenAPI schema")
|
|
parser.add_argument(
|
|
"--output", type=str, default="openapi.json", help="Output file path"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
server = WorkflowServer()
|
|
dict_schema = server.openapi_schema()
|
|
with open(args.output, "w") as f:
|
|
json.dump(dict_schema, indent=2, fp=f)
|
|
print(f"OpenAPI schema written to {args.output}")
|