feat: enhance LLMClient with new Codex generation method and refactor tool handling

- Added a new method `_generate_codex` to LLMClient for generating responses using the Codex Responses API.
- Refactored tool handling to support structured tool calls and improved error handling during Codex interactions.
- Updated the client initialization for Codex to streamline the generation process.
- Enhanced message processing to accommodate new input formats and recursion for tool calls.
This commit is contained in:
sudipnext 2026-02-27 18:38:52 +05:45
parent d2e3ab9d15
commit 2eff131f6b

View file

@ -1,7 +1,7 @@
import asyncio
import dirtyjson
import json
from typing import Any, AsyncGenerator, List, Optional, Union
from typing import AsyncGenerator, List, Optional, Dict, Any
from fastapi import HTTPException
from openai import APIStatusError, AsyncOpenAI, OpenAIError
from openai.types.chat.chat_completion_chunk import (
@ -39,7 +39,6 @@ from models.llm_tool_call import (
OpenAIToolCallFunction,
)
from models.llm_tools import LLMDynamicTool, LLMTool
from services.codex_llm import CodexLLMAdapter
from services.llm_tool_calls_handler import LLMToolCallsHandler
from utils.async_iterator import iterator_to_async
from utils.dummy_functions import do_nothing_async
@ -73,6 +72,7 @@ from utils.schema_utils import (
)
class LLMClient:
def __init__(self):
self.llm_provider = get_llm_provider()
@ -112,7 +112,7 @@ class LLMClient:
case LLMProvider.CUSTOM:
return self._get_custom_client()
case LLMProvider.CODEX:
return None # Codex uses _get_codex_client() with AsyncOpenAI, not self._client
return self._get_codex_client()
case _:
raise HTTPException(
status_code=400,
@ -482,6 +482,147 @@ class LLMClient:
depth=depth,
)
async def _generate_codex(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
) -> Optional[str]:
"""
Generate plain text using the Codex Responses API. On tool calls, run
handlers and recurse (same pattern as _generate_openai).
"""
_MAX_RECURSION_DEPTH = 5
client: AsyncOpenAI = self._client
# Flatten tools to Responses API format
responses_tools: Optional[List[dict]] = None
if tools:
responses_tools = []
for tool in tools:
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
if isinstance(fn, dict):
responses_tools.append({
"type": "function",
"name": fn.get("name", ""),
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
})
else:
responses_tools.append(tool)
# Build instructions + input (same shape as _stream_codex_structured)
instructions = self._get_system_prompt(messages) or None
input_payload: List[Dict[str, Any]] = []
for m in messages:
if isinstance(m, LLMSystemMessage):
continue
if isinstance(m, LLMUserMessage):
input_payload.append({
"role": "user",
"content": [{"type": "input_text", "text": m.content}],
})
elif isinstance(m, OpenAIAssistantMessage):
text = m.content or ""
if text:
input_payload.append({
"role": "assistant",
"content": [{"type": "output_text", "text": text}],
})
else:
text = getattr(m, "content", "") or ""
if text:
input_payload.append({
"role": "user",
"content": [{"type": "input_text", "text": text}],
})
create_kwargs: Dict[str, Any] = {
"model": model,
"store": False,
"stream": True,
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"tool_choice": "auto",
"parallel_tool_calls": True,
}
if instructions:
create_kwargs["instructions"] = instructions
if input_payload:
create_kwargs["input"] = input_payload
if responses_tools:
create_kwargs["tools"] = responses_tools
if max_tokens is not None:
create_kwargs["max_output_tokens"] = max_tokens
stream = await client.responses.create(**create_kwargs)
def _event_dict(ev: Any) -> dict:
if hasattr(ev, "model_dump"):
return ev.model_dump()
return {
"type": getattr(ev, "type", None),
"delta": getattr(ev, "delta", None),
"item": getattr(ev, "item", None),
"message": getattr(ev, "message", None),
}
text_parts: List[str] = []
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
async for ev in stream:
event = _event_dict(ev) if not isinstance(ev, dict) else ev
event_type = event.get("type") or ""
if event_type == "response.output_text.delta":
delta = event.get("delta") or ""
if delta:
text_parts.append(delta)
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
cid = item.get("call_id") or item.get("id", "")
tool_calls_by_id[cid] = item
elif event_type in ("response.error", "response.failed", "error"):
err = event.get("message") or event.get("error") or str(event)
raise HTTPException(status_code=502, detail=f"Codex error: {err}"[:400])
if tool_calls_by_id and responses_tools and depth < _MAX_RECURSION_DEPTH:
parsed_tool_calls = [
OpenAIToolCall(
id=cid,
type="function",
function=OpenAIToolCallFunction(
name=data.get("name", ""),
arguments=data.get("arguments", ""),
),
)
for cid, data in tool_calls_by_id.items()
]
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
parsed_tool_calls
)
new_messages = [
*messages,
OpenAIAssistantMessage(
role="assistant",
content=None,
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
),
*tool_call_messages,
]
return await self._generate_codex(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
)
return "".join(text_parts) or None
async def generate(
self,
model: str,
@ -501,14 +642,11 @@ class LLMClient:
tools=parsed_tools,
)
case LLMProvider.CODEX:
print(
f"LLMClient.generate Codex: model={model} messages={len(messages)} "
f"user_tools={len(parsed_tools) if parsed_tools else 0}"
)
client = self._get_codex_client()
content = await CodexLLMAdapter.generate_codex(
client, model, messages, self.tool_calls_handler,
max_tokens=max_tokens, tools=parsed_tools,
content = await self._generate_codex(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.GOOGLE:
content = await self._generate_google(
@ -657,6 +795,48 @@ class LLMClient:
return content
return None
async def _generate_codex_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
depth: int = 0,
) -> dict | None:
"""
Generate structured Codex output using the Responses API.
This reuses the streaming Codex structured implementation and simply
accumulates the streamed JSON chunks into a single string, then parses
it at the root call.
"""
# Reuse the Responses API streaming implementation for Codex.
accumulated: List[str] = []
async for chunk in self._stream_codex_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
depth=depth,
):
accumulated.append(chunk)
raw = "".join(accumulated)
if not raw:
return None
# At the root level we parse into a dict; recursive calls just
# propagate the raw JSON/text, mirroring other providers.
if depth == 0:
return dict(dirtyjson.loads(raw))
return {"raw": raw}
async def _generate_google_structured(
self,
model: str,
@ -887,19 +1067,13 @@ class LLMClient:
max_tokens=max_tokens,
)
case LLMProvider.CODEX:
print(
f"LLMClient.generate_structured Codex: model={model} messages={len(messages)} "
f"strict={strict} user_tools={len(parsed_tools) if parsed_tools else 0}"
)
client = self._get_codex_client()
content = await CodexLLMAdapter.generate_codex_structured(
client, model, messages, response_format, self.tool_calls_handler,
strict=strict, max_tokens=max_tokens, tools=parsed_tools,
)
print(
"LLMClient.generate_structured Codex: done "
f"content_is_none={content is None} "
f"content_keys={list(content.keys())[:8] if isinstance(content, dict) else None}"
content = await self._generate_codex_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
content = await self._generate_google_structured(
@ -1174,6 +1348,157 @@ class LLMClient:
):
yield event
async def _stream_codex(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
) -> AsyncGenerator[str, None]:
"""
Stream plain text from Codex (Responses API). On tool calls, execute tools
and recurse, mirroring _stream_openai but using Responses events.
"""
_MAX_RECURSION_DEPTH = 5
client: AsyncOpenAI = (
self._get_codex_client()
if self.llm_provider == LLMProvider.CODEX
else self._client
)
# Flatten tools to Responses API format
responses_tools: Optional[List[dict]] = None
if tools:
responses_tools = []
for tool in tools:
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
if isinstance(fn, dict):
responses_tools.append(
{
"type": "function",
"name": fn.get("name", ""),
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
else:
responses_tools.append(tool)
# Build instructions + input (same shape as _generate_codex/_stream_codex_structured)
instructions = self._get_system_prompt(messages) or None
input_payload: List[Dict[str, Any]] = []
for m in messages:
if isinstance(m, LLMSystemMessage):
continue
if isinstance(m, LLMUserMessage):
input_payload.append(
{
"role": "user",
"content": [{"type": "input_text", "text": m.content}],
}
)
elif isinstance(m, OpenAIAssistantMessage):
text = m.content or ""
if text:
input_payload.append(
{
"role": "assistant",
"content": [{"type": "output_text", "text": text}],
}
)
else:
text = getattr(m, "content", "") or ""
if text:
input_payload.append(
{
"role": "user",
"content": [{"type": "input_text", "text": text}],
}
)
create_kwargs: Dict[str, Any] = {
"model": model,
"store": False,
"stream": True,
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"tool_choice": "auto",
"parallel_tool_calls": True,
}
if instructions:
create_kwargs["instructions"] = instructions
if input_payload:
create_kwargs["input"] = input_payload
if responses_tools:
create_kwargs["tools"] = responses_tools
if max_tokens is not None:
create_kwargs["max_output_tokens"] = max_tokens
stream = await client.responses.create(**create_kwargs)
def _event_dict(ev: Any) -> dict:
if hasattr(ev, "model_dump"):
return ev.model_dump()
return {
"type": getattr(ev, "type", None),
"delta": getattr(ev, "delta", None),
"item": getattr(ev, "item", None),
"message": getattr(ev, "message", None),
}
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
async for ev in stream:
event = _event_dict(ev) if not isinstance(ev, dict) else ev
event_type = event.get("type") or ""
if event_type == "response.output_text.delta":
delta = event.get("delta") or ""
if delta:
yield delta
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
cid = item.get("call_id") or item.get("id", "")
tool_calls_by_id[cid] = item
elif event_type in ("response.error", "response.failed", "error"):
err = event.get("message") or event.get("error") or str(event)
raise HTTPException(status_code=502, detail=f"Codex stream error: {err}"[:400])
if tool_calls_by_id and responses_tools and depth < _MAX_RECURSION_DEPTH:
parsed_tool_calls = [
OpenAIToolCall(
id=cid,
type="function",
function=OpenAIToolCallFunction(
name=data.get("name", ""),
arguments=data.get("arguments", ""),
),
)
for cid, data in tool_calls_by_id.items()
]
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
parsed_tool_calls
)
new_messages = [
*messages,
OpenAIAssistantMessage(
role="assistant",
content=None,
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
),
*tool_call_messages,
]
async for chunk in self._stream_codex(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
):
yield chunk
def _stream_ollama(
self,
model: str,
@ -1219,14 +1544,11 @@ class LLMClient:
tools=parsed_tools,
)
case LLMProvider.CODEX:
print(
f"LLMClient.stream Codex: model={model} messages={len(messages)} "
f"user_tools={len(parsed_tools) if parsed_tools else 0}"
)
client = self._get_codex_client()
return CodexLLMAdapter.stream_codex(
client, model, messages, self.tool_calls_handler,
max_tokens=max_tokens, tools=parsed_tools,
return self._stream_codex(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.GOOGLE:
return self._stream_google(
@ -1402,6 +1724,291 @@ class LLMClient:
):
yield event
async def _stream_codex_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
extra_body: Optional[dict] = None,
) -> AsyncGenerator[str, None]:
"""
Stream structured responses using OpenAI's Responses API (Codex-style models).
This implementation is intentionally separate from ChatCompletion-based streaming
because the Responses API uses a fundamentally different event model.
Why this function exists:
1. The Responses API does NOT return `choices[].delta` like ChatCompletions.
Instead, it streams typed events such as:
- response.output_text.delta
- response.output_tool_call.delta
- response.completed
- response.error
2. Structured output can be achieved in two ways:
a) Native JSON schema enforcement via `response_format`
b) Tool-call-based structured output using a synthetic `ResponseSchema` tool
This function supports both approaches. When tool-call structured mode is enabled,
a dynamic `ResponseSchema` tool is injected so the model returns structured data
as tool call arguments.
3. Tool calls must be accumulated incrementally.
The Responses API streams tool call arguments in chunks (`arguments_delta`),
so we reconstruct the full argument payload before executing the tool.
4. Recursive tool execution is supported.
If the model calls external tools (e.g., web search), we:
- Execute the tools asynchronously
- Append tool results as new messages
- Reinvoke the model recursively
This enables multi-step reasoning and grounding workflows.
5. Provider abstraction is preserved.
The Responses API event format is converted into our internal tool-call model
before being passed to the tool handler layer. This prevents SDK-specific
structures from leaking into business logic.
6. Strict schema enforcement (optional).
When `strict=True`, the provided JSON schema is hardened before being sent
to the model to reduce malformed outputs.
Important architectural note:
This function MUST NOT assume ChatCompletion-style streaming fields like
`choices`, `delta.content`, or `delta.tool_calls`. It strictly follows the
Responses API event model.
This separation ensures:
- Future compatibility with GPT-5 / Codex models
- Clean provider abstraction
- Streaming-safe structured JSON assembly
- Robust multi-tool recursive execution
"""
client: AsyncOpenAI = self._client
response_schema = response_format
# Apply strict schema once at root
if strict and depth == 0:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
# Codex Responses API requires all array schemas to specify `items`.
def _fix_arrays(node: Any) -> Any:
if isinstance(node, dict):
# Add default items for arrays missing them
if node.get("type") == "array" and "items" not in node:
node["items"] = {"type": "string"}
for key, value in list(node.items()):
node[key] = _fix_arrays(value)
elif isinstance(node, list):
for idx, value in enumerate(node):
node[idx] = _fix_arrays(value)
return node
response_schema = _fix_arrays(response_schema)
# Responses API tool format: flat {type, name, description, parameters}
response_schema_tool = {
"type": "function",
"name": "ResponseSchema",
"description": "Provide structured response",
"parameters": response_schema,
}
all_tools: List[dict] = [response_schema_tool]
if tools:
for tool in tools:
fn = (tool.get("function") or tool) if isinstance(tool, dict) else {}
if isinstance(fn, dict):
all_tools.append({
"type": "function",
"name": fn.get("name", ""),
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
})
else:
all_tools.append(tool)
# Build instructions + input like Codex adapter (instructions from system; input_text/output_text)
instructions = self._get_system_prompt(messages) or None
input_payload: List[Dict[str, Any]] = []
for m in messages:
if isinstance(m, LLMSystemMessage):
continue
if isinstance(m, LLMUserMessage):
input_payload.append({
"role": "user",
"content": [{"type": "input_text", "text": m.content}],
})
elif isinstance(m, OpenAIAssistantMessage):
text = m.content or ""
if text:
input_payload.append({
"role": "assistant",
"content": [{"type": "output_text", "text": text}],
})
else:
text = getattr(m, "content", "") or ""
if text:
input_payload.append({
"role": "user",
"content": [{"type": "input_text", "text": text}],
})
# Force model to use ResponseSchema for structured output
tool_choice = {"type": "function", "name": "ResponseSchema"}
create_kwargs: Dict[str, Any] = {
"model": model,
"store": False,
"stream": True,
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"tool_choice": tool_choice,
"parallel_tool_calls": True,
"tools": all_tools,
}
if instructions:
create_kwargs["instructions"] = instructions
if input_payload:
create_kwargs["input"] = input_payload
if max_tokens is not None:
create_kwargs["max_output_tokens"] = max_tokens
if extra_body:
create_kwargs.update(extra_body)
stream = await client.responses.create(**create_kwargs)
def _event_dict(ev: Any) -> dict:
if hasattr(ev, "model_dump"):
return ev.model_dump()
return {
"type": getattr(ev, "type", None),
"delta": getattr(ev, "delta", None),
"arguments": getattr(ev, "arguments", None),
"arguments_delta": getattr(ev, "arguments_delta", None),
"item": getattr(ev, "item", None),
"id": getattr(ev, "id", None),
"name": getattr(ev, "name", None),
"error": getattr(ev, "error", None),
"message": getattr(ev, "message", None),
}
tool_calls_by_id: Dict[str, Dict[str, Any]] = {}
current_call_id: Optional[str] = None
has_response_schema_tool_call = False
async for ev in stream:
event = _event_dict(ev) if not isinstance(ev, dict) else ev
event_type = event.get("type") or ""
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call" and item.get("name") == "ResponseSchema":
current_call_id = item.get("call_id") or item.get("id")
elif event_type == "response.function_call_arguments.delta":
if current_call_id:
delta = event.get("delta") or ""
if delta:
has_response_schema_tool_call = True
yield delta
elif event_type == "response.function_call_arguments.done":
if event.get("name") == "ResponseSchema":
args = event.get("arguments") or ""
if args:
has_response_schema_tool_call = True
yield args
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
cid = item.get("call_id") or item.get("id", "")
tool_calls_by_id[cid] = item
if item.get("name") == "ResponseSchema":
args = item.get("arguments") or ""
if args:
has_response_schema_tool_call = True
yield args
elif event_type == "response.output_tool_call.delta":
call_id = event.get("id")
name = event.get("name")
arguments_delta = event.get("arguments_delta") or ""
if call_id and name:
if call_id not in tool_calls_by_id:
tool_calls_by_id[call_id] = {"name": name, "arguments": ""}
tool_calls_by_id[call_id]["arguments"] += arguments_delta
if name == "ResponseSchema" and arguments_delta:
has_response_schema_tool_call = True
yield arguments_delta
elif event_type == "response.completed":
break
elif event_type in ("response.error", "response.failed", "error"):
err = event.get("error") or event.get("message") or str(event)
raise RuntimeError(err)
# ============================================
# EXECUTE NON-STRUCTURED TOOL CALLS (RECURSIVE)
# ============================================
other_tool_calls = {
cid: data
for cid, data in tool_calls_by_id.items()
if data.get("name") != "ResponseSchema"
}
if other_tool_calls and not has_response_schema_tool_call:
parsed_tool_calls = []
for call_id, data in other_tool_calls.items():
args = data.get("arguments", "") if isinstance(data, dict) else ""
parsed_tool_calls.append(
OpenAIToolCall(
id=call_id,
type="function",
function=OpenAIToolCallFunction(
name=data.get("name", ""),
arguments=args,
),
)
)
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
parsed_tool_calls
)
new_messages = [
*messages,
OpenAIAssistantMessage(
role="assistant",
content=None,
tool_calls=[tc.model_dump() for tc in parsed_tool_calls],
),
*tool_call_messages,
]
async for chunk in self._stream_codex_structured(
model=model,
messages=new_messages,
response_format=response_schema,
strict=strict,
max_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
depth=depth + 1,
):
yield chunk
async def _stream_google_structured(
self,
model: str,
@ -1655,14 +2262,13 @@ class LLMClient:
max_tokens=max_tokens,
)
case LLMProvider.CODEX:
print(
f"LLMClient.stream_structured Codex: model={model} messages={len(messages)} "
f"strict={strict} user_tools={len(parsed_tools) if parsed_tools else 0}"
)
client = self._get_codex_client()
return CodexLLMAdapter.stream_codex_structured(
client, model, messages, response_format, self.tool_calls_handler,
strict=strict, tools=parsed_tools, max_tokens=max_tokens,
return self._stream_codex_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
return self._stream_google_structured(