From 087d64ed621ae2ee5e286bb14d12c587c904317f Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Wed, 8 Apr 2026 11:25:13 +0545 Subject: [PATCH] feat: implements new template generation flow; refactor: removes old template generation flow and adds new endpoints --- electron/.gitignore | 1 + electron/servers/fastapi/alembic/env.py | 1 + .../fastapi/api/v1/ppt/endpoints/fonts.py | 8 +- .../fastapi/api/v1/ppt/endpoints/layouts.py | 5 +- .../api/v1/ppt/endpoints/pdf_slides.py | 116 -- .../api/v1/ppt/endpoints/pptx_slides.py | 641 ----------- .../api/v1/ppt/endpoints/presentation.py | 7 +- .../api/v1/ppt/endpoints/slide_to_html.py | 1013 ----------------- electron/servers/fastapi/api/v1/ppt/router.py | 13 +- .../fastapi/models/sql/presentation.py | 2 +- .../models/sql/presentation_layout_code.py | 15 +- .../servers/fastapi/models/sql/template.py | 3 +- .../models/sql/template_create_info.py | 25 + electron/servers/fastapi/pyproject.toml | 2 +- electron/servers/fastapi/services/database.py | 6 +- .../servers/fastapi/templates/__init__.py | 1 + electron/servers/fastapi/templates/example.py | 98 ++ .../servers/fastapi/templates/font_utils.py | 167 +++ .../get_layout_by_name.py | 8 +- electron/servers/fastapi/templates/handler.py | 683 +++++++++++ .../fastapi/templates/pptx_html_stub.py | 30 + .../presentation_layout.py | 13 +- electron/servers/fastapi/templates/preview.py | 477 ++++++++ electron/servers/fastapi/templates/prompts.py | 219 ++++ .../servers/fastapi/templates/providers.py | 365 ++++++ electron/servers/fastapi/templates/router.py | 65 ++ .../tests/test_pptx_slides_processing.py | 140 --- .../tests/test_presentation_generation_api.py | 272 ++--- .../fastapi/tests/test_slide_to_html.py | 115 -- .../fastapi/tests/test_template_api.py | 479 ++++++++ .../fastapi/utils/llm_calls/edit_slide.py | 2 +- .../generate_presentation_structure.py | 2 +- .../utils/llm_calls/generate_slide_content.py | 2 +- .../llm_calls/select_slide_type_on_edit.py | 2 +- electron/servers/fastapi/utils/ppt_utils.py | 2 +- 35 files changed, 2756 insertions(+), 2244 deletions(-) delete mode 100644 electron/servers/fastapi/api/v1/ppt/endpoints/pdf_slides.py delete mode 100644 electron/servers/fastapi/api/v1/ppt/endpoints/pptx_slides.py delete mode 100644 electron/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py create mode 100644 electron/servers/fastapi/models/sql/template_create_info.py create mode 100644 electron/servers/fastapi/templates/__init__.py create mode 100644 electron/servers/fastapi/templates/example.py create mode 100644 electron/servers/fastapi/templates/font_utils.py rename electron/servers/fastapi/{utils => templates}/get_layout_by_name.py (81%) create mode 100644 electron/servers/fastapi/templates/handler.py create mode 100644 electron/servers/fastapi/templates/pptx_html_stub.py rename electron/servers/fastapi/{models => templates}/presentation_layout.py (78%) create mode 100644 electron/servers/fastapi/templates/preview.py create mode 100644 electron/servers/fastapi/templates/prompts.py create mode 100644 electron/servers/fastapi/templates/providers.py create mode 100644 electron/servers/fastapi/templates/router.py delete mode 100644 electron/servers/fastapi/tests/test_pptx_slides_processing.py delete mode 100644 electron/servers/fastapi/tests/test_slide_to_html.py create mode 100644 electron/servers/fastapi/tests/test_template_api.py diff --git a/electron/.gitignore b/electron/.gitignore index 588efb1a..65cc4836 100644 --- a/electron/.gitignore +++ b/electron/.gitignore @@ -10,6 +10,7 @@ app_data tmp debug .fastembed_cache +.codex generated_models nltk diff --git a/electron/servers/fastapi/alembic/env.py b/electron/servers/fastapi/alembic/env.py index 72da76b3..5524c68a 100644 --- a/electron/servers/fastapi/alembic/env.py +++ b/electron/servers/fastapi/alembic/env.py @@ -24,6 +24,7 @@ from models.sql.presentation_layout_code import ( # noqa: F401, E402 ) from models.sql.slide import SlideModel # noqa: F401, E402 from models.sql.template import TemplateModel # noqa: F401, E402 +from models.sql.template_create_info import TemplateCreateInfoModel # noqa: F401, E402 from models.sql.webhook_subscription import WebhookSubscription # noqa: F401, E402 alembic_config = context.config diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/fonts.py b/electron/servers/fastapi/api/v1/ppt/endpoints/fonts.py index a2a821cd..af6a9b8f 100644 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/fonts.py +++ b/electron/servers/fastapi/api/v1/ppt/endpoints/fonts.py @@ -1,12 +1,11 @@ import os import uuid import shutil -from pathlib import Path from typing import List, Dict, Any, Optional from fastapi import APIRouter, HTTPException, File, UploadFile from pydantic import BaseModel +from templates.preview import FontCheckResponse, check_fonts_in_pptx_handler from utils.asset_directory_utils import get_app_data_directory_env -import uuid try: from fontTools.ttLib import TTFont @@ -287,6 +286,9 @@ async def get_uploaded_fonts(): ) +FONTS_ROUTER.post("/check", response_model=FontCheckResponse)(check_fonts_in_pptx_handler) + + @FONTS_ROUTER.delete("/delete/{filename}") async def delete_font(filename: str): """ @@ -330,4 +332,4 @@ async def delete_font(filename: str): raise HTTPException( status_code=500, detail=f"Error deleting font: {str(e)}" - ) \ No newline at end of file + ) diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/layouts.py b/electron/servers/fastapi/api/v1/ppt/endpoints/layouts.py index 6e1051bc..c851175e 100644 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/layouts.py +++ b/electron/servers/fastapi/api/v1/ppt/endpoints/layouts.py @@ -1,8 +1,7 @@ from fastapi import APIRouter, HTTPException import aiohttp -from typing import List, Any -from utils.get_layout_by_name import get_layout_by_name -from models.presentation_layout import PresentationLayoutModel +from templates.get_layout_by_name import get_layout_by_name +from templates.presentation_layout import PresentationLayoutModel LAYOUTS_ROUTER = APIRouter(prefix="/layouts", tags=["Layouts"]) diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/pdf_slides.py b/electron/servers/fastapi/api/v1/ppt/endpoints/pdf_slides.py deleted file mode 100644 index 606cb12f..00000000 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/pdf_slides.py +++ /dev/null @@ -1,116 +0,0 @@ -import os -import shutil -import tempfile -import subprocess -from typing import List, Optional -from fastapi import APIRouter, UploadFile, File, HTTPException -from pydantic import BaseModel - -from services.documents_loader import DocumentsLoader -from utils.asset_directory_utils import get_images_directory -import uuid -from constants.documents import PDF_MIME_TYPES - - -PDF_SLIDES_ROUTER = APIRouter(prefix="/pdf-slides", tags=["PDF Slides"]) - - -class PdfSlideData(BaseModel): - slide_number: int - screenshot_url: str - - -class PdfSlidesResponse(BaseModel): - success: bool - slides: List[PdfSlideData] - total_slides: int - - -@PDF_SLIDES_ROUTER.post("/process", response_model=PdfSlidesResponse) -async def process_pdf_slides( - pdf_file: UploadFile = File(..., description="PDF file to process") -): - """ - Process a PDF file to extract slide screenshots. - - This endpoint: - 1. Validates the uploaded PDF file - 2. Uses ImageMagick to convert PDF pages to PNG images - 3. Returns screenshot URLs for each slide/page - - Note: Font installation is not needed since PDFs already have fonts embedded. - """ - - # Validate PDF file - if pdf_file.content_type not in PDF_MIME_TYPES: - raise HTTPException( - status_code=400, - detail=f"Invalid file type. Expected PDF file, got {pdf_file.content_type}", - ) - # Enforce 100MB size limit - if ( - hasattr(pdf_file, "size") - and pdf_file.size - and pdf_file.size > (100 * 1024 * 1024) - ): - raise HTTPException( - status_code=400, - detail="PDF file exceeded max upload size of 100 MB", - ) - - # Create temporary directory for processing - with tempfile.TemporaryDirectory() as temp_dir: - try: - # Save uploaded PDF file - pdf_path = os.path.join(temp_dir, "presentation.pdf") - with open(pdf_path, "wb") as f: - pdf_content = await pdf_file.read() - f.write(pdf_content) - - # Generate screenshots from PDF using ImageMagick - screenshot_paths = await DocumentsLoader.get_page_images_from_pdf_async( - pdf_path, temp_dir - ) - print(f"Generated {len(screenshot_paths)} PDF screenshots") - - # Move screenshots to images directory and generate URLs - images_dir = get_images_directory() - presentation_id = uuid.uuid4() - presentation_images_dir = os.path.join(images_dir, str(presentation_id)) - os.makedirs(presentation_images_dir, exist_ok=True) - - slides_data = [] - - for i, screenshot_path in enumerate(screenshot_paths, 1): - # Move screenshot to permanent location - screenshot_filename = f"slide_{i}.png" - permanent_screenshot_path = os.path.join( - presentation_images_dir, screenshot_filename - ) - - if ( - os.path.exists(screenshot_path) - and os.path.getsize(screenshot_path) > 0 - ): - # Use shutil.copy2 instead of os.rename to handle cross-device moves - shutil.copy2(screenshot_path, permanent_screenshot_path) - screenshot_url = ( - f"/app_data/images/{presentation_id}/{screenshot_filename}" - ) - else: - # Fallback if screenshot generation failed or file is empty placeholder - screenshot_url = "/static/images/placeholder.jpg" - - slides_data.append( - PdfSlideData(slide_number=i, screenshot_url=screenshot_url) - ) - - return PdfSlidesResponse( - success=True, slides=slides_data, total_slides=len(slides_data) - ) - - except Exception as e: - print(f"Error processing PDF slides: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Failed to process PDF: {str(e)}" - ) diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/pptx_slides.py b/electron/servers/fastapi/api/v1/ppt/endpoints/pptx_slides.py deleted file mode 100644 index 65200b7f..00000000 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/pptx_slides.py +++ /dev/null @@ -1,641 +0,0 @@ -import os -import shutil -import zipfile -import tempfile -import subprocess -import uuid -from typing import List, Optional, Dict -from fastapi import APIRouter, UploadFile, File, HTTPException -from pydantic import BaseModel -import aiohttp -import asyncio -import xml.etree.ElementTree as ET -import re - -from services.documents_loader import DocumentsLoader -from utils.asset_directory_utils import get_images_directory -import uuid -from constants.documents import PPTX_MIME_TYPES - - -def _get_soffice_binary() -> str: - """Return the soffice binary to use for LibreOffice subprocess calls. - - When running inside the Electron desktop app, the main process resolves the - exact soffice binary path at startup and forwards it via the ``SOFFICE_PATH`` - environment variable. Falling back to the bare ``"soffice"`` command keeps - Docker / server deployments working unchanged. - """ - configured = os.environ.get("SOFFICE_PATH") - if configured: - return configured - return "soffice.exe" if os.name == "nt" else "soffice" - - -def _windows_hidden_subprocess_kwargs() -> Dict[str, object]: - """Return subprocess kwargs that suppress Windows console windows.""" - if os.name != "nt": - return {} - - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - return { - "creationflags": getattr(subprocess, "CREATE_NO_WINDOW", 0), - "startupinfo": startupinfo, - } - - -PPTX_SLIDES_ROUTER = APIRouter(prefix="/pptx-slides", tags=["PPTX Slides"]) - - -class SlideData(BaseModel): - slide_number: int - screenshot_url: str - xml_content: str - normalized_fonts: List[str] - - -class FontAnalysisResult(BaseModel): - internally_supported_fonts: List[ - Dict[str, str] - ] # [{"name": "Open Sans", "google_fonts_url": "..."}] - not_supported_fonts: List[str] # ["Custom Font Name"] - - -class PptxSlidesResponse(BaseModel): - success: bool - slides: List[SlideData] - total_slides: int - fonts: Optional[FontAnalysisResult] = None - - -# NEW: Fonts-only router and response for PPTX -class PptxFontsResponse(BaseModel): - success: bool - fonts: FontAnalysisResult - - -PPTX_FONTS_ROUTER = APIRouter(prefix="/pptx-fonts", tags=["PPTX Fonts"]) - -# NEW: Normalize font family names by removing style/weight/stretch descriptors and splitting camel case -_STYLE_TOKENS = { - # styles - "italic", - "italics", - "ital", - "oblique", - "roman", - # combined style shortcuts - "bolditalic", - "bolditalics", - # weights - "thin", - "hairline", - "extralight", - "ultralight", - "light", - "demilight", - "semilight", - "book", - "regular", - "normal", - "medium", - "semibold", - "demibold", - "bold", - "extrabold", - "ultrabold", - "black", - "extrablack", - "ultrablack", - "heavy", - # width/stretch - "narrow", - "condensed", - "semicondensed", - "extracondensed", - "ultracondensed", - "expanded", - "semiexpanded", - "extraexpanded", - "ultraexpanded", -} -# Modifiers commonly used with style tokens -_STYLE_MODIFIERS = {"semi", "demi", "extra", "ultra"} - - -def _insert_spaces_in_camel_case(value: str) -> str: - # Insert space before capital letters preceded by lowercase or digits (e.g., MontserratBold -> Montserrat Bold) - value = re.sub(r"(?<=[a-z0-9])([A-Z])", r" \1", value) - # Handle sequences like BoldItalic -> Bold Italic - value = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", value) - return value - - -def normalize_font_family_name(raw_name: str) -> str: - if not raw_name: - return raw_name - # Replace separators with spaces - name = raw_name.replace("_", " ").replace("-", " ") - # Insert spaces in camel case - name = _insert_spaces_in_camel_case(name) - # Collapse multiple spaces - name = re.sub(r"\s+", " ", name).strip() - # Lowercase helper for matching but keep original casing for output - lower_name = name.lower() - # Quick cut: if the full string ends with a pure style suffix, trim it - for style in sorted(_STYLE_TOKENS, key=len, reverse=True): - if lower_name.endswith(" " + style): - name = name[: -(len(style) + 1)] - lower_name = lower_name[: -(len(style) + 1)] - break - # Tokenize - tokens_original = name.split(" ") - tokens_filtered: List[str] = [] - for index, tok in enumerate(tokens_original): - lower_tok = tok.lower() - # Always keep the first token to avoid stripping families like "Black Ops One" - if index == 0: - tokens_filtered.append(tok) - continue - # Drop style tokens and standalone modifiers - if lower_tok in _STYLE_TOKENS or lower_tok in _STYLE_MODIFIERS: - continue - tokens_filtered.append(tok) - # If everything except first token was dropped and first token is a style token (unlikely), fallback to original - if not tokens_filtered: - tokens_filtered = tokens_original - normalized = " ".join(tokens_filtered).strip() - # Final cleanup of leftover multiple spaces - normalized = re.sub(r"\s+", " ", normalized) - return normalized - - -def extract_fonts_from_oxml(xml_content: str) -> List[str]: - """ - Extract font names from OXML content. - - Args: - xml_content: OXML content as string - - Returns: - List of unique font names found in the OXML - """ - fonts = set() - - try: - # Parse the XML content - root = ET.fromstring(xml_content) - - # Define namespaces commonly used in OXML - namespaces = { - "a": "http://schemas.openxmlformats.org/drawingml/2006/main", - "p": "http://schemas.openxmlformats.org/presentationml/2006/main", - "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", - } - - # Search for font references in various OXML elements - # Look for latin fonts - for font_elem in root.findall(".//a:latin", namespaces): - if "typeface" in font_elem.attrib: - fonts.add(font_elem.attrib["typeface"]) - - # Look for east asian fonts - for font_elem in root.findall(".//a:ea", namespaces): - if "typeface" in font_elem.attrib: - fonts.add(font_elem.attrib["typeface"]) - - # Look for complex script fonts - for font_elem in root.findall(".//a:cs", namespaces): - if "typeface" in font_elem.attrib: - fonts.add(font_elem.attrib["typeface"]) - - # Look for font references in theme elements - for font_elem in root.findall(".//a:font", namespaces): - if "typeface" in font_elem.attrib: - fonts.add(font_elem.attrib["typeface"]) - - # Look for rPr (run properties) font references - for rpr_elem in root.findall(".//a:rPr", namespaces): - for font_elem in rpr_elem.findall(".//a:latin", namespaces): - if "typeface" in font_elem.attrib: - fonts.add(font_elem.attrib["typeface"]) - - # Also search without namespace prefix for compatibility - for font_elem in root.findall(".//latin"): - if "typeface" in font_elem.attrib: - fonts.add(font_elem.attrib["typeface"]) - - # Regex fallback for fonts that might be missed - font_pattern = r'typeface="([^"]+)"' - regex_fonts = re.findall(font_pattern, xml_content) - fonts.update(regex_fonts) - - # Filter out system fonts and empty values - system_fonts = {"+mn-lt", "+mj-lt", "+mn-ea", "+mj-ea", "+mn-cs", "+mj-cs", ""} - fonts = {font for font in fonts if font not in system_fonts and font.strip()} - - return list(fonts) - - except Exception as e: - print(f"Error extracting fonts from OXML: {e}") - return [] - - -async def check_google_font_availability(font_name: str) -> bool: - """ - Check if a font is available in Google Fonts. - - Args: - font_name: Name of the font to check - - Returns: - True if font is available in Google Fonts, False otherwise - """ - try: - formatted_name = font_name.replace(" ", "+") - url = f"https://fonts.googleapis.com/css2?family={formatted_name}&display=swap" - - async with aiohttp.ClientSession() as session: - async with session.head( - url, timeout=aiohttp.ClientTimeout(total=10) - ) as response: - return response.status == 200 - - except Exception as e: - print(f"Error checking Google Font availability for {font_name}: {e}") - return False - - -async def analyze_fonts_in_all_slides(slide_xmls: List[str]) -> FontAnalysisResult: - """ - Analyze fonts across all slides and determine Google Fonts availability. - - Args: - slide_xmls: List of OXML content strings from all slides - - Returns: - FontAnalysisResult with supported and unsupported fonts - """ - # Extract fonts from all slides - raw_fonts = set() - for xml_content in slide_xmls: - slide_fonts = extract_fonts_from_oxml(xml_content) - raw_fonts.update(slide_fonts) - - # Normalize to root families (e.g., "Montserrat Italic" -> "Montserrat") - normalized_fonts = {normalize_font_family_name(f) for f in raw_fonts} - # Remove empties if any - normalized_fonts = {f for f in normalized_fonts if f} - - if not normalized_fonts: - return FontAnalysisResult(internally_supported_fonts=[], not_supported_fonts=[]) - - # Check each normalized font's availability in Google Fonts concurrently - tasks = [check_google_font_availability(font) for font in normalized_fonts] - results = await asyncio.gather(*tasks) - - internally_supported_fonts = [] - not_supported_fonts = [] - - for font, is_available in zip(normalized_fonts, results): - if is_available: - formatted_name = font.replace(" ", "+") - google_fonts_url = f"https://fonts.googleapis.com/css2?family={formatted_name}&display=swap" - internally_supported_fonts.append( - {"name": font, "google_fonts_url": google_fonts_url} - ) - else: - not_supported_fonts.append(font) - - return FontAnalysisResult( - internally_supported_fonts=internally_supported_fonts, not_supported_fonts=[] - ) - - -@PPTX_SLIDES_ROUTER.post("/process", response_model=PptxSlidesResponse) -async def process_pptx_slides( - pptx_file: UploadFile = File(..., description="PPTX file to process"), - fonts: Optional[List[UploadFile]] = File(None, description="Optional font files"), -): - """ - Process a PPTX file to extract slide screenshots and XML content. - - This endpoint: - 1. Validates the uploaded PPTX file - 2. Installs any provided font files - 3. Unzips the PPTX to extract slide XMLs - 4. Uses LibreOffice to generate slide screenshots - 5. Returns both screenshot URLs and XML content for each slide - """ - - # Validate PPTX file - if pptx_file.content_type not in PPTX_MIME_TYPES: - raise HTTPException( - status_code=400, - detail=f"Invalid file type. Expected PPTX file, got {pptx_file.content_type}", - ) - # Enforce 100MB size limit - if ( - hasattr(pptx_file, "size") - and pptx_file.size - and pptx_file.size > (100 * 1024 * 1024) - ): - raise HTTPException( - status_code=400, - detail="PPTX file exceeded max upload size of 100 MB", - ) - - # Create temporary directory for processing - with tempfile.TemporaryDirectory() as temp_dir: - if True: - # Save uploaded PPTX file - pptx_path = os.path.join(temp_dir, "presentation.pptx") - with open(pptx_path, "wb") as f: - pptx_content = await pptx_file.read() - f.write(pptx_content) - - # Install fonts if provided - if fonts: - await _install_fonts(fonts, temp_dir) - - # Extract slide XMLs from PPTX - slide_xmls = _extract_slide_xmls(pptx_path, temp_dir) - - # Convert PPTX to PDF - pdf_path = await _convert_pptx_to_pdf(pptx_path, temp_dir) - - # Generate screenshots using LibreOffice - screenshot_paths = await DocumentsLoader.get_page_images_from_pdf_async( - pdf_path, temp_dir - ) - print(f"Screenshot paths: {screenshot_paths}") - - # Analyze fonts across all slides - font_analysis = await analyze_fonts_in_all_slides(slide_xmls) - print( - f"Font analysis completed: {len(font_analysis.internally_supported_fonts)} supported, {len(font_analysis.not_supported_fonts)} not supported" - ) - - # Move screenshots to images directory and generate URLs - images_dir = get_images_directory() - presentation_id = uuid.uuid4() - presentation_images_dir = os.path.join(images_dir, str(presentation_id)) - os.makedirs(presentation_images_dir, exist_ok=True) - - slides_data = [] - - for i, (xml_content, screenshot_path) in enumerate( - zip(slide_xmls, screenshot_paths), 1 - ): - # Move screenshot to permanent location - screenshot_filename = f"slide_{i}.png" - permanent_screenshot_path = os.path.join( - presentation_images_dir, screenshot_filename - ) - - if ( - os.path.exists(screenshot_path) - and os.path.getsize(screenshot_path) > 0 - ): - # Use shutil.copy2 instead of os.rename to handle cross-device moves - shutil.copy2(screenshot_path, permanent_screenshot_path) - screenshot_url = ( - f"/app_data/images/{presentation_id}/{screenshot_filename}" - ) - else: - # Fallback if screenshot generation failed or file is empty placeholder - screenshot_url = "/static/images/placeholder.jpg" - - # Compute normalized fonts for this slide - raw_slide_fonts = extract_fonts_from_oxml(xml_content) - normalized_fonts = sorted( - {normalize_font_family_name(f) for f in raw_slide_fonts if f} - ) - - slides_data.append( - SlideData( - slide_number=i, - screenshot_url=screenshot_url, - xml_content=xml_content, - normalized_fonts=normalized_fonts, - ) - ) - - return PptxSlidesResponse( - success=True, - slides=slides_data, - total_slides=len(slides_data), - fonts=font_analysis, - ) - - -# NEW: Fonts-only endpoint leveraging the same font extraction/analysis -@PPTX_FONTS_ROUTER.post("/process", response_model=PptxFontsResponse) -async def process_pptx_fonts( - pptx_file: UploadFile = File(..., description="PPTX file to analyze fonts from") -): - """ - Analyze a PPTX file and return only the fonts used in the document. - - Uses the exact same font extraction and analysis utilities as the /pptx-slides endpoint. - """ - # Validate PPTX file - if pptx_file.content_type not in PPTX_MIME_TYPES: - raise HTTPException( - status_code=400, - detail=f"Invalid file type. Expected PPTX file, got {pptx_file.content_type}", - ) - - # Create temporary directory for processing - with tempfile.TemporaryDirectory() as temp_dir: - # Save uploaded PPTX file - pptx_path = os.path.join(temp_dir, "presentation.pptx") - with open(pptx_path, "wb") as f: - pptx_content = await pptx_file.read() - f.write(pptx_content) - - # Extract slide XMLs from PPTX - slide_xmls = _extract_slide_xmls(pptx_path, temp_dir) - - # Analyze fonts across all slides (same logic as in /pptx-slides) - font_analysis = await analyze_fonts_in_all_slides(slide_xmls) - - return PptxFontsResponse( - success=True, - fonts=font_analysis, - ) - - -def _create_font_alias_config(raw_fonts: List[str]) -> str: - """Create a temporary fontconfig configuration that aliases variant family names to normalized root families. - Returns the path to the config file. - """ - # Build mapping from raw -> normalized where different - mappings: Dict[str, str] = {} - for f in raw_fonts: - normalized = normalize_font_family_name(f) - if normalized and normalized != f: - mappings[f] = normalized - # Create config only if we have mappings - fd, fonts_conf_path = tempfile.mkstemp(prefix="fonts_alias_", suffix=".conf") - os.close(fd) - with open(fonts_conf_path, "w", encoding="utf-8") as cfg: - cfg.write( - """ - - - /etc/fonts/fonts.conf -""" - ) - for src, dst in mappings.items(): - cfg.write( - f""" - - - {src} - - - {dst} - - -""" - ) - cfg.write("\n\n") - return fonts_conf_path - - -async def _install_fonts(fonts: List[UploadFile], temp_dir: str) -> None: - """Install provided font files to the system.""" - fonts_dir = os.path.join(temp_dir, "fonts") - os.makedirs(fonts_dir, exist_ok=True) - - for font_file in fonts: - # Save font file - font_path = os.path.join(fonts_dir, font_file.filename) - with open(font_path, "wb") as f: - font_content = await font_file.read() - f.write(font_content) - - # Install font (copy to system fonts directory) - try: - subprocess.run( - ["cp", font_path, "/usr/share/fonts/truetype/"], - check=True, - capture_output=True, - ) - except subprocess.CalledProcessError as e: - print(f"Warning: Failed to install font {font_file.filename}: {e}") - - # Refresh font cache - try: - subprocess.run(["fc-cache", "-f", "-v"], check=True, capture_output=True) - except subprocess.CalledProcessError as e: - print(f"Warning: Failed to refresh font cache: {e}") - - -def _extract_slide_xmls(pptx_path: str, temp_dir: str) -> List[str]: - """Extract slide XML content from PPTX file.""" - slide_xmls = [] - extract_dir = os.path.join(temp_dir, "pptx_extract") - - try: - # Unzip PPTX file - with zipfile.ZipFile(pptx_path, "r") as zip_ref: - zip_ref.extractall(extract_dir) - - # Look for slides in ppt/slides/ directory - slides_dir = os.path.join(extract_dir, "ppt", "slides") - - if not os.path.exists(slides_dir): - raise Exception("No slides directory found in PPTX file") - - # Get all slide XML files and sort them numerically - slide_files = [ - f - for f in os.listdir(slides_dir) - if f.startswith("slide") and f.endswith(".xml") - ] - slide_files.sort(key=lambda x: int(x.replace("slide", "").replace(".xml", ""))) - - # Read XML content from each slide - for slide_file in slide_files: - slide_path = os.path.join(slides_dir, slide_file) - with open(slide_path, "r", encoding="utf-8") as f: - slide_xmls.append(f.read()) - - return slide_xmls - - except Exception as e: - raise Exception(f"Failed to extract slide XMLs: {str(e)}") - - -async def _convert_pptx_to_pdf(pptx_path: str, temp_dir: str) -> str: - """Generate PNG screenshots of PPTX slides using LibreOffice + ImageMagick.""" - screenshots_dir = os.path.join(temp_dir, "screenshots") - os.makedirs(screenshots_dir, exist_ok=True) - - try: - # First, get the number of slides by extracting XMLs - slide_xmls = _extract_slide_xmls(pptx_path, temp_dir) - slide_count = len(slide_xmls) - - # Build font alias config to force variant families to resolve to normalized root families - raw_fonts: List[str] = [] - for xml in slide_xmls: - raw_fonts.extend(extract_fonts_from_oxml(xml)) - raw_fonts = list({f for f in raw_fonts if f}) - fonts_conf_path = _create_font_alias_config(raw_fonts) - env = os.environ.copy() - env["FONTCONFIG_FILE"] = fonts_conf_path - - print(f"Found {slide_count} slides in presentation") - - # Step 1: Convert PPTX to PDF using LibreOffice - print("Starting LibreOffice PDF conversion...") - pdf_filename = "temp_presentation.pdf" - pdf_path = os.path.join(screenshots_dir, pdf_filename) - - try: - result = subprocess.run( - [ - _get_soffice_binary(), - "--headless", - "--convert-to", - "pdf", - "--outdir", - screenshots_dir, - pptx_path, - ], - check=True, - capture_output=True, - text=True, - timeout=500, - env=env, - **_windows_hidden_subprocess_kwargs(), - ) - - print(f"LibreOffice PDF conversion output: {result.stdout}") - if result.stderr: - print(f"LibreOffice PDF conversion warnings: {result.stderr}") - except subprocess.TimeoutExpired: - raise Exception("LibreOffice PDF conversion timed out after 120 seconds") - except subprocess.CalledProcessError as e: - error_msg = e.stderr if e.stderr else str(e) - raise Exception(f"LibreOffice PDF conversion failed: {error_msg}") - - # Find the generated PDF file (LibreOffice uses original filename) - pdf_files = [f for f in os.listdir(screenshots_dir) if f.endswith(".pdf")] - if not pdf_files: - raise Exception("LibreOffice failed to generate PDF file") - - actual_pdf_path = os.path.join(screenshots_dir, pdf_files[0]) - print(f"Generated PDF: {actual_pdf_path}") - return actual_pdf_path - - except Exception as e: - # Re-raise the specific exceptions we've already handled - if "timed out" in str(e) or "failed:" in str(e): - raise - # Handle any other unexpected exceptions - raise Exception(f"Screenshot generation failed: {str(e)}") diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/electron/servers/fastapi/api/v1/ppt/endpoints/presentation.py index 1b596f9a..ec147dce 100644 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/presentation.py +++ b/electron/servers/fastapi/api/v1/ppt/endpoints/presentation.py @@ -24,16 +24,13 @@ from models.presentation_outline_model import ( from enums.tone import Tone from enums.verbosity import Verbosity from models.pptx_models import PptxPresentationModel -from models.presentation_layout import PresentationLayoutModel from models.presentation_structure_model import PresentationStructureModel from models.presentation_with_slides import ( PresentationWithSlides, ) from models.sql.template import TemplateModel - from services.documents_loader import DocumentsLoader from services.webhook_service import WebhookService -from utils.get_layout_by_name import get_layout_by_name from services.image_generation_service import ImageGenerationService from utils.dict_utils import deep_update from utils.export_utils import export_presentation @@ -70,6 +67,8 @@ from utils.process_slides import ( process_slide_add_placeholder_assets, process_slide_and_fetch_assets, ) +from templates.get_layout_by_name import get_layout_by_name +from templates.presentation_layout import PresentationLayoutModel import uuid @@ -881,6 +880,8 @@ async def generate_presentation_sync( return await generate_presentation_handler( request, presentation_id, None, sql_session ) + except HTTPException: + raise except Exception: traceback.print_exc() raise HTTPException(status_code=500, detail="Presentation generation failed") diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py b/electron/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py deleted file mode 100644 index 00db8019..00000000 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/slide_to_html.py +++ /dev/null @@ -1,1013 +0,0 @@ -import os -import base64 -from datetime import datetime -from typing import Optional, List, Dict -from uuid import UUID -from fastapi import APIRouter, HTTPException, File, UploadFile, Form, Depends -from pydantic import BaseModel -from openai import OpenAI -from openai import APIError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete, func -from utils.asset_directory_utils import get_images_directory, resolve_image_path_to_filesystem -from services.database import get_async_session -from models.sql.presentation_layout_code import PresentationLayoutCodeModel -from .prompts import ( - GENERATE_HTML_SYSTEM_PROMPT, - HTML_TO_REACT_SYSTEM_PROMPT, - HTML_EDIT_SYSTEM_PROMPT, -) -from models.sql.template import TemplateModel - - -# Create separate routers for each functionality -SLIDE_TO_HTML_ROUTER = APIRouter(prefix="/slide-to-html", tags=["slide-to-html"]) -HTML_TO_REACT_ROUTER = APIRouter(prefix="/html-to-react", tags=["html-to-react"]) -HTML_EDIT_ROUTER = APIRouter(prefix="/html-edit", tags=["html-edit"]) -LAYOUT_MANAGEMENT_ROUTER = APIRouter( - prefix="/template-management", tags=["template-management"] -) - - -# Request/Response models for slide-to-html endpoint -class SlideToHtmlRequest(BaseModel): - image: str # Partial path to image file (e.g., "/app_data/images/uuid/slide_1.png") - xml: str # OXML content as text - fonts: Optional[List[str]] = None # Optional normalized root fonts for this slide - - -class SlideToHtmlResponse(BaseModel): - success: bool - html: str - - -# Request/Response models for html-edit endpoint -class HtmlEditResponse(BaseModel): - success: bool - edited_html: str - message: Optional[str] = None - - -# Request/Response models for html-to-react endpoint -class HtmlToReactRequest(BaseModel): - html: str # HTML content to convert to React component - image: Optional[str] = None # Optional image path to provide visual context - - -class HtmlToReactResponse(BaseModel): - success: bool - react_component: str - message: Optional[str] = None - - -# Request/Response models for layout management endpoints -class LayoutData(BaseModel): - presentation: UUID # UUID of the presentation - layout_id: str # Unique identifier for the layout - layout_name: str # Display name of the layout - layout_code: str # TSX/React component code for the layout - fonts: Optional[List[str]] = None # Optional list of font links - - -class SaveLayoutsRequest(BaseModel): - layouts: list[LayoutData] - - -class SaveLayoutsResponse(BaseModel): - success: bool - saved_count: int - message: Optional[str] = None - - -class GetLayoutsResponse(BaseModel): - success: bool - layouts: list[LayoutData] - message: Optional[str] = None - template: Optional[dict] = None - fonts: Optional[List[str]] = None - - -class PresentationSummary(BaseModel): - presentation_id: UUID - layout_count: int - last_updated_at: Optional[datetime] = None - template: Optional[dict] = None - - -class GetPresentationSummaryResponse(BaseModel): - success: bool - presentations: List[PresentationSummary] - total_presentations: int - total_layouts: int - message: Optional[str] = None - - -class ErrorResponse(BaseModel): - success: bool = False - detail: str - error_code: Optional[str] = None - - -class TemplateCreateRequest(BaseModel): - id: UUID - name: str - description: Optional[str] = None - - -class TemplateCreateResponse(BaseModel): - success: bool - template: dict - message: Optional[str] = None - - -class TemplateInfo(BaseModel): - id: UUID - name: Optional[str] = None - description: Optional[str] = None - created_at: Optional[datetime] = None - - -async def generate_html_from_slide( - base64_image: str, - media_type: str, - xml_content: str, - api_key: str, - fonts: Optional[List[str]] = None, -) -> str: - """ - Generate HTML content from slide image and XML using OpenAI GPT-5 Responses API. - - Args: - base64_image: Base64 encoded image data - media_type: MIME type of the image (e.g., 'image/png') - xml_content: OXML content as text - api_key: OpenAI API key - fonts: Optional list of normalized root font families to prefer in output - - Returns: - Generated HTML content as string - - Raises: - HTTPException: If API call fails or no content is generated - """ - print( - f"Generating HTML from slide image and XML using OpenAI GPT-5 Responses API..." - ) - try: - client = OpenAI(api_key=api_key) - - # Compose input for Responses API. Include system prompt, image (separate), OXML and optional fonts text. - data_url = f"data:{media_type};base64,{base64_image}" - fonts_text = ( - f"\nFONTS (Normalized root families used in this slide, use where it is required): {', '.join(fonts)}" - if fonts - else "" - ) - user_text = f"OXML: \n\n{fonts_text}" - input_payload = [ - {"role": "system", "content": GENERATE_HTML_SYSTEM_PROMPT}, - { - "role": "user", - "content": [ - {"type": "input_image", "image_url": data_url}, - {"type": "input_text", "text": user_text}, - ], - }, - ] - - print("Making Responses API request for HTML generation...") - response = client.responses.create( - model="gpt-5", - input=input_payload, - reasoning={"effort": "high"}, - text={"verbosity": "low"}, - ) - - # Extract the response text - html_content = ( - getattr(response, "output_text", None) - or getattr(response, "text", None) - or "" - ) - - print(f"Received HTML content length: {len(html_content)}") - - if not html_content: - raise HTTPException( - status_code=500, detail="No HTML content generated by OpenAI GPT-5" - ) - - return html_content - - except APIError as e: - print(f"OpenAI API Error: {e}") - raise HTTPException( - status_code=500, detail=f"OpenAI API error during HTML generation: {str(e)}" - ) - except Exception as e: - # Handle various API errors - error_msg = str(e) - print(f"Exception occurred: {error_msg}") - print(f"Exception type: {type(e)}") - if "timeout" in error_msg.lower(): - raise HTTPException( - status_code=408, - detail=f"OpenAI API timeout during HTML generation: {error_msg}", - ) - elif "connection" in error_msg.lower(): - raise HTTPException( - status_code=503, - detail=f"OpenAI API connection error during HTML generation: {error_msg}", - ) - else: - raise HTTPException( - status_code=500, - detail=f"OpenAI API error during HTML generation: {error_msg}", - ) - - -async def generate_react_component_from_html( - html_content: str, - api_key: str, - image_base64: Optional[str] = None, - media_type: Optional[str] = None, -) -> str: - """ - Convert HTML content to TSX React component using OpenAI GPT-5 Responses API. - - Args: - html_content: Generated HTML content - api_key: OpenAI API key - - Returns: - Generated TSX React component code as string - - Raises: - HTTPException: If API call fails or no content is generated - """ - try: - client = OpenAI(api_key=api_key) - - print("Making Responses API request for React component generation...") - - # Build payload with optional image - content_parts = [{"type": "input_text", "text": f"HTML INPUT:\n{html_content}"}] - if image_base64 and media_type: - data_url = f"data:{media_type};base64,{image_base64}" - content_parts.insert(0, {"type": "input_image", "image_url": data_url}) - - input_payload = [ - {"role": "system", "content": HTML_TO_REACT_SYSTEM_PROMPT}, - {"role": "user", "content": content_parts}, - ] - - response = client.responses.create( - model="gpt-5", - input=input_payload, - reasoning={"effort": "minimal"}, - text={"verbosity": "low"}, - ) - - react_content = ( - getattr(response, "output_text", None) - or getattr(response, "text", None) - or "" - ) - - print(f"Received React content length: {len(react_content)}") - - if not react_content: - raise HTTPException( - status_code=500, detail="No React component generated by OpenAI GPT-5" - ) - - react_content = ( - react_content.replace("```tsx", "") - .replace("```", "") - .replace("typescript", "") - .replace("javascript", "") - ) - - # Filter out lines that start with import or export - filtered_lines = [] - for line in react_content.split("\n"): - stripped_line = line.strip() - if not ( - stripped_line.startswith("import ") - or stripped_line.startswith("export ") - ): - filtered_lines.append(line) - - filtered_react_content = "\n".join(filtered_lines) - print(f"Filtered React content length: {len(filtered_react_content)}") - - return filtered_react_content - except APIError as e: - print(f"OpenAI API Error: {e}") - raise HTTPException( - status_code=500, - detail=f"OpenAI API error during React generation: {str(e)}", - ) - except Exception as e: - # Handle various API errors - error_msg = str(e) - print(f"Exception occurred: {error_msg}") - print(f"Exception type: {type(e)}") - if "timeout" in error_msg.lower(): - raise HTTPException( - status_code=408, - detail=f"OpenAI API timeout during React generation: {error_msg}", - ) - elif "connection" in error_msg.lower(): - raise HTTPException( - status_code=503, - detail=f"OpenAI API connection error during React generation: {error_msg}", - ) - else: - raise HTTPException( - status_code=500, - detail=f"OpenAI API error during React generation: {error_msg}", - ) - - -async def edit_html_with_images( - current_ui_base64: str, - sketch_base64: Optional[str], - media_type: str, - html_content: str, - prompt: str, - api_key: str, -) -> str: - """ - Edit HTML content based on one or two images and a text prompt using OpenAI GPT-5 Responses API. - - Args: - current_ui_base64: Base64 encoded current UI image data - sketch_base64: Base64 encoded sketch/indication image data (optional) - media_type: MIME type of the images (e.g., 'image/png') - html_content: Current HTML content to edit - prompt: Text prompt describing the changes - api_key: OpenAI API key - - Returns: - Edited HTML content as string - - Raises: - HTTPException: If API call fails or no content is generated - """ - try: - client = OpenAI(api_key=api_key) - - print("Making Responses API request for HTML editing...") - - current_data_url = f"data:{media_type};base64,{current_ui_base64}" - sketch_data_url = ( - f"data:{media_type};base64,{sketch_base64}" if sketch_base64 else None - ) - - content_parts = [ - {"type": "input_image", "image_url": current_data_url}, - { - "type": "input_text", - "text": f"CURRENT HTML TO EDIT:\n{html_content}\n\nTEXT PROMPT FOR CHANGES:\n{prompt}", - }, - ] - if sketch_data_url: - # Insert sketch image after current UI image for context - content_parts.insert( - 1, {"type": "input_image", "image_url": sketch_data_url} - ) - - input_payload = [ - {"role": "system", "content": HTML_EDIT_SYSTEM_PROMPT}, - {"role": "user", "content": content_parts}, - ] - - response = client.responses.create( - model="gpt-5", - input=input_payload, - reasoning={"effort": "low"}, - text={"verbosity": "low"}, - ) - - edited_html = ( - getattr(response, "output_text", None) - or getattr(response, "text", None) - or "" - ) - - print(f"Received edited HTML content length: {len(edited_html)}") - - if not edited_html: - raise HTTPException( - status_code=500, - detail="No edited HTML content generated by OpenAI GPT-5", - ) - - return edited_html - - except APIError as e: - print(f"OpenAI API Error: {e}") - raise HTTPException( - status_code=500, detail=f"OpenAI API error during HTML editing: {str(e)}" - ) - except Exception as e: - # Handle various API errors - error_msg = str(e) - print(f"Exception occurred: {error_msg}") - print(f"Exception type: {type(e)}") - if "timeout" in error_msg.lower(): - raise HTTPException( - status_code=408, - detail=f"OpenAI API timeout during HTML editing: {error_msg}", - ) - elif "connection" in error_msg.lower(): - raise HTTPException( - status_code=503, - detail=f"OpenAI API connection error during HTML editing: {error_msg}", - ) - else: - raise HTTPException( - status_code=500, - detail=f"OpenAI API error during HTML editing: {error_msg}", - ) - - -# ENDPOINT 1: Slide to HTML conversion -@SLIDE_TO_HTML_ROUTER.post("/", response_model=SlideToHtmlResponse) -async def convert_slide_to_html(request: SlideToHtmlRequest): - """ - Convert a slide image and its OXML data to HTML using Anthropic Claude API. - - Args: - request: JSON request containing image path and XML content - - Returns: - SlideToHtmlResponse with generated HTML - """ - try: - # Get OpenAI API key from environment - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise HTTPException( - status_code=500, detail="OPENAI_API_KEY environment variable not set" - ) - - # Resolve image path to actual file system path - actual_image_path = resolve_image_path_to_filesystem(request.image) - if not actual_image_path: - raise HTTPException( - status_code=404, detail=f"Image file not found: {request.image}" - ) - - # Read and encode image to base64 - with open(actual_image_path, "rb") as image_file: - image_content = image_file.read() - base64_image = base64.b64encode(image_content).decode("utf-8") - - # Determine media type from file extension - file_extension = os.path.splitext(actual_image_path)[1].lower() - media_type_map = { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".webp": "image/webp", - } - media_type = media_type_map.get(file_extension, "image/png") - - # Generate HTML using the extracted function - html_content = await generate_html_from_slide( - base64_image=base64_image, - media_type=media_type, - xml_content=request.xml, - api_key=api_key, - fonts=request.fonts, - ) - - html_content = html_content.replace("```html", "").replace("```", "") - - return SlideToHtmlResponse(success=True, html=html_content) - - except HTTPException: - # Re-raise HTTP exceptions as-is - raise - except Exception as e: - # Log the full error for debugging - print(f"Unexpected error during slide to HTML processing: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Error processing slide to HTML: {str(e)}" - ) - - -# ENDPOINT 2: HTML to React component conversion -@HTML_TO_REACT_ROUTER.post("/", response_model=HtmlToReactResponse) -async def convert_html_to_react(request: HtmlToReactRequest): - """ - Convert HTML content to TSX React component using Anthropic Claude API. - - Args: - request: JSON request containing HTML content - - Returns: - HtmlToReactResponse with generated React component - """ - try: - # Get OpenAI API key from environment - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise HTTPException( - status_code=500, detail="OPENAI_API_KEY environment variable not set" - ) - - # Validate HTML content - if not request.html or not request.html.strip(): - raise HTTPException(status_code=400, detail="HTML content cannot be empty") - - # Optionally resolve image and encode to base64 - image_b64 = None - media_type = None - if request.image: - actual_image_path = resolve_image_path_to_filesystem(request.image) - if actual_image_path: - with open(actual_image_path, "rb") as f: - image_b64 = base64.b64encode(f.read()).decode("utf-8") - ext = os.path.splitext(actual_image_path)[1].lower() - media_type = { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".webp": "image/webp", - }.get(ext, "image/png") - - # Convert HTML to React component - react_component = await generate_react_component_from_html( - html_content=request.html, - api_key=api_key, - image_base64=image_b64, - media_type=media_type, - ) - - react_component = react_component.replace("```tsx", "").replace("```", "") - - return HtmlToReactResponse( - success=True, - react_component=react_component, - message="React component generated successfully", - ) - - except HTTPException: - # Re-raise HTTP exceptions as-is - raise - except Exception as e: - # Log the full error for debugging - print(f"Unexpected error during HTML to React processing: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Error processing HTML to React: {str(e)}" - ) - - -# ENDPOINT 3: HTML editing with images -@HTML_EDIT_ROUTER.post("/", response_model=HtmlEditResponse) -async def edit_html_with_images_endpoint( - current_ui_image: UploadFile = File(..., description="Current UI image file"), - sketch_image: Optional[UploadFile] = File( - None, description="Sketch/indication image file (optional)" - ), - html: str = Form(..., description="Current HTML content to edit"), - prompt: str = Form(..., description="Text prompt describing the changes"), -): - """ - Edit HTML content based on one or two uploaded images and a text prompt using Anthropic Claude API. - - Args: - current_ui_image: Uploaded current UI image file - sketch_image: Uploaded sketch/indication image file (optional) - html: Current HTML content to edit (form data) - prompt: Text prompt describing the changes (form data) - - Returns: - HtmlEditResponse with edited HTML - """ - try: - # Get OpenAI API key from environment - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise HTTPException( - status_code=500, detail="OPENAI_API_KEY environment variable not set" - ) - - # Validate inputs - if not html or not html.strip(): - raise HTTPException(status_code=400, detail="HTML content cannot be empty") - - if not prompt or not prompt.strip(): - raise HTTPException(status_code=400, detail="Text prompt cannot be empty") - - # Validate current UI image file - if ( - not current_ui_image.content_type - or not current_ui_image.content_type.startswith("image/") - ): - raise HTTPException( - status_code=400, detail="Current UI file must be an image" - ) - - # Validate sketch image file only if provided - if sketch_image and ( - not sketch_image.content_type - or not sketch_image.content_type.startswith("image/") - ): - raise HTTPException(status_code=400, detail="Sketch file must be an image") - - # Read and encode current UI image to base64 - current_ui_content = await current_ui_image.read() - current_ui_base64 = base64.b64encode(current_ui_content).decode("utf-8") - - # Read and encode sketch image to base64 only if provided - sketch_base64 = None - if sketch_image: - sketch_content = await sketch_image.read() - sketch_base64 = base64.b64encode(sketch_content).decode("utf-8") - - # Use the content type from the uploaded files - media_type = current_ui_image.content_type - - # Edit HTML using the function - edited_html = await edit_html_with_images( - current_ui_base64=current_ui_base64, - sketch_base64=sketch_base64, - media_type=media_type, - html_content=html, - prompt=prompt, - api_key=api_key, - ) - - edited_html = edited_html.replace("```html", "").replace("```", "") - - return HtmlEditResponse( - success=True, edited_html=edited_html, message="HTML edited successfully" - ) - - except HTTPException: - # Re-raise HTTP exceptions as-is - raise - except Exception as e: - # Log the full error for debugging - print(f"Unexpected error during HTML editing: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Error processing HTML editing: {str(e)}" - ) - - -# ENDPOINT 4: Save layouts for a presentation -@LAYOUT_MANAGEMENT_ROUTER.post( - "/save-templates", - response_model=SaveLayoutsResponse, - responses={ - 400: {"model": ErrorResponse, "description": "Validation error"}, - 500: {"model": ErrorResponse, "description": "Internal server error"}, - }, -) -async def save_layouts( - request: SaveLayoutsRequest, session: AsyncSession = Depends(get_async_session) -): - """ - Save multiple layouts for presentations. - - Args: - request: JSON request containing array of layout data - session: Database session - - Returns: - SaveLayoutsResponse with success status and count of saved layouts - - Raises: - HTTPException: 400 for validation errors, 500 for server errors - """ - try: - # Validate request data - if not request.layouts: - raise HTTPException(status_code=400, detail="Layouts array cannot be empty") - - if len(request.layouts) > 50: # Reasonable limit - raise HTTPException( - status_code=400, detail="Cannot save more than 50 layouts at once" - ) - - saved_count = 0 - - for i, layout_data in enumerate(request.layouts): - # Validate individual layout data - if ( - not layout_data.presentation - or not str(layout_data.presentation).strip() - ): - raise HTTPException( - status_code=400, - detail=f"Layout {i+1}: presentation_id cannot be empty", - ) - - if not layout_data.layout_id or not layout_data.layout_id.strip(): - raise HTTPException( - status_code=400, detail=f"Layout {i+1}: layout_id cannot be empty" - ) - - if not layout_data.layout_name or not layout_data.layout_name.strip(): - raise HTTPException( - status_code=400, detail=f"Layout {i+1}: layout_name cannot be empty" - ) - - if not layout_data.layout_code or not layout_data.layout_code.strip(): - raise HTTPException( - status_code=400, detail=f"Layout {i+1}: layout_code cannot be empty" - ) - - # Check if layout already exists for this presentation and layout_id - stmt = select(PresentationLayoutCodeModel).where( - PresentationLayoutCodeModel.presentation == layout_data.presentation, - PresentationLayoutCodeModel.layout_id == layout_data.layout_id, - ) - result = await session.execute(stmt) - existing_layout = result.scalar_one_or_none() - - if existing_layout: - # Update existing layout - existing_layout.layout_name = layout_data.layout_name - existing_layout.layout_code = layout_data.layout_code - existing_layout.fonts = layout_data.fonts - existing_layout.updated_at = datetime.now() - else: - # Create new layout - new_layout = PresentationLayoutCodeModel( - presentation=layout_data.presentation, - layout_id=layout_data.layout_id, - layout_name=layout_data.layout_name, - layout_code=layout_data.layout_code, - fonts=layout_data.fonts, - ) - session.add(new_layout) - - saved_count += 1 - - await session.commit() - - return SaveLayoutsResponse( - success=True, - saved_count=saved_count, - message=f"Successfully saved {saved_count} layout(s)", - ) - - except HTTPException: - # Re-raise HTTP exceptions as-is - await session.rollback() - raise - except Exception as e: - await session.rollback() - print(f"Unexpected error saving layouts: {str(e)}") - raise HTTPException( - status_code=500, - detail=f"Internal server error while saving layouts: {str(e)}", - ) - - -# ENDPOINT 5: Get layouts for a presentation -@LAYOUT_MANAGEMENT_ROUTER.get( - "/get-templates/{presentation}", - response_model=GetLayoutsResponse, - responses={ - 400: {"model": ErrorResponse, "description": "Invalid presentation ID"}, - 404: { - "model": ErrorResponse, - "description": "No layouts found for presentation", - }, - 500: {"model": ErrorResponse, "description": "Internal server error"}, - }, -) -async def get_layouts( - presentation: UUID, session: AsyncSession = Depends(get_async_session) -): - """ - Retrieve all layouts for a specific presentation. - - Args: - presentation: UUID of the presentation - session: Database session - - Returns: - GetLayoutsResponse with layouts data - - Raises: - HTTPException: 404 if no layouts found, 400 for invalid UUID, 500 for server errors - """ - try: - # Validate presentation_id format (basic UUID check) - if not presentation or len(str(presentation).strip()) == 0: - raise HTTPException( - status_code=400, detail="Presentation ID cannot be empty" - ) - - # Query layouts for the given presentation_id - stmt = select(PresentationLayoutCodeModel).where( - PresentationLayoutCodeModel.presentation == presentation - ) - result = await session.execute(stmt) - layouts_db = result.scalars().all() - - # Check if any layouts were found - if not layouts_db: - raise HTTPException( - status_code=404, - detail=f"No layouts found for presentation ID: {presentation}", - ) - - # Convert to response format - layouts = [ - LayoutData( - presentation=layout.presentation, - layout_id=layout.layout_id, - layout_name=layout.layout_name, - layout_code=layout.layout_code, - fonts=layout.fonts, - ) - for layout in layouts_db - ] - - # Aggregate unique fonts across all layouts - aggregated_fonts: set[str] = set() - for layout in layouts_db: - if layout.fonts: - aggregated_fonts.update([f for f in layout.fonts if isinstance(f, str)]) - fonts_list = sorted(list(aggregated_fonts)) if aggregated_fonts else None - - # Fetch template meta - template_meta = await session.get(TemplateModel, presentation) - template = None - if template_meta: - template = { - "id": template_meta.id, - "name": template_meta.name, - "description": template_meta.description, - "created_at": template_meta.created_at, - } - - return GetLayoutsResponse( - success=True, - layouts=layouts, - message=f"Retrieved {len(layouts)} layout(s) for presentation {presentation}", - template=template, - fonts=fonts_list, - ) - - except HTTPException: - # Re-raise HTTP exceptions as-is - raise - except Exception as e: - print(f"Error retrieving layouts for presentation {presentation}: {str(e)}") - raise HTTPException( - status_code=500, - detail=f"Internal server error while retrieving layouts: {str(e)}", - ) - - -# ENDPOINT: Get all presentations with layout counts -@LAYOUT_MANAGEMENT_ROUTER.get( - "/summary", - response_model=GetPresentationSummaryResponse, - summary="Get all presentations with layout counts", - description="Retrieve a summary of all presentations and the number of layouts in each", - responses={ - 200: { - "model": GetPresentationSummaryResponse, - "description": "Presentations summary retrieved successfully", - }, - 500: {"model": ErrorResponse, "description": "Internal server error"}, - }, -) -async def get_presentations_summary( - session: AsyncSession = Depends(get_async_session), -): - """ - Get summary of all presentations with their layout counts. - """ - try: - # Query to get presentation_id, count of layouts, and MAX(updated_at) - stmt = select( - PresentationLayoutCodeModel.presentation, - func.count(PresentationLayoutCodeModel.id).label("layout_count"), - func.max(PresentationLayoutCodeModel.updated_at).label("last_updated_at"), - ).group_by(PresentationLayoutCodeModel.presentation) - - result = await session.execute(stmt) - presentation_data = result.all() - - # Convert to response format with template info if available - presentations = [] - for row in presentation_data: - template_meta = await session.get(TemplateModel, row.presentation) - template = None - if template_meta: - template = { - "id": template_meta.id, - "name": template_meta.name, - "description": template_meta.description, - "created_at": template_meta.created_at, - } - presentations.append( - PresentationSummary( - presentation_id=row.presentation, - layout_count=row.layout_count, - last_updated_at=row.last_updated_at, - template=template, - ) - ) - - # Calculate totals - total_presentations = len(presentations) - total_layouts = sum(p.layout_count for p in presentations) - - return GetPresentationSummaryResponse( - success=True, - presentations=presentations, - total_presentations=total_presentations, - total_layouts=total_layouts, - message=f"Retrieved {total_presentations} presentation(s) with {total_layouts} total layout(s)", - ) - - except Exception as e: - print(f"Error retrieving presentations summary: {str(e)}") - raise HTTPException( - status_code=500, - detail=f"Internal server error while retrieving presentations summary: {str(e)}", - ) - - -@LAYOUT_MANAGEMENT_ROUTER.post( - "/templates", - response_model=TemplateCreateResponse, - responses={ - 400: {"model": ErrorResponse, "description": "Validation error"}, - 500: {"model": ErrorResponse, "description": "Internal server error"}, - }, -) -async def create_template( - request: TemplateCreateRequest, - session: AsyncSession = Depends(get_async_session), -): - try: - if not request.id or not request.name: - raise HTTPException(status_code=400, detail="id and name are required") - - # Upsert template by id - existing = await session.get(TemplateModel, request.id) - if existing: - existing.name = request.name - existing.description = request.description - else: - session.add( - TemplateModel( - id=request.id, name=request.name, description=request.description - ) - ) - await session.commit() - - # Read back - template = await session.get(TemplateModel, request.id) - return TemplateCreateResponse( - success=True, - template={ - "id": template.id, - "name": template.name, - "description": template.description, - "created_at": template.created_at, - }, - message="Template saved", - ) - except HTTPException: - await session.rollback() - raise - except Exception as e: - await session.rollback() - raise HTTPException( - status_code=500, detail=f"Failed to save template: {str(e)}" - ) - - -@LAYOUT_MANAGEMENT_ROUTER.delete("/delete-templates/{template_id}", status_code=204) -async def delete_template( - template_id: UUID, - session: AsyncSession = Depends(get_async_session), -): - try: - await session.execute( - delete(TemplateModel).where(TemplateModel.id == template_id) - ) - await session.execute( - delete(PresentationLayoutCodeModel).where( - PresentationLayoutCodeModel.presentation == template_id, - ) - ) - await session.commit() - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to delete template") diff --git a/electron/servers/fastapi/api/v1/ppt/router.py b/electron/servers/fastapi/api/v1/ppt/router.py index 9439cea1..23411e8c 100644 --- a/electron/servers/fastapi/api/v1/ppt/router.py +++ b/electron/servers/fastapi/api/v1/ppt/router.py @@ -1,13 +1,10 @@ from fastapi import APIRouter -from api.v1.ppt.endpoints.slide_to_html import SLIDE_TO_HTML_ROUTER, HTML_TO_REACT_ROUTER, HTML_EDIT_ROUTER, LAYOUT_MANAGEMENT_ROUTER from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER from api.v1.ppt.endpoints.anthropic import ANTHROPIC_ROUTER from api.v1.ppt.endpoints.google import GOOGLE_ROUTER from api.v1.ppt.endpoints.openai import OPENAI_ROUTER from api.v1.ppt.endpoints.files import FILES_ROUTER -from api.v1.ppt.endpoints.pptx_slides import PPTX_SLIDES_ROUTER -from api.v1.ppt.endpoints.pdf_slides import PDF_SLIDES_ROUTER from api.v1.ppt.endpoints.fonts import FONTS_ROUTER from api.v1.ppt.endpoints.icons import ICONS_ROUTER from api.v1.ppt.endpoints.images import IMAGES_ROUTER @@ -15,9 +12,9 @@ from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER from api.v1.ppt.endpoints.slide import SLIDE_ROUTER from api.v1.ppt.endpoints.codex_auth import CODEX_AUTH_ROUTER -from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER from api.v1.ppt.endpoints.theme import THEMES_ROUTER from api.v1.ppt.endpoints.theme_generate import THEME_ROUTER +from templates.router import TEMPLATE_ROUTER API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt") @@ -26,20 +23,14 @@ API_V1_PPT_ROUTER.include_router(FILES_ROUTER) API_V1_PPT_ROUTER.include_router(FONTS_ROUTER) API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER) API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER) -API_V1_PPT_ROUTER.include_router(PPTX_SLIDES_ROUTER) API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER) -API_V1_PPT_ROUTER.include_router(SLIDE_TO_HTML_ROUTER) -API_V1_PPT_ROUTER.include_router(HTML_TO_REACT_ROUTER) -API_V1_PPT_ROUTER.include_router(HTML_EDIT_ROUTER) -API_V1_PPT_ROUTER.include_router(LAYOUT_MANAGEMENT_ROUTER) +API_V1_PPT_ROUTER.include_router(TEMPLATE_ROUTER) API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER) API_V1_PPT_ROUTER.include_router(ICONS_ROUTER) API_V1_PPT_ROUTER.include_router(THEMES_ROUTER) API_V1_PPT_ROUTER.include_router(THEME_ROUTER) API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER) -API_V1_PPT_ROUTER.include_router(PDF_SLIDES_ROUTER) API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER) API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER) API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER) API_V1_PPT_ROUTER.include_router(CODEX_AUTH_ROUTER) -API_V1_PPT_ROUTER.include_router(PPTX_FONTS_ROUTER) diff --git a/electron/servers/fastapi/models/sql/presentation.py b/electron/servers/fastapi/models/sql/presentation.py index d2d57e6c..2645ab63 100644 --- a/electron/servers/fastapi/models/sql/presentation.py +++ b/electron/servers/fastapi/models/sql/presentation.py @@ -4,9 +4,9 @@ import uuid from sqlalchemy import JSON, Column, DateTime, String from sqlmodel import Boolean, Field, SQLModel -from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from models.presentation_structure_model import PresentationStructureModel +from templates.presentation_layout import PresentationLayoutModel from utils.datetime_utils import get_current_utc_datetime diff --git a/electron/servers/fastapi/models/sql/presentation_layout_code.py b/electron/servers/fastapi/models/sql/presentation_layout_code.py index fe57c01e..636ba585 100644 --- a/electron/servers/fastapi/models/sql/presentation_layout_code.py +++ b/electron/servers/fastapi/models/sql/presentation_layout_code.py @@ -1,15 +1,14 @@ from datetime import datetime -from typing import Optional, List +from typing import Optional import uuid -from sqlalchemy import Column, DateTime, Text, JSON -from sqlmodel import SQLModel, Field + +from sqlalchemy import JSON, Column, DateTime, Text +from sqlmodel import Field, SQLModel from utils.datetime_utils import get_current_utc_datetime class PresentationLayoutCodeModel(SQLModel, table=True): - """Model for storing presentation layout codes""" - __tablename__ = "presentation_layout_codes" id: Optional[int] = Field(default=None, primary_key=True) @@ -19,8 +18,10 @@ class PresentationLayoutCodeModel(SQLModel, table=True): layout_code: str = Field( sa_column=Column(Text), description="TSX/React component code for the layout" ) - fonts: Optional[List[str]] = Field( - sa_column=Column(JSON), default=None, description="Optional list of font links" + fonts: Optional[dict[str, str] | list[str]] = Field( + default=None, + sa_column=Column(JSON, nullable=True), + description="Optional font metadata associated with the layout", ) created_at: datetime = Field( sa_column=Column( diff --git a/electron/servers/fastapi/models/sql/template.py b/electron/servers/fastapi/models/sql/template.py index a5ca53fe..f3727151 100644 --- a/electron/servers/fastapi/models/sql/template.py +++ b/electron/servers/fastapi/models/sql/template.py @@ -1,8 +1,9 @@ from datetime import datetime from typing import Optional import uuid + from sqlalchemy import Column, DateTime -from sqlmodel import SQLModel, Field +from sqlmodel import Field, SQLModel from utils.datetime_utils import get_current_utc_datetime diff --git a/electron/servers/fastapi/models/sql/template_create_info.py b/electron/servers/fastapi/models/sql/template_create_info.py new file mode 100644 index 00000000..ce363587 --- /dev/null +++ b/electron/servers/fastapi/models/sql/template_create_info.py @@ -0,0 +1,25 @@ +from datetime import datetime +import uuid + +from sqlalchemy import JSON, Column, DateTime +from sqlmodel import Field, SQLModel + +from utils.datetime_utils import get_current_utc_datetime + + +class TemplateCreateInfoModel(SQLModel, table=True): + __tablename__ = "template_create_infos" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + fonts: dict[str, str] | None = Field( + default=None, + sa_column=Column(JSON, nullable=True), + ) + pptx_url: str | None = Field(default=None) + slide_htmls: list[str] = Field(sa_column=Column(JSON, nullable=False)) + slide_image_urls: list[str] = Field(sa_column=Column(JSON, nullable=False)) + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, default=get_current_utc_datetime + ) + ) diff --git a/electron/servers/fastapi/pyproject.toml b/electron/servers/fastapi/pyproject.toml index 460739d7..b7ed26ec 100644 --- a/electron/servers/fastapi/pyproject.toml +++ b/electron/servers/fastapi/pyproject.toml @@ -42,4 +42,4 @@ dev = [ [tool.setuptools.packages.find] where = ["."] -include = ["api*", "enums*", "models*", "services*", "constants*", "utils*"] +include = ["api*", "enums*", "models*", "services*", "constants*", "utils*", "templates*"] diff --git a/electron/servers/fastapi/services/database.py b/electron/servers/fastapi/services/database.py index 7188cd92..dcc03b01 100644 --- a/electron/servers/fastapi/services/database.py +++ b/electron/servers/fastapi/services/database.py @@ -14,10 +14,11 @@ from models.sql.async_presentation_generation_status import ( from models.sql.image_asset import ImageAsset from models.sql.key_value import KeyValueSqlModel from models.sql.ollama_pull_status import OllamaPullStatus -from models.sql.presentation import PresentationModel -from models.sql.slide import SlideModel from models.sql.presentation_layout_code import PresentationLayoutCodeModel +from models.sql.presentation import PresentationModel from models.sql.template import TemplateModel +from models.sql.template_create_info import TemplateCreateInfoModel +from models.sql.slide import SlideModel from models.sql.webhook_subscription import WebhookSubscription from utils.db_utils import get_database_url_and_connect_args from utils.get_env import get_app_data_directory_env @@ -65,6 +66,7 @@ async def create_db_and_tables(): KeyValueSqlModel.__table__, ImageAsset.__table__, PresentationLayoutCodeModel.__table__, + TemplateCreateInfoModel.__table__, TemplateModel.__table__, WebhookSubscription.__table__, AsyncPresentationGenerationTaskModel.__table__, diff --git a/electron/servers/fastapi/templates/__init__.py b/electron/servers/fastapi/templates/__init__.py new file mode 100644 index 00000000..a9a2c5b3 --- /dev/null +++ b/electron/servers/fastapi/templates/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/electron/servers/fastapi/templates/example.py b/electron/servers/fastapi/templates/example.py new file mode 100644 index 00000000..fab145d3 --- /dev/null +++ b/electron/servers/fastapi/templates/example.py @@ -0,0 +1,98 @@ +from typing import Any + +from templates.presentation_layout import PresentationLayoutModel + +PLACEHOLDER_IMAGE_URL = "https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png" +PLACEHOLDER_ICON_URL = "https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg" + + +def build_schema_example(schema: dict) -> Any: + if not isinstance(schema, dict): + return None + + if "default" in schema: + return schema["default"] + + for key in ("anyOf", "oneOf", "allOf"): + options = schema.get(key) + if isinstance(options, list): + for option in options: + example = build_schema_example(option) + if example is not None: + return example + + enum_values = schema.get("enum") + if enum_values: + return enum_values[0] + + schema_type = schema.get("type") + if schema_type == "object": + properties = schema.get("properties", {}) + result = {} + for field_name, field_schema in properties.items(): + result[field_name] = build_schema_example(field_schema) + return result + + if schema_type == "array": + items_schema = schema.get("items", {}) + if "default" in schema: + return schema["default"] + item_example = build_schema_example(items_schema) + return [] if item_example is None else [item_example] + + if schema_type == "string": + schema_description = (schema.get("description") or "").lower() + if "icon" in schema_description: + return PLACEHOLDER_ICON_URL + if "image" in schema_description or "url" in schema_description: + return PLACEHOLDER_IMAGE_URL + return "Sample text" + + if schema_type == "integer": + return schema.get("minimum", 1) + + if schema_type == "number": + return schema.get("minimum", 1) + + if schema_type == "boolean": + return False + + return None + + +def replace_special_placeholders(value: Any) -> Any: + if isinstance(value, dict): + result = {} + for key, child in value.items(): + if key == "__image_url__": + result[key] = PLACEHOLDER_IMAGE_URL + elif key == "__icon_url__": + result[key] = PLACEHOLDER_ICON_URL + else: + result[key] = replace_special_placeholders(child) + return result + + if isinstance(value, list): + return [replace_special_placeholders(item) for item in value] + + if value == "__image_url__": + return PLACEHOLDER_IMAGE_URL + if value == "__icon_url__": + return PLACEHOLDER_ICON_URL + return value + + +def build_template_example( + template_id: str, layout: PresentationLayoutModel +) -> dict[str, Any]: + slides = [] + for slide in layout.slides: + example_content = replace_special_placeholders( + build_schema_example(slide.json_schema) + ) + slides.append({"layout": slide.id, "content": example_content}) + + return { + "template": template_id, + "slides": slides, + } diff --git a/electron/servers/fastapi/templates/font_utils.py b/electron/servers/fastapi/templates/font_utils.py new file mode 100644 index 00000000..f8e118e3 --- /dev/null +++ b/electron/servers/fastapi/templates/font_utils.py @@ -0,0 +1,167 @@ +import asyncio +import re +import xml.etree.ElementTree as ET +from typing import Iterable + +import aiohttp + +_STYLE_TOKENS = { + "italic", + "italics", + "ital", + "oblique", + "roman", + "bolditalic", + "bolditalics", + "thin", + "hairline", + "extralight", + "ultralight", + "light", + "demilight", + "semilight", + "book", + "regular", + "normal", + "medium", + "semibold", + "demibold", + "bold", + "extrabold", + "ultrabold", + "black", + "extrablack", + "ultrablack", + "heavy", + "narrow", + "condensed", + "semicondensed", + "extracondensed", + "ultracondensed", + "expanded", + "semiexpanded", + "extraexpanded", + "ultraexpanded", +} +_STYLE_MODIFIERS = {"semi", "demi", "extra", "ultra"} + + +def _insert_spaces_in_camel_case(value: str) -> str: + value = re.sub(r"(?<=[a-z0-9])([A-Z])", r" \1", value) + value = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", value) + return value + + +def normalize_font_family_name(raw_name: str) -> str: + if not raw_name: + return raw_name + + name = raw_name.replace("_", " ").replace("-", " ") + name = _insert_spaces_in_camel_case(name) + name = re.sub(r"\s+", " ", name).strip() + lower_name = name.lower() + + for style in sorted(_STYLE_TOKENS, key=len, reverse=True): + suffix = " " + style + if lower_name.endswith(suffix): + name = name[: -len(suffix)] + lower_name = lower_name[: -len(suffix)] + break + + tokens_original = name.split(" ") + tokens_filtered: list[str] = [] + for index, token in enumerate(tokens_original): + lower_token = token.lower() + if index == 0: + tokens_filtered.append(token) + continue + if lower_token in _STYLE_TOKENS or lower_token in _STYLE_MODIFIERS: + continue + tokens_filtered.append(token) + + if not tokens_filtered: + tokens_filtered = tokens_original + + return re.sub(r"\s+", " ", " ".join(tokens_filtered).strip()) + + +def extract_fonts_from_oxml(xml_content: str) -> list[str]: + fonts = set() + + try: + root = ET.fromstring(xml_content) + namespaces = { + "a": "http://schemas.openxmlformats.org/drawingml/2006/main", + "p": "http://schemas.openxmlformats.org/presentationml/2006/main", + "r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships", + } + + for xpath in (".//a:latin", ".//a:ea", ".//a:cs", ".//a:font"): + for font_elem in root.findall(xpath, namespaces): + typeface = font_elem.attrib.get("typeface") + if typeface: + fonts.add(typeface) + + for rpr_elem in root.findall(".//a:rPr", namespaces): + for font_elem in rpr_elem.findall(".//a:latin", namespaces): + typeface = font_elem.attrib.get("typeface") + if typeface: + fonts.add(typeface) + + for font_elem in root.findall(".//latin"): + typeface = font_elem.attrib.get("typeface") + if typeface: + fonts.add(typeface) + + fonts.update(re.findall(r'typeface="([^"]+)"', xml_content)) + + system_fonts = {"+mn-lt", "+mj-lt", "+mn-ea", "+mj-ea", "+mn-cs", "+mj-cs", ""} + return sorted(font for font in fonts if font not in system_fonts and font.strip()) + except Exception: + return [] + + +def get_google_font_css_url(font_name: str) -> str: + return f"https://fonts.googleapis.com/css2?family={font_name.replace(' ', '+')}&display=swap" + + +async def check_google_font_availability(font_name: str) -> bool: + try: + async with aiohttp.ClientSession() as session: + async with session.head( + get_google_font_css_url(font_name), + timeout=aiohttp.ClientTimeout(total=10), + ) as response: + return response.status == 200 + except Exception: + return False + + +def collect_normalized_fonts_from_xmls(slide_xmls: Iterable[str]) -> list[str]: + raw_fonts = set() + for xml_content in slide_xmls: + raw_fonts.update(extract_fonts_from_oxml(xml_content)) + + normalized_fonts = {normalize_font_family_name(font) for font in raw_fonts} + return sorted(font for font in normalized_fonts if font) + + +async def get_available_and_unavailable_fonts( + font_names: Iterable[str], +) -> tuple[list[tuple[str, str]], list[tuple[str, None]]]: + normalized_fonts = sorted({font for font in font_names if font}) + if not normalized_fonts: + return [], [] + + results = await asyncio.gather( + *[check_google_font_availability(font) for font in normalized_fonts] + ) + + available_fonts: list[tuple[str, str]] = [] + unavailable_fonts: list[tuple[str, None]] = [] + for font_name, is_available in zip(normalized_fonts, results): + if is_available: + available_fonts.append((font_name, get_google_font_css_url(font_name))) + else: + unavailable_fonts.append((font_name, None)) + return available_fonts, unavailable_fonts diff --git a/electron/servers/fastapi/utils/get_layout_by_name.py b/electron/servers/fastapi/templates/get_layout_by_name.py similarity index 81% rename from electron/servers/fastapi/utils/get_layout_by_name.py rename to electron/servers/fastapi/templates/get_layout_by_name.py index ec68dd6e..f69251ff 100644 --- a/electron/servers/fastapi/utils/get_layout_by_name.py +++ b/electron/servers/fastapi/templates/get_layout_by_name.py @@ -1,7 +1,8 @@ import aiohttp from fastapi import HTTPException -from models.presentation_layout import PresentationLayoutModel -from typing import List + +from templates.presentation_layout import PresentationLayoutModel + async def get_layout_by_name(layout_name: str) -> PresentationLayoutModel: url = f"http://localhost/api/template?group={layout_name}" @@ -11,8 +12,7 @@ async def get_layout_by_name(layout_name: str) -> PresentationLayoutModel: error_text = await response.text() raise HTTPException( status_code=404, - detail=f"Template '{layout_name}' not found: {error_text}" + detail=f"Template '{layout_name}' not found: {error_text}", ) layout_json = await response.json() - # Parse the JSON into your Pydantic model return PresentationLayoutModel(**layout_json) diff --git a/electron/servers/fastapi/templates/handler.py b/electron/servers/fastapi/templates/handler.py new file mode 100644 index 00000000..39f06fdf --- /dev/null +++ b/electron/servers/fastapi/templates/handler.py @@ -0,0 +1,683 @@ +import os +import random +import re +import uuid +from datetime import datetime +from typing import Any, List, Optional + +import aiohttp +from fastapi import Body, Depends, File, Form, HTTPException, Path, Query, UploadFile +from pydantic import BaseModel +from sqlalchemy import func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import delete, select + +from constants.presentation import DEFAULT_TEMPLATES +from models.sql.presentation_layout_code import PresentationLayoutCodeModel +from models.sql.template import TemplateModel +from models.sql.template_create_info import TemplateCreateInfoModel +from services.database import get_async_session +from templates.example import build_template_example +from templates.get_layout_by_name import get_layout_by_name +from templates.pptx_html_stub import BASIC_TEMPLATE_HTML +from templates.presentation_layout import PresentationLayoutModel +from templates.preview import ( + FontsUploadAndSlidesPreviewResponse, + upload_fonts_and_slides_preview_handler, +) +from templates.prompts import ( + SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT, + SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT, + SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT, +) +from templates.providers import edit_slide_layout_code, generate_slide_layout_code +from utils.asset_directory_utils import resolve_image_path_to_filesystem + + +class TemplateDetail(BaseModel): + id: str + name: str + total_layouts: Optional[int] = None + + +class TemplateLayoutData(BaseModel): + template: uuid.UUID + layout_id: str + layout_name: str + layout_code: str + fonts: Optional[Any] = None + + +class TemplateData(BaseModel): + id: uuid.UUID + init_id: Optional[uuid.UUID] = None + name: str + description: Optional[str] = None + created_at: datetime + + +class GetTemplateLayoutsResponse(BaseModel): + layouts: list[TemplateLayoutData] + template: Optional[TemplateData] = None + fonts: Optional[Any] = None + + +class TemplateExample(BaseModel): + template: str + slides: List[dict] + + +class CreateTemplateInitRequest(BaseModel): + pptx_url: str + slide_image_urls: List[str] + fonts: dict = {} + + +class CreateSlideLayoutRequest(BaseModel): + id: uuid.UUID + index: int + + +class CreateSlideLayoutResponse(BaseModel): + react_component: str + + +class EditSlideLayoutRequest(BaseModel): + react_component: str + prompt: str + + +class EditSlideLayoutResponse(CreateSlideLayoutResponse): + pass + + +class EditSlideLayoutSectionRequest(BaseModel): + react_component: str + section: str + prompt: str + + +class EditSlideLayoutSectionResponse(CreateSlideLayoutResponse): + pass + + +class SaveTemplateLayoutData(BaseModel): + layout_id: str + layout_name: str + layout_code: str + + +class SaveTemplateRequest(BaseModel): + template_info_id: uuid.UUID + name: str + description: Optional[str] = None + layouts: List[SaveTemplateLayoutData] + + +class SaveTemplateResponse(BaseModel): + id: uuid.UUID + name: str + description: Optional[str] = None + created_at: datetime + + +class CloneTemplateRequest(BaseModel): + id: str + name: str + description: Optional[str] = None + + +class UpdateTemplateRequest(BaseModel): + id: uuid.UUID + layouts: List[SaveTemplateLayoutData] + + +class SaveSlideLayoutRequest(BaseModel): + template_id: uuid.UUID + layout_id: str + layout_code: str + + +class CloneSlideLayoutRequest(BaseModel): + template_id: str + layout_id: str + layout_name: Optional[str] = None + + +def _strip_code_fences(value: str) -> str: + return ( + value.replace("```tsx", "") + .replace("```typescript", "") + .replace("```ts", "") + .replace("```", "") + .strip() + ) + + +def _normalize_layout_code_for_create(code: str) -> str: + normalized = _strip_code_fences(code) + normalized = ( + normalized.replace("image_url", "__image_url__") + .replace("icon_url", "__icon_url__") + .replace("image_prompt", "__image_prompt__") + .replace("icon_query", "__icon_query__") + ) + + first_import_match = re.search(r"(?m)^\s*import\b", normalized) + if first_import_match: + normalized = normalized[first_import_match.start() :] + + first_export_match = re.search(r"(?m)^\s*export\b", normalized) + if first_export_match: + normalized = normalized[: first_export_match.start()] + + normalized = re.sub( + r"(?ms)^\s*(?:import|export)\b.*?;(?:\r?\n|$)", + "", + normalized, + ) + normalized = re.sub( + r"(?m)^\s*(?:import|export)\b.*(?:\r?\n|$)", + "", + normalized, + ) + normalized = normalized.strip() + normalized = re.sub( + r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])', + lambda match: ( + match.group(0) + if re.search(r"-\d{4}$", match.group(2)) + else f"{match.group(1)}{match.group(2)}-{random.randint(1000, 9999)}{match.group(3)}" + ), + normalized, + ) + return normalized + + +def _update_layout_id_in_code(code: str) -> tuple[str, str]: + match = re.search(r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])', code) + if not match: + raise HTTPException(status_code=400, detail="layoutId not found in layout code") + + current_id = match.group(2) + suffix = f"{random.randint(1000, 9999)}" + new_id = re.sub(r"-\d{4}$", f"-{suffix}", current_id) + if new_id == current_id: + new_id = f"{current_id}-{suffix}" + + new_code = re.sub( + r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])', + f"\\1{new_id}\\3", + code, + count=1, + ) + return new_code, new_id + + +async def _download_image_bytes(image_url: str) -> bytes: + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as response: + if response.status != 200: + raise HTTPException( + status_code=400, + detail=f"Failed to download slide image: {image_url}", + ) + return await response.read() + + +async def _read_image_bytes_and_media_type(image_url: str) -> tuple[bytes, str]: + actual_image_path = resolve_image_path_to_filesystem(image_url) + if actual_image_path and os.path.isfile(actual_image_path): + with open(actual_image_path, "rb") as image_file: + image_bytes = image_file.read() + file_extension = os.path.splitext(actual_image_path)[1].lower() + else: + image_bytes = await _download_image_bytes(image_url) + file_extension = os.path.splitext(image_url)[1].lower() + + media_type_map = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + } + return image_bytes, media_type_map.get(file_extension, "image/png") + + +async def get_all_templates( + include_defaults: bool = Query( + default=True, description="Whether to include default templates" + ), + sql_session: AsyncSession = Depends(get_async_session), +): + result = await sql_session.execute( + select( + TemplateModel.id, + TemplateModel.name, + func.count(PresentationLayoutCodeModel.id).label("total_layouts"), + ) + .join( + PresentationLayoutCodeModel, + PresentationLayoutCodeModel.presentation == TemplateModel.id, + ) + .group_by(TemplateModel.id, TemplateModel.name) + ) + rows = result.all() + + templates: list[TemplateDetail] = [] + if include_defaults: + templates.extend( + TemplateDetail(id=template, name=template) for template in DEFAULT_TEMPLATES + ) + + templates.extend( + TemplateDetail( + id=f"custom-{template_id}", + name=template_name, + total_layouts=total_layouts, + ) + for template_id, template_name, total_layouts in rows + ) + return templates + + +async def get_layouts( + template_id: str = Path(..., description="The id of the template"), + session: AsyncSession = Depends(get_async_session), +): + if not template_id or not template_id.strip(): + raise HTTPException(status_code=400, detail="Template ID cannot be empty") + + try: + cleaned_template_id = template_id.replace("custom-", "") + template_id_uuid = uuid.UUID(cleaned_template_id) + except Exception as exc: + raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc + + result = await session.execute( + select(PresentationLayoutCodeModel).where( + PresentationLayoutCodeModel.presentation == template_id_uuid + ) + ) + layouts_db = result.scalars().all() + if not layouts_db: + raise HTTPException( + status_code=404, + detail=f"No layouts found for template ID: {template_id}", + ) + + template_meta = await session.get(TemplateModel, template_id_uuid) + template = None + if template_meta: + template = TemplateData( + id=template_id_uuid, + init_id=None, + name=template_meta.name, + description=template_meta.description, + created_at=template_meta.created_at, + ) + + layouts = [ + TemplateLayoutData( + template=template_id_uuid, + layout_id=layout.layout_id, + layout_name=layout.layout_name, + layout_code=layout.layout_code, + fonts=layout.fonts, + ) + for layout in layouts_db + ] + return GetTemplateLayoutsResponse( + layouts=layouts, + template=template, + fonts=layouts[0].fonts if layouts else None, + ) + + +async def get_template_by_id( + id: str = Path( + ..., + description=f"The id of the template, must be one of {', '.join(DEFAULT_TEMPLATES)} or your custom template", + ), + sql_session: AsyncSession = Depends(get_async_session), +): + if id.startswith("custom-"): + try: + template_id = uuid.UUID(id.replace("custom-", "")) + except Exception as exc: + raise HTTPException( + status_code=400, + detail="Template not found. Please use a valid template.", + ) from exc + + template = await sql_session.get(TemplateModel, template_id) + if not template: + raise HTTPException( + status_code=400, + detail="Template not found. Please use a valid template.", + ) + + return await get_layout_by_name(id) + + +async def get_template_example( + id: str = Path( + ..., + description=f"The id of the template, must be one of {', '.join(DEFAULT_TEMPLATES)} or your custom template", + ), + sql_session: AsyncSession = Depends(get_async_session), +): + template = await get_template_by_id(id=id, sql_session=sql_session) + return TemplateExample(**build_template_example(id, template)) + + +async def upload_fonts_and_slides_preview( + pptx_file: UploadFile = File(..., description="PPTX file to preview"), + font_files: Optional[List[UploadFile]] = File( + default=None, description="Font files to upload" + ), + original_font_names: Optional[List[str]] = Form(default=None), + max_slides: Optional[int] = Query(default=None), +): + return await upload_fonts_and_slides_preview_handler( + pptx_file=pptx_file, + font_files=font_files, + original_font_names=original_font_names, + max_slides=max_slides, + ) + + +async def init_create_template( + request: CreateTemplateInitRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + if not request.slide_image_urls: + raise HTTPException( + status_code=400, detail="At least one slide image is required" + ) + + template_create_info = TemplateCreateInfoModel( + fonts=request.fonts or {}, + pptx_url=request.pptx_url, + slide_image_urls=request.slide_image_urls, + slide_htmls=[BASIC_TEMPLATE_HTML for _ in request.slide_image_urls], + ) + sql_session.add(template_create_info) + await sql_session.commit() + await sql_session.refresh(template_create_info) + return template_create_info.id + + +async def create_slide_layout( + request: CreateSlideLayoutRequest = Body(...), + sql_session: AsyncSession = Depends(get_async_session), +): + template_info = await sql_session.get(TemplateCreateInfoModel, request.id) + if not template_info: + raise HTTPException(status_code=400, detail="Template not found") + + total_slides = len(template_info.slide_htmls) + if request.index < 0 or request.index >= total_slides: + raise HTTPException(status_code=400, detail="Invalid slide index") + + slide_html = template_info.slide_htmls[request.index] + slide_image_url = template_info.slide_image_urls[request.index] + image_bytes, media_type = await _read_image_bytes_and_media_type(slide_image_url) + + fonts_text = "" + if template_info.fonts: + font_names = [font.replace(" ", "_") for font in template_info.fonts.keys()] + fonts_text = "#PROVIDED FONTS\n- " + "\n- ".join(font_names) + + user_text = f"{fonts_text}\n\n#SLIDE HTML REFERENCE\n{slide_html}" + react_component = await generate_slide_layout_code( + system_prompt=SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT, + user_text=user_text, + image_bytes=image_bytes, + media_type=media_type, + ) + + return CreateSlideLayoutResponse( + react_component=_normalize_layout_code_for_create(react_component) + ) + + +async def edit_slide_layout( + request: EditSlideLayoutRequest, +): + user_text = f"#Prompt\n{request.prompt}\n\n#TSX code\n{request.react_component}" + react_component = await edit_slide_layout_code( + system_prompt=SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT, + user_text=user_text, + ) + return EditSlideLayoutResponse(react_component=_strip_code_fences(react_component)) + + +async def edit_slide_layout_section( + request: EditSlideLayoutSectionRequest, +): + user_text = ( + f"#Prompt\n{request.prompt}\n\n" + f"#Section to make changes around\n{request.section}\n\n" + f"#TSX code\n{request.react_component}" + ) + react_component = await edit_slide_layout_code( + system_prompt=SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT, + user_text=user_text, + ) + return EditSlideLayoutSectionResponse( + react_component=_strip_code_fences(react_component) + ) + + +async def save_template( + request: SaveTemplateRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + if not request.layouts: + raise HTTPException(status_code=400, detail="Layouts are required") + + template_info = await sql_session.get(TemplateCreateInfoModel, request.template_info_id) + if not template_info: + raise HTTPException(status_code=400, detail="Template info not found") + + template = TemplateModel( + id=uuid.uuid4(), + name=request.name, + description=request.description, + ) + sql_session.add(template) + + sql_session.add_all( + [ + PresentationLayoutCodeModel( + presentation=template.id, + layout_id=layout.layout_id, + layout_name=layout.layout_name, + layout_code=layout.layout_code, + fonts=template_info.fonts, + ) + for layout in request.layouts + ] + ) + await sql_session.commit() + await sql_session.refresh(template) + + return SaveTemplateResponse( + id=template.id, + name=template.name, + description=template.description, + created_at=template.created_at, + ) + + +async def clone_template( + request: CloneTemplateRequest = Body(...), + sql_session: AsyncSession = Depends(get_async_session), +): + if not request.id or not request.id.strip(): + raise HTTPException(status_code=400, detail="Template ID cannot be empty") + + try: + template_id_uuid = uuid.UUID(request.id.replace("custom-", "")) + except Exception as exc: + raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc + + template = await sql_session.get(TemplateModel, template_id_uuid) + if not template: + raise HTTPException( + status_code=400, + detail="Template not found. Please use a valid template.", + ) + + result = await sql_session.execute( + select(PresentationLayoutCodeModel).where( + PresentationLayoutCodeModel.presentation == template_id_uuid + ) + ) + layouts_db = result.scalars().all() + if not layouts_db: + raise HTTPException(status_code=400, detail="No layouts found for template") + + new_template = TemplateModel( + id=uuid.uuid4(), + name=request.name, + description=template.description + if request.description is None + else request.description, + ) + sql_session.add(new_template) + + sql_session.add_all( + [ + PresentationLayoutCodeModel( + presentation=new_template.id, + layout_id=layout.layout_id, + layout_name=layout.layout_name, + layout_code=layout.layout_code, + fonts=layout.fonts, + ) + for layout in layouts_db + ] + ) + await sql_session.commit() + await sql_session.refresh(new_template) + + return SaveTemplateResponse( + id=new_template.id, + name=new_template.name, + description=new_template.description, + created_at=new_template.created_at, + ) + + +async def update_template( + request: UpdateTemplateRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + if not request.layouts: + raise HTTPException(status_code=400, detail="Layouts are required") + + template = await sql_session.get(TemplateModel, request.id) + if not template: + raise HTTPException(status_code=400, detail="Template not found") + + existing_layout = await sql_session.scalar( + select(PresentationLayoutCodeModel).where( + PresentationLayoutCodeModel.presentation == request.id + ) + ) + fonts = existing_layout.fonts if existing_layout else None + + await sql_session.execute( + delete(PresentationLayoutCodeModel).where( + PresentationLayoutCodeModel.presentation == request.id + ) + ) + sql_session.add_all( + [ + PresentationLayoutCodeModel( + presentation=template.id, + layout_id=layout.layout_id, + layout_name=layout.layout_name, + layout_code=layout.layout_code, + fonts=fonts, + ) + for layout in request.layouts + ] + ) + await sql_session.commit() + + return SaveTemplateResponse( + id=template.id, + name=template.name, + description=template.description, + created_at=template.created_at, + ) + + +async def save_slide_layout( + request: SaveSlideLayoutRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + template = await sql_session.get(TemplateModel, request.template_id) + if not template: + raise HTTPException(status_code=400, detail="Template not found") + + layout = await sql_session.scalar( + select(PresentationLayoutCodeModel).where( + PresentationLayoutCodeModel.presentation == request.template_id, + PresentationLayoutCodeModel.layout_id == request.layout_id, + ) + ) + if not layout: + raise HTTPException(status_code=400, detail="Layout not found") + + layout.layout_code = request.layout_code + sql_session.add(layout) + await sql_session.commit() + + +async def clone_slide_layout( + request: CloneSlideLayoutRequest = Body(...), + sql_session: AsyncSession = Depends(get_async_session), +): + if not request.template_id or not request.template_id.strip(): + raise HTTPException(status_code=400, detail="Template ID cannot be empty") + + try: + template_id_uuid = uuid.UUID(request.template_id.replace("custom-", "")) + except Exception as exc: + raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc + + template = await sql_session.get(TemplateModel, template_id_uuid) + if not template: + raise HTTPException(status_code=400, detail="Template not found") + + layout = await sql_session.scalar( + select(PresentationLayoutCodeModel).where( + PresentationLayoutCodeModel.presentation == template_id_uuid, + PresentationLayoutCodeModel.layout_id == request.layout_id, + ) + ) + if not layout: + raise HTTPException(status_code=400, detail="Layout not found") + + new_layout_code, new_layout_id = _update_layout_id_in_code(layout.layout_code) + new_layout = PresentationLayoutCodeModel( + presentation=template_id_uuid, + layout_id=new_layout_id, + layout_name=request.layout_name or layout.layout_name, + layout_code=new_layout_code, + fonts=layout.fonts, + ) + sql_session.add(new_layout) + await sql_session.commit() + await sql_session.refresh(new_layout) + + return SaveTemplateLayoutData( + layout_id=new_layout.layout_id, + layout_name=new_layout.layout_name, + layout_code=new_layout.layout_code, + ) diff --git a/electron/servers/fastapi/templates/pptx_html_stub.py b/electron/servers/fastapi/templates/pptx_html_stub.py new file mode 100644 index 00000000..e07cb462 --- /dev/null +++ b/electron/servers/fastapi/templates/pptx_html_stub.py @@ -0,0 +1,30 @@ +BASIC_TEMPLATE_HTML = """ + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+""".strip() diff --git a/electron/servers/fastapi/models/presentation_layout.py b/electron/servers/fastapi/templates/presentation_layout.py similarity index 78% rename from electron/servers/fastapi/models/presentation_layout.py rename to electron/servers/fastapi/templates/presentation_layout.py index 784e41fc..2bf9a4cb 100644 --- a/electron/servers/fastapi/models/presentation_layout.py +++ b/electron/servers/fastapi/templates/presentation_layout.py @@ -1,4 +1,5 @@ from typing import List, Optional + from fastapi import HTTPException from pydantic import BaseModel, Field @@ -25,15 +26,15 @@ class PresentationLayoutModel(BaseModel): status_code=404, detail=f"Slide layout {slide_layout_id} not found" ) - def to_presentation_structure(self): + def to_presentation_structure(self) -> PresentationStructureModel: return PresentationStructureModel( slides=[index for index in range(len(self.slides))] ) - def to_string(self): - message = f"## Presentation Layout\n\n" + def to_string(self) -> str: + message = "## Presentation Layout\n\n" for index, slide in enumerate(self.slides): - message += f"### Slide Layout: {index}: \n" - message += f"- Name: {slide.name or slide.json_schema.get('title')} \n" - message += f"- Description: {slide.description} \n\n" + message += f"### Slide Layout: {index}\n" + message += f"- Name: {slide.name or slide.json_schema.get('title')}\n" + message += f"- Description: {slide.description}\n\n" return message diff --git a/electron/servers/fastapi/templates/preview.py b/electron/servers/fastapi/templates/preview.py new file mode 100644 index 00000000..0a045319 --- /dev/null +++ b/electron/servers/fastapi/templates/preview.py @@ -0,0 +1,477 @@ +import asyncio +from dataclasses import dataclass +import os +import re +import shutil +import subprocess +import tempfile +import uuid +import zipfile +from pathlib import Path +from typing import Dict, List, Optional + +from fastapi import File, HTTPException, UploadFile +from pydantic import BaseModel + +from constants.documents import PPTX_MIME_TYPES +from services.documents_loader import DocumentsLoader +from templates.font_utils import ( + collect_normalized_fonts_from_xmls, + get_available_and_unavailable_fonts, +) +from utils.get_env import get_app_data_directory_env + +try: + from fontTools.ttLib import TTFont + + FONTTOOLS_AVAILABLE = True +except ImportError: + FONTTOOLS_AVAILABLE = False + + +SUPPORTED_FONT_EXTENSIONS = { + ".ttf": "font/ttf", + ".otf": "font/otf", + ".woff": "font/woff", + ".woff2": "font/woff2", + ".eot": "application/vnd.ms-fontobject", +} + + +class FontInfo(BaseModel): + name: str + url: str | None = None + + +class FontCheckResponse(BaseModel): + available_fonts: List[FontInfo] + unavailable_fonts: List[FontInfo] + + +class FontsUploadAndSlidesPreviewResponse(BaseModel): + slide_image_urls: List[str] + pptx_url: str + modified_pptx_url: str + fonts: dict + + +@dataclass +class StoredFont: + display_name: str + url: str + temp_path: str + + +def _get_soffice_binary() -> str: + configured = os.environ.get("SOFFICE_PATH") + if configured: + return configured + return "soffice.exe" if os.name == "nt" else "soffice" + + +def _windows_hidden_subprocess_kwargs() -> Dict[str, object]: + if os.name != "nt": + return {} + + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + return { + "creationflags": getattr(subprocess, "CREATE_NO_WINDOW", 0), + "startupinfo": startupinfo, + } + + +def _app_data_directory() -> str: + app_data_dir = get_app_data_directory_env() or "/tmp/presenton" + os.makedirs(app_data_dir, exist_ok=True) + return app_data_dir + + +def _get_fonts_directory() -> str: + fonts_dir = os.path.join(_app_data_directory(), "fonts") + os.makedirs(fonts_dir, exist_ok=True) + return fonts_dir + + +def _get_images_directory() -> str: + images_dir = os.path.join(_app_data_directory(), "images") + os.makedirs(images_dir, exist_ok=True) + return images_dir + + +def _get_template_uploads_directory() -> str: + uploads_dir = os.path.join(_app_data_directory(), "uploads", "template-previews") + os.makedirs(uploads_dir, exist_ok=True) + return uploads_dir + + +def _write_bytes_to_path(path: str, data: bytes) -> None: + with open(path, "wb") as file: + file.write(data) + + +def _copy_file(source_path: str, destination_path: str) -> None: + shutil.copy2(source_path, destination_path) + + +def _extract_font_name_from_file(file_path: str) -> str: + filename = os.path.basename(file_path) + base_name = os.path.splitext(filename)[0] + if not FONTTOOLS_AVAILABLE: + return base_name + + try: + font = TTFont(file_path) + if "name" in font: + name_table = font["name"] + for name_id in (1, 4, 6): + for record in name_table.names: + if record.nameID != name_id: + continue + if record.langID in (0x409, 0): + font_name = record.toUnicode().strip() + if font_name: + font.close() + return font_name + for record in name_table.names: + if record.nameID != 1: + continue + font_name = record.toUnicode().strip() + if font_name: + font.close() + return font_name + font.close() + except Exception: + pass + + return base_name + + +def _validate_pptx_file(pptx_file: UploadFile) -> None: + filename = getattr(pptx_file, "filename", "") or "" + if not filename.lower().endswith(".pptx"): + raise HTTPException( + status_code=400, + detail="Invalid file type. Expected PPTX file", + ) + if pptx_file.content_type and pptx_file.content_type not in PPTX_MIME_TYPES: + raise HTTPException( + status_code=400, + detail=f"Invalid file type. Expected PPTX file, got {pptx_file.content_type}", + ) + + +def _ensure_valid_font_file(font_file: UploadFile) -> None: + filename = font_file.filename or "" + extension = os.path.splitext(filename)[1].lower() + if extension not in SUPPORTED_FONT_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Invalid font file. Supported formats: {', '.join(SUPPORTED_FONT_EXTENSIONS.keys())}", + ) + + +async def _persist_custom_fonts( + font_files: Optional[List[UploadFile]], + original_font_names: Optional[List[str]], + temp_dir: str, +) -> list[StoredFont]: + if not font_files: + return [] + + stored_fonts: list[StoredFont] = [] + fonts_dir = _get_fonts_directory() + + for index, font_file in enumerate(font_files): + _ensure_valid_font_file(font_file) + + original_name = ( + original_font_names[index] + if original_font_names and index < len(original_font_names) + else None + ) + extension = os.path.splitext(font_file.filename or "")[1].lower() + unique_name = f"{Path(font_file.filename or f'font_{index}').stem}_{uuid.uuid4().hex[:8]}{extension}" + temp_font_path = os.path.join(temp_dir, unique_name) + permanent_font_path = os.path.join(fonts_dir, unique_name) + font_bytes = await font_file.read() + + await asyncio.to_thread(_write_bytes_to_path, temp_font_path, font_bytes) + await asyncio.to_thread(_write_bytes_to_path, permanent_font_path, font_bytes) + + actual_font_name = await asyncio.to_thread( + _extract_font_name_from_file, permanent_font_path + ) + display_name = original_name or actual_font_name + stored_fonts.append( + StoredFont( + display_name=display_name, + url=f"/app_data/fonts/{unique_name}", + temp_path=temp_font_path, + ) + ) + + return stored_fonts + + +def _create_font_alias_config(raw_fonts: List[str]) -> str: + mappings: Dict[str, str] = {} + for font_name in raw_fonts: + normalized = font_name + if not normalized: + continue + mappings[font_name] = normalized + + fd, fonts_conf_path = tempfile.mkstemp(prefix="fonts_alias_", suffix=".conf") + os.close(fd) + with open(fonts_conf_path, "w", encoding="utf-8") as cfg: + cfg.write( + """ + + + /etc/fonts/fonts.conf +""" + ) + for source_family, destination_family in mappings.items(): + if source_family == destination_family: + continue + cfg.write( + f""" + + + {source_family} + + + {destination_family} + + +""" + ) + cfg.write("\n\n") + return fonts_conf_path + + +async def _install_fonts(font_paths: List[str]) -> None: + if not font_paths: + return + + for font_path in font_paths: + try: + subprocess.run( + ["cp", font_path, "/usr/share/fonts/truetype/"], + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError: + continue + + try: + subprocess.run(["fc-cache", "-f", "-v"], check=True, capture_output=True) + except subprocess.CalledProcessError: + pass + + +def extract_slide_xmls(pptx_path: str, temp_dir: str) -> List[str]: + slide_xmls: list[str] = [] + extract_dir = os.path.join(temp_dir, "pptx_extract") + + with zipfile.ZipFile(pptx_path, "r") as zip_ref: + zip_ref.extractall(extract_dir) + + slides_dir = os.path.join(extract_dir, "ppt", "slides") + if not os.path.exists(slides_dir): + raise HTTPException(status_code=400, detail="No slides directory found in PPTX") + + slide_files = [ + file_name + for file_name in os.listdir(slides_dir) + if file_name.startswith("slide") and file_name.endswith(".xml") + ] + slide_files.sort(key=lambda value: int(re.sub(r"[^0-9]", "", value) or "0")) + + for slide_file in slide_files: + slide_path = os.path.join(slides_dir, slide_file) + with open(slide_path, "r", encoding="utf-8") as slide_handle: + slide_xmls.append(slide_handle.read()) + + return slide_xmls + + +async def convert_pptx_to_pdf( + pptx_path: str, + temp_dir: str, + slide_xmls: Optional[List[str]] = None, +) -> str: + screenshots_dir = os.path.join(temp_dir, "screenshots") + os.makedirs(screenshots_dir, exist_ok=True) + + slide_xmls = slide_xmls or extract_slide_xmls(pptx_path, temp_dir) + raw_fonts = collect_normalized_fonts_from_xmls(slide_xmls) + fonts_conf_path = _create_font_alias_config(raw_fonts) + env = os.environ.copy() + env["FONTCONFIG_FILE"] = fonts_conf_path + + try: + subprocess.run( + [ + _get_soffice_binary(), + "--headless", + "--convert-to", + "pdf", + "--outdir", + screenshots_dir, + pptx_path, + ], + check=True, + capture_output=True, + text=True, + timeout=500, + env=env, + **_windows_hidden_subprocess_kwargs(), + ) + except subprocess.TimeoutExpired as exc: + raise HTTPException( + status_code=500, + detail="LibreOffice PDF conversion timed out after 500 seconds", + ) from exc + except subprocess.CalledProcessError as exc: + error_message = exc.stderr if exc.stderr else str(exc) + raise HTTPException( + status_code=500, + detail=f"LibreOffice PDF conversion failed: {error_message}", + ) from exc + + pdf_files = [file_name for file_name in os.listdir(screenshots_dir) if file_name.endswith(".pdf")] + if not pdf_files: + raise HTTPException( + status_code=500, detail="LibreOffice failed to generate a PDF file" + ) + + return os.path.join(screenshots_dir, pdf_files[0]) + + +async def store_slide_images( + screenshot_paths: List[str], + session_id: uuid.UUID, +) -> List[str]: + images_dir = _get_images_directory() + target_dir = os.path.join(images_dir, str(session_id)) + os.makedirs(target_dir, exist_ok=True) + + slide_image_urls: list[str] = [] + for index, screenshot_path in enumerate(screenshot_paths, start=1): + file_name = f"slide_{index}.png" + destination_path = os.path.join(target_dir, file_name) + + if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 0: + await asyncio.to_thread(_copy_file, screenshot_path, destination_path) + slide_image_urls.append(f"/app_data/images/{session_id}/{file_name}") + else: + slide_image_urls.append("/static/images/placeholder.jpg") + + return slide_image_urls + + +async def store_uploaded_pptx( + pptx_path: str, + session_id: uuid.UUID, +) -> str: + uploads_dir = _get_template_uploads_directory() + target_dir = os.path.join(uploads_dir, str(session_id)) + os.makedirs(target_dir, exist_ok=True) + + destination_path = os.path.join(target_dir, "presentation.pptx") + await asyncio.to_thread(_copy_file, pptx_path, destination_path) + return f"/app_data/uploads/template-previews/{session_id}/presentation.pptx" + + +async def get_available_and_unavailable_fonts_for_pptx( + pptx_path: str, temp_dir: str +) -> tuple[list[tuple[str, str]], list[tuple[str, None]]]: + slide_xmls = extract_slide_xmls(pptx_path, temp_dir) + normalized_fonts = collect_normalized_fonts_from_xmls(slide_xmls) + return await get_available_and_unavailable_fonts(normalized_fonts) + + +async def check_fonts_in_pptx_handler( + pptx_file: UploadFile = File(..., description="PPTX file to analyze fonts from") +) -> FontCheckResponse: + _validate_pptx_file(pptx_file) + + with tempfile.TemporaryDirectory() as temp_dir: + pptx_path = os.path.join(temp_dir, "presentation.pptx") + pptx_content = await pptx_file.read() + await asyncio.to_thread(_write_bytes_to_path, pptx_path, pptx_content) + + available_fonts_data, unavailable_fonts_data = ( + await get_available_and_unavailable_fonts_for_pptx(pptx_path, temp_dir) + ) + + return FontCheckResponse( + available_fonts=[ + FontInfo(name=name, url=url) for name, url in available_fonts_data + ], + unavailable_fonts=[ + FontInfo(name=name, url=url) for name, url in unavailable_fonts_data + ], + ) + + +async def upload_fonts_and_slides_preview_handler( + pptx_file: UploadFile, + font_files: Optional[List[UploadFile]] = None, + original_font_names: Optional[List[str]] = None, + max_slides: Optional[int] = None, +) -> FontsUploadAndSlidesPreviewResponse: + if (font_files and not original_font_names) or ( + original_font_names and not font_files + ): + raise HTTPException( + status_code=400, + detail="Both font_files and original_font_names must be provided together", + ) + if font_files and original_font_names and len(font_files) != len(original_font_names): + raise HTTPException( + status_code=400, + detail="Number of font files must match number of original font names", + ) + + _validate_pptx_file(pptx_file) + + with tempfile.TemporaryDirectory() as temp_dir: + pptx_path = os.path.join(temp_dir, "presentation.pptx") + pptx_content = await pptx_file.read() + await asyncio.to_thread(_write_bytes_to_path, pptx_path, pptx_content) + + stored_fonts = await _persist_custom_fonts( + font_files=font_files, + original_font_names=original_font_names, + temp_dir=temp_dir, + ) + await _install_fonts([font.temp_path for font in stored_fonts]) + + slide_xmls = extract_slide_xmls(pptx_path, temp_dir) + pdf_path = await convert_pptx_to_pdf(pptx_path, temp_dir, slide_xmls=slide_xmls) + screenshot_paths = await DocumentsLoader.get_page_images_from_pdf_async( + pdf_path, temp_dir + ) + + if max_slides and len(screenshot_paths) > max_slides: + screenshot_paths = screenshot_paths[:max_slides] + + session_id = uuid.uuid4() + slide_image_urls = await store_slide_images(screenshot_paths, session_id) + pptx_url = await store_uploaded_pptx(pptx_path, session_id) + + available_fonts, _ = await get_available_and_unavailable_fonts( + collect_normalized_fonts_from_xmls(slide_xmls) + ) + fonts: dict[str, str] = {name: url for name, url in available_fonts} + fonts.update({font.display_name: font.url for font in stored_fonts}) + + return FontsUploadAndSlidesPreviewResponse( + slide_image_urls=slide_image_urls, + pptx_url=pptx_url, + modified_pptx_url=pptx_url, + fonts=fonts, + ) diff --git a/electron/servers/fastapi/templates/prompts.py b/electron/servers/fastapi/templates/prompts.py new file mode 100644 index 00000000..619f9d26 --- /dev/null +++ b/electron/servers/fastapi/templates/prompts.py @@ -0,0 +1,219 @@ +SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT = """ +You need to generate a Zod schema and a TSX React component and provide it as output. +Provide reusable TSX code which can be used as template to generate new slides with different content. + +# Steps: +1. Analyze the slide image to understand the visual hierarchy. +3. Classify elements into decorative and content elements. +4. Group content elements into logical sections like Header, Body, BulletPoints, etc. +5. Generate a Zod schema for the content elements. +6. Generate id, name and description for the layout. +6. Generate a TSX React component using the Zod schema and the HTML reference. + +# Decorative Elements: +- Arrows, Lines, Shapes, etc. +- Images with Grid patterns, background patterns, gradients, solid colors, etc. +- Background of infographics like funnel, timeline, etc. +- Company name, logos, etc. +- Images covering the entire slide. +- Images containing company name, logos, etc. + +# Decorative Elements Rules: +- Use them exactly as they are in the HTML reference. +- Do not change decorative images and icons urls. +- Images containing company name, logos, etc should be identified as decorative elements. + +# Content Elements: +- Title, Description, BulletPoints, etc. +- Graphs, Charts, etc. +- Images and Icons representing textual content like title, description, bullet points, etc. +- Meaningful Images and Icons. +- Icons in infographics that represent the data. + +# Content Elements Rules: +- Properly identify between images and icons elements. +- Image content: + - Image field should be 'z.object({"image_url": z.string(), "image_prompt": z.string().max(100)})' + - Replace actual image url with 'https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png' +- Icon content: + - Icon field should be 'z.object({"icon_url": z.string(), "icon_query": z.string().max(30)})' + - Replace actual icon url with 'https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg' + - Add color styling to the icon to match the color in the image. +- Make sure the urls are correct. + +# Layout Rules: +- The layout should be fixed 1280px width and 720px height. +- Adjust the positions and sizes of elements to fit the layout. +- Try to keep the positions and sizes of elements as close to HTML reference as possible. + +# Flexible Positioning and Sizes Rules: +- Must not use 'absolute' positioning for elements. +- Must use 'flex', 'grid', 'margin', 'padding', 'gap', 'basis', 'justify', 'align', etc for positioning of elements. +- For variable length lists, wrap list into a container and center it. +- Don't use specific sizes (height, width) for elements if not necessary. + +# Schema Field Name and Description Rules: +- Must not use content specific words. +- Only use words based on what content types are present in the slide image. +- Use words like 'title', description', 'heading', 'image', 'graph', 'table', 'bullet points', etc. +- Must not use words like 'budget', 'market', 'revenue', 'sales', 'growth', 'workflow', 'channel', 'plannedValue', 'actualValue', etc. + +# Layout ID, Name and Description Rules: +- Must only use slide structure to derive layout id, name and description. +- Informations like: Type of content, position of content, etc. should be used. +- layoutId example: title-description-right-image. +- layoutName example: Title Description Image. +- layoutDescription example: A slide with a title, description, and an image on right. + +# Zod Schema Rules: +- "describe" must be added for every fields. +- "default" must be added in top level fields of schema. +- Top level fields are those not nested inside other fields. +- Don't mention string type in schema like "url()", "email()", etc. +- Table must be object with "columns" and "rows" fields. +- "columns" must be an array of strings. +- "rows" must be an array of arrays of strings. +- Graph must be object with "categories" and "series" fields. +- "categories" must be an array of strings. +- "series" must be an array of objects with {"name": string, "data": array of numbers}. +- Must not use z.record() anywhere in the schema. + +# String and Array Field Rules: +- Every string field must include `.max(...)`; every array field must include `.max(...)`. +- For strings, set `max` to the exact character count of the text content it represents. +- For arrays, set `max` to the exact item count of the array content it represents. +- Choose a `max` that keeps the longest allowed content from overflowing its container. + +# Table Rules: +- Construct "tr -> th" by iterating over the "columns" field. +- Construct "tr -> td" by iterating over the "rows" field. +- Make sure table height and width adjusts to fit the content. + +# Grahps, Charts, etc Rules: +- Identify if graphs, charts, etc are present in the slide image. +- Identify the type of graph, chart, etc. +- If present, generate a zod schema for the graph, chart, etc. +- Generate TSX code for the graph, chart, etc. even if it is not present in the HTML reference. +- Use graph schema and image to generate the TSX code. +- Use Recharts library for graphs. + +# Fonts Rules: +- Check for "PROVIDED FONTS". +- Must use fonts only from "PROVIDED FONTS". +- Add "font-[\"font-name\"]" to every text element in the slide. + +# Page Number Rules: +- Identify if the slide contains page number from provided HTML reference and image. +- If page number is present, add a "page: z.number().min(1).meta({ description: "Page number" })" field in the schema. + +# React Component Rules: +- React component must be named dynamicSlideLayout. +- dynamicSlideLayout must take "{ data }: { data: Partial> }" as props. +- Wrap the code inside these classes: "relative w-full rounded-sm max-w-[1280px] shadow-lg max-h-[720px] aspect-video bg-white z-20 mx-auto overflow-hidden". +- Make sure camelCase is used for all styles. For e.g. "letter-spacing" should be "letterSpacing". +- Schema.parse must not be used in the code. +- Use 'const {field1, field2, ...} = data;' to access the data. +- field1 or field2 or ... can be undefined, so use optional chaining to access them. +- Don't use "min-height" on cards and instead make its height grow/shrink to fit the content. +- Make sure cards/items are centered vertically and horizontally in the available space. +- Make sure no element is scrollable. +- Don't add any animations, transitions, or effects. +- Make sure no content elements are overflowing the slide boundaries. + +# Import and Export Rules: +- All import statements must be defined at the top. +- Export using 'export {Schema, layoutId, layoutName, layoutDescription, dynamicSlideLayout}' statement at the bottom. +- There must be only one 'export' statement in the whole TSX code. + +# Output Code Rules: +- Code should be in following order: + - Zod Schema (Schema) + - Layout ID, Name and Description (layoutId, layoutName, layoutDescription) + - React Component (dynamicSlideLayout) +- Give just one valid TSX code as output. +- Don't add comments in the code. +- Make sure the generated code is valid TSX code. +- Give only code as output and nothing else. (no json, no markdown, no text, no explanation) + +- Go through generated code and make sure all rules are followed. +- Think as long as you can and iterate as many times as necessary to make sure all rules are followed. +""" + +SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT = """ +You need to edit the given TSX code of the slide layout code according to the prompt and provide it as output. + +# Steps +1. Analyze the TSX code to understand the slide layout. +2. Analyze the prompt to understand the changes to be made. +3. Edit the TSX code according to the prompt. +4. Provide the updated TSX code as output. + +# Rules +- Make sure the changes does not break the existing code. +- Make sure to follow the pattern of the existing code. +- Make sure there are no unused schema fields after the changes are made. + +# Icons and Images Rules +Follow these rules if new icons/images are asked: +- Image field should be 'z.object({"image_url": z.string(), "image_prompt": z.string().max(100)})' +- Use this as default image url: 'https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png' +- Icon field should be 'z.object({"icon_url": z.string(), "icon_query": z.string().max(30)})' +- Use this as default icon url: 'https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg' + +# Schema Rules +- "describe" must be added for every fields. +- "default" must be added in top level fields of schema. +- Top level fields are those not nested inside other fields. +- Must set max for every string and array fields. +- Must set max to a number that will not cause overflow on max content. + +# Graphs And Table Rules +Follow these rules if new graphs/tables are asked: +1. Schema Rules +- Table must be object with "columns" and "rows" fields. +- "columns" must be an array of strings. +- "rows" must be an array of arrays of strings. +- Graph must be object with "categories" and "series" fields. +- "categories" must be an array of strings. +- "series" must be an array of objects with {"name": string, "data": array of numbers}. +2. React Component Rules +- Use recharts library for graphs. + +# Common Prompts +1. Fix the slide +- Check if text/cards/items is overflowing the slide boundaries or text/cards/items are overlapping. +- If yes, fix by moving the element to a better position or resizing the element. + +# Output Rules +- Make sure the schema and react component are valid. +- No matter what prompt is given, don't break the code. +- Provide only the updated TSX code as output and nothing else. (no json, no markdown, no text, no explanation) +""" + +SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT = """ +You need to edit the given TSX code of the slide layout code according to the prompt and provide it as output. + +# Steps +1. Analyze the TSX code to understand the slide layout. +2. Analyze the prompt to understand the changes to be made. +3. Edit the TSX code according to the prompt. +4. Provide the updated TSX code as output. + +# Rules +- Changes should be made only around the mentioned "section to make changes around". +- Make sure the changes does not break the existing code. +- Make sure to follow the pattern of the existing code. +- Make sure there are no unused schema fields after the changes are made. + +# Icons and Images Rules +Follow these rules if new icons/images are asked: +- Image field should be 'z.object({"image_url": z.string(), "image_prompt": z.string().max(100)})' +- Use this as default image url: 'https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png' +- Icon field should be 'z.object({"icon_url": z.string(), "icon_query": z.string().max(30)})' +- Use this as default icon url: 'https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg' + +# Output Rules +- Make sure the schema and react component are valid. +- No matter what prompt is given, don't break the code. +- Provide only the updated TSX code as output and nothing else. (no json, no markdown, no text, no explanation) +""" diff --git a/electron/servers/fastapi/templates/providers.py b/electron/servers/fastapi/templates/providers.py new file mode 100644 index 00000000..452c0bf5 --- /dev/null +++ b/electron/servers/fastapi/templates/providers.py @@ -0,0 +1,365 @@ +import asyncio +import base64 +from dataclasses import dataclass +import time +from typing import Awaitable, Callable, Optional + +from anthropic import AsyncAnthropic +from fastapi import HTTPException +from google import genai +from google.genai import types as google_types +from openai import AsyncOpenAI + +from enums.llm_provider import LLMProvider +from utils.get_env import ( + get_anthropic_api_key_env, + get_codex_access_token_env, + get_codex_account_id_env, + get_codex_refresh_token_env, + get_codex_token_expires_env, + get_google_api_key_env, + get_openai_api_key_env, +) +from utils.llm_provider import get_llm_provider +from utils.set_env import ( + set_codex_access_token_env, + set_codex_account_id_env, + set_codex_refresh_token_env, + set_codex_token_expires_env, +) + +OPENAI_TEMPLATE_MODEL = "gpt-5.4" +CODEX_TEMPLATE_MODEL = "gpt-5.4" +GOOGLE_TEMPLATE_MODEL = "gemini-3.1" +ANTHROPIC_TEMPLATE_MODEL = "opus 4.6" +MAX_ATTEMPTS_PER_PROVIDER = 4 + + +@dataclass(frozen=True) +class TemplateProviderSpec: + provider: LLMProvider + model: str + + +@dataclass(frozen=True) +class PlainLLMProvider: + name: str + call: Callable[[], Awaitable[str]] + + +def get_template_provider_spec() -> TemplateProviderSpec: + provider = get_llm_provider() + if provider == LLMProvider.OPENAI: + return TemplateProviderSpec(provider=provider, model=OPENAI_TEMPLATE_MODEL) + if provider == LLMProvider.CODEX: + return TemplateProviderSpec(provider=provider, model=CODEX_TEMPLATE_MODEL) + if provider == LLMProvider.GOOGLE: + return TemplateProviderSpec(provider=provider, model=GOOGLE_TEMPLATE_MODEL) + if provider == LLMProvider.ANTHROPIC: + return TemplateProviderSpec(provider=provider, model=ANTHROPIC_TEMPLATE_MODEL) + + raise HTTPException( + status_code=400, + detail="Template generation only supports OpenAI, Codex, Google, or Anthropic.", + ) + + +async def run_plain_provider_buckets(*, providers: list[PlainLLMProvider]) -> str: + last_exception: Optional[Exception] = None + + for provider in providers: + for _ in range(MAX_ATTEMPTS_PER_PROVIDER): + try: + response_text = await provider.call() + if response_text: + return response_text + raise ValueError("No output from template generation provider") + except Exception as exc: + last_exception = exc + + if isinstance(last_exception, HTTPException): + raise last_exception + raise HTTPException(status_code=500, detail="Failed to generate template output") + + +def _read_openai_response_text(response) -> str: + output_text = getattr(response, "output_text", None) + if output_text: + return output_text + text = getattr(response, "text", None) + if text: + return text + return "" + + +def _get_openai_client() -> AsyncOpenAI: + api_key = get_openai_api_key_env() + if not api_key: + raise HTTPException(status_code=400, detail="OPENAI_API_KEY is not set") + return AsyncOpenAI(api_key=api_key, timeout=120.0) + + +def _get_codex_headers() -> dict: + access_token = get_codex_access_token_env() + if not access_token: + raise HTTPException( + status_code=400, + detail="Codex OAuth access token is not set. Please authenticate via /api/v1/ppt/codex/auth/initiate", + ) + + expires_str = get_codex_token_expires_env() + if expires_str: + try: + expires_ms = int(expires_str) + now_ms = int(time.time() * 1000) + if now_ms >= expires_ms - 60_000: + refresh_token = get_codex_refresh_token_env() + if refresh_token: + from utils.oauth.openai_codex import ( + TokenSuccess, + get_account_id, + refresh_access_token, + ) + + result = refresh_access_token(refresh_token) + if isinstance(result, TokenSuccess): + set_codex_access_token_env(result.access) + set_codex_refresh_token_env(result.refresh) + set_codex_token_expires_env(str(result.expires)) + account_id = get_account_id(result.access) + if account_id: + set_codex_account_id_env(account_id) + access_token = result.access + except (TypeError, ValueError): + pass + + account_id = get_codex_account_id_env() or "" + return { + "Authorization": f"Bearer {access_token}", + "chatgpt-account-id": account_id, + "OpenAI-Beta": "responses=experimental", + "originator": "pi", + } + + +def _get_codex_client() -> AsyncOpenAI: + headers = _get_codex_headers() + access_token = (headers.get("Authorization") or "").replace("Bearer ", "").strip() + default_headers = { + key: value + for key, value in headers.items() + if key.lower() not in {"authorization", "content-type", "accept"} + } + return AsyncOpenAI( + base_url="https://chatgpt.com/backend-api/codex", + api_key=access_token or "codex", + default_headers=default_headers, + timeout=120.0, + ) + + +def _get_google_client() -> genai.Client: + api_key = get_google_api_key_env() + if not api_key: + raise HTTPException(status_code=400, detail="GOOGLE_API_KEY is not set") + return genai.Client(api_key=api_key) + + +def _get_anthropic_client() -> AsyncAnthropic: + api_key = get_anthropic_api_key_env() + if not api_key: + raise HTTPException(status_code=400, detail="ANTHROPIC_API_KEY is not set") + return AsyncAnthropic(api_key=api_key) + + +async def _call_openai_like( + *, + client: AsyncOpenAI, + model: str, + system_prompt: str, + user_text: str, + image_bytes: Optional[bytes] = None, + media_type: str = "image/png", + reasoning_effort: str = "medium", +) -> str: + content = [{"type": "input_text", "text": user_text}] + if image_bytes: + content.insert( + 0, + { + "type": "input_image", + "image_url": f"data:{media_type};base64,{base64.b64encode(image_bytes).decode('utf-8')}", + }, + ) + + response = await client.responses.create( + model=model, + instructions=system_prompt, + input=[{"role": "user", "content": content}], + reasoning={"effort": reasoning_effort}, + text={"verbosity": "medium"}, + store=False, + ) + output_text = _read_openai_response_text(response) + if not output_text: + raise HTTPException(status_code=500, detail="No output from template provider") + return output_text + + +async def _call_google( + *, + model: str, + system_prompt: str, + user_text: str, + image_bytes: Optional[bytes] = None, + media_type: str = "image/png", +) -> str: + client = _get_google_client() + parts = [google_types.Part.from_text(text=user_text)] + if image_bytes: + parts.append(google_types.Part.from_bytes(data=image_bytes, mime_type=media_type)) + + response = await asyncio.to_thread( + client.models.generate_content, + model=model, + contents=[google_types.Content(role="user", parts=parts)], + config=google_types.GenerateContentConfig( + system_instruction=system_prompt, + response_mime_type="text/plain", + ), + ) + output_text = getattr(response, "text", None) or "" + if not output_text: + raise HTTPException(status_code=500, detail="No output from template provider") + return output_text + + +async def _call_anthropic( + *, + model: str, + system_prompt: str, + user_text: str, + image_bytes: Optional[bytes] = None, + media_type: str = "image/png", +) -> str: + client = _get_anthropic_client() + content = [{"type": "text", "text": user_text}] + if image_bytes: + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": base64.b64encode(image_bytes).decode("utf-8"), + }, + } + ) + + response = await client.messages.create( + model=model, + max_tokens=8192, + system=system_prompt, + messages=[{"role": "user", "content": content}], + ) + output_text = "".join( + block.text for block in response.content if getattr(block, "type", None) == "text" + ) + if not output_text: + raise HTTPException(status_code=500, detail="No output from template provider") + return output_text + + +def _build_provider_call( + *, + system_prompt: str, + user_text: str, + image_bytes: Optional[bytes] = None, + media_type: str = "image/png", + reasoning_effort: str = "medium", +) -> PlainLLMProvider: + spec = get_template_provider_spec() + + if spec.provider == LLMProvider.OPENAI: + return PlainLLMProvider( + name="OpenAI", + call=lambda: _call_openai_like( + client=_get_openai_client(), + model=spec.model, + system_prompt=system_prompt, + user_text=user_text, + image_bytes=image_bytes, + media_type=media_type, + reasoning_effort=reasoning_effort, + ), + ) + if spec.provider == LLMProvider.CODEX: + return PlainLLMProvider( + name="Codex", + call=lambda: _call_openai_like( + client=_get_codex_client(), + model=spec.model, + system_prompt=system_prompt, + user_text=user_text, + image_bytes=image_bytes, + media_type=media_type, + reasoning_effort=reasoning_effort, + ), + ) + if spec.provider == LLMProvider.GOOGLE: + return PlainLLMProvider( + name="Google", + call=lambda: _call_google( + model=spec.model, + system_prompt=system_prompt, + user_text=user_text, + image_bytes=image_bytes, + media_type=media_type, + ), + ) + if spec.provider == LLMProvider.ANTHROPIC: + return PlainLLMProvider( + name="Anthropic", + call=lambda: _call_anthropic( + model=spec.model, + system_prompt=system_prompt, + user_text=user_text, + image_bytes=image_bytes, + media_type=media_type, + ), + ) + + raise HTTPException( + status_code=400, + detail="Template generation only supports OpenAI, Codex, Google, or Anthropic.", + ) + + +async def generate_slide_layout_code( + *, + system_prompt: str, + user_text: str, + image_bytes: bytes, + media_type: str = "image/png", +) -> str: + provider = _build_provider_call( + system_prompt=system_prompt, + user_text=user_text, + image_bytes=image_bytes, + media_type=media_type, + reasoning_effort="high", + ) + return await run_plain_provider_buckets(providers=[provider]) + + +async def edit_slide_layout_code( + *, + system_prompt: str, + user_text: str, +) -> str: + provider = _build_provider_call( + system_prompt=system_prompt, + user_text=user_text, + reasoning_effort="medium", + ) + return await run_plain_provider_buckets(providers=[provider]) diff --git a/electron/servers/fastapi/templates/router.py b/electron/servers/fastapi/templates/router.py new file mode 100644 index 00000000..3d303e4f --- /dev/null +++ b/electron/servers/fastapi/templates/router.py @@ -0,0 +1,65 @@ +import uuid + +from fastapi import APIRouter + +from templates.handler import ( + CreateSlideLayoutResponse, + EditSlideLayoutResponse, + EditSlideLayoutSectionResponse, + FontsUploadAndSlidesPreviewResponse, + GetTemplateLayoutsResponse, + PresentationLayoutModel, + SaveTemplateLayoutData, + SaveTemplateResponse, + TemplateDetail, + TemplateExample, + clone_slide_layout, + clone_template, + create_slide_layout, + edit_slide_layout, + edit_slide_layout_section, + get_all_templates, + get_layouts, + get_template_by_id, + get_template_example, + init_create_template, + save_slide_layout, + save_template, + update_template, + upload_fonts_and_slides_preview, +) + +TEMPLATE_ROUTER = APIRouter(prefix="/template", tags=["Template"]) + +TEMPLATE_ROUTER.get("/all", response_model=list[TemplateDetail])(get_all_templates) +TEMPLATE_ROUTER.get( + "/{template_id}/layouts", response_model=GetTemplateLayoutsResponse +)(get_layouts) +TEMPLATE_ROUTER.get("/{id}", response_model=PresentationLayoutModel)(get_template_by_id) +TEMPLATE_ROUTER.get("/{id}/example", response_model=TemplateExample)( + get_template_example +) +TEMPLATE_ROUTER.post( + "/fonts-upload-and-slides-preview", + response_model=FontsUploadAndSlidesPreviewResponse, +)(upload_fonts_and_slides_preview) +TEMPLATE_ROUTER.post("/create/init", response_model=uuid.UUID)(init_create_template) +TEMPLATE_ROUTER.post("/slide-layout/create", response_model=CreateSlideLayoutResponse)( + create_slide_layout +) +TEMPLATE_ROUTER.post("/create/slide-layout", response_model=CreateSlideLayoutResponse)( + create_slide_layout +) +TEMPLATE_ROUTER.post("/slide-layout/edit", response_model=EditSlideLayoutResponse)( + edit_slide_layout +) +TEMPLATE_ROUTER.post( + "/slide-layout/edit-section", response_model=EditSlideLayoutSectionResponse +)(edit_slide_layout_section) +TEMPLATE_ROUTER.post("/save", response_model=SaveTemplateResponse)(save_template) +TEMPLATE_ROUTER.post("/clone", response_model=SaveTemplateResponse)(clone_template) +TEMPLATE_ROUTER.put("/update", response_model=SaveTemplateResponse)(update_template) +TEMPLATE_ROUTER.post("/slide-layout/save", status_code=200)(save_slide_layout) +TEMPLATE_ROUTER.post("/slide-layout/clone", response_model=SaveTemplateLayoutData)( + clone_slide_layout +) diff --git a/electron/servers/fastapi/tests/test_pptx_slides_processing.py b/electron/servers/fastapi/tests/test_pptx_slides_processing.py deleted file mode 100644 index 13bff010..00000000 --- a/electron/servers/fastapi/tests/test_pptx_slides_processing.py +++ /dev/null @@ -1,140 +0,0 @@ -import os -import tempfile -import zipfile -from fastapi.testclient import TestClient -from fastapi import UploadFile -import pytest - -from api.main import app - - -client = TestClient(app) - - -def create_sample_pptx(): - """Create a minimal PPTX file for testing.""" - # This creates a very basic PPTX structure for testing - pptx_content = { - '[Content_Types].xml': ''' - - - - - -''', - '_rels/.rels': ''' - - -''', - 'ppt/presentation.xml': ''' - - - - - - -''', - 'ppt/_rels/presentation.xml.rels': ''' - - -''', - 'ppt/slides/slide1.xml': ''' - - - - - - - - - - - - - - - - - - - - - -''' - } - - with tempfile.NamedTemporaryFile(suffix='.pptx', delete=False) as temp_file: - with zipfile.ZipFile(temp_file.name, 'w') as zip_file: - for path, content in pptx_content.items(): - zip_file.writestr(path, content) - return temp_file.name - - -def test_pptx_slides_processing(): - """Test the PPTX slides processing endpoint.""" - - # Create a sample PPTX file - pptx_path = create_sample_pptx() - - try: - with open(pptx_path, 'rb') as pptx_file: - files = {'pptx_file': ('test.pptx', pptx_file, 'application/vnd.openxmlformats-officedocument.presentationml.presentation')} - - response = client.post("/api/v1/ppt/pptx-slides/process", files=files) - - # Check response - assert response.status_code == 200 - data = response.json() - - assert data['success'] == True - assert 'slides' in data - assert 'total_slides' in data - assert data['total_slides'] > 0 - - # Check slide data structure - if data['slides']: - slide = data['slides'][0] - assert 'slide_number' in slide - assert 'screenshot_url' in slide - assert 'xml_content' in slide - assert slide['slide_number'] == 1 - assert slide['xml_content'] != '' - - print(f"โœ… Test passed! Processed {data['total_slides']} slides successfully") - - finally: - # Clean up - if os.path.exists(pptx_path): - os.unlink(pptx_path) - - -def test_invalid_file_type(): - """Test that non-PPTX files are rejected.""" - - # Create a text file and try to upload it - with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as temp_file: - temp_file.write(b"This is not a PPTX file") - temp_file.flush() - - try: - with open(temp_file.name, 'rb') as txt_file: - files = {'pptx_file': ('test.txt', txt_file, 'text/plain')} - - response = client.post("/api/v1/ppt/pptx-slides/process", files=files) - - # Should return 400 for invalid file type - assert response.status_code == 400 - data = response.json() - assert 'Invalid file type' in data['detail'] - - print("โœ… Invalid file type test passed!") - - finally: - os.unlink(temp_file.name) - - -if __name__ == "__main__": - print("Running PPTX slides processing tests...") - test_pptx_slides_processing() - test_invalid_file_type() - print("๐ŸŽ‰ All tests completed!") \ No newline at end of file diff --git a/electron/servers/fastapi/tests/test_presentation_generation_api.py b/electron/servers/fastapi/tests/test_presentation_generation_api.py index 3806dc61..b2dc536e 100644 --- a/electron/servers/fastapi/tests/test_presentation_generation_api.py +++ b/electron/servers/fastapi/tests/test_presentation_generation_api.py @@ -1,189 +1,117 @@ -from unittest.mock import patch, AsyncMock, MagicMock +import asyncio +import uuid +from unittest.mock import AsyncMock, patch + import pytest -from fastapi.testclient import TestClient -from fastapi import FastAPI -from models.presentation_layout import PresentationLayoutModel -from models.presentation_structure_model import PresentationStructureModel -from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER +from fastapi import HTTPException +from pydantic import ValidationError -class MockAiohttpResponse: - def __init__(self, status=200, json_data=None): - self.status = status - self._json_data = json_data or {"path": "/tmp/exports/test.pdf"} +from api.v1.ppt.endpoints.presentation import generate_presentation_sync +from models.generate_presentation_request import GeneratePresentationRequest +from models.presentation_and_path import PresentationPathAndEditPath - async def __aenter__(self): - return self - async def __aexit__(self, exc_type, exc, tb): - pass +class FakeAsyncSession: + async def get(self, *_args, **_kwargs): + return None - async def json(self): - return self._json_data + def add(self, *_args, **_kwargs): + return None - async def text(self): - return str(self._json_data) + def add_all(self, *_args, **_kwargs): + return None -class MockAiohttpSession: - def __init__(self, *args, **kwargs): - pass + async def commit(self): + return None - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - pass - - def post(self, *args, **kwargs): - return MockAiohttpResponse() - - def get(self, *args, **kwargs): - pptx_model_data = { - "slides": [], - "title": "Test", - "notes": [], - "layout": {}, - "structure": {}, - } - return MockAiohttpResponse(json_data=pptx_model_data) - -@pytest.fixture -def app(): - app = FastAPI() - app.include_router(PRESENTATION_ROUTER, prefix="/api/v1/ppt") - return app - -@pytest.fixture -def client(app): - return TestClient(app) - -@pytest.fixture -def mock_get_layout(): - async def _mock_get_layout_by_name(layout_name: str): - mock_slide = MagicMock() - mock_slide.name = "Mock Slide" - mock_slide.json_schema = {"title": "Mock Slide Title"} - mock_slide.description = "Mock slide description" - mock_layout = MagicMock(spec=PresentationLayoutModel) - mock_layout.name = layout_name - mock_layout.ordered = True - mock_layout.slides = [mock_slide] - mock_layout.model_dump = lambda: {} - mock_layout.to_presentation_structure = lambda: PresentationStructureModel( - slides=[index for index in range(len(mock_layout.slides))] - ) - def to_string(): - message = f"## Presentation Layout\n\n" - for index, slide in enumerate(mock_layout.slides): - message += f"### Slide Layout: {index}: \n" - message += f"- Name: {slide.name or slide.json_schema.get('title')} \n" - message += f"- Description: {slide.description} \n\n" - return message - mock_layout.to_string = to_string - return mock_layout - return _mock_get_layout_by_name - -async def mock_generate_ppt_outline(*args, **kwargs): - yield '{"title": "Test", "slides": [{"title": "Slide 1", "body": "Body 1"}], "notes": []}' - -@pytest.fixture(autouse=True) -def patch_presentation_api(monkeypatch, mock_get_layout): - # Patch all dependencies used in the API - patches = [ - patch('api.v1.ppt.endpoints.presentation.get_layout_by_name', new=AsyncMock(side_effect=mock_get_layout)), - patch('api.v1.ppt.endpoints.presentation.TEMP_FILE_SERVICE.create_temp_dir', return_value='/tmp/mockdir'), - patch('api.v1.ppt.endpoints.presentation.DocumentsLoader'), - patch('api.v1.ppt.endpoints.presentation.generate_document_summary', new_callable=AsyncMock, return_value="mock_summary"), - patch('api.v1.ppt.endpoints.presentation.generate_ppt_outline', side_effect=mock_generate_ppt_outline), - patch('api.v1.ppt.endpoints.presentation.get_sql_session'), - patch('api.v1.ppt.endpoints.presentation.get_slide_content_from_type_and_outline', new_callable=AsyncMock, return_value={"mock": "slide_content"}), - patch('api.v1.ppt.endpoints.presentation.process_slide_and_fetch_assets', new_callable=AsyncMock), - patch('api.v1.ppt.endpoints.presentation.get_exports_directory', return_value='/tmp/exports'), - patch('api.v1.ppt.endpoints.presentation.PptxPresentationCreator'), - patch('api.v1.ppt.endpoints.presentation.aiohttp.ClientSession', return_value=MockAiohttpSession()), - ] - mocks = [p.start() for p in patches] - - # Setup DocumentsLoader mock - docs_loader = mocks[2] - docs_loader.return_value.load_documents = AsyncMock() - docs_loader.return_value.documents = [] - - # Setup PptxPresentationCreator mock for pptx test - pptx_creator = mocks[9] - pptx_creator.return_value.create_ppt = AsyncMock() - pptx_creator.return_value.save = MagicMock() - - yield - - for p in patches: - p.stop() class TestPresentationGenerationAPI: - def test_generate_presentation_export_as_pdf(self, client): - response = client.post( - "/api/v1/ppt/presentation/generate", - json={ - "content": "Create a presentation about artificial intelligence and machine learning", - "n_slides": 5, - "language": "English", - "export_as": "pdf", - "layout": "general" - } + def test_generate_presentation_export_as_pdf(self): + request = GeneratePresentationRequest( + content="Create a presentation about artificial intelligence and machine learning", + n_slides=5, + language="English", + export_as="pdf", + template="general", ) - assert response.status_code == 200 - assert "presentation_id" in response.json() - assert "pdf" in response.json()["path"] - - def test_generate_presentation_export_as_pptx(self, client): - response = client.post( - "/api/v1/ppt/presentation/generate", - json={ - "content": "Create a presentation about artificial intelligence and machine learning", - "n_slides": 5, - "language": "English", - "export_as": "pptx", - "layout": "general" - } + response_payload = PresentationPathAndEditPath( + presentation_id=uuid.uuid4(), + path="/tmp/exports/test.pdf", + edit_path="/presentation?id=test", ) - assert response.status_code == 200 - assert "presentation_id" in response.json() - assert "pptx" in response.json()["path"] - def test_generate_presentation_with_no_content(self, client): - response = client.post( - "/api/v1/ppt/presentation/generate", - json={ - "n_slides": 5, - "language": "English", - "export_as": "pdf", - "layout": "general" - } + with patch( + "api.v1.ppt.endpoints.presentation.generate_presentation_handler", + new=AsyncMock(return_value=response_payload), + ) as mock_handler: + response = asyncio.run( + generate_presentation_sync(request, sql_session=FakeAsyncSession()) + ) + + assert response == response_payload + mock_handler.assert_awaited_once() + + def test_generate_presentation_export_as_pptx(self): + request = GeneratePresentationRequest( + content="Create a presentation about artificial intelligence and machine learning", + n_slides=5, + language="English", + export_as="pptx", + template="general", ) - assert response.status_code == 422 - - - def test_generate_presentation_with_n_slides_less_than_one(self, client): - response = client.post( - "/api/v1/ppt/presentation/generate", - json={ - "content": "Create a presentation about artificial intelligence and machine learning", - "n_slides": 0, - "language": "English", - "export_as": "pdf", - "layout": "general" - } + response_payload = PresentationPathAndEditPath( + presentation_id=uuid.uuid4(), + path="/tmp/exports/test.pptx", + edit_path="/presentation?id=test", ) - assert response.status_code == 422 - def test_generate_presentation_with_invalid_export_type(self, client): - response = client.post( - "/api/v1/ppt/presentation/generate", - json={ - "content": "Create a presentation about artificial intelligence and machine learning", - "n_slides": 5, - "language": "English", - "export_as": "invalid_type", - "layout": "general" - } + with patch( + "api.v1.ppt.endpoints.presentation.generate_presentation_handler", + new=AsyncMock(return_value=response_payload), + ) as mock_handler: + response = asyncio.run( + generate_presentation_sync(request, sql_session=FakeAsyncSession()) + ) + + assert response == response_payload + mock_handler.assert_awaited_once() + + def test_generate_presentation_with_no_content(self): + with pytest.raises(ValidationError): + GeneratePresentationRequest.model_validate( + { + "n_slides": 5, + "language": "English", + "export_as": "pdf", + "template": "general", + } + ) + + def test_generate_presentation_with_n_slides_less_than_one(self): + request = GeneratePresentationRequest( + content="Create a presentation about artificial intelligence and machine learning", + n_slides=0, + language="English", + export_as="pdf", + template="general", ) - assert response.status_code == 422 + + with pytest.raises(HTTPException) as exc: + asyncio.run( + generate_presentation_sync(request, sql_session=FakeAsyncSession()) + ) + + assert exc.value.status_code == 400 + assert exc.value.detail == "Number of slides must be greater than 0" + + def test_generate_presentation_with_invalid_export_type(self): + with pytest.raises(ValidationError): + GeneratePresentationRequest.model_validate( + { + "content": "Create a presentation about artificial intelligence and machine learning", + "n_slides": 5, + "language": "English", + "export_as": "invalid_type", + "template": "general", + } + ) diff --git a/electron/servers/fastapi/tests/test_slide_to_html.py b/electron/servers/fastapi/tests/test_slide_to_html.py deleted file mode 100644 index 63f9fe92..00000000 --- a/electron/servers/fastapi/tests/test_slide_to_html.py +++ /dev/null @@ -1,115 +0,0 @@ -import pytest -import os -from fastapi.testclient import TestClient - -# Import the main app -from server import app - -client = TestClient(app) - - -def test_slide_to_html_endpoint(): - """Test the slide-to-html endpoint with streaming API support.""" - - # Sample XML data (simplified version of OXML) - test_xml = ''' - - - - - - - - - - - - - - - Test Slide - - - - - - -''' - - # Skip this test if ANTHROPIC_API_KEY is not set - if not os.getenv('ANTHROPIC_API_KEY'): - pytest.skip("ANTHROPIC_API_KEY not set - skipping API test") - - # Use a placeholder image path (since we can't easily test with real files) - test_data = { - "image": "/static/images/placeholder.jpg", - "xml": test_xml - } - - # Make the request with JSON - response = client.post( - "/api/v1/ppt/slide-to-html/", - json=test_data - ) - - # Check response (may take several minutes due to streaming) - print("Note: This test may take several minutes due to Claude's streaming processing...") - - if response.status_code == 200: - data = response.json() - assert data["success"] is True - assert "html" in data - assert len(data["html"]) > 0 - print(f"Generated HTML preview: {data['html'][:200]}...") - print("โœ… Streaming API test completed successfully") - else: - print(f"Request failed with status {response.status_code}: {response.text}") - # Don't fail the test if API key is missing or invalid - if "ANTHROPIC_API_KEY" in response.text: - pytest.skip("Invalid API key - skipping test") - elif "Streaming is required" in response.text: - print("โœ… Streaming error handled correctly by endpoint") - - -def test_slide_to_html_invalid_path(): - """Test the endpoint with an invalid image path.""" - - test_data = { - "image": "/app_data/images/nonexistent/image.png", - "xml": "xml" - } - - response = client.post( - "/api/v1/ppt/slide-to-html/", - json=test_data - ) - - assert response.status_code == 404 - assert "Image file not found" in response.json()["detail"] - - -def test_slide_to_html_missing_xml(): - """Test the endpoint with missing XML data.""" - - test_data = { - "image": "/static/images/placeholder.jpg" - # No XML data provided - } - - response = client.post( - "/api/v1/ppt/slide-to-html/", - json=test_data - ) - - assert response.status_code == 422 # Validation error - - -if __name__ == "__main__": - # Run a simple test - test_slide_to_html_invalid_path() - print("โœ… Invalid path test passed") - - test_slide_to_html_missing_xml() - print("โœ… Missing XML test passed") - - print("๐Ÿงช Run full tests with: pytest test_slide_to_html.py") \ No newline at end of file diff --git a/electron/servers/fastapi/tests/test_template_api.py b/electron/servers/fastapi/tests/test_template_api.py new file mode 100644 index 00000000..6a702c7b --- /dev/null +++ b/electron/servers/fastapi/tests/test_template_api.py @@ -0,0 +1,479 @@ +import asyncio +import base64 +import re +from datetime import datetime, timezone + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy.sql import Delete, Select + +from api.v1.ppt.router import API_V1_PPT_ROUTER +from enums.llm_provider import LLMProvider +from models.sql.presentation_layout_code import PresentationLayoutCodeModel +from models.sql.template import TemplateModel +from models.sql.template_create_info import TemplateCreateInfoModel +from templates.handler import ( + CloneSlideLayoutRequest, + CreateSlideLayoutRequest, + EditSlideLayoutRequest, + EditSlideLayoutSectionRequest, + SaveSlideLayoutRequest, + SaveTemplateLayoutData, + SaveTemplateRequest, + UpdateTemplateRequest, + clone_slide_layout, + create_slide_layout, + edit_slide_layout, + edit_slide_layout_section, + init_create_template, + save_slide_layout, + save_template, + update_template, + upload_fonts_and_slides_preview, +) +from templates.pptx_html_stub import BASIC_TEMPLATE_HTML +from templates.preview import ( + FontCheckResponse, + FontsUploadAndSlidesPreviewResponse, + check_fonts_in_pptx_handler, +) +from templates.providers import ( + ANTHROPIC_TEMPLATE_MODEL, + CODEX_TEMPLATE_MODEL, + GOOGLE_TEMPLATE_MODEL, + OPENAI_TEMPLATE_MODEL, + get_template_provider_spec, +) + + +PNG_BYTES = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAusB9s5WzxQAAAAASUVORK5CYII=" +) + + +@pytest.fixture +def api_client(tmp_path, monkeypatch): + monkeypatch.setenv("APP_DATA_DIRECTORY", str(tmp_path / "app_data")) + + app = FastAPI() + app.include_router(API_V1_PPT_ROUTER) + client = TestClient(app) + yield client, tmp_path + client.close() + + +class FakeScalarResult: + def __init__(self, items): + self._items = list(items) + + def all(self): + return list(self._items) + + +class FakeExecuteResult: + def __init__(self, items): + self._items = list(items) + + def scalars(self): + return FakeScalarResult(self._items) + + def all(self): + return list(self._items) + + +class FakeAsyncSession: + def __init__(self): + self.template_infos: dict = {} + self.templates: dict = {} + self.layouts: list[PresentationLayoutCodeModel] = [] + self._next_layout_row_id = 1 + + async def get(self, model, key): + if model is TemplateCreateInfoModel: + return self.template_infos.get(key) + if model is TemplateModel: + return self.templates.get(key) + return None + + def add(self, obj): + if isinstance(obj, TemplateCreateInfoModel): + if getattr(obj, "created_at", None) is None: + obj.created_at = datetime.now(timezone.utc) + self.template_infos[obj.id] = obj + elif isinstance(obj, TemplateModel): + if getattr(obj, "created_at", None) is None: + obj.created_at = datetime.now(timezone.utc) + self.templates[obj.id] = obj + elif isinstance(obj, PresentationLayoutCodeModel): + if obj.id is None: + obj.id = self._next_layout_row_id + self._next_layout_row_id += 1 + if getattr(obj, "created_at", None) is None: + obj.created_at = datetime.now(timezone.utc) + if getattr(obj, "updated_at", None) is None: + obj.updated_at = datetime.now(timezone.utc) + self.layouts.append(obj) + + def add_all(self, objects): + for obj in objects: + self.add(obj) + + async def commit(self): + return None + + async def refresh(self, _obj): + return None + + async def scalar(self, statement): + items = self._execute_select(statement) + return items[0] if items else None + + async def execute(self, statement): + if isinstance(statement, Delete): + self._execute_delete(statement) + return FakeExecuteResult([]) + return FakeExecuteResult(self._execute_select(statement)) + + def _execute_select(self, statement): + if not isinstance(statement, Select): + return [] + entity = statement.column_descriptions[0].get("entity") + if entity is not PresentationLayoutCodeModel: + return [] + return [ + layout + for layout in self.layouts + if all(self._matches_clause(layout, clause) for clause in statement._where_criteria) + ] + + def _execute_delete(self, statement): + if statement.table.name != PresentationLayoutCodeModel.__tablename__: + return + self.layouts = [ + layout + for layout in self.layouts + if not all(self._matches_clause(layout, clause) for clause in statement._where_criteria) + ] + + def _matches_clause(self, obj, clause): + if hasattr(clause, "clauses"): + return all(self._matches_clause(obj, child) for child in clause.clauses) + left = getattr(clause.left, "key", None) or getattr(clause.left, "name", None) + right = getattr(clause.right, "value", None) + return getattr(obj, left) == right + + +class SimpleUploadFile: + def __init__(self, filename: str, content_type: str, data: bytes): + self.filename = filename + self.content_type = content_type + self._data = data + + async def read(self): + return self._data + + +def test_router_registration_replaces_old_routes(api_client): + client, _ = api_client + paths = { + route.path + for route in client.app.routes + if hasattr(route, "path") and route.path.startswith("/api/v1/ppt") + } + + assert "/api/v1/ppt/template/all" in paths + assert "/api/v1/ppt/template/create/init" in paths + assert "/api/v1/ppt/template/slide-layout/create" in paths + assert "/api/v1/ppt/template/fonts-upload-and-slides-preview" in paths + assert "/api/v1/ppt/fonts/check" in paths + + assert "/api/v1/ppt/fonts/upload" in paths + assert "/api/v1/ppt/fonts/list" in paths + assert "/api/v1/ppt/fonts/uploaded" in paths + + assert "/api/v1/ppt/slide-to-html/" not in paths + assert "/api/v1/ppt/html-to-react/" not in paths + assert "/api/v1/ppt/html-edit/" not in paths + assert "/api/v1/ppt/pptx-slides/process" not in paths + assert "/api/v1/ppt/pdf-slides/process" not in paths + assert "/api/v1/ppt/pptx-fonts/process" not in paths + + +def test_template_create_init_stores_stub_htmls(): + session = FakeAsyncSession() + + template_info_id = asyncio.run( + init_create_template( + request=type( + "Request", + (), + { + "pptx_url": "/app_data/uploads/template-previews/test/presentation.pptx", + "slide_image_urls": [ + "/app_data/images/a/slide_1.png", + "/app_data/images/a/slide_2.png", + ], + "fonts": { + "Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap" + }, + }, + )(), + sql_session=session, + ) + ) + + template_info = session.template_infos[template_info_id] + assert template_info.slide_htmls == [BASIC_TEMPLATE_HTML, BASIC_TEMPLATE_HTML] + assert template_info.slide_image_urls == [ + "/app_data/images/a/slide_1.png", + "/app_data/images/a/slide_2.png", + ] + + +def test_fonts_check_endpoint(monkeypatch): + async def fake_font_check(_pptx_path: str, _temp_dir: str): + return [("Inter", "https://fonts.googleapis.com/css2?family=Inter&display=swap")], [("Custom Font", None)] + + async def fake_to_thread(func, *args, **kwargs): + return func(*args, **kwargs) + + monkeypatch.setattr( + "templates.preview.get_available_and_unavailable_fonts_for_pptx", + fake_font_check, + ) + monkeypatch.setattr("templates.preview.asyncio.to_thread", fake_to_thread) + + upload = SimpleUploadFile( + filename="deck.pptx", + content_type="application/vnd.openxmlformats-officedocument.presentationml.presentation", + data=b"fake-pptx", + ) + response = asyncio.run(check_fonts_in_pptx_handler(pptx_file=upload)) + + assert response == FontCheckResponse( + available_fonts=[ + { + "name": "Inter", + "url": "https://fonts.googleapis.com/css2?family=Inter&display=swap", + } + ], + unavailable_fonts=[{"name": "Custom Font", "url": None}], + ) + + +def test_fonts_upload_and_preview_route_uses_new_handler(monkeypatch): + async def fake_preview_handler(**kwargs): + assert kwargs["pptx_file"].filename == "deck.pptx" + return FontsUploadAndSlidesPreviewResponse( + slide_image_urls=["/app_data/images/1/slide_1.png"], + pptx_url="/app_data/uploads/template-previews/1/presentation.pptx", + modified_pptx_url="/app_data/uploads/template-previews/1/presentation.pptx", + fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"}, + ) + + monkeypatch.setattr( + "templates.handler.upload_fonts_and_slides_preview_handler", + fake_preview_handler, + ) + + upload = SimpleUploadFile( + filename="deck.pptx", + content_type="application/vnd.openxmlformats-officedocument.presentationml.presentation", + data=b"fake-pptx", + ) + response = asyncio.run(upload_fonts_and_slides_preview(pptx_file=upload)) + + assert response == FontsUploadAndSlidesPreviewResponse( + slide_image_urls=["/app_data/images/1/slide_1.png"], + pptx_url="/app_data/uploads/template-previews/1/presentation.pptx", + modified_pptx_url="/app_data/uploads/template-previews/1/presentation.pptx", + fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"}, + ) + + +def test_provider_spec_mapping_and_restrictions(monkeypatch): + monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.OPENAI) + spec = get_template_provider_spec() + assert spec.provider == LLMProvider.OPENAI + assert spec.model == OPENAI_TEMPLATE_MODEL + + monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.CODEX) + spec = get_template_provider_spec() + assert spec.provider == LLMProvider.CODEX + assert spec.model == CODEX_TEMPLATE_MODEL + + monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.GOOGLE) + spec = get_template_provider_spec() + assert spec.provider == LLMProvider.GOOGLE + assert spec.model == GOOGLE_TEMPLATE_MODEL + + monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.ANTHROPIC) + spec = get_template_provider_spec() + assert spec.provider == LLMProvider.ANTHROPIC + assert spec.model == ANTHROPIC_TEMPLATE_MODEL + + monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.OLLAMA) + with pytest.raises(Exception) as exc: + get_template_provider_spec() + assert "Template generation only supports OpenAI, Codex, Google, or Anthropic." in str(exc.value) + + +def test_create_and_edit_slide_layout_routes_use_provider_layer(tmp_path, monkeypatch): + session = FakeAsyncSession() + image_path = tmp_path / "slide.png" + image_path.write_bytes(PNG_BYTES) + + template_info = TemplateCreateInfoModel( + fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"}, + pptx_url="/app_data/uploads/template-previews/seed/presentation.pptx", + slide_htmls=["
seed html
"], + slide_image_urls=[str(image_path)], + ) + session.add(template_info) + + create_calls = [] + edit_calls = [] + + async def fake_generate_layout(**kwargs): + create_calls.append(kwargs) + return """```tsx +import { z } from "zod"; +const layoutId = "title-image"; +const layoutName = "Title Image"; +const layoutDescription = "desc"; +function dynamicSlideLayout() { return
{image_url}{icon_url}{image_prompt}{icon_query}
; } +export { layoutId }; +```""" + + async def fake_edit_layout(**kwargs): + edit_calls.append(kwargs) + return "```tsx\nconst updatedLayout = true;\n```" + + monkeypatch.setattr("templates.handler.generate_slide_layout_code", fake_generate_layout) + monkeypatch.setattr("templates.handler.edit_slide_layout_code", fake_edit_layout) + + create_response = asyncio.run( + create_slide_layout( + request=CreateSlideLayoutRequest(id=template_info.id, index=0), + sql_session=session, + ) + ) + assert "__image_url__" in create_response.react_component + assert "__icon_url__" in create_response.react_component + assert "__image_prompt__" in create_response.react_component + assert "__icon_query__" in create_response.react_component + assert "import " not in create_response.react_component + assert "export " not in create_response.react_component + assert re.search(r'layoutId\s*=\s*"title-image-\d{4}"', create_response.react_component) + + assert create_calls + assert create_calls[0]["system_prompt"] + assert "#SLIDE HTML REFERENCE" in create_calls[0]["user_text"] + assert create_calls[0]["media_type"] == "image/png" + assert create_calls[0]["image_bytes"] == PNG_BYTES + + edit_response = asyncio.run( + edit_slide_layout( + request=EditSlideLayoutRequest( + react_component="const x = 1;", + prompt="Move title up", + ) + ) + ) + assert edit_response.react_component == "const updatedLayout = true;" + + section_response = asyncio.run( + edit_slide_layout_section( + request=EditSlideLayoutSectionRequest( + react_component="const x = 1;", + section="header", + prompt="Change spacing", + ) + ) + ) + assert section_response.react_component == "const updatedLayout = true;" + + assert len(edit_calls) == 2 + assert "#Prompt\nMove title up" in edit_calls[0]["user_text"] + assert "#Section to make changes around\nheader" in edit_calls[1]["user_text"] + + +def test_save_update_and_clone_template_flow(): + session = FakeAsyncSession() + template_info = TemplateCreateInfoModel( + fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"}, + pptx_url="/app_data/uploads/template-previews/seed/presentation.pptx", + slide_htmls=["
seed html
"], + slide_image_urls=["/app_data/images/seed/slide_1.png"], + ) + session.add(template_info) + + save_response = asyncio.run( + save_template( + request=SaveTemplateRequest( + template_info_id=template_info.id, + name="My Template", + description="Saved from test", + layouts=[ + SaveTemplateLayoutData( + layout_id="title-image-1000", + layout_name="Title Image", + layout_code='const layoutId = "title-image-1000";', + ) + ], + ), + sql_session=session, + ) + ) + template_id = save_response.id + + clone_layout_response = asyncio.run( + clone_slide_layout( + request=CloneSlideLayoutRequest( + template_id=f"custom-{template_id}", + layout_id="title-image-1000", + layout_name="Cloned Layout", + ), + sql_session=session, + ) + ) + assert clone_layout_response.layout_name == "Cloned Layout" + assert clone_layout_response.layout_id != "title-image-1000" + + asyncio.run( + save_slide_layout( + request=SaveSlideLayoutRequest( + template_id=template_id, + layout_id="title-image-1000", + layout_code='const layoutId = "title-image-1000"; const edited = true;', + ), + sql_session=session, + ) + ) + assert any("edited = true" in layout.layout_code for layout in session.layouts) + + update_response = asyncio.run( + update_template( + request=UpdateTemplateRequest( + id=template_id, + layouts=[ + SaveTemplateLayoutData( + layout_id="updated-layout-2000", + layout_name="Updated Layout", + layout_code='const layoutId = "updated-layout-2000";', + ) + ], + ), + sql_session=session, + ) + ) + assert update_response.id == template_id + + layouts = [layout for layout in session.layouts if layout.presentation == template_id] + assert session.templates[template_id] is not None + assert len(layouts) == 1 + assert layouts[0].layout_id == "updated-layout-2000" + assert layouts[0].fonts == { + "Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap" + } diff --git a/electron/servers/fastapi/utils/llm_calls/edit_slide.py b/electron/servers/fastapi/utils/llm_calls/edit_slide.py index 6f929aa4..00d2f9b5 100644 --- a/electron/servers/fastapi/utils/llm_calls/edit_slide.py +++ b/electron/servers/fastapi/utils/llm_calls/edit_slide.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional from models.llm_message import LLMSystemMessage, LLMUserMessage -from models.presentation_layout import SlideLayoutModel +from templates.presentation_layout import SlideLayoutModel from models.sql.slide import SlideModel from services.llm_client import LLMClient from utils.llm_client_error_handler import handle_llm_client_exceptions diff --git a/electron/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/electron/servers/fastapi/utils/llm_calls/generate_presentation_structure.py index bbe26172..65c623e2 100644 --- a/electron/servers/fastapi/utils/llm_calls/generate_presentation_structure.py +++ b/electron/servers/fastapi/utils/llm_calls/generate_presentation_structure.py @@ -1,7 +1,7 @@ from typing import Optional, Dict from models.llm_message import LLMSystemMessage, LLMUserMessage -from models.presentation_layout import PresentationLayoutModel +from templates.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from services.llm_client import LLMClient from utils.llm_client_error_handler import handle_llm_client_exceptions diff --git a/electron/servers/fastapi/utils/llm_calls/generate_slide_content.py b/electron/servers/fastapi/utils/llm_calls/generate_slide_content.py index a5010cf2..773c54dc 100644 --- a/electron/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/electron/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -2,7 +2,7 @@ from datetime import datetime import json from typing import Optional from models.llm_message import LLMSystemMessage, LLMUserMessage -from models.presentation_layout import SlideLayoutModel +from templates.presentation_layout import SlideLayoutModel from models.presentation_outline_model import SlideOutlineModel from services.llm_client import LLMClient from utils.llm_client_error_handler import handle_llm_client_exceptions diff --git a/electron/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/electron/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py index d0e52379..23bbc2f9 100644 --- a/electron/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py +++ b/electron/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -1,5 +1,5 @@ from models.llm_message import LLMSystemMessage, LLMUserMessage -from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel +from templates.presentation_layout import PresentationLayoutModel, SlideLayoutModel from models.slide_layout_index import SlideLayoutIndex from models.sql.slide import SlideModel from services.llm_client import LLMClient diff --git a/electron/servers/fastapi/utils/ppt_utils.py b/electron/servers/fastapi/utils/ppt_utils.py index 9cd3b2e0..846e6068 100644 --- a/electron/servers/fastapi/utils/ppt_utils.py +++ b/electron/servers/fastapi/utils/ppt_utils.py @@ -1,4 +1,4 @@ -from models.presentation_layout import PresentationLayoutModel +from templates.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel import re from typing import List