# SPDX-License-Identifier: MIT # Copyright (c) 2025 LlamaIndex Inc. from __future__ import annotations import httpx import json from typing import ( Any, AsyncGenerator, AsyncIterator, overload, ) from contextlib import asynccontextmanager from workflows.events import StartEvent, Event from workflows import Context from workflows.protocol import ( HandlerData, HandlersListResponse, HealthResponse, SendEventResponse, Status, WorkflowsListResponse, CancelHandlerResponse, ) from workflows.protocol.serializable_events import ( EventEnvelope, EventEnvelopeWithMetadata, ) def _raise_for_status_with_body(response: httpx.Response) -> None: """ Raise an HTTPStatusError with the first 200 characters of the response body for 400 and 500 level errors. """ try: response.raise_for_status() except httpx.HTTPStatusError as e: if 400 <= e.response.status_code < 600: body_preview = e.response.text[:200] method = e.request.method url = e.request.url status_code = e.response.status_code raise httpx.HTTPStatusError( f"{status_code} {e.response.reason_phrase} for {method} {url}. Response: {body_preview}", request=e.request, response=e.response, ) from e raise class WorkflowClient: @overload def __init__(self, *, httpx_client: httpx.AsyncClient): ... @overload def __init__( self, *, base_url: str, ): ... def __init__( self, *, httpx_client: httpx.AsyncClient | None = None, base_url: str | None = None, ): if httpx_client is None and base_url is None: raise ValueError("Either httpx_client or base_url must be provided") if httpx_client is not None and base_url is not None: raise ValueError("Only one of httpx_client or base_url must be provided") self.httpx_client = httpx_client self.base_url = base_url @asynccontextmanager async def _get_client(self) -> AsyncIterator[httpx.AsyncClient]: if self.httpx_client: yield self.httpx_client else: async with httpx.AsyncClient(base_url=self.base_url or "") as client: yield client async def is_healthy(self) -> HealthResponse: """ Check whether the workflow server is helathy or not Returns: HealthResponse: health response from the workflow """ async with self._get_client() as client: response = await client.get("/health") _raise_for_status_with_body(response) return HealthResponse.model_validate(response.json()) async def list_workflows(self) -> WorkflowsListResponse: """ List workflows Returns: WorkflowsListResponse: List of workflow names available through the server. """ async with self._get_client() as client: response = await client.get("/workflows") _raise_for_status_with_body(response) return WorkflowsListResponse.model_validate(response.json()) async def run_workflow( self, workflow_name: str, handler_id: str | None = None, start_event: StartEvent | dict[str, Any] | None = None, context: Context | dict[str, Any] | None = None, ) -> HandlerData: """ Run the workflow and wait until completion. Args: start_event (Union[StartEvent, dict[str, Any], None]): start event class or dictionary representation (optional, defaults to None and get passed as an empty dictionary if not provided). context: Context or serialized representation of it (optional, defaults to None if not provided) handler_id (Optional[str]): Workflow handler identifier to continue from a previous completed run. Returns: HandlerData: Data representing the handler running the workflow (including result and metadata) """ if start_event is not None: try: start_event = _serialize_event(start_event, bare=True) except Exception as e: raise ValueError( f"Impossible to serialize the start event because of: {e}" ) if isinstance(context, Context): try: context = context.to_dict() except Exception as e: raise ValueError(f"Impossible to serialize the context because of: {e}") request_body = { "start_event": start_event or "", "context": context or {}, } if handler_id: request_body["handler_id"] = handler_id async with self._get_client() as client: response = await client.post( f"/workflows/{workflow_name}/run", json=request_body ) _raise_for_status_with_body(response) return HandlerData.model_validate(response.json()) async def run_workflow_nowait( self, workflow_name: str, handler_id: str | None = None, start_event: StartEvent | dict[str, Any] | None = None, context: Context | dict[str, Any] | None = None, ) -> HandlerData: """ Run the workflow in the background. Args: start_event (Union[StartEvent, dict[str, Any], None]): start event class or dictionary representation (optional, defaults to None and get passed as an empty dictionary if not provided). context: Context or serialized representation of it (optional, defaults to None if not provided) handler_id (Optional[str]): Workflow handler identifier to continue from a previous completed run. Returns: HandlerData: data representing the handler running the workflow. """ if start_event is not None: try: start_event = _serialize_event(start_event) except Exception as e: raise ValueError( f"Impossible to serialize the start event because of: {e}" ) if isinstance(context, Context): try: context = context.to_dict() except Exception as e: raise ValueError(f"Impossible to serialize the context because of: {e}") request_body: dict[str, Any] = { "start_event": start_event or _serialize_event(StartEvent()), "context": context or {}, } if handler_id: request_body["handler_id"] = handler_id async with self._get_client() as client: response = await client.post( f"/workflows/{workflow_name}/run-nowait", json=request_body ) _raise_for_status_with_body(response) return HandlerData.model_validate(response.json()) async def get_workflow_events( self, handler_id: str, include_internal_events: bool = False, lock_timeout: float = 1, ) -> AsyncGenerator[EventEnvelopeWithMetadata, None]: """ Stream events as they are produced by the workflow. Args: handler_id (str): ID of the handler running the workflow include_internal_events (bool): Include internal workflow events. Defaults to False. lock_timeout (float): Timeout (in seconds) for acquiring the lock to iterate over the events. Returns: AsyncGenerator[EventEnvelopeWithMetadata, None]: Generator for the events that are streamed as instances of `EventEnvelopeWithMetadata`. """ incl_inter = "true" if include_internal_events else "false" url = f"/events/{handler_id}" async with self._get_client() as client: try: async with client.stream( "GET", url, params={ "sse": "false", "include_internal": incl_inter, "acquire_timeout": lock_timeout, }, headers={"Connection": "keep-alive"}, timeout=None, ) as response: # Handle different response codes if response.status_code == 404: raise ValueError("Handler not found") elif response.status_code == 204: # Handler completed, no more events return _raise_for_status_with_body(response) async for line in response.aiter_lines(): if line.strip(): # Skip empty lines try: event = EventEnvelopeWithMetadata.model_validate_json( line ) yield event except json.JSONDecodeError as e: print(f"Failed to parse JSON: {e}, data: {line}") continue except httpx.TimeoutException: raise TimeoutError( f"Timeout waiting for events from handler {handler_id}" ) except httpx.RequestError as e: raise ConnectionError(f"Failed to connect to event stream: {e}") async def send_event( self, handler_id: str, event: Event | dict[str, Any], step: str | None = None, ) -> SendEventResponse: """ Send an event to the workflow. Args: handler_id (str): ID of the handler of the running workflow to send the event to event (Event | dict[str, Any] | str): Event to send, represented as an Event object, a dictionary or a serialized string. step (Optional[str]): Step to send the event to (optional, defaults to None) Returns: SendEventResponse: Confirmation of the send operation """ try: serialized_event: dict[str, Any] = _serialize_event(event) except Exception as e: raise ValueError(f"Error while serializing the provided event: {e}") request_body: dict[str, Any] = {"event": serialized_event} if step: request_body["step"] = step async with self._get_client() as client: response = await client.post(f"/events/{handler_id}", json=request_body) _raise_for_status_with_body(response) return SendEventResponse.model_validate(response.json()) async def get_result(self, handler_id: str) -> HandlerData: """ Deprecated. Use get_handler instead. """ return await self.get_handler(handler_id) async def get_handlers( self, status: list[Status] | None = None, workflow_name: list[str] | None = None, ) -> HandlersListResponse: """ Get all the workflow handlers. Args: status (list[Status] | None): List of statuses (e.g. "running", "completed", etc. ) to filter by. Defaults to None. workflow_name (list[str] | None): List of workflow names to filter by. Defaults to None. Returns: HandlersListResponse: List of workflow handlers. """ async with self._get_client() as client: response = await client.get( "/handlers", params={ "status": status, "workflow_name": workflow_name, }, ) _raise_for_status_with_body(response) return HandlersListResponse.model_validate(response.json()) async def get_handler(self, handler_id: str) -> HandlerData: """ Get a single workflow handler by identifier. Args: handler_id (str): ID of the handler associated with the workflow run Returns: HandlerData: Handler metadata persisted by the server. """ async with self._get_client() as client: response = await client.get(f"/handlers/{handler_id}") _raise_for_status_with_body(response) return HandlerData.model_validate(response.json()) async def cancel_handler( self, handler_id: str, purge: bool = False ) -> CancelHandlerResponse: """ Stop and cancel a workflow run. Args: handler_id (str): ID of the handler associated with the workflow run purge (bool): Whether or not to delete the run also from the persistent storage. Defaults to false """ async with self._get_client() as client: response = await client.post( f"/handlers/{handler_id}/cancel", params={"purge": "true" if purge else "false"}, ) _raise_for_status_with_body(response) return CancelHandlerResponse.model_validate(response.json()) def _serialize_event( event: Event | dict[str, Any], bare: bool = False ) -> dict[str, Any]: if isinstance(event, dict): return event # assumes you know what you are doing. In many cases this needs to be a dict that contains type metadata and the value return ( event.model_dump() if bare else EventEnvelope.from_event(event=event).model_dump() )