feat(fastapi): adds anthropic web search, fix(fastapi): llm messages to system and user message
This commit is contained in:
parent
5c106bd664
commit
dc62eb72d1
12 changed files with 117 additions and 73 deletions
|
|
@ -72,6 +72,9 @@ async def stream_outlines(
|
|||
presentation_outlines_json = json.loads(presentation_outlines_text)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
with open("./debug/outlines.txt", "w") as f:
|
||||
f.write(presentation_outlines_text)
|
||||
print(presentation_outlines_text)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to generate presentation outlines. Please try again.",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from datetime import datetime
|
||||
import json
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.llm_message import LLMUserMessage
|
||||
from models.llm_tools import LLMDynamicTool, SearchWebTool
|
||||
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
API_V1_TEST_ROUTER = APIRouter(prefix="/api/v1/test", tags=["test"])
|
||||
|
|
@ -24,30 +23,7 @@ class ResponseContent(BaseModel):
|
|||
async def test():
|
||||
client = LLMClient()
|
||||
|
||||
async def get_current_datetime_tool_handler(_) -> str:
|
||||
return datetime.now().isoformat()
|
||||
response = await client._search_anthropic("Trending AI tool now")
|
||||
# print(response)
|
||||
|
||||
get_current_datetime_tool = LLMDynamicTool(
|
||||
name="GetDateTimeDynamicTool",
|
||||
description="Get the current date and time",
|
||||
handler=get_current_datetime_tool_handler,
|
||||
)
|
||||
|
||||
text_content = ""
|
||||
|
||||
async for chunk in client.stream_structured(
|
||||
model=get_model(),
|
||||
messages=[
|
||||
LLMUserMessage(
|
||||
content="What is the current date and time ? What is the trending AI tool now ? Use Available tools to get the information."
|
||||
),
|
||||
],
|
||||
response_format=ResponseContent.model_json_schema(),
|
||||
tools=[
|
||||
SearchWebTool,
|
||||
get_current_datetime_tool,
|
||||
],
|
||||
):
|
||||
text_content += chunk
|
||||
|
||||
return {"data": text_content}
|
||||
return {"data": ""}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ class LLMTool(BaseModel):
|
|||
class LLMDynamicTool(LLMTool):
|
||||
name: str
|
||||
description: str
|
||||
strict: bool = False
|
||||
parameters: dict = {}
|
||||
handler: Callable[..., Coroutine[Any, Any, str]]
|
||||
|
||||
|
|
|
|||
|
|
@ -43,13 +43,11 @@ from utils.get_env import (
|
|||
get_google_api_key_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
get_tool_calls_env,
|
||||
)
|
||||
from utils.llm_provider import get_llm_provider, get_model
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.randomizers import get_random_uuid
|
||||
from utils.schema_utils import ensure_strict_json_schema
|
||||
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
|
||||
|
||||
|
||||
class LLMClient:
|
||||
|
|
@ -455,10 +453,10 @@ class LLMClient:
|
|||
LLMDynamicTool(
|
||||
name="ResponseSchema",
|
||||
description="Provide response to the user",
|
||||
strict=strict,
|
||||
parameters=response_schema,
|
||||
handler=do_nothing_async,
|
||||
)
|
||||
),
|
||||
strict=strict,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -557,7 +555,7 @@ class LLMClient:
|
|||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "Provide response to the user",
|
||||
"parameters": response_format,
|
||||
"parameters_json_schema": response_format,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
|
@ -571,7 +569,7 @@ class LLMClient:
|
|||
tools=google_tools,
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json" if not tools else None,
|
||||
response_json_schema=response_format if not tools else None,
|
||||
response_schema=response_format if not tools else None,
|
||||
max_output_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
|
|
@ -1114,10 +1112,10 @@ class LLMClient:
|
|||
LLMDynamicTool(
|
||||
name="ResponseSchema",
|
||||
description="Provide response to the user",
|
||||
strict=strict,
|
||||
parameters=response_schema,
|
||||
handler=do_nothing_async,
|
||||
)
|
||||
),
|
||||
strict=strict,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1235,10 +1233,11 @@ class LLMClient:
|
|||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
depth: int = 0,
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
|
||||
client: genai.Client = self._client
|
||||
|
||||
google_tools = []
|
||||
google_tools = None
|
||||
if tools:
|
||||
google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
|
||||
google_tools.append(
|
||||
|
|
@ -1247,13 +1246,14 @@ class LLMClient:
|
|||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "Provide response to the user",
|
||||
"parameters": response_format,
|
||||
"parameters_json_schema": response_format,
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls: List[GoogleToolCall] = []
|
||||
has_response_schema_tool_call = False
|
||||
async for event in iterator_to_async(client.models.generate_content_stream)(
|
||||
model=model,
|
||||
contents=self._get_google_messages(messages),
|
||||
|
|
@ -1277,7 +1277,6 @@ class LLMClient:
|
|||
for each in event.function_calls
|
||||
]
|
||||
|
||||
has_response_schema_tool_call = False
|
||||
for each in tool_calls:
|
||||
if each.name == "ResponseSchema":
|
||||
has_response_schema_tool_call = True
|
||||
|
|
@ -1317,7 +1316,7 @@ class LLMClient:
|
|||
tools: Optional[List[dict]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
depth: int = 0,
|
||||
):
|
||||
) -> AsyncGenerator[str, None]:
|
||||
client: AsyncAnthropic = self._client
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
|
|
@ -1516,4 +1515,20 @@ class LLMClient:
|
|||
contents=query,
|
||||
config=config,
|
||||
)
|
||||
return response.text
|
||||
return response.text
|
||||
|
||||
async def _search_anthropic(self, query: str) -> str:
|
||||
client: AsyncAnthropic = self._client
|
||||
|
||||
response = await client.messages.create(
|
||||
model=get_model(),
|
||||
max_tokens=4000,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
tools=[
|
||||
{"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
|
||||
],
|
||||
)
|
||||
result = "\n".join(
|
||||
[each.text for each in response.content if each.type == "text"]
|
||||
)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from models.llm_message import (
|
|||
)
|
||||
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
|
||||
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
|
||||
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
|
||||
|
||||
|
||||
class LLMToolCallsHandler:
|
||||
|
|
@ -73,6 +74,9 @@ class LLMToolCallsHandler:
|
|||
description = tool.__doc__ or ""
|
||||
parameters = tool.model_json_schema()
|
||||
|
||||
if strict:
|
||||
parameters = ensure_strict_json_schema(parameters, path=(), root=parameters)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -85,6 +89,9 @@ class LLMToolCallsHandler:
|
|||
|
||||
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
parsed = self.parse_tool_openai(tool)
|
||||
# parsed["function"]["parameters"] = flatten_json_schema(
|
||||
# parsed["function"]["parameters"]
|
||||
# )
|
||||
return {
|
||||
"name": parsed["function"]["name"],
|
||||
"description": parsed["function"]["description"],
|
||||
|
|
@ -185,9 +192,10 @@ class LLMToolCallsHandler:
|
|||
return await self.client._search_google(args.query)
|
||||
|
||||
async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str:
|
||||
return "test"
|
||||
args = SearchWebTool.model_validate_json(arguments)
|
||||
return await self.client._search_anthropic(args.query)
|
||||
|
||||
# Get current datetime tool call handler
|
||||
async def get_current_datetime_tool_call_handler(self, arguments: str) -> str:
|
||||
async def get_current_datetime_tool_call_handler(self, _) -> str:
|
||||
current_time = datetime.now()
|
||||
return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
|
|
@ -41,12 +41,10 @@ def get_messages(
|
|||
language: str,
|
||||
):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(prompt, slide_data, language),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional
|
||||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
|
@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str):
|
|||
response = await client.generate(
|
||||
model=model,
|
||||
messages=[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
|
||||
LLMSystemMessage(content=system_prompt),
|
||||
LLMUserMessage(content=get_user_prompt(prompt, html)),
|
||||
],
|
||||
)
|
||||
return extract_html_from_response(response) or html
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from models.llm_message import LLMMessage, LLMSystemMessage, LLMUserMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
|
||||
from services.llm_client import LLMClient
|
||||
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
|
||||
|
|
@ -30,11 +30,9 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
|
|||
def get_messages(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
LLMSystemMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMUserMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(prompt, n_slides, language, content),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
|
|
@ -11,8 +11,7 @@ def get_messages(
|
|||
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
|
||||
):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=f"""
|
||||
You're a professional presentation designer with creative freedom to design engaging presentations.
|
||||
|
||||
|
|
@ -47,8 +46,7 @@ def get_messages(
|
|||
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
|
||||
""",
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=f"""
|
||||
{data}
|
||||
""",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
|
|
@ -39,12 +39,10 @@ def get_user_prompt(outline: str, language: str):
|
|||
def get_messages(outline: str, language: str):
|
||||
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(outline, language),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
|
|
@ -13,8 +13,7 @@ def get_messages(
|
|||
current_slide_layout: int,
|
||||
):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=f"""
|
||||
Select a Slide Layout index based on provided user prompt and current slide data.
|
||||
{layout.to_string()}
|
||||
|
|
@ -26,8 +25,7 @@ def get_messages(
|
|||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=f"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
|
|
|
|||
|
|
@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
|
|||
return resolved
|
||||
|
||||
|
||||
# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions
|
||||
def flatten_json_schema(schema: dict) -> dict:
|
||||
root_schema = deepcopy(schema)
|
||||
|
||||
def _flatten(node: Any) -> Any:
|
||||
if isinstance(node, dict):
|
||||
# If node is a pure $ref (or combined with extra fields), inline it
|
||||
if "$ref" in node:
|
||||
ref_value = node["$ref"]
|
||||
assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}"
|
||||
resolved = resolve_ref(root=root_schema, ref=ref_value)
|
||||
assert isinstance(resolved, dict), (
|
||||
f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}"
|
||||
)
|
||||
# Merge: referenced first, then overlay current (excluding $ref)
|
||||
merged: dict[str, Any] = deepcopy(resolved)
|
||||
for key, value in node.items():
|
||||
if key == "$ref":
|
||||
continue
|
||||
merged[key] = value
|
||||
return _flatten(merged)
|
||||
|
||||
flattened: dict[str, Any] = {}
|
||||
for key, value in node.items():
|
||||
# Drop defs/definitions in output
|
||||
if key in ("$defs", "definitions"):
|
||||
continue
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()}
|
||||
elif key in ("items", "contains", "additionalProperties", "not"):
|
||||
if isinstance(value, dict):
|
||||
flattened[key] = _flatten(value)
|
||||
elif isinstance(value, list):
|
||||
flattened[key] = [_flatten(v) for v in value]
|
||||
else:
|
||||
flattened[key] = value
|
||||
elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list):
|
||||
flattened[key] = [_flatten(v) for v in value]
|
||||
else:
|
||||
flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value
|
||||
return flattened
|
||||
if isinstance(node, list):
|
||||
return [_flatten(v) for v in node]
|
||||
return node
|
||||
|
||||
result = _flatten(schema)
|
||||
# Ensure top-level cleanup just in case
|
||||
if isinstance(result, dict):
|
||||
result.pop("$defs", None)
|
||||
result.pop("definitions", None)
|
||||
return result
|
||||
|
||||
|
||||
# ? Not used
|
||||
def generate_constraint_sentences(schema: dict) -> str:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue