presenton/servers/fastapi/services/llm_client.py

1544 lines
52 KiB
Python

import asyncio
import json
from typing import AsyncGenerator, List, Optional
from fastapi import HTTPException
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from google import genai
from google.genai.types import Content as GoogleContent, Part as GoogleContentPart
from google.genai.types import GenerateContentConfig, GoogleSearch
from google.genai.types import Tool as GoogleTool
from anthropic import AsyncAnthropic
from anthropic.types import Message as AnthropicMessage
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
from enums.llm_provider import LLMProvider
from models.llm_message import (
AnthropicAssistantMessage,
AnthropicUserMessage,
GoogleAssistantMessage,
GoogleToolCallMessage,
OpenAIAssistantMessage,
LLMMessage,
LLMSystemMessage,
LLMUserMessage,
)
from models.llm_tool_call import (
AnthropicToolCall,
GoogleToolCall,
LLMToolCall,
OpenAIToolCall,
OpenAIToolCallFunction,
)
from models.llm_tools import LLMDynamicTool, LLMTool
from services.llm_tool_calls_handler import LLMToolCallsHandler
from utils.async_iterator import iterator_to_async
from utils.dummy_functions import do_nothing_async
from utils.get_env import (
get_anthropic_api_key_env,
get_custom_llm_api_key_env,
get_custom_llm_url_env,
get_disable_thinking_env,
get_google_api_key_env,
get_ollama_url_env,
get_openai_api_key_env,
get_tool_calls_env,
get_web_grounding_env,
)
from utils.llm_provider import get_llm_provider, get_model
from utils.parsers import parse_bool_or_none
from utils.schema_utils import ensure_strict_json_schema
class LLMClient:
def __init__(self):
self.llm_provider = get_llm_provider()
self._client = self._get_client()
self.tool_calls_handler = LLMToolCallsHandler(self)
# ? Use tool calls
def use_tool_calls_for_structured_output(self) -> bool:
if self.llm_provider != LLMProvider.CUSTOM:
return False
return parse_bool_or_none(get_tool_calls_env()) or False
# ? Web Grounding
def enable_web_grounding(self) -> bool:
if (
self.llm_provider == LLMProvider.OLLAMA
or self.llm_provider == LLMProvider.CUSTOM
):
return False
return parse_bool_or_none(get_web_grounding_env()) or False
# ? Disable thinking
def disable_thinking(self) -> bool:
return parse_bool_or_none(get_disable_thinking_env()) or False
# ? Clients
def _get_client(self):
match self.llm_provider:
case LLMProvider.OPENAI:
return self._get_openai_client()
case LLMProvider.GOOGLE:
return self._get_google_client()
case LLMProvider.ANTHROPIC:
return self._get_anthropic_client()
case LLMProvider.OLLAMA:
return self._get_ollama_client()
case LLMProvider.CUSTOM:
return self._get_custom_client()
case _:
raise HTTPException(
status_code=400,
detail="LLM Provider must be either openai, google, anthropic, ollama, or custom",
)
def _get_openai_client(self):
if not get_openai_api_key_env():
raise HTTPException(
status_code=400,
detail="OpenAI API Key is not set",
)
return AsyncOpenAI()
def _get_google_client(self):
if not get_google_api_key_env():
raise HTTPException(
status_code=400,
detail="Google API Key is not set",
)
return genai.Client()
def _get_anthropic_client(self):
if not get_anthropic_api_key_env():
raise HTTPException(
status_code=400,
detail="Anthropic API Key is not set",
)
return AsyncAnthropic()
def _get_ollama_client(self):
return AsyncOpenAI(
base_url=(get_ollama_url_env() or "http://localhost:11434") + "/v1",
api_key="ollama",
)
def _get_custom_client(self):
if not get_custom_llm_url_env():
raise HTTPException(
status_code=400,
detail="Custom LLM URL is not set",
)
return AsyncOpenAI(
base_url=get_custom_llm_url_env(),
api_key=get_custom_llm_api_key_env() or "null",
)
# ? Prompts
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
for message in messages:
if isinstance(message, LLMSystemMessage):
return message.content
return ""
def _get_google_messages(self, messages: List[LLMMessage]) -> List[str]:
contents = []
for message in messages:
if isinstance(message, LLMUserMessage):
contents.append(
GoogleContent(
role="user", parts=[GoogleContentPart(text=message.content)]
)
)
elif isinstance(message, GoogleAssistantMessage):
contents.append(message.content)
elif isinstance(message, GoogleToolCallMessage):
contents.append(
GoogleContent(
role="user",
parts=[
GoogleContentPart.from_function_response(
name=message.name, response=message.response
)
],
)
)
return contents
def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
return [
message for message in messages if not isinstance(message, LLMSystemMessage)
]
# ? Generate Unstructured Content
async def _generate_openai(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
depth: int = 0,
) -> str | None:
client: AsyncOpenAI = self._client
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
)
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
parsed_tool_calls = [
OpenAIToolCall(
id=tool_call.id,
type=tool_call.type,
function=OpenAIToolCallFunction(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
for tool_call in tool_calls
]
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
parsed_tool_calls
)
assistant_message = OpenAIAssistantMessage(
role="assistant",
content=response.choices[0].message.content,
tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls],
)
new_messages = [
*messages,
assistant_message,
*tool_call_messages,
]
return await self._generate_openai(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
depth=depth + 1,
)
return response.choices[0].message.content
async def _generate_google(
self,
model: str,
messages: List[LLMMessage],
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
) -> str | None:
client: genai.Client = self._client
google_tools = None
if tools:
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
contents=self._get_google_messages(messages),
config=GenerateContentConfig(
tools=google_tools,
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
max_output_tokens=max_tokens,
),
)
content = response.candidates[0].content
response_parts = content.parts
if not response_parts:
return None
text_content = None
tool_calls = []
for each_part in response_parts:
if each_part.function_call:
tool_calls.append(
GoogleToolCall(
name=each_part.function_call.name,
arguments=each_part.function_call.args,
)
)
if each_part.text:
text_content = each_part.text
if tool_calls:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
tool_calls
)
new_messages = [
*messages,
GoogleAssistantMessage(
role="assistant",
content=content,
),
*tool_call_messages,
]
return await self._generate_google(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
)
return text_content
async def _generate_anthropic(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
) -> str | None:
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_anthropic_messages(messages)
],
tools=tools,
max_tokens=max_tokens or 4000,
)
text_content = None
tool_calls: List[AnthropicToolCall] = []
for content in response.content:
if content.type == "text" and isinstance(content.text, str):
text_content = content.text
if content.type == "tool_use":
tool_calls.append(
AnthropicToolCall(
id=content.id,
type=content.type,
name=content.name,
input=content.input,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
return await self._generate_anthropic(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
)
return text_content
async def _generate_ollama(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
depth: int = 0,
):
return await self._generate_openai(
model=model, messages=messages, max_tokens=max_tokens, depth=depth
)
async def _generate_custom(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
depth: int = 0,
):
extra_body = {"enable_thinking": not self.disable_thinking()}
return await self._generate_openai(
model=model,
messages=messages,
max_tokens=max_tokens,
extra_body=extra_body,
depth=depth,
)
async def generate(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
parsed_tools = self.tool_calls_handler.parse_tools(tools)
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.GOOGLE:
content = await self._generate_google(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama(
model=model, messages=messages, max_tokens=max_tokens
)
case LLMProvider.CUSTOM:
content = await self._generate_custom(
model=model, messages=messages, max_tokens=max_tokens
)
if content is None:
raise HTTPException(
status_code=400,
detail="LLM did not return any content",
)
return content
# ? Generate Structured Content
async def _generate_openai_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
depth: int = 0,
) -> dict | None:
client: AsyncOpenAI = self._client
response_schema = response_format
all_tools = [*tools] if tools else None
use_tool_calls_for_structured_output = (
self.use_tool_calls_for_structured_output()
)
if strict and depth == 0:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
if use_tool_calls_for_structured_output and depth == 0:
if all_tools is None:
all_tools = []
all_tools.append(
self.tool_calls_handler.parse_tool(
LLMDynamicTool(
name="ResponseSchema",
description="Provide response to the user",
parameters=response_schema,
handler=do_nothing_async,
),
strict=strict,
)
)
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format=(
{
"type": "json_schema",
"json_schema": (
{
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
}
),
}
if not use_tool_calls_for_structured_output
else None
),
max_completion_tokens=max_tokens,
tools=all_tools,
extra_body=extra_body,
)
content = response.choices[0].message.content
tool_calls = response.choices[0].message.tool_calls
has_response_schema = False
if tool_calls:
for tool_call in tool_calls:
if tool_call.function.name == "ResponseSchema":
content = tool_call.function.arguments
has_response_schema = True
if not has_response_schema:
parsed_tool_calls = [
OpenAIToolCall(
id=tool_call.id,
type=tool_call.type,
function=OpenAIToolCallFunction(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
for tool_call in tool_calls
]
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_openai(
parsed_tool_calls
)
)
new_messages = [
*messages,
OpenAIAssistantMessage(
role="assistant",
content=response.choices[0].message.content,
tool_calls=[each.model_dump() for each in parsed_tool_calls],
),
*tool_call_messages,
]
content = await self._generate_openai_structured(
model=model,
messages=new_messages,
response_format=response_schema,
strict=strict,
max_tokens=max_tokens,
tools=all_tools,
extra_body=extra_body,
depth=depth + 1,
)
if content:
if depth == 0:
return json.loads(content)
return content
return None
async def _generate_google_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
) -> dict | None:
client: genai.Client = self._client
google_tools = None
if tools:
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
google_tools.append(
GoogleTool(
function_declarations=[
{
"name": "ResponseSchema",
"description": "Provide response to the user",
"parameters_json_schema": response_format,
}
]
)
)
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
contents=self._get_google_messages(messages),
config=GenerateContentConfig(
tools=google_tools,
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json" if not tools else None,
response_json_schema=response_format if not tools else None,
max_output_tokens=max_tokens,
),
)
content = response.candidates[0].content
response_parts = content.parts
text_content = None
if not response_parts:
return None
tool_calls: List[GoogleToolCall] = []
for each_part in response_parts:
if each_part.function_call:
tool_calls.append(
GoogleToolCall(
name=each_part.function_call.name,
arguments=each_part.function_call.args,
)
)
if each_part.text:
text_content = each_part.text
for each in tool_calls:
if each.name == "ResponseSchema":
return each.arguments
if tool_calls:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
tool_calls
)
new_messages = [
*messages,
GoogleAssistantMessage(
role="assistant",
content=content,
),
*tool_call_messages,
]
return await self._generate_google_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
)
if text_content:
return json.loads(text_content)
return None
async def _generate_anthropic_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
):
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=[
{
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": response_format,
},
*(tools or []),
],
)
tool_calls: List[AnthropicToolCall] = []
for content in response.content:
if content.type == "tool_use":
tool_calls.append(
AnthropicToolCall(
id=content.id,
type=content.type,
name=content.name,
input=content.input,
)
)
for each in tool_calls:
if each.name == "ResponseSchema":
return each.input
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
return await self._generate_anthropic_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
)
return None
async def _generate_ollama_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
depth: int = 0,
):
return await self._generate_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
depth=depth,
)
async def _generate_custom_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
depth: int = 0,
):
extra_body = {"enable_thinking": not self.disable_thinking()}
return await self._generate_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
extra_body=extra_body,
depth=depth,
)
async def generate_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
) -> dict:
parsed_tools = self.tool_calls_handler.parse_tools(tools)
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
content = await self._generate_google_structured(
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic_structured(
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
case LLMProvider.CUSTOM:
content = await self._generate_custom_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
if content is None:
raise HTTPException(
status_code=400,
detail="LLM did not return any content",
)
return content
# ? Stream Unstructured Content
async def _stream_openai(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
depth: int = 0,
) -> AsyncGenerator[str, None]:
client: AsyncOpenAI = self._client
tool_calls: List[LLMToolCall] = []
current_index = 0
current_id = None
current_name = None
current_arguments = None
async for event in await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
stream=True,
):
event: OpenAIChatCompletionChunk = event
content_chunk = event.choices[0].delta.content
if content_chunk:
yield content_chunk
tool_call_chunk = event.choices[0].delta.tool_calls
if tool_call_chunk:
tool_index = tool_call_chunk[0].index
tool_id = tool_call_chunk[0].id
tool_name = tool_call_chunk[0].function.name
tool_arguments = tool_call_chunk[0].function.arguments
if current_index != tool_index:
tool_calls.append(
OpenAIToolCall(
id=current_id,
type="function",
function=OpenAIToolCallFunction(
name=current_name,
arguments=current_arguments,
),
)
)
current_index = tool_index
current_id = tool_id
current_name = tool_name
current_arguments = tool_arguments
else:
current_name = tool_name or current_name
current_id = tool_id or current_id
if current_arguments is None:
current_arguments = tool_arguments
else:
current_arguments += tool_arguments
if current_id is not None:
tool_calls.append(
OpenAIToolCall(
id=current_id,
type="function",
function=OpenAIToolCallFunction(
name=current_name,
arguments=current_arguments,
),
)
)
if tool_calls:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
tool_calls
)
new_messages = [
*messages,
OpenAIAssistantMessage(
role="assistant",
content=None,
tool_calls=[each.model_dump() for each in tool_calls],
),
*tool_call_messages,
]
async for event in self._stream_openai(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
depth=depth + 1,
):
yield event
async def _stream_google(
self,
model: str,
messages: List[LLMMessage],
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
) -> AsyncGenerator[str, None]:
client: genai.Client = self._client
google_tools = None
if tools:
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
tool_calls = None
async for event in iterator_to_async(client.models.generate_content_stream)(
model=model,
contents=self._get_google_messages(messages),
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
tools=google_tools,
max_output_tokens=max_tokens,
),
):
if event.text:
yield event.text
if event.function_calls:
tool_calls = [
GoogleToolCall(
name=each.name,
arguments=each.args,
)
for each in event.function_calls
]
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_google(tool_calls)
)
new_messages = [
*messages,
GoogleAssistantMessage(
role="assistant",
content=event.candidates[0].content,
),
*tool_call_messages,
]
async for event in self._stream_google(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
):
yield event
async def _stream_anthropic(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
):
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=tools,
) as stream:
tool_calls: List[AnthropicToolCall] = []
async for event in stream:
event: AnthropicMessageStreamEvent = event
if event.type == "text":
yield event.text
if (
event.type == "content_block_stop"
and event.content_block.type == "tool_use"
):
tool_calls.append(
AnthropicToolCall(
id=event.content_block.id,
type=event.content_block.type,
name=event.content_block.name,
input=event.content_block.input,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(
tool_calls
)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
async for event in self._stream_anthropic(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
):
yield event
def _stream_ollama(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
depth: int = 0,
):
return self._stream_openai(
model=model, messages=messages, max_tokens=max_tokens, depth=depth
)
def _stream_custom(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
depth: int = 0,
):
extra_body = {"enable_thinking": not self.disable_thinking()}
return self._stream_openai(
model=model,
messages=messages,
max_tokens=max_tokens,
extra_body=extra_body,
depth=depth,
)
def stream(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
parsed_tools = self.tool_calls_handler.parse_tools(tools)
match self.llm_provider:
case LLMProvider.OPENAI:
return self._stream_openai(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.GOOGLE:
return self._stream_google(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.ANTHROPIC:
return self._stream_anthropic(
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.OLLAMA:
return self._stream_ollama(
model=model, messages=messages, max_tokens=max_tokens
)
case LLMProvider.CUSTOM:
return self._stream_custom(
model=model, messages=messages, max_tokens=max_tokens
)
# ? Stream Structured Content
async def _stream_openai_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
depth: int = 0,
) -> AsyncGenerator[str, None]:
client: AsyncOpenAI = self._client
response_schema = response_format
all_tools = [*tools] if tools else None
use_tool_calls_for_structured_output = (
self.use_tool_calls_for_structured_output()
)
if strict and depth == 0:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
if use_tool_calls_for_structured_output and depth == 0:
if all_tools is None:
all_tools = []
all_tools.append(
self.tool_calls_handler.parse_tool(
LLMDynamicTool(
name="ResponseSchema",
description="Provide response to the user",
parameters=response_schema,
handler=do_nothing_async,
),
strict=strict,
)
)
tool_calls: List[LLMToolCall] = []
current_index = 0
current_id = None
current_name = None
current_arguments = None
has_response_schema_tool_call = False
async for event in await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
tools=all_tools,
response_format=(
{
"type": "json_schema",
"json_schema": (
{
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
}
),
}
if not use_tool_calls_for_structured_output
else None
),
extra_body=extra_body,
stream=True,
):
event: OpenAIChatCompletionChunk = event
content_chunk = event.choices[0].delta.content
if content_chunk:
yield content_chunk
tool_call_chunk = event.choices[0].delta.tool_calls
if tool_call_chunk:
tool_index = tool_call_chunk[0].index
tool_id = tool_call_chunk[0].id
tool_name = tool_call_chunk[0].function.name
tool_arguments = tool_call_chunk[0].function.arguments
if current_index != tool_index:
tool_calls.append(
OpenAIToolCall(
id=current_id,
type="function",
function=OpenAIToolCallFunction(
name=current_name,
arguments=current_arguments,
),
)
)
current_index = tool_index
current_id = tool_id
current_name = tool_name
current_arguments = tool_arguments
else:
current_name = tool_name or current_name
current_id = tool_id or current_id
if current_arguments is None:
current_arguments = tool_arguments
else:
current_arguments += tool_arguments
if current_name == "ResponseSchema":
if tool_arguments:
yield tool_arguments
has_response_schema_tool_call = True
if current_id is not None:
tool_calls.append(
OpenAIToolCall(
id=current_id,
type="function",
function=OpenAIToolCallFunction(
name=current_name,
arguments=current_arguments,
),
)
)
if tool_calls and not has_response_schema_tool_call:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
tool_calls
)
new_messages = [
*messages,
OpenAIAssistantMessage(
role="assistant",
content=None,
tool_calls=[each.model_dump() for each in tool_calls],
),
*tool_call_messages,
]
async for event in self._stream_openai_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
strict=strict,
tools=all_tools,
response_format=response_schema,
extra_body=extra_body,
depth=depth + 1,
):
yield event
async def _stream_google_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
) -> AsyncGenerator[str, None]:
client: genai.Client = self._client
google_tools = None
if tools:
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
google_tools.append(
GoogleTool(
function_declarations=[
{
"name": "ResponseSchema",
"description": "Provide response to the user",
"parameters_json_schema": response_format,
}
]
)
)
tool_calls: List[GoogleToolCall] = []
has_response_schema_tool_call = False
async for event in iterator_to_async(client.models.generate_content_stream)(
model=model,
contents=self._get_google_messages(messages),
config=GenerateContentConfig(
tools=google_tools,
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json" if not tools else None,
response_json_schema=response_format if not tools else None,
max_output_tokens=max_tokens,
),
):
if event.text:
yield event.text
if event.function_calls:
tool_calls = [
GoogleToolCall(
name=each.name,
arguments=each.args,
)
for each in event.function_calls
]
for each in tool_calls:
if each.name == "ResponseSchema":
has_response_schema_tool_call = True
if each.arguments:
yield json.dumps(each.arguments)
if has_response_schema_tool_call:
continue
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_google(tool_calls)
)
new_messages = [
*messages,
GoogleAssistantMessage(
role="assistant",
content=event.candidates[0].content,
),
*tool_call_messages,
]
async for event in self._stream_google_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
):
yield event
async def _stream_anthropic_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
) -> AsyncGenerator[str, None]:
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=[
{
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": response_format,
},
*(tools or []),
],
) as stream:
tool_calls: List[AnthropicToolCall] = []
has_response_schema_tool_call = False
is_response_schema_tool_call_started = False
async for event in stream:
event: AnthropicMessageStreamEvent = event
if (
event.type == "content_block_start"
and event.content_block.type == "tool_use"
):
if event.content_block.name == "ResponseSchema":
has_response_schema_tool_call = True
is_response_schema_tool_call_started = True
if (
event.type == "content_block_delta"
and event.delta.type == "input_json_delta"
and is_response_schema_tool_call_started
):
yield event.delta.partial_json
if has_response_schema_tool_call:
continue
if (
event.type == "content_block_stop"
and event.content_block.type == "tool_use"
):
tool_calls.append(
AnthropicToolCall(
id=event.content_block.id,
type=event.content_block.type,
name=event.content_block.name,
input=event.content_block.input,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(
tool_calls
)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
async for event in self._stream_anthropic_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
):
yield event
def _stream_ollama_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
depth: int = 0,
):
return self._stream_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
depth=depth,
)
def _stream_custom_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
depth: int = 0,
):
extra_body = {"enable_thinking": not self.disable_thinking()}
return self._stream_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
extra_body=extra_body,
depth=depth,
)
def stream_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
):
parsed_tools = self.tool_calls_handler.parse_tools(tools)
match self.llm_provider:
case LLMProvider.OPENAI:
return self._stream_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
return self._stream_google_structured(
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.ANTHROPIC:
return self._stream_anthropic_structured(
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
return self._stream_ollama_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
case LLMProvider.CUSTOM:
return self._stream_custom_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
# ? Web search
async def _search_openai(self, query: str) -> str:
client: AsyncOpenAI = self._client
response = await client.responses.create(
model=get_model(),
tools=[
{
"type": "web_search_preview",
}
],
input=query,
)
return response.output_text
async def _search_google(self, query: str) -> str:
client: genai.Client = self._client
grounding_tool = GoogleTool(google_search=GoogleSearch())
config = GenerateContentConfig(tools=[grounding_tool])
response = await asyncio.to_thread(
client.models.generate_content,
model=get_model(),
contents=query,
config=config,
)
return response.text
async def _search_anthropic(self, query: str) -> str:
client: AsyncAnthropic = self._client
response = await client.messages.create(
model=get_model(),
max_tokens=4000,
messages=[{"role": "user", "content": query}],
tools=[
{"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
],
)
result = "\n".join(
[each.text for each in response.content if each.type == "text"]
)
return result