diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 1af2d76a..0bc78101 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -13,9 +13,9 @@ API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) class ResponseContent(BaseModel): trending_ai_tool: str = Field( - description="The summary of the trending AI tool in about 150 words", - min_length=150, - max_length=200, + description="The summary of the trending AI tool in about 50 words", + min_length=50, + max_length=100, ) current_date_time: str @@ -30,13 +30,12 @@ async def test(): get_current_datetime_tool = LLMDynamicTool( name="GetDateTimeDynamicTool", description="Get the current date and time", - parameters=None, handler=get_current_datetime_tool_handler, ) text_content = "" - async for event in client.stream_structured( + response = await client.generate_structured( model=get_model(), messages=[ LLMUserMessage( @@ -48,7 +47,6 @@ async def test(): SearchWebTool, get_current_datetime_tool, ], - ): - text_content += event + ) - return {"data": text_content} + return {"data": response} diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index bff98951..db741ca4 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -2,6 +2,8 @@ from typing import Any, List, Literal, Optional from pydantic import BaseModel from google.genai.types import Content as GoogleContent +from models.llm_tool_call import AnthropicToolCall + class LLMMessage(BaseModel): pass @@ -28,6 +30,22 @@ class GoogleAssistantMessage(LLMMessage): content: GoogleContent +class AnthropicAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: List[AnthropicToolCall] + + +class AnthropicToolCallMessage(LLMMessage): + type: Literal["tool_result"] = "tool_result" + tool_use_id: str + content: str + + +class AnthropicUserMessage(LLMMessage): + role: Literal["user"] = "user" + content: List[AnthropicToolCallMessage] + + class OpenAIToolCallMessage(LLMMessage): role: Literal["tool"] = "tool" content: str diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py index 2eb306e6..5eb1f008 100644 --- a/servers/fastapi/models/llm_tool_call.py +++ b/servers/fastapi/models/llm_tool_call.py @@ -20,3 +20,10 @@ class OpenAIToolCall(LLMToolCall): class GoogleToolCall(LLMToolCall): name: str arguments: Optional[dict] = None + + +class AnthropicToolCall(LLMToolCall): + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: object diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py index 6ee2e214..4bede740 100644 --- a/servers/fastapi/models/llm_tools.py +++ b/servers/fastapi/models/llm_tools.py @@ -10,7 +10,7 @@ class LLMDynamicTool(LLMTool): name: str description: str strict: bool = False - parameters: Optional[dict] = None + parameters: dict = {} handler: Callable[..., Coroutine[Any, Any, str]] diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 2ffcfb1e..07789ed4 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -15,6 +15,8 @@ from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider from models.llm_message import ( + AnthropicAssistantMessage, + AnthropicUserMessage, GoogleAssistantMessage, GoogleToolCallMessage, OpenAIAssistantMessage, @@ -23,6 +25,7 @@ from models.llm_message import ( LLMUserMessage, ) from models.llm_tool_call import ( + AnthropicToolCall, GoogleToolCall, LLMToolCall, OpenAIToolCall, @@ -157,8 +160,10 @@ class LLMClient: return contents - def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if isinstance(message, LLMUserMessage)] + def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + return [ + message for message in messages if not isinstance(message, LLMSystemMessage) + ] # ? Generate Unstructured Content async def _generate_openai( @@ -287,25 +292,61 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, depth: int = 0, - ): + ) -> str | None: client: AsyncAnthropic = self._client + response: AnthropicMessage = await client.messages.create( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], + tools=tools, max_tokens=max_tokens or 4000, ) - text = "" + text_content = None + tool_calls: List[AnthropicToolCall] = [] for content in response.content: if content.type == "text" and isinstance(content.text, str): - text += content.text - if text == "": - return None - return text + text_content = content.text + + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_ollama( self, @@ -361,7 +402,10 @@ class LLMClient: ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic( - model=model, messages=messages, max_tokens=max_tokens + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, ) case LLMProvider.OLLAMA: content = await self._generate_ollama( @@ -586,6 +630,7 @@ class LLMClient: model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, ): @@ -595,7 +640,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -603,19 +648,51 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) - content: dict | None = None - for content_block in response.content: - if content_block.type == "tool_use": - content = content_block.input + tool_calls: List[AnthropicToolCall] = [] + for content in response.content: + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) - return content + for each in tool_calls: + if each.name == "ResponseSchema": + return each.input + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + return None async def _generate_ollama_structured( self, @@ -690,6 +767,7 @@ class LLMClient: model=model, messages=messages, response_format=response_format, + tools=parsed_tools, max_tokens=max_tokens, ) case LLMProvider.OLLAMA: @@ -873,6 +951,7 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, depth: int = 0, ): client: AsyncAnthropic = self._client @@ -881,14 +960,17 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, + tools=tools, ) as stream: + tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "text" and isinstance(event.text, str): - yield event.text + if event.type == "input_json": + event.partial_json + pass def _stream_ollama( self, diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index 52aad663..a61d24ad 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -4,8 +4,12 @@ import json from typing import Any, Callable, Coroutine, List, Optional from fastapi import HTTPException from enums.llm_provider import LLMProvider -from models.llm_message import GoogleToolCallMessage, OpenAIToolCallMessage -from models.llm_tool_call import GoogleToolCall, OpenAIToolCall +from models.llm_message import ( + AnthropicToolCallMessage, + GoogleToolCallMessage, + OpenAIToolCallMessage, +) +from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool @@ -88,7 +92,13 @@ class LLMToolCallsHandler: } def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): - pass + parsed = self.parse_tool_openai(tool) + input_schema = parsed["function"]["parameters"] + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "input_schema": {"type": "object"} if input_schema == {} else input_schema, + } async def handle_tool_calls_openai( self, @@ -130,6 +140,30 @@ class LLMToolCallsHandler: ] return tool_call_messages + async def handle_tool_calls_anthropic( + self, + tool_calls: List[AnthropicToolCall], + ) -> List[AnthropicToolCallMessage]: + async_tool_calls_tasks = [] + print("--------------------------------") + print(tool_calls) + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + AnthropicToolCallMessage( + content=result, + tool_use_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + print("--------------------------------") + print(tool_call_messages) + return tool_call_messages + # ? Tool call handlers # Search web tool call handler async def search_web_tool_call_handler(self, arguments: str) -> str: