From a3e81da767ba618028634fea89d588b40053b647 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Thu, 7 Aug 2025 04:44:39 +0545 Subject: [PATCH 1/8] feat(fastapi): adding llm agent that supports tool calls and handling --- servers/fastapi/enums/llm_call_type.py | 8 + servers/fastapi/models/llm_message.py | 25 ++- servers/fastapi/models/llm_tool_call.py | 8 + servers/fastapi/models/llm_tools.py | 21 ++ servers/fastapi/services/llm_client.py | 196 ++++++++++++++---- .../services/llm_tool_calls_handler.py | 111 ++++++++++ 6 files changed, 325 insertions(+), 44 deletions(-) create mode 100644 servers/fastapi/enums/llm_call_type.py create mode 100644 servers/fastapi/models/llm_tool_call.py create mode 100644 servers/fastapi/models/llm_tools.py create mode 100644 servers/fastapi/services/llm_tool_calls_handler.py diff --git a/servers/fastapi/enums/llm_call_type.py b/servers/fastapi/enums/llm_call_type.py new file mode 100644 index 00000000..e37fe4ae --- /dev/null +++ b/servers/fastapi/enums/llm_call_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LLMCallType(Enum): + UNSTRUCTURED = "unstructured" + UNSTRUCTURED_STREAM = "unstructured_stream" + STRUCTURED = "structured" + STRUCTURED_STREAM = "structured_stream" diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 51284173..6b364c09 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -1,7 +1,28 @@ -from typing import Literal +from typing import List, Literal, Optional from pydantic import BaseModel class LLMMessage(BaseModel): - role: Literal["user", "system"] + pass + + +class LLMUserMessage(LLMMessage): + role: Literal["user"] content: str + + +class LLMSystemMessage(LLMMessage): + role: Literal["system"] + content: str + + +class LLMToolCallMessage(LLMMessage): + role: Literal["tool"] + content: str + tool_call_id: str + + +class LLMAssistantMessage(LLMMessage): + role: Literal["assistant"] + content: str | None = None + tool_calls: Optional[List[dict]] = None diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py new file mode 100644 index 00000000..51cc4731 --- /dev/null +++ b/servers/fastapi/models/llm_tool_call.py @@ -0,0 +1,8 @@ +from typing import Optional +from pydantic import BaseModel + + +class LLMToolCall(BaseModel): + id: Optional[str] = None + name: str + arguments: Optional[str] = None diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py new file mode 100644 index 00000000..ed8ad6a1 --- /dev/null +++ b/servers/fastapi/models/llm_tools.py @@ -0,0 +1,21 @@ +from typing import Any, Callable, Coroutine +from pydantic import BaseModel, Field + + +class LLMTool(BaseModel): + pass + + +class LLMDynamicTool(LLMTool): + name: str + description: str + parameters: dict + handler: Callable[..., Coroutine[Any, Any, str]] + + +class SearchWebTool(LLMTool): + query: str = Field(description="The query to search the web for") + + +class GetCurrentDatetimeTool(LLMTool): + pass diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index eaf61770..1e52abef 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -3,13 +3,22 @@ import json from typing import List, Optional from fastapi import HTTPException from openai import AsyncOpenAI +from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall from google import genai from google.genai.types import GenerateContentConfig from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider -from models.llm_message import LLMMessage +from models.llm_message import ( + LLMAssistantMessage, + LLMMessage, + LLMSystemMessage, + LLMUserMessage, +) +from models.llm_tool_call import LLMToolCall +from models.llm_tools import LLMDynamicTool, LLMTool +from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async from utils.get_env import ( get_anthropic_api_key_env, @@ -23,6 +32,7 @@ from utils.get_env import ( ) from utils.llm_provider import get_llm_provider from utils.parsers import parse_bool_or_none +from utils.randomizers import get_random_uuid from utils.schema_utils import ensure_strict_json_schema @@ -30,6 +40,7 @@ class LLMClient: def __init__(self): self.llm_provider = get_llm_provider() self._client = self._get_client() + self.tool_calls_handler = LLMToolCallsHandler(self) # ? Use tool calls def use_tool_calls(self) -> bool: @@ -104,15 +115,19 @@ class LLMClient: # ? Prompts def _get_system_prompt(self, messages: List[LLMMessage]) -> str: for message in messages: - if message.role == "system": + if isinstance(message, LLMSystemMessage): return message.content return "" def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]: - return [message.content for message in messages if message.role == "user"] + return [ + message.content + for message in messages + if isinstance(message, LLMUserMessage) + ] def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if message.role == "user"] + return [message for message in messages if isinstance(message, LLMUserMessage)] # ? Generate Unstructured Content async def _generate_openai( @@ -120,6 +135,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -127,8 +143,36 @@ class LLMClient: model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, + tools=tools, extra_body=extra_body, ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + [ + LLMToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_calls + ] + ) + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[ + tool_call.model_dump() for tool_call in tool_call_messages + ], + ), + *tool_call_messages, + ] + return await self._generate_openai( + model, new_messages, max_tokens, tools, extra_body + ) + return response.choices[0].message.content async def _generate_google( @@ -192,7 +236,10 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.get_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: @@ -220,6 +267,7 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -231,46 +279,75 @@ class LLMClient: path=(), root=response_schema, ) - if not use_tool_calls: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - response_format={ - "type": "json_schema", - "json_schema": ( - { - "name": "ResponseSchema", - "strict": strict, - "schema": response_schema, - } - ), - }, - max_completion_tokens=max_tokens, - extra_body=extra_body, - ) - content = response.choices[0].message.content - else: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - tools=[ + response = await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + response_format={ + "type": "json_schema", + "json_schema": ( { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, + "name": "ResponseSchema", + "strict": strict, + "schema": response_schema, } - ], - tool_choice="required", - max_completion_tokens=max_tokens, - extra_body=extra_body, + ), + }, + max_completion_tokens=max_tokens, + extra_body=extra_body, + ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + [ + LLMToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_calls + ] ) - tool_calls = response.choices[0].message.tool_calls - if tool_calls: - content = tool_calls[0].function.arguments + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[each.model_dump() for each in tool_calls], + ), + *tool_call_messages, + ] + return await self._generate_openai_structured( + model, + new_messages, + response_format, + strict, + max_tokens, + tools, + extra_body, + ) + content = response.choices[0].message.content + # else: + # response = await client.chat.completions.create( + # model=model, + # messages=[message.model_dump() for message in messages], + # tools=[ + # { + # "type": "function", + # "function": { + # "name": "ResponseSchema", + # "description": "A response to the user's message", + # "strict": strict, + # "parameters": response_format, + # }, + # } + # ], + # tool_choice="required", + # max_completion_tokens=max_tokens, + # extra_body=extra_body, + # ) + # tool_calls = response.choices[0].message.tool_calls + # if tool_calls: + # content = tool_calls[0].function.arguments if content: return json.loads(content) @@ -404,6 +481,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -413,9 +491,35 @@ class LLMClient: max_completion_tokens=max_tokens, extra_body=extra_body, ) as stream: + tool_calls = [] async for event in stream: - if event.type == "content.delta": + if ( + event.type == "tool_calls.function.arguments.delta" + and event.name == "ResponseSchema" + ): + yield event.arguments_delta + elif event.type == "content.delta": yield event.delta + elif ( + event.type == "tool_calls.function.arguments.done" + and event.name != "ResponseSchema" + ): + tool_calls.append( + LLMToolCall( + id=get_random_uuid(), + name=event.name, + arguments=event.arguments, + ) + ) + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_openai(tool_calls) + ) + + new_messages = [ + *messages, + *tool_call_messages, + ] async def _stream_google( self, @@ -497,6 +601,7 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -659,3 +764,10 @@ class LLMClient: return self._stream_custom_structured( model, messages, response_format, strict, max_tokens ) + + # ? Tool call handling + def get_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): + if tools is None: + return None + parsed_tools = map(self.tool_calls_handler.parse_tool, tools) + return list(parsed_tools) diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py new file mode 100644 index 00000000..c21a1dd0 --- /dev/null +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -0,0 +1,111 @@ +import asyncio +from datetime import datetime +from typing import Any, Callable, Coroutine, List +from fastapi import HTTPException +from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall +from enums.llm_call_type import LLMCallType +from enums.llm_provider import LLMProvider +from models.llm_message import LLMMessage, LLMToolCallMessage +from models.llm_tool_call import LLMToolCall +from models.llm_tools import ( + GetCurrentDatetimeTool, + LLMDynamicTool, + LLMTool, + SearchWebTool, +) + + +class LLMToolCallsHandler: + def __init__(self, client): + from services.llm_client import LLMClient + + self.client: LLMClient = client + + self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = { + "SearchWebTool": self.search_web_tool_call_handler, + "GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler, + } + self.dynamic_tools: List[LLMDynamicTool] = [] + + def get_tool_handler( + self, tool_name: str + ) -> Callable[..., Coroutine[Any, Any, str]]: + handler = self.tools_map.get(tool_name) + if not handler: + dynamic_tools = list( + filter(lambda tool: tool.name == tool_name, self.dynamic_tools) + ) + if dynamic_tools: + return dynamic_tools[0].handler + raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found") + + def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False): + if isinstance(tool, LLMDynamicTool): + self.dynamic_tools.append(tool) + + match self.client.llm_provider: + case LLMProvider.OPENAI: + return self.parse_tool_openai(tool, strict) + case LLMProvider.ANTHROPIC: + return self.parse_tool_anthropic(tool) + case LLMProvider.GOOGLE: + return self.parse_tool_google(tool) + case _: + raise ValueError( + f"LLM provider must be either openai, anthropic, or google" + ) + + def parse_tool_openai( + self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False + ): + if isinstance(tool, LLMDynamicTool): + name = tool.name + description = tool.description + parameters = tool.parameters + else: + name = tool.__class__.__name__ + description = tool.__class__.__doc__ or "" + parameters = tool.model_dump(mode="json") + + return { + "type": "function", + "function": { + "name": name, + "description": description, + "strict": strict, + "parameters": parameters, + }, + } + + def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): + pass + + def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): + pass + + async def handle_tool_calls_openai( + self, + tool_calls: List[LLMToolCall], + ) -> List[LLMToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(tool_call.arguments)) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + return [ + LLMToolCallMessage( + role="tool", + content=result, + tool_call_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + + # ? Tool call handlers + async def search_web_tool_call_handler(self, tool_call: dict) -> str: + pass + + async def get_current_datetime_tool_call_handler(self, tool_call: dict) -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") From 5030908974186633399ee914db1182b2dcde7b66 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Fri, 8 Aug 2025 02:49:51 +0545 Subject: [PATCH 2/8] feat(fastapi): generate and stream implementation for openai with agentic behavior --- servers/fastapi/api/main.py | 2 + servers/fastapi/api/v1/test/router.py | 49 ++ servers/fastapi/models/llm_message.py | 10 +- servers/fastapi/models/llm_tool_call.py | 15 +- servers/fastapi/models/llm_tools.py | 13 +- servers/fastapi/pyproject.toml | 1 + servers/fastapi/services/llm_client.py | 687 +++++++++++++----- .../services/llm_tool_calls_handler.py | 76 +- servers/fastapi/utils/dummy_functions.py | 2 + .../generate_presentation_outlines.py | 2 + servers/fastapi/uv.lock | 27 + 11 files changed, 665 insertions(+), 219 deletions(-) create mode 100644 servers/fastapi/api/v1/test/router.py create mode 100644 servers/fastapi/utils/dummy_functions.py diff --git a/servers/fastapi/api/main.py b/servers/fastapi/api/main.py index 80eea709..f5f4a8d3 100644 --- a/servers/fastapi/api/main.py +++ b/servers/fastapi/api/main.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from api.lifespan import app_lifespan from api.middlewares import UserConfigEnvUpdateMiddleware from api.v1.ppt.router import API_V1_PPT_ROUTER +from api.v1.test.router import API_V1_TEST_ROUTER app = FastAPI(lifespan=app_lifespan) @@ -10,6 +11,7 @@ app = FastAPI(lifespan=app_lifespan) # Routers app.include_router(API_V1_PPT_ROUTER) +app.include_router(API_V1_TEST_ROUTER) # Middlewares origins = ["*"] diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py new file mode 100644 index 00000000..6dcc58f4 --- /dev/null +++ b/servers/fastapi/api/v1/test/router.py @@ -0,0 +1,49 @@ +from datetime import datetime +import json +from fastapi import APIRouter +from pydantic import BaseModel + +from models.llm_message import LLMUserMessage +from models.llm_tools import LLMDynamicTool, SearchWebTool +from services.llm_client import LLMClient + +API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) + + +class ResponseContent(BaseModel): + trending_ai_tool: str + current_date_time: str + + +@API_V1_TEST_ROUTER.get("") +async def test(): + client = LLMClient() + + async def get_current_datetime_tool_handler(_) -> str: + return datetime.now().isoformat() + + get_current_datetime_tool = LLMDynamicTool( + name="GetDateTimeDynamicTool", + description="Get the current date and time", + parameters=None, + handler=get_current_datetime_tool_handler, + ) + + accumulated_content = "" + + async for chunk in client.stream_structured( + model="gpt-4.1-mini", + messages=[ + LLMUserMessage( + content="What is the current date and time ? What is the trending AI tool now ?" + ), + ], + response_format=ResponseContent.model_json_schema(), + tools=[ + SearchWebTool, + get_current_datetime_tool, + ], + ): + accumulated_content += chunk + + return {"data": json.loads(accumulated_content)} diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 6b364c09..03cb0a16 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -7,22 +7,24 @@ class LLMMessage(BaseModel): class LLMUserMessage(LLMMessage): - role: Literal["user"] + role: Literal["user"] = "user" content: str class LLMSystemMessage(LLMMessage): - role: Literal["system"] + role: Literal["system"] = "system" content: str class LLMToolCallMessage(LLMMessage): - role: Literal["tool"] + role: Literal["tool"] = "tool" + id: str content: str + type: str tool_call_id: str class LLMAssistantMessage(LLMMessage): - role: Literal["assistant"] + role: Literal["assistant"] = "assistant" content: str | None = None tool_calls: Optional[List[dict]] = None diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py index 51cc4731..de2444ec 100644 --- a/servers/fastapi/models/llm_tool_call.py +++ b/servers/fastapi/models/llm_tool_call.py @@ -1,8 +1,17 @@ -from typing import Optional +from typing import Literal from pydantic import BaseModel class LLMToolCall(BaseModel): - id: Optional[str] = None + pass + + +class OpenAIToolCallFunction(BaseModel): name: str - arguments: Optional[str] = None + arguments: str + + +class OpenAIToolCall(LLMToolCall): + id: str + type: Literal["function"] = "function" + function: OpenAIToolCallFunction diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py index ed8ad6a1..6ee2e214 100644 --- a/servers/fastapi/models/llm_tools.py +++ b/servers/fastapi/models/llm_tools.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Coroutine +from typing import Any, Callable, Coroutine, Optional from pydantic import BaseModel, Field @@ -9,13 +9,22 @@ class LLMTool(BaseModel): class LLMDynamicTool(LLMTool): name: str description: str - parameters: dict + strict: bool = False + parameters: Optional[dict] = None handler: Callable[..., Coroutine[Any, Any, str]] class SearchWebTool(LLMTool): + """ + Search the web for information. + """ + query: str = Field(description="The query to search the web for") class GetCurrentDatetimeTool(LLMTool): + """ + Get the current datetime. + """ + pass diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml index f0caf8e3..14240244 100644 --- a/servers/fastapi/pyproject.toml +++ b/servers/fastapi/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "openai>=1.98.0", "pathvalidate>=3.3.1", "pdfplumber>=0.11.7", + "pytest>=8.4.1", "python-pptx>=1.0.2", "redis>=6.2.0", "sqlmodel>=0.0.24", diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 1e52abef..ee506f7f 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -1,9 +1,11 @@ import asyncio import json -from typing import List, Optional +from typing import AsyncGenerator, List, Optional from fastapi import HTTPException from openai import AsyncOpenAI -from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) from google import genai from google.genai.types import GenerateContentConfig from anthropic import AsyncAnthropic @@ -16,10 +18,11 @@ from models.llm_message import ( LLMSystemMessage, LLMUserMessage, ) -from models.llm_tool_call import LLMToolCall +from models.llm_tool_call import LLMToolCall, OpenAIToolCall, OpenAIToolCallFunction from models.llm_tools import LLMDynamicTool, LLMTool from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async +from utils.dummy_functions import do_nothing_async from utils.get_env import ( get_anthropic_api_key_env, get_custom_llm_api_key_env, @@ -43,7 +46,7 @@ class LLMClient: self.tool_calls_handler = LLMToolCallsHandler(self) # ? Use tool calls - def use_tool_calls(self) -> bool: + def use_tool_calls_for_structured_output(self) -> bool: if self.llm_provider != LLMProvider.CUSTOM: return False return parse_bool_or_none(get_tool_calls_env()) or False @@ -137,7 +140,8 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> str | None: client: AsyncOpenAI = self._client response = await client.chat.completions.create( model=model, @@ -148,29 +152,37 @@ class LLMClient: ) tool_calls = response.choices[0].message.tool_calls if tool_calls: - tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( - [ - LLMToolCall( - id=tool_call.id, + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) - for tool_call in tool_calls - ] + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + assistant_message = LLMAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls], ) new_messages = [ *messages, - LLMAssistantMessage( - role="assistant", - content=response.choices[0].message.content, - tool_calls=[ - tool_call.model_dump() for tool_call in tool_call_messages - ], - ), + assistant_message, *tool_call_messages, ] return await self._generate_openai( - model, new_messages, max_tokens, tools, extra_body + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, ) return response.choices[0].message.content @@ -180,6 +192,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): client: genai.Client = self._client response = await asyncio.to_thread( @@ -199,6 +212,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client response: AnthropicMessage = await client.messages.create( @@ -219,16 +233,30 @@ class LLMClient: return text async def _generate_ollama( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): - return await self._generate_openai(model, messages, max_tokens) + return await self._generate_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) async def _generate_custom( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} return await self._generate_openai( - model, messages, max_tokens, extra_body=extra_body + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) async def generate( @@ -238,20 +266,33 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): - parsed_tools = self.get_tools(tools) + parsed_tools = self.tool_calls_handler.parse_tools(tools) content = None match self.llm_provider: case LLMProvider.OPENAI: - content = await self._generate_openai(model, messages, max_tokens) + content = await self._generate_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - content = await self._generate_google(model, messages, max_tokens) + content = await self._generate_google( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.ANTHROPIC: - content = await self._generate_anthropic(model, messages, max_tokens) + content = await self._generate_anthropic( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.OLLAMA: - content = await self._generate_ollama(model, messages, max_tokens) + content = await self._generate_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - content = await self._generate_custom(model, messages, max_tokens) + content = await self._generate_custom( + model=model, messages=messages, max_tokens=max_tokens + ) if content is None: raise HTTPException( status_code=400, @@ -269,88 +310,109 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> dict | None: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + strict=strict, + parameters=response_schema, + handler=do_nothing_async, + ) + ) + ) + response = await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], - response_format={ - "type": "json_schema", - "json_schema": ( - { - "name": "ResponseSchema", - "strict": strict, - "schema": response_schema, - } - ), - }, + response_format=( + { + "type": "json_schema", + "json_schema": ( + { + "name": "ResponseSchema", + "strict": strict, + "schema": response_schema, + } + ), + } + if not use_tool_calls_for_structured_output + else None + ), max_completion_tokens=max_tokens, + tools=all_tools, extra_body=extra_body, ) + + content = response.choices[0].message.content + tool_calls = response.choices[0].message.tool_calls + has_response_schema = False + if tool_calls: - tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( - [ - LLMToolCall( + for tool_call in tool_calls: + if tool_call.function.name == "ResponseSchema": + content = tool_call.function.arguments + has_response_schema = True + + if not has_response_schema: + parsed_tool_calls = [ + OpenAIToolCall( id=tool_call.id, - name=tool_call.function.name, - arguments=tool_call.function.arguments, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), ) for tool_call in tool_calls ] - ) - new_messages = [ - *messages, - LLMAssistantMessage( - role="assistant", - content=response.choices[0].message.content, - tool_calls=[each.model_dump() for each in tool_calls], - ), - *tool_call_messages, - ] - return await self._generate_openai_structured( - model, - new_messages, - response_format, - strict, - max_tokens, - tools, - extra_body, - ) - content = response.choices[0].message.content - # else: - # response = await client.chat.completions.create( - # model=model, - # messages=[message.model_dump() for message in messages], - # tools=[ - # { - # "type": "function", - # "function": { - # "name": "ResponseSchema", - # "description": "A response to the user's message", - # "strict": strict, - # "parameters": response_format, - # }, - # } - # ], - # tool_choice="required", - # max_completion_tokens=max_tokens, - # extra_body=extra_body, - # ) - # tool_calls = response.choices[0].message.tool_calls - # if tool_calls: - # content = tool_calls[0].function.arguments - + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + ) + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[each.model_dump() for each in parsed_tool_calls], + ), + *tool_call_messages, + ] + content = await self._generate_openai_structured( + model=model, + messages=new_messages, + response_format=response_schema, + strict=strict, + max_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + depth=depth + 1, + ) if content: - return json.loads(content) + if depth == 0: + return json.loads(content) + return content return None async def _generate_google_structured( @@ -359,6 +421,7 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, + depth: int = 0, ): client: genai.Client = self._client response = await asyncio.to_thread( @@ -384,6 +447,7 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client response: AnthropicMessage = await client.messages.create( @@ -420,9 +484,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) async def _generate_custom_structured( @@ -432,10 +502,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens, extra_body + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) async def generate_structured( @@ -444,29 +521,51 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ) -> dict: + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: content = await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: content = await self._generate_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: content = await self._generate_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: content = await self._generate_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) if content is None: raise HTTPException( @@ -483,49 +582,99 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - async with client.chat.completions.stream( + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + async for event in await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, + tools=tools, extra_body=extra_body, - ) as stream: - tool_calls = [] - async for event in stream: - if ( - event.type == "tool_calls.function.arguments.delta" - and event.name == "ResponseSchema" - ): - yield event.arguments_delta - elif event.type == "content.delta": - yield event.delta - elif ( - event.type == "tool_calls.function.arguments.done" - and event.name != "ResponseSchema" - ): + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: tool_calls.append( - LLMToolCall( - id=get_random_uuid(), - name=event.name, - arguments=event.arguments, + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), ) ) - if tool_calls: - tool_call_messages = ( - await self.tool_calls_handler.handle_tool_calls_openai(tool_calls) - ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments - new_messages = [ - *messages, - *tool_call_messages, - ] + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], + ), + *tool_call_messages, + ] + async for event in self._stream_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ): + yield event async def _stream_google( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): client: genai.Client = self._client async for event in iterator_to_async(client.models.generate_content_stream)( @@ -545,6 +694,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client async with client.messages.stream( @@ -566,32 +716,61 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) def _stream_custom( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} - return self._stream_openai(model, messages, max_tokens, extra_body) + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, + ) def stream( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - return self._stream_google(model, messages, max_tokens) + return self._stream_google( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.ANTHROPIC: - return self._stream_anthropic(model, messages, max_tokens) + return self._stream_anthropic( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.OLLAMA: - return self._stream_ollama(model, messages, max_tokens) + return self._stream_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - return self._stream_custom(model, messages, max_tokens) + return self._stream_custom( + model=model, messages=messages, max_tokens=max_tokens + ) # ? Stream Structured Content async def _stream_openai_structured( @@ -603,58 +782,143 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() + response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - async with client.chat.completions.stream( - model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - response_format=( - { - "type": "json_schema", - "json_schema": { + + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + strict=strict, + parameters=response_schema, + handler=do_nothing_async, + ) + ) + ) + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + + has_response_schema_tool_call = False + async for event in await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + max_completion_tokens=max_tokens, + tools=all_tools, + response_format=( + { + "type": "json_schema", + "json_schema": ( + { "name": "ResponseSchema", "strict": strict, "schema": response_schema, - }, - } + } + ), + } + if not use_tool_calls_for_structured_output + else None + ), + extra_body=extra_body, + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_name == "ResponseSchema": + if tool_arguments: + yield tool_arguments + has_response_schema_tool_call = True + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls and not has_response_schema_tool_call: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], ), - extra_body=extra_body, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta - else: - async with client.chat.completions.stream( + *tool_call_messages, + ] + async for event in self._stream_openai_structured( model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", + messages=new_messages, + max_tokens=max_tokens, + strict=strict, + tools=all_tools, + response_format=response_schema, extra_body=extra_body, - ) as stream: - async for event in stream: - if event.type == "tool_calls.function.arguments.delta": - yield event.arguments_delta + depth=depth + 1, + ): + yield event async def _stream_google_structured( self, @@ -662,6 +926,7 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, + depth: int = 0, ): client: genai.Client = self._client async for event in iterator_to_async(client.models.generate_content_stream)( @@ -683,6 +948,7 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client async with client.messages.stream( @@ -717,9 +983,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) def _stream_custom_structured( @@ -729,10 +1001,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): extra_body = {"enable_thinking": not self.disable_thinking()} return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens, extra_body + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) def stream_structured( @@ -741,33 +1020,67 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: return self._stream_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: return self._stream_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: return self._stream_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: return self._stream_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) - # ? Tool call handling - def get_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): - if tools is None: - return None - parsed_tools = map(self.tool_calls_handler.parse_tool, tools) - return list(parsed_tools) + # ? Web search + def _search_openai(self, query: str) -> str: + client: AsyncOpenAI = self._client + response = client.responses.create( + model="o4-mini", + tools=[ + { + "type": "web_search_preview", + "user_location": { + "type": "approximate", + "country": "GB", + "city": "London", + "region": "London", + }, + } + ], + input="What are the best restaurants around Granary Square?", + ) diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index c21a1dd0..9c589277 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -1,18 +1,12 @@ import asyncio from datetime import datetime -from typing import Any, Callable, Coroutine, List +import json +from typing import Any, Callable, Coroutine, List, Optional from fastapi import HTTPException -from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall -from enums.llm_call_type import LLMCallType from enums.llm_provider import LLMProvider -from models.llm_message import LLMMessage, LLMToolCallMessage -from models.llm_tool_call import LLMToolCall -from models.llm_tools import ( - GetCurrentDatetimeTool, - LLMDynamicTool, - LLMTool, - SearchWebTool, -) +from models.llm_message import LLMToolCallMessage +from models.llm_tool_call import OpenAIToolCall +from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool class LLMToolCallsHandler: @@ -31,7 +25,9 @@ class LLMToolCallsHandler: self, tool_name: str ) -> Callable[..., Coroutine[Any, Any, str]]: handler = self.tools_map.get(tool_name) - if not handler: + if handler: + return handler + else: dynamic_tools = list( filter(lambda tool: tool.name == tool_name, self.dynamic_tools) ) @@ -39,6 +35,12 @@ class LLMToolCallsHandler: return dynamic_tools[0].handler raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found") + def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): + if tools is None: + return None + parsed_tools = map(self.parse_tool, tools) + return list(parsed_tools) + def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False): if isinstance(tool, LLMDynamicTool): self.dynamic_tools.append(tool) @@ -63,9 +65,9 @@ class LLMToolCallsHandler: description = tool.description parameters = tool.parameters else: - name = tool.__class__.__name__ - description = tool.__class__.__doc__ or "" - parameters = tool.model_dump(mode="json") + name = tool.__name__ + description = tool.__doc__ or "" + parameters = tool.model_json_schema() return { "type": "function", @@ -85,27 +87,55 @@ class LLMToolCallsHandler: async def handle_tool_calls_openai( self, - tool_calls: List[LLMToolCall], + tool_calls: List[OpenAIToolCall], ) -> List[LLMToolCallMessage]: async_tool_calls_tasks = [] for tool_call in tool_calls: - tool_name = tool_call.name + tool_name = tool_call.function.name tool_handler = self.get_tool_handler(tool_name) - async_tool_calls_tasks.append(tool_handler(tool_call.arguments)) + async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments)) tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) - return [ + tool_call_messages = [ LLMToolCallMessage( role="tool", + id=tool_call.id, content=result, tool_call_id=tool_call.id, + type=tool_call.type, ) for tool_call, result in zip(tool_calls, tool_call_results) ] + return tool_call_messages # ? Tool call handlers - async def search_web_tool_call_handler(self, tool_call: dict) -> str: - pass - async def get_current_datetime_tool_call_handler(self, tool_call: dict) -> str: - return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # Search web tool call handler + async def search_web_tool_call_handler(self, arguments: str) -> str: + match self.client.llm_provider: + case LLMProvider.OPENAI: + return await self.search_web_tool_call_handler_openai(arguments) + case LLMProvider.ANTHROPIC: + return await self.search_web_tool_call_handler_anthropic(arguments) + case LLMProvider.GOOGLE: + return await self.search_web_tool_call_handler_google(arguments) + case _: + return ( + "Web search tool call handler not implemented for this LLM provider: " + + self.client.llm_provider.value + ) + + async def search_web_tool_call_handler_openai(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return args.query + + async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: + return "test" + + async def search_web_tool_call_handler_google(self, arguments: str) -> str: + return "test" + + # Get current datetime tool call handler + async def get_current_datetime_tool_call_handler(self, arguments: str) -> str: + current_time = datetime.now() + return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}" diff --git a/servers/fastapi/utils/dummy_functions.py b/servers/fastapi/utils/dummy_functions.py new file mode 100644 index 00000000..461e9695 --- /dev/null +++ b/servers/fastapi/utils/dummy_functions.py @@ -0,0 +1,2 @@ +async def do_nothing_async(_): + return None diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 2a3c9e95..ee5f0224 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,6 +1,7 @@ from typing import Optional from models.llm_message import LLMMessage +from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides from utils.llm_provider import get_model @@ -55,5 +56,6 @@ async def generate_ppt_outline( get_messages(prompt, n_slides, language, content), response_model.model_json_schema(), strict=True, + tools=[SearchWebTool, GetCurrentDatetimeTool], ): yield chunk diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock index b579f42a..4e19c02b 100644 --- a/servers/fastapi/uv.lock +++ b/servers/fastapi/uv.lock @@ -1061,6 +1061,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "isodate" version = "0.7.2" @@ -1907,6 +1916,7 @@ dependencies = [ { name = "openai" }, { name = "pathvalidate" }, { name = "pdfplumber" }, + { name = "pytest" }, { name = "python-pptx" }, { name = "redis" }, { name = "sqlmodel" }, @@ -1928,6 +1938,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.98.0" }, { name = "pathvalidate", specifier = ">=3.3.1" }, { name = "pdfplumber", specifier = ">=0.11.7" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "python-pptx", specifier = ">=1.0.2" }, { name = "redis", specifier = ">=6.2.0" }, { name = "sqlmodel", specifier = ">=0.0.24" }, @@ -2211,6 +2222,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + [[package]] name = "python-bidi" version = "0.6.6" From 49342e7c3c40989715f30af818d7d530ddd636d0 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Fri, 8 Aug 2025 08:01:15 +0545 Subject: [PATCH 3/8] feat(fastapi): adds tools support for openai and google generate and stream --- servers/fastapi/api/v1/test/router.py | 21 +- servers/fastapi/models/llm_message.py | 30 +- servers/fastapi/models/llm_tool_call.py | 7 +- servers/fastapi/services/llm_client.py | 335 +++++++++++++++--- .../services/llm_tool_calls_handler.py | 52 ++- 5 files changed, 369 insertions(+), 76 deletions(-) diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 6dcc58f4..1af2d76a 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -1,17 +1,22 @@ from datetime import datetime import json from fastapi import APIRouter -from pydantic import BaseModel +from pydantic import BaseModel, Field from models.llm_message import LLMUserMessage from models.llm_tools import LLMDynamicTool, SearchWebTool from services.llm_client import LLMClient +from utils.llm_provider import get_model API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) class ResponseContent(BaseModel): - trending_ai_tool: str + trending_ai_tool: str = Field( + description="The summary of the trending AI tool in about 150 words", + min_length=150, + max_length=200, + ) current_date_time: str @@ -29,13 +34,13 @@ async def test(): handler=get_current_datetime_tool_handler, ) - accumulated_content = "" + text_content = "" - async for chunk in client.stream_structured( - model="gpt-4.1-mini", + async for event in client.stream_structured( + model=get_model(), messages=[ LLMUserMessage( - content="What is the current date and time ? What is the trending AI tool now ?" + content="What is the current date and time ? What is the trending AI tool now ? Use Available tools to get the information." ), ], response_format=ResponseContent.model_json_schema(), @@ -44,6 +49,6 @@ async def test(): get_current_datetime_tool, ], ): - accumulated_content += chunk + text_content += event - return {"data": json.loads(accumulated_content)} + return {"data": text_content} diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 03cb0a16..bff98951 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -1,5 +1,6 @@ -from typing import List, Literal, Optional +from typing import Any, List, Literal, Optional from pydantic import BaseModel +from google.genai.types import Content as GoogleContent class LLMMessage(BaseModel): @@ -16,15 +17,24 @@ class LLMSystemMessage(LLMMessage): content: str -class LLMToolCallMessage(LLMMessage): - role: Literal["tool"] = "tool" - id: str - content: str - type: str - tool_call_id: str - - -class LLMAssistantMessage(LLMMessage): +class OpenAIAssistantMessage(LLMMessage): role: Literal["assistant"] = "assistant" content: str | None = None tool_calls: Optional[List[dict]] = None + + +class GoogleAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: GoogleContent + + +class OpenAIToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + content: str + tool_call_id: str + + +class GoogleToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + name: str + response: dict diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py index de2444ec..2eb306e6 100644 --- a/servers/fastapi/models/llm_tool_call.py +++ b/servers/fastapi/models/llm_tool_call.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel @@ -15,3 +15,8 @@ class OpenAIToolCall(LLMToolCall): id: str type: Literal["function"] = "function" function: OpenAIToolCallFunction + + +class GoogleToolCall(LLMToolCall): + name: str + arguments: Optional[dict] = None diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index ee506f7f..2ffcfb1e 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -7,18 +7,27 @@ from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) from google import genai -from google.genai.types import GenerateContentConfig +from google.genai.types import Content as GoogleContent, Part as GoogleContentPart +from google.genai.types import GenerateContentConfig, GoogleSearch +from google.genai.types import Tool as GoogleTool from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider from models.llm_message import ( - LLMAssistantMessage, + GoogleAssistantMessage, + GoogleToolCallMessage, + OpenAIAssistantMessage, LLMMessage, LLMSystemMessage, LLMUserMessage, ) -from models.llm_tool_call import LLMToolCall, OpenAIToolCall, OpenAIToolCallFunction +from models.llm_tool_call import ( + GoogleToolCall, + LLMToolCall, + OpenAIToolCall, + OpenAIToolCallFunction, +) from models.llm_tools import LLMDynamicTool, LLMTool from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async @@ -31,9 +40,10 @@ from utils.get_env import ( get_google_api_key_env, get_ollama_url_env, get_openai_api_key_env, + get_openai_model_env, get_tool_calls_env, ) -from utils.llm_provider import get_llm_provider +from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none from utils.randomizers import get_random_uuid from utils.schema_utils import ensure_strict_json_schema @@ -122,12 +132,30 @@ class LLMClient: return message.content return "" - def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]: - return [ - message.content - for message in messages - if isinstance(message, LLMUserMessage) - ] + def _get_google_messages(self, messages: List[LLMMessage]) -> List[str]: + contents = [] + for message in messages: + if isinstance(message, LLMUserMessage): + contents.append( + GoogleContent( + role="user", parts=[GoogleContentPart(text=message.content)] + ) + ) + elif isinstance(message, GoogleAssistantMessage): + contents.append(message.content) + elif isinstance(message, GoogleToolCallMessage): + contents.append( + GoogleContent( + role="user", + parts=[ + GoogleContentPart.from_function_response( + name=message.name, response=message.response + ) + ], + ) + ) + + return contents def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: return [message for message in messages if isinstance(message, LLMUserMessage)] @@ -166,7 +194,7 @@ class LLMClient: tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( parsed_tool_calls ) - assistant_message = LLMAssistantMessage( + assistant_message = OpenAIAssistantMessage( role="assistant", content=response.choices[0].message.content, tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls], @@ -191,21 +219,68 @@ class LLMClient: self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, - ): + ) -> str | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", max_output_tokens=max_tokens, ), ) - return response.text + + content = response.candidates[0].content + response_parts = content.parts + + if not response_parts: + return None + + text_content = None + tool_calls = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + if each_part.text: + text_content = each_part.text + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_anthropic( self, @@ -279,7 +354,10 @@ class LLMClient: ) case LLMProvider.GOOGLE: content = await self._generate_google( - model=model, messages=messages, max_tokens=max_tokens + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic( @@ -392,7 +470,7 @@ class LLMClient: ) new_messages = [ *messages, - LLMAssistantMessage( + OpenAIAssistantMessage( role="assistant", content=response.choices[0].message.content, tool_calls=[each.model_dump() for each in parsed_tool_calls], @@ -421,25 +499,87 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, depth: int = 0, - ): + ) -> dict | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters": response_format, + } + ] + ) + ) + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) - content = None - if response.text: - content = json.loads(response.text) - return content + content = response.candidates[0].content + response_parts = content.parts + text_content = None + + if not response_parts: + return None + + tool_calls: List[GoogleToolCall] = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + + if each_part.text: + text_content = each_part.text + + for each in tool_calls: + if each.name == "ResponseSchema": + return each.arguments + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + if text_content: + return json.loads(text_content) + return None async def _generate_anthropic_structured( self, @@ -542,6 +682,7 @@ class LLMClient: model=model, messages=messages, response_format=response_format, + tools=parsed_tools, max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: @@ -652,7 +793,7 @@ class LLMClient: ) new_messages = [ *messages, - LLMAssistantMessage( + OpenAIAssistantMessage( role="assistant", content=None, tool_calls=[each.model_dump() for each in tool_calls], @@ -673,22 +814,60 @@ class LLMClient: self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, - ): + ) -> AsyncGenerator[str, None]: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + + tool_calls = None async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", + tools=google_tools, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic( self, model: str, @@ -757,7 +936,10 @@ class LLMClient: ) case LLMProvider.GOOGLE: return self._stream_google( - model=model, messages=messages, max_tokens=max_tokens + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, ) case LLMProvider.ANTHROPIC: return self._stream_anthropic( @@ -901,7 +1083,7 @@ class LLMClient: ) new_messages = [ *messages, - LLMAssistantMessage( + OpenAIAssistantMessage( role="assistant", content=None, tool_calls=[each.model_dump() for each in tool_calls], @@ -926,22 +1108,82 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, depth: int = 0, ): client: genai.Client = self._client + + google_tools = [] + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters": response_format, + } + ] + ) + ) + + tool_calls: List[GoogleToolCall] = [] async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + has_response_schema_tool_call = False + for each in tool_calls: + if each.name == "ResponseSchema": + has_response_schema_tool_call = True + if each.arguments: + yield json.dumps(each.arguments) + + if has_response_schema_tool_call: + continue + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic_structured( self, model: str, @@ -1040,6 +1282,7 @@ class LLMClient: model=model, messages=messages, response_format=response_format, + tools=parsed_tools, max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: @@ -1067,20 +1310,28 @@ class LLMClient: ) # ? Web search - def _search_openai(self, query: str) -> str: + async def _search_openai(self, query: str) -> str: client: AsyncOpenAI = self._client - response = client.responses.create( - model="o4-mini", + response = await client.responses.create( + model=get_model(), tools=[ { "type": "web_search_preview", - "user_location": { - "type": "approximate", - "country": "GB", - "city": "London", - "region": "London", - }, } ], - input="What are the best restaurants around Granary Square?", + input=query, ) + return response.output_text + + async def _search_google(self, query: str) -> str: + client: genai.Client = self._client + grounding_tool = GoogleTool(google_search=GoogleSearch()) + config = GenerateContentConfig(tools=[grounding_tool]) + + response = await asyncio.to_thread( + client.models.generate_content, + model=get_model(), + contents=query, + config=config, + ) + return response.text diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index 9c589277..52aad663 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -4,8 +4,8 @@ import json from typing import Any, Callable, Coroutine, List, Optional from fastapi import HTTPException from enums.llm_provider import LLMProvider -from models.llm_message import LLMToolCallMessage -from models.llm_tool_call import OpenAIToolCall +from models.llm_message import GoogleToolCallMessage, OpenAIToolCallMessage +from models.llm_tool_call import GoogleToolCall, OpenAIToolCall from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool @@ -79,16 +79,21 @@ class LLMToolCallsHandler: }, } - def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): - pass - def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "parameters": parsed["function"]["parameters"], + } + + def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): pass async def handle_tool_calls_openai( self, tool_calls: List[OpenAIToolCall], - ) -> List[LLMToolCallMessage]: + ) -> List[OpenAIToolCallMessage]: async_tool_calls_tasks = [] for tool_call in tool_calls: tool_name = tool_call.function.name @@ -97,19 +102,35 @@ class LLMToolCallsHandler: tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) tool_call_messages = [ - LLMToolCallMessage( - role="tool", - id=tool_call.id, + OpenAIToolCallMessage( content=result, tool_call_id=tool_call.id, - type=tool_call.type, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_google( + self, + tool_calls: List[GoogleToolCall], + ) -> List[GoogleToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + GoogleToolCallMessage( + name=tool_call.name, + response={"result": result}, ) for tool_call, result in zip(tool_calls, tool_call_results) ] return tool_call_messages # ? Tool call handlers - # Search web tool call handler async def search_web_tool_call_handler(self, arguments: str) -> str: match self.client.llm_provider: @@ -127,12 +148,13 @@ class LLMToolCallsHandler: async def search_web_tool_call_handler_openai(self, arguments: str) -> str: args = SearchWebTool.model_validate_json(arguments) - return args.query - - async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: - return "test" + return await self.client._search_openai(args.query) async def search_web_tool_call_handler_google(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_google(args.query) + + async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: return "test" # Get current datetime tool call handler From 84fd0dee1affc6c53f0b17a2298a6af793970446 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Fri, 8 Aug 2025 09:06:05 +0545 Subject: [PATCH 4/8] feat(fastapi): adds tool calls support for anthropic generate --- servers/fastapi/api/v1/test/router.py | 14 +- servers/fastapi/models/llm_message.py | 18 +++ servers/fastapi/models/llm_tool_call.py | 7 + servers/fastapi/models/llm_tools.py | 2 +- servers/fastapi/services/llm_client.py | 130 ++++++++++++++---- .../services/llm_tool_calls_handler.py | 40 +++++- 6 files changed, 175 insertions(+), 36 deletions(-) diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 1af2d76a..0bc78101 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -13,9 +13,9 @@ API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) class ResponseContent(BaseModel): trending_ai_tool: str = Field( - description="The summary of the trending AI tool in about 150 words", - min_length=150, - max_length=200, + description="The summary of the trending AI tool in about 50 words", + min_length=50, + max_length=100, ) current_date_time: str @@ -30,13 +30,12 @@ async def test(): get_current_datetime_tool = LLMDynamicTool( name="GetDateTimeDynamicTool", description="Get the current date and time", - parameters=None, handler=get_current_datetime_tool_handler, ) text_content = "" - async for event in client.stream_structured( + response = await client.generate_structured( model=get_model(), messages=[ LLMUserMessage( @@ -48,7 +47,6 @@ async def test(): SearchWebTool, get_current_datetime_tool, ], - ): - text_content += event + ) - return {"data": text_content} + return {"data": response} diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index bff98951..db741ca4 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -2,6 +2,8 @@ from typing import Any, List, Literal, Optional from pydantic import BaseModel from google.genai.types import Content as GoogleContent +from models.llm_tool_call import AnthropicToolCall + class LLMMessage(BaseModel): pass @@ -28,6 +30,22 @@ class GoogleAssistantMessage(LLMMessage): content: GoogleContent +class AnthropicAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: List[AnthropicToolCall] + + +class AnthropicToolCallMessage(LLMMessage): + type: Literal["tool_result"] = "tool_result" + tool_use_id: str + content: str + + +class AnthropicUserMessage(LLMMessage): + role: Literal["user"] = "user" + content: List[AnthropicToolCallMessage] + + class OpenAIToolCallMessage(LLMMessage): role: Literal["tool"] = "tool" content: str diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py index 2eb306e6..5eb1f008 100644 --- a/servers/fastapi/models/llm_tool_call.py +++ b/servers/fastapi/models/llm_tool_call.py @@ -20,3 +20,10 @@ class OpenAIToolCall(LLMToolCall): class GoogleToolCall(LLMToolCall): name: str arguments: Optional[dict] = None + + +class AnthropicToolCall(LLMToolCall): + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: object diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py index 6ee2e214..4bede740 100644 --- a/servers/fastapi/models/llm_tools.py +++ b/servers/fastapi/models/llm_tools.py @@ -10,7 +10,7 @@ class LLMDynamicTool(LLMTool): name: str description: str strict: bool = False - parameters: Optional[dict] = None + parameters: dict = {} handler: Callable[..., Coroutine[Any, Any, str]] diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 2ffcfb1e..07789ed4 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -15,6 +15,8 @@ from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider from models.llm_message import ( + AnthropicAssistantMessage, + AnthropicUserMessage, GoogleAssistantMessage, GoogleToolCallMessage, OpenAIAssistantMessage, @@ -23,6 +25,7 @@ from models.llm_message import ( LLMUserMessage, ) from models.llm_tool_call import ( + AnthropicToolCall, GoogleToolCall, LLMToolCall, OpenAIToolCall, @@ -157,8 +160,10 @@ class LLMClient: return contents - def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if isinstance(message, LLMUserMessage)] + def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + return [ + message for message in messages if not isinstance(message, LLMSystemMessage) + ] # ? Generate Unstructured Content async def _generate_openai( @@ -287,25 +292,61 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, depth: int = 0, - ): + ) -> str | None: client: AsyncAnthropic = self._client + response: AnthropicMessage = await client.messages.create( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], + tools=tools, max_tokens=max_tokens or 4000, ) - text = "" + text_content = None + tool_calls: List[AnthropicToolCall] = [] for content in response.content: if content.type == "text" and isinstance(content.text, str): - text += content.text - if text == "": - return None - return text + text_content = content.text + + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_ollama( self, @@ -361,7 +402,10 @@ class LLMClient: ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic( - model=model, messages=messages, max_tokens=max_tokens + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, ) case LLMProvider.OLLAMA: content = await self._generate_ollama( @@ -586,6 +630,7 @@ class LLMClient: model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, ): @@ -595,7 +640,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -603,19 +648,51 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) - content: dict | None = None - for content_block in response.content: - if content_block.type == "tool_use": - content = content_block.input + tool_calls: List[AnthropicToolCall] = [] + for content in response.content: + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) - return content + for each in tool_calls: + if each.name == "ResponseSchema": + return each.input + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + return None async def _generate_ollama_structured( self, @@ -690,6 +767,7 @@ class LLMClient: model=model, messages=messages, response_format=response_format, + tools=parsed_tools, max_tokens=max_tokens, ) case LLMProvider.OLLAMA: @@ -873,6 +951,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, depth: int = 0, ): client: AsyncAnthropic = self._client @@ -881,14 +960,17 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, + tools=tools, ) as stream: + tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "text" and isinstance(event.text, str): - yield event.text + if event.type == "input_json": + event.partial_json + pass def _stream_ollama( self, diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index 52aad663..a61d24ad 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -4,8 +4,12 @@ import json from typing import Any, Callable, Coroutine, List, Optional from fastapi import HTTPException from enums.llm_provider import LLMProvider -from models.llm_message import GoogleToolCallMessage, OpenAIToolCallMessage -from models.llm_tool_call import GoogleToolCall, OpenAIToolCall +from models.llm_message import ( + AnthropicToolCallMessage, + GoogleToolCallMessage, + OpenAIToolCallMessage, +) +from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool @@ -88,7 +92,13 @@ class LLMToolCallsHandler: } def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): - pass + parsed = self.parse_tool_openai(tool) + input_schema = parsed["function"]["parameters"] + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "input_schema": {"type": "object"} if input_schema == {} else input_schema, + } async def handle_tool_calls_openai( self, @@ -130,6 +140,30 @@ class LLMToolCallsHandler: ] return tool_call_messages + async def handle_tool_calls_anthropic( + self, + tool_calls: List[AnthropicToolCall], + ) -> List[AnthropicToolCallMessage]: + async_tool_calls_tasks = [] + print("--------------------------------") + print(tool_calls) + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + AnthropicToolCallMessage( + content=result, + tool_use_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + print("--------------------------------") + print(tool_call_messages) + return tool_call_messages + # ? Tool call handlers # Search web tool call handler async def search_web_tool_call_handler(self, arguments: str) -> str: From 5c106bd6648d03c9b6f8f4ef4c2875fb7cb386db Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Fri, 8 Aug 2025 22:11:41 +0545 Subject: [PATCH 5/8] feat(fastapi): adds tool call support for anthropic stream and stream structured --- servers/fastapi/api/v1/test/router.py | 7 +- servers/fastapi/services/llm_client.py | 126 ++++++++++++++++-- .../services/llm_tool_calls_handler.py | 4 - .../generate_presentation_outlines.py | 6 +- 4 files changed, 120 insertions(+), 23 deletions(-) diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 0bc78101..9db101e0 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -35,7 +35,7 @@ async def test(): text_content = "" - response = await client.generate_structured( + async for chunk in client.stream_structured( model=get_model(), messages=[ LLMUserMessage( @@ -47,6 +47,7 @@ async def test(): SearchWebTool, get_current_datetime_tool, ], - ) + ): + text_content += chunk - return {"data": response} + return {"data": text_content} diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 07789ed4..847a4e80 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -955,6 +955,7 @@ class LLMClient: depth: int = 0, ): client: AsyncAnthropic = self._client + async with client.messages.stream( model=model, system=self._get_system_prompt(messages), @@ -968,9 +969,48 @@ class LLMClient: tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json": - event.partial_json - pass + + if event.type == "text": + yield event.text + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama( self, @@ -1025,7 +1065,10 @@ class LLMClient: ) case LLMProvider.ANTHROPIC: return self._stream_anthropic( - model=model, messages=messages, max_tokens=max_tokens + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, ) case LLMProvider.OLLAMA: return self._stream_ollama( @@ -1271,6 +1314,7 @@ class LLMClient: model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, ): @@ -1280,7 +1324,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -1288,17 +1332,72 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) as stream: + tool_calls: List[AnthropicToolCall] = [] + has_response_schema_tool_call = False + is_response_schema_tool_call_started = False async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json" and isinstance(event.partial_json, str): - yield event.partial_json + if ( + event.type == "content_block_start" + and event.content_block.type == "tool_use" + ): + if event.content_block.name == "ResponseSchema": + has_response_schema_tool_call = True + is_response_schema_tool_call_started = True + + if ( + event.type == "content_block_delta" + and event.delta.type == "input_json_delta" + and is_response_schema_tool_call_started + ): + yield event.delta.partial_json + + if has_response_schema_tool_call: + continue + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama_structured( self, @@ -1372,6 +1471,7 @@ class LLMClient: model=model, messages=messages, response_format=response_format, + tools=parsed_tools, max_tokens=max_tokens, ) case LLMProvider.OLLAMA: @@ -1416,4 +1516,4 @@ class LLMClient: contents=query, config=config, ) - return response.text + return response.text \ No newline at end of file diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index a61d24ad..723fefe0 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -145,8 +145,6 @@ class LLMToolCallsHandler: tool_calls: List[AnthropicToolCall], ) -> List[AnthropicToolCallMessage]: async_tool_calls_tasks = [] - print("--------------------------------") - print(tool_calls) for tool_call in tool_calls: tool_name = tool_call.name tool_handler = self.get_tool_handler(tool_name) @@ -160,8 +158,6 @@ class LLMToolCallsHandler: ) for tool_call, result in zip(tool_calls, tool_call_results) ] - print("--------------------------------") - print(tool_call_messages) return tool_call_messages # ? Tool call handlers diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index ee5f0224..3bb95c5e 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,6 +1,6 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides @@ -29,11 +29,11 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ - LLMMessage( + LLMSystemMessage( role="system", content=system_prompt, ), - LLMMessage( + LLMUserMessage( role="user", content=get_user_prompt(prompt, n_slides, language, content), ), From dc62eb72d171e8c2edfa49d6e0312983e3ca94aa Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Sat, 9 Aug 2025 01:36:16 +0545 Subject: [PATCH 6/8] feat(fastapi): adds anthropic web search, fix(fastapi): llm messages to system and user message --- .../fastapi/api/v1/ppt/endpoints/outlines.py | 3 ++ servers/fastapi/api/v1/test/router.py | 34 ++---------- servers/fastapi/models/llm_tools.py | 1 - servers/fastapi/services/llm_client.py | 45 ++++++++++------ .../services/llm_tool_calls_handler.py | 12 ++++- servers/fastapi/utils/llm_calls/edit_slide.py | 8 ++- .../utils/llm_calls/edit_slide_html.py | 6 +-- .../generate_presentation_outlines.py | 4 +- .../generate_presentation_structure.py | 8 ++- .../utils/llm_calls/generate_slide_content.py | 8 ++- .../llm_calls/select_slide_type_on_edit.py | 8 ++- servers/fastapi/utils/schema_utils.py | 53 +++++++++++++++++++ 12 files changed, 117 insertions(+), 73 deletions(-) diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py index f1eff7ba..bf5b1489 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py +++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py @@ -72,6 +72,9 @@ async def stream_outlines( presentation_outlines_json = json.loads(presentation_outlines_text) except Exception as e: print(e) + with open("./debug/outlines.txt", "w") as f: + f.write(presentation_outlines_text) + print(presentation_outlines_text) raise HTTPException( status_code=400, detail="Failed to generate presentation outlines. Please try again.", diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 9db101e0..71b77d2b 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -1,11 +1,10 @@ -from datetime import datetime -import json from fastapi import APIRouter from pydantic import BaseModel, Field from models.llm_message import LLMUserMessage -from models.llm_tools import LLMDynamicTool, SearchWebTool +from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient +from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline from utils.llm_provider import get_model API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) @@ -24,30 +23,7 @@ class ResponseContent(BaseModel): async def test(): client = LLMClient() - async def get_current_datetime_tool_handler(_) -> str: - return datetime.now().isoformat() + response = await client._search_anthropic("Trending AI tool now") + # print(response) - get_current_datetime_tool = LLMDynamicTool( - name="GetDateTimeDynamicTool", - description="Get the current date and time", - handler=get_current_datetime_tool_handler, - ) - - text_content = "" - - async for chunk in client.stream_structured( - model=get_model(), - messages=[ - LLMUserMessage( - content="What is the current date and time ? What is the trending AI tool now ? Use Available tools to get the information." - ), - ], - response_format=ResponseContent.model_json_schema(), - tools=[ - SearchWebTool, - get_current_datetime_tool, - ], - ): - text_content += chunk - - return {"data": text_content} + return {"data": ""} diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py index 4bede740..ccf64e67 100644 --- a/servers/fastapi/models/llm_tools.py +++ b/servers/fastapi/models/llm_tools.py @@ -9,7 +9,6 @@ class LLMTool(BaseModel): class LLMDynamicTool(LLMTool): name: str description: str - strict: bool = False parameters: dict = {} handler: Callable[..., Coroutine[Any, Any, str]] diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 847a4e80..e220f577 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -43,13 +43,11 @@ from utils.get_env import ( get_google_api_key_env, get_ollama_url_env, get_openai_api_key_env, - get_openai_model_env, get_tool_calls_env, ) from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none -from utils.randomizers import get_random_uuid -from utils.schema_utils import ensure_strict_json_schema +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema class LLMClient: @@ -455,10 +453,10 @@ class LLMClient: LLMDynamicTool( name="ResponseSchema", description="Provide response to the user", - strict=strict, parameters=response_schema, handler=do_nothing_async, - ) + ), + strict=strict, ) ) @@ -557,7 +555,7 @@ class LLMClient: { "name": "ResponseSchema", "description": "Provide response to the user", - "parameters": response_format, + "parameters_json_schema": response_format, } ] ) @@ -571,7 +569,7 @@ class LLMClient: tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="application/json" if not tools else None, - response_json_schema=response_format if not tools else None, + response_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) @@ -1114,10 +1112,10 @@ class LLMClient: LLMDynamicTool( name="ResponseSchema", description="Provide response to the user", - strict=strict, parameters=response_schema, handler=do_nothing_async, - ) + ), + strict=strict, ) ) @@ -1235,10 +1233,11 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[dict]] = None, depth: int = 0, - ): + ) -> AsyncGenerator[str, None]: + client: genai.Client = self._client - google_tools = [] + google_tools = None if tools: google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] google_tools.append( @@ -1247,13 +1246,14 @@ class LLMClient: { "name": "ResponseSchema", "description": "Provide response to the user", - "parameters": response_format, + "parameters_json_schema": response_format, } ] ) ) tool_calls: List[GoogleToolCall] = [] + has_response_schema_tool_call = False async for event in iterator_to_async(client.models.generate_content_stream)( model=model, contents=self._get_google_messages(messages), @@ -1277,7 +1277,6 @@ class LLMClient: for each in event.function_calls ] - has_response_schema_tool_call = False for each in tool_calls: if each.name == "ResponseSchema": has_response_schema_tool_call = True @@ -1317,7 +1316,7 @@ class LLMClient: tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, - ): + ) -> AsyncGenerator[str, None]: client: AsyncAnthropic = self._client async with client.messages.stream( model=model, @@ -1516,4 +1515,20 @@ class LLMClient: contents=query, config=config, ) - return response.text \ No newline at end of file + return response.text + + async def _search_anthropic(self, query: str) -> str: + client: AsyncAnthropic = self._client + + response = await client.messages.create( + model=get_model(), + max_tokens=4000, + messages=[{"role": "user", "content": query}], + tools=[ + {"type": "web_search_20250305", "name": "web_search", "max_uses": 1} + ], + ) + result = "\n".join( + [each.text for each in response.content if each.type == "text"] + ) + return result diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index 723fefe0..1d8ffec4 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -11,6 +11,7 @@ from models.llm_message import ( ) from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema class LLMToolCallsHandler: @@ -73,6 +74,9 @@ class LLMToolCallsHandler: description = tool.__doc__ or "" parameters = tool.model_json_schema() + if strict: + parameters = ensure_strict_json_schema(parameters, path=(), root=parameters) + return { "type": "function", "function": { @@ -85,6 +89,9 @@ class LLMToolCallsHandler: def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): parsed = self.parse_tool_openai(tool) + # parsed["function"]["parameters"] = flatten_json_schema( + # parsed["function"]["parameters"] + # ) return { "name": parsed["function"]["name"], "description": parsed["function"]["description"], @@ -185,9 +192,10 @@ class LLMToolCallsHandler: return await self.client._search_google(args.query) async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: - return "test" + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_anthropic(args.query) # Get current datetime tool call handler - async def get_current_datetime_tool_call_handler(self, arguments: str) -> str: + async def get_current_datetime_tool_call_handler(self, _) -> str: current_time = datetime.now() return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}" diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py index a8df598a..30599d08 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide.py +++ b/servers/fastapi/utils/llm_calls/edit_slide.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.sql.slide import SlideModel from services.llm_client import LLMClient @@ -41,12 +41,10 @@ def get_messages( language: str, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, slide_data, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/edit_slide_html.py b/servers/fastapi/utils/llm_calls/edit_slide_html.py index a5e2dfad..cf58d185 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide_html.py +++ b/servers/fastapi/utils/llm_calls/edit_slide_html.py @@ -1,5 +1,5 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from services.llm_client import LLMClient from utils.llm_provider import get_model @@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str): response = await client.generate( model=model, messages=[ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=get_user_prompt(prompt, html)), + LLMSystemMessage(content=system_prompt), + LLMUserMessage(content=get_user_prompt(prompt, html)), ], ) return extract_html_from_response(response) or html diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 3bb95c5e..6c0ad512 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,6 +1,6 @@ from typing import Optional -from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides @@ -30,11 +30,9 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ LLMSystemMessage( - role="system", content=system_prompt, ), LLMUserMessage( - role="user", content=get_user_prompt(prompt, n_slides, language, content), ), ] diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py index 47f47dba..1bfc0cd0 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from services.llm_client import LLMClient @@ -11,8 +11,7 @@ def get_messages( presentation_layout: PresentationLayoutModel, n_slides: int, data: str ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" You're a professional presentation designer with creative freedom to design engaging presentations. @@ -47,8 +46,7 @@ def get_messages( Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals. """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" {data} """, diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py index 62b87e2b..be19b168 100644 --- a/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.presentation_outline_model import SlideOutlineModel from services.llm_client import LLMClient @@ -39,12 +39,10 @@ def get_user_prompt(outline: str, language: str): def get_messages(outline: str, language: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(outline, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py index f3532b48..7235e558 100644 --- a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py +++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel from models.slide_layout_index import SlideLayoutIndex from models.sql.slide import SlideModel @@ -13,8 +13,7 @@ def get_messages( current_slide_layout: int, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" Select a Slide Layout index based on provided user prompt and current slide data. {layout.to_string()} @@ -26,8 +25,7 @@ def get_messages( **Go through all notes and steps and make sure they are followed, including mentioned constraints** """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" - User Prompt: {prompt} - Current Slide Data: {slide_data} diff --git a/servers/fastapi/utils/schema_utils.py b/servers/fastapi/utils/schema_utils.py index ae65f002..6cb01a0e 100644 --- a/servers/fastapi/utils/schema_utils.py +++ b/servers/fastapi/utils/schema_utils.py @@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved +# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions +def flatten_json_schema(schema: dict) -> dict: + root_schema = deepcopy(schema) + + def _flatten(node: Any) -> Any: + if isinstance(node, dict): + # If node is a pure $ref (or combined with extra fields), inline it + if "$ref" in node: + ref_value = node["$ref"] + assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}" + resolved = resolve_ref(root=root_schema, ref=ref_value) + assert isinstance(resolved, dict), ( + f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}" + ) + # Merge: referenced first, then overlay current (excluding $ref) + merged: dict[str, Any] = deepcopy(resolved) + for key, value in node.items(): + if key == "$ref": + continue + merged[key] = value + return _flatten(merged) + + flattened: dict[str, Any] = {} + for key, value in node.items(): + # Drop defs/definitions in output + if key in ("$defs", "definitions"): + continue + if key == "properties" and isinstance(value, dict): + flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()} + elif key in ("items", "contains", "additionalProperties", "not"): + if isinstance(value, dict): + flattened[key] = _flatten(value) + elif isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = value + elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value + return flattened + if isinstance(node, list): + return [_flatten(v) for v in node] + return node + + result = _flatten(schema) + # Ensure top-level cleanup just in case + if isinstance(result, dict): + result.pop("$defs", None) + result.pop("definitions", None) + return result + + # ? Not used def generate_constraint_sentences(schema: dict) -> str: """ From 3f523f149187109e30d1605edb950fa066612800 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Sat, 9 Aug 2025 03:03:13 +0545 Subject: [PATCH 7/8] refactor: removes redis service and env variables, fix(fastapi): user config bool env variables issues, parse tool fix for custom llm on tool call structured output --- servers/fastapi/api/main.py | 2 - .../fastapi/api/v1/ppt/endpoints/ollama.py | 2 +- servers/fastapi/api/v1/test/router.py | 29 ----- servers/fastapi/constants/llm.py | 4 +- servers/fastapi/models/user_config.py | 3 + servers/fastapi/services/llm_client.py | 14 ++- .../services/llm_tool_calls_handler.py | 2 +- servers/fastapi/services/redis_service.py | 115 ------------------ servers/fastapi/utils/get_env.py | 20 +-- .../generate_presentation_outlines.py | 7 +- servers/fastapi/utils/set_env.py | 4 + servers/fastapi/utils/user_config.py | 33 +++-- servers/nextjs/app/api/user-config/route.ts | 4 + servers/nextjs/components/AnthropicConfig.tsx | 21 +++- servers/nextjs/components/GoogleConfig.tsx | 24 +++- servers/nextjs/components/LLMSelection.tsx | 3 + servers/nextjs/components/OpenAIConfig.tsx | 22 +++- servers/nextjs/types/llm_config.ts | 1 + servers/nextjs/utils/providerUtils.ts | 1 + start.js | 1 + 20 files changed, 131 insertions(+), 181 deletions(-) delete mode 100644 servers/fastapi/api/v1/test/router.py delete mode 100644 servers/fastapi/services/redis_service.py diff --git a/servers/fastapi/api/main.py b/servers/fastapi/api/main.py index f5f4a8d3..80eea709 100644 --- a/servers/fastapi/api/main.py +++ b/servers/fastapi/api/main.py @@ -3,7 +3,6 @@ from fastapi.middleware.cors import CORSMiddleware from api.lifespan import app_lifespan from api.middlewares import UserConfigEnvUpdateMiddleware from api.v1.ppt.router import API_V1_PPT_ROUTER -from api.v1.test.router import API_V1_TEST_ROUTER app = FastAPI(lifespan=app_lifespan) @@ -11,7 +10,6 @@ app = FastAPI(lifespan=app_lifespan) # Routers app.include_router(API_V1_PPT_ROUTER) -app.include_router(API_V1_TEST_ROUTER) # Middlewares origins = ["*"] diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index adde8669..0dafa3e1 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -64,7 +64,7 @@ async def pull_model( # If the model is being pulled, return the model if saved_model_status: # If the model is being pulled, return the model - # ? If the model status is pulled in redis but was not found while listing pulled models, + # ? If the model status is pulled in database but was not found while listing pulled models, # ? it means the model was deleted and we need to pull it again if ( saved_model_status["status"] == "error" diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py deleted file mode 100644 index 71b77d2b..00000000 --- a/servers/fastapi/api/v1/test/router.py +++ /dev/null @@ -1,29 +0,0 @@ -from fastapi import APIRouter -from pydantic import BaseModel, Field - -from models.llm_message import LLMUserMessage -from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool -from services.llm_client import LLMClient -from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline -from utils.llm_provider import get_model - -API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) - - -class ResponseContent(BaseModel): - trending_ai_tool: str = Field( - description="The summary of the trending AI tool in about 50 words", - min_length=50, - max_length=100, - ) - current_date_time: str - - -@API_V1_TEST_ROUTER.get("") -async def test(): - client = LLMClient() - - response = await client._search_anthropic("Trending AI tool now") - # print(response) - - return {"data": ""} diff --git a/servers/fastapi/constants/llm.py b/servers/fastapi/constants/llm.py index ac4bd527..7d374f30 100644 --- a/servers/fastapi/constants/llm.py +++ b/servers/fastapi/constants/llm.py @@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1" # Default models DEFAULT_OPENAI_MODEL = "gpt-4.1" -DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash" -DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620" +DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash" +DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514" diff --git a/servers/fastapi/models/user_config.py b/servers/fastapi/models/user_config.py index 50783544..c040d22c 100644 --- a/servers/fastapi/models/user_config.py +++ b/servers/fastapi/models/user_config.py @@ -35,3 +35,6 @@ class UserConfig(BaseModel): TOOL_CALLS: Optional[bool] = None DISABLE_THINKING: Optional[bool] = None EXTENDED_REASONING: Optional[bool] = None + + # Web Search + WEB_GROUNDING: Optional[bool] = None diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index e220f577..3e8b35f2 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -44,10 +44,11 @@ from utils.get_env import ( get_ollama_url_env, get_openai_api_key_env, get_tool_calls_env, + get_web_grounding_env, ) from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none -from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema +from utils.schema_utils import ensure_strict_json_schema class LLMClient: @@ -62,6 +63,15 @@ class LLMClient: return False return parse_bool_or_none(get_tool_calls_env()) or False + # ? Web Grounding + def enable_web_grounding(self) -> bool: + if ( + self.llm_provider == LLMProvider.OLLAMA + or self.llm_provider == LLMProvider.CUSTOM + ): + return False + return parse_bool_or_none(get_web_grounding_env()) or False + # ? Disable thinking def disable_thinking(self) -> bool: return parse_bool_or_none(get_disable_thinking_env()) or False @@ -569,7 +579,7 @@ class LLMClient: tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="application/json" if not tools else None, - response_schema=response_format if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index 1d8ffec4..ed0d51ee 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -51,7 +51,7 @@ class LLMToolCallsHandler: self.dynamic_tools.append(tool) match self.client.llm_provider: - case LLMProvider.OPENAI: + case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM: return self.parse_tool_openai(tool, strict) case LLMProvider.ANTHROPIC: return self.parse_tool_anthropic(tool) diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py deleted file mode 100644 index f2e3d8c9..00000000 --- a/servers/fastapi/services/redis_service.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Optional -import redis -from redis.exceptions import RedisError - -from utils.get_env import ( - get_redis_db_env, - get_redis_host_env, - get_redis_password_env, - get_redis_port_env, -) - - -class RedisService: - def __init__(self): - self.redis_host = get_redis_host_env() or "localhost" - self.redis_port = int(get_redis_port_env() or "6379") - self.redis_db = int(get_redis_db_env() or "0") - self.redis_password = get_redis_password_env() or None - self.client = self._create_client() - - def _create_client(self) -> redis.Redis: - return redis.Redis( - host=self.redis_host, - port=self.redis_port, - db=self.redis_db, - password=self.redis_password, - decode_responses=True, - ) - - def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: - try: - return self.client.set(key, value, ex=expire) - except RedisError: - return False - - def get(self, key: str) -> Optional[str]: - try: - return self.client.get(key) - except RedisError: - return None - - def delete(self, key: str) -> bool: - try: - return bool(self.client.delete(key)) - except RedisError: - return False - - def exists(self, key: str) -> bool: - try: - return bool(self.client.exists(key)) - except RedisError: - return False - - def set_hash(self, name: str, mapping: dict) -> bool: - try: - return self.client.hmset(name, mapping) - except RedisError: - return False - - def get_hash(self, name: str) -> Optional[dict]: - try: - return self.client.hgetall(name) - except RedisError: - return None - - def delete_hash(self, name: str, *fields: str) -> int: - try: - return self.client.hdel(name, *fields) - except RedisError: - return 0 - - def set_list(self, name: str, values: list) -> bool: - try: - self.client.delete(name) - if values: - self.client.rpush(name, *values) - return True - except RedisError: - return False - - def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: - try: - return self.client.lrange(name, start, end) - except RedisError: - return None - - def add_to_set(self, name: str, *values: str) -> int: - try: - return self.client.sadd(name, *values) - except RedisError: - return 0 - - def get_set(self, name: str) -> Optional[set]: - try: - return self.client.smembers(name) - except RedisError: - return None - - def remove_from_set(self, name: str, *values: str) -> int: - try: - return self.client.srem(name, *values) - except RedisError: - return 0 - - def clear(self) -> bool: - try: - return self.client.flushdb() - except RedisError: - return False - - def close(self): - try: - self.client.close() - except RedisError: - pass diff --git a/servers/fastapi/utils/get_env.py b/servers/fastapi/utils/get_env.py index c2c72efd..fa80b2a2 100644 --- a/servers/fastapi/utils/get_env.py +++ b/servers/fastapi/utils/get_env.py @@ -81,22 +81,6 @@ def get_pixabay_api_key_env(): return os.getenv("PIXABAY_API_KEY") -def get_redis_host_env(): - return os.getenv("REDIS_HOST") - - -def get_redis_port_env(): - return os.getenv("REDIS_PORT") - - -def get_redis_db_env(): - return os.getenv("REDIS_DB") - - -def get_redis_password_env(): - return os.getenv("REDIS_PASSWORD") - - def get_tool_calls_env(): return os.getenv("TOOL_CALLS") @@ -107,3 +91,7 @@ def get_disable_thinking_env(): def get_extended_reasoning_env(): return os.getenv("EXTENDED_REASONING") + + +def get_web_grounding_env(): + return os.getenv("WEB_GROUNDING") diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 6c0ad512..892b9cff 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -4,7 +4,10 @@ from models.llm_message import LLMSystemMessage, LLMUserMessage from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides +from utils.get_env import get_web_grounding_env from utils.llm_provider import get_model +from utils.parsers import parse_bool_or_none +from utils.user_config import get_user_config system_prompt = """ You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content. @@ -49,11 +52,13 @@ async def generate_ppt_outline( client = LLMClient() + tools = [SearchWebTool, GetCurrentDatetimeTool] + async for chunk in client.stream_structured( model, get_messages(prompt, n_slides, language, content), response_model.model_json_schema(), strict=True, - tools=[SearchWebTool, GetCurrentDatetimeTool], + tools=tools if client.enable_web_grounding() else None, ): yield chunk diff --git a/servers/fastapi/utils/set_env.py b/servers/fastapi/utils/set_env.py index 7ac0e335..ea3758f3 100644 --- a/servers/fastapi/utils/set_env.py +++ b/servers/fastapi/utils/set_env.py @@ -79,3 +79,7 @@ def set_disable_thinking_env(value): def set_extended_reasoning_env(value): os.environ["EXTENDED_REASONING"] = value + + +def set_web_grounding_env(value): + os.environ["WEB_GROUNDING"] = value \ No newline at end of file diff --git a/servers/fastapi/utils/user_config.py b/servers/fastapi/utils/user_config.py index 06235d5a..49fd1722 100644 --- a/servers/fastapi/utils/user_config.py +++ b/servers/fastapi/utils/user_config.py @@ -22,6 +22,7 @@ from utils.get_env import ( get_image_provider_env, get_pixabay_api_key_env, get_extended_reasoning_env, + get_web_grounding_env, ) from utils.parsers import parse_bool_or_none from utils.set_env import ( @@ -43,6 +44,7 @@ from utils.set_env import ( set_image_provider_env, set_pixabay_api_key_env, set_tool_calls_env, + set_web_grounding_env, ) @@ -76,12 +78,26 @@ def get_user_config(): IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(), PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(), - TOOL_CALLS=existing_config.TOOL_CALLS - or parse_bool_or_none(get_tool_calls_env()), - DISABLE_THINKING=existing_config.DISABLE_THINKING - or parse_bool_or_none(get_disable_thinking_env()), - EXTENDED_REASONING=existing_config.EXTENDED_REASONING - or parse_bool_or_none(get_extended_reasoning_env()), + TOOL_CALLS=( + existing_config.TOOL_CALLS + if existing_config.TOOL_CALLS is not None + else (parse_bool_or_none(get_tool_calls_env()) or False) + ), + DISABLE_THINKING=( + existing_config.DISABLE_THINKING + if existing_config.DISABLE_THINKING is not None + else (parse_bool_or_none(get_disable_thinking_env()) or False) + ), + EXTENDED_REASONING=( + existing_config.EXTENDED_REASONING + if existing_config.EXTENDED_REASONING is not None + else (parse_bool_or_none(get_extended_reasoning_env()) or False) + ), + WEB_GROUNDING=( + existing_config.WEB_GROUNDING + if existing_config.WEB_GROUNDING is not None + else (parse_bool_or_none(get_web_grounding_env()) or False) + ), ) @@ -122,5 +138,6 @@ def update_env_with_user_config(): if user_config.DISABLE_THINKING: set_disable_thinking_env(str(user_config.DISABLE_THINKING)) if user_config.EXTENDED_REASONING: - if user_config.EXTENDED_REASONING: - set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + if user_config.WEB_GROUNDING: + set_web_grounding_env(str(user_config.WEB_GROUNDING)) diff --git a/servers/nextjs/app/api/user-config/route.ts b/servers/nextjs/app/api/user-config/route.ts index ff3c643a..03b801e3 100644 --- a/servers/nextjs/app/api/user-config/route.ts +++ b/servers/nextjs/app/api/user-config/route.ts @@ -57,6 +57,10 @@ export async function POST(request: Request) { userConfig.EXTENDED_REASONING === undefined ? existingConfig.EXTENDED_REASONING : userConfig.EXTENDED_REASONING, + WEB_GROUNDING: + userConfig.WEB_GROUNDING === undefined + ? existingConfig.WEB_GROUNDING + : userConfig.WEB_GROUNDING, USE_CUSTOM_URL: userConfig.USE_CUSTOM_URL === undefined ? existingConfig.USE_CUSTOM_URL diff --git a/servers/nextjs/components/AnthropicConfig.tsx b/servers/nextjs/components/AnthropicConfig.tsx index 567846b5..4b61bb65 100644 --- a/servers/nextjs/components/AnthropicConfig.tsx +++ b/servers/nextjs/components/AnthropicConfig.tsx @@ -19,6 +19,7 @@ interface AnthropicConfigProps { anthropicApiKey: string; anthropicModel: string; extendedReasoning: boolean; + webGrounding?: boolean; onInputChange: (value: string | boolean, field: string) => void; } @@ -27,6 +28,7 @@ export default function AnthropicConfig({ anthropicApiKey, anthropicModel, extendedReasoning, + webGrounding, onInputChange, }: AnthropicConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -65,7 +67,7 @@ export default function AnthropicConfig({ const data = await response.json(); setAvailableModels(data); setModelsChecked(true); - onInputChange("claude-3-5-sonnet-20241022", "anthropic_model"); + onInputChange("claude-sonnet-4-20250514", "anthropic_model"); } else { console.error('Failed to fetch models'); setAvailableModels([]); @@ -226,6 +228,23 @@ export default function AnthropicConfig({ ) : null} + + {/* Web Grounding Toggle - at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/components/GoogleConfig.tsx b/servers/nextjs/components/GoogleConfig.tsx index 6746f779..8d333dd3 100644 --- a/servers/nextjs/components/GoogleConfig.tsx +++ b/servers/nextjs/components/GoogleConfig.tsx @@ -13,16 +13,19 @@ import { import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { cn } from "@/lib/utils"; import { toast } from "sonner"; +import { Switch } from "./ui/switch"; interface GoogleConfigProps { googleApiKey: string; googleModel: string; - onInputChange: (value: string, field: string) => void; + webGrounding?: boolean; + onInputChange: (value: string | boolean, field: string) => void; } export default function GoogleConfig({ googleApiKey, googleModel, + webGrounding, onInputChange }: GoogleConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -61,7 +64,7 @@ export default function GoogleConfig({ const data = await response.json(); setAvailableModels(data); setModelsChecked(true); - onInputChange("models/gemini-2.0-flash", "google_model"); + onInputChange("models/gemini-2.5-flash", "google_model"); } else { console.error('Failed to fetch models'); setAvailableModels([]); @@ -205,6 +208,23 @@ export default function GoogleConfig({ ) : null} + + {/* Web Grounding Toggle - at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/components/LLMSelection.tsx b/servers/nextjs/components/LLMSelection.tsx index 422e8333..ed308226 100644 --- a/servers/nextjs/components/LLMSelection.tsx +++ b/servers/nextjs/components/LLMSelection.tsx @@ -149,6 +149,7 @@ export default function LLMProviderSelection({ @@ -158,6 +159,7 @@ export default function LLMProviderSelection({ @@ -168,6 +170,7 @@ export default function LLMProviderSelection({ anthropicApiKey={llmConfig.ANTHROPIC_API_KEY || ""} anthropicModel={llmConfig.ANTHROPIC_MODEL || ""} extendedReasoning={llmConfig.EXTENDED_REASONING || false} + webGrounding={llmConfig.WEB_GROUNDING || false} onInputChange={input_field_changed} /> diff --git a/servers/nextjs/components/OpenAIConfig.tsx b/servers/nextjs/components/OpenAIConfig.tsx index b73695e9..7c465a99 100644 --- a/servers/nextjs/components/OpenAIConfig.tsx +++ b/servers/nextjs/components/OpenAIConfig.tsx @@ -13,16 +13,19 @@ import { import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { cn } from "@/lib/utils"; import { toast } from "sonner"; +import { Switch } from "./ui/switch"; interface OpenAIConfigProps { openaiApiKey: string; openaiModel: string; - onInputChange: (value: string, field: string) => void; + webGrounding?: boolean; + onInputChange: (value: string | boolean, field: string) => void; } export default function OpenAIConfig({ openaiApiKey, openaiModel, + webGrounding, onInputChange }: OpenAIConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -210,6 +213,23 @@ export default function OpenAIConfig({ ) : null} + + {/* Web Grounding Toggle - show at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/types/llm_config.ts b/servers/nextjs/types/llm_config.ts index 0a44e639..5b73b215 100644 --- a/servers/nextjs/types/llm_config.ts +++ b/servers/nextjs/types/llm_config.ts @@ -31,6 +31,7 @@ export interface LLMConfig { TOOL_CALLS?: boolean; DISABLE_THINKING?: boolean; EXTENDED_REASONING?: boolean; + WEB_GROUNDING?: boolean; // Only used in UI settings USE_CUSTOM_URL?: boolean; diff --git a/servers/nextjs/utils/providerUtils.ts b/servers/nextjs/utils/providerUtils.ts index a0efb4f8..4c776ee8 100644 --- a/servers/nextjs/utils/providerUtils.ts +++ b/servers/nextjs/utils/providerUtils.ts @@ -48,6 +48,7 @@ export const updateLLMConfig = ( tool_calls: "TOOL_CALLS", disable_thinking: "DISABLE_THINKING", extended_reasoning: "EXTENDED_REASONING", + web_grounding: "WEB_GROUNDING", }; const configKey = fieldMappings[field]; diff --git a/start.js b/start.js index f2dfa1b0..2a0e336f 100644 --- a/start.js +++ b/start.js @@ -81,6 +81,7 @@ const setupUserConfigFromEnv = () => { TOOL_CALLS: process.env.TOOL_CALLS || existingConfig.TOOL_CALLS, DISABLE_THINKING: process.env.DISABLE_THINKING || existingConfig.DISABLE_THINKING, EXTENDED_REASONING: process.env.EXTENDED_REASONING || existingConfig.EXTENDED_REASONING, + WEB_GROUNDING: process.env.WEB_GROUNDING || existingConfig.WEB_GROUNDING, USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL, }; From dc474cf0d550c783f18723fa9efc54a1f6aa322b Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Sat, 9 Aug 2025 03:07:33 +0545 Subject: [PATCH 8/8] docs(readme): adds new Web Grounding environment variable --- README.md | 1 + docker-compose.yml | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/README.md b/README.md index 72a53771..a7851cc4 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep - **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom** - **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output. - **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled. +- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results. You can also set the following environment variables to customize the image generation provider and API keys: diff --git a/docker-compose.yml b/docker-compose.yml index 6a0dfbc2..6ea2643a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} production-gpu: @@ -60,6 +63,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development: @@ -87,6 +93,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development-gpu: @@ -121,4 +130,7 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} \ No newline at end of file