presenton/electron/servers/fastapi/services/llm_tool_calls_handler.py
sudipnext 4370b0291b feat: enhance web search functionality in LLM client and presentation generation
- Added methods to enable web search based on user settings and LLM provider.
- Updated presentation outline generation to utilize prefetched web facts.
- Modified system prompts to clarify web search usage.
- Improved UI text in advanced settings to better inform users about web search controls.
2026-04-16 14:40:31 +05:45

211 lines
8 KiB
Python

import asyncio
from datetime import datetime
import json
from typing import Any, Callable, Coroutine, List, Optional
from fastapi import HTTPException
from enums.llm_provider import LLMProvider
from models.llm_message import (
AnthropicToolCallMessage,
GoogleToolCallMessage,
OpenAIToolCallMessage,
)
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
from utils.schema_utils import (
ensure_strict_json_schema,
flatten_json_schema,
remove_titles_from_schema,
)
class LLMToolCallsHandler:
def __init__(self, client):
from services.llm_client import LLMClient
self.client: LLMClient = client
self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = {
"SearchWebTool": self.search_web_tool_call_handler,
"GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler,
}
self.dynamic_tools: List[LLMDynamicTool] = []
def get_tool_handler(
self, tool_name: str
) -> Callable[..., Coroutine[Any, Any, str]]:
handler = self.tools_map.get(tool_name)
if handler:
return handler
else:
dynamic_tools = list(
filter(lambda tool: tool.name == tool_name, self.dynamic_tools)
)
if dynamic_tools:
return dynamic_tools[0].handler
raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found")
def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None):
if tools is None:
return None
parsed_tools = map(self.parse_tool, tools)
return list(parsed_tools)
def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False):
if isinstance(tool, LLMDynamicTool):
self.dynamic_tools.append(tool)
match self.client.llm_provider:
case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM | LLMProvider.CODEX:
return self.parse_tool_openai(tool, strict)
case LLMProvider.ANTHROPIC:
return self.parse_tool_anthropic(tool)
case LLMProvider.GOOGLE:
return self.parse_tool_google(tool)
case _:
raise ValueError(
"LLM provider must be one of: openai, anthropic, google, codex, ollama, custom"
)
def parse_tool_openai(
self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False
):
if isinstance(tool, LLMDynamicTool):
name = tool.name
description = tool.description
parameters = tool.parameters
else:
name = tool.__name__
description = tool.__doc__ or ""
parameters = tool.model_json_schema()
if strict:
parameters = ensure_strict_json_schema(parameters, path=(), root=parameters)
return {
"type": "function",
"function": {
"name": name,
"description": description,
"strict": strict,
"parameters": parameters,
},
}
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
parsed = self.parse_tool_openai(tool)
parsed["function"]["parameters"] = (
remove_titles_from_schema(
flatten_json_schema(parsed["function"]["parameters"])
)
if parsed["function"]["parameters"]
else {}
)
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],
"parameters": parsed["function"]["parameters"],
}
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
parsed = self.parse_tool_openai(tool)
input_schema = parsed["function"]["parameters"]
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],
"input_schema": {"type": "object"} if input_schema == {} else input_schema,
}
async def handle_tool_calls_openai(
self,
tool_calls: List[OpenAIToolCall],
) -> List[OpenAIToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
OpenAIToolCallMessage(
content=result,
tool_call_id=tool_call.id,
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
return tool_call_messages
async def handle_tool_calls_google(
self,
tool_calls: List[GoogleToolCall],
) -> List[GoogleToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments)))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
GoogleToolCallMessage(
id=tool_call.id,
name=tool_call.name,
response={"result": result},
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
return tool_call_messages
async def handle_tool_calls_anthropic(
self,
tool_calls: List[AnthropicToolCall],
) -> List[AnthropicToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input)))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
AnthropicToolCallMessage(
content=result,
tool_use_id=tool_call.id,
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
return tool_call_messages
# ? Tool call handlers
# Search web tool call handler
async def search_web_tool_call_handler(self, arguments: str) -> str:
match self.client.llm_provider:
case LLMProvider.OPENAI | LLMProvider.CODEX:
return await self.search_web_tool_call_handler_openai(arguments)
case LLMProvider.ANTHROPIC:
return await self.search_web_tool_call_handler_anthropic(arguments)
case LLMProvider.GOOGLE:
return await self.search_web_tool_call_handler_google(arguments)
case _:
return (
"Web search tool call handler not implemented for this LLM provider: "
+ self.client.llm_provider.value
)
async def search_web_tool_call_handler_openai(self, arguments: str) -> str:
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_openai(args.query)
async def search_web_tool_call_handler_google(self, arguments: str) -> str:
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_google(args.query)
async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str:
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_anthropic(args.query)
# Get current datetime tool call handler
async def get_current_datetime_tool_call_handler(self, _) -> str:
current_time = datetime.now()
return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}"