diff --git a/README.md b/README.md index 72a53771..a7851cc4 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep - **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom** - **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output. - **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled. +- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results. You can also set the following environment variables to customize the image generation provider and API keys: diff --git a/docker-compose.yml b/docker-compose.yml index 6a0dfbc2..6ea2643a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} production-gpu: @@ -60,6 +63,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development: @@ -87,6 +93,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development-gpu: @@ -121,4 +130,7 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} \ No newline at end of file diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index adde8669..0dafa3e1 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -64,7 +64,7 @@ async def pull_model( # If the model is being pulled, return the model if saved_model_status: # If the model is being pulled, return the model - # ? If the model status is pulled in redis but was not found while listing pulled models, + # ? If the model status is pulled in database but was not found while listing pulled models, # ? it means the model was deleted and we need to pull it again if ( saved_model_status["status"] == "error" diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py index f1eff7ba..bf5b1489 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py +++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py @@ -72,6 +72,9 @@ async def stream_outlines( presentation_outlines_json = json.loads(presentation_outlines_text) except Exception as e: print(e) + with open("./debug/outlines.txt", "w") as f: + f.write(presentation_outlines_text) + print(presentation_outlines_text) raise HTTPException( status_code=400, detail="Failed to generate presentation outlines. Please try again.", diff --git a/servers/fastapi/constants/llm.py b/servers/fastapi/constants/llm.py index ac4bd527..7d374f30 100644 --- a/servers/fastapi/constants/llm.py +++ b/servers/fastapi/constants/llm.py @@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1" # Default models DEFAULT_OPENAI_MODEL = "gpt-4.1" -DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash" -DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620" +DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash" +DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514" diff --git a/servers/fastapi/enums/llm_call_type.py b/servers/fastapi/enums/llm_call_type.py new file mode 100644 index 00000000..e37fe4ae --- /dev/null +++ b/servers/fastapi/enums/llm_call_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LLMCallType(Enum): + UNSTRUCTURED = "unstructured" + UNSTRUCTURED_STREAM = "unstructured_stream" + STRUCTURED = "structured" + STRUCTURED_STREAM = "structured_stream" diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 51284173..db741ca4 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -1,7 +1,58 @@ -from typing import Literal +from typing import Any, List, Literal, Optional from pydantic import BaseModel +from google.genai.types import Content as GoogleContent + +from models.llm_tool_call import AnthropicToolCall class LLMMessage(BaseModel): - role: Literal["user", "system"] + pass + + +class LLMUserMessage(LLMMessage): + role: Literal["user"] = "user" content: str + + +class LLMSystemMessage(LLMMessage): + role: Literal["system"] = "system" + content: str + + +class OpenAIAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: str | None = None + tool_calls: Optional[List[dict]] = None + + +class GoogleAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: GoogleContent + + +class AnthropicAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: List[AnthropicToolCall] + + +class AnthropicToolCallMessage(LLMMessage): + type: Literal["tool_result"] = "tool_result" + tool_use_id: str + content: str + + +class AnthropicUserMessage(LLMMessage): + role: Literal["user"] = "user" + content: List[AnthropicToolCallMessage] + + +class OpenAIToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + content: str + tool_call_id: str + + +class GoogleToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + name: str + response: dict diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py new file mode 100644 index 00000000..5eb1f008 --- /dev/null +++ b/servers/fastapi/models/llm_tool_call.py @@ -0,0 +1,29 @@ +from typing import Literal, Optional +from pydantic import BaseModel + + +class LLMToolCall(BaseModel): + pass + + +class OpenAIToolCallFunction(BaseModel): + name: str + arguments: str + + +class OpenAIToolCall(LLMToolCall): + id: str + type: Literal["function"] = "function" + function: OpenAIToolCallFunction + + +class GoogleToolCall(LLMToolCall): + name: str + arguments: Optional[dict] = None + + +class AnthropicToolCall(LLMToolCall): + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: object diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py new file mode 100644 index 00000000..ccf64e67 --- /dev/null +++ b/servers/fastapi/models/llm_tools.py @@ -0,0 +1,29 @@ +from typing import Any, Callable, Coroutine, Optional +from pydantic import BaseModel, Field + + +class LLMTool(BaseModel): + pass + + +class LLMDynamicTool(LLMTool): + name: str + description: str + parameters: dict = {} + handler: Callable[..., Coroutine[Any, Any, str]] + + +class SearchWebTool(LLMTool): + """ + Search the web for information. + """ + + query: str = Field(description="The query to search the web for") + + +class GetCurrentDatetimeTool(LLMTool): + """ + Get the current datetime. + """ + + pass diff --git a/servers/fastapi/models/user_config.py b/servers/fastapi/models/user_config.py index 50783544..c040d22c 100644 --- a/servers/fastapi/models/user_config.py +++ b/servers/fastapi/models/user_config.py @@ -35,3 +35,6 @@ class UserConfig(BaseModel): TOOL_CALLS: Optional[bool] = None DISABLE_THINKING: Optional[bool] = None EXTENDED_REASONING: Optional[bool] = None + + # Web Search + WEB_GROUNDING: Optional[bool] = None diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml index f0caf8e3..14240244 100644 --- a/servers/fastapi/pyproject.toml +++ b/servers/fastapi/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "openai>=1.98.0", "pathvalidate>=3.3.1", "pdfplumber>=0.11.7", + "pytest>=8.4.1", "python-pptx>=1.0.2", "redis>=6.2.0", "sqlmodel>=0.0.24", diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index eaf61770..3e8b35f2 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -1,16 +1,40 @@ import asyncio import json -from typing import List, Optional +from typing import AsyncGenerator, List, Optional from fastapi import HTTPException from openai import AsyncOpenAI +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) from google import genai -from google.genai.types import GenerateContentConfig +from google.genai.types import Content as GoogleContent, Part as GoogleContentPart +from google.genai.types import GenerateContentConfig, GoogleSearch +from google.genai.types import Tool as GoogleTool from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider -from models.llm_message import LLMMessage +from models.llm_message import ( + AnthropicAssistantMessage, + AnthropicUserMessage, + GoogleAssistantMessage, + GoogleToolCallMessage, + OpenAIAssistantMessage, + LLMMessage, + LLMSystemMessage, + LLMUserMessage, +) +from models.llm_tool_call import ( + AnthropicToolCall, + GoogleToolCall, + LLMToolCall, + OpenAIToolCall, + OpenAIToolCallFunction, +) +from models.llm_tools import LLMDynamicTool, LLMTool +from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async +from utils.dummy_functions import do_nothing_async from utils.get_env import ( get_anthropic_api_key_env, get_custom_llm_api_key_env, @@ -20,8 +44,9 @@ from utils.get_env import ( get_ollama_url_env, get_openai_api_key_env, get_tool_calls_env, + get_web_grounding_env, ) -from utils.llm_provider import get_llm_provider +from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none from utils.schema_utils import ensure_strict_json_schema @@ -30,13 +55,23 @@ class LLMClient: def __init__(self): self.llm_provider = get_llm_provider() self._client = self._get_client() + self.tool_calls_handler = LLMToolCallsHandler(self) # ? Use tool calls - def use_tool_calls(self) -> bool: + def use_tool_calls_for_structured_output(self) -> bool: if self.llm_provider != LLMProvider.CUSTOM: return False return parse_bool_or_none(get_tool_calls_env()) or False + # ? Web Grounding + def enable_web_grounding(self) -> bool: + if ( + self.llm_provider == LLMProvider.OLLAMA + or self.llm_provider == LLMProvider.CUSTOM + ): + return False + return parse_bool_or_none(get_web_grounding_env()) or False + # ? Disable thinking def disable_thinking(self) -> bool: return parse_bool_or_none(get_disable_thinking_env()) or False @@ -104,15 +139,39 @@ class LLMClient: # ? Prompts def _get_system_prompt(self, messages: List[LLMMessage]) -> str: for message in messages: - if message.role == "system": + if isinstance(message, LLMSystemMessage): return message.content return "" - def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]: - return [message.content for message in messages if message.role == "user"] + def _get_google_messages(self, messages: List[LLMMessage]) -> List[str]: + contents = [] + for message in messages: + if isinstance(message, LLMUserMessage): + contents.append( + GoogleContent( + role="user", parts=[GoogleContentPart(text=message.content)] + ) + ) + elif isinstance(message, GoogleAssistantMessage): + contents.append(message.content) + elif isinstance(message, GoogleToolCallMessage): + contents.append( + GoogleContent( + role="user", + parts=[ + GoogleContentPart.from_function_response( + name=message.name, response=message.response + ) + ], + ) + ) - def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if message.role == "user"] + return contents + + def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + return [ + message for message in messages if not isinstance(message, LLMSystemMessage) + ] # ? Generate Unstructured Content async def _generate_openai( @@ -120,71 +179,208 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> str | None: client: AsyncOpenAI = self._client response = await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, + tools=tools, extra_body=extra_body, ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + assistant_message = OpenAIAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls], + ) + new_messages = [ + *messages, + assistant_message, + *tool_call_messages, + ] + return await self._generate_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ) + return response.choices[0].message.content async def _generate_google( self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> str | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", max_output_tokens=max_tokens, ), ) - return response.text + + content = response.candidates[0].content + response_parts = content.parts + + if not response_parts: + return None + + text_content = None + tool_calls = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + if each_part.text: + text_content = each_part.text + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_anthropic( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> str | None: client: AsyncAnthropic = self._client + response: AnthropicMessage = await client.messages.create( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], + tools=tools, max_tokens=max_tokens or 4000, ) - text = "" + text_content = None + tool_calls: List[AnthropicToolCall] = [] for content in response.content: if content.type == "text" and isinstance(content.text, str): - text += content.text - if text == "": - return None - return text + text_content = content.text + + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_ollama( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): - return await self._generate_openai(model, messages, max_tokens) + return await self._generate_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) async def _generate_custom( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} return await self._generate_openai( - model, messages, max_tokens, extra_body=extra_body + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) async def generate( @@ -192,19 +388,41 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: - content = await self._generate_openai(model, messages, max_tokens) + content = await self._generate_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - content = await self._generate_google(model, messages, max_tokens) + content = await self._generate_google( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.ANTHROPIC: - content = await self._generate_anthropic(model, messages, max_tokens) + content = await self._generate_anthropic( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.OLLAMA: - content = await self._generate_ollama(model, messages, max_tokens) + content = await self._generate_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - content = await self._generate_custom(model, messages, max_tokens) + content = await self._generate_custom( + model=model, messages=messages, max_tokens=max_tokens + ) if content is None: raise HTTPException( status_code=400, @@ -220,22 +438,43 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> dict | None: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - response_format={ + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + parameters=response_schema, + handler=do_nothing_async, + ), + strict=strict, + ) + ) + + response = await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + response_format=( + { "type": "json_schema", "json_schema": ( { @@ -244,36 +483,66 @@ class LLMClient: "schema": response_schema, } ), - }, - max_completion_tokens=max_tokens, - extra_body=extra_body, - ) - content = response.choices[0].message.content - else: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", - max_completion_tokens=max_tokens, - extra_body=extra_body, - ) - tool_calls = response.choices[0].message.tool_calls - if tool_calls: - content = tool_calls[0].function.arguments + } + if not use_tool_calls_for_structured_output + else None + ), + max_completion_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + ) + content = response.choices[0].message.content + + tool_calls = response.choices[0].message.tool_calls + has_response_schema = False + + if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name == "ResponseSchema": + content = tool_call.function.arguments + has_response_schema = True + + if not has_response_schema: + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[each.model_dump() for each in parsed_tool_calls], + ), + *tool_call_messages, + ] + content = await self._generate_openai_structured( + model=model, + messages=new_messages, + response_format=response_schema, + strict=strict, + max_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + depth=depth + 1, + ) if content: - return json.loads(content) + if depth == 0: + return json.loads(content) + return content return None async def _generate_google_structured( @@ -282,31 +551,96 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> dict | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters_json_schema": response_format, + } + ] + ) + ) + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) - content = None - if response.text: - content = json.loads(response.text) - return content + content = response.candidates[0].content + response_parts = content.parts + text_content = None + + if not response_parts: + return None + + tool_calls: List[GoogleToolCall] = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + + if each_part.text: + text_content = each_part.text + + for each in tool_calls: + if each.name == "ResponseSchema": + return each.arguments + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + if text_content: + return json.loads(text_content) + return None async def _generate_anthropic_structured( self, model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client response: AnthropicMessage = await client.messages.create( @@ -314,7 +648,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -322,19 +656,51 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) - content: dict | None = None - for content_block in response.content: - if content_block.type == "tool_use": - content = content_block.input + tool_calls: List[AnthropicToolCall] = [] + for content in response.content: + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) - return content + for each in tool_calls: + if each.name == "ResponseSchema": + return each.input + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + return None async def _generate_ollama_structured( self, @@ -343,9 +709,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) async def _generate_custom_structured( @@ -355,10 +727,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens, extra_body + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) async def generate_structured( @@ -367,29 +746,53 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ) -> dict: + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: content = await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: content = await self._generate_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: content = await self._generate_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: content = await self._generate_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) if content is None: raise HTTPException( @@ -404,90 +807,285 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - async with client.chat.completions.stream( + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + async for event in await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, + tools=tools, extra_body=extra_body, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], + ), + *tool_call_messages, + ] + async for event in self._stream_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ): + yield event async def _stream_google( self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + + tool_calls = None async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", + tools=google_tools, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client + async with client.messages.stream( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, + tools=tools, ) as stream: + tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "text" and isinstance(event.text, str): + + if event.type == "text": yield event.text + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + def _stream_ollama( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) def _stream_custom( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} - return self._stream_openai(model, messages, max_tokens, extra_body) + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, + ) def stream( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - return self._stream_google(model, messages, max_tokens) + return self._stream_google( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.ANTHROPIC: - return self._stream_anthropic(model, messages, max_tokens) + return self._stream_anthropic( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.OLLAMA: - return self._stream_ollama(model, messages, max_tokens) + return self._stream_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - return self._stream_custom(model, messages, max_tokens) + return self._stream_custom( + model=model, messages=messages, max_tokens=max_tokens + ) # ? Stream Structured Content async def _stream_openai_structured( @@ -497,59 +1095,145 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() + response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - async with client.chat.completions.stream( - model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - response_format=( - { - "type": "json_schema", - "json_schema": { + + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + parameters=response_schema, + handler=do_nothing_async, + ), + strict=strict, + ) + ) + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + + has_response_schema_tool_call = False + async for event in await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + max_completion_tokens=max_tokens, + tools=all_tools, + response_format=( + { + "type": "json_schema", + "json_schema": ( + { "name": "ResponseSchema", "strict": strict, "schema": response_schema, - }, - } + } + ), + } + if not use_tool_calls_for_structured_output + else None + ), + extra_body=extra_body, + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_name == "ResponseSchema": + if tool_arguments: + yield tool_arguments + has_response_schema_tool_call = True + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls and not has_response_schema_tool_call: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], ), - extra_body=extra_body, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta - else: - async with client.chat.completions.stream( + *tool_call_messages, + ] + async for event in self._stream_openai_structured( model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", + messages=new_messages, + max_tokens=max_tokens, + strict=strict, + tools=all_tools, + response_format=response_schema, extra_body=extra_body, - ) as stream: - async for event in stream: - if event.type == "tool_calls.function.arguments.delta": - yield event.arguments_delta + depth=depth + 1, + ): + yield event async def _stream_google_structured( self, @@ -557,35 +1241,99 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: + client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters_json_schema": response_format, + } + ] + ) + ) + + tool_calls: List[GoogleToolCall] = [] + has_response_schema_tool_call = False async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + for each in tool_calls: + if each.name == "ResponseSchema": + has_response_schema_tool_call = True + if each.arguments: + yield json.dumps(each.arguments) + + if has_response_schema_tool_call: + continue + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic_structured( self, model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncAnthropic = self._client async with client.messages.stream( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -593,17 +1341,72 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) as stream: + tool_calls: List[AnthropicToolCall] = [] + has_response_schema_tool_call = False + is_response_schema_tool_call_started = False async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json" and isinstance(event.partial_json, str): - yield event.partial_json + if ( + event.type == "content_block_start" + and event.content_block.type == "tool_use" + ): + if event.content_block.name == "ResponseSchema": + has_response_schema_tool_call = True + is_response_schema_tool_call_started = True + + if ( + event.type == "content_block_delta" + and event.delta.type == "input_json_delta" + and is_response_schema_tool_call_started + ): + yield event.delta.partial_json + + if has_response_schema_tool_call: + continue + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama_structured( self, @@ -612,9 +1415,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) def _stream_custom_structured( @@ -624,10 +1433,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens, extra_body + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) def stream_structured( @@ -636,26 +1452,93 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: return self._stream_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: return self._stream_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: return self._stream_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: return self._stream_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) + + # ? Web search + async def _search_openai(self, query: str) -> str: + client: AsyncOpenAI = self._client + response = await client.responses.create( + model=get_model(), + tools=[ + { + "type": "web_search_preview", + } + ], + input=query, + ) + return response.output_text + + async def _search_google(self, query: str) -> str: + client: genai.Client = self._client + grounding_tool = GoogleTool(google_search=GoogleSearch()) + config = GenerateContentConfig(tools=[grounding_tool]) + + response = await asyncio.to_thread( + client.models.generate_content, + model=get_model(), + contents=query, + config=config, + ) + return response.text + + async def _search_anthropic(self, query: str) -> str: + client: AsyncAnthropic = self._client + + response = await client.messages.create( + model=get_model(), + max_tokens=4000, + messages=[{"role": "user", "content": query}], + tools=[ + {"type": "web_search_20250305", "name": "web_search", "max_uses": 1} + ], + ) + result = "\n".join( + [each.text for each in response.content if each.type == "text"] + ) + return result diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py new file mode 100644 index 00000000..ed0d51ee --- /dev/null +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -0,0 +1,201 @@ +import asyncio +from datetime import datetime +import json +from typing import Any, Callable, Coroutine, List, Optional +from fastapi import HTTPException +from enums.llm_provider import LLMProvider +from models.llm_message import ( + AnthropicToolCallMessage, + GoogleToolCallMessage, + OpenAIToolCallMessage, +) +from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall +from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema + + +class LLMToolCallsHandler: + def __init__(self, client): + from services.llm_client import LLMClient + + self.client: LLMClient = client + + self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = { + "SearchWebTool": self.search_web_tool_call_handler, + "GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler, + } + self.dynamic_tools: List[LLMDynamicTool] = [] + + def get_tool_handler( + self, tool_name: str + ) -> Callable[..., Coroutine[Any, Any, str]]: + handler = self.tools_map.get(tool_name) + if handler: + return handler + else: + dynamic_tools = list( + filter(lambda tool: tool.name == tool_name, self.dynamic_tools) + ) + if dynamic_tools: + return dynamic_tools[0].handler + raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found") + + def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): + if tools is None: + return None + parsed_tools = map(self.parse_tool, tools) + return list(parsed_tools) + + def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False): + if isinstance(tool, LLMDynamicTool): + self.dynamic_tools.append(tool) + + match self.client.llm_provider: + case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM: + return self.parse_tool_openai(tool, strict) + case LLMProvider.ANTHROPIC: + return self.parse_tool_anthropic(tool) + case LLMProvider.GOOGLE: + return self.parse_tool_google(tool) + case _: + raise ValueError( + f"LLM provider must be either openai, anthropic, or google" + ) + + def parse_tool_openai( + self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False + ): + if isinstance(tool, LLMDynamicTool): + name = tool.name + description = tool.description + parameters = tool.parameters + else: + name = tool.__name__ + description = tool.__doc__ or "" + parameters = tool.model_json_schema() + + if strict: + parameters = ensure_strict_json_schema(parameters, path=(), root=parameters) + + return { + "type": "function", + "function": { + "name": name, + "description": description, + "strict": strict, + "parameters": parameters, + }, + } + + def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + # parsed["function"]["parameters"] = flatten_json_schema( + # parsed["function"]["parameters"] + # ) + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "parameters": parsed["function"]["parameters"], + } + + def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + input_schema = parsed["function"]["parameters"] + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "input_schema": {"type": "object"} if input_schema == {} else input_schema, + } + + async def handle_tool_calls_openai( + self, + tool_calls: List[OpenAIToolCall], + ) -> List[OpenAIToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments)) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + OpenAIToolCallMessage( + content=result, + tool_call_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_google( + self, + tool_calls: List[GoogleToolCall], + ) -> List[GoogleToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + GoogleToolCallMessage( + name=tool_call.name, + response={"result": result}, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_anthropic( + self, + tool_calls: List[AnthropicToolCall], + ) -> List[AnthropicToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + AnthropicToolCallMessage( + content=result, + tool_use_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + # ? Tool call handlers + # Search web tool call handler + async def search_web_tool_call_handler(self, arguments: str) -> str: + match self.client.llm_provider: + case LLMProvider.OPENAI: + return await self.search_web_tool_call_handler_openai(arguments) + case LLMProvider.ANTHROPIC: + return await self.search_web_tool_call_handler_anthropic(arguments) + case LLMProvider.GOOGLE: + return await self.search_web_tool_call_handler_google(arguments) + case _: + return ( + "Web search tool call handler not implemented for this LLM provider: " + + self.client.llm_provider.value + ) + + async def search_web_tool_call_handler_openai(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_openai(args.query) + + async def search_web_tool_call_handler_google(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_google(args.query) + + async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_anthropic(args.query) + + # Get current datetime tool call handler + async def get_current_datetime_tool_call_handler(self, _) -> str: + current_time = datetime.now() + return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}" diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py deleted file mode 100644 index f2e3d8c9..00000000 --- a/servers/fastapi/services/redis_service.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Optional -import redis -from redis.exceptions import RedisError - -from utils.get_env import ( - get_redis_db_env, - get_redis_host_env, - get_redis_password_env, - get_redis_port_env, -) - - -class RedisService: - def __init__(self): - self.redis_host = get_redis_host_env() or "localhost" - self.redis_port = int(get_redis_port_env() or "6379") - self.redis_db = int(get_redis_db_env() or "0") - self.redis_password = get_redis_password_env() or None - self.client = self._create_client() - - def _create_client(self) -> redis.Redis: - return redis.Redis( - host=self.redis_host, - port=self.redis_port, - db=self.redis_db, - password=self.redis_password, - decode_responses=True, - ) - - def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: - try: - return self.client.set(key, value, ex=expire) - except RedisError: - return False - - def get(self, key: str) -> Optional[str]: - try: - return self.client.get(key) - except RedisError: - return None - - def delete(self, key: str) -> bool: - try: - return bool(self.client.delete(key)) - except RedisError: - return False - - def exists(self, key: str) -> bool: - try: - return bool(self.client.exists(key)) - except RedisError: - return False - - def set_hash(self, name: str, mapping: dict) -> bool: - try: - return self.client.hmset(name, mapping) - except RedisError: - return False - - def get_hash(self, name: str) -> Optional[dict]: - try: - return self.client.hgetall(name) - except RedisError: - return None - - def delete_hash(self, name: str, *fields: str) -> int: - try: - return self.client.hdel(name, *fields) - except RedisError: - return 0 - - def set_list(self, name: str, values: list) -> bool: - try: - self.client.delete(name) - if values: - self.client.rpush(name, *values) - return True - except RedisError: - return False - - def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: - try: - return self.client.lrange(name, start, end) - except RedisError: - return None - - def add_to_set(self, name: str, *values: str) -> int: - try: - return self.client.sadd(name, *values) - except RedisError: - return 0 - - def get_set(self, name: str) -> Optional[set]: - try: - return self.client.smembers(name) - except RedisError: - return None - - def remove_from_set(self, name: str, *values: str) -> int: - try: - return self.client.srem(name, *values) - except RedisError: - return 0 - - def clear(self) -> bool: - try: - return self.client.flushdb() - except RedisError: - return False - - def close(self): - try: - self.client.close() - except RedisError: - pass diff --git a/servers/fastapi/utils/dummy_functions.py b/servers/fastapi/utils/dummy_functions.py new file mode 100644 index 00000000..461e9695 --- /dev/null +++ b/servers/fastapi/utils/dummy_functions.py @@ -0,0 +1,2 @@ +async def do_nothing_async(_): + return None diff --git a/servers/fastapi/utils/get_env.py b/servers/fastapi/utils/get_env.py index c2c72efd..fa80b2a2 100644 --- a/servers/fastapi/utils/get_env.py +++ b/servers/fastapi/utils/get_env.py @@ -81,22 +81,6 @@ def get_pixabay_api_key_env(): return os.getenv("PIXABAY_API_KEY") -def get_redis_host_env(): - return os.getenv("REDIS_HOST") - - -def get_redis_port_env(): - return os.getenv("REDIS_PORT") - - -def get_redis_db_env(): - return os.getenv("REDIS_DB") - - -def get_redis_password_env(): - return os.getenv("REDIS_PASSWORD") - - def get_tool_calls_env(): return os.getenv("TOOL_CALLS") @@ -107,3 +91,7 @@ def get_disable_thinking_env(): def get_extended_reasoning_env(): return os.getenv("EXTENDED_REASONING") + + +def get_web_grounding_env(): + return os.getenv("WEB_GROUNDING") diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py index a8df598a..30599d08 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide.py +++ b/servers/fastapi/utils/llm_calls/edit_slide.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.sql.slide import SlideModel from services.llm_client import LLMClient @@ -41,12 +41,10 @@ def get_messages( language: str, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, slide_data, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/edit_slide_html.py b/servers/fastapi/utils/llm_calls/edit_slide_html.py index a5e2dfad..cf58d185 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide_html.py +++ b/servers/fastapi/utils/llm_calls/edit_slide_html.py @@ -1,5 +1,5 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from services.llm_client import LLMClient from utils.llm_provider import get_model @@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str): response = await client.generate( model=model, messages=[ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=get_user_prompt(prompt, html)), + LLMSystemMessage(content=system_prompt), + LLMUserMessage(content=get_user_prompt(prompt, html)), ], ) return extract_html_from_response(response) or html diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 2a3c9e95..892b9cff 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,9 +1,13 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage +from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides +from utils.get_env import get_web_grounding_env from utils.llm_provider import get_model +from utils.parsers import parse_bool_or_none +from utils.user_config import get_user_config system_prompt = """ You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content. @@ -28,12 +32,10 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, n_slides, language, content), ), ] @@ -50,10 +52,13 @@ async def generate_ppt_outline( client = LLMClient() + tools = [SearchWebTool, GetCurrentDatetimeTool] + async for chunk in client.stream_structured( model, get_messages(prompt, n_slides, language, content), response_model.model_json_schema(), strict=True, + tools=tools if client.enable_web_grounding() else None, ): yield chunk diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py index 47f47dba..1bfc0cd0 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from services.llm_client import LLMClient @@ -11,8 +11,7 @@ def get_messages( presentation_layout: PresentationLayoutModel, n_slides: int, data: str ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" You're a professional presentation designer with creative freedom to design engaging presentations. @@ -47,8 +46,7 @@ def get_messages( Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals. """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" {data} """, diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py index 62b87e2b..be19b168 100644 --- a/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.presentation_outline_model import SlideOutlineModel from services.llm_client import LLMClient @@ -39,12 +39,10 @@ def get_user_prompt(outline: str, language: str): def get_messages(outline: str, language: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(outline, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py index f3532b48..7235e558 100644 --- a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py +++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel from models.slide_layout_index import SlideLayoutIndex from models.sql.slide import SlideModel @@ -13,8 +13,7 @@ def get_messages( current_slide_layout: int, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" Select a Slide Layout index based on provided user prompt and current slide data. {layout.to_string()} @@ -26,8 +25,7 @@ def get_messages( **Go through all notes and steps and make sure they are followed, including mentioned constraints** """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" - User Prompt: {prompt} - Current Slide Data: {slide_data} diff --git a/servers/fastapi/utils/schema_utils.py b/servers/fastapi/utils/schema_utils.py index ae65f002..6cb01a0e 100644 --- a/servers/fastapi/utils/schema_utils.py +++ b/servers/fastapi/utils/schema_utils.py @@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved +# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions +def flatten_json_schema(schema: dict) -> dict: + root_schema = deepcopy(schema) + + def _flatten(node: Any) -> Any: + if isinstance(node, dict): + # If node is a pure $ref (or combined with extra fields), inline it + if "$ref" in node: + ref_value = node["$ref"] + assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}" + resolved = resolve_ref(root=root_schema, ref=ref_value) + assert isinstance(resolved, dict), ( + f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}" + ) + # Merge: referenced first, then overlay current (excluding $ref) + merged: dict[str, Any] = deepcopy(resolved) + for key, value in node.items(): + if key == "$ref": + continue + merged[key] = value + return _flatten(merged) + + flattened: dict[str, Any] = {} + for key, value in node.items(): + # Drop defs/definitions in output + if key in ("$defs", "definitions"): + continue + if key == "properties" and isinstance(value, dict): + flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()} + elif key in ("items", "contains", "additionalProperties", "not"): + if isinstance(value, dict): + flattened[key] = _flatten(value) + elif isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = value + elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value + return flattened + if isinstance(node, list): + return [_flatten(v) for v in node] + return node + + result = _flatten(schema) + # Ensure top-level cleanup just in case + if isinstance(result, dict): + result.pop("$defs", None) + result.pop("definitions", None) + return result + + # ? Not used def generate_constraint_sentences(schema: dict) -> str: """ diff --git a/servers/fastapi/utils/set_env.py b/servers/fastapi/utils/set_env.py index 7ac0e335..ea3758f3 100644 --- a/servers/fastapi/utils/set_env.py +++ b/servers/fastapi/utils/set_env.py @@ -79,3 +79,7 @@ def set_disable_thinking_env(value): def set_extended_reasoning_env(value): os.environ["EXTENDED_REASONING"] = value + + +def set_web_grounding_env(value): + os.environ["WEB_GROUNDING"] = value \ No newline at end of file diff --git a/servers/fastapi/utils/user_config.py b/servers/fastapi/utils/user_config.py index 06235d5a..49fd1722 100644 --- a/servers/fastapi/utils/user_config.py +++ b/servers/fastapi/utils/user_config.py @@ -22,6 +22,7 @@ from utils.get_env import ( get_image_provider_env, get_pixabay_api_key_env, get_extended_reasoning_env, + get_web_grounding_env, ) from utils.parsers import parse_bool_or_none from utils.set_env import ( @@ -43,6 +44,7 @@ from utils.set_env import ( set_image_provider_env, set_pixabay_api_key_env, set_tool_calls_env, + set_web_grounding_env, ) @@ -76,12 +78,26 @@ def get_user_config(): IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(), PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(), - TOOL_CALLS=existing_config.TOOL_CALLS - or parse_bool_or_none(get_tool_calls_env()), - DISABLE_THINKING=existing_config.DISABLE_THINKING - or parse_bool_or_none(get_disable_thinking_env()), - EXTENDED_REASONING=existing_config.EXTENDED_REASONING - or parse_bool_or_none(get_extended_reasoning_env()), + TOOL_CALLS=( + existing_config.TOOL_CALLS + if existing_config.TOOL_CALLS is not None + else (parse_bool_or_none(get_tool_calls_env()) or False) + ), + DISABLE_THINKING=( + existing_config.DISABLE_THINKING + if existing_config.DISABLE_THINKING is not None + else (parse_bool_or_none(get_disable_thinking_env()) or False) + ), + EXTENDED_REASONING=( + existing_config.EXTENDED_REASONING + if existing_config.EXTENDED_REASONING is not None + else (parse_bool_or_none(get_extended_reasoning_env()) or False) + ), + WEB_GROUNDING=( + existing_config.WEB_GROUNDING + if existing_config.WEB_GROUNDING is not None + else (parse_bool_or_none(get_web_grounding_env()) or False) + ), ) @@ -122,5 +138,6 @@ def update_env_with_user_config(): if user_config.DISABLE_THINKING: set_disable_thinking_env(str(user_config.DISABLE_THINKING)) if user_config.EXTENDED_REASONING: - if user_config.EXTENDED_REASONING: - set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + if user_config.WEB_GROUNDING: + set_web_grounding_env(str(user_config.WEB_GROUNDING)) diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock index b579f42a..4e19c02b 100644 --- a/servers/fastapi/uv.lock +++ b/servers/fastapi/uv.lock @@ -1061,6 +1061,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "isodate" version = "0.7.2" @@ -1907,6 +1916,7 @@ dependencies = [ { name = "openai" }, { name = "pathvalidate" }, { name = "pdfplumber" }, + { name = "pytest" }, { name = "python-pptx" }, { name = "redis" }, { name = "sqlmodel" }, @@ -1928,6 +1938,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.98.0" }, { name = "pathvalidate", specifier = ">=3.3.1" }, { name = "pdfplumber", specifier = ">=0.11.7" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "python-pptx", specifier = ">=1.0.2" }, { name = "redis", specifier = ">=6.2.0" }, { name = "sqlmodel", specifier = ">=0.0.24" }, @@ -2211,6 +2222,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + [[package]] name = "python-bidi" version = "0.6.6" diff --git a/servers/nextjs/app/api/user-config/route.ts b/servers/nextjs/app/api/user-config/route.ts index ff3c643a..03b801e3 100644 --- a/servers/nextjs/app/api/user-config/route.ts +++ b/servers/nextjs/app/api/user-config/route.ts @@ -57,6 +57,10 @@ export async function POST(request: Request) { userConfig.EXTENDED_REASONING === undefined ? existingConfig.EXTENDED_REASONING : userConfig.EXTENDED_REASONING, + WEB_GROUNDING: + userConfig.WEB_GROUNDING === undefined + ? existingConfig.WEB_GROUNDING + : userConfig.WEB_GROUNDING, USE_CUSTOM_URL: userConfig.USE_CUSTOM_URL === undefined ? existingConfig.USE_CUSTOM_URL diff --git a/servers/nextjs/components/AnthropicConfig.tsx b/servers/nextjs/components/AnthropicConfig.tsx index 567846b5..4b61bb65 100644 --- a/servers/nextjs/components/AnthropicConfig.tsx +++ b/servers/nextjs/components/AnthropicConfig.tsx @@ -19,6 +19,7 @@ interface AnthropicConfigProps { anthropicApiKey: string; anthropicModel: string; extendedReasoning: boolean; + webGrounding?: boolean; onInputChange: (value: string | boolean, field: string) => void; } @@ -27,6 +28,7 @@ export default function AnthropicConfig({ anthropicApiKey, anthropicModel, extendedReasoning, + webGrounding, onInputChange, }: AnthropicConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -65,7 +67,7 @@ export default function AnthropicConfig({ const data = await response.json(); setAvailableModels(data); setModelsChecked(true); - onInputChange("claude-3-5-sonnet-20241022", "anthropic_model"); + onInputChange("claude-sonnet-4-20250514", "anthropic_model"); } else { console.error('Failed to fetch models'); setAvailableModels([]); @@ -226,6 +228,23 @@ export default function AnthropicConfig({ ) : null} + + {/* Web Grounding Toggle - at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/components/GoogleConfig.tsx b/servers/nextjs/components/GoogleConfig.tsx index 6746f779..8d333dd3 100644 --- a/servers/nextjs/components/GoogleConfig.tsx +++ b/servers/nextjs/components/GoogleConfig.tsx @@ -13,16 +13,19 @@ import { import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { cn } from "@/lib/utils"; import { toast } from "sonner"; +import { Switch } from "./ui/switch"; interface GoogleConfigProps { googleApiKey: string; googleModel: string; - onInputChange: (value: string, field: string) => void; + webGrounding?: boolean; + onInputChange: (value: string | boolean, field: string) => void; } export default function GoogleConfig({ googleApiKey, googleModel, + webGrounding, onInputChange }: GoogleConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -61,7 +64,7 @@ export default function GoogleConfig({ const data = await response.json(); setAvailableModels(data); setModelsChecked(true); - onInputChange("models/gemini-2.0-flash", "google_model"); + onInputChange("models/gemini-2.5-flash", "google_model"); } else { console.error('Failed to fetch models'); setAvailableModels([]); @@ -205,6 +208,23 @@ export default function GoogleConfig({ ) : null} + + {/* Web Grounding Toggle - at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/components/LLMSelection.tsx b/servers/nextjs/components/LLMSelection.tsx index 422e8333..ed308226 100644 --- a/servers/nextjs/components/LLMSelection.tsx +++ b/servers/nextjs/components/LLMSelection.tsx @@ -149,6 +149,7 @@ export default function LLMProviderSelection({ @@ -158,6 +159,7 @@ export default function LLMProviderSelection({ @@ -168,6 +170,7 @@ export default function LLMProviderSelection({ anthropicApiKey={llmConfig.ANTHROPIC_API_KEY || ""} anthropicModel={llmConfig.ANTHROPIC_MODEL || ""} extendedReasoning={llmConfig.EXTENDED_REASONING || false} + webGrounding={llmConfig.WEB_GROUNDING || false} onInputChange={input_field_changed} /> diff --git a/servers/nextjs/components/OpenAIConfig.tsx b/servers/nextjs/components/OpenAIConfig.tsx index b73695e9..7c465a99 100644 --- a/servers/nextjs/components/OpenAIConfig.tsx +++ b/servers/nextjs/components/OpenAIConfig.tsx @@ -13,16 +13,19 @@ import { import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { cn } from "@/lib/utils"; import { toast } from "sonner"; +import { Switch } from "./ui/switch"; interface OpenAIConfigProps { openaiApiKey: string; openaiModel: string; - onInputChange: (value: string, field: string) => void; + webGrounding?: boolean; + onInputChange: (value: string | boolean, field: string) => void; } export default function OpenAIConfig({ openaiApiKey, openaiModel, + webGrounding, onInputChange }: OpenAIConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -210,6 +213,23 @@ export default function OpenAIConfig({ ) : null} + + {/* Web Grounding Toggle - show at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/types/llm_config.ts b/servers/nextjs/types/llm_config.ts index 0a44e639..5b73b215 100644 --- a/servers/nextjs/types/llm_config.ts +++ b/servers/nextjs/types/llm_config.ts @@ -31,6 +31,7 @@ export interface LLMConfig { TOOL_CALLS?: boolean; DISABLE_THINKING?: boolean; EXTENDED_REASONING?: boolean; + WEB_GROUNDING?: boolean; // Only used in UI settings USE_CUSTOM_URL?: boolean; diff --git a/servers/nextjs/utils/providerUtils.ts b/servers/nextjs/utils/providerUtils.ts index a0efb4f8..4c776ee8 100644 --- a/servers/nextjs/utils/providerUtils.ts +++ b/servers/nextjs/utils/providerUtils.ts @@ -48,6 +48,7 @@ export const updateLLMConfig = ( tool_calls: "TOOL_CALLS", disable_thinking: "DISABLE_THINKING", extended_reasoning: "EXTENDED_REASONING", + web_grounding: "WEB_GROUNDING", }; const configKey = fieldMappings[field]; diff --git a/start.js b/start.js index f2dfa1b0..2a0e336f 100644 --- a/start.js +++ b/start.js @@ -81,6 +81,7 @@ const setupUserConfigFromEnv = () => { TOOL_CALLS: process.env.TOOL_CALLS || existingConfig.TOOL_CALLS, DISABLE_THINKING: process.env.DISABLE_THINKING || existingConfig.DISABLE_THINKING, EXTENDED_REASONING: process.env.EXTENDED_REASONING || existingConfig.EXTENDED_REASONING, + WEB_GROUNDING: process.env.WEB_GROUNDING || existingConfig.WEB_GROUNDING, USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL, };