feat(fastapi): adds tool calls support for anthropic generate
This commit is contained in:
parent
49342e7c3c
commit
84fd0dee1a
6 changed files with 175 additions and 36 deletions
|
|
@ -13,9 +13,9 @@ API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"])
|
|||
|
||||
class ResponseContent(BaseModel):
|
||||
trending_ai_tool: str = Field(
|
||||
description="The summary of the trending AI tool in about 150 words",
|
||||
min_length=150,
|
||||
max_length=200,
|
||||
description="The summary of the trending AI tool in about 50 words",
|
||||
min_length=50,
|
||||
max_length=100,
|
||||
)
|
||||
current_date_time: str
|
||||
|
||||
|
|
@ -30,13 +30,12 @@ async def test():
|
|||
get_current_datetime_tool = LLMDynamicTool(
|
||||
name="GetDateTimeDynamicTool",
|
||||
description="Get the current date and time",
|
||||
parameters=None,
|
||||
handler=get_current_datetime_tool_handler,
|
||||
)
|
||||
|
||||
text_content = ""
|
||||
|
||||
async for event in client.stream_structured(
|
||||
response = await client.generate_structured(
|
||||
model=get_model(),
|
||||
messages=[
|
||||
LLMUserMessage(
|
||||
|
|
@ -48,7 +47,6 @@ async def test():
|
|||
SearchWebTool,
|
||||
get_current_datetime_tool,
|
||||
],
|
||||
):
|
||||
text_content += event
|
||||
)
|
||||
|
||||
return {"data": text_content}
|
||||
return {"data": response}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from typing import Any, List, Literal, Optional
|
|||
from pydantic import BaseModel
|
||||
from google.genai.types import Content as GoogleContent
|
||||
|
||||
from models.llm_tool_call import AnthropicToolCall
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
pass
|
||||
|
|
@ -28,6 +30,22 @@ class GoogleAssistantMessage(LLMMessage):
|
|||
content: GoogleContent
|
||||
|
||||
|
||||
class AnthropicAssistantMessage(LLMMessage):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: List[AnthropicToolCall]
|
||||
|
||||
|
||||
class AnthropicToolCallMessage(LLMMessage):
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str
|
||||
content: str
|
||||
|
||||
|
||||
class AnthropicUserMessage(LLMMessage):
|
||||
role: Literal["user"] = "user"
|
||||
content: List[AnthropicToolCallMessage]
|
||||
|
||||
|
||||
class OpenAIToolCallMessage(LLMMessage):
|
||||
role: Literal["tool"] = "tool"
|
||||
content: str
|
||||
|
|
|
|||
|
|
@ -20,3 +20,10 @@ class OpenAIToolCall(LLMToolCall):
|
|||
class GoogleToolCall(LLMToolCall):
|
||||
name: str
|
||||
arguments: Optional[dict] = None
|
||||
|
||||
|
||||
class AnthropicToolCall(LLMToolCall):
|
||||
type: Literal["tool_use"] = "tool_use"
|
||||
id: str
|
||||
name: str
|
||||
input: object
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class LLMDynamicTool(LLMTool):
|
|||
name: str
|
||||
description: str
|
||||
strict: bool = False
|
||||
parameters: Optional[dict] = None
|
||||
parameters: dict = {}
|
||||
handler: Callable[..., Coroutine[Any, Any, str]]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from anthropic.types import Message as AnthropicMessage
|
|||
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import (
|
||||
AnthropicAssistantMessage,
|
||||
AnthropicUserMessage,
|
||||
GoogleAssistantMessage,
|
||||
GoogleToolCallMessage,
|
||||
OpenAIAssistantMessage,
|
||||
|
|
@ -23,6 +25,7 @@ from models.llm_message import (
|
|||
LLMUserMessage,
|
||||
)
|
||||
from models.llm_tool_call import (
|
||||
AnthropicToolCall,
|
||||
GoogleToolCall,
|
||||
LLMToolCall,
|
||||
OpenAIToolCall,
|
||||
|
|
@ -157,8 +160,10 @@ class LLMClient:
|
|||
|
||||
return contents
|
||||
|
||||
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
return [message for message in messages if isinstance(message, LLMUserMessage)]
|
||||
def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
return [
|
||||
message for message in messages if not isinstance(message, LLMSystemMessage)
|
||||
]
|
||||
|
||||
# ? Generate Unstructured Content
|
||||
async def _generate_openai(
|
||||
|
|
@ -287,25 +292,61 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
):
|
||||
) -> str | None:
|
||||
client: AsyncAnthropic = self._client
|
||||
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
for message in self._get_anthropic_messages(messages)
|
||||
],
|
||||
tools=tools,
|
||||
max_tokens=max_tokens or 4000,
|
||||
)
|
||||
text = ""
|
||||
text_content = None
|
||||
tool_calls: List[AnthropicToolCall] = []
|
||||
for content in response.content:
|
||||
if content.type == "text" and isinstance(content.text, str):
|
||||
text += content.text
|
||||
if text == "":
|
||||
return None
|
||||
return text
|
||||
text_content = content.text
|
||||
|
||||
if content.type == "tool_use":
|
||||
tool_calls.append(
|
||||
AnthropicToolCall(
|
||||
id=content.id,
|
||||
type=content.type,
|
||||
name=content.name,
|
||||
input=content.input,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
tool_call_messages = (
|
||||
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
AnthropicAssistantMessage(
|
||||
role="assistant",
|
||||
content=[each.model_dump() for each in tool_calls],
|
||||
),
|
||||
AnthropicUserMessage(
|
||||
role="user",
|
||||
content=[each.model_dump() for each in tool_call_messages],
|
||||
),
|
||||
]
|
||||
return await self._generate_anthropic(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
)
|
||||
|
||||
return text_content
|
||||
|
||||
async def _generate_ollama(
|
||||
self,
|
||||
|
|
@ -361,7 +402,10 @@ class LLMClient:
|
|||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
content = await self._generate_anthropic(
|
||||
model=model, messages=messages, max_tokens=max_tokens
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama(
|
||||
|
|
@ -586,6 +630,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
tools: Optional[List[dict]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
depth: int = 0,
|
||||
):
|
||||
|
|
@ -595,7 +640,7 @@ class LLMClient:
|
|||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
for message in self._get_anthropic_messages(messages)
|
||||
],
|
||||
max_tokens=max_tokens or 4000,
|
||||
tools=[
|
||||
|
|
@ -603,19 +648,51 @@ class LLMClient:
|
|||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"input_schema": response_format,
|
||||
}
|
||||
},
|
||||
*(tools or []),
|
||||
],
|
||||
tool_choice={
|
||||
"type": "tool",
|
||||
"name": "ResponseSchema",
|
||||
},
|
||||
)
|
||||
content: dict | None = None
|
||||
for content_block in response.content:
|
||||
if content_block.type == "tool_use":
|
||||
content = content_block.input
|
||||
tool_calls: List[AnthropicToolCall] = []
|
||||
for content in response.content:
|
||||
if content.type == "tool_use":
|
||||
tool_calls.append(
|
||||
AnthropicToolCall(
|
||||
id=content.id,
|
||||
type=content.type,
|
||||
name=content.name,
|
||||
input=content.input,
|
||||
)
|
||||
)
|
||||
|
||||
return content
|
||||
for each in tool_calls:
|
||||
if each.name == "ResponseSchema":
|
||||
return each.input
|
||||
|
||||
if tool_calls:
|
||||
tool_call_messages = (
|
||||
await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
AnthropicAssistantMessage(
|
||||
role="assistant",
|
||||
content=[each.model_dump() for each in tool_calls],
|
||||
),
|
||||
AnthropicUserMessage(
|
||||
role="user",
|
||||
content=[each.model_dump() for each in tool_call_messages],
|
||||
),
|
||||
]
|
||||
return await self._generate_anthropic_structured(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_ollama_structured(
|
||||
self,
|
||||
|
|
@ -690,6 +767,7 @@ class LLMClient:
|
|||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
|
|
@ -873,6 +951,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
|
|
@ -881,14 +960,17 @@ class LLMClient:
|
|||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
for message in self._get_anthropic_messages(messages)
|
||||
],
|
||||
max_tokens=max_tokens or 4000,
|
||||
tools=tools,
|
||||
) as stream:
|
||||
tool_calls: List[AnthropicToolCall] = []
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
if event.type == "text" and isinstance(event.text, str):
|
||||
yield event.text
|
||||
if event.type == "input_json":
|
||||
event.partial_json
|
||||
pass
|
||||
|
||||
def _stream_ollama(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@ import json
|
|||
from typing import Any, Callable, Coroutine, List, Optional
|
||||
from fastapi import HTTPException
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import GoogleToolCallMessage, OpenAIToolCallMessage
|
||||
from models.llm_tool_call import GoogleToolCall, OpenAIToolCall
|
||||
from models.llm_message import (
|
||||
AnthropicToolCallMessage,
|
||||
GoogleToolCallMessage,
|
||||
OpenAIToolCallMessage,
|
||||
)
|
||||
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
|
||||
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
|
||||
|
||||
|
||||
|
|
@ -88,7 +92,13 @@ class LLMToolCallsHandler:
|
|||
}
|
||||
|
||||
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
pass
|
||||
parsed = self.parse_tool_openai(tool)
|
||||
input_schema = parsed["function"]["parameters"]
|
||||
return {
|
||||
"name": parsed["function"]["name"],
|
||||
"description": parsed["function"]["description"],
|
||||
"input_schema": {"type": "object"} if input_schema == {} else input_schema,
|
||||
}
|
||||
|
||||
async def handle_tool_calls_openai(
|
||||
self,
|
||||
|
|
@ -130,6 +140,30 @@ class LLMToolCallsHandler:
|
|||
]
|
||||
return tool_call_messages
|
||||
|
||||
async def handle_tool_calls_anthropic(
|
||||
self,
|
||||
tool_calls: List[AnthropicToolCall],
|
||||
) -> List[AnthropicToolCallMessage]:
|
||||
async_tool_calls_tasks = []
|
||||
print("--------------------------------")
|
||||
print(tool_calls)
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_handler = self.get_tool_handler(tool_name)
|
||||
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input)))
|
||||
|
||||
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
|
||||
tool_call_messages = [
|
||||
AnthropicToolCallMessage(
|
||||
content=result,
|
||||
tool_use_id=tool_call.id,
|
||||
)
|
||||
for tool_call, result in zip(tool_calls, tool_call_results)
|
||||
]
|
||||
print("--------------------------------")
|
||||
print(tool_call_messages)
|
||||
return tool_call_messages
|
||||
|
||||
# ? Tool call handlers
|
||||
# Search web tool call handler
|
||||
async def search_web_tool_call_handler(self, arguments: str) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue