feat(fastapi): adds tool call support for anthropic stream and stream structured
This commit is contained in:
parent
84fd0dee1a
commit
5c106bd664
4 changed files with 120 additions and 23 deletions
|
|
@ -35,7 +35,7 @@ async def test():
|
|||
|
||||
text_content = ""
|
||||
|
||||
response = await client.generate_structured(
|
||||
async for chunk in client.stream_structured(
|
||||
model=get_model(),
|
||||
messages=[
|
||||
LLMUserMessage(
|
||||
|
|
@ -47,6 +47,7 @@ async def test():
|
|||
SearchWebTool,
|
||||
get_current_datetime_tool,
|
||||
],
|
||||
)
|
||||
):
|
||||
text_content += chunk
|
||||
|
||||
return {"data": response}
|
||||
return {"data": text_content}
|
||||
|
|
|
|||
|
|
@ -955,6 +955,7 @@ class LLMClient:
|
|||
depth: int = 0,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
|
|
@ -968,9 +969,48 @@ class LLMClient:
|
|||
tool_calls: List[AnthropicToolCall] = []
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
if event.type == "input_json":
|
||||
event.partial_json
|
||||
pass
|
||||
|
||||
if event.type == "text":
|
||||
yield event.text
|
||||
|
||||
if (
|
||||
event.type == "content_block_stop"
|
||||
and event.content_block.type == "tool_use"
|
||||
):
|
||||
tool_calls.append(
|
||||
AnthropicToolCall(
|
||||
id=event.content_block.id,
|
||||
type=event.content_block.type,
|
||||
name=event.content_block.name,
|
||||
input=event.content_block.input,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
tool_call_messages = (
|
||||
await self.tool_calls_handler.handle_tool_calls_anthropic(
|
||||
tool_calls
|
||||
)
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
AnthropicAssistantMessage(
|
||||
role="assistant",
|
||||
content=[each.model_dump() for each in tool_calls],
|
||||
),
|
||||
AnthropicUserMessage(
|
||||
role="user",
|
||||
content=[each.model_dump() for each in tool_call_messages],
|
||||
),
|
||||
]
|
||||
async for event in self._stream_anthropic(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield event
|
||||
|
||||
def _stream_ollama(
|
||||
self,
|
||||
|
|
@ -1025,7 +1065,10 @@ class LLMClient:
|
|||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self._stream_anthropic(
|
||||
model=model, messages=messages, max_tokens=max_tokens
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=parsed_tools,
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._stream_ollama(
|
||||
|
|
@ -1271,6 +1314,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
tools: Optional[List[dict]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
depth: int = 0,
|
||||
):
|
||||
|
|
@ -1280,7 +1324,7 @@ class LLMClient:
|
|||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
for message in self._get_anthropic_messages(messages)
|
||||
],
|
||||
max_tokens=max_tokens or 4000,
|
||||
tools=[
|
||||
|
|
@ -1288,17 +1332,72 @@ class LLMClient:
|
|||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"input_schema": response_format,
|
||||
}
|
||||
},
|
||||
*(tools or []),
|
||||
],
|
||||
tool_choice={
|
||||
"type": "tool",
|
||||
"name": "ResponseSchema",
|
||||
},
|
||||
) as stream:
|
||||
tool_calls: List[AnthropicToolCall] = []
|
||||
has_response_schema_tool_call = False
|
||||
is_response_schema_tool_call_started = False
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
if event.type == "input_json" and isinstance(event.partial_json, str):
|
||||
yield event.partial_json
|
||||
if (
|
||||
event.type == "content_block_start"
|
||||
and event.content_block.type == "tool_use"
|
||||
):
|
||||
if event.content_block.name == "ResponseSchema":
|
||||
has_response_schema_tool_call = True
|
||||
is_response_schema_tool_call_started = True
|
||||
|
||||
if (
|
||||
event.type == "content_block_delta"
|
||||
and event.delta.type == "input_json_delta"
|
||||
and is_response_schema_tool_call_started
|
||||
):
|
||||
yield event.delta.partial_json
|
||||
|
||||
if has_response_schema_tool_call:
|
||||
continue
|
||||
|
||||
if (
|
||||
event.type == "content_block_stop"
|
||||
and event.content_block.type == "tool_use"
|
||||
):
|
||||
tool_calls.append(
|
||||
AnthropicToolCall(
|
||||
id=event.content_block.id,
|
||||
type=event.content_block.type,
|
||||
name=event.content_block.name,
|
||||
input=event.content_block.input,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
tool_call_messages = (
|
||||
await self.tool_calls_handler.handle_tool_calls_anthropic(
|
||||
tool_calls
|
||||
)
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
AnthropicAssistantMessage(
|
||||
role="assistant",
|
||||
content=[each.model_dump() for each in tool_calls],
|
||||
),
|
||||
AnthropicUserMessage(
|
||||
role="user",
|
||||
content=[each.model_dump() for each in tool_call_messages],
|
||||
),
|
||||
]
|
||||
async for event in self._stream_anthropic_structured(
|
||||
model=model,
|
||||
messages=new_messages,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
depth=depth + 1,
|
||||
):
|
||||
yield event
|
||||
|
||||
def _stream_ollama_structured(
|
||||
self,
|
||||
|
|
@ -1372,6 +1471,7 @@ class LLMClient:
|
|||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
|
|
@ -1416,4 +1516,4 @@ class LLMClient:
|
|||
contents=query,
|
||||
config=config,
|
||||
)
|
||||
return response.text
|
||||
return response.text
|
||||
|
|
@ -145,8 +145,6 @@ class LLMToolCallsHandler:
|
|||
tool_calls: List[AnthropicToolCall],
|
||||
) -> List[AnthropicToolCallMessage]:
|
||||
async_tool_calls_tasks = []
|
||||
print("--------------------------------")
|
||||
print(tool_calls)
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_handler = self.get_tool_handler(tool_name)
|
||||
|
|
@ -160,8 +158,6 @@ class LLMToolCallsHandler:
|
|||
)
|
||||
for tool_call, result in zip(tool_calls, tool_call_results)
|
||||
]
|
||||
print("--------------------------------")
|
||||
print(tool_call_messages)
|
||||
return tool_call_messages
|
||||
|
||||
# ? Tool call handlers
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage
|
||||
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
|
||||
from services.llm_client import LLMClient
|
||||
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
|
||||
|
|
@ -29,11 +29,11 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
|
|||
|
||||
def get_messages(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
LLMMessage(
|
||||
LLMSystemMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
LLMUserMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(prompt, n_slides, language, content),
|
||||
),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue