diff --git a/servers/fastapi/enums/llm_call_type.py b/servers/fastapi/enums/llm_call_type.py new file mode 100644 index 00000000..e37fe4ae --- /dev/null +++ b/servers/fastapi/enums/llm_call_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LLMCallType(Enum): + UNSTRUCTURED = "unstructured" + UNSTRUCTURED_STREAM = "unstructured_stream" + STRUCTURED = "structured" + STRUCTURED_STREAM = "structured_stream" diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 51284173..6b364c09 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -1,7 +1,28 @@ -from typing import Literal +from typing import List, Literal, Optional from pydantic import BaseModel class LLMMessage(BaseModel): - role: Literal["user", "system"] + pass + + +class LLMUserMessage(LLMMessage): + role: Literal["user"] content: str + + +class LLMSystemMessage(LLMMessage): + role: Literal["system"] + content: str + + +class LLMToolCallMessage(LLMMessage): + role: Literal["tool"] + content: str + tool_call_id: str + + +class LLMAssistantMessage(LLMMessage): + role: Literal["assistant"] + content: str | None = None + tool_calls: Optional[List[dict]] = None diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py new file mode 100644 index 00000000..51cc4731 --- /dev/null +++ b/servers/fastapi/models/llm_tool_call.py @@ -0,0 +1,8 @@ +from typing import Optional +from pydantic import BaseModel + + +class LLMToolCall(BaseModel): + id: Optional[str] = None + name: str + arguments: Optional[str] = None diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py new file mode 100644 index 00000000..ed8ad6a1 --- /dev/null +++ b/servers/fastapi/models/llm_tools.py @@ -0,0 +1,21 @@ +from typing import Any, Callable, Coroutine +from pydantic import BaseModel, Field + + +class LLMTool(BaseModel): + pass + + +class LLMDynamicTool(LLMTool): + name: str + description: str + parameters: dict + handler: Callable[..., Coroutine[Any, Any, str]] + + +class SearchWebTool(LLMTool): + query: str = Field(description="The query to search the web for") + + +class GetCurrentDatetimeTool(LLMTool): + pass diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index eaf61770..1e52abef 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -3,13 +3,22 @@ import json from typing import List, Optional from fastapi import HTTPException from openai import AsyncOpenAI +from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall from google import genai from google.genai.types import GenerateContentConfig from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider -from models.llm_message import LLMMessage +from models.llm_message import ( + LLMAssistantMessage, + LLMMessage, + LLMSystemMessage, + LLMUserMessage, +) +from models.llm_tool_call import LLMToolCall +from models.llm_tools import LLMDynamicTool, LLMTool +from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async from utils.get_env import ( get_anthropic_api_key_env, @@ -23,6 +32,7 @@ from utils.get_env import ( ) from utils.llm_provider import get_llm_provider from utils.parsers import parse_bool_or_none +from utils.randomizers import get_random_uuid from utils.schema_utils import ensure_strict_json_schema @@ -30,6 +40,7 @@ class LLMClient: def __init__(self): self.llm_provider = get_llm_provider() self._client = self._get_client() + self.tool_calls_handler = LLMToolCallsHandler(self) # ? Use tool calls def use_tool_calls(self) -> bool: @@ -104,15 +115,19 @@ class LLMClient: # ? Prompts def _get_system_prompt(self, messages: List[LLMMessage]) -> str: for message in messages: - if message.role == "system": + if isinstance(message, LLMSystemMessage): return message.content return "" def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]: - return [message.content for message in messages if message.role == "user"] + return [ + message.content + for message in messages + if isinstance(message, LLMUserMessage) + ] def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if message.role == "user"] + return [message for message in messages if isinstance(message, LLMUserMessage)] # ? Generate Unstructured Content async def _generate_openai( @@ -120,6 +135,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -127,8 +143,36 @@ class LLMClient: model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, + tools=tools, extra_body=extra_body, ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + [ + LLMToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_calls + ] + ) + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[ + tool_call.model_dump() for tool_call in tool_call_messages + ], + ), + *tool_call_messages, + ] + return await self._generate_openai( + model, new_messages, max_tokens, tools, extra_body + ) + return response.choices[0].message.content async def _generate_google( @@ -192,7 +236,10 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.get_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: @@ -220,6 +267,7 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -231,46 +279,75 @@ class LLMClient: path=(), root=response_schema, ) - if not use_tool_calls: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - response_format={ - "type": "json_schema", - "json_schema": ( - { - "name": "ResponseSchema", - "strict": strict, - "schema": response_schema, - } - ), - }, - max_completion_tokens=max_tokens, - extra_body=extra_body, - ) - content = response.choices[0].message.content - else: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - tools=[ + response = await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + response_format={ + "type": "json_schema", + "json_schema": ( { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, + "name": "ResponseSchema", + "strict": strict, + "schema": response_schema, } - ], - tool_choice="required", - max_completion_tokens=max_tokens, - extra_body=extra_body, + ), + }, + max_completion_tokens=max_tokens, + extra_body=extra_body, + ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + [ + LLMToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_calls + ] ) - tool_calls = response.choices[0].message.tool_calls - if tool_calls: - content = tool_calls[0].function.arguments + new_messages = [ + *messages, + LLMAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[each.model_dump() for each in tool_calls], + ), + *tool_call_messages, + ] + return await self._generate_openai_structured( + model, + new_messages, + response_format, + strict, + max_tokens, + tools, + extra_body, + ) + content = response.choices[0].message.content + # else: + # response = await client.chat.completions.create( + # model=model, + # messages=[message.model_dump() for message in messages], + # tools=[ + # { + # "type": "function", + # "function": { + # "name": "ResponseSchema", + # "description": "A response to the user's message", + # "strict": strict, + # "parameters": response_format, + # }, + # } + # ], + # tool_choice="required", + # max_completion_tokens=max_tokens, + # extra_body=extra_body, + # ) + # tool_calls = response.choices[0].message.tool_calls + # if tool_calls: + # content = tool_calls[0].function.arguments if content: return json.loads(content) @@ -404,6 +481,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -413,9 +491,35 @@ class LLMClient: max_completion_tokens=max_tokens, extra_body=extra_body, ) as stream: + tool_calls = [] async for event in stream: - if event.type == "content.delta": + if ( + event.type == "tool_calls.function.arguments.delta" + and event.name == "ResponseSchema" + ): + yield event.arguments_delta + elif event.type == "content.delta": yield event.delta + elif ( + event.type == "tool_calls.function.arguments.done" + and event.name != "ResponseSchema" + ): + tool_calls.append( + LLMToolCall( + id=get_random_uuid(), + name=event.name, + arguments=event.arguments, + ) + ) + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_openai(tool_calls) + ) + + new_messages = [ + *messages, + *tool_call_messages, + ] async def _stream_google( self, @@ -497,6 +601,7 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, extra_body: Optional[dict] = None, ): client: AsyncOpenAI = self._client @@ -659,3 +764,10 @@ class LLMClient: return self._stream_custom_structured( model, messages, response_format, strict, max_tokens ) + + # ? Tool call handling + def get_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): + if tools is None: + return None + parsed_tools = map(self.tool_calls_handler.parse_tool, tools) + return list(parsed_tools) diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py new file mode 100644 index 00000000..c21a1dd0 --- /dev/null +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -0,0 +1,111 @@ +import asyncio +from datetime import datetime +from typing import Any, Callable, Coroutine, List +from fastapi import HTTPException +from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall +from enums.llm_call_type import LLMCallType +from enums.llm_provider import LLMProvider +from models.llm_message import LLMMessage, LLMToolCallMessage +from models.llm_tool_call import LLMToolCall +from models.llm_tools import ( + GetCurrentDatetimeTool, + LLMDynamicTool, + LLMTool, + SearchWebTool, +) + + +class LLMToolCallsHandler: + def __init__(self, client): + from services.llm_client import LLMClient + + self.client: LLMClient = client + + self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = { + "SearchWebTool": self.search_web_tool_call_handler, + "GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler, + } + self.dynamic_tools: List[LLMDynamicTool] = [] + + def get_tool_handler( + self, tool_name: str + ) -> Callable[..., Coroutine[Any, Any, str]]: + handler = self.tools_map.get(tool_name) + if not handler: + dynamic_tools = list( + filter(lambda tool: tool.name == tool_name, self.dynamic_tools) + ) + if dynamic_tools: + return dynamic_tools[0].handler + raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found") + + def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False): + if isinstance(tool, LLMDynamicTool): + self.dynamic_tools.append(tool) + + match self.client.llm_provider: + case LLMProvider.OPENAI: + return self.parse_tool_openai(tool, strict) + case LLMProvider.ANTHROPIC: + return self.parse_tool_anthropic(tool) + case LLMProvider.GOOGLE: + return self.parse_tool_google(tool) + case _: + raise ValueError( + f"LLM provider must be either openai, anthropic, or google" + ) + + def parse_tool_openai( + self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False + ): + if isinstance(tool, LLMDynamicTool): + name = tool.name + description = tool.description + parameters = tool.parameters + else: + name = tool.__class__.__name__ + description = tool.__class__.__doc__ or "" + parameters = tool.model_dump(mode="json") + + return { + "type": "function", + "function": { + "name": name, + "description": description, + "strict": strict, + "parameters": parameters, + }, + } + + def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): + pass + + def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): + pass + + async def handle_tool_calls_openai( + self, + tool_calls: List[LLMToolCall], + ) -> List[LLMToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(tool_call.arguments)) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + return [ + LLMToolCallMessage( + role="tool", + content=result, + tool_call_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + + # ? Tool call handlers + async def search_web_tool_call_handler(self, tool_call: dict) -> str: + pass + + async def get_current_datetime_tool_call_handler(self, tool_call: dict) -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S")