367 lines
13 KiB
Python
367 lines
13 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025 LlamaIndex Inc.
|
|
|
|
from __future__ import annotations
|
|
|
|
import httpx
|
|
import json
|
|
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
AsyncIterator,
|
|
overload,
|
|
)
|
|
from contextlib import asynccontextmanager
|
|
from workflows.events import StartEvent, Event
|
|
from workflows import Context
|
|
from workflows.protocol import (
|
|
HandlerData,
|
|
HandlersListResponse,
|
|
HealthResponse,
|
|
SendEventResponse,
|
|
Status,
|
|
WorkflowsListResponse,
|
|
CancelHandlerResponse,
|
|
)
|
|
from workflows.protocol.serializable_events import (
|
|
EventEnvelope,
|
|
EventEnvelopeWithMetadata,
|
|
)
|
|
|
|
|
|
def _raise_for_status_with_body(response: httpx.Response) -> None:
|
|
"""
|
|
Raise an HTTPStatusError with the first 200 characters of the response body
|
|
for 400 and 500 level errors.
|
|
"""
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as e:
|
|
if 400 <= e.response.status_code < 600:
|
|
body_preview = e.response.text[:200]
|
|
method = e.request.method
|
|
url = e.request.url
|
|
status_code = e.response.status_code
|
|
raise httpx.HTTPStatusError(
|
|
f"{status_code} {e.response.reason_phrase} for {method} {url}. Response: {body_preview}",
|
|
request=e.request,
|
|
response=e.response,
|
|
) from e
|
|
raise
|
|
|
|
|
|
class WorkflowClient:
|
|
@overload
|
|
def __init__(self, *, httpx_client: httpx.AsyncClient): ...
|
|
@overload
|
|
def __init__(
|
|
self,
|
|
*,
|
|
base_url: str,
|
|
): ...
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
httpx_client: httpx.AsyncClient | None = None,
|
|
base_url: str | None = None,
|
|
):
|
|
if httpx_client is None and base_url is None:
|
|
raise ValueError("Either httpx_client or base_url must be provided")
|
|
if httpx_client is not None and base_url is not None:
|
|
raise ValueError("Only one of httpx_client or base_url must be provided")
|
|
self.httpx_client = httpx_client
|
|
self.base_url = base_url
|
|
|
|
@asynccontextmanager
|
|
async def _get_client(self) -> AsyncIterator[httpx.AsyncClient]:
|
|
if self.httpx_client:
|
|
yield self.httpx_client
|
|
else:
|
|
async with httpx.AsyncClient(base_url=self.base_url or "") as client:
|
|
yield client
|
|
|
|
async def is_healthy(self) -> HealthResponse:
|
|
"""
|
|
Check whether the workflow server is helathy or not
|
|
|
|
Returns:
|
|
HealthResponse: health response from the workflow
|
|
"""
|
|
async with self._get_client() as client:
|
|
response = await client.get("/health")
|
|
_raise_for_status_with_body(response)
|
|
return HealthResponse.model_validate(response.json())
|
|
|
|
async def list_workflows(self) -> WorkflowsListResponse:
|
|
"""
|
|
List workflows
|
|
|
|
Returns:
|
|
WorkflowsListResponse: List of workflow names available through the server.
|
|
"""
|
|
async with self._get_client() as client:
|
|
response = await client.get("/workflows")
|
|
|
|
_raise_for_status_with_body(response)
|
|
|
|
return WorkflowsListResponse.model_validate(response.json())
|
|
|
|
async def run_workflow(
|
|
self,
|
|
workflow_name: str,
|
|
handler_id: str | None = None,
|
|
start_event: StartEvent | dict[str, Any] | None = None,
|
|
context: Context | dict[str, Any] | None = None,
|
|
) -> HandlerData:
|
|
"""
|
|
Run the workflow and wait until completion.
|
|
|
|
Args:
|
|
start_event (Union[StartEvent, dict[str, Any], None]): start event class or dictionary representation (optional, defaults to None and get passed as an empty dictionary if not provided).
|
|
context: Context or serialized representation of it (optional, defaults to None if not provided)
|
|
handler_id (Optional[str]): Workflow handler identifier to continue from a previous completed run.
|
|
|
|
Returns:
|
|
HandlerData: Data representing the handler running the workflow (including result and metadata)
|
|
"""
|
|
if start_event is not None:
|
|
try:
|
|
start_event = _serialize_event(start_event, bare=True)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f"Impossible to serialize the start event because of: {e}"
|
|
)
|
|
if isinstance(context, Context):
|
|
try:
|
|
context = context.to_dict()
|
|
except Exception as e:
|
|
raise ValueError(f"Impossible to serialize the context because of: {e}")
|
|
request_body = {
|
|
"start_event": start_event or "",
|
|
"context": context or {},
|
|
}
|
|
if handler_id:
|
|
request_body["handler_id"] = handler_id
|
|
async with self._get_client() as client:
|
|
response = await client.post(
|
|
f"/workflows/{workflow_name}/run", json=request_body
|
|
)
|
|
|
|
_raise_for_status_with_body(response)
|
|
|
|
return HandlerData.model_validate(response.json())
|
|
|
|
async def run_workflow_nowait(
|
|
self,
|
|
workflow_name: str,
|
|
handler_id: str | None = None,
|
|
start_event: StartEvent | dict[str, Any] | None = None,
|
|
context: Context | dict[str, Any] | None = None,
|
|
) -> HandlerData:
|
|
"""
|
|
Run the workflow in the background.
|
|
|
|
Args:
|
|
start_event (Union[StartEvent, dict[str, Any], None]): start event class or dictionary representation (optional, defaults to None and get passed as an empty dictionary if not provided).
|
|
context: Context or serialized representation of it (optional, defaults to None if not provided)
|
|
handler_id (Optional[str]): Workflow handler identifier to continue from a previous completed run.
|
|
|
|
Returns:
|
|
HandlerData: data representing the handler running the workflow.
|
|
"""
|
|
if start_event is not None:
|
|
try:
|
|
start_event = _serialize_event(start_event)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f"Impossible to serialize the start event because of: {e}"
|
|
)
|
|
if isinstance(context, Context):
|
|
try:
|
|
context = context.to_dict()
|
|
except Exception as e:
|
|
raise ValueError(f"Impossible to serialize the context because of: {e}")
|
|
request_body: dict[str, Any] = {
|
|
"start_event": start_event or _serialize_event(StartEvent()),
|
|
"context": context or {},
|
|
}
|
|
if handler_id:
|
|
request_body["handler_id"] = handler_id
|
|
async with self._get_client() as client:
|
|
response = await client.post(
|
|
f"/workflows/{workflow_name}/run-nowait", json=request_body
|
|
)
|
|
|
|
_raise_for_status_with_body(response)
|
|
|
|
return HandlerData.model_validate(response.json())
|
|
|
|
async def get_workflow_events(
|
|
self,
|
|
handler_id: str,
|
|
include_internal_events: bool = False,
|
|
lock_timeout: float = 1,
|
|
) -> AsyncGenerator[EventEnvelopeWithMetadata, None]:
|
|
"""
|
|
Stream events as they are produced by the workflow.
|
|
|
|
Args:
|
|
handler_id (str): ID of the handler running the workflow
|
|
include_internal_events (bool): Include internal workflow events. Defaults to False.
|
|
lock_timeout (float): Timeout (in seconds) for acquiring the lock to iterate over the events.
|
|
|
|
Returns:
|
|
AsyncGenerator[EventEnvelopeWithMetadata, None]: Generator for the events that are streamed as instances of `EventEnvelopeWithMetadata`.
|
|
"""
|
|
incl_inter = "true" if include_internal_events else "false"
|
|
url = f"/events/{handler_id}"
|
|
|
|
async with self._get_client() as client:
|
|
try:
|
|
async with client.stream(
|
|
"GET",
|
|
url,
|
|
params={
|
|
"sse": "false",
|
|
"include_internal": incl_inter,
|
|
"acquire_timeout": lock_timeout,
|
|
},
|
|
headers={"Connection": "keep-alive"},
|
|
timeout=None,
|
|
) as response:
|
|
# Handle different response codes
|
|
if response.status_code == 404:
|
|
raise ValueError("Handler not found")
|
|
elif response.status_code == 204:
|
|
# Handler completed, no more events
|
|
return
|
|
|
|
_raise_for_status_with_body(response)
|
|
|
|
async for line in response.aiter_lines():
|
|
if line.strip(): # Skip empty lines
|
|
try:
|
|
event = EventEnvelopeWithMetadata.model_validate_json(
|
|
line
|
|
)
|
|
yield event
|
|
except json.JSONDecodeError as e:
|
|
print(f"Failed to parse JSON: {e}, data: {line}")
|
|
continue
|
|
|
|
except httpx.TimeoutException:
|
|
raise TimeoutError(
|
|
f"Timeout waiting for events from handler {handler_id}"
|
|
)
|
|
except httpx.RequestError as e:
|
|
raise ConnectionError(f"Failed to connect to event stream: {e}")
|
|
|
|
async def send_event(
|
|
self,
|
|
handler_id: str,
|
|
event: Event | dict[str, Any],
|
|
step: str | None = None,
|
|
) -> SendEventResponse:
|
|
"""
|
|
Send an event to the workflow.
|
|
|
|
Args:
|
|
handler_id (str): ID of the handler of the running workflow to send the event to
|
|
event (Event | dict[str, Any] | str): Event to send, represented as an Event object, a dictionary or a serialized string.
|
|
step (Optional[str]): Step to send the event to (optional, defaults to None)
|
|
|
|
Returns:
|
|
SendEventResponse: Confirmation of the send operation
|
|
"""
|
|
try:
|
|
serialized_event: dict[str, Any] = _serialize_event(event)
|
|
except Exception as e:
|
|
raise ValueError(f"Error while serializing the provided event: {e}")
|
|
request_body: dict[str, Any] = {"event": serialized_event}
|
|
if step:
|
|
request_body["step"] = step
|
|
async with self._get_client() as client:
|
|
response = await client.post(f"/events/{handler_id}", json=request_body)
|
|
_raise_for_status_with_body(response)
|
|
|
|
return SendEventResponse.model_validate(response.json())
|
|
|
|
async def get_result(self, handler_id: str) -> HandlerData:
|
|
"""
|
|
Deprecated. Use get_handler instead.
|
|
"""
|
|
return await self.get_handler(handler_id)
|
|
|
|
async def get_handlers(
|
|
self,
|
|
status: list[Status] | None = None,
|
|
workflow_name: list[str] | None = None,
|
|
) -> HandlersListResponse:
|
|
"""
|
|
Get all the workflow handlers.
|
|
Args:
|
|
status (list[Status] | None): List of statuses (e.g. "running", "completed", etc. ) to filter by. Defaults to None.
|
|
workflow_name (list[str] | None): List of workflow names to filter by. Defaults to None.
|
|
Returns:
|
|
HandlersListResponse: List of workflow handlers.
|
|
"""
|
|
async with self._get_client() as client:
|
|
response = await client.get(
|
|
"/handlers",
|
|
params={
|
|
"status": status,
|
|
"workflow_name": workflow_name,
|
|
},
|
|
)
|
|
_raise_for_status_with_body(response)
|
|
|
|
return HandlersListResponse.model_validate(response.json())
|
|
|
|
async def get_handler(self, handler_id: str) -> HandlerData:
|
|
"""
|
|
Get a single workflow handler by identifier.
|
|
|
|
Args:
|
|
handler_id (str): ID of the handler associated with the workflow run
|
|
|
|
Returns:
|
|
HandlerData: Handler metadata persisted by the server.
|
|
"""
|
|
async with self._get_client() as client:
|
|
response = await client.get(f"/handlers/{handler_id}")
|
|
_raise_for_status_with_body(response)
|
|
|
|
return HandlerData.model_validate(response.json())
|
|
|
|
async def cancel_handler(
|
|
self, handler_id: str, purge: bool = False
|
|
) -> CancelHandlerResponse:
|
|
"""
|
|
Stop and cancel a workflow run.
|
|
|
|
Args:
|
|
handler_id (str): ID of the handler associated with the workflow run
|
|
purge (bool): Whether or not to delete the run also from the persistent storage. Defaults to false
|
|
"""
|
|
async with self._get_client() as client:
|
|
response = await client.post(
|
|
f"/handlers/{handler_id}/cancel",
|
|
params={"purge": "true" if purge else "false"},
|
|
)
|
|
_raise_for_status_with_body(response)
|
|
|
|
return CancelHandlerResponse.model_validate(response.json())
|
|
|
|
|
|
def _serialize_event(
|
|
event: Event | dict[str, Any], bare: bool = False
|
|
) -> dict[str, Any]:
|
|
if isinstance(event, dict):
|
|
return event # assumes you know what you are doing. In many cases this needs to be a dict that contains type metadata and the value
|
|
return (
|
|
event.model_dump()
|
|
if bare
|
|
else EventEnvelope.from_event(event=event).model_dump()
|
|
)
|