diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 0bc78101..9db101e0 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -35,7 +35,7 @@ async def test(): text_content = "" - response = await client.generate_structured( + async for chunk in client.stream_structured( model=get_model(), messages=[ LLMUserMessage( @@ -47,6 +47,7 @@ async def test(): SearchWebTool, get_current_datetime_tool, ], - ) + ): + text_content += chunk - return {"data": response} + return {"data": text_content} diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 07789ed4..847a4e80 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -955,6 +955,7 @@ class LLMClient: depth: int = 0, ): client: AsyncAnthropic = self._client + async with client.messages.stream( model=model, system=self._get_system_prompt(messages), @@ -968,9 +969,48 @@ class LLMClient: tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json": - event.partial_json - pass + + if event.type == "text": + yield event.text + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama( self, @@ -1025,7 +1065,10 @@ class LLMClient: ) case LLMProvider.ANTHROPIC: return self._stream_anthropic( - model=model, messages=messages, max_tokens=max_tokens + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, ) case LLMProvider.OLLAMA: return self._stream_ollama( @@ -1271,6 +1314,7 @@ class LLMClient: model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, ): @@ -1280,7 +1324,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -1288,17 +1332,72 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) as stream: + tool_calls: List[AnthropicToolCall] = [] + has_response_schema_tool_call = False + is_response_schema_tool_call_started = False async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json" and isinstance(event.partial_json, str): - yield event.partial_json + if ( + event.type == "content_block_start" + and event.content_block.type == "tool_use" + ): + if event.content_block.name == "ResponseSchema": + has_response_schema_tool_call = True + is_response_schema_tool_call_started = True + + if ( + event.type == "content_block_delta" + and event.delta.type == "input_json_delta" + and is_response_schema_tool_call_started + ): + yield event.delta.partial_json + + if has_response_schema_tool_call: + continue + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama_structured( self, @@ -1372,6 +1471,7 @@ class LLMClient: model=model, messages=messages, response_format=response_format, + tools=parsed_tools, max_tokens=max_tokens, ) case LLMProvider.OLLAMA: @@ -1416,4 +1516,4 @@ class LLMClient: contents=query, config=config, ) - return response.text + return response.text \ No newline at end of file diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index a61d24ad..723fefe0 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -145,8 +145,6 @@ class LLMToolCallsHandler: tool_calls: List[AnthropicToolCall], ) -> List[AnthropicToolCallMessage]: async_tool_calls_tasks = [] - print("--------------------------------") - print(tool_calls) for tool_call in tool_calls: tool_name = tool_call.name tool_handler = self.get_tool_handler(tool_name) @@ -160,8 +158,6 @@ class LLMToolCallsHandler: ) for tool_call, result in zip(tool_calls, tool_call_results) ] - print("--------------------------------") - print(tool_call_messages) return tool_call_messages # ? Tool call handlers diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index ee5f0224..3bb95c5e 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,6 +1,6 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides @@ -29,11 +29,11 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ - LLMMessage( + LLMSystemMessage( role="system", content=system_prompt, ), - LLMMessage( + LLMUserMessage( role="user", content=get_user_prompt(prompt, n_slides, language, content), ),