From dc62eb72d171e8c2edfa49d6e0312983e3ca94aa Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Sat, 9 Aug 2025 01:36:16 +0545 Subject: [PATCH] feat(fastapi): adds anthropic web search, fix(fastapi): llm messages to system and user message --- .../fastapi/api/v1/ppt/endpoints/outlines.py | 3 ++ servers/fastapi/api/v1/test/router.py | 34 ++---------- servers/fastapi/models/llm_tools.py | 1 - servers/fastapi/services/llm_client.py | 45 ++++++++++------ .../services/llm_tool_calls_handler.py | 12 ++++- servers/fastapi/utils/llm_calls/edit_slide.py | 8 ++- .../utils/llm_calls/edit_slide_html.py | 6 +-- .../generate_presentation_outlines.py | 4 +- .../generate_presentation_structure.py | 8 ++- .../utils/llm_calls/generate_slide_content.py | 8 ++- .../llm_calls/select_slide_type_on_edit.py | 8 ++- servers/fastapi/utils/schema_utils.py | 53 +++++++++++++++++++ 12 files changed, 117 insertions(+), 73 deletions(-) diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py index f1eff7ba..bf5b1489 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py +++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py @@ -72,6 +72,9 @@ async def stream_outlines( presentation_outlines_json = json.loads(presentation_outlines_text) except Exception as e: print(e) + with open("./debug/outlines.txt", "w") as f: + f.write(presentation_outlines_text) + print(presentation_outlines_text) raise HTTPException( status_code=400, detail="Failed to generate presentation outlines. Please try again.", diff --git a/servers/fastapi/api/v1/test/router.py b/servers/fastapi/api/v1/test/router.py index 9db101e0..71b77d2b 100644 --- a/servers/fastapi/api/v1/test/router.py +++ b/servers/fastapi/api/v1/test/router.py @@ -1,11 +1,10 @@ -from datetime import datetime -import json from fastapi import APIRouter from pydantic import BaseModel, Field from models.llm_message import LLMUserMessage -from models.llm_tools import LLMDynamicTool, SearchWebTool +from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient +from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline from utils.llm_provider import get_model API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"]) @@ -24,30 +23,7 @@ class ResponseContent(BaseModel): async def test(): client = LLMClient() - async def get_current_datetime_tool_handler(_) -> str: - return datetime.now().isoformat() + response = await client._search_anthropic("Trending AI tool now") + # print(response) - get_current_datetime_tool = LLMDynamicTool( - name="GetDateTimeDynamicTool", - description="Get the current date and time", - handler=get_current_datetime_tool_handler, - ) - - text_content = "" - - async for chunk in client.stream_structured( - model=get_model(), - messages=[ - LLMUserMessage( - content="What is the current date and time ? What is the trending AI tool now ? Use Available tools to get the information." - ), - ], - response_format=ResponseContent.model_json_schema(), - tools=[ - SearchWebTool, - get_current_datetime_tool, - ], - ): - text_content += chunk - - return {"data": text_content} + return {"data": ""} diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py index 4bede740..ccf64e67 100644 --- a/servers/fastapi/models/llm_tools.py +++ b/servers/fastapi/models/llm_tools.py @@ -9,7 +9,6 @@ class LLMTool(BaseModel): class LLMDynamicTool(LLMTool): name: str description: str - strict: bool = False parameters: dict = {} handler: Callable[..., Coroutine[Any, Any, str]] diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index 847a4e80..e220f577 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -43,13 +43,11 @@ from utils.get_env import ( get_google_api_key_env, get_ollama_url_env, get_openai_api_key_env, - get_openai_model_env, get_tool_calls_env, ) from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none -from utils.randomizers import get_random_uuid -from utils.schema_utils import ensure_strict_json_schema +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema class LLMClient: @@ -455,10 +453,10 @@ class LLMClient: LLMDynamicTool( name="ResponseSchema", description="Provide response to the user", - strict=strict, parameters=response_schema, handler=do_nothing_async, - ) + ), + strict=strict, ) ) @@ -557,7 +555,7 @@ class LLMClient: { "name": "ResponseSchema", "description": "Provide response to the user", - "parameters": response_format, + "parameters_json_schema": response_format, } ] ) @@ -571,7 +569,7 @@ class LLMClient: tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="application/json" if not tools else None, - response_json_schema=response_format if not tools else None, + response_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) @@ -1114,10 +1112,10 @@ class LLMClient: LLMDynamicTool( name="ResponseSchema", description="Provide response to the user", - strict=strict, parameters=response_schema, handler=do_nothing_async, - ) + ), + strict=strict, ) ) @@ -1235,10 +1233,11 @@ class LLMClient: max_tokens: Optional[int] = None, tools: Optional[List[dict]] = None, depth: int = 0, - ): + ) -> AsyncGenerator[str, None]: + client: genai.Client = self._client - google_tools = [] + google_tools = None if tools: google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] google_tools.append( @@ -1247,13 +1246,14 @@ class LLMClient: { "name": "ResponseSchema", "description": "Provide response to the user", - "parameters": response_format, + "parameters_json_schema": response_format, } ] ) ) tool_calls: List[GoogleToolCall] = [] + has_response_schema_tool_call = False async for event in iterator_to_async(client.models.generate_content_stream)( model=model, contents=self._get_google_messages(messages), @@ -1277,7 +1277,6 @@ class LLMClient: for each in event.function_calls ] - has_response_schema_tool_call = False for each in tool_calls: if each.name == "ResponseSchema": has_response_schema_tool_call = True @@ -1317,7 +1316,7 @@ class LLMClient: tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, depth: int = 0, - ): + ) -> AsyncGenerator[str, None]: client: AsyncAnthropic = self._client async with client.messages.stream( model=model, @@ -1516,4 +1515,20 @@ class LLMClient: contents=query, config=config, ) - return response.text \ No newline at end of file + return response.text + + async def _search_anthropic(self, query: str) -> str: + client: AsyncAnthropic = self._client + + response = await client.messages.create( + model=get_model(), + max_tokens=4000, + messages=[{"role": "user", "content": query}], + tools=[ + {"type": "web_search_20250305", "name": "web_search", "max_uses": 1} + ], + ) + result = "\n".join( + [each.text for each in response.content if each.type == "text"] + ) + return result diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py index 723fefe0..1d8ffec4 100644 --- a/servers/fastapi/services/llm_tool_calls_handler.py +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -11,6 +11,7 @@ from models.llm_message import ( ) from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema class LLMToolCallsHandler: @@ -73,6 +74,9 @@ class LLMToolCallsHandler: description = tool.__doc__ or "" parameters = tool.model_json_schema() + if strict: + parameters = ensure_strict_json_schema(parameters, path=(), root=parameters) + return { "type": "function", "function": { @@ -85,6 +89,9 @@ class LLMToolCallsHandler: def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): parsed = self.parse_tool_openai(tool) + # parsed["function"]["parameters"] = flatten_json_schema( + # parsed["function"]["parameters"] + # ) return { "name": parsed["function"]["name"], "description": parsed["function"]["description"], @@ -185,9 +192,10 @@ class LLMToolCallsHandler: return await self.client._search_google(args.query) async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: - return "test" + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_anthropic(args.query) # Get current datetime tool call handler - async def get_current_datetime_tool_call_handler(self, arguments: str) -> str: + async def get_current_datetime_tool_call_handler(self, _) -> str: current_time = datetime.now() return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}" diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py index a8df598a..30599d08 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide.py +++ b/servers/fastapi/utils/llm_calls/edit_slide.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.sql.slide import SlideModel from services.llm_client import LLMClient @@ -41,12 +41,10 @@ def get_messages( language: str, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, slide_data, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/edit_slide_html.py b/servers/fastapi/utils/llm_calls/edit_slide_html.py index a5e2dfad..cf58d185 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide_html.py +++ b/servers/fastapi/utils/llm_calls/edit_slide_html.py @@ -1,5 +1,5 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from services.llm_client import LLMClient from utils.llm_provider import get_model @@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str): response = await client.generate( model=model, messages=[ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=get_user_prompt(prompt, html)), + LLMSystemMessage(content=system_prompt), + LLMUserMessage(content=get_user_prompt(prompt, html)), ], ) return extract_html_from_response(response) or html diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 3bb95c5e..6c0ad512 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,6 +1,6 @@ from typing import Optional -from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides @@ -30,11 +30,9 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ LLMSystemMessage( - role="system", content=system_prompt, ), LLMUserMessage( - role="user", content=get_user_prompt(prompt, n_slides, language, content), ), ] diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py index 47f47dba..1bfc0cd0 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from services.llm_client import LLMClient @@ -11,8 +11,7 @@ def get_messages( presentation_layout: PresentationLayoutModel, n_slides: int, data: str ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" You're a professional presentation designer with creative freedom to design engaging presentations. @@ -47,8 +46,7 @@ def get_messages( Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals. """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" {data} """, diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py index 62b87e2b..be19b168 100644 --- a/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.presentation_outline_model import SlideOutlineModel from services.llm_client import LLMClient @@ -39,12 +39,10 @@ def get_user_prompt(outline: str, language: str): def get_messages(outline: str, language: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(outline, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py index f3532b48..7235e558 100644 --- a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py +++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel from models.slide_layout_index import SlideLayoutIndex from models.sql.slide import SlideModel @@ -13,8 +13,7 @@ def get_messages( current_slide_layout: int, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" Select a Slide Layout index based on provided user prompt and current slide data. {layout.to_string()} @@ -26,8 +25,7 @@ def get_messages( **Go through all notes and steps and make sure they are followed, including mentioned constraints** """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" - User Prompt: {prompt} - Current Slide Data: {slide_data} diff --git a/servers/fastapi/utils/schema_utils.py b/servers/fastapi/utils/schema_utils.py index ae65f002..6cb01a0e 100644 --- a/servers/fastapi/utils/schema_utils.py +++ b/servers/fastapi/utils/schema_utils.py @@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved +# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions +def flatten_json_schema(schema: dict) -> dict: + root_schema = deepcopy(schema) + + def _flatten(node: Any) -> Any: + if isinstance(node, dict): + # If node is a pure $ref (or combined with extra fields), inline it + if "$ref" in node: + ref_value = node["$ref"] + assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}" + resolved = resolve_ref(root=root_schema, ref=ref_value) + assert isinstance(resolved, dict), ( + f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}" + ) + # Merge: referenced first, then overlay current (excluding $ref) + merged: dict[str, Any] = deepcopy(resolved) + for key, value in node.items(): + if key == "$ref": + continue + merged[key] = value + return _flatten(merged) + + flattened: dict[str, Any] = {} + for key, value in node.items(): + # Drop defs/definitions in output + if key in ("$defs", "definitions"): + continue + if key == "properties" and isinstance(value, dict): + flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()} + elif key in ("items", "contains", "additionalProperties", "not"): + if isinstance(value, dict): + flattened[key] = _flatten(value) + elif isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = value + elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value + return flattened + if isinstance(node, list): + return [_flatten(v) for v in node] + return node + + result = _flatten(schema) + # Ensure top-level cleanup just in case + if isinstance(result, dict): + result.pop("$defs", None) + result.pop("definitions", None) + return result + + # ? Not used def generate_constraint_sentences(schema: dict) -> str: """