feat(fastapi): adds anthropic web search, fix(fastapi): llm messages to system and user message

This commit is contained in:
sauravniraula 2025-08-09 01:36:16 +05:45
parent 5c106bd664
commit dc62eb72d1
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
12 changed files with 117 additions and 73 deletions

View file

@ -72,6 +72,9 @@ async def stream_outlines(
presentation_outlines_json = json.loads(presentation_outlines_text)
except Exception as e:
print(e)
with open("./debug/outlines.txt", "w") as f:
f.write(presentation_outlines_text)
print(presentation_outlines_text)
raise HTTPException(
status_code=400,
detail="Failed to generate presentation outlines. Please try again.",

View file

@ -1,11 +1,10 @@
from datetime import datetime
import json
from fastapi import APIRouter
from pydantic import BaseModel, Field
from models.llm_message import LLMUserMessage
from models.llm_tools import LLMDynamicTool, SearchWebTool
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
from services.llm_client import LLMClient
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
from utils.llm_provider import get_model
API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"])
@ -24,30 +23,7 @@ class ResponseContent(BaseModel):
async def test():
client = LLMClient()
async def get_current_datetime_tool_handler(_) -> str:
return datetime.now().isoformat()
response = await client._search_anthropic("Trending AI tool now")
# print(response)
get_current_datetime_tool = LLMDynamicTool(
name="GetDateTimeDynamicTool",
description="Get the current date and time",
handler=get_current_datetime_tool_handler,
)
text_content = ""
async for chunk in client.stream_structured(
model=get_model(),
messages=[
LLMUserMessage(
content="What is the current date and time ? What is the trending AI tool now ? Use Available tools to get the information."
),
],
response_format=ResponseContent.model_json_schema(),
tools=[
SearchWebTool,
get_current_datetime_tool,
],
):
text_content += chunk
return {"data": text_content}
return {"data": ""}

View file

@ -9,7 +9,6 @@ class LLMTool(BaseModel):
class LLMDynamicTool(LLMTool):
name: str
description: str
strict: bool = False
parameters: dict = {}
handler: Callable[..., Coroutine[Any, Any, str]]

View file

@ -43,13 +43,11 @@ from utils.get_env import (
get_google_api_key_env,
get_ollama_url_env,
get_openai_api_key_env,
get_openai_model_env,
get_tool_calls_env,
)
from utils.llm_provider import get_llm_provider, get_model
from utils.parsers import parse_bool_or_none
from utils.randomizers import get_random_uuid
from utils.schema_utils import ensure_strict_json_schema
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
class LLMClient:
@ -455,10 +453,10 @@ class LLMClient:
LLMDynamicTool(
name="ResponseSchema",
description="Provide response to the user",
strict=strict,
parameters=response_schema,
handler=do_nothing_async,
)
),
strict=strict,
)
)
@ -557,7 +555,7 @@ class LLMClient:
{
"name": "ResponseSchema",
"description": "Provide response to the user",
"parameters": response_format,
"parameters_json_schema": response_format,
}
]
)
@ -571,7 +569,7 @@ class LLMClient:
tools=google_tools,
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json" if not tools else None,
response_json_schema=response_format if not tools else None,
response_schema=response_format if not tools else None,
max_output_tokens=max_tokens,
),
)
@ -1114,10 +1112,10 @@ class LLMClient:
LLMDynamicTool(
name="ResponseSchema",
description="Provide response to the user",
strict=strict,
parameters=response_schema,
handler=do_nothing_async,
)
),
strict=strict,
)
)
@ -1235,10 +1233,11 @@ class LLMClient:
max_tokens: Optional[int] = None,
tools: Optional[List[dict]] = None,
depth: int = 0,
):
) -> AsyncGenerator[str, None]:
client: genai.Client = self._client
google_tools = []
google_tools = None
if tools:
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
google_tools.append(
@ -1247,13 +1246,14 @@ class LLMClient:
{
"name": "ResponseSchema",
"description": "Provide response to the user",
"parameters": response_format,
"parameters_json_schema": response_format,
}
]
)
)
tool_calls: List[GoogleToolCall] = []
has_response_schema_tool_call = False
async for event in iterator_to_async(client.models.generate_content_stream)(
model=model,
contents=self._get_google_messages(messages),
@ -1277,7 +1277,6 @@ class LLMClient:
for each in event.function_calls
]
has_response_schema_tool_call = False
for each in tool_calls:
if each.name == "ResponseSchema":
has_response_schema_tool_call = True
@ -1317,7 +1316,7 @@ class LLMClient:
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
depth: int = 0,
):
) -> AsyncGenerator[str, None]:
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
@ -1516,4 +1515,20 @@ class LLMClient:
contents=query,
config=config,
)
return response.text
return response.text
async def _search_anthropic(self, query: str) -> str:
client: AsyncAnthropic = self._client
response = await client.messages.create(
model=get_model(),
max_tokens=4000,
messages=[{"role": "user", "content": query}],
tools=[
{"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
],
)
result = "\n".join(
[each.text for each in response.content if each.type == "text"]
)
return result

View file

@ -11,6 +11,7 @@ from models.llm_message import (
)
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
class LLMToolCallsHandler:
@ -73,6 +74,9 @@ class LLMToolCallsHandler:
description = tool.__doc__ or ""
parameters = tool.model_json_schema()
if strict:
parameters = ensure_strict_json_schema(parameters, path=(), root=parameters)
return {
"type": "function",
"function": {
@ -85,6 +89,9 @@ class LLMToolCallsHandler:
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
parsed = self.parse_tool_openai(tool)
# parsed["function"]["parameters"] = flatten_json_schema(
# parsed["function"]["parameters"]
# )
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],
@ -185,9 +192,10 @@ class LLMToolCallsHandler:
return await self.client._search_google(args.query)
async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str:
return "test"
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_anthropic(args.query)
# Get current datetime tool call handler
async def get_current_datetime_tool_call_handler(self, arguments: str) -> str:
async def get_current_datetime_tool_call_handler(self, _) -> str:
current_time = datetime.now()
return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}"

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import SlideLayoutModel
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
@ -41,12 +41,10 @@ def get_messages(
language: str,
):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=system_prompt,
),
LLMMessage(
role="user",
LLMUserMessage(
content=get_user_prompt(prompt, slide_data, language),
),
]

View file

@ -1,5 +1,5 @@
from typing import Optional
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from services.llm_client import LLMClient
from utils.llm_provider import get_model
@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str):
response = await client.generate(
model=model,
messages=[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
LLMSystemMessage(content=system_prompt),
LLMUserMessage(content=get_user_prompt(prompt, html)),
],
)
return extract_html_from_response(response) or html

View file

@ -1,6 +1,6 @@
from typing import Optional
from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
from services.llm_client import LLMClient
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
@ -30,11 +30,9 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
def get_messages(prompt: str, n_slides: int, language: str, content: str):
return [
LLMSystemMessage(
role="system",
content=system_prompt,
),
LLMUserMessage(
role="user",
content=get_user_prompt(prompt, n_slides, language, content),
),
]

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import PresentationLayoutModel
from models.presentation_outline_model import PresentationOutlineModel
from services.llm_client import LLMClient
@ -11,8 +11,7 @@ def get_messages(
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=f"""
You're a professional presentation designer with creative freedom to design engaging presentations.
@ -47,8 +46,7 @@ def get_messages(
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
""",
),
LLMMessage(
role="user",
LLMUserMessage(
content=f"""
{data}
""",

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import SlideLayoutModel
from models.presentation_outline_model import SlideOutlineModel
from services.llm_client import LLMClient
@ -39,12 +39,10 @@ def get_user_prompt(outline: str, language: str):
def get_messages(outline: str, language: str):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=system_prompt,
),
LLMMessage(
role="user",
LLMUserMessage(
content=get_user_prompt(outline, language),
),
]

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from models.slide_layout_index import SlideLayoutIndex
from models.sql.slide import SlideModel
@ -13,8 +13,7 @@ def get_messages(
current_slide_layout: int,
):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=f"""
Select a Slide Layout index based on provided user prompt and current slide data.
{layout.to_string()}
@ -26,8 +25,7 @@ def get_messages(
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
),
LLMMessage(
role="user",
LLMUserMessage(
content=f"""
- User Prompt: {prompt}
- Current Slide Data: {slide_data}

View file

@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
return resolved
# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions
def flatten_json_schema(schema: dict) -> dict:
root_schema = deepcopy(schema)
def _flatten(node: Any) -> Any:
if isinstance(node, dict):
# If node is a pure $ref (or combined with extra fields), inline it
if "$ref" in node:
ref_value = node["$ref"]
assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}"
resolved = resolve_ref(root=root_schema, ref=ref_value)
assert isinstance(resolved, dict), (
f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}"
)
# Merge: referenced first, then overlay current (excluding $ref)
merged: dict[str, Any] = deepcopy(resolved)
for key, value in node.items():
if key == "$ref":
continue
merged[key] = value
return _flatten(merged)
flattened: dict[str, Any] = {}
for key, value in node.items():
# Drop defs/definitions in output
if key in ("$defs", "definitions"):
continue
if key == "properties" and isinstance(value, dict):
flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()}
elif key in ("items", "contains", "additionalProperties", "not"):
if isinstance(value, dict):
flattened[key] = _flatten(value)
elif isinstance(value, list):
flattened[key] = [_flatten(v) for v in value]
else:
flattened[key] = value
elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list):
flattened[key] = [_flatten(v) for v in value]
else:
flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value
return flattened
if isinstance(node, list):
return [_flatten(v) for v in node]
return node
result = _flatten(schema)
# Ensure top-level cleanup just in case
if isinstance(result, dict):
result.pop("$defs", None)
result.pop("definitions", None)
return result
# ? Not used
def generate_constraint_sentences(schema: dict) -> str:
"""