- Fix admin sidebar: remove duplicate Teams, add Storage nav item - Analytics: client-scoped queries, super_admin sees all (including NULL client_id) - Storage management: list/download/delete presentations with file metadata - Settings page with brand config router - AI usage tracking: new AIUsageModel, ai_usage_service, analytics endpoint - Master deck → template bridge: _register_as_template creates TemplateModel + PresentationLayoutCodeModel so parsed layouts appear in template picker - Multi-provider LLM vision in parser: Anthropic/Google/OpenAI with asyncio.to_thread - Fix PPTX upload 400: accept by .pptx extension (browser sends octet-stream) - Fix reparse FK violation: presentation_id=None for parse_master_deck jobs - Worker job_timeout increased to 1800s for LLM-heavy master deck parsing - PYTHONUNBUFFERED=1 in docker-compose worker for real-time log output - Auth: clientId in /me response, dev-login cookie improvements - Frontend: auth slice clientId, master-deck thumbnails, storage page Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1763 lines
61 KiB
Python
1763 lines
61 KiB
Python
import asyncio
|
|
import dirtyjson
|
|
import json
|
|
from typing import AsyncGenerator, List, Optional
|
|
from fastapi import HTTPException
|
|
from openai import AsyncOpenAI
|
|
from openai.types.chat.chat_completion_chunk import (
|
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
|
)
|
|
from google import genai
|
|
from google.genai.types import Content as GoogleContent, Part as GoogleContentPart
|
|
from google.genai.types import (
|
|
GenerateContentConfig,
|
|
GoogleSearch,
|
|
ToolConfig as GoogleToolConfig,
|
|
FunctionCallingConfig as GoogleFunctionCallingConfig,
|
|
FunctionCallingConfigMode as GoogleFunctionCallingConfigMode,
|
|
)
|
|
from google.genai.types import Tool as GoogleTool
|
|
from anthropic import AsyncAnthropic
|
|
from anthropic.types import Message as AnthropicMessage
|
|
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
|
|
from enums.llm_provider import LLMProvider
|
|
from models.llm_message import (
|
|
AnthropicAssistantMessage,
|
|
AnthropicUserMessage,
|
|
GoogleAssistantMessage,
|
|
GoogleToolCallMessage,
|
|
OpenAIAssistantMessage,
|
|
LLMMessage,
|
|
LLMSystemMessage,
|
|
LLMUserMessage,
|
|
)
|
|
from models.llm_tool_call import (
|
|
AnthropicToolCall,
|
|
GoogleToolCall,
|
|
LLMToolCall,
|
|
OpenAIToolCall,
|
|
OpenAIToolCallFunction,
|
|
)
|
|
from models.llm_tools import LLMDynamicTool, LLMTool
|
|
from services.llm_tool_calls_handler import LLMToolCallsHandler
|
|
from utils.async_iterator import iterator_to_async
|
|
from utils.dummy_functions import do_nothing_async
|
|
from utils.get_env import (
|
|
get_anthropic_api_key_env,
|
|
get_custom_llm_api_key_env,
|
|
get_custom_llm_url_env,
|
|
get_disable_thinking_env,
|
|
get_fallback_llm_providers_env,
|
|
get_google_api_key_env,
|
|
get_ollama_url_env,
|
|
get_openai_api_key_env,
|
|
get_tool_calls_env,
|
|
get_web_grounding_env,
|
|
)
|
|
from utils.llm_provider import get_llm_provider, get_model
|
|
from utils.parsers import parse_bool_or_none
|
|
from utils.schema_utils import (
|
|
ensure_strict_json_schema,
|
|
flatten_json_schema,
|
|
remove_titles_from_schema,
|
|
)
|
|
|
|
|
|
import logging
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_fallback_providers() -> list:
|
|
"""Parse FALLBACK_LLM_PROVIDERS env var into list of LLMProvider enums."""
|
|
raw = get_fallback_llm_providers_env()
|
|
if not raw:
|
|
return []
|
|
providers = []
|
|
for name in raw.split(","):
|
|
name = name.strip().lower()
|
|
try:
|
|
providers.append(LLMProvider(name))
|
|
except ValueError:
|
|
pass
|
|
return providers
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(self):
|
|
self.llm_provider = get_llm_provider()
|
|
self._client = self._get_client()
|
|
self.tool_calls_handler = LLMToolCallsHandler(self)
|
|
|
|
# ? Use tool calls
|
|
def use_tool_calls_for_structured_output(self) -> bool:
|
|
if self.llm_provider != LLMProvider.CUSTOM:
|
|
return False
|
|
return parse_bool_or_none(get_tool_calls_env()) or False
|
|
|
|
# ? Web Grounding
|
|
def enable_web_grounding(self) -> bool:
|
|
if (
|
|
self.llm_provider == LLMProvider.OLLAMA
|
|
or self.llm_provider == LLMProvider.CUSTOM
|
|
):
|
|
return False
|
|
return parse_bool_or_none(get_web_grounding_env()) or False
|
|
|
|
# ? Disable thinking
|
|
def disable_thinking(self) -> bool:
|
|
return parse_bool_or_none(get_disable_thinking_env()) or False
|
|
|
|
# ? Clients
|
|
def _get_client(self):
|
|
match self.llm_provider:
|
|
case LLMProvider.OPENAI:
|
|
return self._get_openai_client()
|
|
case LLMProvider.GOOGLE:
|
|
return self._get_google_client()
|
|
case LLMProvider.ANTHROPIC:
|
|
return self._get_anthropic_client()
|
|
case LLMProvider.OLLAMA:
|
|
return self._get_ollama_client()
|
|
case LLMProvider.CUSTOM:
|
|
return self._get_custom_client()
|
|
case _:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="LLM Provider must be either openai, google, anthropic, ollama, or custom",
|
|
)
|
|
|
|
def _get_openai_client(self):
|
|
if not get_openai_api_key_env():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="OpenAI API Key is not set",
|
|
)
|
|
return AsyncOpenAI()
|
|
|
|
def _get_google_client(self):
|
|
if not get_google_api_key_env():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Google API Key is not set",
|
|
)
|
|
return genai.Client()
|
|
|
|
def _get_anthropic_client(self):
|
|
if not get_anthropic_api_key_env():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Anthropic API Key is not set",
|
|
)
|
|
return AsyncAnthropic()
|
|
|
|
def _get_ollama_client(self):
|
|
return AsyncOpenAI(
|
|
base_url=(get_ollama_url_env() or "http://localhost:11434") + "/v1",
|
|
api_key="ollama",
|
|
)
|
|
|
|
def _get_custom_client(self):
|
|
if not get_custom_llm_url_env():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Custom LLM URL is not set",
|
|
)
|
|
return AsyncOpenAI(
|
|
base_url=get_custom_llm_url_env(),
|
|
api_key=get_custom_llm_api_key_env() or "null",
|
|
)
|
|
|
|
# ? Prompts
|
|
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
|
|
for message in messages:
|
|
if isinstance(message, LLMSystemMessage):
|
|
return message.content
|
|
return ""
|
|
|
|
def _get_google_messages(self, messages: List[LLMMessage]) -> List[GoogleContent]:
|
|
contents = []
|
|
for message in messages:
|
|
if isinstance(message, LLMUserMessage):
|
|
contents.append(
|
|
GoogleContent(
|
|
role=message.role,
|
|
parts=[GoogleContentPart(text=message.content)],
|
|
)
|
|
)
|
|
elif isinstance(message, GoogleAssistantMessage):
|
|
contents.append(message.content)
|
|
elif isinstance(message, GoogleToolCallMessage):
|
|
contents.append(
|
|
GoogleContent(
|
|
role="user",
|
|
parts=[
|
|
GoogleContentPart.from_function_response(
|
|
name=message.name,
|
|
response=message.response,
|
|
)
|
|
],
|
|
)
|
|
)
|
|
|
|
return contents
|
|
|
|
def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
|
return [
|
|
message for message in messages if not isinstance(message, LLMSystemMessage)
|
|
]
|
|
|
|
# ? Generate Unstructured Content
|
|
async def _generate_openai(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
extra_body: Optional[dict] = None,
|
|
depth: int = 0,
|
|
) -> str | None:
|
|
client: AsyncOpenAI = self._client
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=[message.model_dump() for message in messages],
|
|
max_completion_tokens=max_tokens,
|
|
tools=tools,
|
|
extra_body=extra_body,
|
|
)
|
|
|
|
if len(response.choices) == 0:
|
|
return None
|
|
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
if tool_calls:
|
|
parsed_tool_calls = [
|
|
OpenAIToolCall(
|
|
id=tool_call.id,
|
|
type=tool_call.type,
|
|
function=OpenAIToolCallFunction(
|
|
name=tool_call.function.name,
|
|
arguments=tool_call.function.arguments,
|
|
),
|
|
)
|
|
for tool_call in tool_calls
|
|
]
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
|
parsed_tool_calls
|
|
)
|
|
assistant_message = OpenAIAssistantMessage(
|
|
role="assistant",
|
|
content=response.choices[0].message.content,
|
|
tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls],
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
assistant_message,
|
|
*tool_call_messages,
|
|
]
|
|
return await self._generate_openai(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
extra_body=extra_body,
|
|
depth=depth + 1,
|
|
)
|
|
|
|
# Log AI usage
|
|
from services import ai_usage_service
|
|
usage = getattr(response, "usage", None)
|
|
ai_usage_service.log(
|
|
provider="openai", model=model, call_type="generate",
|
|
input_tokens=getattr(usage, "prompt_tokens", None) if usage else None,
|
|
output_tokens=getattr(usage, "completion_tokens", None) if usage else None,
|
|
total_tokens=getattr(usage, "total_tokens", None) if usage else None,
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
async def _generate_google(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
tools: Optional[List[dict]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
) -> str | None:
|
|
client: genai.Client = self._client
|
|
|
|
google_tools = None
|
|
if tools:
|
|
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
|
|
|
|
response = await asyncio.to_thread(
|
|
client.models.generate_content,
|
|
model=model,
|
|
contents=self._get_google_messages(messages),
|
|
config=GenerateContentConfig(
|
|
tools=google_tools,
|
|
system_instruction=self._get_system_prompt(messages),
|
|
response_mime_type="text/plain",
|
|
max_output_tokens=max_tokens,
|
|
),
|
|
)
|
|
|
|
content = response.candidates[0].content
|
|
response_parts = content.parts
|
|
|
|
if not response_parts:
|
|
return None
|
|
|
|
text_content = None
|
|
tool_calls = []
|
|
for each_part in response_parts:
|
|
if each_part.function_call:
|
|
tool_calls.append(
|
|
GoogleToolCall(
|
|
id=each_part.function_call.id,
|
|
name=each_part.function_call.name,
|
|
arguments=each_part.function_call.args,
|
|
)
|
|
)
|
|
if each_part.text:
|
|
text_content = each_part.text
|
|
|
|
if tool_calls:
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
|
|
tool_calls
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
GoogleAssistantMessage(
|
|
role="assistant",
|
|
content=content,
|
|
),
|
|
*tool_call_messages,
|
|
]
|
|
return await self._generate_google(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
)
|
|
|
|
# Log AI usage
|
|
from services import ai_usage_service
|
|
usage_meta = getattr(response, "usage_metadata", None)
|
|
ai_usage_service.log(
|
|
provider="google", model=model, call_type="generate",
|
|
input_tokens=getattr(usage_meta, "prompt_token_count", None) if usage_meta else None,
|
|
output_tokens=getattr(usage_meta, "candidates_token_count", None) if usage_meta else None,
|
|
total_tokens=getattr(usage_meta, "total_token_count", None) if usage_meta else None,
|
|
)
|
|
return text_content
|
|
|
|
async def _generate_anthropic(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
depth: int = 0,
|
|
) -> str | None:
|
|
client: AsyncAnthropic = self._client
|
|
|
|
response: AnthropicMessage = await client.messages.create(
|
|
model=model,
|
|
system=self._get_system_prompt(messages),
|
|
messages=[
|
|
message.model_dump()
|
|
for message in self._get_anthropic_messages(messages)
|
|
],
|
|
tools=tools,
|
|
max_tokens=max_tokens or 4000,
|
|
)
|
|
text_content = None
|
|
tool_calls: List[AnthropicToolCall] = []
|
|
for content in response.content:
|
|
if content.type == "text" and isinstance(content.text, str):
|
|
text_content = content.text
|
|
|
|
if content.type == "tool_use":
|
|
tool_calls.append(
|
|
AnthropicToolCall(
|
|
id=content.id,
|
|
type=content.type,
|
|
name=content.name,
|
|
input=content.input,
|
|
)
|
|
)
|
|
|
|
if tool_calls:
|
|
tool_call_messages = (
|
|
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
AnthropicAssistantMessage(
|
|
role="assistant",
|
|
content=[each.model_dump() for each in tool_calls],
|
|
),
|
|
AnthropicUserMessage(
|
|
role="user",
|
|
content=[each.model_dump() for each in tool_call_messages],
|
|
),
|
|
]
|
|
return await self._generate_anthropic(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
)
|
|
|
|
# Log AI usage
|
|
from services import ai_usage_service
|
|
ai_usage_service.log(
|
|
provider="anthropic", model=model, call_type="generate",
|
|
input_tokens=response.usage.input_tokens,
|
|
output_tokens=response.usage.output_tokens,
|
|
)
|
|
return text_content
|
|
|
|
async def _generate_ollama(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
return await self._generate_openai(
|
|
model=model, messages=messages, max_tokens=max_tokens, depth=depth
|
|
)
|
|
|
|
async def _generate_custom(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
extra_body = {"enable_thinking": False} if self.disable_thinking() else None
|
|
return await self._generate_openai(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
extra_body=extra_body,
|
|
depth=depth,
|
|
)
|
|
|
|
async def generate(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
|
):
|
|
try:
|
|
return await self._generate_impl(model, messages, max_tokens, tools)
|
|
except Exception as primary_error:
|
|
fallback_providers = _get_fallback_providers()
|
|
# Remove primary provider from fallback list
|
|
fallback_providers = [
|
|
p for p in fallback_providers if p != self.llm_provider
|
|
]
|
|
for fb_provider in fallback_providers:
|
|
_logger.warning(
|
|
"LLM generate failed with %s: %s. Trying fallback %s",
|
|
self.llm_provider.value, str(primary_error)[:200], fb_provider.value,
|
|
)
|
|
try:
|
|
fb_client = LLMClient.__new__(LLMClient)
|
|
fb_client.llm_provider = fb_provider
|
|
fb_client._client = fb_client._get_client()
|
|
fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client)
|
|
fb_model = get_model()
|
|
return await fb_client._generate_impl(
|
|
fb_model, messages, max_tokens, tools
|
|
)
|
|
except Exception as fb_error:
|
|
_logger.warning(
|
|
"Fallback %s also failed: %s",
|
|
fb_provider.value, str(fb_error)[:200],
|
|
)
|
|
continue
|
|
raise primary_error
|
|
|
|
async def _generate_impl(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
|
):
|
|
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
|
|
|
content = None
|
|
match self.llm_provider:
|
|
case LLMProvider.OPENAI:
|
|
content = await self._generate_openai(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
tools=parsed_tools,
|
|
)
|
|
case LLMProvider.GOOGLE:
|
|
content = await self._generate_google(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
tools=parsed_tools,
|
|
)
|
|
case LLMProvider.ANTHROPIC:
|
|
content = await self._generate_anthropic(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
tools=parsed_tools,
|
|
)
|
|
case LLMProvider.OLLAMA:
|
|
content = await self._generate_ollama(
|
|
model=model, messages=messages, max_tokens=max_tokens
|
|
)
|
|
case LLMProvider.CUSTOM:
|
|
content = await self._generate_custom(
|
|
model=model, messages=messages, max_tokens=max_tokens
|
|
)
|
|
if content is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="LLM did not return any content",
|
|
)
|
|
return content
|
|
|
|
# ? Generate Structured Content
|
|
async def _generate_openai_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
extra_body: Optional[dict] = None,
|
|
depth: int = 0,
|
|
) -> dict | None:
|
|
client: AsyncOpenAI = self._client
|
|
response_schema = response_format
|
|
all_tools = [*tools] if tools else None
|
|
|
|
use_tool_calls_for_structured_output = (
|
|
self.use_tool_calls_for_structured_output()
|
|
)
|
|
if strict and depth == 0:
|
|
response_schema = ensure_strict_json_schema(
|
|
response_schema,
|
|
path=(),
|
|
root=response_schema,
|
|
)
|
|
if use_tool_calls_for_structured_output and depth == 0:
|
|
if all_tools is None:
|
|
all_tools = []
|
|
all_tools.append(
|
|
self.tool_calls_handler.parse_tool(
|
|
LLMDynamicTool(
|
|
name="ResponseSchema",
|
|
description="Provide response to the user",
|
|
parameters=response_schema,
|
|
handler=do_nothing_async,
|
|
),
|
|
strict=strict,
|
|
)
|
|
)
|
|
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=[message.model_dump() for message in messages],
|
|
response_format=(
|
|
{
|
|
"type": "json_schema",
|
|
"json_schema": (
|
|
{
|
|
"name": "ResponseSchema",
|
|
"strict": strict,
|
|
"schema": response_schema,
|
|
}
|
|
),
|
|
}
|
|
if not use_tool_calls_for_structured_output
|
|
else None
|
|
),
|
|
max_completion_tokens=max_tokens,
|
|
tools=all_tools,
|
|
extra_body=extra_body,
|
|
)
|
|
|
|
if len(response.choices) == 0:
|
|
return None
|
|
|
|
content = response.choices[0].message.content
|
|
|
|
tool_calls = response.choices[0].message.tool_calls
|
|
has_response_schema = False
|
|
|
|
if tool_calls:
|
|
for tool_call in tool_calls:
|
|
if tool_call.function.name == "ResponseSchema":
|
|
content = tool_call.function.arguments
|
|
has_response_schema = True
|
|
|
|
if not has_response_schema:
|
|
parsed_tool_calls = [
|
|
OpenAIToolCall(
|
|
id=tool_call.id,
|
|
type=tool_call.type,
|
|
function=OpenAIToolCallFunction(
|
|
name=tool_call.function.name,
|
|
arguments=tool_call.function.arguments,
|
|
),
|
|
)
|
|
for tool_call in tool_calls
|
|
]
|
|
tool_call_messages = (
|
|
await self.tool_calls_handler.handle_tool_calls_openai(
|
|
parsed_tool_calls
|
|
)
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
OpenAIAssistantMessage(
|
|
role="assistant",
|
|
content=response.choices[0].message.content,
|
|
tool_calls=[each.model_dump() for each in parsed_tool_calls],
|
|
),
|
|
*tool_call_messages,
|
|
]
|
|
content = await self._generate_openai_structured(
|
|
model=model,
|
|
messages=new_messages,
|
|
response_format=response_schema,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
tools=all_tools,
|
|
extra_body=extra_body,
|
|
depth=depth + 1,
|
|
)
|
|
if content:
|
|
# Log AI usage at the top-level call
|
|
if depth == 0:
|
|
from services import ai_usage_service
|
|
usage = getattr(response, "usage", None)
|
|
ai_usage_service.log(
|
|
provider="openai", model=model, call_type="generate_structured",
|
|
input_tokens=getattr(usage, "prompt_tokens", None) if usage else None,
|
|
output_tokens=getattr(usage, "completion_tokens", None) if usage else None,
|
|
total_tokens=getattr(usage, "total_tokens", None) if usage else None,
|
|
)
|
|
return dict(dirtyjson.loads(content))
|
|
return content
|
|
return None
|
|
|
|
async def _generate_google_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
depth: int = 0,
|
|
) -> dict | None:
|
|
client: genai.Client = self._client
|
|
|
|
google_tools = None
|
|
if tools:
|
|
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
|
|
google_tools.append(
|
|
GoogleTool(
|
|
function_declarations=[
|
|
{
|
|
"name": "ResponseSchema",
|
|
"description": "Provide response to the user",
|
|
"parameters": remove_titles_from_schema(
|
|
flatten_json_schema(response_format)
|
|
),
|
|
}
|
|
]
|
|
)
|
|
)
|
|
|
|
response = await asyncio.to_thread(
|
|
client.models.generate_content,
|
|
model=model,
|
|
contents=self._get_google_messages(messages),
|
|
config=GenerateContentConfig(
|
|
tools=google_tools,
|
|
tool_config=(
|
|
GoogleToolConfig(
|
|
function_calling_config=GoogleFunctionCallingConfig(
|
|
mode=GoogleFunctionCallingConfigMode.ANY,
|
|
),
|
|
)
|
|
if tools
|
|
else None
|
|
),
|
|
system_instruction=self._get_system_prompt(messages),
|
|
response_mime_type="application/json" if not tools else None,
|
|
response_json_schema=response_format if not tools else None,
|
|
max_output_tokens=max_tokens,
|
|
),
|
|
)
|
|
|
|
content = response.candidates[0].content
|
|
response_parts = content.parts
|
|
text_content = None
|
|
|
|
if not response_parts:
|
|
return None
|
|
|
|
tool_calls: List[GoogleToolCall] = []
|
|
for each_part in response_parts:
|
|
if each_part.function_call:
|
|
tool_calls.append(
|
|
GoogleToolCall(
|
|
id=each_part.function_call.id,
|
|
name=each_part.function_call.name,
|
|
arguments=each_part.function_call.args,
|
|
)
|
|
)
|
|
|
|
if each_part.text:
|
|
text_content = each_part.text
|
|
|
|
for each in tool_calls:
|
|
if each.name == "ResponseSchema":
|
|
return each.arguments
|
|
|
|
if tool_calls:
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
|
|
tool_calls
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
GoogleAssistantMessage(
|
|
role="assistant",
|
|
content=content,
|
|
),
|
|
*tool_call_messages,
|
|
]
|
|
return await self._generate_google_structured(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
)
|
|
|
|
if text_content:
|
|
# Log AI usage
|
|
from services import ai_usage_service
|
|
usage_meta = getattr(response, "usage_metadata", None)
|
|
ai_usage_service.log(
|
|
provider="google", model=model, call_type="generate_structured",
|
|
input_tokens=getattr(usage_meta, "prompt_token_count", None) if usage_meta else None,
|
|
output_tokens=getattr(usage_meta, "candidates_token_count", None) if usage_meta else None,
|
|
total_tokens=getattr(usage_meta, "total_token_count", None) if usage_meta else None,
|
|
)
|
|
return dict(dirtyjson.loads(text_content))
|
|
return None
|
|
|
|
async def _generate_anthropic_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
tools: Optional[List[dict]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
client: AsyncAnthropic = self._client
|
|
response: AnthropicMessage = await client.messages.create(
|
|
model=model,
|
|
system=self._get_system_prompt(messages),
|
|
messages=[
|
|
message.model_dump()
|
|
for message in self._get_anthropic_messages(messages)
|
|
],
|
|
max_tokens=max_tokens or 4000,
|
|
tools=[
|
|
{
|
|
"name": "ResponseSchema",
|
|
"description": "A response to the user's message",
|
|
"input_schema": response_format,
|
|
},
|
|
*(tools or []),
|
|
],
|
|
)
|
|
tool_calls: List[AnthropicToolCall] = []
|
|
for content in response.content:
|
|
if content.type == "tool_use":
|
|
tool_calls.append(
|
|
AnthropicToolCall(
|
|
id=content.id,
|
|
type=content.type,
|
|
name=content.name,
|
|
input=content.input,
|
|
)
|
|
)
|
|
|
|
for each in tool_calls:
|
|
if each.name == "ResponseSchema":
|
|
# Log AI usage
|
|
from services import ai_usage_service
|
|
ai_usage_service.log(
|
|
provider="anthropic", model=model, call_type="generate_structured",
|
|
input_tokens=response.usage.input_tokens,
|
|
output_tokens=response.usage.output_tokens,
|
|
)
|
|
return each.input
|
|
|
|
if tool_calls:
|
|
tool_call_messages = (
|
|
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
AnthropicAssistantMessage(
|
|
role="assistant",
|
|
content=[each.model_dump() for each in tool_calls],
|
|
),
|
|
AnthropicUserMessage(
|
|
role="user",
|
|
content=[each.model_dump() for each in tool_call_messages],
|
|
),
|
|
]
|
|
return await self._generate_anthropic_structured(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
)
|
|
|
|
return None
|
|
|
|
async def _generate_ollama_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
return await self._generate_openai_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
depth=depth,
|
|
)
|
|
|
|
async def _generate_custom_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
extra_body = {"enable_thinking": False} if self.disable_thinking() else None
|
|
return await self._generate_openai_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
extra_body=extra_body,
|
|
depth=depth,
|
|
)
|
|
|
|
async def generate_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
) -> dict:
|
|
try:
|
|
return await self._generate_structured_impl(
|
|
model, messages, response_format, strict, tools, max_tokens
|
|
)
|
|
except Exception as primary_error:
|
|
fallback_providers = _get_fallback_providers()
|
|
fallback_providers = [
|
|
p for p in fallback_providers if p != self.llm_provider
|
|
]
|
|
for fb_provider in fallback_providers:
|
|
_logger.warning(
|
|
"LLM generate_structured failed with %s: %s. Trying fallback %s",
|
|
self.llm_provider.value, str(primary_error)[:200], fb_provider.value,
|
|
)
|
|
try:
|
|
fb_client = LLMClient.__new__(LLMClient)
|
|
fb_client.llm_provider = fb_provider
|
|
fb_client._client = fb_client._get_client()
|
|
fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client)
|
|
fb_model = get_model()
|
|
return await fb_client._generate_structured_impl(
|
|
fb_model, messages, response_format, strict, tools, max_tokens
|
|
)
|
|
except Exception as fb_error:
|
|
_logger.warning(
|
|
"Fallback %s also failed: %s",
|
|
fb_provider.value, str(fb_error)[:200],
|
|
)
|
|
continue
|
|
raise primary_error
|
|
|
|
async def _generate_structured_impl(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
) -> dict:
|
|
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
|
|
|
content = None
|
|
match self.llm_provider:
|
|
case LLMProvider.OPENAI:
|
|
content = await self._generate_openai_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
tools=parsed_tools,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.GOOGLE:
|
|
content = await self._generate_google_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
tools=parsed_tools,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.ANTHROPIC:
|
|
content = await self._generate_anthropic_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
tools=parsed_tools,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.OLLAMA:
|
|
content = await self._generate_ollama_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.CUSTOM:
|
|
content = await self._generate_custom_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
)
|
|
if content is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="LLM did not return any content",
|
|
)
|
|
return content
|
|
|
|
# ? Stream Unstructured Content
|
|
async def _stream_openai(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
extra_body: Optional[dict] = None,
|
|
depth: int = 0,
|
|
) -> AsyncGenerator[str, None]:
|
|
client: AsyncOpenAI = self._client
|
|
|
|
tool_calls: List[LLMToolCall] = []
|
|
current_index = 0
|
|
current_id = None
|
|
current_name = None
|
|
current_arguments = None
|
|
async for event in await client.chat.completions.create(
|
|
model=model,
|
|
messages=[message.model_dump() for message in messages],
|
|
max_completion_tokens=max_tokens,
|
|
tools=tools,
|
|
extra_body=extra_body,
|
|
stream=True,
|
|
):
|
|
event: OpenAIChatCompletionChunk = event
|
|
if not event.choices:
|
|
continue
|
|
|
|
content_chunk = event.choices[0].delta.content
|
|
if content_chunk:
|
|
yield content_chunk
|
|
|
|
tool_call_chunk = event.choices[0].delta.tool_calls
|
|
if tool_call_chunk:
|
|
tool_index = tool_call_chunk[0].index
|
|
tool_id = tool_call_chunk[0].id
|
|
tool_name = tool_call_chunk[0].function.name
|
|
tool_arguments = tool_call_chunk[0].function.arguments
|
|
|
|
if current_index != tool_index:
|
|
tool_calls.append(
|
|
OpenAIToolCall(
|
|
id=current_id,
|
|
type="function",
|
|
function=OpenAIToolCallFunction(
|
|
name=current_name,
|
|
arguments=current_arguments,
|
|
),
|
|
)
|
|
)
|
|
current_index = tool_index
|
|
current_id = tool_id
|
|
current_name = tool_name
|
|
current_arguments = tool_arguments
|
|
else:
|
|
current_name = tool_name or current_name
|
|
current_id = tool_id or current_id
|
|
if current_arguments is None:
|
|
current_arguments = tool_arguments
|
|
elif tool_arguments:
|
|
current_arguments += tool_arguments
|
|
|
|
if current_id is not None:
|
|
tool_calls.append(
|
|
OpenAIToolCall(
|
|
id=current_id,
|
|
type="function",
|
|
function=OpenAIToolCallFunction(
|
|
name=current_name,
|
|
arguments=current_arguments,
|
|
),
|
|
)
|
|
)
|
|
|
|
if tool_calls:
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
|
tool_calls
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
OpenAIAssistantMessage(
|
|
role="assistant",
|
|
content=None,
|
|
tool_calls=[each.model_dump() for each in tool_calls],
|
|
),
|
|
*tool_call_messages,
|
|
]
|
|
async for event in self._stream_openai(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
extra_body=extra_body,
|
|
depth=depth + 1,
|
|
):
|
|
yield event
|
|
|
|
async def _stream_google(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
tools: Optional[List[dict]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
) -> AsyncGenerator[str, None]:
|
|
client: genai.Client = self._client
|
|
|
|
google_tools = None
|
|
if tools:
|
|
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
|
|
|
|
generated_contents = []
|
|
tool_calls: List[GoogleToolCall] = []
|
|
async for event in iterator_to_async(client.models.generate_content_stream)(
|
|
model=model,
|
|
contents=self._get_google_messages(messages),
|
|
config=GenerateContentConfig(
|
|
system_instruction=self._get_system_prompt(messages),
|
|
response_mime_type="text/plain",
|
|
tools=google_tools,
|
|
max_output_tokens=max_tokens,
|
|
),
|
|
):
|
|
if not (
|
|
event.candidates
|
|
and event.candidates[0].content
|
|
and event.candidates[0].content.parts
|
|
):
|
|
continue
|
|
|
|
generated_contents.append(event.candidates[0].content)
|
|
|
|
for each_part in event.candidates[0].content.parts:
|
|
if each_part.text:
|
|
yield each_part.text
|
|
|
|
if each_part.function_call:
|
|
tool_calls.append(
|
|
GoogleToolCall(
|
|
id=each_part.function_call.id,
|
|
name=each_part.function_call.name,
|
|
arguments=each_part.function_call.args,
|
|
)
|
|
)
|
|
|
|
if tool_calls:
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
|
|
tool_calls
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
*[
|
|
GoogleAssistantMessage(
|
|
role="assistant",
|
|
content=each,
|
|
)
|
|
for each in generated_contents
|
|
],
|
|
*tool_call_messages,
|
|
]
|
|
async for event in self._stream_google(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
):
|
|
yield event
|
|
|
|
async def _stream_anthropic(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
depth: int = 0,
|
|
):
|
|
client: AsyncAnthropic = self._client
|
|
|
|
tool_calls: List[AnthropicToolCall] = []
|
|
async with client.messages.stream(
|
|
model=model,
|
|
system=self._get_system_prompt(messages),
|
|
messages=[
|
|
message.model_dump()
|
|
for message in self._get_anthropic_messages(messages)
|
|
],
|
|
max_tokens=max_tokens or 4000,
|
|
tools=tools,
|
|
) as stream:
|
|
async for event in stream:
|
|
event: AnthropicMessageStreamEvent = event
|
|
|
|
if event.type == "text":
|
|
yield event.text
|
|
|
|
if (
|
|
event.type == "content_block_stop"
|
|
and event.content_block.type == "tool_use"
|
|
):
|
|
tool_calls.append(
|
|
AnthropicToolCall(
|
|
id=event.content_block.id,
|
|
type=event.content_block.type,
|
|
name=event.content_block.name,
|
|
input=event.content_block.input,
|
|
)
|
|
)
|
|
|
|
if tool_calls:
|
|
tool_call_messages = (
|
|
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
AnthropicAssistantMessage(
|
|
role="assistant",
|
|
content=[each.model_dump() for each in tool_calls],
|
|
),
|
|
AnthropicUserMessage(
|
|
role="user",
|
|
content=[each.model_dump() for each in tool_call_messages],
|
|
),
|
|
]
|
|
async for event in self._stream_anthropic(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
):
|
|
yield event
|
|
|
|
def _stream_ollama(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
return self._stream_openai(
|
|
model=model, messages=messages, max_tokens=max_tokens, depth=depth
|
|
)
|
|
|
|
def _stream_custom(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
extra_body = {"enable_thinking": False} if self.disable_thinking() else None
|
|
return self._stream_openai(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
extra_body=extra_body,
|
|
depth=depth,
|
|
)
|
|
|
|
def stream(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
|
):
|
|
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
|
|
|
match self.llm_provider:
|
|
case LLMProvider.OPENAI:
|
|
return self._stream_openai(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
tools=parsed_tools,
|
|
)
|
|
case LLMProvider.GOOGLE:
|
|
return self._stream_google(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
tools=parsed_tools,
|
|
)
|
|
case LLMProvider.ANTHROPIC:
|
|
return self._stream_anthropic(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
tools=parsed_tools,
|
|
)
|
|
case LLMProvider.OLLAMA:
|
|
return self._stream_ollama(
|
|
model=model, messages=messages, max_tokens=max_tokens
|
|
)
|
|
case LLMProvider.CUSTOM:
|
|
return self._stream_custom(
|
|
model=model, messages=messages, max_tokens=max_tokens
|
|
)
|
|
|
|
# ? Stream Structured Content
|
|
async def _stream_openai_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
extra_body: Optional[dict] = None,
|
|
depth: int = 0,
|
|
) -> AsyncGenerator[str, None]:
|
|
client: AsyncOpenAI = self._client
|
|
|
|
response_schema = response_format
|
|
all_tools = [*tools] if tools else None
|
|
|
|
use_tool_calls_for_structured_output = (
|
|
self.use_tool_calls_for_structured_output()
|
|
)
|
|
if strict and depth == 0:
|
|
response_schema = ensure_strict_json_schema(
|
|
response_schema,
|
|
path=(),
|
|
root=response_schema,
|
|
)
|
|
|
|
if use_tool_calls_for_structured_output and depth == 0:
|
|
if all_tools is None:
|
|
all_tools = []
|
|
all_tools.append(
|
|
self.tool_calls_handler.parse_tool(
|
|
LLMDynamicTool(
|
|
name="ResponseSchema",
|
|
description="Provide response to the user",
|
|
parameters=response_schema,
|
|
handler=do_nothing_async,
|
|
),
|
|
strict=strict,
|
|
)
|
|
)
|
|
|
|
tool_calls: List[LLMToolCall] = []
|
|
current_index = 0
|
|
current_id = None
|
|
current_name = None
|
|
current_arguments = None
|
|
|
|
has_response_schema_tool_call = False
|
|
async for event in await client.chat.completions.create(
|
|
model=model,
|
|
messages=[message.model_dump() for message in messages],
|
|
max_completion_tokens=max_tokens,
|
|
tools=all_tools,
|
|
response_format=(
|
|
{
|
|
"type": "json_schema",
|
|
"json_schema": (
|
|
{
|
|
"name": "ResponseSchema",
|
|
"strict": strict,
|
|
"schema": response_schema,
|
|
}
|
|
),
|
|
}
|
|
if not use_tool_calls_for_structured_output
|
|
else None
|
|
),
|
|
extra_body=extra_body,
|
|
stream=True,
|
|
):
|
|
event: OpenAIChatCompletionChunk = event
|
|
if not event.choices:
|
|
continue
|
|
|
|
content_chunk = event.choices[0].delta.content
|
|
if content_chunk and not use_tool_calls_for_structured_output:
|
|
yield content_chunk
|
|
|
|
tool_call_chunk = event.choices[0].delta.tool_calls
|
|
if tool_call_chunk:
|
|
tool_index = tool_call_chunk[0].index
|
|
tool_id = tool_call_chunk[0].id
|
|
tool_name = tool_call_chunk[0].function.name
|
|
tool_arguments = tool_call_chunk[0].function.arguments
|
|
|
|
if current_index != tool_index:
|
|
tool_calls.append(
|
|
OpenAIToolCall(
|
|
id=current_id,
|
|
type="function",
|
|
function=OpenAIToolCallFunction(
|
|
name=current_name,
|
|
arguments=current_arguments,
|
|
),
|
|
)
|
|
)
|
|
current_index = tool_index
|
|
current_id = tool_id
|
|
current_name = tool_name
|
|
current_arguments = tool_arguments
|
|
else:
|
|
current_name = tool_name or current_name
|
|
current_id = tool_id or current_id
|
|
if current_arguments is None:
|
|
current_arguments = tool_arguments
|
|
elif tool_arguments:
|
|
current_arguments += tool_arguments
|
|
|
|
if current_name == "ResponseSchema":
|
|
if tool_arguments:
|
|
yield tool_arguments
|
|
has_response_schema_tool_call = True
|
|
|
|
if current_id is not None:
|
|
tool_calls.append(
|
|
OpenAIToolCall(
|
|
id=current_id,
|
|
type="function",
|
|
function=OpenAIToolCallFunction(
|
|
name=current_name,
|
|
arguments=current_arguments,
|
|
),
|
|
)
|
|
)
|
|
|
|
if tool_calls and not has_response_schema_tool_call:
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
|
tool_calls
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
OpenAIAssistantMessage(
|
|
role="assistant",
|
|
content=None,
|
|
tool_calls=[each.model_dump() for each in tool_calls],
|
|
),
|
|
*tool_call_messages,
|
|
]
|
|
async for event in self._stream_openai_structured(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
strict=strict,
|
|
tools=all_tools,
|
|
response_format=response_schema,
|
|
extra_body=extra_body,
|
|
depth=depth + 1,
|
|
):
|
|
yield event
|
|
|
|
async def _stream_google_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[dict]] = None,
|
|
depth: int = 0,
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
client: genai.Client = self._client
|
|
|
|
google_tools = None
|
|
if tools:
|
|
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
|
|
google_tools.append(
|
|
GoogleTool(
|
|
function_declarations=[
|
|
{
|
|
"name": "ResponseSchema",
|
|
"description": "Provide response to the user",
|
|
"parameters": remove_titles_from_schema(
|
|
flatten_json_schema(response_format)
|
|
),
|
|
}
|
|
]
|
|
)
|
|
)
|
|
|
|
parsed_messages = self._get_google_messages(messages)
|
|
|
|
generated_contents = []
|
|
tool_calls: List[GoogleToolCall] = []
|
|
has_response_schema_tool_call = False
|
|
async for event in iterator_to_async(client.models.generate_content_stream)(
|
|
model=model,
|
|
contents=parsed_messages,
|
|
config=GenerateContentConfig(
|
|
tools=google_tools,
|
|
tool_config=(
|
|
GoogleToolConfig(
|
|
function_calling_config=GoogleFunctionCallingConfig(
|
|
mode=GoogleFunctionCallingConfigMode.ANY,
|
|
),
|
|
)
|
|
if tools
|
|
else None
|
|
),
|
|
system_instruction=self._get_system_prompt(messages),
|
|
response_mime_type="application/json" if not tools else None,
|
|
response_json_schema=response_format if not tools else None,
|
|
max_output_tokens=max_tokens,
|
|
),
|
|
):
|
|
if not (
|
|
event.candidates
|
|
and event.candidates[0].content
|
|
and event.candidates[0].content.parts
|
|
):
|
|
continue
|
|
|
|
generated_contents.append(event.candidates[0].content)
|
|
|
|
for each_part in event.candidates[0].content.parts:
|
|
if each_part.text and not google_tools:
|
|
yield each_part.text
|
|
|
|
if each_part.function_call:
|
|
if each_part.function_call.name == "ResponseSchema":
|
|
has_response_schema_tool_call = True
|
|
if each_part.function_call.args:
|
|
yield json.dumps(each_part.function_call.args)
|
|
|
|
tool_calls.append(
|
|
GoogleToolCall(
|
|
id=each_part.function_call.id,
|
|
name=each_part.function_call.name,
|
|
arguments=each_part.function_call.args,
|
|
)
|
|
)
|
|
|
|
if tool_calls and not has_response_schema_tool_call:
|
|
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
|
|
tool_calls
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
*[
|
|
GoogleAssistantMessage(
|
|
role="assistant",
|
|
content=each,
|
|
)
|
|
for each in generated_contents
|
|
],
|
|
*tool_call_messages,
|
|
]
|
|
async for event in self._stream_google_structured(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
):
|
|
yield event
|
|
|
|
async def _stream_anthropic_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
tools: Optional[List[dict]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
) -> AsyncGenerator[str, None]:
|
|
client: AsyncAnthropic = self._client
|
|
|
|
tool_calls: List[AnthropicToolCall] = []
|
|
has_response_schema_tool_call = False
|
|
async with client.messages.stream(
|
|
model=model,
|
|
system=self._get_system_prompt(messages),
|
|
messages=[
|
|
message.model_dump()
|
|
for message in self._get_anthropic_messages(messages)
|
|
],
|
|
max_tokens=max_tokens or 4000,
|
|
tools=[
|
|
{
|
|
"name": "ResponseSchema",
|
|
"description": "A response to the user's message",
|
|
"input_schema": response_format,
|
|
},
|
|
*(tools or []),
|
|
],
|
|
) as stream:
|
|
is_response_schema_tool_call_started = False
|
|
async for event in stream:
|
|
event: AnthropicMessageStreamEvent = event
|
|
|
|
if (
|
|
event.type == "content_block_start"
|
|
and event.content_block.type == "tool_use"
|
|
):
|
|
if event.content_block.name == "ResponseSchema":
|
|
has_response_schema_tool_call = True
|
|
is_response_schema_tool_call_started = True
|
|
|
|
if (
|
|
event.type == "content_block_delta"
|
|
and event.delta.type == "input_json_delta"
|
|
and is_response_schema_tool_call_started
|
|
):
|
|
yield event.delta.partial_json
|
|
|
|
if (
|
|
event.type == "content_block_stop"
|
|
and event.content_block.type == "tool_use"
|
|
):
|
|
tool_calls.append(
|
|
AnthropicToolCall(
|
|
id=event.content_block.id,
|
|
type=event.content_block.type,
|
|
name=event.content_block.name,
|
|
input=event.content_block.input,
|
|
)
|
|
)
|
|
|
|
if tool_calls and not has_response_schema_tool_call:
|
|
tool_call_messages = (
|
|
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
|
|
)
|
|
new_messages = [
|
|
*messages,
|
|
AnthropicAssistantMessage(
|
|
role="assistant",
|
|
content=[each.model_dump() for each in tool_calls],
|
|
),
|
|
AnthropicUserMessage(
|
|
role="user",
|
|
content=[each.model_dump() for each in tool_call_messages],
|
|
),
|
|
]
|
|
async for event in self._stream_anthropic_structured(
|
|
model=model,
|
|
messages=new_messages,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
depth=depth + 1,
|
|
):
|
|
yield event
|
|
|
|
def _stream_ollama_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
return self._stream_openai_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
depth=depth,
|
|
)
|
|
|
|
def _stream_custom_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
max_tokens: Optional[int] = None,
|
|
depth: int = 0,
|
|
):
|
|
extra_body = {"enable_thinking": False} if self.disable_thinking() else None
|
|
return self._stream_openai_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
extra_body=extra_body,
|
|
depth=depth,
|
|
)
|
|
|
|
def stream_structured(
|
|
self,
|
|
model: str,
|
|
messages: List[LLMMessage],
|
|
response_format: dict,
|
|
strict: bool = False,
|
|
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
):
|
|
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
|
|
|
match self.llm_provider:
|
|
case LLMProvider.OPENAI:
|
|
return self._stream_openai_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
tools=parsed_tools,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.GOOGLE:
|
|
return self._stream_google_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
tools=parsed_tools,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.ANTHROPIC:
|
|
return self._stream_anthropic_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
tools=parsed_tools,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.OLLAMA:
|
|
return self._stream_ollama_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
)
|
|
case LLMProvider.CUSTOM:
|
|
return self._stream_custom_structured(
|
|
model=model,
|
|
messages=messages,
|
|
response_format=response_format,
|
|
strict=strict,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
# ? Web search
|
|
async def _search_openai(self, query: str) -> str:
|
|
client: AsyncOpenAI = self._client
|
|
response = await client.responses.create(
|
|
model=get_model(),
|
|
tools=[
|
|
{
|
|
"type": "web_search_preview",
|
|
}
|
|
],
|
|
input=query,
|
|
)
|
|
return response.output_text
|
|
|
|
async def _search_google(self, query: str) -> str:
|
|
client: genai.Client = self._client
|
|
grounding_tool = GoogleTool(google_search=GoogleSearch())
|
|
config = GenerateContentConfig(tools=[grounding_tool])
|
|
|
|
response = await asyncio.to_thread(
|
|
client.models.generate_content,
|
|
model=get_model(),
|
|
contents=query,
|
|
config=config,
|
|
)
|
|
return response.text
|
|
|
|
async def _search_anthropic(self, query: str) -> str:
|
|
client: AsyncAnthropic = self._client
|
|
|
|
response = await client.messages.create(
|
|
model=get_model(),
|
|
max_tokens=4000,
|
|
messages=[{"role": "user", "content": query}],
|
|
tools=[
|
|
{"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
|
|
],
|
|
)
|
|
result = "\n".join(
|
|
[each.text for each in response.content if each.type == "text"]
|
|
)
|
|
return result
|