diff --git a/README.md b/README.md index 72a53771..a7851cc4 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep - **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom** - **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output. - **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled. +- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results. You can also set the following environment variables to customize the image generation provider and API keys: diff --git a/badu.js b/badu.js new file mode 100644 index 00000000..e69de29b diff --git a/docker-compose.yml b/docker-compose.yml index 42502b5f..85ca9210 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} production-gpu: @@ -60,6 +63,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development: @@ -87,6 +93,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development-gpu: @@ -120,5 +129,8 @@ services: - CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY} - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - - DATABASE_URL=${DATABASE_URL} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} + - DATABASE_URL=${DATABASE_URL} diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index adde8669..0dafa3e1 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -64,7 +64,7 @@ async def pull_model( # If the model is being pulled, return the model if saved_model_status: # If the model is being pulled, return the model - # ? If the model status is pulled in redis but was not found while listing pulled models, + # ? If the model status is pulled in database but was not found while listing pulled models, # ? it means the model was deleted and we need to pull it again if ( saved_model_status["status"] == "error" diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py index b0ec47af..bf5b1489 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py +++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py @@ -72,6 +72,9 @@ async def stream_outlines( presentation_outlines_json = json.loads(presentation_outlines_text) except Exception as e: print(e) + with open("./debug/outlines.txt", "w") as f: + f.write(presentation_outlines_text) + print(presentation_outlines_text) raise HTTPException( status_code=400, detail="Failed to generate presentation outlines. Please try again.", @@ -87,7 +90,8 @@ async def stream_outlines( presentation.outlines = presentation_outlines.model_dump() presentation.title = ( - presentation_outlines.slides[0][:50] + presentation_outlines.slides[0] + .content[:50] .replace("#", "") .replace("/", "") .replace("\\", "") diff --git a/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/servers/fastapi/api/v1/ppt/endpoints/presentation.py index 5b4589a2..2650782d 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/presentation.py +++ b/servers/fastapi/api/v1/ppt/endpoints/presentation.py @@ -11,7 +11,10 @@ from sqlmodel import select from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES from models.presentation_and_path import PresentationPathAndEditPath from models.presentation_from_template import GetPresentationUsingTemplateRequest -from models.presentation_outline_model import PresentationOutlineModel +from models.presentation_outline_model import ( + PresentationOutlineModel, + SlideOutlineModel, +) from models.pptx_models import PptxPresentationModel from models.presentation_layout import PresentationLayoutModel from models.presentation_structure_model import PresentationStructureModel @@ -126,7 +129,7 @@ async def create_presentation( @PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel) async def prepare_presentation( presentation_id: Annotated[str, Body()], - outlines: Annotated[List[str], Body()], + outlines: Annotated[List[SlideOutlineModel], Body()], layout: Annotated[PresentationLayoutModel, Body()], title: Annotated[Optional[str], Body()] = None, sql_session: AsyncSession = Depends(get_async_session), @@ -161,7 +164,9 @@ async def prepare_presentation( presentation_structure.slides[index] = random_slide_index sql_session.add(presentation) - presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump() + presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump( + mode="json" + ) presentation.title = title or presentation.title presentation.set_layout(layout) presentation.set_structure(presentation_structure) diff --git a/servers/fastapi/api/v1/ppt/endpoints/prompts.py b/servers/fastapi/api/v1/ppt/endpoints/prompts.py index e1c0c0ba..500a3e44 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/prompts.py +++ b/servers/fastapi/api/v1/ppt/endpoints/prompts.py @@ -23,7 +23,7 @@ Follow these rules strictly: - If there is a box/card enclosing a text, make it grow as well when the text grows, so that the text does not overflow the box/card. - Give out only HTML and Tailwind code. No other texts or explanations. - Do not give entire HTML structure with head, body, etc. Just give the respective HTML and Tailwind code inside div with above classes. -- If a list of fonts is provided, you must use the provided fonts (normalized root families) in font-family declarations, prioritizing them over inferred fonts. Use the first matching family wherever applicable. +- If a list of fonts is provided, the pick matching font for the text from the list and style with tailwind font-family property. Use following format: font-["font-name"] """ HTML_TO_REACT_SYSTEM_PROMPT = """ @@ -34,7 +34,7 @@ Convert given static HTML and Tailwind slide to a TSX React component so that it 3) For similar components in the layouts (eg, team members), they should be represented by array of such components in the schema. 4) For image and icons icons should be a different schema with two dunder fields for prompt and url separately. 5) Default value for schema fields should be populated with the respective static value in HTML input. -6) In schema max and min value for characters in string and items in array should be specified as per the given image of the slide. You should accurately evaluate the maximum and minimum possible characters respective fields can handle visually through the image. +6) In schema max and min value for characters in string and items in array should be specified as per the given image of the slide. You should accurately evaluate the maximum and minimum possible characters respective fields can handle visually through the image. ALso give out maximum number of words it can handle in the meta. 7) For image and icons schema should be compulsorily declared with two dunder fields for prompt and url separately. 8) Component name at the end should always yo 'dynamicSlideLayout'. 9) **Import or export statements should not be present in the output.** @@ -53,6 +53,12 @@ Convert given static HTML and Tailwind slide to a TSX React component so that it 14. Always complete the reference, do not give "slideData .? .cards" instead give "slideData?.cards". 15. Do not add anything other than code. Do not add "use client", "json", "typescript", "javascript" and other prefix or suffix, just give out code exactly formatted like example. 16. In schema, give default for all fields irrespective of their types, give defualt values for array and objects as well. +17. For charts use recharts.js library and follow these rules strictly: + - Do not import rechart, it will already be imported. + - There should support for multiple chart types including bar, line, pie and donut in the same size as given. + - Use an attribute in the schema to select between chart types. + - All data should be properly represented in schema. +18. For diagrams use mermaid with appropriate placeholder which can render any daigram. Schema should have a field for code. Render in the placeholder properly. For example: Input:

Effects of Global Warming

global warming effects on earth

Global warming triggers a cascade of effects on our planet. These changes impact everything from our oceans to our ecosystems.

sea level rising icon

Rising Sea Levels

Rising sea levels threaten coastal communities and ecosystems due to melting glaciers and thermal expansion.

heatwave icon

Intense Heatwaves

Heatwaves are becoming more frequent and intense, posing significant risks to human health and agriculture.

precipitation changes icon

Changes in Precipitation

Altered precipitation patterns lead to increased droughts in some regions and severe flooding in others, affecting water resources.

@@ -62,7 +68,7 @@ const ImageSchema = z.object({ description: "URL to image", }), __image_prompt__: z.string().meta({ - description: "Prompt used to generate the image", + description: "Prompt used to generate the image. Max 30 words", }).min(10).max(50), }) @@ -71,7 +77,7 @@ const IconSchema = z.object({ description: "URL to icon", }), __icon_query__: z.string().meta({ - description: "Query used to search the icon", + description: "Query used to search the icon. Max 3 words", }).min(5).max(20), }) const layoutId = "bullet-with-icons-slide" @@ -80,23 +86,23 @@ const layoutDescription = "A bullets style slide with main content, supporting i const Schema = z.object({ title: z.string().min(3).max(40).default("Problem").meta({ - description: "Main title of the slide", + description: "Main title of the slide. Max 5 words", }), description: z.string().max(150).default("Businesses face challenges with outdated technology and rising costs, limiting efficiency and growth in competitive markets.").meta({ - description: "Main description text explaining the problem or topic", + description: "Main description text explaining the problem or topic. Max 30 words", }), image: ImageSchema.default({ __image_url__: 'https://images.unsplash.com/photo-1552664730-d307ca884978?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1000&q=80', __image_prompt__: "Business people analyzing documents and charts in office" }).meta({ - description: "Supporting image for the slide", + description: "Supporting image for the slide. Max 30 words", }), bulletPoints: z.array(z.object({ title: z.string().min(2).max(80).meta({ - description: "Bullet point title", + description: "Bullet point title. Max 4 words", }), description: z.string().min(10).max(150).meta({ - description: "Bullet point description", + description: "Bullet point description. Max 15 words", }), icon: IconSchema, })).min(1).max(3).default([ @@ -117,7 +123,7 @@ const Schema = z.object({ } } ]).meta({ - description: "List of bullet points with icons and descriptions", + description: "List of bullet points with icons and descriptions. Max 3 points", }) }) diff --git a/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py b/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py index 5e09179d..67539938 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py +++ b/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py @@ -4,9 +4,8 @@ from datetime import datetime from typing import Optional, List, Dict from fastapi import APIRouter, HTTPException, File, UploadFile, Form, Depends from pydantic import BaseModel -from google import genai -from google.genai import errors -from google.genai import types +from openai import OpenAI +from openai import APIError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, delete, func from utils.asset_directory_utils import get_images_directory @@ -26,6 +25,7 @@ LAYOUT_MANAGEMENT_ROUTER = APIRouter(prefix="/layout-management", tags=["layout- class SlideToHtmlRequest(BaseModel): image: str # Partial path to image file (e.g., "/app_data/images/uuid/slide_1.png") xml: str # OXML content as text + fonts: Optional[List[str]] = None # Optional normalized root fonts for this slide class SlideToHtmlResponse(BaseModel): success: bool @@ -42,6 +42,7 @@ class HtmlEditResponse(BaseModel): # Request/Response models for html-to-react endpoint class HtmlToReactRequest(BaseModel): html: str # HTML content to convert to React component + image: Optional[str] = None # Optional image path to provide visual context class HtmlToReactResponse(BaseModel): @@ -94,15 +95,16 @@ class ErrorResponse(BaseModel): error_code: Optional[str] = None -async def generate_html_from_slide(base64_image: str, media_type: str, xml_content: str, api_key: str) -> str: +async def generate_html_from_slide(base64_image: str, media_type: str, xml_content: str, api_key: str, fonts: Optional[List[str]] = None) -> str: """ - Generate HTML content from slide image and XML using Google Gen AI API. + Generate HTML content from slide image and XML using OpenAI GPT-5 Responses API. Args: base64_image: Base64 encoded image data media_type: MIME type of the image (e.g., 'image/png') xml_content: OXML content as text - api_key: Google Gen AI API key + api_key: OpenAI API key + fonts: Optional list of normalized root font families to prefer in output Returns: Generated HTML content as string @@ -110,62 +112,51 @@ async def generate_html_from_slide(base64_image: str, media_type: str, xml_conte Raises: HTTPException: If API call fails or no content is generated """ - print(f"Generating HTML from slide image and XML using Google Gen AI API...") + print(f"Generating HTML from slide image and XML using OpenAI GPT-5 Responses API...") try: - # Initialize Google Gen AI client - client = genai.Client(api_key=api_key) - - # Convert base64 to bytes - image_bytes = base64.b64decode(base64_image) - - print("Starting non-streaming request to Google Gen AI for HTML generation...") - - # Create content with image and text - contents = [ - types.Part.from_bytes( - mime_type=media_type, - data=image_bytes, - ), - types.Part.from_text(text=f"\nOXML: \n\n{xml_content}"), + client = OpenAI(api_key=api_key) + + # Compose input for Responses API. Include system prompt, image (separate), OXML and optional fonts text. + data_url = f"data:{media_type};base64,{base64_image}" + fonts_text = f"\nFONTS (Normalized root families used in this slide, use where it is required): {', '.join(fonts)}" if fonts else "" + user_text = f"OXML: \n\n{fonts_text}" + input_payload = [ + {"role": "system", "content": GENERATE_HTML_SYSTEM_PROMPT}, + { + "role": "user", + "content": [ + {"type": "input_image", "image_url": data_url}, + {"type": "input_text", "text": user_text}, + ], + }, ] - - # Generate content config with thinking enabled - generate_content_config = types.GenerateContentConfig( - system_instruction=GENERATE_HTML_SYSTEM_PROMPT, - max_output_tokens=65536, - temperature=1.0, - thinking_config=types.ThinkingConfig( - thinking_budget=32768, - ), - ) - - print("Making non-streaming request for HTML generation...") - - # Generate content in non-streaming mode - response = client.models.generate_content( - model="gemini-2.5-pro", - contents=contents, - config=generate_content_config, + + print("Making Responses API request for HTML generation...") + response = client.responses.create( + model="gpt-5", + input=input_payload, + reasoning={"effort": "high"}, + text={"verbosity": "low"}, ) # Extract the response text - html_content = response.text if response.text else "" + html_content = getattr(response, "output_text", None) or getattr(response, "text", None) or "" print(f"Received HTML content length: {len(html_content)}") if not html_content: raise HTTPException( status_code=500, - detail="No HTML content generated by Google Gen AI API" + detail="No HTML content generated by OpenAI GPT-5" ) return html_content - except errors.APIError as e: - print(f"Google API Error: {e}") + except APIError as e: + print(f"OpenAI API Error: {e}") raise HTTPException( status_code=500, - detail=f"Google API error during HTML generation: {str(e)}" + detail=f"OpenAI API error during HTML generation: {str(e)}" ) except Exception as e: # Handle various API errors @@ -175,27 +166,27 @@ async def generate_html_from_slide(base64_image: str, media_type: str, xml_conte if "timeout" in error_msg.lower(): raise HTTPException( status_code=408, - detail=f"Google Gen AI API timeout during HTML generation: {error_msg}" + detail=f"OpenAI API timeout during HTML generation: {error_msg}" ) elif "connection" in error_msg.lower(): raise HTTPException( status_code=503, - detail=f"Google Gen AI API connection error during HTML generation: {error_msg}" + detail=f"OpenAI API connection error during HTML generation: {error_msg}" ) else: raise HTTPException( status_code=500, - detail=f"Google Gen AI API error during HTML generation: {error_msg}" + detail=f"OpenAI API error during HTML generation: {error_msg}" ) -async def generate_react_component_from_html(html_content: str, api_key: str) -> str: +async def generate_react_component_from_html(html_content: str, api_key: str, image_base64: Optional[str] = None, media_type: Optional[str] = None) -> str: """ - Convert HTML content to TSX React component using Google Gen AI API. + Convert HTML content to TSX React component using OpenAI GPT-5 Responses API. Args: html_content: Generated HTML content - api_key: Google Gen AI API key + api_key: OpenAI API key Returns: Generated TSX React component code as string @@ -204,42 +195,36 @@ async def generate_react_component_from_html(html_content: str, api_key: str) -> HTTPException: If API call fails or no content is generated """ try: - # Initialize Google Gen AI client - client = genai.Client(api_key=api_key) - - print("Starting non-streaming request to Google Gen AI for React component generation...") - - # Create content with text - contents = types.Part.from_text(text=html_content) - - # Generate content config with thinking enabled - generate_content_config = types.GenerateContentConfig( - system_instruction=HTML_TO_REACT_SYSTEM_PROMPT, - max_output_tokens=65536, - temperature=1.0, - thinking_config=types.ThinkingConfig( - thinking_budget=15000, - ), - ) - - print("Making non-streaming request for React component...") - - # Generate content in non-streaming mode - response = client.models.generate_content( - model="gemini-2.5-pro", - contents=contents, - config=generate_content_config, + client = OpenAI(api_key=api_key) + + print("Making Responses API request for React component generation...") + + # Build payload with optional image + content_parts = [{"type": "input_text", "text": f"HTML INPUT:\n{html_content}"}] + if image_base64 and media_type: + data_url = f"data:{media_type};base64,{image_base64}" + content_parts.insert(0, {"type": "input_image", "image_url": data_url}) + + input_payload = [ + {"role": "system", "content": HTML_TO_REACT_SYSTEM_PROMPT}, + {"role": "user", "content": content_parts}, + ] + + response = client.responses.create( + model="gpt-5", + input=input_payload, + reasoning={"effort": "minimal"}, + text={"verbosity": "low"}, ) - # Extract the response text - react_content = response.text if response.text else "" + react_content = getattr(response, "output_text", None) or getattr(response, "text", None) or "" print(f"Received React content length: {len(react_content)}") if not react_content: raise HTTPException( status_code=500, - detail="No React component generated by Google Gen AI API" + detail="No React component generated by OpenAI GPT-5" ) react_content = react_content.replace("```tsx", "").replace("```", "").replace("typescript", "").replace("javascript", "") @@ -256,11 +241,11 @@ async def generate_react_component_from_html(html_content: str, api_key: str) -> print(f"Filtered React content length: {len(filtered_react_content)}") return filtered_react_content - except errors.APIError as e: - print(f"Google API Error: {e}") + except APIError as e: + print(f"OpenAI API Error: {e}") raise HTTPException( status_code=500, - detail=f"Google API error during React generation: {str(e)}" + detail=f"OpenAI API error during React generation: {str(e)}" ) except Exception as e: # Handle various API errors @@ -270,23 +255,23 @@ async def generate_react_component_from_html(html_content: str, api_key: str) -> if "timeout" in error_msg.lower(): raise HTTPException( status_code=408, - detail=f"Google Gen AI API timeout during React generation: {error_msg}" + detail=f"OpenAI API timeout during React generation: {error_msg}" ) elif "connection" in error_msg.lower(): raise HTTPException( status_code=503, - detail=f"Google Gen AI API connection error during React generation: {error_msg}" + detail=f"OpenAI API connection error during React generation: {error_msg}" ) else: raise HTTPException( status_code=500, - detail=f"Google Gen AI API error during React generation: {error_msg}" + detail=f"OpenAI API error during React generation: {error_msg}" ) async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[str], media_type: str, html_content: str, prompt: str, api_key: str) -> str: """ - Edit HTML content based on one or two images and a text prompt using Google Gen AI API. + Edit HTML content based on one or two images and a text prompt using OpenAI GPT-5 Responses API. Args: current_ui_base64: Base64 encoded current UI image data @@ -294,7 +279,7 @@ async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[ media_type: MIME type of the images (e.g., 'image/png') html_content: Current HTML content to edit prompt: Text prompt describing the changes - api_key: Google Gen AI API key + api_key: OpenAI API key Returns: Edited HTML content as string @@ -303,70 +288,50 @@ async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[ HTTPException: If API call fails or no content is generated """ try: - # Initialize Google Gen AI client - client = genai.Client(api_key=api_key) - - print("Starting non-streaming request to Google Gen AI for HTML editing...") - - # Convert base64 images to bytes - current_ui_bytes = base64.b64decode(current_ui_base64) - - # Build content array - always include text and current UI image - contents = [ - types.Part.from_text(text=f"Current HTML to edit:\n\n{html_content}\n\nText prompt for changes: {prompt}"), - types.Part.from_bytes( - mime_type=media_type, - data=current_ui_bytes, - ) + client = OpenAI(api_key=api_key) + + print("Making Responses API request for HTML editing...") + + current_data_url = f"data:{media_type};base64,{current_ui_base64}" + sketch_data_url = f"data:{media_type};base64,{sketch_base64}" if sketch_base64 else None + + content_parts = [ + {"type": "input_image", "image_url": current_data_url}, + {"type": "input_text", "text": f"CURRENT HTML TO EDIT:\n{html_content}\n\nTEXT PROMPT FOR CHANGES:\n{prompt}"}, ] - - # Only add sketch image if provided - if sketch_base64: - sketch_bytes = base64.b64decode(sketch_base64) - contents.append( - types.Part.from_bytes( - mime_type=media_type, - data=sketch_bytes, - ) - ) - - # Generate content config with thinking enabled - generate_content_config = types.GenerateContentConfig( - system_instruction=HTML_EDIT_SYSTEM_PROMPT, - max_output_tokens=65536, - temperature=1.0, - thinking_config=types.ThinkingConfig( - thinking_budget=16000, - ), + if sketch_data_url: + # Insert sketch image after current UI image for context + content_parts.insert(1, {"type": "input_image", "image_url": sketch_data_url}) + + input_payload = [ + {"role": "system", "content": HTML_EDIT_SYSTEM_PROMPT}, + {"role": "user", "content": content_parts}, + ] + + response = client.responses.create( + model="gpt-5", + input=input_payload, + reasoning={"effort": "low"}, + text={"verbosity": "low"}, ) - - print("Making non-streaming request for HTML editing...") - - # Generate content in non-streaming mode - response = client.models.generate_content( - model="gemini-2.5-pro", - contents=contents, - config=generate_content_config, - ) - - # Extract the response text - edited_html = response.text if response.text else "" + + edited_html = getattr(response, "output_text", None) or getattr(response, "text", None) or "" print(f"Received edited HTML content length: {len(edited_html)}") if not edited_html: raise HTTPException( status_code=500, - detail="No edited HTML content generated by Google Gen AI API" + detail="No edited HTML content generated by OpenAI GPT-5" ) return edited_html - except errors.APIError as e: - print(f"Google API Error: {e}") + except APIError as e: + print(f"OpenAI API Error: {e}") raise HTTPException( status_code=500, - detail=f"Google API error during HTML editing: {str(e)}" + detail=f"OpenAI API error during HTML editing: {str(e)}" ) except Exception as e: # Handle various API errors @@ -376,17 +341,17 @@ async def edit_html_with_images(current_ui_base64: str, sketch_base64: Optional[ if "timeout" in error_msg.lower(): raise HTTPException( status_code=408, - detail=f"Google Gen AI API timeout during HTML editing: {error_msg}" + detail=f"OpenAI API timeout during HTML editing: {error_msg}" ) elif "connection" in error_msg.lower(): raise HTTPException( status_code=503, - detail=f"Google Gen AI API connection error during HTML editing: {error_msg}" + detail=f"OpenAI API connection error during HTML editing: {error_msg}" ) else: raise HTTPException( status_code=500, - detail=f"Google Gen AI API error during HTML editing: {error_msg}" + detail=f"OpenAI API error during HTML editing: {error_msg}" ) @@ -403,12 +368,12 @@ async def convert_slide_to_html(request: SlideToHtmlRequest): SlideToHtmlResponse with generated HTML """ try: - # Get Google Gen AI API key from environment - api_key = os.getenv("GOOGLE_API_KEY") + # Get OpenAI API key from environment + api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise HTTPException( status_code=500, - detail="GOOGLE_API_KEY environment variable not set" + detail="OPENAI_API_KEY environment variable not set" ) # Resolve image path to actual file system path @@ -458,7 +423,8 @@ async def convert_slide_to_html(request: SlideToHtmlRequest): base64_image=base64_image, media_type=media_type, xml_content=request.xml, - api_key=api_key + api_key=api_key, + fonts=request.fonts, ) html_content = html_content.replace("```html", "").replace("```", "") @@ -493,12 +459,12 @@ async def convert_html_to_react(request: HtmlToReactRequest): HtmlToReactResponse with generated React component """ try: - # Get Google Gen AI API key from environment - api_key = os.getenv("GOOGLE_API_KEY") + # Get OpenAI API key from environment + api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise HTTPException( status_code=500, - detail="GOOGLE_API_KEY environment variable not set" + detail="OPENAI_API_KEY environment variable not set" ) # Validate HTML content @@ -508,10 +474,31 @@ async def convert_html_to_react(request: HtmlToReactRequest): detail="HTML content cannot be empty" ) + # Optionally resolve image and encode to base64 + image_b64 = None + media_type = None + if request.image: + image_path = request.image + if image_path.startswith("/app_data/images/"): + relative_path = image_path[len("/app_data/images/"):] + actual_image_path = os.path.join(get_images_directory(), relative_path) + elif image_path.startswith("/static/"): + relative_path = image_path[len("/static/"):] + actual_image_path = os.path.join("static", relative_path) + else: + actual_image_path = image_path if os.path.isabs(image_path) else os.path.join(get_images_directory(), image_path) + if os.path.exists(actual_image_path): + with open(actual_image_path, "rb") as f: + image_b64 = base64.b64encode(f.read()).decode("utf-8") + ext = os.path.splitext(actual_image_path)[1].lower() + media_type = {'.png':'image/png','.jpg':'image/jpeg','.jpeg':'image/jpeg','.gif':'image/gif','.webp':'image/webp'}.get(ext, 'image/png') + # Convert HTML to React component react_component = await generate_react_component_from_html( html_content=request.html, - api_key=api_key + api_key=api_key, + image_base64=image_b64, + media_type=media_type ) react_component = react_component.replace("```tsx", "").replace("```", "") @@ -555,12 +542,12 @@ async def edit_html_with_images_endpoint( HtmlEditResponse with edited HTML """ try: - # Get Google Gen AI API key from environment - api_key = os.getenv("GOOGLE_API_KEY") + # Get OpenAI API key from environment + api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise HTTPException( status_code=500, - detail="GOOGLE_API_KEY environment variable not set" + detail="OPENAI_API_KEY environment variable not set" ) # Validate inputs diff --git a/servers/fastapi/constants/llm.py b/servers/fastapi/constants/llm.py index ac4bd527..7d374f30 100644 --- a/servers/fastapi/constants/llm.py +++ b/servers/fastapi/constants/llm.py @@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1" # Default models DEFAULT_OPENAI_MODEL = "gpt-4.1" -DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash" -DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620" +DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash" +DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514" diff --git a/servers/fastapi/constants/supported_ollama_models.py b/servers/fastapi/constants/supported_ollama_models.py index 455a1217..a02a7a18 100644 --- a/servers/fastapi/constants/supported_ollama_models.py +++ b/servers/fastapi/constants/supported_ollama_models.py @@ -6,61 +6,51 @@ SUPPORTED_OLLAMA_MODELS = { label="Llama 3:8b", value="llama3:8b", size="4.7GB", - icon="/static/icons/meta.png", ), "llama3:70b": OllamaModelMetadata( label="Llama 3:70b", value="llama3:70b", size="40GB", - icon="/static/icons/meta.png", ), "llama3.1:8b": OllamaModelMetadata( label="Llama 3.1:8b", value="llama3.1:8b", size="4.9GB", - icon="/static/icons/meta.png", ), "llama3.1:70b": OllamaModelMetadata( label="Llama 3.1:70b", value="llama3.1:70b", size="43GB", - icon="/static/icons/meta.png", ), "llama3.1:405b": OllamaModelMetadata( label="Llama 3.1:405b", value="llama3.1:405b", size="243GB", - icon="/static/icons/meta.png", ), "llama3.2:1b": OllamaModelMetadata( label="Llama 3.2:1b", value="llama3.2:1b", size="1.3GB", - icon="/static/icons/meta.png", ), "llama3.2:3b": OllamaModelMetadata( label="Llama 3.2:3b", value="llama3.2:3b", size="2GB", - icon="/static/icons/meta.png", ), "llama3.3:70b": OllamaModelMetadata( label="Llama 3.3:70b", value="llama3.3:70b", size="43GB", - icon="/static/icons/meta.png", ), "llama4:16x17b": OllamaModelMetadata( label="Llama 4:16x17b", value="llama4:16x17b", size="67GB", - icon="/static/icons/meta.png", ), "llama4:128x17b": OllamaModelMetadata( label="Llama 4:128x17b", value="llama4:128x17b", size="245GB", - icon="/static/icons/meta.png", ), } @@ -69,25 +59,21 @@ SUPPORTED_GEMMA_MODELS = { label="Gemma 3:1b", value="gemma3:1b", size="815MB", - icon="/static/icons/gemma.png", ), "gemma3:4b": OllamaModelMetadata( label="Gemma 3:4b", value="gemma3:4b", size="3.3GB", - icon="/static/icons/gemma.png", ), "gemma3:12b": OllamaModelMetadata( label="Gemma 3:12b", value="gemma3:12b", size="8.1GB", - icon="/static/icons/gemma.png", ), "gemma3:27b": OllamaModelMetadata( label="Gemma 3:27b", value="gemma3:27b", size="17GB", - icon="/static/icons/gemma.png", ), } @@ -96,43 +82,36 @@ SUPPORTED_DEEPSEEK_MODELS = { label="DeepSeek R1:1.5b", value="deepseek-r1:1.5b", size="1.1GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:7b": OllamaModelMetadata( label="DeepSeek R1:7b", value="deepseek-r1:7b", size="4.7GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:8b": OllamaModelMetadata( label="DeepSeek R1:8b", value="deepseek-r1:8b", size="5.2GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:14b": OllamaModelMetadata( label="DeepSeek R1:14b", value="deepseek-r1:14b", size="9GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:32b": OllamaModelMetadata( label="DeepSeek R1:32b", value="deepseek-r1:32b", size="20GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:70b": OllamaModelMetadata( label="DeepSeek R1:70b", value="deepseek-r1:70b", size="43GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:671b": OllamaModelMetadata( label="DeepSeek R1:671b", value="deepseek-r1:671b", size="404GB", - icon="/static/icons/deepseek.png", ), } @@ -141,49 +120,54 @@ SUPPORTED_QWEN_MODELS = { label="Qwen 3:0.6b", value="qwen3:0.6b", size="523MB", - icon="/static/icons/qwen.png", ), "qwen3:1.7b": OllamaModelMetadata( label="Qwen 3:1.7b", value="qwen3:1.7b", size="1.4GB", - icon="/static/icons/qwen.png", ), "qwen3:4b": OllamaModelMetadata( label="Qwen 3:4b", value="qwen3:4b", size="2.6GB", - icon="/static/icons/qwen.png", ), "qwen3:8b": OllamaModelMetadata( label="Qwen 3:8b", value="qwen3:8b", size="5.2GB", - icon="/static/icons/qwen.png", ), "qwen3:14b": OllamaModelMetadata( label="Qwen 3:14b", value="qwen3:14b", size="9.3GB", - icon="/static/icons/qwen.png", ), "qwen3:30b": OllamaModelMetadata( label="Qwen 3:30b", value="qwen3:30b", size="19GB", - icon="/static/icons/qwen.png", ), "qwen3:32b": OllamaModelMetadata( label="Qwen 3:32b", value="qwen3:32b", size="20GB", - icon="/static/icons/qwen.png", ), "qwen3:235b": OllamaModelMetadata( label="Qwen 3:235b", value="qwen3:235b", size="142GB", - icon="/static/icons/qwen.png", + ), +} + +SUPPORTED_GPT_OSS_MODELS = { + "gpt-oss:20b": OllamaModelMetadata( + label="GPT-OSS 20b", + value="gpt-oss:20b", + size="14GB", + ), + "gpt-oss:120b": OllamaModelMetadata( + label="GPT-OSS 120b", + value="gpt-oss:120b", + size="65GB", ), } @@ -192,4 +176,5 @@ SUPPORTED_OLLAMA_MODELS = { **SUPPORTED_GEMMA_MODELS, **SUPPORTED_DEEPSEEK_MODELS, **SUPPORTED_QWEN_MODELS, + **SUPPORTED_GPT_OSS_MODELS, } diff --git a/servers/fastapi/enums/llm_call_type.py b/servers/fastapi/enums/llm_call_type.py new file mode 100644 index 00000000..e37fe4ae --- /dev/null +++ b/servers/fastapi/enums/llm_call_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LLMCallType(Enum): + UNSTRUCTURED = "unstructured" + UNSTRUCTURED_STREAM = "unstructured_stream" + STRUCTURED = "structured" + STRUCTURED_STREAM = "structured_stream" diff --git a/servers/fastapi/models/document_chunk.py b/servers/fastapi/models/document_chunk.py index 6861e4fb..a7500be9 100644 --- a/servers/fastapi/models/document_chunk.py +++ b/servers/fastapi/models/document_chunk.py @@ -1,5 +1,7 @@ from pydantic import BaseModel +from models.presentation_outline_model import SlideOutlineModel + class DocumentChunk(BaseModel): heading: str @@ -7,5 +9,5 @@ class DocumentChunk(BaseModel): heading_index: int score: float - def to_slide_outline(self) -> str: - return f"{self.heading}\n{self.content}" + def to_slide_outline(self) -> SlideOutlineModel: + return SlideOutlineModel(content=f"{self.heading}\n{self.content}") diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 51284173..db741ca4 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -1,7 +1,58 @@ -from typing import Literal +from typing import Any, List, Literal, Optional from pydantic import BaseModel +from google.genai.types import Content as GoogleContent + +from models.llm_tool_call import AnthropicToolCall class LLMMessage(BaseModel): - role: Literal["user", "system"] + pass + + +class LLMUserMessage(LLMMessage): + role: Literal["user"] = "user" content: str + + +class LLMSystemMessage(LLMMessage): + role: Literal["system"] = "system" + content: str + + +class OpenAIAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: str | None = None + tool_calls: Optional[List[dict]] = None + + +class GoogleAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: GoogleContent + + +class AnthropicAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: List[AnthropicToolCall] + + +class AnthropicToolCallMessage(LLMMessage): + type: Literal["tool_result"] = "tool_result" + tool_use_id: str + content: str + + +class AnthropicUserMessage(LLMMessage): + role: Literal["user"] = "user" + content: List[AnthropicToolCallMessage] + + +class OpenAIToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + content: str + tool_call_id: str + + +class GoogleToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + name: str + response: dict diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py new file mode 100644 index 00000000..5eb1f008 --- /dev/null +++ b/servers/fastapi/models/llm_tool_call.py @@ -0,0 +1,29 @@ +from typing import Literal, Optional +from pydantic import BaseModel + + +class LLMToolCall(BaseModel): + pass + + +class OpenAIToolCallFunction(BaseModel): + name: str + arguments: str + + +class OpenAIToolCall(LLMToolCall): + id: str + type: Literal["function"] = "function" + function: OpenAIToolCallFunction + + +class GoogleToolCall(LLMToolCall): + name: str + arguments: Optional[dict] = None + + +class AnthropicToolCall(LLMToolCall): + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: object diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py new file mode 100644 index 00000000..ccf64e67 --- /dev/null +++ b/servers/fastapi/models/llm_tools.py @@ -0,0 +1,29 @@ +from typing import Any, Callable, Coroutine, Optional +from pydantic import BaseModel, Field + + +class LLMTool(BaseModel): + pass + + +class LLMDynamicTool(LLMTool): + name: str + description: str + parameters: dict = {} + handler: Callable[..., Coroutine[Any, Any, str]] + + +class SearchWebTool(LLMTool): + """ + Search the web for information. + """ + + query: str = Field(description="The query to search the web for") + + +class GetCurrentDatetimeTool(LLMTool): + """ + Get the current datetime. + """ + + pass diff --git a/servers/fastapi/models/ollama_model_metadata.py b/servers/fastapi/models/ollama_model_metadata.py index 1f8ed985..88c89cc6 100644 --- a/servers/fastapi/models/ollama_model_metadata.py +++ b/servers/fastapi/models/ollama_model_metadata.py @@ -4,5 +4,4 @@ from pydantic import BaseModel class OllamaModelMetadata(BaseModel): label: str value: str - icon: str size: str diff --git a/servers/fastapi/models/presentation_outline_model.py b/servers/fastapi/models/presentation_outline_model.py index ad55ae4b..01a3b2b7 100644 --- a/servers/fastapi/models/presentation_outline_model.py +++ b/servers/fastapi/models/presentation_outline_model.py @@ -2,8 +2,12 @@ from typing import List from pydantic import BaseModel +class SlideOutlineModel(BaseModel): + content: str + + class PresentationOutlineModel(BaseModel): - slides: List[str] + slides: List[SlideOutlineModel] def to_string(self): message = "" diff --git a/servers/fastapi/models/user_config.py b/servers/fastapi/models/user_config.py index 50783544..c040d22c 100644 --- a/servers/fastapi/models/user_config.py +++ b/servers/fastapi/models/user_config.py @@ -35,3 +35,6 @@ class UserConfig(BaseModel): TOOL_CALLS: Optional[bool] = None DISABLE_THINKING: Optional[bool] = None EXTENDED_REASONING: Optional[bool] = None + + # Web Search + WEB_GROUNDING: Optional[bool] = None diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml index f0caf8e3..14240244 100644 --- a/servers/fastapi/pyproject.toml +++ b/servers/fastapi/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "openai>=1.98.0", "pathvalidate>=3.3.1", "pdfplumber>=0.11.7", + "pytest>=8.4.1", "python-pptx>=1.0.2", "redis>=6.2.0", "sqlmodel>=0.0.24", diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index f016e763..3e8b35f2 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -1,16 +1,40 @@ import asyncio import json -from typing import List, Optional +from typing import AsyncGenerator, List, Optional from fastapi import HTTPException from openai import AsyncOpenAI +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) from google import genai -from google.genai.types import GenerateContentConfig +from google.genai.types import Content as GoogleContent, Part as GoogleContentPart +from google.genai.types import GenerateContentConfig, GoogleSearch +from google.genai.types import Tool as GoogleTool from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider -from models.llm_message import LLMMessage +from models.llm_message import ( + AnthropicAssistantMessage, + AnthropicUserMessage, + GoogleAssistantMessage, + GoogleToolCallMessage, + OpenAIAssistantMessage, + LLMMessage, + LLMSystemMessage, + LLMUserMessage, +) +from models.llm_tool_call import ( + AnthropicToolCall, + GoogleToolCall, + LLMToolCall, + OpenAIToolCall, + OpenAIToolCallFunction, +) +from models.llm_tools import LLMDynamicTool, LLMTool +from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async +from utils.dummy_functions import do_nothing_async from utils.get_env import ( get_anthropic_api_key_env, get_custom_llm_api_key_env, @@ -20,8 +44,9 @@ from utils.get_env import ( get_ollama_url_env, get_openai_api_key_env, get_tool_calls_env, + get_web_grounding_env, ) -from utils.llm_provider import get_llm_provider +from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none from utils.schema_utils import ensure_strict_json_schema @@ -30,17 +55,25 @@ class LLMClient: def __init__(self): self.llm_provider = get_llm_provider() self._client = self._get_client() + self.tool_calls_handler = LLMToolCallsHandler(self) # ? Use tool calls - def use_tool_calls(self) -> bool: + def use_tool_calls_for_structured_output(self) -> bool: if self.llm_provider != LLMProvider.CUSTOM: return False return parse_bool_or_none(get_tool_calls_env()) or False + # ? Web Grounding + def enable_web_grounding(self) -> bool: + if ( + self.llm_provider == LLMProvider.OLLAMA + or self.llm_provider == LLMProvider.CUSTOM + ): + return False + return parse_bool_or_none(get_web_grounding_env()) or False + # ? Disable thinking def disable_thinking(self) -> bool: - if self.llm_provider != LLMProvider.CUSTOM: - return False return parse_bool_or_none(get_disable_thinking_env()) or False # ? Clients @@ -106,15 +139,39 @@ class LLMClient: # ? Prompts def _get_system_prompt(self, messages: List[LLMMessage]) -> str: for message in messages: - if message.role == "system": + if isinstance(message, LLMSystemMessage): return message.content return "" - def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]: - return [message.content for message in messages if message.role == "user"] + def _get_google_messages(self, messages: List[LLMMessage]) -> List[str]: + contents = [] + for message in messages: + if isinstance(message, LLMUserMessage): + contents.append( + GoogleContent( + role="user", parts=[GoogleContentPart(text=message.content)] + ) + ) + elif isinstance(message, GoogleAssistantMessage): + contents.append(message.content) + elif isinstance(message, GoogleToolCallMessage): + contents.append( + GoogleContent( + role="user", + parts=[ + GoogleContentPart.from_function_response( + name=message.name, response=message.response + ) + ], + ) + ) - def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if message.role == "user"] + return contents + + def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + return [ + message for message in messages if not isinstance(message, LLMSystemMessage) + ] # ? Generate Unstructured Content async def _generate_openai( @@ -122,89 +179,250 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> str | None: client: AsyncOpenAI = self._client response = await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, + tools=tools, + extra_body=extra_body, ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + assistant_message = OpenAIAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls], + ) + new_messages = [ + *messages, + assistant_message, + *tool_call_messages, + ] + return await self._generate_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ) + return response.choices[0].message.content async def _generate_google( self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> str | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", max_output_tokens=max_tokens, ), ) - return response.text + + content = response.candidates[0].content + response_parts = content.parts + + if not response_parts: + return None + + text_content = None + tool_calls = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + if each_part.text: + text_content = each_part.text + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_anthropic( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> str | None: client: AsyncAnthropic = self._client + response: AnthropicMessage = await client.messages.create( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], + tools=tools, max_tokens=max_tokens or 4000, ) - text = "" + text_content = None + tool_calls: List[AnthropicToolCall] = [] for content in response.content: if content.type == "text" and isinstance(content.text, str): - text += content.text - if text == "": - return None - return text + text_content = content.text + + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_ollama( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): - return await self._generate_openai(model, messages, max_tokens) + return await self._generate_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) async def _generate_custom( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): - return await self._generate_openai(model, messages, max_tokens) + extra_body = {"enable_thinking": not self.disable_thinking()} + return await self._generate_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, + ) async def generate( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: - content = await self._generate_openai(model, messages, max_tokens) + content = await self._generate_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - content = await self._generate_google(model, messages, max_tokens) + content = await self._generate_google( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.ANTHROPIC: - content = await self._generate_anthropic(model, messages, max_tokens) + content = await self._generate_anthropic( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.OLLAMA: - content = await self._generate_ollama(model, messages, max_tokens) + content = await self._generate_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - content = await self._generate_custom(model, messages, max_tokens) + content = await self._generate_custom( + model=model, messages=messages, max_tokens=max_tokens + ) if content is None: raise HTTPException( status_code=400, @@ -220,21 +438,43 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> dict | None: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - response_format={ + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + parameters=response_schema, + handler=do_nothing_async, + ), + strict=strict, + ) + ) + + response = await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + response_format=( + { "type": "json_schema", "json_schema": ( { @@ -243,40 +483,66 @@ class LLMClient: "schema": response_schema, } ), - }, - max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) - content = response.choices[0].message.content - else: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", - max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) - tool_calls = response.choices[0].message.tool_calls - if tool_calls: - content = tool_calls[0].function.arguments + } + if not use_tool_calls_for_structured_output + else None + ), + max_completion_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + ) + content = response.choices[0].message.content + + tool_calls = response.choices[0].message.tool_calls + has_response_schema = False + + if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name == "ResponseSchema": + content = tool_call.function.arguments + has_response_schema = True + + if not has_response_schema: + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[each.model_dump() for each in parsed_tool_calls], + ), + *tool_call_messages, + ] + content = await self._generate_openai_structured( + model=model, + messages=new_messages, + response_format=response_schema, + strict=strict, + max_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + depth=depth + 1, + ) if content: - return json.loads(content) + if depth == 0: + return json.loads(content) + return content return None async def _generate_google_structured( @@ -285,31 +551,96 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> dict | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters_json_schema": response_format, + } + ] + ) + ) + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) - content = None - if response.text: - content = json.loads(response.text) - return content + content = response.candidates[0].content + response_parts = content.parts + text_content = None + + if not response_parts: + return None + + tool_calls: List[GoogleToolCall] = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + + if each_part.text: + text_content = each_part.text + + for each in tool_calls: + if each.name == "ResponseSchema": + return each.arguments + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + if text_content: + return json.loads(text_content) + return None async def _generate_anthropic_structured( self, model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client response: AnthropicMessage = await client.messages.create( @@ -317,7 +648,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -325,19 +656,51 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) - content: dict | None = None - for content_block in response.content: - if content_block.type == "tool_use": - content = content_block.input + tool_calls: List[AnthropicToolCall] = [] + for content in response.content: + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) - return content + for each in tool_calls: + if each.name == "ResponseSchema": + return each.input + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + return None async def _generate_ollama_structured( self, @@ -346,9 +709,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) async def _generate_custom_structured( @@ -358,9 +727,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): + extra_body = {"enable_thinking": not self.disable_thinking()} return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) async def generate_structured( @@ -369,29 +746,53 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ) -> dict: + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: content = await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: content = await self._generate_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: content = await self._generate_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: content = await self._generate_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) if content is None: raise HTTPException( @@ -406,90 +807,285 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - async with client.chat.completions.stream( + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + async for event in await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta + tools=tools, + extra_body=extra_body, + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], + ), + *tool_call_messages, + ] + async for event in self._stream_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ): + yield event async def _stream_google( self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + + tool_calls = None async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", + tools=google_tools, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client + async with client.messages.stream( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, + tools=tools, ) as stream: + tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "text" and isinstance(event.text, str): + + if event.type == "text": yield event.text + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + def _stream_ollama( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) def _stream_custom( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): - return self._stream_openai(model, messages, max_tokens) + extra_body = {"enable_thinking": not self.disable_thinking()} + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, + ) def stream( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - return self._stream_google(model, messages, max_tokens) + return self._stream_google( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.ANTHROPIC: - return self._stream_anthropic(model, messages, max_tokens) + return self._stream_anthropic( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.OLLAMA: - return self._stream_ollama(model, messages, max_tokens) + return self._stream_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - return self._stream_custom(model, messages, max_tokens) + return self._stream_custom( + model=model, messages=messages, max_tokens=max_tokens + ) # ? Stream Structured Content async def _stream_openai_structured( @@ -499,62 +1095,145 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() + response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - async with client.chat.completions.stream( - model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - response_format=( - { - "type": "json_schema", - "json_schema": { + + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + parameters=response_schema, + handler=do_nothing_async, + ), + strict=strict, + ) + ) + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + + has_response_schema_tool_call = False + async for event in await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + max_completion_tokens=max_tokens, + tools=all_tools, + response_format=( + { + "type": "json_schema", + "json_schema": ( + { "name": "ResponseSchema", "strict": strict, "schema": response_schema, - }, - } + } + ), + } + if not use_tool_calls_for_structured_output + else None + ), + extra_body=extra_body, + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_name == "ResponseSchema": + if tool_arguments: + yield tool_arguments + has_response_schema_tool_call = True + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls and not has_response_schema_tool_call: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], ), - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta - else: - async with client.chat.completions.stream( + *tool_call_messages, + ] + async for event in self._stream_openai_structured( model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) as stream: - async for event in stream: - if event.type == "tool_calls.function.arguments.delta": - yield event.arguments_delta + messages=new_messages, + max_tokens=max_tokens, + strict=strict, + tools=all_tools, + response_format=response_schema, + extra_body=extra_body, + depth=depth + 1, + ): + yield event async def _stream_google_structured( self, @@ -562,35 +1241,99 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: + client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters_json_schema": response_format, + } + ] + ) + ) + + tool_calls: List[GoogleToolCall] = [] + has_response_schema_tool_call = False async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + for each in tool_calls: + if each.name == "ResponseSchema": + has_response_schema_tool_call = True + if each.arguments: + yield json.dumps(each.arguments) + + if has_response_schema_tool_call: + continue + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic_structured( self, model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncAnthropic = self._client async with client.messages.stream( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -598,17 +1341,72 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) as stream: + tool_calls: List[AnthropicToolCall] = [] + has_response_schema_tool_call = False + is_response_schema_tool_call_started = False async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json" and isinstance(event.partial_json, str): - yield event.partial_json + if ( + event.type == "content_block_start" + and event.content_block.type == "tool_use" + ): + if event.content_block.name == "ResponseSchema": + has_response_schema_tool_call = True + is_response_schema_tool_call_started = True + + if ( + event.type == "content_block_delta" + and event.delta.type == "input_json_delta" + and is_response_schema_tool_call_started + ): + yield event.delta.partial_json + + if has_response_schema_tool_call: + continue + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama_structured( self, @@ -617,9 +1415,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) def _stream_custom_structured( @@ -629,9 +1433,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): + extra_body = {"enable_thinking": not self.disable_thinking()} return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) def stream_structured( @@ -640,26 +1452,93 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: return self._stream_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: return self._stream_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: return self._stream_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: return self._stream_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) + + # ? Web search + async def _search_openai(self, query: str) -> str: + client: AsyncOpenAI = self._client + response = await client.responses.create( + model=get_model(), + tools=[ + { + "type": "web_search_preview", + } + ], + input=query, + ) + return response.output_text + + async def _search_google(self, query: str) -> str: + client: genai.Client = self._client + grounding_tool = GoogleTool(google_search=GoogleSearch()) + config = GenerateContentConfig(tools=[grounding_tool]) + + response = await asyncio.to_thread( + client.models.generate_content, + model=get_model(), + contents=query, + config=config, + ) + return response.text + + async def _search_anthropic(self, query: str) -> str: + client: AsyncAnthropic = self._client + + response = await client.messages.create( + model=get_model(), + max_tokens=4000, + messages=[{"role": "user", "content": query}], + tools=[ + {"type": "web_search_20250305", "name": "web_search", "max_uses": 1} + ], + ) + result = "\n".join( + [each.text for each in response.content if each.type == "text"] + ) + return result diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py new file mode 100644 index 00000000..ed0d51ee --- /dev/null +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -0,0 +1,201 @@ +import asyncio +from datetime import datetime +import json +from typing import Any, Callable, Coroutine, List, Optional +from fastapi import HTTPException +from enums.llm_provider import LLMProvider +from models.llm_message import ( + AnthropicToolCallMessage, + GoogleToolCallMessage, + OpenAIToolCallMessage, +) +from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall +from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema + + +class LLMToolCallsHandler: + def __init__(self, client): + from services.llm_client import LLMClient + + self.client: LLMClient = client + + self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = { + "SearchWebTool": self.search_web_tool_call_handler, + "GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler, + } + self.dynamic_tools: List[LLMDynamicTool] = [] + + def get_tool_handler( + self, tool_name: str + ) -> Callable[..., Coroutine[Any, Any, str]]: + handler = self.tools_map.get(tool_name) + if handler: + return handler + else: + dynamic_tools = list( + filter(lambda tool: tool.name == tool_name, self.dynamic_tools) + ) + if dynamic_tools: + return dynamic_tools[0].handler + raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found") + + def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): + if tools is None: + return None + parsed_tools = map(self.parse_tool, tools) + return list(parsed_tools) + + def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False): + if isinstance(tool, LLMDynamicTool): + self.dynamic_tools.append(tool) + + match self.client.llm_provider: + case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM: + return self.parse_tool_openai(tool, strict) + case LLMProvider.ANTHROPIC: + return self.parse_tool_anthropic(tool) + case LLMProvider.GOOGLE: + return self.parse_tool_google(tool) + case _: + raise ValueError( + f"LLM provider must be either openai, anthropic, or google" + ) + + def parse_tool_openai( + self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False + ): + if isinstance(tool, LLMDynamicTool): + name = tool.name + description = tool.description + parameters = tool.parameters + else: + name = tool.__name__ + description = tool.__doc__ or "" + parameters = tool.model_json_schema() + + if strict: + parameters = ensure_strict_json_schema(parameters, path=(), root=parameters) + + return { + "type": "function", + "function": { + "name": name, + "description": description, + "strict": strict, + "parameters": parameters, + }, + } + + def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + # parsed["function"]["parameters"] = flatten_json_schema( + # parsed["function"]["parameters"] + # ) + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "parameters": parsed["function"]["parameters"], + } + + def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + input_schema = parsed["function"]["parameters"] + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "input_schema": {"type": "object"} if input_schema == {} else input_schema, + } + + async def handle_tool_calls_openai( + self, + tool_calls: List[OpenAIToolCall], + ) -> List[OpenAIToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments)) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + OpenAIToolCallMessage( + content=result, + tool_call_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_google( + self, + tool_calls: List[GoogleToolCall], + ) -> List[GoogleToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + GoogleToolCallMessage( + name=tool_call.name, + response={"result": result}, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_anthropic( + self, + tool_calls: List[AnthropicToolCall], + ) -> List[AnthropicToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + AnthropicToolCallMessage( + content=result, + tool_use_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + # ? Tool call handlers + # Search web tool call handler + async def search_web_tool_call_handler(self, arguments: str) -> str: + match self.client.llm_provider: + case LLMProvider.OPENAI: + return await self.search_web_tool_call_handler_openai(arguments) + case LLMProvider.ANTHROPIC: + return await self.search_web_tool_call_handler_anthropic(arguments) + case LLMProvider.GOOGLE: + return await self.search_web_tool_call_handler_google(arguments) + case _: + return ( + "Web search tool call handler not implemented for this LLM provider: " + + self.client.llm_provider.value + ) + + async def search_web_tool_call_handler_openai(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_openai(args.query) + + async def search_web_tool_call_handler_google(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_google(args.query) + + async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_anthropic(args.query) + + # Get current datetime tool call handler + async def get_current_datetime_tool_call_handler(self, _) -> str: + current_time = datetime.now() + return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}" diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py deleted file mode 100644 index f2e3d8c9..00000000 --- a/servers/fastapi/services/redis_service.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Optional -import redis -from redis.exceptions import RedisError - -from utils.get_env import ( - get_redis_db_env, - get_redis_host_env, - get_redis_password_env, - get_redis_port_env, -) - - -class RedisService: - def __init__(self): - self.redis_host = get_redis_host_env() or "localhost" - self.redis_port = int(get_redis_port_env() or "6379") - self.redis_db = int(get_redis_db_env() or "0") - self.redis_password = get_redis_password_env() or None - self.client = self._create_client() - - def _create_client(self) -> redis.Redis: - return redis.Redis( - host=self.redis_host, - port=self.redis_port, - db=self.redis_db, - password=self.redis_password, - decode_responses=True, - ) - - def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: - try: - return self.client.set(key, value, ex=expire) - except RedisError: - return False - - def get(self, key: str) -> Optional[str]: - try: - return self.client.get(key) - except RedisError: - return None - - def delete(self, key: str) -> bool: - try: - return bool(self.client.delete(key)) - except RedisError: - return False - - def exists(self, key: str) -> bool: - try: - return bool(self.client.exists(key)) - except RedisError: - return False - - def set_hash(self, name: str, mapping: dict) -> bool: - try: - return self.client.hmset(name, mapping) - except RedisError: - return False - - def get_hash(self, name: str) -> Optional[dict]: - try: - return self.client.hgetall(name) - except RedisError: - return None - - def delete_hash(self, name: str, *fields: str) -> int: - try: - return self.client.hdel(name, *fields) - except RedisError: - return 0 - - def set_list(self, name: str, values: list) -> bool: - try: - self.client.delete(name) - if values: - self.client.rpush(name, *values) - return True - except RedisError: - return False - - def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: - try: - return self.client.lrange(name, start, end) - except RedisError: - return None - - def add_to_set(self, name: str, *values: str) -> int: - try: - return self.client.sadd(name, *values) - except RedisError: - return 0 - - def get_set(self, name: str) -> Optional[set]: - try: - return self.client.smembers(name) - except RedisError: - return None - - def remove_from_set(self, name: str, *values: str) -> int: - try: - return self.client.srem(name, *values) - except RedisError: - return 0 - - def clear(self) -> bool: - try: - return self.client.flushdb() - except RedisError: - return False - - def close(self): - try: - self.client.close() - except RedisError: - pass diff --git a/servers/fastapi/static/icons/deepseek.png b/servers/fastapi/static/icons/deepseek.png deleted file mode 100644 index 798b8f18..00000000 Binary files a/servers/fastapi/static/icons/deepseek.png and /dev/null differ diff --git a/servers/fastapi/static/icons/gemma.png b/servers/fastapi/static/icons/gemma.png deleted file mode 100644 index 647d87a2..00000000 Binary files a/servers/fastapi/static/icons/gemma.png and /dev/null differ diff --git a/servers/fastapi/static/icons/meta.png b/servers/fastapi/static/icons/meta.png deleted file mode 100644 index 0a3d82c1..00000000 Binary files a/servers/fastapi/static/icons/meta.png and /dev/null differ diff --git a/servers/fastapi/static/icons/qwen.png b/servers/fastapi/static/icons/qwen.png deleted file mode 100644 index 2cee1c36..00000000 Binary files a/servers/fastapi/static/icons/qwen.png and /dev/null differ diff --git a/servers/fastapi/utils/dummy_functions.py b/servers/fastapi/utils/dummy_functions.py new file mode 100644 index 00000000..461e9695 --- /dev/null +++ b/servers/fastapi/utils/dummy_functions.py @@ -0,0 +1,2 @@ +async def do_nothing_async(_): + return None diff --git a/servers/fastapi/utils/get_dynamic_models.py b/servers/fastapi/utils/get_dynamic_models.py index 744a6a5a..fd4b2bda 100644 --- a/servers/fastapi/utils/get_dynamic_models.py +++ b/servers/fastapi/utils/get_dynamic_models.py @@ -1,13 +1,23 @@ from typing import List from pydantic import Field -from models.presentation_outline_model import PresentationOutlineModel +from models.presentation_outline_model import ( + PresentationOutlineModel, + SlideOutlineModel, +) from models.presentation_structure_model import PresentationStructureModel def get_presentation_outline_model_with_n_slides(n_slides: int): + class SlideOutlineModelWithNSlides(SlideOutlineModel): + content: str = Field( + description="Markdown content for each slide", + min_length=100, + max_length=300, + ) + class PresentationOutlineModelWithNSlides(PresentationOutlineModel): - slides: List[str] = Field( - description="Markdown content for each slide in about 100 to 200 words", + slides: List[SlideOutlineModelWithNSlides] = Field( + description="List of slide outlines", min_items=n_slides, max_items=n_slides, ) diff --git a/servers/fastapi/utils/get_env.py b/servers/fastapi/utils/get_env.py index c2c72efd..fa80b2a2 100644 --- a/servers/fastapi/utils/get_env.py +++ b/servers/fastapi/utils/get_env.py @@ -81,22 +81,6 @@ def get_pixabay_api_key_env(): return os.getenv("PIXABAY_API_KEY") -def get_redis_host_env(): - return os.getenv("REDIS_HOST") - - -def get_redis_port_env(): - return os.getenv("REDIS_PORT") - - -def get_redis_db_env(): - return os.getenv("REDIS_DB") - - -def get_redis_password_env(): - return os.getenv("REDIS_PASSWORD") - - def get_tool_calls_env(): return os.getenv("TOOL_CALLS") @@ -107,3 +91,7 @@ def get_disable_thinking_env(): def get_extended_reasoning_env(): return os.getenv("EXTENDED_REASONING") + + +def get_web_grounding_env(): + return os.getenv("WEB_GROUNDING") diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py index a8df598a..30599d08 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide.py +++ b/servers/fastapi/utils/llm_calls/edit_slide.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.sql.slide import SlideModel from services.llm_client import LLMClient @@ -41,12 +41,10 @@ def get_messages( language: str, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, slide_data, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/edit_slide_html.py b/servers/fastapi/utils/llm_calls/edit_slide_html.py index a5e2dfad..cf58d185 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide_html.py +++ b/servers/fastapi/utils/llm_calls/edit_slide_html.py @@ -1,5 +1,5 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from services.llm_client import LLMClient from utils.llm_provider import get_model @@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str): response = await client.generate( model=model, messages=[ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=get_user_prompt(prompt, html)), + LLMSystemMessage(content=system_prompt), + LLMUserMessage(content=get_user_prompt(prompt, html)), ], ) return extract_html_from_response(response) or html diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 507bb6eb..892b9cff 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,10 +1,13 @@ -import asyncio from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage +from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides +from utils.get_env import get_web_grounding_env from utils.llm_provider import get_model +from utils.parsers import parse_bool_or_none +from utils.user_config import get_user_config system_prompt = """ You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content. @@ -29,12 +32,10 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, n_slides, language, content), ), ] @@ -51,10 +52,13 @@ async def generate_ppt_outline( client = LLMClient() + tools = [SearchWebTool, GetCurrentDatetimeTool] + async for chunk in client.stream_structured( model, get_messages(prompt, n_slides, language, content), response_model.model_json_schema(), strict=True, + tools=tools if client.enable_web_grounding() else None, ): yield chunk diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py index 47f47dba..1bfc0cd0 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from services.llm_client import LLMClient @@ -11,8 +11,7 @@ def get_messages( presentation_layout: PresentationLayoutModel, n_slides: int, data: str ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" You're a professional presentation designer with creative freedom to design engaging presentations. @@ -47,8 +46,7 @@ def get_messages( Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals. """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" {data} """, diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py index ecff518a..be19b168 100644 --- a/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -1,5 +1,6 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel +from models.presentation_outline_model import SlideOutlineModel from services.llm_client import LLMClient from utils.llm_provider import get_model from utils.schema_utils import remove_fields_from_schema @@ -38,19 +39,17 @@ def get_user_prompt(outline: str, language: str): def get_messages(outline: str, language: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(outline, language), ), ] async def get_slide_content_from_type_and_outline( - slide_layout: SlideLayoutModel, outline: str, language: str + slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str ): client = LLMClient() model = get_model() @@ -62,7 +61,7 @@ async def get_slide_content_from_type_and_outline( response = await client.generate_structured( model=model, messages=get_messages( - outline, + outline.content, language, ), response_format=response_schema, diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py index f3532b48..7235e558 100644 --- a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py +++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel from models.slide_layout_index import SlideLayoutIndex from models.sql.slide import SlideModel @@ -13,8 +13,7 @@ def get_messages( current_slide_layout: int, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" Select a Slide Layout index based on provided user prompt and current slide data. {layout.to_string()} @@ -26,8 +25,7 @@ def get_messages( **Go through all notes and steps and make sure they are followed, including mentioned constraints** """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" - User Prompt: {prompt} - Current Slide Data: {slide_data} diff --git a/servers/fastapi/utils/schema_utils.py b/servers/fastapi/utils/schema_utils.py index ae65f002..6cb01a0e 100644 --- a/servers/fastapi/utils/schema_utils.py +++ b/servers/fastapi/utils/schema_utils.py @@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved +# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions +def flatten_json_schema(schema: dict) -> dict: + root_schema = deepcopy(schema) + + def _flatten(node: Any) -> Any: + if isinstance(node, dict): + # If node is a pure $ref (or combined with extra fields), inline it + if "$ref" in node: + ref_value = node["$ref"] + assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}" + resolved = resolve_ref(root=root_schema, ref=ref_value) + assert isinstance(resolved, dict), ( + f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}" + ) + # Merge: referenced first, then overlay current (excluding $ref) + merged: dict[str, Any] = deepcopy(resolved) + for key, value in node.items(): + if key == "$ref": + continue + merged[key] = value + return _flatten(merged) + + flattened: dict[str, Any] = {} + for key, value in node.items(): + # Drop defs/definitions in output + if key in ("$defs", "definitions"): + continue + if key == "properties" and isinstance(value, dict): + flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()} + elif key in ("items", "contains", "additionalProperties", "not"): + if isinstance(value, dict): + flattened[key] = _flatten(value) + elif isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = value + elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value + return flattened + if isinstance(node, list): + return [_flatten(v) for v in node] + return node + + result = _flatten(schema) + # Ensure top-level cleanup just in case + if isinstance(result, dict): + result.pop("$defs", None) + result.pop("definitions", None) + return result + + # ? Not used def generate_constraint_sentences(schema: dict) -> str: """ diff --git a/servers/fastapi/utils/set_env.py b/servers/fastapi/utils/set_env.py index 7ac0e335..ea3758f3 100644 --- a/servers/fastapi/utils/set_env.py +++ b/servers/fastapi/utils/set_env.py @@ -79,3 +79,7 @@ def set_disable_thinking_env(value): def set_extended_reasoning_env(value): os.environ["EXTENDED_REASONING"] = value + + +def set_web_grounding_env(value): + os.environ["WEB_GROUNDING"] = value \ No newline at end of file diff --git a/servers/fastapi/utils/user_config.py b/servers/fastapi/utils/user_config.py index 06235d5a..49fd1722 100644 --- a/servers/fastapi/utils/user_config.py +++ b/servers/fastapi/utils/user_config.py @@ -22,6 +22,7 @@ from utils.get_env import ( get_image_provider_env, get_pixabay_api_key_env, get_extended_reasoning_env, + get_web_grounding_env, ) from utils.parsers import parse_bool_or_none from utils.set_env import ( @@ -43,6 +44,7 @@ from utils.set_env import ( set_image_provider_env, set_pixabay_api_key_env, set_tool_calls_env, + set_web_grounding_env, ) @@ -76,12 +78,26 @@ def get_user_config(): IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(), PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(), - TOOL_CALLS=existing_config.TOOL_CALLS - or parse_bool_or_none(get_tool_calls_env()), - DISABLE_THINKING=existing_config.DISABLE_THINKING - or parse_bool_or_none(get_disable_thinking_env()), - EXTENDED_REASONING=existing_config.EXTENDED_REASONING - or parse_bool_or_none(get_extended_reasoning_env()), + TOOL_CALLS=( + existing_config.TOOL_CALLS + if existing_config.TOOL_CALLS is not None + else (parse_bool_or_none(get_tool_calls_env()) or False) + ), + DISABLE_THINKING=( + existing_config.DISABLE_THINKING + if existing_config.DISABLE_THINKING is not None + else (parse_bool_or_none(get_disable_thinking_env()) or False) + ), + EXTENDED_REASONING=( + existing_config.EXTENDED_REASONING + if existing_config.EXTENDED_REASONING is not None + else (parse_bool_or_none(get_extended_reasoning_env()) or False) + ), + WEB_GROUNDING=( + existing_config.WEB_GROUNDING + if existing_config.WEB_GROUNDING is not None + else (parse_bool_or_none(get_web_grounding_env()) or False) + ), ) @@ -122,5 +138,6 @@ def update_env_with_user_config(): if user_config.DISABLE_THINKING: set_disable_thinking_env(str(user_config.DISABLE_THINKING)) if user_config.EXTENDED_REASONING: - if user_config.EXTENDED_REASONING: - set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + if user_config.WEB_GROUNDING: + set_web_grounding_env(str(user_config.WEB_GROUNDING)) diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock index b579f42a..4e19c02b 100644 --- a/servers/fastapi/uv.lock +++ b/servers/fastapi/uv.lock @@ -1061,6 +1061,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "isodate" version = "0.7.2" @@ -1907,6 +1916,7 @@ dependencies = [ { name = "openai" }, { name = "pathvalidate" }, { name = "pdfplumber" }, + { name = "pytest" }, { name = "python-pptx" }, { name = "redis" }, { name = "sqlmodel" }, @@ -1928,6 +1938,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.98.0" }, { name = "pathvalidate", specifier = ">=3.3.1" }, { name = "pdfplumber", specifier = ">=0.11.7" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "python-pptx", specifier = ">=1.0.2" }, { name = "redis", specifier = ">=6.2.0" }, { name = "sqlmodel", specifier = ">=0.0.24" }, @@ -2211,6 +2222,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + [[package]] name = "python-bidi" version = "0.6.6" diff --git a/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx b/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx index f7ef3756..d1612d74 100644 --- a/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx +++ b/servers/nextjs/app/(presentation-generator)/context/LayoutContext.tsx @@ -12,6 +12,7 @@ import * as z from "zod"; import { useDispatch } from "react-redux"; import { setLayoutLoading } from "@/store/slices/presentationGeneration"; import * as Babel from "@babel/standalone"; +import * as Recharts from "recharts"; export interface LayoutInfo { id: string; name?: string; @@ -99,14 +100,17 @@ const compileCustomLayout = (layoutCode: string, React: any, z: any) => { const factory = new Function( "React", "_z", + "Recharts", ` const z = _z; + // Expose commonly used Recharts components to compiled layouts + const { ResponsiveContainer, LineChart, Line, BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, Legend, PieChart, Pie, Cell, AreaChart, Area, RadarChart, Radar, PolarGrid, PolarAngleAxis, PolarRadiusAxis } = Recharts || {}; ${compiled} /* everything declared in the string is in scope here */ return { __esModule: true, - default: dynamicSlideLayout, + default: typeof dynamicSlideLayout !== 'undefined' ? dynamicSlideLayout : (typeof DefaultLayout !== 'undefined' ? DefaultLayout : undefined), layoutName, layoutId, layoutDescription, @@ -115,7 +119,7 @@ const compileCustomLayout = (layoutCode: string, React: any, z: any) => { ` ); - return factory(React, z); + return factory(React, z, Recharts); }; export const LayoutProvider: React.FC<{ diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useAPIKeyCheck.ts b/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useAPIKeyCheck.ts deleted file mode 100644 index 5164e7df..00000000 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useAPIKeyCheck.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { useState, useEffect } from "react"; - -export const useAPIKeyCheck = () => { - const [hasRequiredKey, setHasAnthropicKey] = useState(false); - const [isRequiredKeyLoading, setIsAnthropicKeyLoading] = useState(true); - - useEffect(() => { - fetch("/api/has-required-key") - .then((res) => res.json()) - .then((data) => { - setHasAnthropicKey(data.hasKey); - setIsAnthropicKeyLoading(false); - }); - }, []); - - return { hasRequiredKey, isRequiredKeyLoading }; -}; \ No newline at end of file diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/APIKeyWarning.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/APIKeyWarning.tsx similarity index 65% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/APIKeyWarning.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/APIKeyWarning.tsx index 44b2e0d2..ad803232 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/APIKeyWarning.tsx +++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/APIKeyWarning.tsx @@ -10,6 +10,10 @@ export const APIKeyWarning: React.FC = () => {

Please add "GOOGLE_API_KEY" to enable template creation via AI.

+

Please add your OpenAI API Key to process the layout

+

+ This feature requires an OpenAI model GPT-5. Configure your key in settings or via environment variables. +

diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/EditControls.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/EditControls.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/EditControls.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/EditControls.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/HtmlEditor.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/HtmlEditor.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/HtmlEditor.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/HtmlEditor.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/NewEachSlide.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/NewEachSlide.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/NewEachSlide.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/NewEachSlide.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideActions.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideActions.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideActions.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideActions.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideContentDisplay.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideContentDisplay.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/EachSlide/SlideContentDisplay.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/EachSlide/SlideContentDisplay.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/FileUploadSection.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/FileUploadSection.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/FileUploadSection.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/FileUploadSection.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/FontManager.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/FontManager.tsx similarity index 99% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/FontManager.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/FontManager.tsx index 7ee5df52..5ce0f53b 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/FontManager.tsx +++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/FontManager.tsx @@ -224,7 +224,7 @@ const FontManager: React.FC = ({ onClick={processSlideToHtml} className="text-xs px-8 py-2 font-semibold bg-blue-600 text-white hover:text-white hover:bg-blue-700 border-blue-600" > - Extract layouts + Extract Template diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/LoadingSpinner.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/LoadingSpinner.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/LoadingSpinner.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/LoadingSpinner.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutButton.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutButton.tsx similarity index 94% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutButton.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutButton.tsx index 9123b912..6badcf5b 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutButton.tsx +++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutButton.tsx @@ -25,12 +25,12 @@ export const SaveLayoutButton: React.FC = ({ {isSaving ? ( <> - Saving Layout... + Saving Template... ) : ( <> - Save Layout + Save Template )} diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutModal.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutModal.tsx similarity index 91% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutModal.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutModal.tsx index 57924e7a..40fa3a89 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SaveLayoutModal.tsx +++ b/servers/nextjs/app/(presentation-generator)/custom-template/components/SaveLayoutModal.tsx @@ -53,22 +53,22 @@ export const SaveLayoutModal: React.FC = ({ - Save Layout + Save Template - Enter a name and description for your layout. This will help you identify it later. + Enter a name and description for your template. This will help you identify it later.
setLayoutName(e.target.value)} - placeholder="Enter layout name..." + placeholder="Enter template name..." disabled={isSaving} className="w-full" /> @@ -81,7 +81,7 @@ export const SaveLayoutModal: React.FC = ({ id="description" value={description} onChange={(e) => setDescription(e.target.value)} - placeholder="Enter a description for your layout..." + placeholder="Enter a description for your template..." disabled={isSaving} className="w-full resize-none" rows={3} @@ -109,7 +109,7 @@ export const SaveLayoutModal: React.FC = ({ ) : ( <> - Save Layout + Save Template )} diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/SlideContent.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/SlideContent.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/SlideContent.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/SlideContent.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/components/Timer.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/components/Timer.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/components/Timer.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/components/Timer.tsx diff --git a/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useAPIKeyCheck.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useAPIKeyCheck.ts new file mode 100644 index 00000000..f16dd738 --- /dev/null +++ b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useAPIKeyCheck.ts @@ -0,0 +1,18 @@ +import { useState, useEffect } from "react"; + +export const useAPIKeyCheck = () => { + const [hasRequiredKey, setHasRequiredKey] = useState(false); + const [isRequiredKeyLoading, setIsRequiredKeyLoading] = useState(true); + + useEffect(() => { + fetch("/api/has-required-key") + .then((res) => res.json()) + .then((data) => { + setHasRequiredKey(Boolean(data.hasKey)); + setIsRequiredKeyLoading(false); + }) + .catch(() => setIsRequiredKeyLoading(false)); + }, []); + + return { hasRequiredKey, isRequiredKeyLoading }; +}; \ No newline at end of file diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useCustomLayout.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useCustomLayout.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useCustomLayout.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useCustomLayout.ts diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useDrawingCanvas.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useDrawingCanvas.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useDrawingCanvas.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useDrawingCanvas.ts diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFileUpload.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFileUpload.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFileUpload.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFileUpload.ts diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFontManagement.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFontManagement.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useFontManagement.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useFontManagement.ts diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useHtmlEdit.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useHtmlEdit.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useHtmlEdit.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useHtmlEdit.ts diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useLayoutSaving.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useLayoutSaving.ts similarity index 97% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useLayoutSaving.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useLayoutSaving.ts index 8433a86d..570e73f2 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useLayoutSaving.ts +++ b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useLayoutSaving.ts @@ -27,6 +27,11 @@ export const useLayoutSaving = ( const maxRetries = 3; let retryCount = 0; + console.log("Slide to convert to react", { + html: slide.html, + image: slide.screenshot_url, + }) + while (retryCount < maxRetries) { try { const response = await fetch("/api/v1/ppt/html-to-react/", { @@ -36,6 +41,7 @@ export const useLayoutSaving = ( }, body: JSON.stringify({ html: slide.html, + image: slide.screenshot_url, }), }); diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideEdit.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideEdit.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideEdit.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideEdit.ts diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideProcessing.ts b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideProcessing.ts similarity index 96% rename from servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideProcessing.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideProcessing.ts index 4a677a68..544c4804 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/hooks/useSlideProcessing.ts +++ b/servers/nextjs/app/(presentation-generator)/custom-template/hooks/useSlideProcessing.ts @@ -35,6 +35,7 @@ export const useSlideProcessing = ( body: JSON.stringify({ image: slide.screenshot_url, xml: slide.xml_content, + fonts: slide.normalized_fonts ?? [], }), }); @@ -157,7 +158,10 @@ export const useSlideProcessing = ( setSlides(initialSlides); toast.success( - `Successfully extracted ${pptxData.slides.length} slides! Converting to HTML...` + `Template Processing Finished`, + { + description: `Please Upload the not supported fonts, and click Extract Template` + } ); diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/page.tsx b/servers/nextjs/app/(presentation-generator)/custom-template/page.tsx similarity index 95% rename from servers/nextjs/app/(presentation-generator)/custom-layout/page.tsx rename to servers/nextjs/app/(presentation-generator)/custom-template/page.tsx index 1a6d2415..652e1c33 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/page.tsx +++ b/servers/nextjs/app/(presentation-generator)/custom-template/page.tsx @@ -18,7 +18,7 @@ import { APIKeyWarning } from "./components/APIKeyWarning"; import { useAPIKeyCheck } from "./hooks/useAPIKeyCheck"; -const CustomLayoutPage = () => { +const CustomTemplatePage = () => { const { refetch } = useLayout(); // Custom hooks for different concerns @@ -66,8 +66,9 @@ const CustomLayoutPage = () => { // Anthropic key warning if (!hasRequiredKey) { return ; - } + + } return (
@@ -75,7 +76,7 @@ const CustomLayoutPage = () => { {/* Header */}

- Custom Layout Processor + Custom Template Processor

Upload your PPTX file to extract slides and convert them to @@ -126,7 +127,7 @@ const CustomLayoutPage = () => {

)} - {/* Floating Save Layout Button */} + {/* Floating Save Template Button */} {slides.length > 0 && slides.some((s) => s.processed) && ( { /> )} - {/* Save Layout Modal */} + {/* Save Template Modal */} { ); }; -export default CustomLayoutPage; +export default CustomTemplatePage; + + diff --git a/servers/nextjs/app/(presentation-generator)/custom-layout/types/index.ts b/servers/nextjs/app/(presentation-generator)/custom-template/types/index.ts similarity index 99% rename from servers/nextjs/app/(presentation-generator)/custom-layout/types/index.ts rename to servers/nextjs/app/(presentation-generator)/custom-template/types/index.ts index 64e260ad..8e57a035 100644 --- a/servers/nextjs/app/(presentation-generator)/custom-layout/types/index.ts +++ b/servers/nextjs/app/(presentation-generator)/custom-template/types/index.ts @@ -3,6 +3,7 @@ export interface SlideData { slide_number: number; screenshot_url: string; xml_content: string; + normalized_fonts?: string[]; } export interface UploadedFont { diff --git a/servers/nextjs/app/(presentation-generator)/dashboard/components/Header.tsx b/servers/nextjs/app/(presentation-generator)/dashboard/components/Header.tsx index 13ebb205..c6b1a507 100644 --- a/servers/nextjs/app/(presentation-generator)/dashboard/components/Header.tsx +++ b/servers/nextjs/app/(presentation-generator)/dashboard/components/Header.tsx @@ -25,13 +25,13 @@ const Header = () => {
- Layouts + Templates
diff --git a/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx b/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx index bc5ee297..329a5ef1 100644 --- a/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx +++ b/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx @@ -1,10 +1,10 @@ import React from "react"; import { Button } from "@/components/ui/button"; -import { LoadingState, StreamState, LayoutGroup } from "../types/index"; +import { LoadingState, LayoutGroup } from "../types/index"; interface GenerateButtonProps { loadingState: LoadingState; - streamState: StreamState; + streamState: { isStreaming: boolean, isLoading: boolean }; selectedLayoutGroup: LayoutGroup | null; onSubmit: () => void; } @@ -23,7 +23,7 @@ const GenerateButton: React.FC = ({ const getButtonText = () => { if (loadingState.isLoading) return loadingState.message; if (streamState.isLoading || streamState.isStreaming) return "Loading..."; - if (!selectedLayoutGroup) return "Select a Layout Style"; + if (!selectedLayoutGroup) return "Select a Templae"; return "Generate Presentation"; }; diff --git a/servers/nextjs/app/(presentation-generator)/outline/components/LayoutSelection.tsx b/servers/nextjs/app/(presentation-generator)/outline/components/LayoutSelection.tsx index 0824fa1c..e5538cff 100644 --- a/servers/nextjs/app/(presentation-generator)/outline/components/LayoutSelection.tsx +++ b/servers/nextjs/app/(presentation-generator)/outline/components/LayoutSelection.tsx @@ -82,10 +82,10 @@ const LayoutSelection: React.FC = ({
- No Layout Styles Available + No Templates Available

- No presentation layout styles could be loaded. Please try refreshing the page. + No presentation templates could be loaded. Please try refreshing the page.

diff --git a/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx b/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx index a305da91..5764fe6d 100644 --- a/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx +++ b/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx @@ -18,7 +18,7 @@ import { Button } from "@/components/ui/button"; import { FileText } from "lucide-react"; interface OutlineContentProps { - outlines: string[] | null; + outlines: { content: string }[] | null; isLoading: boolean; isStreaming: boolean; onDragEnd: (event: any) => void; @@ -32,7 +32,7 @@ const OutlineContent: React.FC = ({ onDragEnd, onAddSlide }) => { - + console.log('isLoading', isLoading) const sensors = useSensors( useSensor(PointerSensor), useSensor(KeyboardSensor, { @@ -83,7 +83,18 @@ const OutlineContent: React.FC = ({ collisionDetection={closestCenter} onDragEnd={onDragEnd} > - ( + + )) + ) : + ({ id: `slide-${index}` })) || []} strategy={verticalListSortingStrategy} > @@ -95,7 +106,7 @@ const OutlineContent: React.FC = ({ isStreaming={isStreaming} /> ))} - + } + + {slug.includes('custom-') && } +
+ +
+

+ {layoutGroup[0].groupName} Layouts +

+

+ {layoutGroup.length} layout{layoutGroup.length !== 1 ? "s" : ""} •{" "} + {layoutGroup[0].groupName} +

+
+ +
+ + + {/* Layout Grid */} +
+
+ {layoutGroup.map((layout: any, index: number) => { + const { + component: LayoutComponent, + sampleData, + name, + fileName, + } = layout; + + return ( + + {/* Layout Header */} +
+
+
+

+ {name} +

+
+ + {fileName} + + + {layoutGroup[0].groupName} + +
+
+
+
+ Layout #{index + 1} +
+
+
+
+ + {/* Layout Content */} +
+ +
+
+ ); + })} +
+
+ + {/* Footer */} +
+
+
+

+ {layoutGroup[0].groupName} • {layoutGroup.length} components +

+
+
+
+ + ); +}; + +export default GroupLayoutPreview; \ No newline at end of file diff --git a/servers/nextjs/app/(presentation-generator)/layout-preview/components/LoadingStates.tsx b/servers/nextjs/app/(presentation-generator)/template-preview/components/LoadingStates.tsx similarity index 100% rename from servers/nextjs/app/(presentation-generator)/layout-preview/components/LoadingStates.tsx rename to servers/nextjs/app/(presentation-generator)/template-preview/components/LoadingStates.tsx diff --git a/servers/nextjs/app/(presentation-generator)/layout-preview/hooks/useGroupLayoutLoader.ts b/servers/nextjs/app/(presentation-generator)/template-preview/hooks/useGroupLayoutLoader.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/layout-preview/hooks/useGroupLayoutLoader.ts rename to servers/nextjs/app/(presentation-generator)/template-preview/hooks/useGroupLayoutLoader.ts diff --git a/servers/nextjs/app/(presentation-generator)/layout-preview/hooks/useLayoutLoader.ts b/servers/nextjs/app/(presentation-generator)/template-preview/hooks/useLayoutLoader.ts similarity index 84% rename from servers/nextjs/app/(presentation-generator)/layout-preview/hooks/useLayoutLoader.ts rename to servers/nextjs/app/(presentation-generator)/template-preview/hooks/useLayoutLoader.ts index 686e360e..014e8514 100644 --- a/servers/nextjs/app/(presentation-generator)/layout-preview/hooks/useLayoutLoader.ts +++ b/servers/nextjs/app/(presentation-generator)/template-preview/hooks/useLayoutLoader.ts @@ -23,9 +23,9 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { setLoading(true) setError(null) - const response = await fetch('/api/layouts') + const response = await fetch('/api/templates') if (!response.ok) { - toast.error('Error loading layouts', { + toast.error('Error loading templates', { description: response.statusText, }) return @@ -38,7 +38,7 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { const groupLayouts: LayoutInfo[] = [] const groupSettings: GroupSetting = groupData.settings ? groupData.settings : { - description: `${groupData.groupName} presentation layouts`, + description: `${groupData.groupName} presentation templates`, ordered: false, default: false } @@ -50,7 +50,7 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { if (!module.default) { toast.error(`${layoutName} has no default export`, { - description: 'Please ensure the layout file exports a default component', + description: 'Please ensure the template file exports a default component', }) console.warn(`${layoutName} has no default export`) throw new Error(`${layoutName} has no default export`) @@ -58,14 +58,12 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { if (!module.Schema) { toast.error(`${layoutName} is missing required Schema export`, { - description: 'Please ensure the layout file exports a Schema', + description: 'Please ensure the template file exports a Schema', }) console.error(`${layoutName} is missing required Schema export`) throw new Error(`${layoutName} is missing required Schema export`) } - // Use empty object to let schema apply its default values - // User will need to provide actual data when using the layouts const sampleData = module.Schema.parse({}) const layoutId = module.layoutId || layoutName.toLowerCase().replace(/layout$/, '') @@ -85,15 +83,12 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { } catch (importError) { console.error(`Failed to import ${fileName} from ${groupData.groupName}:`, importError) - // Try alternative import path try { const layoutName = fileName.replace('.tsx', '').replace('.ts', '') const module = await import(`@/presentation-layouts/${groupData.groupName}/${layoutName}`) if (module.default && module.Schema) { - // Use empty object to let schema apply its default values const sampleData = module.Schema.parse({}) - // if layoutId is not provided, use the layoutName const layoutId = module.layoutId || layoutName.toLowerCase().replace(/layout$/, '') const layoutInfo: LayoutInfo = { @@ -108,7 +103,7 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { groupLayouts.push(layoutInfo) allLayouts.push(layoutInfo) } else { - console.error(`${layoutName} is missing required exports (default component or Schema)`) + console.error(`${layoutName} is missing required exports (default component or Schema)`) } } catch (altError) { console.error(`Alternative import also failed for ${fileName} from ${groupData.groupName}:`, altError) @@ -126,10 +121,10 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { } if (allLayouts.length === 0) { - toast.error('No valid layouts found', { - description: 'Make sure your layout files export both a default component and a Schema.', + toast.error('No valid templates found', { + description: 'Make sure your template files export both a default component and a Schema.', }) - setError('No valid layouts found. Make sure your layout files export both a default component and a Schema.') + setError('No valid templates found. Make sure your template files export both a default component and a Schema.') } else { setLayoutGroups(loadedGroups) setLayouts(allLayouts) @@ -137,8 +132,8 @@ export const useLayoutLoader = (): UseLayoutLoaderReturn => { } } catch (error) { - console.error('Error loading layouts:', error) - setError(error instanceof Error ? error.message : 'Failed to load layouts') + console.error('Error loading templates:', error) + setError(error instanceof Error ? error.message : 'Failed to load templates') } finally { setLoading(false) } diff --git a/servers/nextjs/app/(presentation-generator)/layout-preview/page.tsx b/servers/nextjs/app/(presentation-generator)/template-preview/page.tsx similarity index 98% rename from servers/nextjs/app/(presentation-generator)/layout-preview/page.tsx rename to servers/nextjs/app/(presentation-generator)/template-preview/page.tsx index 42cbebee..2938de74 100644 --- a/servers/nextjs/app/(presentation-generator)/layout-preview/page.tsx +++ b/servers/nextjs/app/(presentation-generator)/template-preview/page.tsx @@ -74,7 +74,7 @@ const LayoutPreview = () => { key={group.groupName} className="cursor-pointer hover:shadow-md transition-all duration-200 group" onClick={() => - router.push(`/layout-preview/${group.groupName}`) + router.push(`/template-preview/${group.groupName}`) } >
diff --git a/servers/nextjs/app/(presentation-generator)/layout-preview/types/index.ts b/servers/nextjs/app/(presentation-generator)/template-preview/types/index.ts similarity index 100% rename from servers/nextjs/app/(presentation-generator)/layout-preview/types/index.ts rename to servers/nextjs/app/(presentation-generator)/template-preview/types/index.ts diff --git a/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx b/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx index f8799d51..8b14b69e 100644 --- a/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx +++ b/servers/nextjs/app/(presentation-generator)/upload/components/UploadPage.tsx @@ -131,7 +131,7 @@ const UploadPage = () => { config, files: responses, })); - dispatch(clearOutlines()); + dispatch(clearOutlines()) router.push("/documents-preview"); }; @@ -155,7 +155,7 @@ const UploadPage = () => { }); dispatch(setPresentationId(createResponse.id)); - dispatch(clearOutlines()); + dispatch(clearOutlines()) router.push("/outline"); }; diff --git a/servers/nextjs/app/api/has-required-key/route.ts b/servers/nextjs/app/api/has-required-key/route.ts index 9e4a3b40..05efdbe8 100644 --- a/servers/nextjs/app/api/has-required-key/route.ts +++ b/servers/nextjs/app/api/has-required-key/route.ts @@ -1,8 +1,25 @@ import { NextResponse } from "next/server"; +import fs from "fs"; export const dynamic = "force-dynamic"; export async function GET() { - const hasKey = process.env.GOOGLE_API_KEY !== ""; + const userConfigPath = process.env.USER_CONFIG_PATH; + + let keyFromFile = ""; + if (userConfigPath && fs.existsSync(userConfigPath)) { + try { + const raw = fs.readFileSync(userConfigPath, "utf-8"); + const cfg = JSON.parse(raw || "{}"); + keyFromFile = cfg?.OPENAI_API_KEY || ""; + } catch {} + } + + console.log(keyFromFile); + + const keyFromEnv = process.env.OPENAI_API_KEY || ""; + console.log(keyFromEnv); + const hasKey = Boolean((keyFromFile || keyFromEnv).trim()); + return NextResponse.json({ hasKey }); -} +} \ No newline at end of file diff --git a/servers/nextjs/app/api/layouts/route.ts b/servers/nextjs/app/api/layouts/route.ts index 4d16ac98..93f88e76 100644 --- a/servers/nextjs/app/api/layouts/route.ts +++ b/servers/nextjs/app/api/layouts/route.ts @@ -1,7 +1,7 @@ import { NextResponse } from 'next/server' import { promises as fs } from 'fs' import path from 'path' -import { GroupSetting } from '@/app/(presentation-generator)/layout-preview/types' +import { GroupSetting } from '@/app/(presentation-generator)/template-preview/types' export async function GET() { try { diff --git a/servers/nextjs/app/api/template/route.ts b/servers/nextjs/app/api/template/route.ts new file mode 100644 index 00000000..2b862d33 --- /dev/null +++ b/servers/nextjs/app/api/template/route.ts @@ -0,0 +1,60 @@ +import { NextResponse } from "next/server"; +import puppeteer from "puppeteer"; + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const groupName = searchParams.get("group"); + + if (!groupName) { + return NextResponse.json({ error: "Missing group name" }, { status: 400 }); + } + + const schemaPageUrl = `http://localhost/schema?group=${encodeURIComponent(groupName)}`; + + let browser; + try { + browser = await puppeteer.launch({ headless: true, args: ["--no-sandbox", "--disable-web-security"] }); + const page = await browser.newPage(); + await page.setViewport({ width: 1280, height: 720 }); + await page.goto(schemaPageUrl, { waitUntil: "networkidle0", timeout: 80000 }); + + await page.waitForSelector("[data-layouts]", { timeout: 30000 }); + + const { dataLayouts, dataGroupSettings } = await page.$eval( + "[data-layouts]", + (el) => ({ + dataLayouts: el.getAttribute("data-layouts"), + dataGroupSettings: el.getAttribute("data-group-settings"), + }) + ); + + let slides, groupSettings; + try { + slides = JSON.parse(dataLayouts || "[]"); + } catch (e) { + slides = []; + } + try { + groupSettings = JSON.parse(dataGroupSettings || "null"); + } catch (e) { + groupSettings = null; + } + + const response = { + name: groupName, + ordered: groupSettings?.ordered ?? false, + slides: slides.map((slide: any) => ({ + id: slide.id, + name: slide.name, + description: slide.description, + json_schema: slide.json_schema, + })), + }; + + return NextResponse.json(response); + } catch (err) { + return NextResponse.json({ error: "Failed to fetch or parse client page" }, { status: 500 }); + } finally { + if (browser) await browser.close(); + } +} \ No newline at end of file diff --git a/servers/nextjs/app/api/templates/route.ts b/servers/nextjs/app/api/templates/route.ts new file mode 100644 index 00000000..50215cdb --- /dev/null +++ b/servers/nextjs/app/api/templates/route.ts @@ -0,0 +1,57 @@ +import { NextResponse } from 'next/server' +import { promises as fs } from 'fs' +import path from 'path' +import { GroupSetting } from '@/app/(presentation-generator)/template-preview/types' + +export async function GET() { + try { + const layoutsDirectory = path.join(process.cwd(), 'presentation-layouts') + const items = await fs.readdir(layoutsDirectory, { withFileTypes: true }) + + const groupDirectories = items.filter(item => item.isDirectory()).map(dir => dir.name) + + const allLayouts: { groupName: string; files: string[]; settings: GroupSetting | null }[] = [] + + for (const groupName of groupDirectories) { + try { + const groupPath = path.join(layoutsDirectory, groupName) + const groupFiles = await fs.readdir(groupPath) + + const layoutFiles = groupFiles.filter(file => + file.endsWith('.tsx') && + !file.startsWith('.') && + !file.includes('.test.') && + !file.includes('.spec.') && + file !== 'settings.json' + ) + + let settings: GroupSetting | null = null + const settingsPath = path.join(groupPath, 'settings.json') + try { + const settingsContent = await fs.readFile(settingsPath, 'utf-8') + settings = JSON.parse(settingsContent) as GroupSetting + } catch { + settings = { + description: `${groupName} presentation templates`, + ordered: false, + default: false, + } + } + + if (layoutFiles.length > 0) { + allLayouts.push({ groupName, files: layoutFiles, settings }) + } + } catch (error) { + console.error(`Error reading group directory ${groupName}:`, error) + } + } + + return NextResponse.json(allLayouts) + } catch (error) { + console.error('Error reading presentation-layouts directory:', error) + return NextResponse.json( + { error: 'Failed to read presentation layouts directory' }, + { status: 500 } + ) + } +} \ No newline at end of file diff --git a/servers/nextjs/app/api/user-config/route.ts b/servers/nextjs/app/api/user-config/route.ts index ff3c643a..03b801e3 100644 --- a/servers/nextjs/app/api/user-config/route.ts +++ b/servers/nextjs/app/api/user-config/route.ts @@ -57,6 +57,10 @@ export async function POST(request: Request) { userConfig.EXTENDED_REASONING === undefined ? existingConfig.EXTENDED_REASONING : userConfig.EXTENDED_REASONING, + WEB_GROUNDING: + userConfig.WEB_GROUNDING === undefined + ? existingConfig.WEB_GROUNDING + : userConfig.WEB_GROUNDING, USE_CUSTOM_URL: userConfig.USE_CUSTOM_URL === undefined ? existingConfig.USE_CUSTOM_URL diff --git a/servers/nextjs/components/AnthropicConfig.tsx b/servers/nextjs/components/AnthropicConfig.tsx index 567846b5..4b61bb65 100644 --- a/servers/nextjs/components/AnthropicConfig.tsx +++ b/servers/nextjs/components/AnthropicConfig.tsx @@ -19,6 +19,7 @@ interface AnthropicConfigProps { anthropicApiKey: string; anthropicModel: string; extendedReasoning: boolean; + webGrounding?: boolean; onInputChange: (value: string | boolean, field: string) => void; } @@ -27,6 +28,7 @@ export default function AnthropicConfig({ anthropicApiKey, anthropicModel, extendedReasoning, + webGrounding, onInputChange, }: AnthropicConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -65,7 +67,7 @@ export default function AnthropicConfig({ const data = await response.json(); setAvailableModels(data); setModelsChecked(true); - onInputChange("claude-3-5-sonnet-20241022", "anthropic_model"); + onInputChange("claude-sonnet-4-20250514", "anthropic_model"); } else { console.error('Failed to fetch models'); setAvailableModels([]); @@ -226,6 +228,23 @@ export default function AnthropicConfig({
) : null} + + {/* Web Grounding Toggle - at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/components/GoogleConfig.tsx b/servers/nextjs/components/GoogleConfig.tsx index 6746f779..8d333dd3 100644 --- a/servers/nextjs/components/GoogleConfig.tsx +++ b/servers/nextjs/components/GoogleConfig.tsx @@ -13,16 +13,19 @@ import { import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { cn } from "@/lib/utils"; import { toast } from "sonner"; +import { Switch } from "./ui/switch"; interface GoogleConfigProps { googleApiKey: string; googleModel: string; - onInputChange: (value: string, field: string) => void; + webGrounding?: boolean; + onInputChange: (value: string | boolean, field: string) => void; } export default function GoogleConfig({ googleApiKey, googleModel, + webGrounding, onInputChange }: GoogleConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -61,7 +64,7 @@ export default function GoogleConfig({ const data = await response.json(); setAvailableModels(data); setModelsChecked(true); - onInputChange("models/gemini-2.0-flash", "google_model"); + onInputChange("models/gemini-2.5-flash", "google_model"); } else { console.error('Failed to fetch models'); setAvailableModels([]); @@ -205,6 +208,23 @@ export default function GoogleConfig({ ) : null} + + {/* Web Grounding Toggle - at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/components/Header.tsx b/servers/nextjs/components/Header.tsx index 13ebb205..ea324d3a 100644 --- a/servers/nextjs/components/Header.tsx +++ b/servers/nextjs/components/Header.tsx @@ -1,43 +1,27 @@ "use client"; -import Wrapper from "@/components/Wrapper"; import React from "react"; import Link from "next/link"; -import BackBtn from "@/components/BackBtn"; -import { usePathname } from "next/navigation"; -import HeaderNav from "@/app/(presentation-generator)/components/HeaderNab"; import { Layout } from "lucide-react"; -const Header = () => { - const pathname = usePathname(); + +const Header: React.FC = () => { return ( -
- -
-
- {pathname !== "/upload" && } - - Presentation logo - -
-
- +
+
+
+ + Presenton + + +
+
- -
+
+ ); }; diff --git a/servers/nextjs/components/LLMSelection.tsx b/servers/nextjs/components/LLMSelection.tsx index 422e8333..ed308226 100644 --- a/servers/nextjs/components/LLMSelection.tsx +++ b/servers/nextjs/components/LLMSelection.tsx @@ -149,6 +149,7 @@ export default function LLMProviderSelection({ @@ -158,6 +159,7 @@ export default function LLMProviderSelection({ @@ -168,6 +170,7 @@ export default function LLMProviderSelection({ anthropicApiKey={llmConfig.ANTHROPIC_API_KEY || ""} anthropicModel={llmConfig.ANTHROPIC_MODEL || ""} extendedReasoning={llmConfig.EXTENDED_REASONING || false} + webGrounding={llmConfig.WEB_GROUNDING || false} onInputChange={input_field_changed} /> diff --git a/servers/nextjs/components/OllamaConfig.tsx b/servers/nextjs/components/OllamaConfig.tsx index e29fa4e9..74f30987 100644 --- a/servers/nextjs/components/OllamaConfig.tsx +++ b/servers/nextjs/components/OllamaConfig.tsx @@ -19,7 +19,6 @@ interface OllamaModel { label: string; value: string; size: string; - icon: string; } interface OllamaConfigProps { @@ -128,19 +127,6 @@ export default function OllamaConfig({ className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between" >
- {ollamaModel && ( -
- m.value === ollamaModel - )?.icon - } - alt={`${ollamaModel} icon`} - className="rounded-sm" - /> -
- )} {ollamaModel ? ollamaModels?.find( @@ -189,13 +175,6 @@ export default function OllamaConfig({ )} />
-
- {`${model.label} -
diff --git a/servers/nextjs/components/OpenAIConfig.tsx b/servers/nextjs/components/OpenAIConfig.tsx index b73695e9..7c465a99 100644 --- a/servers/nextjs/components/OpenAIConfig.tsx +++ b/servers/nextjs/components/OpenAIConfig.tsx @@ -13,16 +13,19 @@ import { import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover"; import { cn } from "@/lib/utils"; import { toast } from "sonner"; +import { Switch } from "./ui/switch"; interface OpenAIConfigProps { openaiApiKey: string; openaiModel: string; - onInputChange: (value: string, field: string) => void; + webGrounding?: boolean; + onInputChange: (value: string | boolean, field: string) => void; } export default function OpenAIConfig({ openaiApiKey, openaiModel, + webGrounding, onInputChange }: OpenAIConfigProps) { const [openModelSelect, setOpenModelSelect] = useState(false); @@ -210,6 +213,23 @@ export default function OpenAIConfig({
) : null} + + {/* Web Grounding Toggle - show at the end, below models dropdown */} +
+
+ + onInputChange(checked, "web_grounding")} + /> +
+

+ + If enabled, the model can use web search grounding when available. +

+
); } \ No newline at end of file diff --git a/servers/nextjs/store/slices/presentationGeneration.ts b/servers/nextjs/store/slices/presentationGeneration.ts index 623496a7..67c03cbc 100644 --- a/servers/nextjs/store/slices/presentationGeneration.ts +++ b/servers/nextjs/store/slices/presentationGeneration.ts @@ -18,7 +18,7 @@ interface PresentationGenerationState { presentation_id: string | null; isLoading: boolean; isStreaming: boolean | null; - outlines: string[]; + outlines: { content: string }[]; error: string | null; presentationData: PresentationData | null; isSlidesRendered: boolean; @@ -72,7 +72,7 @@ const presentationGenerationSlice = createSlice({ state.outlines = []; }, // Set outlines - setOutlines: (state, action: PayloadAction) => { + setOutlines: (state, action: PayloadAction<{ content: string }[]>) => { state.outlines = action.payload; }, // Set presentation data diff --git a/servers/nextjs/types/llm_config.ts b/servers/nextjs/types/llm_config.ts index 0a44e639..5b73b215 100644 --- a/servers/nextjs/types/llm_config.ts +++ b/servers/nextjs/types/llm_config.ts @@ -31,6 +31,7 @@ export interface LLMConfig { TOOL_CALLS?: boolean; DISABLE_THINKING?: boolean; EXTENDED_REASONING?: boolean; + WEB_GROUNDING?: boolean; // Only used in UI settings USE_CUSTOM_URL?: boolean; diff --git a/servers/nextjs/utils/providerUtils.ts b/servers/nextjs/utils/providerUtils.ts index a0efb4f8..5e4dea0e 100644 --- a/servers/nextjs/utils/providerUtils.ts +++ b/servers/nextjs/utils/providerUtils.ts @@ -3,9 +3,7 @@ import { LLMConfig } from "@/types/llm_config"; export interface OllamaModel { label: string; value: string; - description: string; size: string; - icon: string; } export interface DownloadingModel { @@ -48,6 +46,7 @@ export const updateLLMConfig = ( tool_calls: "TOOL_CALLS", disable_thinking: "DISABLE_THINKING", extended_reasoning: "EXTENDED_REASONING", + web_grounding: "WEB_GROUNDING", }; const configKey = fieldMappings[field]; diff --git a/start.js b/start.js index 9db651ff..2e7533e5 100644 --- a/start.js +++ b/start.js @@ -81,6 +81,7 @@ const setupUserConfigFromEnv = () => { TOOL_CALLS: process.env.TOOL_CALLS || existingConfig.TOOL_CALLS, DISABLE_THINKING: process.env.DISABLE_THINKING || existingConfig.DISABLE_THINKING, EXTENDED_REASONING: process.env.EXTENDED_REASONING || existingConfig.EXTENDED_REASONING, + WEB_GROUNDING: process.env.WEB_GROUNDING || existingConfig.WEB_GROUNDING, USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL, };