feat(fastapi): adding llm agent that supports tool calls and handling

This commit is contained in:
sauravniraula 2025-08-07 04:44:39 +05:45
parent dcfe8a68e1
commit a3e81da767
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
6 changed files with 325 additions and 44 deletions

View file

@ -0,0 +1,8 @@
from enum import Enum
class LLMCallType(Enum):
UNSTRUCTURED = "unstructured"
UNSTRUCTURED_STREAM = "unstructured_stream"
STRUCTURED = "structured"
STRUCTURED_STREAM = "structured_stream"

View file

@ -1,7 +1,28 @@
from typing import Literal
from typing import List, Literal, Optional
from pydantic import BaseModel
class LLMMessage(BaseModel):
role: Literal["user", "system"]
pass
class LLMUserMessage(LLMMessage):
role: Literal["user"]
content: str
class LLMSystemMessage(LLMMessage):
role: Literal["system"]
content: str
class LLMToolCallMessage(LLMMessage):
role: Literal["tool"]
content: str
tool_call_id: str
class LLMAssistantMessage(LLMMessage):
role: Literal["assistant"]
content: str | None = None
tool_calls: Optional[List[dict]] = None

View file

@ -0,0 +1,8 @@
from typing import Optional
from pydantic import BaseModel
class LLMToolCall(BaseModel):
id: Optional[str] = None
name: str
arguments: Optional[str] = None

View file

@ -0,0 +1,21 @@
from typing import Any, Callable, Coroutine
from pydantic import BaseModel, Field
class LLMTool(BaseModel):
pass
class LLMDynamicTool(LLMTool):
name: str
description: str
parameters: dict
handler: Callable[..., Coroutine[Any, Any, str]]
class SearchWebTool(LLMTool):
query: str = Field(description="The query to search the web for")
class GetCurrentDatetimeTool(LLMTool):
pass

View file

@ -3,13 +3,22 @@ import json
from typing import List, Optional
from fastapi import HTTPException
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
from google import genai
from google.genai.types import GenerateContentConfig
from anthropic import AsyncAnthropic
from anthropic.types import Message as AnthropicMessage
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
from enums.llm_provider import LLMProvider
from models.llm_message import LLMMessage
from models.llm_message import (
LLMAssistantMessage,
LLMMessage,
LLMSystemMessage,
LLMUserMessage,
)
from models.llm_tool_call import LLMToolCall
from models.llm_tools import LLMDynamicTool, LLMTool
from services.llm_tool_calls_handler import LLMToolCallsHandler
from utils.async_iterator import iterator_to_async
from utils.get_env import (
get_anthropic_api_key_env,
@ -23,6 +32,7 @@ from utils.get_env import (
)
from utils.llm_provider import get_llm_provider
from utils.parsers import parse_bool_or_none
from utils.randomizers import get_random_uuid
from utils.schema_utils import ensure_strict_json_schema
@ -30,6 +40,7 @@ class LLMClient:
def __init__(self):
self.llm_provider = get_llm_provider()
self._client = self._get_client()
self.tool_calls_handler = LLMToolCallsHandler(self)
# ? Use tool calls
def use_tool_calls(self) -> bool:
@ -104,15 +115,19 @@ class LLMClient:
# ? Prompts
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
for message in messages:
if message.role == "system":
if isinstance(message, LLMSystemMessage):
return message.content
return ""
def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]:
return [message.content for message in messages if message.role == "user"]
return [
message.content
for message in messages
if isinstance(message, LLMUserMessage)
]
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
return [message for message in messages if message.role == "user"]
return [message for message in messages if isinstance(message, LLMUserMessage)]
# ? Generate Unstructured Content
async def _generate_openai(
@ -120,6 +135,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
):
client: AsyncOpenAI = self._client
@ -127,8 +143,36 @@ class LLMClient:
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
tools=tools,
extra_body=extra_body,
)
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
[
LLMToolCall(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_calls
]
)
new_messages = [
*messages,
LLMAssistantMessage(
role="assistant",
content=response.choices[0].message.content,
tool_calls=[
tool_call.model_dump() for tool_call in tool_call_messages
],
),
*tool_call_messages,
]
return await self._generate_openai(
model, new_messages, max_tokens, tools, extra_body
)
return response.choices[0].message.content
async def _generate_google(
@ -192,7 +236,10 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
parsed_tools = self.get_tools(tools)
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
@ -220,6 +267,7 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
):
client: AsyncOpenAI = self._client
@ -231,46 +279,75 @@ class LLMClient:
path=(),
root=response_schema,
)
if not use_tool_calls:
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
{
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
}
),
},
max_completion_tokens=max_tokens,
extra_body=extra_body,
)
content = response.choices[0].message.content
else:
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
tools=[
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
{
"type": "function",
"function": {
"name": "ResponseSchema",
"description": "A response to the user's message",
"strict": strict,
"parameters": response_format,
},
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
}
],
tool_choice="required",
max_completion_tokens=max_tokens,
extra_body=extra_body,
),
},
max_completion_tokens=max_tokens,
extra_body=extra_body,
)
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
[
LLMToolCall(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_calls
]
)
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
content = tool_calls[0].function.arguments
new_messages = [
*messages,
LLMAssistantMessage(
role="assistant",
content=response.choices[0].message.content,
tool_calls=[each.model_dump() for each in tool_calls],
),
*tool_call_messages,
]
return await self._generate_openai_structured(
model,
new_messages,
response_format,
strict,
max_tokens,
tools,
extra_body,
)
content = response.choices[0].message.content
# else:
# response = await client.chat.completions.create(
# model=model,
# messages=[message.model_dump() for message in messages],
# tools=[
# {
# "type": "function",
# "function": {
# "name": "ResponseSchema",
# "description": "A response to the user's message",
# "strict": strict,
# "parameters": response_format,
# },
# }
# ],
# tool_choice="required",
# max_completion_tokens=max_tokens,
# extra_body=extra_body,
# )
# tool_calls = response.choices[0].message.tool_calls
# if tool_calls:
# content = tool_calls[0].function.arguments
if content:
return json.loads(content)
@ -404,6 +481,7 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
):
client: AsyncOpenAI = self._client
@ -413,9 +491,35 @@ class LLMClient:
max_completion_tokens=max_tokens,
extra_body=extra_body,
) as stream:
tool_calls = []
async for event in stream:
if event.type == "content.delta":
if (
event.type == "tool_calls.function.arguments.delta"
and event.name == "ResponseSchema"
):
yield event.arguments_delta
elif event.type == "content.delta":
yield event.delta
elif (
event.type == "tool_calls.function.arguments.done"
and event.name != "ResponseSchema"
):
tool_calls.append(
LLMToolCall(
id=get_random_uuid(),
name=event.name,
arguments=event.arguments,
)
)
if tool_calls:
tool_call_messages = (
await self.tool_calls_handler.handle_tool_calls_openai(tool_calls)
)
new_messages = [
*messages,
*tool_call_messages,
]
async def _stream_google(
self,
@ -497,6 +601,7 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
extra_body: Optional[dict] = None,
):
client: AsyncOpenAI = self._client
@ -659,3 +764,10 @@ class LLMClient:
return self._stream_custom_structured(
model, messages, response_format, strict, max_tokens
)
# ? Tool call handling
def get_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None):
if tools is None:
return None
parsed_tools = map(self.tool_calls_handler.parse_tool, tools)
return list(parsed_tools)

View file

@ -0,0 +1,111 @@
import asyncio
from datetime import datetime
from typing import Any, Callable, Coroutine, List
from fastapi import HTTPException
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
from enums.llm_call_type import LLMCallType
from enums.llm_provider import LLMProvider
from models.llm_message import LLMMessage, LLMToolCallMessage
from models.llm_tool_call import LLMToolCall
from models.llm_tools import (
GetCurrentDatetimeTool,
LLMDynamicTool,
LLMTool,
SearchWebTool,
)
class LLMToolCallsHandler:
def __init__(self, client):
from services.llm_client import LLMClient
self.client: LLMClient = client
self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = {
"SearchWebTool": self.search_web_tool_call_handler,
"GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler,
}
self.dynamic_tools: List[LLMDynamicTool] = []
def get_tool_handler(
self, tool_name: str
) -> Callable[..., Coroutine[Any, Any, str]]:
handler = self.tools_map.get(tool_name)
if not handler:
dynamic_tools = list(
filter(lambda tool: tool.name == tool_name, self.dynamic_tools)
)
if dynamic_tools:
return dynamic_tools[0].handler
raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found")
def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False):
if isinstance(tool, LLMDynamicTool):
self.dynamic_tools.append(tool)
match self.client.llm_provider:
case LLMProvider.OPENAI:
return self.parse_tool_openai(tool, strict)
case LLMProvider.ANTHROPIC:
return self.parse_tool_anthropic(tool)
case LLMProvider.GOOGLE:
return self.parse_tool_google(tool)
case _:
raise ValueError(
f"LLM provider must be either openai, anthropic, or google"
)
def parse_tool_openai(
self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False
):
if isinstance(tool, LLMDynamicTool):
name = tool.name
description = tool.description
parameters = tool.parameters
else:
name = tool.__class__.__name__
description = tool.__class__.__doc__ or ""
parameters = tool.model_dump(mode="json")
return {
"type": "function",
"function": {
"name": name,
"description": description,
"strict": strict,
"parameters": parameters,
},
}
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
pass
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
pass
async def handle_tool_calls_openai(
self,
tool_calls: List[LLMToolCall],
) -> List[LLMToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(tool_call.arguments))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
return [
LLMToolCallMessage(
role="tool",
content=result,
tool_call_id=tool_call.id,
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
# ? Tool call handlers
async def search_web_tool_call_handler(self, tool_call: dict) -> str:
pass
async def get_current_datetime_tool_call_handler(self, tool_call: dict) -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")