diff --git a/README.md b/README.md
index 72a53771..a7851cc4 100644
--- a/README.md
+++ b/README.md
@@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep
- **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom**
- **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output.
- **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled.
+- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results.
You can also set the following environment variables to customize the image generation provider and API keys:
diff --git a/badu.js b/badu.js
new file mode 100644
index 00000000..e69de29b
diff --git a/docker-compose.yml b/docker-compose.yml
index 42502b5f..85ca9210 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -25,6 +25,9 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
+ - TOOL_CALLS=${TOOL_CALLS}
+ - DISABLE_THINKING=${DISABLE_THINKING}
+ - WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}
production-gpu:
@@ -60,6 +63,9 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
+ - TOOL_CALLS=${TOOL_CALLS}
+ - DISABLE_THINKING=${DISABLE_THINKING}
+ - WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}
development:
@@ -87,6 +93,9 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
+ - TOOL_CALLS=${TOOL_CALLS}
+ - DISABLE_THINKING=${DISABLE_THINKING}
+ - WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}
development-gpu:
@@ -120,5 +129,8 @@ services:
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- - DATABASE_URL=${DATABASE_URL}
- EXTENDED_REASONING=${EXTENDED_REASONING}
+ - TOOL_CALLS=${TOOL_CALLS}
+ - DISABLE_THINKING=${DISABLE_THINKING}
+ - WEB_GROUNDING=${WEB_GROUNDING}
+ - DATABASE_URL=${DATABASE_URL}
diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py
index adde8669..0dafa3e1 100644
--- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py
+++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py
@@ -64,7 +64,7 @@ async def pull_model(
# If the model is being pulled, return the model
if saved_model_status:
# If the model is being pulled, return the model
- # ? If the model status is pulled in redis but was not found while listing pulled models,
+ # ? If the model status is pulled in database but was not found while listing pulled models,
# ? it means the model was deleted and we need to pull it again
if (
saved_model_status["status"] == "error"
diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py
index b0ec47af..bf5b1489 100644
--- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py
+++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py
@@ -72,6 +72,9 @@ async def stream_outlines(
presentation_outlines_json = json.loads(presentation_outlines_text)
except Exception as e:
print(e)
+ with open("./debug/outlines.txt", "w") as f:
+ f.write(presentation_outlines_text)
+ print(presentation_outlines_text)
raise HTTPException(
status_code=400,
detail="Failed to generate presentation outlines. Please try again.",
@@ -87,7 +90,8 @@ async def stream_outlines(
presentation.outlines = presentation_outlines.model_dump()
presentation.title = (
- presentation_outlines.slides[0][:50]
+ presentation_outlines.slides[0]
+ .content[:50]
.replace("#", "")
.replace("/", "")
.replace("\\", "")
diff --git a/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/servers/fastapi/api/v1/ppt/endpoints/presentation.py
index 5b4589a2..2650782d 100644
--- a/servers/fastapi/api/v1/ppt/endpoints/presentation.py
+++ b/servers/fastapi/api/v1/ppt/endpoints/presentation.py
@@ -11,7 +11,10 @@ from sqlmodel import select
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
from models.presentation_and_path import PresentationPathAndEditPath
from models.presentation_from_template import GetPresentationUsingTemplateRequest
-from models.presentation_outline_model import PresentationOutlineModel
+from models.presentation_outline_model import (
+ PresentationOutlineModel,
+ SlideOutlineModel,
+)
from models.pptx_models import PptxPresentationModel
from models.presentation_layout import PresentationLayoutModel
from models.presentation_structure_model import PresentationStructureModel
@@ -126,7 +129,7 @@ async def create_presentation(
@PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel)
async def prepare_presentation(
presentation_id: Annotated[str, Body()],
- outlines: Annotated[List[str], Body()],
+ outlines: Annotated[List[SlideOutlineModel], Body()],
layout: Annotated[PresentationLayoutModel, Body()],
title: Annotated[Optional[str], Body()] = None,
sql_session: AsyncSession = Depends(get_async_session),
@@ -161,7 +164,9 @@ async def prepare_presentation(
presentation_structure.slides[index] = random_slide_index
sql_session.add(presentation)
- presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump()
+ presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump(
+ mode="json"
+ )
presentation.title = title or presentation.title
presentation.set_layout(layout)
presentation.set_structure(presentation_structure)
diff --git a/servers/fastapi/api/v1/ppt/endpoints/prompts.py b/servers/fastapi/api/v1/ppt/endpoints/prompts.py
index e1c0c0ba..500a3e44 100644
--- a/servers/fastapi/api/v1/ppt/endpoints/prompts.py
+++ b/servers/fastapi/api/v1/ppt/endpoints/prompts.py
@@ -23,7 +23,7 @@ Follow these rules strictly:
- If there is a box/card enclosing a text, make it grow as well when the text grows, so that the text does not overflow the box/card.
- Give out only HTML and Tailwind code. No other texts or explanations.
- Do not give entire HTML structure with head, body, etc. Just give the respective HTML and Tailwind code inside div with above classes.
-- If a list of fonts is provided, you must use the provided fonts (normalized root families) in font-family declarations, prioritizing them over inferred fonts. Use the first matching family wherever applicable.
+- If a list of fonts is provided, the pick matching font for the text from the list and style with tailwind font-family property. Use following format: font-["font-name"]
"""
HTML_TO_REACT_SYSTEM_PROMPT = """
@@ -34,7 +34,7 @@ Convert given static HTML and Tailwind slide to a TSX React component so that it
3) For similar components in the layouts (eg, team members), they should be represented by array of such components in the schema.
4) For image and icons icons should be a different schema with two dunder fields for prompt and url separately.
5) Default value for schema fields should be populated with the respective static value in HTML input.
-6) In schema max and min value for characters in string and items in array should be specified as per the given image of the slide. You should accurately evaluate the maximum and minimum possible characters respective fields can handle visually through the image.
+6) In schema max and min value for characters in string and items in array should be specified as per the given image of the slide. You should accurately evaluate the maximum and minimum possible characters respective fields can handle visually through the image. ALso give out maximum number of words it can handle in the meta.
7) For image and icons schema should be compulsorily declared with two dunder fields for prompt and url separately.
8) Component name at the end should always yo 'dynamicSlideLayout'.
9) **Import or export statements should not be present in the output.**
@@ -53,6 +53,12 @@ Convert given static HTML and Tailwind slide to a TSX React component so that it
14. Always complete the reference, do not give "slideData .? .cards" instead give "slideData?.cards".
15. Do not add anything other than code. Do not add "use client", "json", "typescript", "javascript" and other prefix or suffix, just give out code exactly formatted like example.
16. In schema, give default for all fields irrespective of their types, give defualt values for array and objects as well.
+17. For charts use recharts.js library and follow these rules strictly:
+ - Do not import rechart, it will already be imported.
+ - There should support for multiple chart types including bar, line, pie and donut in the same size as given.
+ - Use an attribute in the schema to select between chart types.
+ - All data should be properly represented in schema.
+18. For diagrams use mermaid with appropriate placeholder which can render any daigram. Schema should have a field for code. Render in the placeholder properly.
For example:
Input:
Effects of Global Warming
Global warming triggers a cascade of effects on our planet. These changes impact everything from our oceans to our ecosystems.
Rising Sea Levels
Rising sea levels threaten coastal communities and ecosystems due to melting glaciers and thermal expansion.
Intense Heatwaves
Heatwaves are becoming more frequent and intense, posing significant risks to human health and agriculture.
Changes in Precipitation
Altered precipitation patterns lead to increased droughts in some regions and severe flooding in others, affecting water resources.
@@ -62,7 +68,7 @@ const ImageSchema = z.object({
description: "URL to image",
}),
__image_prompt__: z.string().meta({
- description: "Prompt used to generate the image",
+ description: "Prompt used to generate the image. Max 30 words",
}).min(10).max(50),
})
@@ -71,7 +77,7 @@ const IconSchema = z.object({
description: "URL to icon",
}),
__icon_query__: z.string().meta({
- description: "Query used to search the icon",
+ description: "Query used to search the icon. Max 3 words",
}).min(5).max(20),
})
const layoutId = "bullet-with-icons-slide"
@@ -80,23 +86,23 @@ const layoutDescription = "A bullets style slide with main content, supporting i
const Schema = z.object({
title: z.string().min(3).max(40).default("Problem").meta({
- description: "Main title of the slide",
+ description: "Main title of the slide. Max 5 words",
}),
description: z.string().max(150).default("Businesses face challenges with outdated technology and rising costs, limiting efficiency and growth in competitive markets.").meta({
- description: "Main description text explaining the problem or topic",
+ description: "Main description text explaining the problem or topic. Max 30 words",
}),
image: ImageSchema.default({
__image_url__: 'https://images.unsplash.com/photo-1552664730-d307ca884978?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1000&q=80',
__image_prompt__: "Business people analyzing documents and charts in office"
}).meta({
- description: "Supporting image for the slide",
+ description: "Supporting image for the slide. Max 30 words",
}),
bulletPoints: z.array(z.object({
title: z.string().min(2).max(80).meta({
- description: "Bullet point title",
+ description: "Bullet point title. Max 4 words",
}),
description: z.string().min(10).max(150).meta({
- description: "Bullet point description",
+ description: "Bullet point description. Max 15 words",
}),
icon: IconSchema,
})).min(1).max(3).default([
@@ -117,7 +123,7 @@ const Schema = z.object({
}
}
]).meta({
- description: "List of bullet points with icons and descriptions",
+ description: "List of bullet points with icons and descriptions. Max 3 points",
})
})
diff --git a/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py b/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py
index 5e09179d..67539938 100644
--- a/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py
+++ b/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py
@@ -4,9 +4,8 @@ from datetime import datetime
from typing import Optional, List, Dict
from fastapi import APIRouter, HTTPException, File, UploadFile, Form, Depends
from pydantic import BaseModel
-from google import genai
-from google.genai import errors
-from google.genai import types
+from openai import OpenAI
+from openai import APIError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, func
from utils.asset_directory_utils import get_images_directory
@@ -26,6 +25,7 @@ LAYOUT_MANAGEMENT_ROUTER = APIRouter(prefix="/layout-management", tags=["layout-
class SlideToHtmlRequest(BaseModel):
image: str # Partial path to image file (e.g., "/app_data/images/uuid/slide_1.png")
xml: str # OXML content as text
+ fonts: Optional[List[str]] = None # Optional normalized root fonts for this slide
class SlideToHtmlResponse(BaseModel):
success: bool
@@ -42,6 +42,7 @@ class HtmlEditResponse(BaseModel):
# Request/Response models for html-to-react endpoint
class HtmlToReactRequest(BaseModel):
html: str # HTML content to convert to React component
+ image: Optional[str] = None # Optional image path to provide visual context
class HtmlToReactResponse(BaseModel):
@@ -94,15 +95,16 @@ class ErrorResponse(BaseModel):
error_code: Optional[str] = None
-async def generate_html_from_slide(base64_image: str, media_type: str, xml_content: str, api_key: str) -> str:
+async def generate_html_from_slide(base64_image: str, media_type: str, xml_content: str, api_key: str, fonts: Optional[List[str]] = None) -> str:
"""
- Generate HTML content from slide image and XML using Google Gen AI API.
+ Generate HTML content from slide image and XML using OpenAI GPT-5 Responses API.
Args:
base64_image: Base64 encoded image data
media_type: MIME type of the image (e.g., 'image/png')
xml_content: OXML content as text
- api_key: Google Gen AI API key
+ api_key: OpenAI API key
+ fonts: Optional list of normalized root font families to prefer in output
Returns:
Generated HTML content as string
@@ -110,62 +112,51 @@ async def generate_html_from_slide(base64_image: str, media_type: str, xml_conte
Raises:
HTTPException: If API call fails or no content is generated
"""
- print(f"Generating HTML from slide image and XML using Google Gen AI API...")
+ print(f"Generating HTML from slide image and XML using OpenAI GPT-5 Responses API...")
try:
- # Initialize Google Gen AI client
- client = genai.Client(api_key=api_key)
-
- # Convert base64 to bytes
- image_bytes = base64.b64decode(base64_image)
-
- print("Starting non-streaming request to Google Gen AI for HTML generation...")
-
- # Create content with image and text
- contents = [
- types.Part.from_bytes(
- mime_type=media_type,
- data=image_bytes,
- ),
- types.Part.from_text(text=f"\nOXML: \n\n{xml_content}"),
+ client = OpenAI(api_key=api_key)
+
+ # Compose input for Responses API. Include system prompt, image (separate), OXML and optional fonts text.
+ data_url = f"data:{media_type};base64,{base64_image}"
+ fonts_text = f"\nFONTS (Normalized root families used in this slide, use where it is required): {', '.join(fonts)}" if fonts else ""
+ user_text = f"OXML: \n\n{fonts_text}"
+ input_payload = [
+ {"role": "system", "content": GENERATE_HTML_SYSTEM_PROMPT},
+ {
+ "role": "user",
+ "content": [
+ {"type": "input_image", "image_url": data_url},
+ {"type": "input_text", "text": user_text},
+ ],
+ },
]
-
- # Generate content config with thinking enabled
- generate_content_config = types.GenerateContentConfig(
- system_instruction=GENERATE_HTML_SYSTEM_PROMPT,
- max_output_tokens=65536,
- temperature=1.0,
- thinking_config=types.ThinkingConfig(
- thinking_budget=32768,
- ),
- )
-
- print("Making non-streaming request for HTML generation...")
-
- # Generate content in non-streaming mode
- response = client.models.generate_content(
- model="gemini-2.5-pro",
- contents=contents,
- config=generate_content_config,
+
+ print("Making Responses API request for HTML generation...")
+ response = client.responses.create(
+ model="gpt-5",
+ input=input_payload,
+ reasoning={"effort": "high"},
+ text={"verbosity": "low"},
)
# Extract the response text
- html_content = response.text if response.text else ""
+ html_content = getattr(response, "output_text", None) or getattr(response, "text", None) or ""
print(f"Received HTML content length: {len(html_content)}")
if not html_content:
raise HTTPException(
status_code=500,
- detail="No HTML content generated by Google Gen AI API"
+ detail="No HTML content generated by OpenAI GPT-5"
)
return html_content
- except errors.APIError as e:
- print(f"Google API Error: {e}")
+ except APIError as e:
+ print(f"OpenAI API Error: {e}")
raise HTTPException(
status_code=500,
- detail=f"Google API error during HTML generation: {str(e)}"
+ detail=f"OpenAI API error during HTML generation: {str(e)}"
)
except Exception as e:
# Handle various API errors
@@ -175,27 +166,27 @@ async def generate_html_from_slide(base64_image: str, media_type: str, xml_conte
if "timeout" in error_msg.lower():
raise HTTPException(
status_code=408,
- detail=f"Google Gen AI API timeout during HTML generation: {error_msg}"
+ detail=f"OpenAI API timeout during HTML generation: {error_msg}"
)
elif "connection" in error_msg.lower():
raise HTTPException(
status_code=503,
- detail=f"Google Gen AI API connection error during HTML generation: {error_msg}"
+ detail=f"OpenAI API connection error during HTML generation: {error_msg}"
)
else:
raise HTTPException(
status_code=500,
- detail=f"Google Gen AI API error during HTML generation: {error_msg}"
+ detail=f"OpenAI API error during HTML generation: {error_msg}"
)
-async def generate_react_component_from_html(html_content: str, api_key: str) -> str:
+async def generate_react_component_from_html(html_content: str, api_key: str, image_base64: Optional[str] = None, media_type: Optional[str] = None) -> str:
"""
- Convert HTML content to TSX React component using Google Gen AI API.
+ Convert HTML content to TSX React component using OpenAI GPT-5 Responses API.
Args:
html_content: Generated HTML content
- api_key: Google Gen AI API key
+ api_key: OpenAI API key
Returns:
Generated TSX React component code as string
@@ -204,42 +195,36 @@ async def generate_react_component_from_html(html_content: str, api_key: str) ->
HTTPException: If API call fails or no content is generated
"""
try:
- # Initialize Google Gen AI client
- client = genai.Client(api_key=api_key)
-
- print("Starting non-streaming request to Google Gen AI for React component generation...")
-
- # Create content with text
- contents = types.Part.from_text(text=html_content)
-
- # Generate content config with thinking enabled
- generate_content_config = types.GenerateContentConfig(
- system_instruction=HTML_TO_REACT_SYSTEM_PROMPT,
- max_output_tokens=65536,
- temperature=1.0,
- thinking_config=types.ThinkingConfig(
- thinking_budget=15000,
- ),
- )
-
- print("Making non-streaming request for React component...")
-
- # Generate content in non-streaming mode
- response = client.models.generate_content(
- model="gemini-2.5-pro",
- contents=contents,
- config=generate_content_config,
+ client = OpenAI(api_key=api_key)
+
+ print("Making Responses API request for React component generation...")
+
+ # Build payload with optional image
+ content_parts = [{"type": "input_text", "text": f"HTML INPUT:\n{html_content}"}]
+ if image_base64 and media_type:
+ data_url = f"data:{media_type};base64,{image_base64}"
+ content_parts.insert(0, {"type": "input_image", "image_url": data_url})
+
+ input_payload = [
+ {"role": "system", "content": HTML_TO_REACT_SYSTEM_PROMPT},
+ {"role": "user", "content": content_parts},
+ ]
+
+ response = client.responses.create(
+ model="gpt-5",
+ input=input_payload,
+ reasoning={"effort": "minimal"},
+ text={"verbosity": "low"},
)
- # Extract the response text
- react_content = response.text if response.text else ""
+ react_content = getattr(response, "output_text", None) or getattr(response, "text", None) or ""
print(f"Received React content length: {len(react_content)}")
if not react_content:
raise HTTPException(
status_code=500,
- detail="No React component generated by Google Gen AI API"
+ detail="No React component generated by OpenAI GPT-5"
)
react_content = react_content.replace("```tsx", "").replace("```", "").replace("typescript", "").replace("javascript", "")
@@ -256,11 +241,11 @@ async def generate_react_component_from_html(html_content: str, api_key: str) ->
print(f"Filtered React content length: {len(filtered_react_content)}")
return filtered_react_content
- except errors.APIError as e:
- print(f"Google API Error: {e}")
+ except APIError as e:
+ print(f"OpenAI API Error: {e}")
raise HTTPException(
status_code=500,
- detail=f"Google API error during React generation: {str(e)}"
+ detail=f"OpenAI API error during React generation: {str(e)}"
)
except Exception as e:
# Handle various API errors
@@ -270,23 +255,23 @@ async def generate_react_component_from_html(html_content: str, api_key: str) ->
if "timeout" in error_msg.lower():
raise HTTPException(
status_code=408,
- detail=f"Google Gen AI API timeout during React generation: {error_msg}"
+ detail=f"OpenAI API timeout during React generation: {error_msg}"
)
elif "connection" in error_msg.lower():
raise HTTPException(
status_code=503,
- detail=f"Google Gen AI API connection error during React generation: {error_msg}"
+ detail=f"OpenAI API connection error during React generation: {error_msg}"
)
else:
raise HTTPException(
status_code=500,
- detail=f"Google Gen AI API error during React generation: {error_msg}"
+ detail=f"OpenAI API error during React generation: {error_msg}"
)
async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[str], media_type: str, html_content: str, prompt: str, api_key: str) -> str:
"""
- Edit HTML content based on one or two images and a text prompt using Google Gen AI API.
+ Edit HTML content based on one or two images and a text prompt using OpenAI GPT-5 Responses API.
Args:
current_ui_base64: Base64 encoded current UI image data
@@ -294,7 +279,7 @@ async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[
media_type: MIME type of the images (e.g., 'image/png')
html_content: Current HTML content to edit
prompt: Text prompt describing the changes
- api_key: Google Gen AI API key
+ api_key: OpenAI API key
Returns:
Edited HTML content as string
@@ -303,70 +288,50 @@ async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[
HTTPException: If API call fails or no content is generated
"""
try:
- # Initialize Google Gen AI client
- client = genai.Client(api_key=api_key)
-
- print("Starting non-streaming request to Google Gen AI for HTML editing...")
-
- # Convert base64 images to bytes
- current_ui_bytes = base64.b64decode(current_ui_base64)
-
- # Build content array - always include text and current UI image
- contents = [
- types.Part.from_text(text=f"Current HTML to edit:\n\n{html_content}\n\nText prompt for changes: {prompt}"),
- types.Part.from_bytes(
- mime_type=media_type,
- data=current_ui_bytes,
- )
+ client = OpenAI(api_key=api_key)
+
+ print("Making Responses API request for HTML editing...")
+
+ current_data_url = f"data:{media_type};base64,{current_ui_base64}"
+ sketch_data_url = f"data:{media_type};base64,{sketch_base64}" if sketch_base64 else None
+
+ content_parts = [
+ {"type": "input_image", "image_url": current_data_url},
+ {"type": "input_text", "text": f"CURRENT HTML TO EDIT:\n{html_content}\n\nTEXT PROMPT FOR CHANGES:\n{prompt}"},
]
-
- # Only add sketch image if provided
- if sketch_base64:
- sketch_bytes = base64.b64decode(sketch_base64)
- contents.append(
- types.Part.from_bytes(
- mime_type=media_type,
- data=sketch_bytes,
- )
- )
-
- # Generate content config with thinking enabled
- generate_content_config = types.GenerateContentConfig(
- system_instruction=HTML_EDIT_SYSTEM_PROMPT,
- max_output_tokens=65536,
- temperature=1.0,
- thinking_config=types.ThinkingConfig(
- thinking_budget=16000,
- ),
+ if sketch_data_url:
+ # Insert sketch image after current UI image for context
+ content_parts.insert(1, {"type": "input_image", "image_url": sketch_data_url})
+
+ input_payload = [
+ {"role": "system", "content": HTML_EDIT_SYSTEM_PROMPT},
+ {"role": "user", "content": content_parts},
+ ]
+
+ response = client.responses.create(
+ model="gpt-5",
+ input=input_payload,
+ reasoning={"effort": "low"},
+ text={"verbosity": "low"},
)
-
- print("Making non-streaming request for HTML editing...")
-
- # Generate content in non-streaming mode
- response = client.models.generate_content(
- model="gemini-2.5-pro",
- contents=contents,
- config=generate_content_config,
- )
-
- # Extract the response text
- edited_html = response.text if response.text else ""
+
+ edited_html = getattr(response, "output_text", None) or getattr(response, "text", None) or ""
print(f"Received edited HTML content length: {len(edited_html)}")
if not edited_html:
raise HTTPException(
status_code=500,
- detail="No edited HTML content generated by Google Gen AI API"
+ detail="No edited HTML content generated by OpenAI GPT-5"
)
return edited_html
- except errors.APIError as e:
- print(f"Google API Error: {e}")
+ except APIError as e:
+ print(f"OpenAI API Error: {e}")
raise HTTPException(
status_code=500,
- detail=f"Google API error during HTML editing: {str(e)}"
+ detail=f"OpenAI API error during HTML editing: {str(e)}"
)
except Exception as e:
# Handle various API errors
@@ -376,17 +341,17 @@ async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[
if "timeout" in error_msg.lower():
raise HTTPException(
status_code=408,
- detail=f"Google Gen AI API timeout during HTML editing: {error_msg}"
+ detail=f"OpenAI API timeout during HTML editing: {error_msg}"
)
elif "connection" in error_msg.lower():
raise HTTPException(
status_code=503,
- detail=f"Google Gen AI API connection error during HTML editing: {error_msg}"
+ detail=f"OpenAI API connection error during HTML editing: {error_msg}"
)
else:
raise HTTPException(
status_code=500,
- detail=f"Google Gen AI API error during HTML editing: {error_msg}"
+ detail=f"OpenAI API error during HTML editing: {error_msg}"
)
@@ -403,12 +368,12 @@ async def convert_slide_to_html(request: SlideToHtmlRequest):
SlideToHtmlResponse with generated HTML
"""
try:
- # Get Google Gen AI API key from environment
- api_key = os.getenv("GOOGLE_API_KEY")
+ # Get OpenAI API key from environment
+ api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise HTTPException(
status_code=500,
- detail="GOOGLE_API_KEY environment variable not set"
+ detail="OPENAI_API_KEY environment variable not set"
)
# Resolve image path to actual file system path
@@ -458,7 +423,8 @@ async def convert_slide_to_html(request: SlideToHtmlRequest):
base64_image=base64_image,
media_type=media_type,
xml_content=request.xml,
- api_key=api_key
+ api_key=api_key,
+ fonts=request.fonts,
)
html_content = html_content.replace("```html", "").replace("```", "")
@@ -493,12 +459,12 @@ async def convert_html_to_react(request: HtmlToReactRequest):
HtmlToReactResponse with generated React component
"""
try:
- # Get Google Gen AI API key from environment
- api_key = os.getenv("GOOGLE_API_KEY")
+ # Get OpenAI API key from environment
+ api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise HTTPException(
status_code=500,
- detail="GOOGLE_API_KEY environment variable not set"
+ detail="OPENAI_API_KEY environment variable not set"
)
# Validate HTML content
@@ -508,10 +474,31 @@ async def convert_html_to_react(request: HtmlToReactRequest):
detail="HTML content cannot be empty"
)
+ # Optionally resolve image and encode to base64
+ image_b64 = None
+ media_type = None
+ if request.image:
+ image_path = request.image
+ if image_path.startswith("/app_data/images/"):
+ relative_path = image_path[len("/app_data/images/"):]
+ actual_image_path = os.path.join(get_images_directory(), relative_path)
+ elif image_path.startswith("/static/"):
+ relative_path = image_path[len("/static/"):]
+ actual_image_path = os.path.join("static", relative_path)
+ else:
+ actual_image_path = image_path if os.path.isabs(image_path) else os.path.join(get_images_directory(), image_path)
+ if os.path.exists(actual_image_path):
+ with open(actual_image_path, "rb") as f:
+ image_b64 = base64.b64encode(f.read()).decode("utf-8")
+ ext = os.path.splitext(actual_image_path)[1].lower()
+ media_type = {'.png':'image/png','.jpg':'image/jpeg','.jpeg':'image/jpeg','.gif':'image/gif','.webp':'image/webp'}.get(ext, 'image/png')
+
# Convert HTML to React component
react_component = await generate_react_component_from_html(
html_content=request.html,
- api_key=api_key
+ api_key=api_key,
+ image_base64=image_b64,
+ media_type=media_type
)
react_component = react_component.replace("```tsx", "").replace("```", "")
@@ -555,12 +542,12 @@ async def edit_html_with_images_endpoint(
HtmlEditResponse with edited HTML
"""
try:
- # Get Google Gen AI API key from environment
- api_key = os.getenv("GOOGLE_API_KEY")
+ # Get OpenAI API key from environment
+ api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise HTTPException(
status_code=500,
- detail="GOOGLE_API_KEY environment variable not set"
+ detail="OPENAI_API_KEY environment variable not set"
)
# Validate inputs
diff --git a/servers/fastapi/constants/llm.py b/servers/fastapi/constants/llm.py
index ac4bd527..7d374f30 100644
--- a/servers/fastapi/constants/llm.py
+++ b/servers/fastapi/constants/llm.py
@@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1"
# Default models
DEFAULT_OPENAI_MODEL = "gpt-4.1"
-DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash"
-DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"
+DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash"
+DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
diff --git a/servers/fastapi/constants/supported_ollama_models.py b/servers/fastapi/constants/supported_ollama_models.py
index 455a1217..a02a7a18 100644
--- a/servers/fastapi/constants/supported_ollama_models.py
+++ b/servers/fastapi/constants/supported_ollama_models.py
@@ -6,61 +6,51 @@ SUPPORTED_OLLAMA_MODELS = {
label="Llama 3:8b",
value="llama3:8b",
size="4.7GB",
- icon="/static/icons/meta.png",
),
"llama3:70b": OllamaModelMetadata(
label="Llama 3:70b",
value="llama3:70b",
size="40GB",
- icon="/static/icons/meta.png",
),
"llama3.1:8b": OllamaModelMetadata(
label="Llama 3.1:8b",
value="llama3.1:8b",
size="4.9GB",
- icon="/static/icons/meta.png",
),
"llama3.1:70b": OllamaModelMetadata(
label="Llama 3.1:70b",
value="llama3.1:70b",
size="43GB",
- icon="/static/icons/meta.png",
),
"llama3.1:405b": OllamaModelMetadata(
label="Llama 3.1:405b",
value="llama3.1:405b",
size="243GB",
- icon="/static/icons/meta.png",
),
"llama3.2:1b": OllamaModelMetadata(
label="Llama 3.2:1b",
value="llama3.2:1b",
size="1.3GB",
- icon="/static/icons/meta.png",
),
"llama3.2:3b": OllamaModelMetadata(
label="Llama 3.2:3b",
value="llama3.2:3b",
size="2GB",
- icon="/static/icons/meta.png",
),
"llama3.3:70b": OllamaModelMetadata(
label="Llama 3.3:70b",
value="llama3.3:70b",
size="43GB",
- icon="/static/icons/meta.png",
),
"llama4:16x17b": OllamaModelMetadata(
label="Llama 4:16x17b",
value="llama4:16x17b",
size="67GB",
- icon="/static/icons/meta.png",
),
"llama4:128x17b": OllamaModelMetadata(
label="Llama 4:128x17b",
value="llama4:128x17b",
size="245GB",
- icon="/static/icons/meta.png",
),
}
@@ -69,25 +59,21 @@ SUPPORTED_GEMMA_MODELS = {
label="Gemma 3:1b",
value="gemma3:1b",
size="815MB",
- icon="/static/icons/gemma.png",
),
"gemma3:4b": OllamaModelMetadata(
label="Gemma 3:4b",
value="gemma3:4b",
size="3.3GB",
- icon="/static/icons/gemma.png",
),
"gemma3:12b": OllamaModelMetadata(
label="Gemma 3:12b",
value="gemma3:12b",
size="8.1GB",
- icon="/static/icons/gemma.png",
),
"gemma3:27b": OllamaModelMetadata(
label="Gemma 3:27b",
value="gemma3:27b",
size="17GB",
- icon="/static/icons/gemma.png",
),
}
@@ -96,43 +82,36 @@ SUPPORTED_DEEPSEEK_MODELS = {
label="DeepSeek R1:1.5b",
value="deepseek-r1:1.5b",
size="1.1GB",
- icon="/static/icons/deepseek.png",
),
"deepseek-r1:7b": OllamaModelMetadata(
label="DeepSeek R1:7b",
value="deepseek-r1:7b",
size="4.7GB",
- icon="/static/icons/deepseek.png",
),
"deepseek-r1:8b": OllamaModelMetadata(
label="DeepSeek R1:8b",
value="deepseek-r1:8b",
size="5.2GB",
- icon="/static/icons/deepseek.png",
),
"deepseek-r1:14b": OllamaModelMetadata(
label="DeepSeek R1:14b",
value="deepseek-r1:14b",
size="9GB",
- icon="/static/icons/deepseek.png",
),
"deepseek-r1:32b": OllamaModelMetadata(
label="DeepSeek R1:32b",
value="deepseek-r1:32b",
size="20GB",
- icon="/static/icons/deepseek.png",
),
"deepseek-r1:70b": OllamaModelMetadata(
label="DeepSeek R1:70b",
value="deepseek-r1:70b",
size="43GB",
- icon="/static/icons/deepseek.png",
),
"deepseek-r1:671b": OllamaModelMetadata(
label="DeepSeek R1:671b",
value="deepseek-r1:671b",
size="404GB",
- icon="/static/icons/deepseek.png",
),
}
@@ -141,49 +120,54 @@ SUPPORTED_QWEN_MODELS = {
label="Qwen 3:0.6b",
value="qwen3:0.6b",
size="523MB",
- icon="/static/icons/qwen.png",
),
"qwen3:1.7b": OllamaModelMetadata(
label="Qwen 3:1.7b",
value="qwen3:1.7b",
size="1.4GB",
- icon="/static/icons/qwen.png",
),
"qwen3:4b": OllamaModelMetadata(
label="Qwen 3:4b",
value="qwen3:4b",
size="2.6GB",
- icon="/static/icons/qwen.png",
),
"qwen3:8b": OllamaModelMetadata(
label="Qwen 3:8b",
value="qwen3:8b",
size="5.2GB",
- icon="/static/icons/qwen.png",
),
"qwen3:14b": OllamaModelMetadata(
label="Qwen 3:14b",
value="qwen3:14b",
size="9.3GB",
- icon="/static/icons/qwen.png",
),
"qwen3:30b": OllamaModelMetadata(
label="Qwen 3:30b",
value="qwen3:30b",
size="19GB",
- icon="/static/icons/qwen.png",
),
"qwen3:32b": OllamaModelMetadata(
label="Qwen 3:32b",
value="qwen3:32b",
size="20GB",
- icon="/static/icons/qwen.png",
),
"qwen3:235b": OllamaModelMetadata(
label="Qwen 3:235b",
value="qwen3:235b",
size="142GB",
- icon="/static/icons/qwen.png",
+ ),
+}
+
+SUPPORTED_GPT_OSS_MODELS = {
+ "gpt-oss:20b": OllamaModelMetadata(
+ label="GPT-OSS 20b",
+ value="gpt-oss:20b",
+ size="14GB",
+ ),
+ "gpt-oss:120b": OllamaModelMetadata(
+ label="GPT-OSS 120b",
+ value="gpt-oss:120b",
+ size="65GB",
),
}
@@ -192,4 +176,5 @@ SUPPORTED_OLLAMA_MODELS = {
**SUPPORTED_GEMMA_MODELS,
**SUPPORTED_DEEPSEEK_MODELS,
**SUPPORTED_QWEN_MODELS,
+ **SUPPORTED_GPT_OSS_MODELS,
}
diff --git a/servers/fastapi/enums/llm_call_type.py b/servers/fastapi/enums/llm_call_type.py
new file mode 100644
index 00000000..e37fe4ae
--- /dev/null
+++ b/servers/fastapi/enums/llm_call_type.py
@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class LLMCallType(Enum):
+ UNSTRUCTURED = "unstructured"
+ UNSTRUCTURED_STREAM = "unstructured_stream"
+ STRUCTURED = "structured"
+ STRUCTURED_STREAM = "structured_stream"
diff --git a/servers/fastapi/models/document_chunk.py b/servers/fastapi/models/document_chunk.py
index 6861e4fb..a7500be9 100644
--- a/servers/fastapi/models/document_chunk.py
+++ b/servers/fastapi/models/document_chunk.py
@@ -1,5 +1,7 @@
from pydantic import BaseModel
+from models.presentation_outline_model import SlideOutlineModel
+
class DocumentChunk(BaseModel):
heading: str
@@ -7,5 +9,5 @@ class DocumentChunk(BaseModel):
heading_index: int
score: float
- def to_slide_outline(self) -> str:
- return f"{self.heading}\n{self.content}"
+ def to_slide_outline(self) -> SlideOutlineModel:
+ return SlideOutlineModel(content=f"{self.heading}\n{self.content}")
diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py
index 51284173..db741ca4 100644
--- a/servers/fastapi/models/llm_message.py
+++ b/servers/fastapi/models/llm_message.py
@@ -1,7 +1,58 @@
-from typing import Literal
+from typing import Any, List, Literal, Optional
from pydantic import BaseModel
+from google.genai.types import Content as GoogleContent
+
+from models.llm_tool_call import AnthropicToolCall
class LLMMessage(BaseModel):
- role: Literal["user", "system"]
+ pass
+
+
+class LLMUserMessage(LLMMessage):
+ role: Literal["user"] = "user"
content: str
+
+
+class LLMSystemMessage(LLMMessage):
+ role: Literal["system"] = "system"
+ content: str
+
+
+class OpenAIAssistantMessage(LLMMessage):
+ role: Literal["assistant"] = "assistant"
+ content: str | None = None
+ tool_calls: Optional[List[dict]] = None
+
+
+class GoogleAssistantMessage(LLMMessage):
+ role: Literal["assistant"] = "assistant"
+ content: GoogleContent
+
+
+class AnthropicAssistantMessage(LLMMessage):
+ role: Literal["assistant"] = "assistant"
+ content: List[AnthropicToolCall]
+
+
+class AnthropicToolCallMessage(LLMMessage):
+ type: Literal["tool_result"] = "tool_result"
+ tool_use_id: str
+ content: str
+
+
+class AnthropicUserMessage(LLMMessage):
+ role: Literal["user"] = "user"
+ content: List[AnthropicToolCallMessage]
+
+
+class OpenAIToolCallMessage(LLMMessage):
+ role: Literal["tool"] = "tool"
+ content: str
+ tool_call_id: str
+
+
+class GoogleToolCallMessage(LLMMessage):
+ role: Literal["tool"] = "tool"
+ name: str
+ response: dict
diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py
new file mode 100644
index 00000000..5eb1f008
--- /dev/null
+++ b/servers/fastapi/models/llm_tool_call.py
@@ -0,0 +1,29 @@
+from typing import Literal, Optional
+from pydantic import BaseModel
+
+
+class LLMToolCall(BaseModel):
+ pass
+
+
+class OpenAIToolCallFunction(BaseModel):
+ name: str
+ arguments: str
+
+
+class OpenAIToolCall(LLMToolCall):
+ id: str
+ type: Literal["function"] = "function"
+ function: OpenAIToolCallFunction
+
+
+class GoogleToolCall(LLMToolCall):
+ name: str
+ arguments: Optional[dict] = None
+
+
+class AnthropicToolCall(LLMToolCall):
+ type: Literal["tool_use"] = "tool_use"
+ id: str
+ name: str
+ input: object
diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py
new file mode 100644
index 00000000..ccf64e67
--- /dev/null
+++ b/servers/fastapi/models/llm_tools.py
@@ -0,0 +1,29 @@
+from typing import Any, Callable, Coroutine, Optional
+from pydantic import BaseModel, Field
+
+
+class LLMTool(BaseModel):
+ pass
+
+
+class LLMDynamicTool(LLMTool):
+ name: str
+ description: str
+ parameters: dict = {}
+ handler: Callable[..., Coroutine[Any, Any, str]]
+
+
+class SearchWebTool(LLMTool):
+ """
+ Search the web for information.
+ """
+
+ query: str = Field(description="The query to search the web for")
+
+
+class GetCurrentDatetimeTool(LLMTool):
+ """
+ Get the current datetime.
+ """
+
+ pass
diff --git a/servers/fastapi/models/ollama_model_metadata.py b/servers/fastapi/models/ollama_model_metadata.py
index 1f8ed985..88c89cc6 100644
--- a/servers/fastapi/models/ollama_model_metadata.py
+++ b/servers/fastapi/models/ollama_model_metadata.py
@@ -4,5 +4,4 @@ from pydantic import BaseModel
class OllamaModelMetadata(BaseModel):
label: str
value: str
- icon: str
size: str
diff --git a/servers/fastapi/models/presentation_outline_model.py b/servers/fastapi/models/presentation_outline_model.py
index ad55ae4b..01a3b2b7 100644
--- a/servers/fastapi/models/presentation_outline_model.py
+++ b/servers/fastapi/models/presentation_outline_model.py
@@ -2,8 +2,12 @@ from typing import List
from pydantic import BaseModel
+class SlideOutlineModel(BaseModel):
+ content: str
+
+
class PresentationOutlineModel(BaseModel):
- slides: List[str]
+ slides: List[SlideOutlineModel]
def to_string(self):
message = ""
diff --git a/servers/fastapi/models/user_config.py b/servers/fastapi/models/user_config.py
index 50783544..c040d22c 100644
--- a/servers/fastapi/models/user_config.py
+++ b/servers/fastapi/models/user_config.py
@@ -35,3 +35,6 @@ class UserConfig(BaseModel):
TOOL_CALLS: Optional[bool] = None
DISABLE_THINKING: Optional[bool] = None
EXTENDED_REASONING: Optional[bool] = None
+
+ # Web Search
+ WEB_GROUNDING: Optional[bool] = None
diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml
index f0caf8e3..14240244 100644
--- a/servers/fastapi/pyproject.toml
+++ b/servers/fastapi/pyproject.toml
@@ -19,6 +19,7 @@ dependencies = [
"openai>=1.98.0",
"pathvalidate>=3.3.1",
"pdfplumber>=0.11.7",
+ "pytest>=8.4.1",
"python-pptx>=1.0.2",
"redis>=6.2.0",
"sqlmodel>=0.0.24",
diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py
index f016e763..3e8b35f2 100644
--- a/servers/fastapi/services/llm_client.py
+++ b/servers/fastapi/services/llm_client.py
@@ -1,16 +1,40 @@
import asyncio
import json
-from typing import List, Optional
+from typing import AsyncGenerator, List, Optional
from fastapi import HTTPException
from openai import AsyncOpenAI
+from openai.types.chat.chat_completion_chunk import (
+ ChatCompletionChunk as OpenAIChatCompletionChunk,
+)
from google import genai
-from google.genai.types import GenerateContentConfig
+from google.genai.types import Content as GoogleContent, Part as GoogleContentPart
+from google.genai.types import GenerateContentConfig, GoogleSearch
+from google.genai.types import Tool as GoogleTool
from anthropic import AsyncAnthropic
from anthropic.types import Message as AnthropicMessage
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
from enums.llm_provider import LLMProvider
-from models.llm_message import LLMMessage
+from models.llm_message import (
+ AnthropicAssistantMessage,
+ AnthropicUserMessage,
+ GoogleAssistantMessage,
+ GoogleToolCallMessage,
+ OpenAIAssistantMessage,
+ LLMMessage,
+ LLMSystemMessage,
+ LLMUserMessage,
+)
+from models.llm_tool_call import (
+ AnthropicToolCall,
+ GoogleToolCall,
+ LLMToolCall,
+ OpenAIToolCall,
+ OpenAIToolCallFunction,
+)
+from models.llm_tools import LLMDynamicTool, LLMTool
+from services.llm_tool_calls_handler import LLMToolCallsHandler
from utils.async_iterator import iterator_to_async
+from utils.dummy_functions import do_nothing_async
from utils.get_env import (
get_anthropic_api_key_env,
get_custom_llm_api_key_env,
@@ -20,8 +44,9 @@ from utils.get_env import (
get_ollama_url_env,
get_openai_api_key_env,
get_tool_calls_env,
+ get_web_grounding_env,
)
-from utils.llm_provider import get_llm_provider
+from utils.llm_provider import get_llm_provider, get_model
from utils.parsers import parse_bool_or_none
from utils.schema_utils import ensure_strict_json_schema
@@ -30,17 +55,25 @@ class LLMClient:
def __init__(self):
self.llm_provider = get_llm_provider()
self._client = self._get_client()
+ self.tool_calls_handler = LLMToolCallsHandler(self)
# ? Use tool calls
- def use_tool_calls(self) -> bool:
+ def use_tool_calls_for_structured_output(self) -> bool:
if self.llm_provider != LLMProvider.CUSTOM:
return False
return parse_bool_or_none(get_tool_calls_env()) or False
+ # ? Web Grounding
+ def enable_web_grounding(self) -> bool:
+ if (
+ self.llm_provider == LLMProvider.OLLAMA
+ or self.llm_provider == LLMProvider.CUSTOM
+ ):
+ return False
+ return parse_bool_or_none(get_web_grounding_env()) or False
+
# ? Disable thinking
def disable_thinking(self) -> bool:
- if self.llm_provider != LLMProvider.CUSTOM:
- return False
return parse_bool_or_none(get_disable_thinking_env()) or False
# ? Clients
@@ -106,15 +139,39 @@ class LLMClient:
# ? Prompts
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
for message in messages:
- if message.role == "system":
+ if isinstance(message, LLMSystemMessage):
return message.content
return ""
- def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]:
- return [message.content for message in messages if message.role == "user"]
+ def _get_google_messages(self, messages: List[LLMMessage]) -> List[str]:
+ contents = []
+ for message in messages:
+ if isinstance(message, LLMUserMessage):
+ contents.append(
+ GoogleContent(
+ role="user", parts=[GoogleContentPart(text=message.content)]
+ )
+ )
+ elif isinstance(message, GoogleAssistantMessage):
+ contents.append(message.content)
+ elif isinstance(message, GoogleToolCallMessage):
+ contents.append(
+ GoogleContent(
+ role="user",
+ parts=[
+ GoogleContentPart.from_function_response(
+ name=message.name, response=message.response
+ )
+ ],
+ )
+ )
- def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
- return [message for message in messages if message.role == "user"]
+ return contents
+
+ def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
+ return [
+ message for message in messages if not isinstance(message, LLMSystemMessage)
+ ]
# ? Generate Unstructured Content
async def _generate_openai(
@@ -122,89 +179,250 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ extra_body: Optional[dict] = None,
+ depth: int = 0,
+ ) -> str | None:
client: AsyncOpenAI = self._client
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
- extra_body={
- "enable_thinking": not self.disable_thinking(),
- },
+ tools=tools,
+ extra_body=extra_body,
)
+ tool_calls = response.choices[0].message.tool_calls
+ if tool_calls:
+ parsed_tool_calls = [
+ OpenAIToolCall(
+ id=tool_call.id,
+ type=tool_call.type,
+ function=OpenAIToolCallFunction(
+ name=tool_call.function.name,
+ arguments=tool_call.function.arguments,
+ ),
+ )
+ for tool_call in tool_calls
+ ]
+ tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
+ parsed_tool_calls
+ )
+ assistant_message = OpenAIAssistantMessage(
+ role="assistant",
+ content=response.choices[0].message.content,
+ tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls],
+ )
+ new_messages = [
+ *messages,
+ assistant_message,
+ *tool_call_messages,
+ ]
+ return await self._generate_openai(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ tools=tools,
+ extra_body=extra_body,
+ depth=depth + 1,
+ )
+
return response.choices[0].message.content
async def _generate_google(
self,
model: str,
messages: List[LLMMessage],
+ tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
- ):
+ depth: int = 0,
+ ) -> str | None:
client: genai.Client = self._client
+
+ google_tools = None
+ if tools:
+ google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
+
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
- contents=self._get_user_prompts(messages),
+ contents=self._get_google_messages(messages),
config=GenerateContentConfig(
+ tools=google_tools,
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
max_output_tokens=max_tokens,
),
)
- return response.text
+
+ content = response.candidates[0].content
+ response_parts = content.parts
+
+ if not response_parts:
+ return None
+
+ text_content = None
+ tool_calls = []
+ for each_part in response_parts:
+ if each_part.function_call:
+ tool_calls.append(
+ GoogleToolCall(
+ name=each_part.function_call.name,
+ arguments=each_part.function_call.args,
+ )
+ )
+ if each_part.text:
+ text_content = each_part.text
+
+ if tool_calls:
+ tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
+ tool_calls
+ )
+ new_messages = [
+ *messages,
+ GoogleAssistantMessage(
+ role="assistant",
+ content=content,
+ ),
+ *tool_call_messages,
+ ]
+ return await self._generate_google(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ tools=tools,
+ depth=depth + 1,
+ )
+
+ return text_content
async def _generate_anthropic(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ depth: int = 0,
+ ) -> str | None:
client: AsyncAnthropic = self._client
+
response: AnthropicMessage = await client.messages.create(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
- for message in self._get_user_llm_messages(messages)
+ for message in self._get_anthropic_messages(messages)
],
+ tools=tools,
max_tokens=max_tokens or 4000,
)
- text = ""
+ text_content = None
+ tool_calls: List[AnthropicToolCall] = []
for content in response.content:
if content.type == "text" and isinstance(content.text, str):
- text += content.text
- if text == "":
- return None
- return text
+ text_content = content.text
+
+ if content.type == "tool_use":
+ tool_calls.append(
+ AnthropicToolCall(
+ id=content.id,
+ type=content.type,
+ name=content.name,
+ input=content.input,
+ )
+ )
+
+ if tool_calls:
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
+ )
+ new_messages = [
+ *messages,
+ AnthropicAssistantMessage(
+ role="assistant",
+ content=[each.model_dump() for each in tool_calls],
+ ),
+ AnthropicUserMessage(
+ role="user",
+ content=[each.model_dump() for each in tool_call_messages],
+ ),
+ ]
+ return await self._generate_anthropic(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ tools=tools,
+ depth=depth + 1,
+ )
+
+ return text_content
async def _generate_ollama(
- self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ max_tokens: Optional[int] = None,
+ depth: int = 0,
):
- return await self._generate_openai(model, messages, max_tokens)
+ return await self._generate_openai(
+ model=model, messages=messages, max_tokens=max_tokens, depth=depth
+ )
async def _generate_custom(
- self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ max_tokens: Optional[int] = None,
+ depth: int = 0,
):
- return await self._generate_openai(model, messages, max_tokens)
+ extra_body = {"enable_thinking": not self.disable_thinking()}
+ return await self._generate_openai(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ extra_body=extra_body,
+ depth=depth,
+ )
async def generate(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
+ tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
+ parsed_tools = self.tool_calls_handler.parse_tools(tools)
+
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
- content = await self._generate_openai(model, messages, max_tokens)
+ content = await self._generate_openai(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ tools=parsed_tools,
+ )
case LLMProvider.GOOGLE:
- content = await self._generate_google(model, messages, max_tokens)
+ content = await self._generate_google(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ tools=parsed_tools,
+ )
case LLMProvider.ANTHROPIC:
- content = await self._generate_anthropic(model, messages, max_tokens)
+ content = await self._generate_anthropic(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ tools=parsed_tools,
+ )
case LLMProvider.OLLAMA:
- content = await self._generate_ollama(model, messages, max_tokens)
+ content = await self._generate_ollama(
+ model=model, messages=messages, max_tokens=max_tokens
+ )
case LLMProvider.CUSTOM:
- content = await self._generate_custom(model, messages, max_tokens)
+ content = await self._generate_custom(
+ model=model, messages=messages, max_tokens=max_tokens
+ )
if content is None:
raise HTTPException(
status_code=400,
@@ -220,21 +438,43 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ extra_body: Optional[dict] = None,
+ depth: int = 0,
+ ) -> dict | None:
client: AsyncOpenAI = self._client
- use_tool_calls = self.use_tool_calls()
response_schema = response_format
- if strict:
+ all_tools = [*tools] if tools else None
+
+ use_tool_calls_for_structured_output = (
+ self.use_tool_calls_for_structured_output()
+ )
+ if strict and depth == 0:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
- if not use_tool_calls:
- response = await client.chat.completions.create(
- model=model,
- messages=[message.model_dump() for message in messages],
- response_format={
+ if use_tool_calls_for_structured_output and depth == 0:
+ if all_tools is None:
+ all_tools = []
+ all_tools.append(
+ self.tool_calls_handler.parse_tool(
+ LLMDynamicTool(
+ name="ResponseSchema",
+ description="Provide response to the user",
+ parameters=response_schema,
+ handler=do_nothing_async,
+ ),
+ strict=strict,
+ )
+ )
+
+ response = await client.chat.completions.create(
+ model=model,
+ messages=[message.model_dump() for message in messages],
+ response_format=(
+ {
"type": "json_schema",
"json_schema": (
{
@@ -243,40 +483,66 @@ class LLMClient:
"schema": response_schema,
}
),
- },
- max_completion_tokens=max_tokens,
- extra_body={
- "enable_thinking": not self.disable_thinking(),
- },
- )
- content = response.choices[0].message.content
- else:
- response = await client.chat.completions.create(
- model=model,
- messages=[message.model_dump() for message in messages],
- tools=[
- {
- "type": "function",
- "function": {
- "name": "ResponseSchema",
- "description": "A response to the user's message",
- "strict": strict,
- "parameters": response_format,
- },
- }
- ],
- tool_choice="required",
- max_completion_tokens=max_tokens,
- extra_body={
- "enable_thinking": not self.disable_thinking(),
- },
- )
- tool_calls = response.choices[0].message.tool_calls
- if tool_calls:
- content = tool_calls[0].function.arguments
+ }
+ if not use_tool_calls_for_structured_output
+ else None
+ ),
+ max_completion_tokens=max_tokens,
+ tools=all_tools,
+ extra_body=extra_body,
+ )
+ content = response.choices[0].message.content
+
+ tool_calls = response.choices[0].message.tool_calls
+ has_response_schema = False
+
+ if tool_calls:
+ for tool_call in tool_calls:
+ if tool_call.function.name == "ResponseSchema":
+ content = tool_call.function.arguments
+ has_response_schema = True
+
+ if not has_response_schema:
+ parsed_tool_calls = [
+ OpenAIToolCall(
+ id=tool_call.id,
+ type=tool_call.type,
+ function=OpenAIToolCallFunction(
+ name=tool_call.function.name,
+ arguments=tool_call.function.arguments,
+ ),
+ )
+ for tool_call in tool_calls
+ ]
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_openai(
+ parsed_tool_calls
+ )
+ )
+ new_messages = [
+ *messages,
+ OpenAIAssistantMessage(
+ role="assistant",
+ content=response.choices[0].message.content,
+ tool_calls=[each.model_dump() for each in parsed_tool_calls],
+ ),
+ *tool_call_messages,
+ ]
+ content = await self._generate_openai_structured(
+ model=model,
+ messages=new_messages,
+ response_format=response_schema,
+ strict=strict,
+ max_tokens=max_tokens,
+ tools=all_tools,
+ extra_body=extra_body,
+ depth=depth + 1,
+ )
if content:
- return json.loads(content)
+ if depth == 0:
+ return json.loads(content)
+ return content
return None
async def _generate_google_structured(
@@ -285,31 +551,96 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ depth: int = 0,
+ ) -> dict | None:
client: genai.Client = self._client
+
+ google_tools = None
+ if tools:
+ google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
+ google_tools.append(
+ GoogleTool(
+ function_declarations=[
+ {
+ "name": "ResponseSchema",
+ "description": "Provide response to the user",
+ "parameters_json_schema": response_format,
+ }
+ ]
+ )
+ )
+
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
- contents=self._get_user_prompts(messages),
+ contents=self._get_google_messages(messages),
config=GenerateContentConfig(
+ tools=google_tools,
system_instruction=self._get_system_prompt(messages),
- response_mime_type="application/json",
- response_json_schema=response_format,
+ response_mime_type="application/json" if not tools else None,
+ response_json_schema=response_format if not tools else None,
max_output_tokens=max_tokens,
),
)
- content = None
- if response.text:
- content = json.loads(response.text)
- return content
+ content = response.candidates[0].content
+ response_parts = content.parts
+ text_content = None
+
+ if not response_parts:
+ return None
+
+ tool_calls: List[GoogleToolCall] = []
+ for each_part in response_parts:
+ if each_part.function_call:
+ tool_calls.append(
+ GoogleToolCall(
+ name=each_part.function_call.name,
+ arguments=each_part.function_call.args,
+ )
+ )
+
+ if each_part.text:
+ text_content = each_part.text
+
+ for each in tool_calls:
+ if each.name == "ResponseSchema":
+ return each.arguments
+
+ if tool_calls:
+ tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google(
+ tool_calls
+ )
+ new_messages = [
+ *messages,
+ GoogleAssistantMessage(
+ role="assistant",
+ content=content,
+ ),
+ *tool_call_messages,
+ ]
+ return await self._generate_google_structured(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ response_format=response_format,
+ tools=tools,
+ depth=depth + 1,
+ )
+
+ if text_content:
+ return json.loads(text_content)
+ return None
async def _generate_anthropic_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
+ tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
@@ -317,7 +648,7 @@ class LLMClient:
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
- for message in self._get_user_llm_messages(messages)
+ for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=[
@@ -325,19 +656,51 @@ class LLMClient:
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": response_format,
- }
+ },
+ *(tools or []),
],
- tool_choice={
- "type": "tool",
- "name": "ResponseSchema",
- },
)
- content: dict | None = None
- for content_block in response.content:
- if content_block.type == "tool_use":
- content = content_block.input
+ tool_calls: List[AnthropicToolCall] = []
+ for content in response.content:
+ if content.type == "tool_use":
+ tool_calls.append(
+ AnthropicToolCall(
+ id=content.id,
+ type=content.type,
+ name=content.name,
+ input=content.input,
+ )
+ )
- return content
+ for each in tool_calls:
+ if each.name == "ResponseSchema":
+ return each.input
+
+ if tool_calls:
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls)
+ )
+ new_messages = [
+ *messages,
+ AnthropicAssistantMessage(
+ role="assistant",
+ content=[each.model_dump() for each in tool_calls],
+ ),
+ AnthropicUserMessage(
+ role="user",
+ content=[each.model_dump() for each in tool_call_messages],
+ ),
+ ]
+ return await self._generate_anthropic_structured(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ response_format=response_format,
+ tools=tools,
+ depth=depth + 1,
+ )
+
+ return None
async def _generate_ollama_structured(
self,
@@ -346,9 +709,15 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
return await self._generate_openai_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
+ depth=depth,
)
async def _generate_custom_structured(
@@ -358,9 +727,17 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
+ extra_body = {"enable_thinking": not self.disable_thinking()}
return await self._generate_openai_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
+ extra_body=extra_body,
+ depth=depth,
)
async def generate_structured(
@@ -369,29 +746,53 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
+ tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
) -> dict:
+ parsed_tools = self.tool_calls_handler.parse_tools(tools)
+
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ tools=parsed_tools,
+ max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
content = await self._generate_google_structured(
- model, messages, response_format, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ tools=parsed_tools,
+ max_tokens=max_tokens,
)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic_structured(
- model, messages, response_format, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ tools=parsed_tools,
+ max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
)
case LLMProvider.CUSTOM:
content = await self._generate_custom_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
)
if content is None:
raise HTTPException(
@@ -406,90 +807,285 @@ class LLMClient:
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ extra_body: Optional[dict] = None,
+ depth: int = 0,
+ ) -> AsyncGenerator[str, None]:
client: AsyncOpenAI = self._client
- async with client.chat.completions.stream(
+
+ tool_calls: List[LLMToolCall] = []
+ current_index = 0
+ current_id = None
+ current_name = None
+ current_arguments = None
+ async for event in await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=max_tokens,
- extra_body={
- "enable_thinking": not self.disable_thinking(),
- },
- ) as stream:
- async for event in stream:
- if event.type == "content.delta":
- yield event.delta
+ tools=tools,
+ extra_body=extra_body,
+ stream=True,
+ ):
+ event: OpenAIChatCompletionChunk = event
+ content_chunk = event.choices[0].delta.content
+ if content_chunk:
+ yield content_chunk
+
+ tool_call_chunk = event.choices[0].delta.tool_calls
+ if tool_call_chunk:
+ tool_index = tool_call_chunk[0].index
+ tool_id = tool_call_chunk[0].id
+ tool_name = tool_call_chunk[0].function.name
+ tool_arguments = tool_call_chunk[0].function.arguments
+
+ if current_index != tool_index:
+ tool_calls.append(
+ OpenAIToolCall(
+ id=current_id,
+ type="function",
+ function=OpenAIToolCallFunction(
+ name=current_name,
+ arguments=current_arguments,
+ ),
+ )
+ )
+ current_index = tool_index
+ current_id = tool_id
+ current_name = tool_name
+ current_arguments = tool_arguments
+ else:
+ current_name = tool_name or current_name
+ current_id = tool_id or current_id
+ if current_arguments is None:
+ current_arguments = tool_arguments
+ else:
+ current_arguments += tool_arguments
+
+ if current_id is not None:
+ tool_calls.append(
+ OpenAIToolCall(
+ id=current_id,
+ type="function",
+ function=OpenAIToolCallFunction(
+ name=current_name,
+ arguments=current_arguments,
+ ),
+ )
+ )
+
+ if tool_calls:
+ tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
+ tool_calls
+ )
+ new_messages = [
+ *messages,
+ OpenAIAssistantMessage(
+ role="assistant",
+ content=None,
+ tool_calls=[each.model_dump() for each in tool_calls],
+ ),
+ *tool_call_messages,
+ ]
+ async for event in self._stream_openai(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ tools=tools,
+ extra_body=extra_body,
+ depth=depth + 1,
+ ):
+ yield event
async def _stream_google(
self,
model: str,
messages: List[LLMMessage],
+ tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
- ):
+ depth: int = 0,
+ ) -> AsyncGenerator[str, None]:
client: genai.Client = self._client
+
+ google_tools = None
+ if tools:
+ google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
+
+ tool_calls = None
async for event in iterator_to_async(client.models.generate_content_stream)(
model=model,
- contents=self._get_user_prompts(messages),
+ contents=self._get_google_messages(messages),
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
+ tools=google_tools,
max_output_tokens=max_tokens,
),
):
if event.text:
yield event.text
+ if event.function_calls:
+ tool_calls = [
+ GoogleToolCall(
+ name=each.name,
+ arguments=each.args,
+ )
+ for each in event.function_calls
+ ]
+
+ if tool_calls:
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_google(tool_calls)
+ )
+ new_messages = [
+ *messages,
+ GoogleAssistantMessage(
+ role="assistant",
+ content=event.candidates[0].content,
+ ),
+ *tool_call_messages,
+ ]
+ async for event in self._stream_google(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ tools=tools,
+ depth=depth + 1,
+ ):
+ yield event
+
async def _stream_anthropic(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
+ tools: Optional[List[dict]] = None,
+ depth: int = 0,
):
client: AsyncAnthropic = self._client
+
async with client.messages.stream(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
- for message in self._get_user_llm_messages(messages)
+ for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
+ tools=tools,
) as stream:
+ tool_calls: List[AnthropicToolCall] = []
async for event in stream:
event: AnthropicMessageStreamEvent = event
- if event.type == "text" and isinstance(event.text, str):
+
+ if event.type == "text":
yield event.text
+ if (
+ event.type == "content_block_stop"
+ and event.content_block.type == "tool_use"
+ ):
+ tool_calls.append(
+ AnthropicToolCall(
+ id=event.content_block.id,
+ type=event.content_block.type,
+ name=event.content_block.name,
+ input=event.content_block.input,
+ )
+ )
+
+ if tool_calls:
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_anthropic(
+ tool_calls
+ )
+ )
+ new_messages = [
+ *messages,
+ AnthropicAssistantMessage(
+ role="assistant",
+ content=[each.model_dump() for each in tool_calls],
+ ),
+ AnthropicUserMessage(
+ role="user",
+ content=[each.model_dump() for each in tool_call_messages],
+ ),
+ ]
+ async for event in self._stream_anthropic(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ tools=tools,
+ depth=depth + 1,
+ ):
+ yield event
+
def _stream_ollama(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
- return self._stream_openai(model, messages, max_tokens)
+ return self._stream_openai(
+ model=model, messages=messages, max_tokens=max_tokens, depth=depth
+ )
def _stream_custom(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
- return self._stream_openai(model, messages, max_tokens)
+ extra_body = {"enable_thinking": not self.disable_thinking()}
+ return self._stream_openai(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ extra_body=extra_body,
+ depth=depth,
+ )
def stream(
- self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ max_tokens: Optional[int] = None,
+ tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
+ parsed_tools = self.tool_calls_handler.parse_tools(tools)
+
match self.llm_provider:
case LLMProvider.OPENAI:
- return self._stream_openai(model, messages, max_tokens)
+ return self._stream_openai(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ tools=parsed_tools,
+ )
case LLMProvider.GOOGLE:
- return self._stream_google(model, messages, max_tokens)
+ return self._stream_google(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ tools=parsed_tools,
+ )
case LLMProvider.ANTHROPIC:
- return self._stream_anthropic(model, messages, max_tokens)
+ return self._stream_anthropic(
+ model=model,
+ messages=messages,
+ max_tokens=max_tokens,
+ tools=parsed_tools,
+ )
case LLMProvider.OLLAMA:
- return self._stream_ollama(model, messages, max_tokens)
+ return self._stream_ollama(
+ model=model, messages=messages, max_tokens=max_tokens
+ )
case LLMProvider.CUSTOM:
- return self._stream_custom(model, messages, max_tokens)
+ return self._stream_custom(
+ model=model, messages=messages, max_tokens=max_tokens
+ )
# ? Stream Structured Content
async def _stream_openai_structured(
@@ -499,62 +1095,145 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ extra_body: Optional[dict] = None,
+ depth: int = 0,
+ ) -> AsyncGenerator[str, None]:
client: AsyncOpenAI = self._client
- use_tool_calls = self.use_tool_calls()
+
response_schema = response_format
- if strict:
+ all_tools = [*tools] if tools else None
+
+ use_tool_calls_for_structured_output = (
+ self.use_tool_calls_for_structured_output()
+ )
+ if strict and depth == 0:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
- if not use_tool_calls:
- async with client.chat.completions.stream(
- model=model,
- messages=[message.model_dump() for message in messages],
- max_completion_tokens=max_tokens,
- response_format=(
- {
- "type": "json_schema",
- "json_schema": {
+
+ if use_tool_calls_for_structured_output and depth == 0:
+ if all_tools is None:
+ all_tools = []
+ all_tools.append(
+ self.tool_calls_handler.parse_tool(
+ LLMDynamicTool(
+ name="ResponseSchema",
+ description="Provide response to the user",
+ parameters=response_schema,
+ handler=do_nothing_async,
+ ),
+ strict=strict,
+ )
+ )
+
+ tool_calls: List[LLMToolCall] = []
+ current_index = 0
+ current_id = None
+ current_name = None
+ current_arguments = None
+
+ has_response_schema_tool_call = False
+ async for event in await client.chat.completions.create(
+ model=model,
+ messages=[message.model_dump() for message in messages],
+ max_completion_tokens=max_tokens,
+ tools=all_tools,
+ response_format=(
+ {
+ "type": "json_schema",
+ "json_schema": (
+ {
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
- },
- }
+ }
+ ),
+ }
+ if not use_tool_calls_for_structured_output
+ else None
+ ),
+ extra_body=extra_body,
+ stream=True,
+ ):
+ event: OpenAIChatCompletionChunk = event
+ content_chunk = event.choices[0].delta.content
+ if content_chunk:
+ yield content_chunk
+
+ tool_call_chunk = event.choices[0].delta.tool_calls
+ if tool_call_chunk:
+ tool_index = tool_call_chunk[0].index
+ tool_id = tool_call_chunk[0].id
+ tool_name = tool_call_chunk[0].function.name
+ tool_arguments = tool_call_chunk[0].function.arguments
+
+ if current_index != tool_index:
+ tool_calls.append(
+ OpenAIToolCall(
+ id=current_id,
+ type="function",
+ function=OpenAIToolCallFunction(
+ name=current_name,
+ arguments=current_arguments,
+ ),
+ )
+ )
+ current_index = tool_index
+ current_id = tool_id
+ current_name = tool_name
+ current_arguments = tool_arguments
+ else:
+ current_name = tool_name or current_name
+ current_id = tool_id or current_id
+ if current_arguments is None:
+ current_arguments = tool_arguments
+ else:
+ current_arguments += tool_arguments
+
+ if current_name == "ResponseSchema":
+ if tool_arguments:
+ yield tool_arguments
+ has_response_schema_tool_call = True
+
+ if current_id is not None:
+ tool_calls.append(
+ OpenAIToolCall(
+ id=current_id,
+ type="function",
+ function=OpenAIToolCallFunction(
+ name=current_name,
+ arguments=current_arguments,
+ ),
+ )
+ )
+
+ if tool_calls and not has_response_schema_tool_call:
+ tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai(
+ tool_calls
+ )
+ new_messages = [
+ *messages,
+ OpenAIAssistantMessage(
+ role="assistant",
+ content=None,
+ tool_calls=[each.model_dump() for each in tool_calls],
),
- extra_body={
- "enable_thinking": not self.disable_thinking(),
- },
- ) as stream:
- async for event in stream:
- if event.type == "content.delta":
- yield event.delta
- else:
- async with client.chat.completions.stream(
+ *tool_call_messages,
+ ]
+ async for event in self._stream_openai_structured(
model=model,
- messages=[message.model_dump() for message in messages],
- max_completion_tokens=max_tokens,
- tools=[
- {
- "type": "function",
- "function": {
- "name": "ResponseSchema",
- "description": "A response to the user's message",
- "strict": strict,
- "parameters": response_format,
- },
- }
- ],
- tool_choice="required",
- extra_body={
- "enable_thinking": not self.disable_thinking(),
- },
- ) as stream:
- async for event in stream:
- if event.type == "tool_calls.function.arguments.delta":
- yield event.arguments_delta
+ messages=new_messages,
+ max_tokens=max_tokens,
+ strict=strict,
+ tools=all_tools,
+ response_format=response_schema,
+ extra_body=extra_body,
+ depth=depth + 1,
+ ):
+ yield event
async def _stream_google_structured(
self,
@@ -562,35 +1241,99 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
max_tokens: Optional[int] = None,
- ):
+ tools: Optional[List[dict]] = None,
+ depth: int = 0,
+ ) -> AsyncGenerator[str, None]:
+
client: genai.Client = self._client
+
+ google_tools = None
+ if tools:
+ google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools]
+ google_tools.append(
+ GoogleTool(
+ function_declarations=[
+ {
+ "name": "ResponseSchema",
+ "description": "Provide response to the user",
+ "parameters_json_schema": response_format,
+ }
+ ]
+ )
+ )
+
+ tool_calls: List[GoogleToolCall] = []
+ has_response_schema_tool_call = False
async for event in iterator_to_async(client.models.generate_content_stream)(
model=model,
- contents=self._get_user_prompts(messages),
+ contents=self._get_google_messages(messages),
config=GenerateContentConfig(
+ tools=google_tools,
system_instruction=self._get_system_prompt(messages),
- response_mime_type="application/json",
- response_json_schema=response_format,
+ response_mime_type="application/json" if not tools else None,
+ response_json_schema=response_format if not tools else None,
max_output_tokens=max_tokens,
),
):
if event.text:
yield event.text
+ if event.function_calls:
+ tool_calls = [
+ GoogleToolCall(
+ name=each.name,
+ arguments=each.args,
+ )
+ for each in event.function_calls
+ ]
+
+ for each in tool_calls:
+ if each.name == "ResponseSchema":
+ has_response_schema_tool_call = True
+ if each.arguments:
+ yield json.dumps(each.arguments)
+
+ if has_response_schema_tool_call:
+ continue
+
+ if tool_calls:
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_google(tool_calls)
+ )
+ new_messages = [
+ *messages,
+ GoogleAssistantMessage(
+ role="assistant",
+ content=event.candidates[0].content,
+ ),
+ *tool_call_messages,
+ ]
+ async for event in self._stream_google_structured(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ response_format=response_format,
+ tools=tools,
+ depth=depth + 1,
+ ):
+ yield event
+
async def _stream_anthropic_structured(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
+ tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
- ):
+ depth: int = 0,
+ ) -> AsyncGenerator[str, None]:
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
- for message in self._get_user_llm_messages(messages)
+ for message in self._get_anthropic_messages(messages)
],
max_tokens=max_tokens or 4000,
tools=[
@@ -598,17 +1341,72 @@ class LLMClient:
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": response_format,
- }
+ },
+ *(tools or []),
],
- tool_choice={
- "type": "tool",
- "name": "ResponseSchema",
- },
) as stream:
+ tool_calls: List[AnthropicToolCall] = []
+ has_response_schema_tool_call = False
+ is_response_schema_tool_call_started = False
async for event in stream:
event: AnthropicMessageStreamEvent = event
- if event.type == "input_json" and isinstance(event.partial_json, str):
- yield event.partial_json
+ if (
+ event.type == "content_block_start"
+ and event.content_block.type == "tool_use"
+ ):
+ if event.content_block.name == "ResponseSchema":
+ has_response_schema_tool_call = True
+ is_response_schema_tool_call_started = True
+
+ if (
+ event.type == "content_block_delta"
+ and event.delta.type == "input_json_delta"
+ and is_response_schema_tool_call_started
+ ):
+ yield event.delta.partial_json
+
+ if has_response_schema_tool_call:
+ continue
+
+ if (
+ event.type == "content_block_stop"
+ and event.content_block.type == "tool_use"
+ ):
+ tool_calls.append(
+ AnthropicToolCall(
+ id=event.content_block.id,
+ type=event.content_block.type,
+ name=event.content_block.name,
+ input=event.content_block.input,
+ )
+ )
+
+ if tool_calls:
+ tool_call_messages = (
+ await self.tool_calls_handler.handle_tool_calls_anthropic(
+ tool_calls
+ )
+ )
+ new_messages = [
+ *messages,
+ AnthropicAssistantMessage(
+ role="assistant",
+ content=[each.model_dump() for each in tool_calls],
+ ),
+ AnthropicUserMessage(
+ role="user",
+ content=[each.model_dump() for each in tool_call_messages],
+ ),
+ ]
+ async for event in self._stream_anthropic_structured(
+ model=model,
+ messages=new_messages,
+ max_tokens=max_tokens,
+ response_format=response_format,
+ tools=tools,
+ depth=depth + 1,
+ ):
+ yield event
def _stream_ollama_structured(
self,
@@ -617,9 +1415,15 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
return self._stream_openai_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
+ depth=depth,
)
def _stream_custom_structured(
@@ -629,9 +1433,17 @@ class LLMClient:
response_format: dict,
strict: bool = False,
max_tokens: Optional[int] = None,
+ depth: int = 0,
):
+ extra_body = {"enable_thinking": not self.disable_thinking()}
return self._stream_openai_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
+ extra_body=extra_body,
+ depth=depth,
)
def stream_structured(
@@ -640,26 +1452,93 @@ class LLMClient:
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
+ tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
):
+ parsed_tools = self.tool_calls_handler.parse_tools(tools)
+
match self.llm_provider:
case LLMProvider.OPENAI:
return self._stream_openai_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ tools=parsed_tools,
+ max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
return self._stream_google_structured(
- model, messages, response_format, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ tools=parsed_tools,
+ max_tokens=max_tokens,
)
case LLMProvider.ANTHROPIC:
return self._stream_anthropic_structured(
- model, messages, response_format, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ tools=parsed_tools,
+ max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
return self._stream_ollama_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
)
case LLMProvider.CUSTOM:
return self._stream_custom_structured(
- model, messages, response_format, strict, max_tokens
+ model=model,
+ messages=messages,
+ response_format=response_format,
+ strict=strict,
+ max_tokens=max_tokens,
)
+
+ # ? Web search
+ async def _search_openai(self, query: str) -> str:
+ client: AsyncOpenAI = self._client
+ response = await client.responses.create(
+ model=get_model(),
+ tools=[
+ {
+ "type": "web_search_preview",
+ }
+ ],
+ input=query,
+ )
+ return response.output_text
+
+ async def _search_google(self, query: str) -> str:
+ client: genai.Client = self._client
+ grounding_tool = GoogleTool(google_search=GoogleSearch())
+ config = GenerateContentConfig(tools=[grounding_tool])
+
+ response = await asyncio.to_thread(
+ client.models.generate_content,
+ model=get_model(),
+ contents=query,
+ config=config,
+ )
+ return response.text
+
+ async def _search_anthropic(self, query: str) -> str:
+ client: AsyncAnthropic = self._client
+
+ response = await client.messages.create(
+ model=get_model(),
+ max_tokens=4000,
+ messages=[{"role": "user", "content": query}],
+ tools=[
+ {"type": "web_search_20250305", "name": "web_search", "max_uses": 1}
+ ],
+ )
+ result = "\n".join(
+ [each.text for each in response.content if each.type == "text"]
+ )
+ return result
diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py
new file mode 100644
index 00000000..ed0d51ee
--- /dev/null
+++ b/servers/fastapi/services/llm_tool_calls_handler.py
@@ -0,0 +1,201 @@
+import asyncio
+from datetime import datetime
+import json
+from typing import Any, Callable, Coroutine, List, Optional
+from fastapi import HTTPException
+from enums.llm_provider import LLMProvider
+from models.llm_message import (
+ AnthropicToolCallMessage,
+ GoogleToolCallMessage,
+ OpenAIToolCallMessage,
+)
+from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
+from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
+from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
+
+
+class LLMToolCallsHandler:
+ def __init__(self, client):
+ from services.llm_client import LLMClient
+
+ self.client: LLMClient = client
+
+ self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = {
+ "SearchWebTool": self.search_web_tool_call_handler,
+ "GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler,
+ }
+ self.dynamic_tools: List[LLMDynamicTool] = []
+
+ def get_tool_handler(
+ self, tool_name: str
+ ) -> Callable[..., Coroutine[Any, Any, str]]:
+ handler = self.tools_map.get(tool_name)
+ if handler:
+ return handler
+ else:
+ dynamic_tools = list(
+ filter(lambda tool: tool.name == tool_name, self.dynamic_tools)
+ )
+ if dynamic_tools:
+ return dynamic_tools[0].handler
+ raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found")
+
+ def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None):
+ if tools is None:
+ return None
+ parsed_tools = map(self.parse_tool, tools)
+ return list(parsed_tools)
+
+ def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False):
+ if isinstance(tool, LLMDynamicTool):
+ self.dynamic_tools.append(tool)
+
+ match self.client.llm_provider:
+ case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM:
+ return self.parse_tool_openai(tool, strict)
+ case LLMProvider.ANTHROPIC:
+ return self.parse_tool_anthropic(tool)
+ case LLMProvider.GOOGLE:
+ return self.parse_tool_google(tool)
+ case _:
+ raise ValueError(
+ f"LLM provider must be either openai, anthropic, or google"
+ )
+
+ def parse_tool_openai(
+ self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False
+ ):
+ if isinstance(tool, LLMDynamicTool):
+ name = tool.name
+ description = tool.description
+ parameters = tool.parameters
+ else:
+ name = tool.__name__
+ description = tool.__doc__ or ""
+ parameters = tool.model_json_schema()
+
+ if strict:
+ parameters = ensure_strict_json_schema(parameters, path=(), root=parameters)
+
+ return {
+ "type": "function",
+ "function": {
+ "name": name,
+ "description": description,
+ "strict": strict,
+ "parameters": parameters,
+ },
+ }
+
+ def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
+ parsed = self.parse_tool_openai(tool)
+ # parsed["function"]["parameters"] = flatten_json_schema(
+ # parsed["function"]["parameters"]
+ # )
+ return {
+ "name": parsed["function"]["name"],
+ "description": parsed["function"]["description"],
+ "parameters": parsed["function"]["parameters"],
+ }
+
+ def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
+ parsed = self.parse_tool_openai(tool)
+ input_schema = parsed["function"]["parameters"]
+ return {
+ "name": parsed["function"]["name"],
+ "description": parsed["function"]["description"],
+ "input_schema": {"type": "object"} if input_schema == {} else input_schema,
+ }
+
+ async def handle_tool_calls_openai(
+ self,
+ tool_calls: List[OpenAIToolCall],
+ ) -> List[OpenAIToolCallMessage]:
+ async_tool_calls_tasks = []
+ for tool_call in tool_calls:
+ tool_name = tool_call.function.name
+ tool_handler = self.get_tool_handler(tool_name)
+ async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments))
+
+ tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
+ tool_call_messages = [
+ OpenAIToolCallMessage(
+ content=result,
+ tool_call_id=tool_call.id,
+ )
+ for tool_call, result in zip(tool_calls, tool_call_results)
+ ]
+ return tool_call_messages
+
+ async def handle_tool_calls_google(
+ self,
+ tool_calls: List[GoogleToolCall],
+ ) -> List[GoogleToolCallMessage]:
+ async_tool_calls_tasks = []
+ for tool_call in tool_calls:
+ tool_name = tool_call.name
+ tool_handler = self.get_tool_handler(tool_name)
+ async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments)))
+
+ tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
+ tool_call_messages = [
+ GoogleToolCallMessage(
+ name=tool_call.name,
+ response={"result": result},
+ )
+ for tool_call, result in zip(tool_calls, tool_call_results)
+ ]
+ return tool_call_messages
+
+ async def handle_tool_calls_anthropic(
+ self,
+ tool_calls: List[AnthropicToolCall],
+ ) -> List[AnthropicToolCallMessage]:
+ async_tool_calls_tasks = []
+ for tool_call in tool_calls:
+ tool_name = tool_call.name
+ tool_handler = self.get_tool_handler(tool_name)
+ async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input)))
+
+ tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
+ tool_call_messages = [
+ AnthropicToolCallMessage(
+ content=result,
+ tool_use_id=tool_call.id,
+ )
+ for tool_call, result in zip(tool_calls, tool_call_results)
+ ]
+ return tool_call_messages
+
+ # ? Tool call handlers
+ # Search web tool call handler
+ async def search_web_tool_call_handler(self, arguments: str) -> str:
+ match self.client.llm_provider:
+ case LLMProvider.OPENAI:
+ return await self.search_web_tool_call_handler_openai(arguments)
+ case LLMProvider.ANTHROPIC:
+ return await self.search_web_tool_call_handler_anthropic(arguments)
+ case LLMProvider.GOOGLE:
+ return await self.search_web_tool_call_handler_google(arguments)
+ case _:
+ return (
+ "Web search tool call handler not implemented for this LLM provider: "
+ + self.client.llm_provider.value
+ )
+
+ async def search_web_tool_call_handler_openai(self, arguments: str) -> str:
+ args = SearchWebTool.model_validate_json(arguments)
+ return await self.client._search_openai(args.query)
+
+ async def search_web_tool_call_handler_google(self, arguments: str) -> str:
+ args = SearchWebTool.model_validate_json(arguments)
+ return await self.client._search_google(args.query)
+
+ async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str:
+ args = SearchWebTool.model_validate_json(arguments)
+ return await self.client._search_anthropic(args.query)
+
+ # Get current datetime tool call handler
+ async def get_current_datetime_tool_call_handler(self, _) -> str:
+ current_time = datetime.now()
+ return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}"
diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py
deleted file mode 100644
index f2e3d8c9..00000000
--- a/servers/fastapi/services/redis_service.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from typing import Any, Optional
-import redis
-from redis.exceptions import RedisError
-
-from utils.get_env import (
- get_redis_db_env,
- get_redis_host_env,
- get_redis_password_env,
- get_redis_port_env,
-)
-
-
-class RedisService:
- def __init__(self):
- self.redis_host = get_redis_host_env() or "localhost"
- self.redis_port = int(get_redis_port_env() or "6379")
- self.redis_db = int(get_redis_db_env() or "0")
- self.redis_password = get_redis_password_env() or None
- self.client = self._create_client()
-
- def _create_client(self) -> redis.Redis:
- return redis.Redis(
- host=self.redis_host,
- port=self.redis_port,
- db=self.redis_db,
- password=self.redis_password,
- decode_responses=True,
- )
-
- def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
- try:
- return self.client.set(key, value, ex=expire)
- except RedisError:
- return False
-
- def get(self, key: str) -> Optional[str]:
- try:
- return self.client.get(key)
- except RedisError:
- return None
-
- def delete(self, key: str) -> bool:
- try:
- return bool(self.client.delete(key))
- except RedisError:
- return False
-
- def exists(self, key: str) -> bool:
- try:
- return bool(self.client.exists(key))
- except RedisError:
- return False
-
- def set_hash(self, name: str, mapping: dict) -> bool:
- try:
- return self.client.hmset(name, mapping)
- except RedisError:
- return False
-
- def get_hash(self, name: str) -> Optional[dict]:
- try:
- return self.client.hgetall(name)
- except RedisError:
- return None
-
- def delete_hash(self, name: str, *fields: str) -> int:
- try:
- return self.client.hdel(name, *fields)
- except RedisError:
- return 0
-
- def set_list(self, name: str, values: list) -> bool:
- try:
- self.client.delete(name)
- if values:
- self.client.rpush(name, *values)
- return True
- except RedisError:
- return False
-
- def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
- try:
- return self.client.lrange(name, start, end)
- except RedisError:
- return None
-
- def add_to_set(self, name: str, *values: str) -> int:
- try:
- return self.client.sadd(name, *values)
- except RedisError:
- return 0
-
- def get_set(self, name: str) -> Optional[set]:
- try:
- return self.client.smembers(name)
- except RedisError:
- return None
-
- def remove_from_set(self, name: str, *values: str) -> int:
- try:
- return self.client.srem(name, *values)
- except RedisError:
- return 0
-
- def clear(self) -> bool:
- try:
- return self.client.flushdb()
- except RedisError:
- return False
-
- def close(self):
- try:
- self.client.close()
- except RedisError:
- pass
diff --git a/servers/fastapi/static/icons/deepseek.png b/servers/fastapi/static/icons/deepseek.png
deleted file mode 100644
index 798b8f18..00000000
Binary files a/servers/fastapi/static/icons/deepseek.png and /dev/null differ
diff --git a/servers/fastapi/static/icons/gemma.png b/servers/fastapi/static/icons/gemma.png
deleted file mode 100644
index 647d87a2..00000000
Binary files a/servers/fastapi/static/icons/gemma.png and /dev/null differ
diff --git a/servers/fastapi/static/icons/meta.png b/servers/fastapi/static/icons/meta.png
deleted file mode 100644
index 0a3d82c1..00000000
Binary files a/servers/fastapi/static/icons/meta.png and /dev/null differ
diff --git a/servers/fastapi/static/icons/qwen.png b/servers/fastapi/static/icons/qwen.png
deleted file mode 100644
index 2cee1c36..00000000
Binary files a/servers/fastapi/static/icons/qwen.png and /dev/null differ
diff --git a/servers/fastapi/utils/dummy_functions.py b/servers/fastapi/utils/dummy_functions.py
new file mode 100644
index 00000000..461e9695
--- /dev/null
+++ b/servers/fastapi/utils/dummy_functions.py
@@ -0,0 +1,2 @@
+async def do_nothing_async(_):
+ return None
diff --git a/servers/fastapi/utils/get_dynamic_models.py b/servers/fastapi/utils/get_dynamic_models.py
index 744a6a5a..fd4b2bda 100644
--- a/servers/fastapi/utils/get_dynamic_models.py
+++ b/servers/fastapi/utils/get_dynamic_models.py
@@ -1,13 +1,23 @@
from typing import List
from pydantic import Field
-from models.presentation_outline_model import PresentationOutlineModel
+from models.presentation_outline_model import (
+ PresentationOutlineModel,
+ SlideOutlineModel,
+)
from models.presentation_structure_model import PresentationStructureModel
def get_presentation_outline_model_with_n_slides(n_slides: int):
+ class SlideOutlineModelWithNSlides(SlideOutlineModel):
+ content: str = Field(
+ description="Markdown content for each slide",
+ min_length=100,
+ max_length=300,
+ )
+
class PresentationOutlineModelWithNSlides(PresentationOutlineModel):
- slides: List[str] = Field(
- description="Markdown content for each slide in about 100 to 200 words",
+ slides: List[SlideOutlineModelWithNSlides] = Field(
+ description="List of slide outlines",
min_items=n_slides,
max_items=n_slides,
)
diff --git a/servers/fastapi/utils/get_env.py b/servers/fastapi/utils/get_env.py
index c2c72efd..fa80b2a2 100644
--- a/servers/fastapi/utils/get_env.py
+++ b/servers/fastapi/utils/get_env.py
@@ -81,22 +81,6 @@ def get_pixabay_api_key_env():
return os.getenv("PIXABAY_API_KEY")
-def get_redis_host_env():
- return os.getenv("REDIS_HOST")
-
-
-def get_redis_port_env():
- return os.getenv("REDIS_PORT")
-
-
-def get_redis_db_env():
- return os.getenv("REDIS_DB")
-
-
-def get_redis_password_env():
- return os.getenv("REDIS_PASSWORD")
-
-
def get_tool_calls_env():
return os.getenv("TOOL_CALLS")
@@ -107,3 +91,7 @@ def get_disable_thinking_env():
def get_extended_reasoning_env():
return os.getenv("EXTENDED_REASONING")
+
+
+def get_web_grounding_env():
+ return os.getenv("WEB_GROUNDING")
diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py
index a8df598a..30599d08 100644
--- a/servers/fastapi/utils/llm_calls/edit_slide.py
+++ b/servers/fastapi/utils/llm_calls/edit_slide.py
@@ -1,4 +1,4 @@
-from models.llm_message import LLMMessage
+from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import SlideLayoutModel
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
@@ -41,12 +41,10 @@ def get_messages(
language: str,
):
return [
- LLMMessage(
- role="system",
+ LLMSystemMessage(
content=system_prompt,
),
- LLMMessage(
- role="user",
+ LLMUserMessage(
content=get_user_prompt(prompt, slide_data, language),
),
]
diff --git a/servers/fastapi/utils/llm_calls/edit_slide_html.py b/servers/fastapi/utils/llm_calls/edit_slide_html.py
index a5e2dfad..cf58d185 100644
--- a/servers/fastapi/utils/llm_calls/edit_slide_html.py
+++ b/servers/fastapi/utils/llm_calls/edit_slide_html.py
@@ -1,5 +1,5 @@
from typing import Optional
-from models.llm_message import LLMMessage
+from models.llm_message import LLMSystemMessage, LLMUserMessage
from services.llm_client import LLMClient
from utils.llm_provider import get_model
@@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str):
response = await client.generate(
model=model,
messages=[
- LLMMessage(role="system", content=system_prompt),
- LLMMessage(role="user", content=get_user_prompt(prompt, html)),
+ LLMSystemMessage(content=system_prompt),
+ LLMUserMessage(content=get_user_prompt(prompt, html)),
],
)
return extract_html_from_response(response) or html
diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py
index 507bb6eb..892b9cff 100644
--- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py
+++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py
@@ -1,10 +1,13 @@
-import asyncio
from typing import Optional
-from models.llm_message import LLMMessage
+from models.llm_message import LLMSystemMessage, LLMUserMessage
+from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
from services.llm_client import LLMClient
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
+from utils.get_env import get_web_grounding_env
from utils.llm_provider import get_model
+from utils.parsers import parse_bool_or_none
+from utils.user_config import get_user_config
system_prompt = """
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
@@ -29,12 +32,10 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
def get_messages(prompt: str, n_slides: int, language: str, content: str):
return [
- LLMMessage(
- role="system",
+ LLMSystemMessage(
content=system_prompt,
),
- LLMMessage(
- role="user",
+ LLMUserMessage(
content=get_user_prompt(prompt, n_slides, language, content),
),
]
@@ -51,10 +52,13 @@ async def generate_ppt_outline(
client = LLMClient()
+ tools = [SearchWebTool, GetCurrentDatetimeTool]
+
async for chunk in client.stream_structured(
model,
get_messages(prompt, n_slides, language, content),
response_model.model_json_schema(),
strict=True,
+ tools=tools if client.enable_web_grounding() else None,
):
yield chunk
diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py
index 47f47dba..1bfc0cd0 100644
--- a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py
+++ b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py
@@ -1,4 +1,4 @@
-from models.llm_message import LLMMessage
+from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import PresentationLayoutModel
from models.presentation_outline_model import PresentationOutlineModel
from services.llm_client import LLMClient
@@ -11,8 +11,7 @@ def get_messages(
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
):
return [
- LLMMessage(
- role="system",
+ LLMSystemMessage(
content=f"""
You're a professional presentation designer with creative freedom to design engaging presentations.
@@ -47,8 +46,7 @@ def get_messages(
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
""",
),
- LLMMessage(
- role="user",
+ LLMUserMessage(
content=f"""
{data}
""",
diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py
index ecff518a..be19b168 100644
--- a/servers/fastapi/utils/llm_calls/generate_slide_content.py
+++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py
@@ -1,5 +1,6 @@
-from models.llm_message import LLMMessage
+from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import SlideLayoutModel
+from models.presentation_outline_model import SlideOutlineModel
from services.llm_client import LLMClient
from utils.llm_provider import get_model
from utils.schema_utils import remove_fields_from_schema
@@ -38,19 +39,17 @@ def get_user_prompt(outline: str, language: str):
def get_messages(outline: str, language: str):
return [
- LLMMessage(
- role="system",
+ LLMSystemMessage(
content=system_prompt,
),
- LLMMessage(
- role="user",
+ LLMUserMessage(
content=get_user_prompt(outline, language),
),
]
async def get_slide_content_from_type_and_outline(
- slide_layout: SlideLayoutModel, outline: str, language: str
+ slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
):
client = LLMClient()
model = get_model()
@@ -62,7 +61,7 @@ async def get_slide_content_from_type_and_outline(
response = await client.generate_structured(
model=model,
messages=get_messages(
- outline,
+ outline.content,
language,
),
response_format=response_schema,
diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py
index f3532b48..7235e558 100644
--- a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py
+++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py
@@ -1,4 +1,4 @@
-from models.llm_message import LLMMessage
+from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from models.slide_layout_index import SlideLayoutIndex
from models.sql.slide import SlideModel
@@ -13,8 +13,7 @@ def get_messages(
current_slide_layout: int,
):
return [
- LLMMessage(
- role="system",
+ LLMSystemMessage(
content=f"""
Select a Slide Layout index based on provided user prompt and current slide data.
{layout.to_string()}
@@ -26,8 +25,7 @@ def get_messages(
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
),
- LLMMessage(
- role="user",
+ LLMUserMessage(
content=f"""
- User Prompt: {prompt}
- Current Slide Data: {slide_data}
diff --git a/servers/fastapi/utils/schema_utils.py b/servers/fastapi/utils/schema_utils.py
index ae65f002..6cb01a0e 100644
--- a/servers/fastapi/utils/schema_utils.py
+++ b/servers/fastapi/utils/schema_utils.py
@@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
return resolved
+# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions
+def flatten_json_schema(schema: dict) -> dict:
+ root_schema = deepcopy(schema)
+
+ def _flatten(node: Any) -> Any:
+ if isinstance(node, dict):
+ # If node is a pure $ref (or combined with extra fields), inline it
+ if "$ref" in node:
+ ref_value = node["$ref"]
+ assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}"
+ resolved = resolve_ref(root=root_schema, ref=ref_value)
+ assert isinstance(resolved, dict), (
+ f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}"
+ )
+ # Merge: referenced first, then overlay current (excluding $ref)
+ merged: dict[str, Any] = deepcopy(resolved)
+ for key, value in node.items():
+ if key == "$ref":
+ continue
+ merged[key] = value
+ return _flatten(merged)
+
+ flattened: dict[str, Any] = {}
+ for key, value in node.items():
+ # Drop defs/definitions in output
+ if key in ("$defs", "definitions"):
+ continue
+ if key == "properties" and isinstance(value, dict):
+ flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()}
+ elif key in ("items", "contains", "additionalProperties", "not"):
+ if isinstance(value, dict):
+ flattened[key] = _flatten(value)
+ elif isinstance(value, list):
+ flattened[key] = [_flatten(v) for v in value]
+ else:
+ flattened[key] = value
+ elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list):
+ flattened[key] = [_flatten(v) for v in value]
+ else:
+ flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value
+ return flattened
+ if isinstance(node, list):
+ return [_flatten(v) for v in node]
+ return node
+
+ result = _flatten(schema)
+ # Ensure top-level cleanup just in case
+ if isinstance(result, dict):
+ result.pop("$defs", None)
+ result.pop("definitions", None)
+ return result
+
+
# ? Not used
def generate_constraint_sentences(schema: dict) -> str:
"""
diff --git a/servers/fastapi/utils/set_env.py b/servers/fastapi/utils/set_env.py
index 7ac0e335..ea3758f3 100644
--- a/servers/fastapi/utils/set_env.py
+++ b/servers/fastapi/utils/set_env.py
@@ -79,3 +79,7 @@ def set_disable_thinking_env(value):
def set_extended_reasoning_env(value):
os.environ["EXTENDED_REASONING"] = value
+
+
+def set_web_grounding_env(value):
+ os.environ["WEB_GROUNDING"] = value
\ No newline at end of file
diff --git a/servers/fastapi/utils/user_config.py b/servers/fastapi/utils/user_config.py
index 06235d5a..49fd1722 100644
--- a/servers/fastapi/utils/user_config.py
+++ b/servers/fastapi/utils/user_config.py
@@ -22,6 +22,7 @@ from utils.get_env import (
get_image_provider_env,
get_pixabay_api_key_env,
get_extended_reasoning_env,
+ get_web_grounding_env,
)
from utils.parsers import parse_bool_or_none
from utils.set_env import (
@@ -43,6 +44,7 @@ from utils.set_env import (
set_image_provider_env,
set_pixabay_api_key_env,
set_tool_calls_env,
+ set_web_grounding_env,
)
@@ -76,12 +78,26 @@ def get_user_config():
IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(),
PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(),
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(),
- TOOL_CALLS=existing_config.TOOL_CALLS
- or parse_bool_or_none(get_tool_calls_env()),
- DISABLE_THINKING=existing_config.DISABLE_THINKING
- or parse_bool_or_none(get_disable_thinking_env()),
- EXTENDED_REASONING=existing_config.EXTENDED_REASONING
- or parse_bool_or_none(get_extended_reasoning_env()),
+ TOOL_CALLS=(
+ existing_config.TOOL_CALLS
+ if existing_config.TOOL_CALLS is not None
+ else (parse_bool_or_none(get_tool_calls_env()) or False)
+ ),
+ DISABLE_THINKING=(
+ existing_config.DISABLE_THINKING
+ if existing_config.DISABLE_THINKING is not None
+ else (parse_bool_or_none(get_disable_thinking_env()) or False)
+ ),
+ EXTENDED_REASONING=(
+ existing_config.EXTENDED_REASONING
+ if existing_config.EXTENDED_REASONING is not None
+ else (parse_bool_or_none(get_extended_reasoning_env()) or False)
+ ),
+ WEB_GROUNDING=(
+ existing_config.WEB_GROUNDING
+ if existing_config.WEB_GROUNDING is not None
+ else (parse_bool_or_none(get_web_grounding_env()) or False)
+ ),
)
@@ -122,5 +138,6 @@ def update_env_with_user_config():
if user_config.DISABLE_THINKING:
set_disable_thinking_env(str(user_config.DISABLE_THINKING))
if user_config.EXTENDED_REASONING:
- if user_config.EXTENDED_REASONING:
- set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
+ set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
+ if user_config.WEB_GROUNDING:
+ set_web_grounding_env(str(user_config.WEB_GROUNDING))
diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock
index b579f42a..4e19c02b 100644
--- a/servers/fastapi/uv.lock
+++ b/servers/fastapi/uv.lock
@@ -1061,6 +1061,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" },
]
+[[package]]
+name = "iniconfig"
+version = "2.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
+]
+
[[package]]
name = "isodate"
version = "0.7.2"
@@ -1907,6 +1916,7 @@ dependencies = [
{ name = "openai" },
{ name = "pathvalidate" },
{ name = "pdfplumber" },
+ { name = "pytest" },
{ name = "python-pptx" },
{ name = "redis" },
{ name = "sqlmodel" },
@@ -1928,6 +1938,7 @@ requires-dist = [
{ name = "openai", specifier = ">=1.98.0" },
{ name = "pathvalidate", specifier = ">=3.3.1" },
{ name = "pdfplumber", specifier = ">=0.11.7" },
+ { name = "pytest", specifier = ">=8.4.1" },
{ name = "python-pptx", specifier = ">=1.0.2" },
{ name = "redis", specifier = ">=6.2.0" },
{ name = "sqlmodel", specifier = ">=0.0.24" },
@@ -2211,6 +2222,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" },
]
+[[package]]
+name = "pytest"
+version = "8.4.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "iniconfig" },
+ { name = "packaging" },
+ { name = "pluggy" },
+ { name = "pygments" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
+]
+
[[package]]
name = "python-bidi"
version = "0.6.6"
diff --git a/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx b/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx
index f7ef3756..d1612d74 100644
--- a/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx
+++ b/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx
@@ -12,6 +12,7 @@ import * as z from "zod";
import { useDispatch } from "react-redux";
import { setLayoutLoading } from "@/store/slices/presentationGeneration";
import * as Babel from "@babel/standalone";
+import * as Recharts from "recharts";
export interface LayoutInfo {
id: string;
name?: string;
@@ -99,14 +100,17 @@ const compileCustomLayout = (layoutCode: string, React: any, z: any) => {
const factory = new Function(
"React",
"_z",
+ "Recharts",
`
const z = _z;
+ // Expose commonly used Recharts components to compiled layouts
+ const { ResponsiveContainer, LineChart, Line, BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, Legend, PieChart, Pie, Cell, AreaChart, Area, RadarChart, Radar, PolarGrid, PolarAngleAxis, PolarRadiusAxis } = Recharts || {};
${compiled}
/* everything declared in the string is in scope here */
return {
__esModule: true,
- default: dynamicSlideLayout,
+ default: typeof dynamicSlideLayout !== 'undefined' ? dynamicSlideLayout : (typeof DefaultLayout !== 'undefined' ? DefaultLayout : undefined),
layoutName,
layoutId,
layoutDescription,
@@ -115,7 +119,7 @@ const compileCustomLayout = (layoutCode: string, React: any, z: any) => {
`
);
- return factory(React, z);
+ return factory(React, z, Recharts);
};
export const LayoutProvider: React.FC<{
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useAPIKeyCheck.ts b/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useAPIKeyCheck.ts
deleted file mode 100644
index 5164e7df..00000000
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useAPIKeyCheck.ts
+++ /dev/null
@@ -1,17 +0,0 @@
-import { useState, useEffect } from "react";
-
-export const useAPIKeyCheck = () => {
- const [hasRequiredKey, setHasAnthropicKey] = useState(false);
- const [isRequiredKeyLoading, setIsAnthropicKeyLoading] = useState(true);
-
- useEffect(() => {
- fetch("/api/has-required-key")
- .then((res) => res.json())
- .then((data) => {
- setHasAnthropicKey(data.hasKey);
- setIsAnthropicKeyLoading(false);
- });
- }, []);
-
- return { hasRequiredKey, isRequiredKeyLoading };
-};
\ No newline at end of file
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/APIKeyWarning.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/APIKeyWarning.tsx
similarity index 65%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/APIKeyWarning.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/APIKeyWarning.tsx
index 44b2e0d2..ad803232 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/APIKeyWarning.tsx
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/APIKeyWarning.tsx
@@ -10,6 +10,10 @@ export const APIKeyWarning: React.FC = () => {
Please add "GOOGLE_API_KEY" to enable template creation via AI.
+
Please add your OpenAI API Key to process the layout
+
+ This feature requires an OpenAI model GPT-5. Configure your key in settings or via environment variables.
+
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/EditControls.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/EditControls.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/EditControls.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/EditControls.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/HtmlEditor.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/HtmlEditor.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/HtmlEditor.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/HtmlEditor.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/NewEachSlide.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/NewEachSlide.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/NewEachSlide.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/NewEachSlide.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideActions.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideActions.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideActions.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideActions.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideContentDisplay.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideContentDisplay.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideContentDisplay.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideContentDisplay.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/FileUploadSection.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/FileUploadSection.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/FileUploadSection.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/FileUploadSection.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/FontManager.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/FontManager.tsx
similarity index 99%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/FontManager.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/FontManager.tsx
index 7ee5df52..5ce0f53b 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/FontManager.tsx
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/FontManager.tsx
@@ -224,7 +224,7 @@ const FontManager: React.FC = ({
onClick={processSlideToHtml}
className="text-xs px-8 py-2 font-semibold bg-blue-600 text-white hover:text-white hover:bg-blue-700 border-blue-600"
>
- Extract layouts
+ Extract Template
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/LoadingSpinner.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/LoadingSpinner.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/LoadingSpinner.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/LoadingSpinner.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutButton.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutButton.tsx
similarity index 94%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutButton.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutButton.tsx
index 9123b912..6badcf5b 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutButton.tsx
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutButton.tsx
@@ -25,12 +25,12 @@ export const SaveLayoutButton: React.FC = ({
{isSaving ? (
<>
- Saving Layout...
+ Saving Template...
>
) : (
<>
- Save Layout
+ Save Template
>
)}
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutModal.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutModal.tsx
similarity index 91%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutModal.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutModal.tsx
index 57924e7a..40fa3a89 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutModal.tsx
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutModal.tsx
@@ -53,22 +53,22 @@ export const SaveLayoutModal: React.FC = ({
- Save Layout
+ Save Template
- Enter a name and description for your layout. This will help you identify it later.
+ Enter a name and description for your template. This will help you identify it later.
setLayoutName(e.target.value)}
- placeholder="Enter layout name..."
+ placeholder="Enter template name..."
disabled={isSaving}
className="w-full"
/>
@@ -81,7 +81,7 @@ export const SaveLayoutModal: React.FC = ({
id="description"
value={description}
onChange={(e) => setDescription(e.target.value)}
- placeholder="Enter a description for your layout..."
+ placeholder="Enter a description for your template..."
disabled={isSaving}
className="w-full resize-none"
rows={3}
@@ -109,7 +109,7 @@ export const SaveLayoutModal: React.FC = ({
) : (
<>
- Save Layout
+ Save Template
>
)}
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SlideContent.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/SlideContent.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/SlideContent.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/SlideContent.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/Timer.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/Timer.tsx
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/Timer.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/components/Timer.tsx
diff --git a/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useAPIKeyCheck.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useAPIKeyCheck.ts
new file mode 100644
index 00000000..f16dd738
--- /dev/null
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useAPIKeyCheck.ts
@@ -0,0 +1,18 @@
+import { useState, useEffect } from "react";
+
+export const useAPIKeyCheck = () => {
+ const [hasRequiredKey, setHasRequiredKey] = useState(false);
+ const [isRequiredKeyLoading, setIsRequiredKeyLoading] = useState(true);
+
+ useEffect(() => {
+ fetch("/api/has-required-key")
+ .then((res) => res.json())
+ .then((data) => {
+ setHasRequiredKey(Boolean(data.hasKey));
+ setIsRequiredKeyLoading(false);
+ })
+ .catch(() => setIsRequiredKeyLoading(false));
+ }, []);
+
+ return { hasRequiredKey, isRequiredKeyLoading };
+};
\ No newline at end of file
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useCustomLayout.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useCustomLayout.ts
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useCustomLayout.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useCustomLayout.ts
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useDrawingCanvas.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useDrawingCanvas.ts
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useDrawingCanvas.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useDrawingCanvas.ts
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFileUpload.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFileUpload.ts
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFileUpload.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFileUpload.ts
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFontManagement.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFontManagement.ts
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFontManagement.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFontManagement.ts
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useHtmlEdit.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useHtmlEdit.ts
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useHtmlEdit.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useHtmlEdit.ts
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useLayoutSaving.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useLayoutSaving.ts
similarity index 97%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useLayoutSaving.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useLayoutSaving.ts
index 8433a86d..570e73f2 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useLayoutSaving.ts
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useLayoutSaving.ts
@@ -27,6 +27,11 @@ export const useLayoutSaving = (
const maxRetries = 3;
let retryCount = 0;
+ console.log("Slide to convert to react", {
+ html: slide.html,
+ image: slide.screenshot_url,
+ })
+
while (retryCount < maxRetries) {
try {
const response = await fetch("/api/v1/ppt/html-to-react/", {
@@ -36,6 +41,7 @@ export const useLayoutSaving = (
},
body: JSON.stringify({
html: slide.html,
+ image: slide.screenshot_url,
}),
});
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideEdit.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideEdit.ts
similarity index 100%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideEdit.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideEdit.ts
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideProcessing.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideProcessing.ts
similarity index 96%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideProcessing.ts
rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideProcessing.ts
index 4a677a68..544c4804 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideProcessing.ts
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideProcessing.ts
@@ -35,6 +35,7 @@ export const useSlideProcessing = (
body: JSON.stringify({
image: slide.screenshot_url,
xml: slide.xml_content,
+ fonts: slide.normalized_fonts ?? [],
}),
});
@@ -157,7 +158,10 @@ export const useSlideProcessing = (
setSlides(initialSlides);
toast.success(
- `Successfully extracted ${pptxData.slides.length} slides! Converting to HTML...`
+ `Template Processing Finished`,
+ {
+ description: `Please Upload the not supported fonts, and click Extract Template`
+ }
);
diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/page.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/page.tsx
similarity index 95%
rename from servers/nextjs/app/(presentation-generator)/custom-layout/page.tsx
rename to servers/nextjs/app/(presentation-generator)/custom-template/page.tsx
index 1a6d2415..652e1c33 100644
--- a/servers/nextjs/app/(presentation-generator)/custom-layout/page.tsx
+++ b/servers/nextjs/app/(presentation-generator)/custom-template/page.tsx
@@ -18,7 +18,7 @@ import { APIKeyWarning } from "./components/APIKeyWarning";
import { useAPIKeyCheck } from "./hooks/useAPIKeyCheck";
-const CustomLayoutPage = () => {
+const CustomTemplatePage = () => {
const { refetch } = useLayout();
// Custom hooks for different concerns
@@ -66,8 +66,9 @@ const CustomLayoutPage = () => {
// Anthropic key warning
if (!hasRequiredKey) {
return ;
- }
+
+ }
return (