feat(fastapi): adds tool calls support for anthropic generate

This commit is contained in:
sauravniraula 2025-08-08 09:06:05 +05:45
parent 49342e7c3c
commit 84fd0dee1a
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
6 changed files with 175 additions and 36 deletions

View file

@ -13,9 +13,9 @@ API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"])
class ResponseContent(BaseModel):
trending_ai_tool: str = Field(
description="The summary of the trending AI tool in about 150 words",
min_length=150,
max_length=200,
description="The summary of the trending AI tool in about 50 words",
min_length=50,
max_length=100,
)
current_date_time: str
@ -30,13 +30,12 @@ async def test():
get_current_datetime_tool = LLMDynamicTool(
name="GetDateTimeDynamicTool",
description="Get the current date and time",
parameters=None,
handler=get_current_datetime_tool_handler,
)
text_content = ""
async for event in client.stream_structured(
response = await client.generate_structured(
model=get_model(),
messages=[
LLMUserMessage(
@ -48,7 +47,6 @@ async def test():
SearchWebTool,
get_current_datetime_tool,
],
):
text_content += event
)
return {"data": text_content}
return {"data": response}

View file

@ -2,6 +2,8 @@ from typing import Any, List, Literal, Optional
from pydantic import BaseModel
from google.genai.types import Content as GoogleContent
from models.llm_tool_call import AnthropicToolCall
class LLMMessage(BaseModel):
pass
@ -28,6 +30,22 @@ class GoogleAssistantMessage(LLMMessage):
content: GoogleContent
class AnthropicAssistantMessage(LLMMessage):
role: Literal["assistant"] = "assistant"
content: List[AnthropicToolCall]
class AnthropicToolCallMessage(LLMMessage):
type: Literal["tool_result"] = "tool_result"
tool_use_id: str
content: str
class AnthropicUserMessage(LLMMessage):
role: Literal["user"] = "user"
content: List[AnthropicToolCallMessage]
class OpenAIToolCallMessage(LLMMessage):
role: Literal["tool"] = "tool"
content: str

View file

@ -20,3 +20,10 @@ class OpenAIToolCall(LLMToolCall):
class GoogleToolCall(LLMToolCall):
name: str
arguments: Optional[dict] = None
class AnthropicToolCall(LLMToolCall):
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: object

View file

@ -10,7 +10,7 @@ class LLMDynamicTool(LLMTool):
name: str
description: str
strict: bool = False
parameters: Optional[dict] = None
parameters: dict = {}
handler: Callable[..., Coroutine[Any, Any, str]]

View file

@ -15,6 +15,8 @@ from anthropic.types import Message as AnthropicMessage
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
from enums.llm_provider import LLMProvider
from models.llm_message import (
AnthropicAssistantMessage,
AnthropicUserMessage,
GoogleAssistantMessage,
GoogleToolCallMessage,
OpenAIAssistantMessage,
@ -23,6 +25,7 @@ from models.llm_message import (
LLMUserMessage,
)
from models.llm_tool_call import (
AnthropicToolCall,
GoogleToolCall,
LLMToolCall,
OpenAIToolCall,
@ -157,8 +160,10 @@ class LLMClient:
return contents
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
return [message for message in messages if isinstance(message, LLMUserMessage)]
def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
return [
message for message in messages if not isinstance(message, LLMSystemMessage)
]
# ? Generate Unstructured Content
async def _generate_openai(
@ -287,25 +292,61 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
):
) -> str | None:
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
for message in self._get_anthropic_messages(messages)
],
tools=tools,
max_tokens=max_tokens or 4000,
)
text = ""
text_content = None
tool_calls: List[AnthropicToolCall] = []
for content in response.content:
if content.type == "text" and isinstance(content.text, str):
text += content.text
if text == "":
return None
return text
text_content = content.text
if content.type == "tool_use":
tool_calls.append(
AnthropicToolCall(
id=content.id,
type=content.type,
name=content.name,
input=content.input,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
return await self._generate_anthropic(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
)
return text_content
async def _generate_ollama(
self,
@ -361,7 +402,10 @@ class LLMClient:
)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic(
model=model, messages=messages, max_tokens=max_tokens
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama(
@ -586,6 +630,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
response_format: dict,
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
):
@ -595,7 +640,7 @@ class LLMClient:
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=[
@ -603,19 +648,51 @@ class LLMClient:
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": response_format,
}
},
*(tools or []),
],
tool_choice={
"type": "tool",
"name": "ResponseSchema",
},
)
content: dict | None = None
for content_block in response.content:
if content_block.type == "tool_use":
content = content_block.input
tool_calls: List[AnthropicToolCall] = []
for content in response.content:
if content.type == "tool_use":
tool_calls.append(
AnthropicToolCall(
id=content.id,
type=content.type,
name=content.name,
input=content.input,
)
)
return content
for each in tool_calls:
if each.name == "ResponseSchema":
return each.input
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
return await self._generate_anthropic_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
)
return None
async def _generate_ollama_structured(
self,
@ -690,6 +767,7 @@ class LLMClient:
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
@ -873,6 +951,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
):
client: AsyncAnthropic = self._client
@ -881,14 +960,17 @@ class LLMClient:
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=tools,
) as stream:
tool_calls: List[AnthropicToolCall] = []
async for event in stream:
event: AnthropicMessageStreamEvent = event
if event.type == "text" and isinstance(event.text, str):
yield event.text
if event.type == "input_json":
event.partial_json
pass
def _stream_ollama(
self,

View file

@ -4,8 +4,12 @@ import json
from typing import Any, Callable, Coroutine, List, Optional
from fastapi import HTTPException
from enums.llm_provider import LLMProvider
from models.llm_message import GoogleToolCallMessage, OpenAIToolCallMessage
from models.llm_tool_call import GoogleToolCall, OpenAIToolCall
from models.llm_message import (
AnthropicToolCallMessage,
GoogleToolCallMessage,
OpenAIToolCallMessage,
)
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
@ -88,7 +92,13 @@ class LLMToolCallsHandler:
}
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
pass
parsed = self.parse_tool_openai(tool)
input_schema = parsed["function"]["parameters"]
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],
"input_schema": {"type": "object"} if input_schema == {} else input_schema,
}
async def handle_tool_calls_openai(
self,
@ -130,6 +140,30 @@ class LLMToolCallsHandler:
]
return tool_call_messages
async def handle_tool_calls_anthropic(
self,
tool_calls: List[AnthropicToolCall],
) -> List[AnthropicToolCallMessage]:
async_tool_calls_tasks = []
print("--------------------------------")
print(tool_calls)
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input)))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
AnthropicToolCallMessage(
content=result,
tool_use_id=tool_call.id,
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
print("--------------------------------")
print(tool_call_messages)
return tool_call_messages
# ? Tool call handlers
# Search web tool call handler
async def search_web_tool_call_handler(self, arguments: str) -> str: