feat(fastapi): adds tool call support for anthropic stream and stream structured

This commit is contained in:
sauravniraula 2025-08-08 22:11:41 +05:45
parent 84fd0dee1a
commit 5c106bd664
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
4 changed files with 120 additions and 23 deletions

View file

@ -35,7 +35,7 @@ async def test():
text_content = ""
response = await client.generate_structured(
async for chunk in client.stream_structured(
model=get_model(),
messages=[
LLMUserMessage(
@ -47,6 +47,7 @@ async def test():
SearchWebTool,
get_current_datetime_tool,
],
)
):
text_content += chunk
return {"data": response}
return {"data": text_content}

View file

@ -955,6 +955,7 @@ class LLMClient:
depth: int = 0,
):
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
system=self._get_system_prompt(messages),
@ -968,9 +969,48 @@ class LLMClient:
tool_calls: List[AnthropicToolCall] = []
async for event in stream:
event: AnthropicMessageStreamEvent = event
if event.type == "input_json":
event.partial_json
pass
if event.type == "text":
yield event.text
if (
event.type == "content_block_stop"
and event.content_block.type == "tool_use"
):
tool_calls.append(
AnthropicToolCall(
id=event.content_block.id,
type=event.content_block.type,
name=event.content_block.name,
input=event.content_block.input,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(
tool_calls
)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
async for event in self._stream_anthropic(
model=model,
messages=new_messages,
max_tokens=max_tokens,
tools=tools,
depth=depth + 1,
):
yield event
def _stream_ollama(
self,
@ -1025,7 +1065,10 @@ class LLMClient:
)
case LLMProvider.ANTHROPIC:
return self._stream_anthropic(
model=model, messages=messages, max_tokens=max_tokens
model=model,
messages=messages,
max_tokens=max_tokens,
tools=parsed_tools,
)
case LLMProvider.OLLAMA:
return self._stream_ollama(
@ -1271,6 +1314,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
response_format: dict,
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
):
@ -1280,7 +1324,7 @@ class LLMClient:
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=[
@ -1288,17 +1332,72 @@ class LLMClient:
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": response_format,
}
},
*(tools or []),
],
tool_choice={
"type": "tool",
"name": "ResponseSchema",
},
) as stream:
tool_calls: List[AnthropicToolCall] = []
has_response_schema_tool_call = False
is_response_schema_tool_call_started = False
async for event in stream:
event: AnthropicMessageStreamEvent = event
if event.type == "input_json" and isinstance(event.partial_json, str):
yield event.partial_json
if (
event.type == "content_block_start"
and event.content_block.type == "tool_use"
):
if event.content_block.name == "ResponseSchema":
has_response_schema_tool_call = True
is_response_schema_tool_call_started = True
if (
event.type == "content_block_delta"
and event.delta.type == "input_json_delta"
and is_response_schema_tool_call_started
):
yield event.delta.partial_json
if has_response_schema_tool_call:
continue
if (
event.type == "content_block_stop"
and event.content_block.type == "tool_use"
):
tool_calls.append(
AnthropicToolCall(
id=event.content_block.id,
type=event.content_block.type,
name=event.content_block.name,
input=event.content_block.input,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_anthropic(
tool_calls
)
)
new_messages = [
*messages,
AnthropicAssistantMessage(
role="assistant",
content=[each.model_dump() for each in tool_calls],
),
AnthropicUserMessage(
role="user",
content=[each.model_dump() for each in tool_call_messages],
),
]
async for event in self._stream_anthropic_structured(
model=model,
messages=new_messages,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
depth=depth + 1,
):
yield event
def _stream_ollama_structured(
self,
@ -1372,6 +1471,7 @@ class LLMClient:
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
@ -1416,4 +1516,4 @@ class LLMClient:
contents=query,
config=config,
)
return response.text
return response.text

View file

@ -145,8 +145,6 @@ class LLMToolCallsHandler:
tool_calls: List[AnthropicToolCall],
) -> List[AnthropicToolCallMessage]:
async_tool_calls_tasks = []
print("--------------------------------")
print(tool_calls)
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
@ -160,8 +158,6 @@ class LLMToolCallsHandler:
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
print("--------------------------------")
print(tool_call_messages)
return tool_call_messages
# ? Tool call handlers

View file

@ -1,6 +1,6 @@
from typing import Optional
from models.llm_message import LLMMessage
from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
from services.llm_client import LLMClient
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
@ -29,11 +29,11 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
def get_messages(prompt: str, n_slides: int, language: str, content: str):
return [
LLMMessage(
LLMSystemMessage(
role="system",
content=system_prompt,
),
LLMMessage(
LLMUserMessage(
role="user",
content=get_user_prompt(prompt, n_slides, language, content),
),