feat(fastapi): adding llm agent that supports tool calls and handling
This commit is contained in:
parent
dcfe8a68e1
commit
a3e81da767
6 changed files with 325 additions and 44 deletions
8
servers/fastapi/enums/llm_call_type.py
Normal file
8
servers/fastapi/enums/llm_call_type.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class LLMCallType(Enum):
|
||||
UNSTRUCTURED = "unstructured"
|
||||
UNSTRUCTURED_STREAM = "unstructured_stream"
|
||||
STRUCTURED = "structured"
|
||||
STRUCTURED_STREAM = "structured_stream"
|
||||
|
|
@ -1,7 +1,28 @@
|
|||
from typing import Literal
|
||||
from typing import List, Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: Literal["user", "system"]
|
||||
pass
|
||||
|
||||
|
||||
class LLMUserMessage(LLMMessage):
|
||||
role: Literal["user"]
|
||||
content: str
|
||||
|
||||
|
||||
class LLMSystemMessage(LLMMessage):
|
||||
role: Literal["system"]
|
||||
content: str
|
||||
|
||||
|
||||
class LLMToolCallMessage(LLMMessage):
|
||||
role: Literal["tool"]
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class LLMAssistantMessage(LLMMessage):
|
||||
role: Literal["assistant"]
|
||||
content: str | None = None
|
||||
tool_calls: Optional[List[dict]] = None
|
||||
|
|
|
|||
8
servers/fastapi/models/llm_tool_call.py
Normal file
8
servers/fastapi/models/llm_tool_call.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMToolCall(BaseModel):
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
arguments: Optional[str] = None
|
||||
21
servers/fastapi/models/llm_tools.py
Normal file
21
servers/fastapi/models/llm_tools.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Any, Callable, Coroutine
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LLMTool(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class LLMDynamicTool(LLMTool):
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
handler: Callable[..., Coroutine[Any, Any, str]]
|
||||
|
||||
|
||||
class SearchWebTool(LLMTool):
|
||||
query: str = Field(description="The query to search the web for")
|
||||
|
||||
|
||||
class GetCurrentDatetimeTool(LLMTool):
|
||||
pass
|
||||
|
|
@ -3,13 +3,22 @@ import json
|
|||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import (
|
||||
LLMAssistantMessage,
|
||||
LLMMessage,
|
||||
LLMSystemMessage,
|
||||
LLMUserMessage,
|
||||
)
|
||||
from models.llm_tool_call import LLMToolCall
|
||||
from models.llm_tools import LLMDynamicTool, LLMTool
|
||||
from services.llm_tool_calls_handler import LLMToolCallsHandler
|
||||
from utils.async_iterator import iterator_to_async
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
|
|
@ -23,6 +32,7 @@ from utils.get_env import (
|
|||
)
|
||||
from utils.llm_provider import get_llm_provider
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.randomizers import get_random_uuid
|
||||
from utils.schema_utils import ensure_strict_json_schema
|
||||
|
||||
|
||||
|
|
@ -30,6 +40,7 @@ class LLMClient:
|
|||
def __init__(self):
|
||||
self.llm_provider = get_llm_provider()
|
||||
self._client = self._get_client()
|
||||
self.tool_calls_handler = LLMToolCallsHandler(self)
|
||||
|
||||
# ? Use tool calls
|
||||
def use_tool_calls(self) -> bool:
|
||||
|
|
@ -104,15 +115,19 @@ class LLMClient:
|
|||
# ? Prompts
|
||||
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
if isinstance(message, LLMSystemMessage):
|
||||
return message.content
|
||||
return ""
|
||||
|
||||
def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]:
|
||||
return [message.content for message in messages if message.role == "user"]
|
||||
return [
|
||||
message.content
|
||||
for message in messages
|
||||
if isinstance(message, LLMUserMessage)
|
||||
]
|
||||
|
||||
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
return [message for message in messages if message.role == "user"]
|
||||
return [message for message in messages if isinstance(message, LLMUserMessage)]
|
||||
|
||||
# ? Generate Unstructured Content
|
||||
async def _generate_openai(
|
||||
|
|
@ -120,6 +135,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
extra_body: Optional[dict] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
|
|
@ -127,8 +143,36 @@ class LLMClient:
|
|||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
if tool_calls:
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
[
|
||||
LLMToolCall(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
)
|
||||
new_messages = [
|
||||
*messages,
|
||||
LLMAssistantMessage(
|
||||
role="assistant",
|
||||
content=response.choices[0].message.content,
|
||||
tool_calls=[
|
||||
tool_call.model_dump() for tool_call in tool_call_messages
|
||||
],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
return await self._generate_openai(
|
||||
model, new_messages, max_tokens, tools, extra_body
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _generate_google(
|
||||
|
|
@ -192,7 +236,10 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
||||
):
|
||||
parsed_tools = self.get_tools(tools)
|
||||
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
|
|
@ -220,6 +267,7 @@ class LLMClient:
|
|||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
extra_body: Optional[dict] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
|
|
@ -231,46 +279,75 @@ class LLMClient:
|
|||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
if not use_tool_calls:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=max_tokens,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
tools=[
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"strict": strict,
|
||||
"parameters": response_format,
|
||||
},
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
}
|
||||
],
|
||||
tool_choice="required",
|
||||
max_completion_tokens=max_tokens,
|
||||
extra_body=extra_body,
|
||||
),
|
||||
},
|
||||
max_completion_tokens=max_tokens,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
if tool_calls:
|
||||
tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
|
||||
[
|
||||
LLMToolCall(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
)
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
if tool_calls:
|
||||
content = tool_calls[0].function.arguments
|
||||
new_messages = [
|
||||
*messages,
|
||||
LLMAssistantMessage(
|
||||
role="assistant",
|
||||
content=response.choices[0].message.content,
|
||||
tool_calls=[each.model_dump() for each in tool_calls],
|
||||
),
|
||||
*tool_call_messages,
|
||||
]
|
||||
return await self._generate_openai_structured(
|
||||
model,
|
||||
new_messages,
|
||||
response_format,
|
||||
strict,
|
||||
max_tokens,
|
||||
tools,
|
||||
extra_body,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
# else:
|
||||
# response = await client.chat.completions.create(
|
||||
# model=model,
|
||||
# messages=[message.model_dump() for message in messages],
|
||||
# tools=[
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "ResponseSchema",
|
||||
# "description": "A response to the user's message",
|
||||
# "strict": strict,
|
||||
# "parameters": response_format,
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# tool_choice="required",
|
||||
# max_completion_tokens=max_tokens,
|
||||
# extra_body=extra_body,
|
||||
# )
|
||||
# tool_calls = response.choices[0].message.tool_calls
|
||||
# if tool_calls:
|
||||
# content = tool_calls[0].function.arguments
|
||||
|
||||
if content:
|
||||
return json.loads(content)
|
||||
|
|
@ -404,6 +481,7 @@ class LLMClient:
|
|||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
extra_body: Optional[dict] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
|
|
@ -413,9 +491,35 @@ class LLMClient:
|
|||
max_completion_tokens=max_tokens,
|
||||
extra_body=extra_body,
|
||||
) as stream:
|
||||
tool_calls = []
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
if (
|
||||
event.type == "tool_calls.function.arguments.delta"
|
||||
and event.name == "ResponseSchema"
|
||||
):
|
||||
yield event.arguments_delta
|
||||
elif event.type == "content.delta":
|
||||
yield event.delta
|
||||
elif (
|
||||
event.type == "tool_calls.function.arguments.done"
|
||||
and event.name != "ResponseSchema"
|
||||
):
|
||||
tool_calls.append(
|
||||
LLMToolCall(
|
||||
id=get_random_uuid(),
|
||||
name=event.name,
|
||||
arguments=event.arguments,
|
||||
)
|
||||
)
|
||||
if tool_calls:
|
||||
tool_call_messages = (
|
||||
await self.tool_calls_handler.handle_tool_calls_openai(tool_calls)
|
||||
)
|
||||
|
||||
new_messages = [
|
||||
*messages,
|
||||
*tool_call_messages,
|
||||
]
|
||||
|
||||
async def _stream_google(
|
||||
self,
|
||||
|
|
@ -497,6 +601,7 @@ class LLMClient:
|
|||
response_format: dict,
|
||||
strict: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
extra_body: Optional[dict] = None,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
|
|
@ -659,3 +764,10 @@ class LLMClient:
|
|||
return self._stream_custom_structured(
|
||||
model, messages, response_format, strict, max_tokens
|
||||
)
|
||||
|
||||
# ? Tool call handling
|
||||
def get_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None):
|
||||
if tools is None:
|
||||
return None
|
||||
parsed_tools = map(self.tool_calls_handler.parse_tool, tools)
|
||||
return list(parsed_tools)
|
||||
|
|
|
|||
111
servers/fastapi/services/llm_tool_calls_handler.py
Normal file
111
servers/fastapi/services/llm_tool_calls_handler.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Coroutine, List
|
||||
from fastapi import HTTPException
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
|
||||
from enums.llm_call_type import LLMCallType
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import LLMMessage, LLMToolCallMessage
|
||||
from models.llm_tool_call import LLMToolCall
|
||||
from models.llm_tools import (
|
||||
GetCurrentDatetimeTool,
|
||||
LLMDynamicTool,
|
||||
LLMTool,
|
||||
SearchWebTool,
|
||||
)
|
||||
|
||||
|
||||
class LLMToolCallsHandler:
|
||||
def __init__(self, client):
|
||||
from services.llm_client import LLMClient
|
||||
|
||||
self.client: LLMClient = client
|
||||
|
||||
self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = {
|
||||
"SearchWebTool": self.search_web_tool_call_handler,
|
||||
"GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler,
|
||||
}
|
||||
self.dynamic_tools: List[LLMDynamicTool] = []
|
||||
|
||||
def get_tool_handler(
|
||||
self, tool_name: str
|
||||
) -> Callable[..., Coroutine[Any, Any, str]]:
|
||||
handler = self.tools_map.get(tool_name)
|
||||
if not handler:
|
||||
dynamic_tools = list(
|
||||
filter(lambda tool: tool.name == tool_name, self.dynamic_tools)
|
||||
)
|
||||
if dynamic_tools:
|
||||
return dynamic_tools[0].handler
|
||||
raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found")
|
||||
|
||||
def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False):
|
||||
if isinstance(tool, LLMDynamicTool):
|
||||
self.dynamic_tools.append(tool)
|
||||
|
||||
match self.client.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self.parse_tool_openai(tool, strict)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self.parse_tool_anthropic(tool)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self.parse_tool_google(tool)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"LLM provider must be either openai, anthropic, or google"
|
||||
)
|
||||
|
||||
def parse_tool_openai(
|
||||
self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False
|
||||
):
|
||||
if isinstance(tool, LLMDynamicTool):
|
||||
name = tool.name
|
||||
description = tool.description
|
||||
parameters = tool.parameters
|
||||
else:
|
||||
name = tool.__class__.__name__
|
||||
description = tool.__class__.__doc__ or ""
|
||||
parameters = tool.model_dump(mode="json")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"strict": strict,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
pass
|
||||
|
||||
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
pass
|
||||
|
||||
async def handle_tool_calls_openai(
|
||||
self,
|
||||
tool_calls: List[LLMToolCall],
|
||||
) -> List[LLMToolCallMessage]:
|
||||
async_tool_calls_tasks = []
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_handler = self.get_tool_handler(tool_name)
|
||||
async_tool_calls_tasks.append(tool_handler(tool_call.arguments))
|
||||
|
||||
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
|
||||
return [
|
||||
LLMToolCallMessage(
|
||||
role="tool",
|
||||
content=result,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
for tool_call, result in zip(tool_calls, tool_call_results)
|
||||
]
|
||||
|
||||
# ? Tool call handlers
|
||||
async def search_web_tool_call_handler(self, tool_call: dict) -> str:
|
||||
pass
|
||||
|
||||
async def get_current_datetime_tool_call_handler(self, tool_call: dict) -> str:
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
Loading…
Add table
Reference in a new issue