feat: implements new template generation flow; refactor: removes old template generation flow and adds new endpoints
This commit is contained in:
parent
ca186c6d20
commit
087d64ed62
35 changed files with 2756 additions and 2244 deletions
1
electron/.gitignore
vendored
1
electron/.gitignore
vendored
|
|
@ -10,6 +10,7 @@ app_data
|
|||
tmp
|
||||
debug
|
||||
.fastembed_cache
|
||||
.codex
|
||||
|
||||
generated_models
|
||||
nltk
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from models.sql.presentation_layout_code import ( # noqa: F401, E402
|
|||
)
|
||||
from models.sql.slide import SlideModel # noqa: F401, E402
|
||||
from models.sql.template import TemplateModel # noqa: F401, E402
|
||||
from models.sql.template_create_info import TemplateCreateInfoModel # noqa: F401, E402
|
||||
from models.sql.webhook_subscription import WebhookSubscription # noqa: F401, E402
|
||||
|
||||
alembic_config = context.config
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from templates.preview import FontCheckResponse, check_fonts_in_pptx_handler
|
||||
from utils.asset_directory_utils import get_app_data_directory_env
|
||||
import uuid
|
||||
|
||||
try:
|
||||
from fontTools.ttLib import TTFont
|
||||
|
|
@ -287,6 +286,9 @@ async def get_uploaded_fonts():
|
|||
)
|
||||
|
||||
|
||||
FONTS_ROUTER.post("/check", response_model=FontCheckResponse)(check_fonts_in_pptx_handler)
|
||||
|
||||
|
||||
@FONTS_ROUTER.delete("/delete/{filename}")
|
||||
async def delete_font(filename: str):
|
||||
"""
|
||||
|
|
@ -330,4 +332,4 @@ async def delete_font(filename: str):
|
|||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error deleting font: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
import aiohttp
|
||||
from typing import List, Any
|
||||
from utils.get_layout_by_name import get_layout_by_name
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from templates.get_layout_by_name import get_layout_by_name
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
|
||||
LAYOUTS_ROUTER = APIRouter(prefix="/layouts", tags=["Layouts"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,116 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
import uuid
|
||||
from constants.documents import PDF_MIME_TYPES
|
||||
|
||||
|
||||
PDF_SLIDES_ROUTER = APIRouter(prefix="/pdf-slides", tags=["PDF Slides"])
|
||||
|
||||
|
||||
class PdfSlideData(BaseModel):
|
||||
slide_number: int
|
||||
screenshot_url: str
|
||||
|
||||
|
||||
class PdfSlidesResponse(BaseModel):
|
||||
success: bool
|
||||
slides: List[PdfSlideData]
|
||||
total_slides: int
|
||||
|
||||
|
||||
@PDF_SLIDES_ROUTER.post("/process", response_model=PdfSlidesResponse)
|
||||
async def process_pdf_slides(
|
||||
pdf_file: UploadFile = File(..., description="PDF file to process")
|
||||
):
|
||||
"""
|
||||
Process a PDF file to extract slide screenshots.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the uploaded PDF file
|
||||
2. Uses ImageMagick to convert PDF pages to PNG images
|
||||
3. Returns screenshot URLs for each slide/page
|
||||
|
||||
Note: Font installation is not needed since PDFs already have fonts embedded.
|
||||
"""
|
||||
|
||||
# Validate PDF file
|
||||
if pdf_file.content_type not in PDF_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Expected PDF file, got {pdf_file.content_type}",
|
||||
)
|
||||
# Enforce 100MB size limit
|
||||
if (
|
||||
hasattr(pdf_file, "size")
|
||||
and pdf_file.size
|
||||
and pdf_file.size > (100 * 1024 * 1024)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="PDF file exceeded max upload size of 100 MB",
|
||||
)
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Save uploaded PDF file
|
||||
pdf_path = os.path.join(temp_dir, "presentation.pdf")
|
||||
with open(pdf_path, "wb") as f:
|
||||
pdf_content = await pdf_file.read()
|
||||
f.write(pdf_content)
|
||||
|
||||
# Generate screenshots from PDF using ImageMagick
|
||||
screenshot_paths = await DocumentsLoader.get_page_images_from_pdf_async(
|
||||
pdf_path, temp_dir
|
||||
)
|
||||
print(f"Generated {len(screenshot_paths)} PDF screenshots")
|
||||
|
||||
# Move screenshots to images directory and generate URLs
|
||||
images_dir = get_images_directory()
|
||||
presentation_id = uuid.uuid4()
|
||||
presentation_images_dir = os.path.join(images_dir, str(presentation_id))
|
||||
os.makedirs(presentation_images_dir, exist_ok=True)
|
||||
|
||||
slides_data = []
|
||||
|
||||
for i, screenshot_path in enumerate(screenshot_paths, 1):
|
||||
# Move screenshot to permanent location
|
||||
screenshot_filename = f"slide_{i}.png"
|
||||
permanent_screenshot_path = os.path.join(
|
||||
presentation_images_dir, screenshot_filename
|
||||
)
|
||||
|
||||
if (
|
||||
os.path.exists(screenshot_path)
|
||||
and os.path.getsize(screenshot_path) > 0
|
||||
):
|
||||
# Use shutil.copy2 instead of os.rename to handle cross-device moves
|
||||
shutil.copy2(screenshot_path, permanent_screenshot_path)
|
||||
screenshot_url = (
|
||||
f"/app_data/images/{presentation_id}/{screenshot_filename}"
|
||||
)
|
||||
else:
|
||||
# Fallback if screenshot generation failed or file is empty placeholder
|
||||
screenshot_url = "/static/images/placeholder.jpg"
|
||||
|
||||
slides_data.append(
|
||||
PdfSlideData(slide_number=i, screenshot_url=screenshot_url)
|
||||
)
|
||||
|
||||
return PdfSlidesResponse(
|
||||
success=True, slides=slides_data, total_slides=len(slides_data)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing PDF slides: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to process PDF: {str(e)}"
|
||||
)
|
||||
|
|
@ -1,641 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
import subprocess
|
||||
import uuid
|
||||
from typing import List, Optional, Dict
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import xml.etree.ElementTree as ET
|
||||
import re
|
||||
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
import uuid
|
||||
from constants.documents import PPTX_MIME_TYPES
|
||||
|
||||
|
||||
def _get_soffice_binary() -> str:
|
||||
"""Return the soffice binary to use for LibreOffice subprocess calls.
|
||||
|
||||
When running inside the Electron desktop app, the main process resolves the
|
||||
exact soffice binary path at startup and forwards it via the ``SOFFICE_PATH``
|
||||
environment variable. Falling back to the bare ``"soffice"`` command keeps
|
||||
Docker / server deployments working unchanged.
|
||||
"""
|
||||
configured = os.environ.get("SOFFICE_PATH")
|
||||
if configured:
|
||||
return configured
|
||||
return "soffice.exe" if os.name == "nt" else "soffice"
|
||||
|
||||
|
||||
def _windows_hidden_subprocess_kwargs() -> Dict[str, object]:
|
||||
"""Return subprocess kwargs that suppress Windows console windows."""
|
||||
if os.name != "nt":
|
||||
return {}
|
||||
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
return {
|
||||
"creationflags": getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
"startupinfo": startupinfo,
|
||||
}
|
||||
|
||||
|
||||
PPTX_SLIDES_ROUTER = APIRouter(prefix="/pptx-slides", tags=["PPTX Slides"])
|
||||
|
||||
|
||||
class SlideData(BaseModel):
|
||||
slide_number: int
|
||||
screenshot_url: str
|
||||
xml_content: str
|
||||
normalized_fonts: List[str]
|
||||
|
||||
|
||||
class FontAnalysisResult(BaseModel):
|
||||
internally_supported_fonts: List[
|
||||
Dict[str, str]
|
||||
] # [{"name": "Open Sans", "google_fonts_url": "..."}]
|
||||
not_supported_fonts: List[str] # ["Custom Font Name"]
|
||||
|
||||
|
||||
class PptxSlidesResponse(BaseModel):
|
||||
success: bool
|
||||
slides: List[SlideData]
|
||||
total_slides: int
|
||||
fonts: Optional[FontAnalysisResult] = None
|
||||
|
||||
|
||||
# NEW: Fonts-only router and response for PPTX
|
||||
class PptxFontsResponse(BaseModel):
|
||||
success: bool
|
||||
fonts: FontAnalysisResult
|
||||
|
||||
|
||||
PPTX_FONTS_ROUTER = APIRouter(prefix="/pptx-fonts", tags=["PPTX Fonts"])
|
||||
|
||||
# NEW: Normalize font family names by removing style/weight/stretch descriptors and splitting camel case
|
||||
_STYLE_TOKENS = {
|
||||
# styles
|
||||
"italic",
|
||||
"italics",
|
||||
"ital",
|
||||
"oblique",
|
||||
"roman",
|
||||
# combined style shortcuts
|
||||
"bolditalic",
|
||||
"bolditalics",
|
||||
# weights
|
||||
"thin",
|
||||
"hairline",
|
||||
"extralight",
|
||||
"ultralight",
|
||||
"light",
|
||||
"demilight",
|
||||
"semilight",
|
||||
"book",
|
||||
"regular",
|
||||
"normal",
|
||||
"medium",
|
||||
"semibold",
|
||||
"demibold",
|
||||
"bold",
|
||||
"extrabold",
|
||||
"ultrabold",
|
||||
"black",
|
||||
"extrablack",
|
||||
"ultrablack",
|
||||
"heavy",
|
||||
# width/stretch
|
||||
"narrow",
|
||||
"condensed",
|
||||
"semicondensed",
|
||||
"extracondensed",
|
||||
"ultracondensed",
|
||||
"expanded",
|
||||
"semiexpanded",
|
||||
"extraexpanded",
|
||||
"ultraexpanded",
|
||||
}
|
||||
# Modifiers commonly used with style tokens
|
||||
_STYLE_MODIFIERS = {"semi", "demi", "extra", "ultra"}
|
||||
|
||||
|
||||
def _insert_spaces_in_camel_case(value: str) -> str:
|
||||
# Insert space before capital letters preceded by lowercase or digits (e.g., MontserratBold -> Montserrat Bold)
|
||||
value = re.sub(r"(?<=[a-z0-9])([A-Z])", r" \1", value)
|
||||
# Handle sequences like BoldItalic -> Bold Italic
|
||||
value = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", value)
|
||||
return value
|
||||
|
||||
|
||||
def normalize_font_family_name(raw_name: str) -> str:
|
||||
if not raw_name:
|
||||
return raw_name
|
||||
# Replace separators with spaces
|
||||
name = raw_name.replace("_", " ").replace("-", " ")
|
||||
# Insert spaces in camel case
|
||||
name = _insert_spaces_in_camel_case(name)
|
||||
# Collapse multiple spaces
|
||||
name = re.sub(r"\s+", " ", name).strip()
|
||||
# Lowercase helper for matching but keep original casing for output
|
||||
lower_name = name.lower()
|
||||
# Quick cut: if the full string ends with a pure style suffix, trim it
|
||||
for style in sorted(_STYLE_TOKENS, key=len, reverse=True):
|
||||
if lower_name.endswith(" " + style):
|
||||
name = name[: -(len(style) + 1)]
|
||||
lower_name = lower_name[: -(len(style) + 1)]
|
||||
break
|
||||
# Tokenize
|
||||
tokens_original = name.split(" ")
|
||||
tokens_filtered: List[str] = []
|
||||
for index, tok in enumerate(tokens_original):
|
||||
lower_tok = tok.lower()
|
||||
# Always keep the first token to avoid stripping families like "Black Ops One"
|
||||
if index == 0:
|
||||
tokens_filtered.append(tok)
|
||||
continue
|
||||
# Drop style tokens and standalone modifiers
|
||||
if lower_tok in _STYLE_TOKENS or lower_tok in _STYLE_MODIFIERS:
|
||||
continue
|
||||
tokens_filtered.append(tok)
|
||||
# If everything except first token was dropped and first token is a style token (unlikely), fallback to original
|
||||
if not tokens_filtered:
|
||||
tokens_filtered = tokens_original
|
||||
normalized = " ".join(tokens_filtered).strip()
|
||||
# Final cleanup of leftover multiple spaces
|
||||
normalized = re.sub(r"\s+", " ", normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def extract_fonts_from_oxml(xml_content: str) -> List[str]:
|
||||
"""
|
||||
Extract font names from OXML content.
|
||||
|
||||
Args:
|
||||
xml_content: OXML content as string
|
||||
|
||||
Returns:
|
||||
List of unique font names found in the OXML
|
||||
"""
|
||||
fonts = set()
|
||||
|
||||
try:
|
||||
# Parse the XML content
|
||||
root = ET.fromstring(xml_content)
|
||||
|
||||
# Define namespaces commonly used in OXML
|
||||
namespaces = {
|
||||
"a": "http://schemas.openxmlformats.org/drawingml/2006/main",
|
||||
"p": "http://schemas.openxmlformats.org/presentationml/2006/main",
|
||||
"r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
|
||||
}
|
||||
|
||||
# Search for font references in various OXML elements
|
||||
# Look for latin fonts
|
||||
for font_elem in root.findall(".//a:latin", namespaces):
|
||||
if "typeface" in font_elem.attrib:
|
||||
fonts.add(font_elem.attrib["typeface"])
|
||||
|
||||
# Look for east asian fonts
|
||||
for font_elem in root.findall(".//a:ea", namespaces):
|
||||
if "typeface" in font_elem.attrib:
|
||||
fonts.add(font_elem.attrib["typeface"])
|
||||
|
||||
# Look for complex script fonts
|
||||
for font_elem in root.findall(".//a:cs", namespaces):
|
||||
if "typeface" in font_elem.attrib:
|
||||
fonts.add(font_elem.attrib["typeface"])
|
||||
|
||||
# Look for font references in theme elements
|
||||
for font_elem in root.findall(".//a:font", namespaces):
|
||||
if "typeface" in font_elem.attrib:
|
||||
fonts.add(font_elem.attrib["typeface"])
|
||||
|
||||
# Look for rPr (run properties) font references
|
||||
for rpr_elem in root.findall(".//a:rPr", namespaces):
|
||||
for font_elem in rpr_elem.findall(".//a:latin", namespaces):
|
||||
if "typeface" in font_elem.attrib:
|
||||
fonts.add(font_elem.attrib["typeface"])
|
||||
|
||||
# Also search without namespace prefix for compatibility
|
||||
for font_elem in root.findall(".//latin"):
|
||||
if "typeface" in font_elem.attrib:
|
||||
fonts.add(font_elem.attrib["typeface"])
|
||||
|
||||
# Regex fallback for fonts that might be missed
|
||||
font_pattern = r'typeface="([^"]+)"'
|
||||
regex_fonts = re.findall(font_pattern, xml_content)
|
||||
fonts.update(regex_fonts)
|
||||
|
||||
# Filter out system fonts and empty values
|
||||
system_fonts = {"+mn-lt", "+mj-lt", "+mn-ea", "+mj-ea", "+mn-cs", "+mj-cs", ""}
|
||||
fonts = {font for font in fonts if font not in system_fonts and font.strip()}
|
||||
|
||||
return list(fonts)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting fonts from OXML: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def check_google_font_availability(font_name: str) -> bool:
|
||||
"""
|
||||
Check if a font is available in Google Fonts.
|
||||
|
||||
Args:
|
||||
font_name: Name of the font to check
|
||||
|
||||
Returns:
|
||||
True if font is available in Google Fonts, False otherwise
|
||||
"""
|
||||
try:
|
||||
formatted_name = font_name.replace(" ", "+")
|
||||
url = f"https://fonts.googleapis.com/css2?family={formatted_name}&display=swap"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.head(
|
||||
url, timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
return response.status == 200
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error checking Google Font availability for {font_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def analyze_fonts_in_all_slides(slide_xmls: List[str]) -> FontAnalysisResult:
|
||||
"""
|
||||
Analyze fonts across all slides and determine Google Fonts availability.
|
||||
|
||||
Args:
|
||||
slide_xmls: List of OXML content strings from all slides
|
||||
|
||||
Returns:
|
||||
FontAnalysisResult with supported and unsupported fonts
|
||||
"""
|
||||
# Extract fonts from all slides
|
||||
raw_fonts = set()
|
||||
for xml_content in slide_xmls:
|
||||
slide_fonts = extract_fonts_from_oxml(xml_content)
|
||||
raw_fonts.update(slide_fonts)
|
||||
|
||||
# Normalize to root families (e.g., "Montserrat Italic" -> "Montserrat")
|
||||
normalized_fonts = {normalize_font_family_name(f) for f in raw_fonts}
|
||||
# Remove empties if any
|
||||
normalized_fonts = {f for f in normalized_fonts if f}
|
||||
|
||||
if not normalized_fonts:
|
||||
return FontAnalysisResult(internally_supported_fonts=[], not_supported_fonts=[])
|
||||
|
||||
# Check each normalized font's availability in Google Fonts concurrently
|
||||
tasks = [check_google_font_availability(font) for font in normalized_fonts]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
internally_supported_fonts = []
|
||||
not_supported_fonts = []
|
||||
|
||||
for font, is_available in zip(normalized_fonts, results):
|
||||
if is_available:
|
||||
formatted_name = font.replace(" ", "+")
|
||||
google_fonts_url = f"https://fonts.googleapis.com/css2?family={formatted_name}&display=swap"
|
||||
internally_supported_fonts.append(
|
||||
{"name": font, "google_fonts_url": google_fonts_url}
|
||||
)
|
||||
else:
|
||||
not_supported_fonts.append(font)
|
||||
|
||||
return FontAnalysisResult(
|
||||
internally_supported_fonts=internally_supported_fonts, not_supported_fonts=[]
|
||||
)
|
||||
|
||||
|
||||
@PPTX_SLIDES_ROUTER.post("/process", response_model=PptxSlidesResponse)
|
||||
async def process_pptx_slides(
|
||||
pptx_file: UploadFile = File(..., description="PPTX file to process"),
|
||||
fonts: Optional[List[UploadFile]] = File(None, description="Optional font files"),
|
||||
):
|
||||
"""
|
||||
Process a PPTX file to extract slide screenshots and XML content.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the uploaded PPTX file
|
||||
2. Installs any provided font files
|
||||
3. Unzips the PPTX to extract slide XMLs
|
||||
4. Uses LibreOffice to generate slide screenshots
|
||||
5. Returns both screenshot URLs and XML content for each slide
|
||||
"""
|
||||
|
||||
# Validate PPTX file
|
||||
if pptx_file.content_type not in PPTX_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Expected PPTX file, got {pptx_file.content_type}",
|
||||
)
|
||||
# Enforce 100MB size limit
|
||||
if (
|
||||
hasattr(pptx_file, "size")
|
||||
and pptx_file.size
|
||||
and pptx_file.size > (100 * 1024 * 1024)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="PPTX file exceeded max upload size of 100 MB",
|
||||
)
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if True:
|
||||
# Save uploaded PPTX file
|
||||
pptx_path = os.path.join(temp_dir, "presentation.pptx")
|
||||
with open(pptx_path, "wb") as f:
|
||||
pptx_content = await pptx_file.read()
|
||||
f.write(pptx_content)
|
||||
|
||||
# Install fonts if provided
|
||||
if fonts:
|
||||
await _install_fonts(fonts, temp_dir)
|
||||
|
||||
# Extract slide XMLs from PPTX
|
||||
slide_xmls = _extract_slide_xmls(pptx_path, temp_dir)
|
||||
|
||||
# Convert PPTX to PDF
|
||||
pdf_path = await _convert_pptx_to_pdf(pptx_path, temp_dir)
|
||||
|
||||
# Generate screenshots using LibreOffice
|
||||
screenshot_paths = await DocumentsLoader.get_page_images_from_pdf_async(
|
||||
pdf_path, temp_dir
|
||||
)
|
||||
print(f"Screenshot paths: {screenshot_paths}")
|
||||
|
||||
# Analyze fonts across all slides
|
||||
font_analysis = await analyze_fonts_in_all_slides(slide_xmls)
|
||||
print(
|
||||
f"Font analysis completed: {len(font_analysis.internally_supported_fonts)} supported, {len(font_analysis.not_supported_fonts)} not supported"
|
||||
)
|
||||
|
||||
# Move screenshots to images directory and generate URLs
|
||||
images_dir = get_images_directory()
|
||||
presentation_id = uuid.uuid4()
|
||||
presentation_images_dir = os.path.join(images_dir, str(presentation_id))
|
||||
os.makedirs(presentation_images_dir, exist_ok=True)
|
||||
|
||||
slides_data = []
|
||||
|
||||
for i, (xml_content, screenshot_path) in enumerate(
|
||||
zip(slide_xmls, screenshot_paths), 1
|
||||
):
|
||||
# Move screenshot to permanent location
|
||||
screenshot_filename = f"slide_{i}.png"
|
||||
permanent_screenshot_path = os.path.join(
|
||||
presentation_images_dir, screenshot_filename
|
||||
)
|
||||
|
||||
if (
|
||||
os.path.exists(screenshot_path)
|
||||
and os.path.getsize(screenshot_path) > 0
|
||||
):
|
||||
# Use shutil.copy2 instead of os.rename to handle cross-device moves
|
||||
shutil.copy2(screenshot_path, permanent_screenshot_path)
|
||||
screenshot_url = (
|
||||
f"/app_data/images/{presentation_id}/{screenshot_filename}"
|
||||
)
|
||||
else:
|
||||
# Fallback if screenshot generation failed or file is empty placeholder
|
||||
screenshot_url = "/static/images/placeholder.jpg"
|
||||
|
||||
# Compute normalized fonts for this slide
|
||||
raw_slide_fonts = extract_fonts_from_oxml(xml_content)
|
||||
normalized_fonts = sorted(
|
||||
{normalize_font_family_name(f) for f in raw_slide_fonts if f}
|
||||
)
|
||||
|
||||
slides_data.append(
|
||||
SlideData(
|
||||
slide_number=i,
|
||||
screenshot_url=screenshot_url,
|
||||
xml_content=xml_content,
|
||||
normalized_fonts=normalized_fonts,
|
||||
)
|
||||
)
|
||||
|
||||
return PptxSlidesResponse(
|
||||
success=True,
|
||||
slides=slides_data,
|
||||
total_slides=len(slides_data),
|
||||
fonts=font_analysis,
|
||||
)
|
||||
|
||||
|
||||
# NEW: Fonts-only endpoint leveraging the same font extraction/analysis
|
||||
@PPTX_FONTS_ROUTER.post("/process", response_model=PptxFontsResponse)
|
||||
async def process_pptx_fonts(
|
||||
pptx_file: UploadFile = File(..., description="PPTX file to analyze fonts from")
|
||||
):
|
||||
"""
|
||||
Analyze a PPTX file and return only the fonts used in the document.
|
||||
|
||||
Uses the exact same font extraction and analysis utilities as the /pptx-slides endpoint.
|
||||
"""
|
||||
# Validate PPTX file
|
||||
if pptx_file.content_type not in PPTX_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Expected PPTX file, got {pptx_file.content_type}",
|
||||
)
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save uploaded PPTX file
|
||||
pptx_path = os.path.join(temp_dir, "presentation.pptx")
|
||||
with open(pptx_path, "wb") as f:
|
||||
pptx_content = await pptx_file.read()
|
||||
f.write(pptx_content)
|
||||
|
||||
# Extract slide XMLs from PPTX
|
||||
slide_xmls = _extract_slide_xmls(pptx_path, temp_dir)
|
||||
|
||||
# Analyze fonts across all slides (same logic as in /pptx-slides)
|
||||
font_analysis = await analyze_fonts_in_all_slides(slide_xmls)
|
||||
|
||||
return PptxFontsResponse(
|
||||
success=True,
|
||||
fonts=font_analysis,
|
||||
)
|
||||
|
||||
|
||||
def _create_font_alias_config(raw_fonts: List[str]) -> str:
|
||||
"""Create a temporary fontconfig configuration that aliases variant family names to normalized root families.
|
||||
Returns the path to the config file.
|
||||
"""
|
||||
# Build mapping from raw -> normalized where different
|
||||
mappings: Dict[str, str] = {}
|
||||
for f in raw_fonts:
|
||||
normalized = normalize_font_family_name(f)
|
||||
if normalized and normalized != f:
|
||||
mappings[f] = normalized
|
||||
# Create config only if we have mappings
|
||||
fd, fonts_conf_path = tempfile.mkstemp(prefix="fonts_alias_", suffix=".conf")
|
||||
os.close(fd)
|
||||
with open(fonts_conf_path, "w", encoding="utf-8") as cfg:
|
||||
cfg.write(
|
||||
"""<?xml version='1.0'?>
|
||||
<!DOCTYPE fontconfig SYSTEM "urn:fontconfig:fonts.dtd">
|
||||
<fontconfig>
|
||||
<include>/etc/fonts/fonts.conf</include>
|
||||
"""
|
||||
)
|
||||
for src, dst in mappings.items():
|
||||
cfg.write(
|
||||
f"""
|
||||
<match target="pattern">
|
||||
<test name="family" compare="eq">
|
||||
<string>{src}</string>
|
||||
</test>
|
||||
<edit name="family" mode="assign" binding="strong">
|
||||
<string>{dst}</string>
|
||||
</edit>
|
||||
</match>
|
||||
"""
|
||||
)
|
||||
cfg.write("\n</fontconfig>\n")
|
||||
return fonts_conf_path
|
||||
|
||||
|
||||
async def _install_fonts(fonts: List[UploadFile], temp_dir: str) -> None:
|
||||
"""Install provided font files to the system."""
|
||||
fonts_dir = os.path.join(temp_dir, "fonts")
|
||||
os.makedirs(fonts_dir, exist_ok=True)
|
||||
|
||||
for font_file in fonts:
|
||||
# Save font file
|
||||
font_path = os.path.join(fonts_dir, font_file.filename)
|
||||
with open(font_path, "wb") as f:
|
||||
font_content = await font_file.read()
|
||||
f.write(font_content)
|
||||
|
||||
# Install font (copy to system fonts directory)
|
||||
try:
|
||||
subprocess.run(
|
||||
["cp", font_path, "/usr/share/fonts/truetype/"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Warning: Failed to install font {font_file.filename}: {e}")
|
||||
|
||||
# Refresh font cache
|
||||
try:
|
||||
subprocess.run(["fc-cache", "-f", "-v"], check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Warning: Failed to refresh font cache: {e}")
|
||||
|
||||
|
||||
def _extract_slide_xmls(pptx_path: str, temp_dir: str) -> List[str]:
|
||||
"""Extract slide XML content from PPTX file."""
|
||||
slide_xmls = []
|
||||
extract_dir = os.path.join(temp_dir, "pptx_extract")
|
||||
|
||||
try:
|
||||
# Unzip PPTX file
|
||||
with zipfile.ZipFile(pptx_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_dir)
|
||||
|
||||
# Look for slides in ppt/slides/ directory
|
||||
slides_dir = os.path.join(extract_dir, "ppt", "slides")
|
||||
|
||||
if not os.path.exists(slides_dir):
|
||||
raise Exception("No slides directory found in PPTX file")
|
||||
|
||||
# Get all slide XML files and sort them numerically
|
||||
slide_files = [
|
||||
f
|
||||
for f in os.listdir(slides_dir)
|
||||
if f.startswith("slide") and f.endswith(".xml")
|
||||
]
|
||||
slide_files.sort(key=lambda x: int(x.replace("slide", "").replace(".xml", "")))
|
||||
|
||||
# Read XML content from each slide
|
||||
for slide_file in slide_files:
|
||||
slide_path = os.path.join(slides_dir, slide_file)
|
||||
with open(slide_path, "r", encoding="utf-8") as f:
|
||||
slide_xmls.append(f.read())
|
||||
|
||||
return slide_xmls
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to extract slide XMLs: {str(e)}")
|
||||
|
||||
|
||||
async def _convert_pptx_to_pdf(pptx_path: str, temp_dir: str) -> str:
|
||||
"""Generate PNG screenshots of PPTX slides using LibreOffice + ImageMagick."""
|
||||
screenshots_dir = os.path.join(temp_dir, "screenshots")
|
||||
os.makedirs(screenshots_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# First, get the number of slides by extracting XMLs
|
||||
slide_xmls = _extract_slide_xmls(pptx_path, temp_dir)
|
||||
slide_count = len(slide_xmls)
|
||||
|
||||
# Build font alias config to force variant families to resolve to normalized root families
|
||||
raw_fonts: List[str] = []
|
||||
for xml in slide_xmls:
|
||||
raw_fonts.extend(extract_fonts_from_oxml(xml))
|
||||
raw_fonts = list({f for f in raw_fonts if f})
|
||||
fonts_conf_path = _create_font_alias_config(raw_fonts)
|
||||
env = os.environ.copy()
|
||||
env["FONTCONFIG_FILE"] = fonts_conf_path
|
||||
|
||||
print(f"Found {slide_count} slides in presentation")
|
||||
|
||||
# Step 1: Convert PPTX to PDF using LibreOffice
|
||||
print("Starting LibreOffice PDF conversion...")
|
||||
pdf_filename = "temp_presentation.pdf"
|
||||
pdf_path = os.path.join(screenshots_dir, pdf_filename)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
_get_soffice_binary(),
|
||||
"--headless",
|
||||
"--convert-to",
|
||||
"pdf",
|
||||
"--outdir",
|
||||
screenshots_dir,
|
||||
pptx_path,
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=500,
|
||||
env=env,
|
||||
**_windows_hidden_subprocess_kwargs(),
|
||||
)
|
||||
|
||||
print(f"LibreOffice PDF conversion output: {result.stdout}")
|
||||
if result.stderr:
|
||||
print(f"LibreOffice PDF conversion warnings: {result.stderr}")
|
||||
except subprocess.TimeoutExpired:
|
||||
raise Exception("LibreOffice PDF conversion timed out after 120 seconds")
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = e.stderr if e.stderr else str(e)
|
||||
raise Exception(f"LibreOffice PDF conversion failed: {error_msg}")
|
||||
|
||||
# Find the generated PDF file (LibreOffice uses original filename)
|
||||
pdf_files = [f for f in os.listdir(screenshots_dir) if f.endswith(".pdf")]
|
||||
if not pdf_files:
|
||||
raise Exception("LibreOffice failed to generate PDF file")
|
||||
|
||||
actual_pdf_path = os.path.join(screenshots_dir, pdf_files[0])
|
||||
print(f"Generated PDF: {actual_pdf_path}")
|
||||
return actual_pdf_path
|
||||
|
||||
except Exception as e:
|
||||
# Re-raise the specific exceptions we've already handled
|
||||
if "timed out" in str(e) or "failed:" in str(e):
|
||||
raise
|
||||
# Handle any other unexpected exceptions
|
||||
raise Exception(f"Screenshot generation failed: {str(e)}")
|
||||
|
|
@ -24,16 +24,13 @@ from models.presentation_outline_model import (
|
|||
from enums.tone import Tone
|
||||
from enums.verbosity import Verbosity
|
||||
from models.pptx_models import PptxPresentationModel
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.presentation_with_slides import (
|
||||
PresentationWithSlides,
|
||||
)
|
||||
from models.sql.template import TemplateModel
|
||||
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from services.webhook_service import WebhookService
|
||||
from utils.get_layout_by_name import get_layout_by_name
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.dict_utils import deep_update
|
||||
from utils.export_utils import export_presentation
|
||||
|
|
@ -70,6 +67,8 @@ from utils.process_slides import (
|
|||
process_slide_add_placeholder_assets,
|
||||
process_slide_and_fetch_assets,
|
||||
)
|
||||
from templates.get_layout_by_name import get_layout_by_name
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
import uuid
|
||||
|
||||
|
||||
|
|
@ -881,6 +880,8 @@ async def generate_presentation_sync(
|
|||
return await generate_presentation_handler(
|
||||
request, presentation_id, None, sql_session
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail="Presentation generation failed")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,13 +1,10 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from api.v1.ppt.endpoints.slide_to_html import SLIDE_TO_HTML_ROUTER, HTML_TO_REACT_ROUTER, HTML_EDIT_ROUTER, LAYOUT_MANAGEMENT_ROUTER
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
from api.v1.ppt.endpoints.anthropic import ANTHROPIC_ROUTER
|
||||
from api.v1.ppt.endpoints.google import GOOGLE_ROUTER
|
||||
from api.v1.ppt.endpoints.openai import OPENAI_ROUTER
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
from api.v1.ppt.endpoints.pptx_slides import PPTX_SLIDES_ROUTER
|
||||
from api.v1.ppt.endpoints.pdf_slides import PDF_SLIDES_ROUTER
|
||||
from api.v1.ppt.endpoints.fonts import FONTS_ROUTER
|
||||
from api.v1.ppt.endpoints.icons import ICONS_ROUTER
|
||||
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
||||
|
|
@ -15,9 +12,9 @@ from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
|
|||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
|
||||
from api.v1.ppt.endpoints.codex_auth import CODEX_AUTH_ROUTER
|
||||
from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER
|
||||
from api.v1.ppt.endpoints.theme import THEMES_ROUTER
|
||||
from api.v1.ppt.endpoints.theme_generate import THEME_ROUTER
|
||||
from templates.router import TEMPLATE_ROUTER
|
||||
|
||||
|
||||
API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt")
|
||||
|
|
@ -26,20 +23,14 @@ API_V1_PPT_ROUTER.include_router(FILES_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(FONTS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PPTX_SLIDES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(SLIDE_TO_HTML_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(HTML_TO_REACT_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(HTML_EDIT_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(LAYOUT_MANAGEMENT_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(TEMPLATE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(THEMES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(THEME_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PDF_SLIDES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CODEX_AUTH_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PPTX_FONTS_ROUTER)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ import uuid
|
|||
from sqlalchemy import JSON, Column, DateTime, String
|
||||
from sqlmodel import Boolean, Field, SQLModel
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from sqlalchemy import Column, DateTime, Text, JSON
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, Text
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
||||
class PresentationLayoutCodeModel(SQLModel, table=True):
|
||||
"""Model for storing presentation layout codes"""
|
||||
|
||||
__tablename__ = "presentation_layout_codes"
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
|
@ -19,8 +18,10 @@ class PresentationLayoutCodeModel(SQLModel, table=True):
|
|||
layout_code: str = Field(
|
||||
sa_column=Column(Text), description="TSX/React component code for the layout"
|
||||
)
|
||||
fonts: Optional[List[str]] = Field(
|
||||
sa_column=Column(JSON), default=None, description="Optional list of font links"
|
||||
fonts: Optional[dict[str, str] | list[str]] = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON, nullable=True),
|
||||
description="Optional font metadata associated with the layout",
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
|
|
|||
25
electron/servers/fastapi/models/sql/template_create_info.py
Normal file
25
electron/servers/fastapi/models/sql/template_create_info.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
||||
class TemplateCreateInfoModel(SQLModel, table=True):
|
||||
__tablename__ = "template_create_infos"
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
fonts: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(JSON, nullable=True),
|
||||
)
|
||||
pptx_url: str | None = Field(default=None)
|
||||
slide_htmls: list[str] = Field(sa_column=Column(JSON, nullable=False))
|
||||
slide_image_urls: list[str] = Field(sa_column=Column(JSON, nullable=False))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, default=get_current_utc_datetime
|
||||
)
|
||||
)
|
||||
|
|
@ -42,4 +42,4 @@ dev = [
|
|||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["api*", "enums*", "models*", "services*", "constants*", "utils*"]
|
||||
include = ["api*", "enums*", "models*", "services*", "constants*", "utils*", "templates*"]
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@ from models.sql.async_presentation_generation_status import (
|
|||
from models.sql.image_asset import ImageAsset
|
||||
from models.sql.key_value import KeyValueSqlModel
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sql.presentation_layout_code import PresentationLayoutCodeModel
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.template import TemplateModel
|
||||
from models.sql.template_create_info import TemplateCreateInfoModel
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from utils.db_utils import get_database_url_and_connect_args
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
|
|
@ -65,6 +66,7 @@ async def create_db_and_tables():
|
|||
KeyValueSqlModel.__table__,
|
||||
ImageAsset.__table__,
|
||||
PresentationLayoutCodeModel.__table__,
|
||||
TemplateCreateInfoModel.__table__,
|
||||
TemplateModel.__table__,
|
||||
WebhookSubscription.__table__,
|
||||
AsyncPresentationGenerationTaskModel.__table__,
|
||||
|
|
|
|||
1
electron/servers/fastapi/templates/__init__.py
Normal file
1
electron/servers/fastapi/templates/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
__all__ = []
|
||||
98
electron/servers/fastapi/templates/example.py
Normal file
98
electron/servers/fastapi/templates/example.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from typing import Any
|
||||
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
|
||||
PLACEHOLDER_IMAGE_URL = "https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png"
|
||||
PLACEHOLDER_ICON_URL = "https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg"
|
||||
|
||||
|
||||
def build_schema_example(schema: dict) -> Any:
|
||||
if not isinstance(schema, dict):
|
||||
return None
|
||||
|
||||
if "default" in schema:
|
||||
return schema["default"]
|
||||
|
||||
for key in ("anyOf", "oneOf", "allOf"):
|
||||
options = schema.get(key)
|
||||
if isinstance(options, list):
|
||||
for option in options:
|
||||
example = build_schema_example(option)
|
||||
if example is not None:
|
||||
return example
|
||||
|
||||
enum_values = schema.get("enum")
|
||||
if enum_values:
|
||||
return enum_values[0]
|
||||
|
||||
schema_type = schema.get("type")
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties", {})
|
||||
result = {}
|
||||
for field_name, field_schema in properties.items():
|
||||
result[field_name] = build_schema_example(field_schema)
|
||||
return result
|
||||
|
||||
if schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
if "default" in schema:
|
||||
return schema["default"]
|
||||
item_example = build_schema_example(items_schema)
|
||||
return [] if item_example is None else [item_example]
|
||||
|
||||
if schema_type == "string":
|
||||
schema_description = (schema.get("description") or "").lower()
|
||||
if "icon" in schema_description:
|
||||
return PLACEHOLDER_ICON_URL
|
||||
if "image" in schema_description or "url" in schema_description:
|
||||
return PLACEHOLDER_IMAGE_URL
|
||||
return "Sample text"
|
||||
|
||||
if schema_type == "integer":
|
||||
return schema.get("minimum", 1)
|
||||
|
||||
if schema_type == "number":
|
||||
return schema.get("minimum", 1)
|
||||
|
||||
if schema_type == "boolean":
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def replace_special_placeholders(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
result = {}
|
||||
for key, child in value.items():
|
||||
if key == "__image_url__":
|
||||
result[key] = PLACEHOLDER_IMAGE_URL
|
||||
elif key == "__icon_url__":
|
||||
result[key] = PLACEHOLDER_ICON_URL
|
||||
else:
|
||||
result[key] = replace_special_placeholders(child)
|
||||
return result
|
||||
|
||||
if isinstance(value, list):
|
||||
return [replace_special_placeholders(item) for item in value]
|
||||
|
||||
if value == "__image_url__":
|
||||
return PLACEHOLDER_IMAGE_URL
|
||||
if value == "__icon_url__":
|
||||
return PLACEHOLDER_ICON_URL
|
||||
return value
|
||||
|
||||
|
||||
def build_template_example(
|
||||
template_id: str, layout: PresentationLayoutModel
|
||||
) -> dict[str, Any]:
|
||||
slides = []
|
||||
for slide in layout.slides:
|
||||
example_content = replace_special_placeholders(
|
||||
build_schema_example(slide.json_schema)
|
||||
)
|
||||
slides.append({"layout": slide.id, "content": example_content})
|
||||
|
||||
return {
|
||||
"template": template_id,
|
||||
"slides": slides,
|
||||
}
|
||||
167
electron/servers/fastapi/templates/font_utils.py
Normal file
167
electron/servers/fastapi/templates/font_utils.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
import asyncio
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Iterable
|
||||
|
||||
import aiohttp
|
||||
|
||||
_STYLE_TOKENS = {
|
||||
"italic",
|
||||
"italics",
|
||||
"ital",
|
||||
"oblique",
|
||||
"roman",
|
||||
"bolditalic",
|
||||
"bolditalics",
|
||||
"thin",
|
||||
"hairline",
|
||||
"extralight",
|
||||
"ultralight",
|
||||
"light",
|
||||
"demilight",
|
||||
"semilight",
|
||||
"book",
|
||||
"regular",
|
||||
"normal",
|
||||
"medium",
|
||||
"semibold",
|
||||
"demibold",
|
||||
"bold",
|
||||
"extrabold",
|
||||
"ultrabold",
|
||||
"black",
|
||||
"extrablack",
|
||||
"ultrablack",
|
||||
"heavy",
|
||||
"narrow",
|
||||
"condensed",
|
||||
"semicondensed",
|
||||
"extracondensed",
|
||||
"ultracondensed",
|
||||
"expanded",
|
||||
"semiexpanded",
|
||||
"extraexpanded",
|
||||
"ultraexpanded",
|
||||
}
|
||||
_STYLE_MODIFIERS = {"semi", "demi", "extra", "ultra"}
|
||||
|
||||
|
||||
def _insert_spaces_in_camel_case(value: str) -> str:
|
||||
value = re.sub(r"(?<=[a-z0-9])([A-Z])", r" \1", value)
|
||||
value = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", value)
|
||||
return value
|
||||
|
||||
|
||||
def normalize_font_family_name(raw_name: str) -> str:
|
||||
if not raw_name:
|
||||
return raw_name
|
||||
|
||||
name = raw_name.replace("_", " ").replace("-", " ")
|
||||
name = _insert_spaces_in_camel_case(name)
|
||||
name = re.sub(r"\s+", " ", name).strip()
|
||||
lower_name = name.lower()
|
||||
|
||||
for style in sorted(_STYLE_TOKENS, key=len, reverse=True):
|
||||
suffix = " " + style
|
||||
if lower_name.endswith(suffix):
|
||||
name = name[: -len(suffix)]
|
||||
lower_name = lower_name[: -len(suffix)]
|
||||
break
|
||||
|
||||
tokens_original = name.split(" ")
|
||||
tokens_filtered: list[str] = []
|
||||
for index, token in enumerate(tokens_original):
|
||||
lower_token = token.lower()
|
||||
if index == 0:
|
||||
tokens_filtered.append(token)
|
||||
continue
|
||||
if lower_token in _STYLE_TOKENS or lower_token in _STYLE_MODIFIERS:
|
||||
continue
|
||||
tokens_filtered.append(token)
|
||||
|
||||
if not tokens_filtered:
|
||||
tokens_filtered = tokens_original
|
||||
|
||||
return re.sub(r"\s+", " ", " ".join(tokens_filtered).strip())
|
||||
|
||||
|
||||
def extract_fonts_from_oxml(xml_content: str) -> list[str]:
|
||||
fonts = set()
|
||||
|
||||
try:
|
||||
root = ET.fromstring(xml_content)
|
||||
namespaces = {
|
||||
"a": "http://schemas.openxmlformats.org/drawingml/2006/main",
|
||||
"p": "http://schemas.openxmlformats.org/presentationml/2006/main",
|
||||
"r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
|
||||
}
|
||||
|
||||
for xpath in (".//a:latin", ".//a:ea", ".//a:cs", ".//a:font"):
|
||||
for font_elem in root.findall(xpath, namespaces):
|
||||
typeface = font_elem.attrib.get("typeface")
|
||||
if typeface:
|
||||
fonts.add(typeface)
|
||||
|
||||
for rpr_elem in root.findall(".//a:rPr", namespaces):
|
||||
for font_elem in rpr_elem.findall(".//a:latin", namespaces):
|
||||
typeface = font_elem.attrib.get("typeface")
|
||||
if typeface:
|
||||
fonts.add(typeface)
|
||||
|
||||
for font_elem in root.findall(".//latin"):
|
||||
typeface = font_elem.attrib.get("typeface")
|
||||
if typeface:
|
||||
fonts.add(typeface)
|
||||
|
||||
fonts.update(re.findall(r'typeface="([^"]+)"', xml_content))
|
||||
|
||||
system_fonts = {"+mn-lt", "+mj-lt", "+mn-ea", "+mj-ea", "+mn-cs", "+mj-cs", ""}
|
||||
return sorted(font for font in fonts if font not in system_fonts and font.strip())
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def get_google_font_css_url(font_name: str) -> str:
|
||||
return f"https://fonts.googleapis.com/css2?family={font_name.replace(' ', '+')}&display=swap"
|
||||
|
||||
|
||||
async def check_google_font_availability(font_name: str) -> bool:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.head(
|
||||
get_google_font_css_url(font_name),
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def collect_normalized_fonts_from_xmls(slide_xmls: Iterable[str]) -> list[str]:
|
||||
raw_fonts = set()
|
||||
for xml_content in slide_xmls:
|
||||
raw_fonts.update(extract_fonts_from_oxml(xml_content))
|
||||
|
||||
normalized_fonts = {normalize_font_family_name(font) for font in raw_fonts}
|
||||
return sorted(font for font in normalized_fonts if font)
|
||||
|
||||
|
||||
async def get_available_and_unavailable_fonts(
|
||||
font_names: Iterable[str],
|
||||
) -> tuple[list[tuple[str, str]], list[tuple[str, None]]]:
|
||||
normalized_fonts = sorted({font for font in font_names if font})
|
||||
if not normalized_fonts:
|
||||
return [], []
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[check_google_font_availability(font) for font in normalized_fonts]
|
||||
)
|
||||
|
||||
available_fonts: list[tuple[str, str]] = []
|
||||
unavailable_fonts: list[tuple[str, None]] = []
|
||||
for font_name, is_available in zip(normalized_fonts, results):
|
||||
if is_available:
|
||||
available_fonts.append((font_name, get_google_font_css_url(font_name)))
|
||||
else:
|
||||
unavailable_fonts.append((font_name, None))
|
||||
return available_fonts, unavailable_fonts
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from typing import List
|
||||
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
|
||||
|
||||
async def get_layout_by_name(layout_name: str) -> PresentationLayoutModel:
|
||||
url = f"http://localhost/api/template?group={layout_name}"
|
||||
|
|
@ -11,8 +12,7 @@ async def get_layout_by_name(layout_name: str) -> PresentationLayoutModel:
|
|||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Template '{layout_name}' not found: {error_text}"
|
||||
detail=f"Template '{layout_name}' not found: {error_text}",
|
||||
)
|
||||
layout_json = await response.json()
|
||||
# Parse the JSON into your Pydantic model
|
||||
return PresentationLayoutModel(**layout_json)
|
||||
683
electron/servers/fastapi/templates/handler.py
Normal file
683
electron/servers/fastapi/templates/handler.py
Normal file
|
|
@ -0,0 +1,683 @@
|
|||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import Body, Depends, File, Form, HTTPException, Path, Query, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import delete, select
|
||||
|
||||
from constants.presentation import DEFAULT_TEMPLATES
|
||||
from models.sql.presentation_layout_code import PresentationLayoutCodeModel
|
||||
from models.sql.template import TemplateModel
|
||||
from models.sql.template_create_info import TemplateCreateInfoModel
|
||||
from services.database import get_async_session
|
||||
from templates.example import build_template_example
|
||||
from templates.get_layout_by_name import get_layout_by_name
|
||||
from templates.pptx_html_stub import BASIC_TEMPLATE_HTML
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
from templates.preview import (
|
||||
FontsUploadAndSlidesPreviewResponse,
|
||||
upload_fonts_and_slides_preview_handler,
|
||||
)
|
||||
from templates.prompts import (
|
||||
SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT,
|
||||
SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT,
|
||||
SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT,
|
||||
)
|
||||
from templates.providers import edit_slide_layout_code, generate_slide_layout_code
|
||||
from utils.asset_directory_utils import resolve_image_path_to_filesystem
|
||||
|
||||
|
||||
class TemplateDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
total_layouts: Optional[int] = None
|
||||
|
||||
|
||||
class TemplateLayoutData(BaseModel):
|
||||
template: uuid.UUID
|
||||
layout_id: str
|
||||
layout_name: str
|
||||
layout_code: str
|
||||
fonts: Optional[Any] = None
|
||||
|
||||
|
||||
class TemplateData(BaseModel):
|
||||
id: uuid.UUID
|
||||
init_id: Optional[uuid.UUID] = None
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class GetTemplateLayoutsResponse(BaseModel):
|
||||
layouts: list[TemplateLayoutData]
|
||||
template: Optional[TemplateData] = None
|
||||
fonts: Optional[Any] = None
|
||||
|
||||
|
||||
class TemplateExample(BaseModel):
|
||||
template: str
|
||||
slides: List[dict]
|
||||
|
||||
|
||||
class CreateTemplateInitRequest(BaseModel):
|
||||
pptx_url: str
|
||||
slide_image_urls: List[str]
|
||||
fonts: dict = {}
|
||||
|
||||
|
||||
class CreateSlideLayoutRequest(BaseModel):
|
||||
id: uuid.UUID
|
||||
index: int
|
||||
|
||||
|
||||
class CreateSlideLayoutResponse(BaseModel):
|
||||
react_component: str
|
||||
|
||||
|
||||
class EditSlideLayoutRequest(BaseModel):
|
||||
react_component: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class EditSlideLayoutResponse(CreateSlideLayoutResponse):
|
||||
pass
|
||||
|
||||
|
||||
class EditSlideLayoutSectionRequest(BaseModel):
|
||||
react_component: str
|
||||
section: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class EditSlideLayoutSectionResponse(CreateSlideLayoutResponse):
|
||||
pass
|
||||
|
||||
|
||||
class SaveTemplateLayoutData(BaseModel):
|
||||
layout_id: str
|
||||
layout_name: str
|
||||
layout_code: str
|
||||
|
||||
|
||||
class SaveTemplateRequest(BaseModel):
|
||||
template_info_id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
layouts: List[SaveTemplateLayoutData]
|
||||
|
||||
|
||||
class SaveTemplateResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CloneTemplateRequest(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateTemplateRequest(BaseModel):
|
||||
id: uuid.UUID
|
||||
layouts: List[SaveTemplateLayoutData]
|
||||
|
||||
|
||||
class SaveSlideLayoutRequest(BaseModel):
|
||||
template_id: uuid.UUID
|
||||
layout_id: str
|
||||
layout_code: str
|
||||
|
||||
|
||||
class CloneSlideLayoutRequest(BaseModel):
|
||||
template_id: str
|
||||
layout_id: str
|
||||
layout_name: Optional[str] = None
|
||||
|
||||
|
||||
def _strip_code_fences(value: str) -> str:
|
||||
return (
|
||||
value.replace("```tsx", "")
|
||||
.replace("```typescript", "")
|
||||
.replace("```ts", "")
|
||||
.replace("```", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
|
||||
def _normalize_layout_code_for_create(code: str) -> str:
|
||||
normalized = _strip_code_fences(code)
|
||||
normalized = (
|
||||
normalized.replace("image_url", "__image_url__")
|
||||
.replace("icon_url", "__icon_url__")
|
||||
.replace("image_prompt", "__image_prompt__")
|
||||
.replace("icon_query", "__icon_query__")
|
||||
)
|
||||
|
||||
first_import_match = re.search(r"(?m)^\s*import\b", normalized)
|
||||
if first_import_match:
|
||||
normalized = normalized[first_import_match.start() :]
|
||||
|
||||
first_export_match = re.search(r"(?m)^\s*export\b", normalized)
|
||||
if first_export_match:
|
||||
normalized = normalized[: first_export_match.start()]
|
||||
|
||||
normalized = re.sub(
|
||||
r"(?ms)^\s*(?:import|export)\b.*?;(?:\r?\n|$)",
|
||||
"",
|
||||
normalized,
|
||||
)
|
||||
normalized = re.sub(
|
||||
r"(?m)^\s*(?:import|export)\b.*(?:\r?\n|$)",
|
||||
"",
|
||||
normalized,
|
||||
)
|
||||
normalized = normalized.strip()
|
||||
normalized = re.sub(
|
||||
r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])',
|
||||
lambda match: (
|
||||
match.group(0)
|
||||
if re.search(r"-\d{4}$", match.group(2))
|
||||
else f"{match.group(1)}{match.group(2)}-{random.randint(1000, 9999)}{match.group(3)}"
|
||||
),
|
||||
normalized,
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _update_layout_id_in_code(code: str) -> tuple[str, str]:
|
||||
match = re.search(r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])', code)
|
||||
if not match:
|
||||
raise HTTPException(status_code=400, detail="layoutId not found in layout code")
|
||||
|
||||
current_id = match.group(2)
|
||||
suffix = f"{random.randint(1000, 9999)}"
|
||||
new_id = re.sub(r"-\d{4}$", f"-{suffix}", current_id)
|
||||
if new_id == current_id:
|
||||
new_id = f"{current_id}-{suffix}"
|
||||
|
||||
new_code = re.sub(
|
||||
r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])',
|
||||
f"\\1{new_id}\\3",
|
||||
code,
|
||||
count=1,
|
||||
)
|
||||
return new_code, new_id
|
||||
|
||||
|
||||
async def _download_image_bytes(image_url: str) -> bytes:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to download slide image: {image_url}",
|
||||
)
|
||||
return await response.read()
|
||||
|
||||
|
||||
async def _read_image_bytes_and_media_type(image_url: str) -> tuple[bytes, str]:
|
||||
actual_image_path = resolve_image_path_to_filesystem(image_url)
|
||||
if actual_image_path and os.path.isfile(actual_image_path):
|
||||
with open(actual_image_path, "rb") as image_file:
|
||||
image_bytes = image_file.read()
|
||||
file_extension = os.path.splitext(actual_image_path)[1].lower()
|
||||
else:
|
||||
image_bytes = await _download_image_bytes(image_url)
|
||||
file_extension = os.path.splitext(image_url)[1].lower()
|
||||
|
||||
media_type_map = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
return image_bytes, media_type_map.get(file_extension, "image/png")
|
||||
|
||||
|
||||
async def get_all_templates(
|
||||
include_defaults: bool = Query(
|
||||
default=True, description="Whether to include default templates"
|
||||
),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
result = await sql_session.execute(
|
||||
select(
|
||||
TemplateModel.id,
|
||||
TemplateModel.name,
|
||||
func.count(PresentationLayoutCodeModel.id).label("total_layouts"),
|
||||
)
|
||||
.join(
|
||||
PresentationLayoutCodeModel,
|
||||
PresentationLayoutCodeModel.presentation == TemplateModel.id,
|
||||
)
|
||||
.group_by(TemplateModel.id, TemplateModel.name)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
templates: list[TemplateDetail] = []
|
||||
if include_defaults:
|
||||
templates.extend(
|
||||
TemplateDetail(id=template, name=template) for template in DEFAULT_TEMPLATES
|
||||
)
|
||||
|
||||
templates.extend(
|
||||
TemplateDetail(
|
||||
id=f"custom-{template_id}",
|
||||
name=template_name,
|
||||
total_layouts=total_layouts,
|
||||
)
|
||||
for template_id, template_name, total_layouts in rows
|
||||
)
|
||||
return templates
|
||||
|
||||
|
||||
async def get_layouts(
|
||||
template_id: str = Path(..., description="The id of the template"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if not template_id or not template_id.strip():
|
||||
raise HTTPException(status_code=400, detail="Template ID cannot be empty")
|
||||
|
||||
try:
|
||||
cleaned_template_id = template_id.replace("custom-", "")
|
||||
template_id_uuid = uuid.UUID(cleaned_template_id)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc
|
||||
|
||||
result = await session.execute(
|
||||
select(PresentationLayoutCodeModel).where(
|
||||
PresentationLayoutCodeModel.presentation == template_id_uuid
|
||||
)
|
||||
)
|
||||
layouts_db = result.scalars().all()
|
||||
if not layouts_db:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No layouts found for template ID: {template_id}",
|
||||
)
|
||||
|
||||
template_meta = await session.get(TemplateModel, template_id_uuid)
|
||||
template = None
|
||||
if template_meta:
|
||||
template = TemplateData(
|
||||
id=template_id_uuid,
|
||||
init_id=None,
|
||||
name=template_meta.name,
|
||||
description=template_meta.description,
|
||||
created_at=template_meta.created_at,
|
||||
)
|
||||
|
||||
layouts = [
|
||||
TemplateLayoutData(
|
||||
template=template_id_uuid,
|
||||
layout_id=layout.layout_id,
|
||||
layout_name=layout.layout_name,
|
||||
layout_code=layout.layout_code,
|
||||
fonts=layout.fonts,
|
||||
)
|
||||
for layout in layouts_db
|
||||
]
|
||||
return GetTemplateLayoutsResponse(
|
||||
layouts=layouts,
|
||||
template=template,
|
||||
fonts=layouts[0].fonts if layouts else None,
|
||||
)
|
||||
|
||||
|
||||
async def get_template_by_id(
|
||||
id: str = Path(
|
||||
...,
|
||||
description=f"The id of the template, must be one of {', '.join(DEFAULT_TEMPLATES)} or your custom template",
|
||||
),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if id.startswith("custom-"):
|
||||
try:
|
||||
template_id = uuid.UUID(id.replace("custom-", ""))
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Template not found. Please use a valid template.",
|
||||
) from exc
|
||||
|
||||
template = await sql_session.get(TemplateModel, template_id)
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Template not found. Please use a valid template.",
|
||||
)
|
||||
|
||||
return await get_layout_by_name(id)
|
||||
|
||||
|
||||
async def get_template_example(
|
||||
id: str = Path(
|
||||
...,
|
||||
description=f"The id of the template, must be one of {', '.join(DEFAULT_TEMPLATES)} or your custom template",
|
||||
),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
template = await get_template_by_id(id=id, sql_session=sql_session)
|
||||
return TemplateExample(**build_template_example(id, template))
|
||||
|
||||
|
||||
async def upload_fonts_and_slides_preview(
|
||||
pptx_file: UploadFile = File(..., description="PPTX file to preview"),
|
||||
font_files: Optional[List[UploadFile]] = File(
|
||||
default=None, description="Font files to upload"
|
||||
),
|
||||
original_font_names: Optional[List[str]] = Form(default=None),
|
||||
max_slides: Optional[int] = Query(default=None),
|
||||
):
|
||||
return await upload_fonts_and_slides_preview_handler(
|
||||
pptx_file=pptx_file,
|
||||
font_files=font_files,
|
||||
original_font_names=original_font_names,
|
||||
max_slides=max_slides,
|
||||
)
|
||||
|
||||
|
||||
async def init_create_template(
|
||||
request: CreateTemplateInitRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if not request.slide_image_urls:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="At least one slide image is required"
|
||||
)
|
||||
|
||||
template_create_info = TemplateCreateInfoModel(
|
||||
fonts=request.fonts or {},
|
||||
pptx_url=request.pptx_url,
|
||||
slide_image_urls=request.slide_image_urls,
|
||||
slide_htmls=[BASIC_TEMPLATE_HTML for _ in request.slide_image_urls],
|
||||
)
|
||||
sql_session.add(template_create_info)
|
||||
await sql_session.commit()
|
||||
await sql_session.refresh(template_create_info)
|
||||
return template_create_info.id
|
||||
|
||||
|
||||
async def create_slide_layout(
|
||||
request: CreateSlideLayoutRequest = Body(...),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
template_info = await sql_session.get(TemplateCreateInfoModel, request.id)
|
||||
if not template_info:
|
||||
raise HTTPException(status_code=400, detail="Template not found")
|
||||
|
||||
total_slides = len(template_info.slide_htmls)
|
||||
if request.index < 0 or request.index >= total_slides:
|
||||
raise HTTPException(status_code=400, detail="Invalid slide index")
|
||||
|
||||
slide_html = template_info.slide_htmls[request.index]
|
||||
slide_image_url = template_info.slide_image_urls[request.index]
|
||||
image_bytes, media_type = await _read_image_bytes_and_media_type(slide_image_url)
|
||||
|
||||
fonts_text = ""
|
||||
if template_info.fonts:
|
||||
font_names = [font.replace(" ", "_") for font in template_info.fonts.keys()]
|
||||
fonts_text = "#PROVIDED FONTS\n- " + "\n- ".join(font_names)
|
||||
|
||||
user_text = f"{fonts_text}\n\n#SLIDE HTML REFERENCE\n{slide_html}"
|
||||
react_component = await generate_slide_layout_code(
|
||||
system_prompt=SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT,
|
||||
user_text=user_text,
|
||||
image_bytes=image_bytes,
|
||||
media_type=media_type,
|
||||
)
|
||||
|
||||
return CreateSlideLayoutResponse(
|
||||
react_component=_normalize_layout_code_for_create(react_component)
|
||||
)
|
||||
|
||||
|
||||
async def edit_slide_layout(
|
||||
request: EditSlideLayoutRequest,
|
||||
):
|
||||
user_text = f"#Prompt\n{request.prompt}\n\n#TSX code\n{request.react_component}"
|
||||
react_component = await edit_slide_layout_code(
|
||||
system_prompt=SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT,
|
||||
user_text=user_text,
|
||||
)
|
||||
return EditSlideLayoutResponse(react_component=_strip_code_fences(react_component))
|
||||
|
||||
|
||||
async def edit_slide_layout_section(
|
||||
request: EditSlideLayoutSectionRequest,
|
||||
):
|
||||
user_text = (
|
||||
f"#Prompt\n{request.prompt}\n\n"
|
||||
f"#Section to make changes around\n{request.section}\n\n"
|
||||
f"#TSX code\n{request.react_component}"
|
||||
)
|
||||
react_component = await edit_slide_layout_code(
|
||||
system_prompt=SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT,
|
||||
user_text=user_text,
|
||||
)
|
||||
return EditSlideLayoutSectionResponse(
|
||||
react_component=_strip_code_fences(react_component)
|
||||
)
|
||||
|
||||
|
||||
async def save_template(
|
||||
request: SaveTemplateRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if not request.layouts:
|
||||
raise HTTPException(status_code=400, detail="Layouts are required")
|
||||
|
||||
template_info = await sql_session.get(TemplateCreateInfoModel, request.template_info_id)
|
||||
if not template_info:
|
||||
raise HTTPException(status_code=400, detail="Template info not found")
|
||||
|
||||
template = TemplateModel(
|
||||
id=uuid.uuid4(),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
)
|
||||
sql_session.add(template)
|
||||
|
||||
sql_session.add_all(
|
||||
[
|
||||
PresentationLayoutCodeModel(
|
||||
presentation=template.id,
|
||||
layout_id=layout.layout_id,
|
||||
layout_name=layout.layout_name,
|
||||
layout_code=layout.layout_code,
|
||||
fonts=template_info.fonts,
|
||||
)
|
||||
for layout in request.layouts
|
||||
]
|
||||
)
|
||||
await sql_session.commit()
|
||||
await sql_session.refresh(template)
|
||||
|
||||
return SaveTemplateResponse(
|
||||
id=template.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
created_at=template.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def clone_template(
|
||||
request: CloneTemplateRequest = Body(...),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if not request.id or not request.id.strip():
|
||||
raise HTTPException(status_code=400, detail="Template ID cannot be empty")
|
||||
|
||||
try:
|
||||
template_id_uuid = uuid.UUID(request.id.replace("custom-", ""))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc
|
||||
|
||||
template = await sql_session.get(TemplateModel, template_id_uuid)
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Template not found. Please use a valid template.",
|
||||
)
|
||||
|
||||
result = await sql_session.execute(
|
||||
select(PresentationLayoutCodeModel).where(
|
||||
PresentationLayoutCodeModel.presentation == template_id_uuid
|
||||
)
|
||||
)
|
||||
layouts_db = result.scalars().all()
|
||||
if not layouts_db:
|
||||
raise HTTPException(status_code=400, detail="No layouts found for template")
|
||||
|
||||
new_template = TemplateModel(
|
||||
id=uuid.uuid4(),
|
||||
name=request.name,
|
||||
description=template.description
|
||||
if request.description is None
|
||||
else request.description,
|
||||
)
|
||||
sql_session.add(new_template)
|
||||
|
||||
sql_session.add_all(
|
||||
[
|
||||
PresentationLayoutCodeModel(
|
||||
presentation=new_template.id,
|
||||
layout_id=layout.layout_id,
|
||||
layout_name=layout.layout_name,
|
||||
layout_code=layout.layout_code,
|
||||
fonts=layout.fonts,
|
||||
)
|
||||
for layout in layouts_db
|
||||
]
|
||||
)
|
||||
await sql_session.commit()
|
||||
await sql_session.refresh(new_template)
|
||||
|
||||
return SaveTemplateResponse(
|
||||
id=new_template.id,
|
||||
name=new_template.name,
|
||||
description=new_template.description,
|
||||
created_at=new_template.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def update_template(
|
||||
request: UpdateTemplateRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if not request.layouts:
|
||||
raise HTTPException(status_code=400, detail="Layouts are required")
|
||||
|
||||
template = await sql_session.get(TemplateModel, request.id)
|
||||
if not template:
|
||||
raise HTTPException(status_code=400, detail="Template not found")
|
||||
|
||||
existing_layout = await sql_session.scalar(
|
||||
select(PresentationLayoutCodeModel).where(
|
||||
PresentationLayoutCodeModel.presentation == request.id
|
||||
)
|
||||
)
|
||||
fonts = existing_layout.fonts if existing_layout else None
|
||||
|
||||
await sql_session.execute(
|
||||
delete(PresentationLayoutCodeModel).where(
|
||||
PresentationLayoutCodeModel.presentation == request.id
|
||||
)
|
||||
)
|
||||
sql_session.add_all(
|
||||
[
|
||||
PresentationLayoutCodeModel(
|
||||
presentation=template.id,
|
||||
layout_id=layout.layout_id,
|
||||
layout_name=layout.layout_name,
|
||||
layout_code=layout.layout_code,
|
||||
fonts=fonts,
|
||||
)
|
||||
for layout in request.layouts
|
||||
]
|
||||
)
|
||||
await sql_session.commit()
|
||||
|
||||
return SaveTemplateResponse(
|
||||
id=template.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
created_at=template.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def save_slide_layout(
|
||||
request: SaveSlideLayoutRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
template = await sql_session.get(TemplateModel, request.template_id)
|
||||
if not template:
|
||||
raise HTTPException(status_code=400, detail="Template not found")
|
||||
|
||||
layout = await sql_session.scalar(
|
||||
select(PresentationLayoutCodeModel).where(
|
||||
PresentationLayoutCodeModel.presentation == request.template_id,
|
||||
PresentationLayoutCodeModel.layout_id == request.layout_id,
|
||||
)
|
||||
)
|
||||
if not layout:
|
||||
raise HTTPException(status_code=400, detail="Layout not found")
|
||||
|
||||
layout.layout_code = request.layout_code
|
||||
sql_session.add(layout)
|
||||
await sql_session.commit()
|
||||
|
||||
|
||||
async def clone_slide_layout(
|
||||
request: CloneSlideLayoutRequest = Body(...),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if not request.template_id or not request.template_id.strip():
|
||||
raise HTTPException(status_code=400, detail="Template ID cannot be empty")
|
||||
|
||||
try:
|
||||
template_id_uuid = uuid.UUID(request.template_id.replace("custom-", ""))
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc
|
||||
|
||||
template = await sql_session.get(TemplateModel, template_id_uuid)
|
||||
if not template:
|
||||
raise HTTPException(status_code=400, detail="Template not found")
|
||||
|
||||
layout = await sql_session.scalar(
|
||||
select(PresentationLayoutCodeModel).where(
|
||||
PresentationLayoutCodeModel.presentation == template_id_uuid,
|
||||
PresentationLayoutCodeModel.layout_id == request.layout_id,
|
||||
)
|
||||
)
|
||||
if not layout:
|
||||
raise HTTPException(status_code=400, detail="Layout not found")
|
||||
|
||||
new_layout_code, new_layout_id = _update_layout_id_in_code(layout.layout_code)
|
||||
new_layout = PresentationLayoutCodeModel(
|
||||
presentation=template_id_uuid,
|
||||
layout_id=new_layout_id,
|
||||
layout_name=request.layout_name or layout.layout_name,
|
||||
layout_code=new_layout_code,
|
||||
fonts=layout.fonts,
|
||||
)
|
||||
sql_session.add(new_layout)
|
||||
await sql_session.commit()
|
||||
await sql_session.refresh(new_layout)
|
||||
|
||||
return SaveTemplateLayoutData(
|
||||
layout_id=new_layout.layout_id,
|
||||
layout_name=new_layout.layout_name,
|
||||
layout_code=new_layout.layout_code,
|
||||
)
|
||||
30
electron/servers/fastapi/templates/pptx_html_stub.py
Normal file
30
electron/servers/fastapi/templates/pptx_html_stub.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
BASIC_TEMPLATE_HTML = """
|
||||
<!-- TODO: pptx to html conversion -->
|
||||
<div class="relative w-full rounded-sm max-w-[1280px] shadow-lg max-h-[720px] aspect-video bg-white z-20 mx-auto overflow-hidden">
|
||||
<section class="flex h-full w-full flex-col justify-between bg-white px-[72px] py-[64px]">
|
||||
<div class="space-y-[18px]">
|
||||
<div class="h-[20px] w-[180px] rounded-full bg-slate-200"></div>
|
||||
<div class="h-[72px] w-[70%] rounded-[20px] bg-slate-100"></div>
|
||||
<div class="space-y-[10px] pt-[8px]">
|
||||
<div class="h-[16px] w-[82%] rounded-full bg-slate-100"></div>
|
||||
<div class="h-[16px] w-[78%] rounded-full bg-slate-100"></div>
|
||||
<div class="h-[16px] w-[66%] rounded-full bg-slate-100"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-[18px]">
|
||||
<div class="rounded-[24px] border border-slate-200 bg-slate-50 p-[24px]">
|
||||
<div class="h-[18px] w-[55%] rounded-full bg-slate-200"></div>
|
||||
<div class="mt-[18px] h-[96px] rounded-[18px] bg-white"></div>
|
||||
</div>
|
||||
<div class="rounded-[24px] border border-slate-200 bg-slate-50 p-[24px]">
|
||||
<div class="h-[18px] w-[55%] rounded-full bg-slate-200"></div>
|
||||
<div class="mt-[18px] h-[96px] rounded-[18px] bg-white"></div>
|
||||
</div>
|
||||
<div class="rounded-[24px] border border-slate-200 bg-slate-50 p-[24px]">
|
||||
<div class="h-[18px] w-[55%] rounded-full bg-slate-200"></div>
|
||||
<div class="mt-[18px] h-[96px] rounded-[18px] bg-white"></div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
""".strip()
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -25,15 +26,15 @@ class PresentationLayoutModel(BaseModel):
|
|||
status_code=404, detail=f"Slide layout {slide_layout_id} not found"
|
||||
)
|
||||
|
||||
def to_presentation_structure(self):
|
||||
def to_presentation_structure(self) -> PresentationStructureModel:
|
||||
return PresentationStructureModel(
|
||||
slides=[index for index in range(len(self.slides))]
|
||||
)
|
||||
|
||||
def to_string(self):
|
||||
message = f"## Presentation Layout\n\n"
|
||||
def to_string(self) -> str:
|
||||
message = "## Presentation Layout\n\n"
|
||||
for index, slide in enumerate(self.slides):
|
||||
message += f"### Slide Layout: {index}: \n"
|
||||
message += f"- Name: {slide.name or slide.json_schema.get('title')} \n"
|
||||
message += f"- Description: {slide.description} \n\n"
|
||||
message += f"### Slide Layout: {index}\n"
|
||||
message += f"- Name: {slide.name or slide.json_schema.get('title')}\n"
|
||||
message += f"- Description: {slide.description}\n\n"
|
||||
return message
|
||||
477
electron/servers/fastapi/templates/preview.py
Normal file
477
electron/servers/fastapi/templates/preview.py
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from constants.documents import PPTX_MIME_TYPES
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from templates.font_utils import (
|
||||
collect_normalized_fonts_from_xmls,
|
||||
get_available_and_unavailable_fonts,
|
||||
)
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
|
||||
try:
|
||||
from fontTools.ttLib import TTFont
|
||||
|
||||
FONTTOOLS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FONTTOOLS_AVAILABLE = False
|
||||
|
||||
|
||||
SUPPORTED_FONT_EXTENSIONS = {
|
||||
".ttf": "font/ttf",
|
||||
".otf": "font/otf",
|
||||
".woff": "font/woff",
|
||||
".woff2": "font/woff2",
|
||||
".eot": "application/vnd.ms-fontobject",
|
||||
}
|
||||
|
||||
|
||||
class FontInfo(BaseModel):
|
||||
name: str
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class FontCheckResponse(BaseModel):
|
||||
available_fonts: List[FontInfo]
|
||||
unavailable_fonts: List[FontInfo]
|
||||
|
||||
|
||||
class FontsUploadAndSlidesPreviewResponse(BaseModel):
|
||||
slide_image_urls: List[str]
|
||||
pptx_url: str
|
||||
modified_pptx_url: str
|
||||
fonts: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredFont:
|
||||
display_name: str
|
||||
url: str
|
||||
temp_path: str
|
||||
|
||||
|
||||
def _get_soffice_binary() -> str:
|
||||
configured = os.environ.get("SOFFICE_PATH")
|
||||
if configured:
|
||||
return configured
|
||||
return "soffice.exe" if os.name == "nt" else "soffice"
|
||||
|
||||
|
||||
def _windows_hidden_subprocess_kwargs() -> Dict[str, object]:
|
||||
if os.name != "nt":
|
||||
return {}
|
||||
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
return {
|
||||
"creationflags": getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
"startupinfo": startupinfo,
|
||||
}
|
||||
|
||||
|
||||
def _app_data_directory() -> str:
|
||||
app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
|
||||
os.makedirs(app_data_dir, exist_ok=True)
|
||||
return app_data_dir
|
||||
|
||||
|
||||
def _get_fonts_directory() -> str:
|
||||
fonts_dir = os.path.join(_app_data_directory(), "fonts")
|
||||
os.makedirs(fonts_dir, exist_ok=True)
|
||||
return fonts_dir
|
||||
|
||||
|
||||
def _get_images_directory() -> str:
|
||||
images_dir = os.path.join(_app_data_directory(), "images")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
|
||||
def _get_template_uploads_directory() -> str:
|
||||
uploads_dir = os.path.join(_app_data_directory(), "uploads", "template-previews")
|
||||
os.makedirs(uploads_dir, exist_ok=True)
|
||||
return uploads_dir
|
||||
|
||||
|
||||
def _write_bytes_to_path(path: str, data: bytes) -> None:
|
||||
with open(path, "wb") as file:
|
||||
file.write(data)
|
||||
|
||||
|
||||
def _copy_file(source_path: str, destination_path: str) -> None:
|
||||
shutil.copy2(source_path, destination_path)
|
||||
|
||||
|
||||
def _extract_font_name_from_file(file_path: str) -> str:
|
||||
filename = os.path.basename(file_path)
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
if not FONTTOOLS_AVAILABLE:
|
||||
return base_name
|
||||
|
||||
try:
|
||||
font = TTFont(file_path)
|
||||
if "name" in font:
|
||||
name_table = font["name"]
|
||||
for name_id in (1, 4, 6):
|
||||
for record in name_table.names:
|
||||
if record.nameID != name_id:
|
||||
continue
|
||||
if record.langID in (0x409, 0):
|
||||
font_name = record.toUnicode().strip()
|
||||
if font_name:
|
||||
font.close()
|
||||
return font_name
|
||||
for record in name_table.names:
|
||||
if record.nameID != 1:
|
||||
continue
|
||||
font_name = record.toUnicode().strip()
|
||||
if font_name:
|
||||
font.close()
|
||||
return font_name
|
||||
font.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return base_name
|
||||
|
||||
|
||||
def _validate_pptx_file(pptx_file: UploadFile) -> None:
|
||||
filename = getattr(pptx_file, "filename", "") or ""
|
||||
if not filename.lower().endswith(".pptx"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid file type. Expected PPTX file",
|
||||
)
|
||||
if pptx_file.content_type and pptx_file.content_type not in PPTX_MIME_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Expected PPTX file, got {pptx_file.content_type}",
|
||||
)
|
||||
|
||||
|
||||
def _ensure_valid_font_file(font_file: UploadFile) -> None:
|
||||
filename = font_file.filename or ""
|
||||
extension = os.path.splitext(filename)[1].lower()
|
||||
if extension not in SUPPORTED_FONT_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid font file. Supported formats: {', '.join(SUPPORTED_FONT_EXTENSIONS.keys())}",
|
||||
)
|
||||
|
||||
|
||||
async def _persist_custom_fonts(
|
||||
font_files: Optional[List[UploadFile]],
|
||||
original_font_names: Optional[List[str]],
|
||||
temp_dir: str,
|
||||
) -> list[StoredFont]:
|
||||
if not font_files:
|
||||
return []
|
||||
|
||||
stored_fonts: list[StoredFont] = []
|
||||
fonts_dir = _get_fonts_directory()
|
||||
|
||||
for index, font_file in enumerate(font_files):
|
||||
_ensure_valid_font_file(font_file)
|
||||
|
||||
original_name = (
|
||||
original_font_names[index]
|
||||
if original_font_names and index < len(original_font_names)
|
||||
else None
|
||||
)
|
||||
extension = os.path.splitext(font_file.filename or "")[1].lower()
|
||||
unique_name = f"{Path(font_file.filename or f'font_{index}').stem}_{uuid.uuid4().hex[:8]}{extension}"
|
||||
temp_font_path = os.path.join(temp_dir, unique_name)
|
||||
permanent_font_path = os.path.join(fonts_dir, unique_name)
|
||||
font_bytes = await font_file.read()
|
||||
|
||||
await asyncio.to_thread(_write_bytes_to_path, temp_font_path, font_bytes)
|
||||
await asyncio.to_thread(_write_bytes_to_path, permanent_font_path, font_bytes)
|
||||
|
||||
actual_font_name = await asyncio.to_thread(
|
||||
_extract_font_name_from_file, permanent_font_path
|
||||
)
|
||||
display_name = original_name or actual_font_name
|
||||
stored_fonts.append(
|
||||
StoredFont(
|
||||
display_name=display_name,
|
||||
url=f"/app_data/fonts/{unique_name}",
|
||||
temp_path=temp_font_path,
|
||||
)
|
||||
)
|
||||
|
||||
return stored_fonts
|
||||
|
||||
|
||||
def _create_font_alias_config(raw_fonts: List[str]) -> str:
|
||||
mappings: Dict[str, str] = {}
|
||||
for font_name in raw_fonts:
|
||||
normalized = font_name
|
||||
if not normalized:
|
||||
continue
|
||||
mappings[font_name] = normalized
|
||||
|
||||
fd, fonts_conf_path = tempfile.mkstemp(prefix="fonts_alias_", suffix=".conf")
|
||||
os.close(fd)
|
||||
with open(fonts_conf_path, "w", encoding="utf-8") as cfg:
|
||||
cfg.write(
|
||||
"""<?xml version='1.0'?>
|
||||
<!DOCTYPE fontconfig SYSTEM "urn:fontconfig:fonts.dtd">
|
||||
<fontconfig>
|
||||
<include>/etc/fonts/fonts.conf</include>
|
||||
"""
|
||||
)
|
||||
for source_family, destination_family in mappings.items():
|
||||
if source_family == destination_family:
|
||||
continue
|
||||
cfg.write(
|
||||
f"""
|
||||
<match target="pattern">
|
||||
<test name="family" compare="eq">
|
||||
<string>{source_family}</string>
|
||||
</test>
|
||||
<edit name="family" mode="assign" binding="strong">
|
||||
<string>{destination_family}</string>
|
||||
</edit>
|
||||
</match>
|
||||
"""
|
||||
)
|
||||
cfg.write("\n</fontconfig>\n")
|
||||
return fonts_conf_path
|
||||
|
||||
|
||||
async def _install_fonts(font_paths: List[str]) -> None:
|
||||
if not font_paths:
|
||||
return
|
||||
|
||||
for font_path in font_paths:
|
||||
try:
|
||||
subprocess.run(
|
||||
["cp", font_path, "/usr/share/fonts/truetype/"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
continue
|
||||
|
||||
try:
|
||||
subprocess.run(["fc-cache", "-f", "-v"], check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
def extract_slide_xmls(pptx_path: str, temp_dir: str) -> List[str]:
|
||||
slide_xmls: list[str] = []
|
||||
extract_dir = os.path.join(temp_dir, "pptx_extract")
|
||||
|
||||
with zipfile.ZipFile(pptx_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_dir)
|
||||
|
||||
slides_dir = os.path.join(extract_dir, "ppt", "slides")
|
||||
if not os.path.exists(slides_dir):
|
||||
raise HTTPException(status_code=400, detail="No slides directory found in PPTX")
|
||||
|
||||
slide_files = [
|
||||
file_name
|
||||
for file_name in os.listdir(slides_dir)
|
||||
if file_name.startswith("slide") and file_name.endswith(".xml")
|
||||
]
|
||||
slide_files.sort(key=lambda value: int(re.sub(r"[^0-9]", "", value) or "0"))
|
||||
|
||||
for slide_file in slide_files:
|
||||
slide_path = os.path.join(slides_dir, slide_file)
|
||||
with open(slide_path, "r", encoding="utf-8") as slide_handle:
|
||||
slide_xmls.append(slide_handle.read())
|
||||
|
||||
return slide_xmls
|
||||
|
||||
|
||||
async def convert_pptx_to_pdf(
|
||||
pptx_path: str,
|
||||
temp_dir: str,
|
||||
slide_xmls: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
screenshots_dir = os.path.join(temp_dir, "screenshots")
|
||||
os.makedirs(screenshots_dir, exist_ok=True)
|
||||
|
||||
slide_xmls = slide_xmls or extract_slide_xmls(pptx_path, temp_dir)
|
||||
raw_fonts = collect_normalized_fonts_from_xmls(slide_xmls)
|
||||
fonts_conf_path = _create_font_alias_config(raw_fonts)
|
||||
env = os.environ.copy()
|
||||
env["FONTCONFIG_FILE"] = fonts_conf_path
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
_get_soffice_binary(),
|
||||
"--headless",
|
||||
"--convert-to",
|
||||
"pdf",
|
||||
"--outdir",
|
||||
screenshots_dir,
|
||||
pptx_path,
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=500,
|
||||
env=env,
|
||||
**_windows_hidden_subprocess_kwargs(),
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="LibreOffice PDF conversion timed out after 500 seconds",
|
||||
) from exc
|
||||
except subprocess.CalledProcessError as exc:
|
||||
error_message = exc.stderr if exc.stderr else str(exc)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"LibreOffice PDF conversion failed: {error_message}",
|
||||
) from exc
|
||||
|
||||
pdf_files = [file_name for file_name in os.listdir(screenshots_dir) if file_name.endswith(".pdf")]
|
||||
if not pdf_files:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="LibreOffice failed to generate a PDF file"
|
||||
)
|
||||
|
||||
return os.path.join(screenshots_dir, pdf_files[0])
|
||||
|
||||
|
||||
async def store_slide_images(
|
||||
screenshot_paths: List[str],
|
||||
session_id: uuid.UUID,
|
||||
) -> List[str]:
|
||||
images_dir = _get_images_directory()
|
||||
target_dir = os.path.join(images_dir, str(session_id))
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
slide_image_urls: list[str] = []
|
||||
for index, screenshot_path in enumerate(screenshot_paths, start=1):
|
||||
file_name = f"slide_{index}.png"
|
||||
destination_path = os.path.join(target_dir, file_name)
|
||||
|
||||
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 0:
|
||||
await asyncio.to_thread(_copy_file, screenshot_path, destination_path)
|
||||
slide_image_urls.append(f"/app_data/images/{session_id}/{file_name}")
|
||||
else:
|
||||
slide_image_urls.append("/static/images/placeholder.jpg")
|
||||
|
||||
return slide_image_urls
|
||||
|
||||
|
||||
async def store_uploaded_pptx(
|
||||
pptx_path: str,
|
||||
session_id: uuid.UUID,
|
||||
) -> str:
|
||||
uploads_dir = _get_template_uploads_directory()
|
||||
target_dir = os.path.join(uploads_dir, str(session_id))
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
destination_path = os.path.join(target_dir, "presentation.pptx")
|
||||
await asyncio.to_thread(_copy_file, pptx_path, destination_path)
|
||||
return f"/app_data/uploads/template-previews/{session_id}/presentation.pptx"
|
||||
|
||||
|
||||
async def get_available_and_unavailable_fonts_for_pptx(
|
||||
pptx_path: str, temp_dir: str
|
||||
) -> tuple[list[tuple[str, str]], list[tuple[str, None]]]:
|
||||
slide_xmls = extract_slide_xmls(pptx_path, temp_dir)
|
||||
normalized_fonts = collect_normalized_fonts_from_xmls(slide_xmls)
|
||||
return await get_available_and_unavailable_fonts(normalized_fonts)
|
||||
|
||||
|
||||
async def check_fonts_in_pptx_handler(
|
||||
pptx_file: UploadFile = File(..., description="PPTX file to analyze fonts from")
|
||||
) -> FontCheckResponse:
|
||||
_validate_pptx_file(pptx_file)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
pptx_path = os.path.join(temp_dir, "presentation.pptx")
|
||||
pptx_content = await pptx_file.read()
|
||||
await asyncio.to_thread(_write_bytes_to_path, pptx_path, pptx_content)
|
||||
|
||||
available_fonts_data, unavailable_fonts_data = (
|
||||
await get_available_and_unavailable_fonts_for_pptx(pptx_path, temp_dir)
|
||||
)
|
||||
|
||||
return FontCheckResponse(
|
||||
available_fonts=[
|
||||
FontInfo(name=name, url=url) for name, url in available_fonts_data
|
||||
],
|
||||
unavailable_fonts=[
|
||||
FontInfo(name=name, url=url) for name, url in unavailable_fonts_data
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def upload_fonts_and_slides_preview_handler(
|
||||
pptx_file: UploadFile,
|
||||
font_files: Optional[List[UploadFile]] = None,
|
||||
original_font_names: Optional[List[str]] = None,
|
||||
max_slides: Optional[int] = None,
|
||||
) -> FontsUploadAndSlidesPreviewResponse:
|
||||
if (font_files and not original_font_names) or (
|
||||
original_font_names and not font_files
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Both font_files and original_font_names must be provided together",
|
||||
)
|
||||
if font_files and original_font_names and len(font_files) != len(original_font_names):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of font files must match number of original font names",
|
||||
)
|
||||
|
||||
_validate_pptx_file(pptx_file)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
pptx_path = os.path.join(temp_dir, "presentation.pptx")
|
||||
pptx_content = await pptx_file.read()
|
||||
await asyncio.to_thread(_write_bytes_to_path, pptx_path, pptx_content)
|
||||
|
||||
stored_fonts = await _persist_custom_fonts(
|
||||
font_files=font_files,
|
||||
original_font_names=original_font_names,
|
||||
temp_dir=temp_dir,
|
||||
)
|
||||
await _install_fonts([font.temp_path for font in stored_fonts])
|
||||
|
||||
slide_xmls = extract_slide_xmls(pptx_path, temp_dir)
|
||||
pdf_path = await convert_pptx_to_pdf(pptx_path, temp_dir, slide_xmls=slide_xmls)
|
||||
screenshot_paths = await DocumentsLoader.get_page_images_from_pdf_async(
|
||||
pdf_path, temp_dir
|
||||
)
|
||||
|
||||
if max_slides and len(screenshot_paths) > max_slides:
|
||||
screenshot_paths = screenshot_paths[:max_slides]
|
||||
|
||||
session_id = uuid.uuid4()
|
||||
slide_image_urls = await store_slide_images(screenshot_paths, session_id)
|
||||
pptx_url = await store_uploaded_pptx(pptx_path, session_id)
|
||||
|
||||
available_fonts, _ = await get_available_and_unavailable_fonts(
|
||||
collect_normalized_fonts_from_xmls(slide_xmls)
|
||||
)
|
||||
fonts: dict[str, str] = {name: url for name, url in available_fonts}
|
||||
fonts.update({font.display_name: font.url for font in stored_fonts})
|
||||
|
||||
return FontsUploadAndSlidesPreviewResponse(
|
||||
slide_image_urls=slide_image_urls,
|
||||
pptx_url=pptx_url,
|
||||
modified_pptx_url=pptx_url,
|
||||
fonts=fonts,
|
||||
)
|
||||
219
electron/servers/fastapi/templates/prompts.py
Normal file
219
electron/servers/fastapi/templates/prompts.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT = """
|
||||
You need to generate a Zod schema and a TSX React component and provide it as output.
|
||||
Provide reusable TSX code which can be used as template to generate new slides with different content.
|
||||
|
||||
# Steps:
|
||||
1. Analyze the slide image to understand the visual hierarchy.
|
||||
3. Classify elements into decorative and content elements.
|
||||
4. Group content elements into logical sections like Header, Body, BulletPoints, etc.
|
||||
5. Generate a Zod schema for the content elements.
|
||||
6. Generate id, name and description for the layout.
|
||||
6. Generate a TSX React component using the Zod schema and the HTML reference.
|
||||
|
||||
# Decorative Elements:
|
||||
- Arrows, Lines, Shapes, etc.
|
||||
- Images with Grid patterns, background patterns, gradients, solid colors, etc.
|
||||
- Background of infographics like funnel, timeline, etc.
|
||||
- Company name, logos, etc.
|
||||
- Images covering the entire slide.
|
||||
- Images containing company name, logos, etc.
|
||||
|
||||
# Decorative Elements Rules:
|
||||
- Use them exactly as they are in the HTML reference.
|
||||
- Do not change decorative images and icons urls.
|
||||
- Images containing company name, logos, etc should be identified as decorative elements.
|
||||
|
||||
# Content Elements:
|
||||
- Title, Description, BulletPoints, etc.
|
||||
- Graphs, Charts, etc.
|
||||
- Images and Icons representing textual content like title, description, bullet points, etc.
|
||||
- Meaningful Images and Icons.
|
||||
- Icons in infographics that represent the data.
|
||||
|
||||
# Content Elements Rules:
|
||||
- Properly identify between images and icons elements.
|
||||
- Image content:
|
||||
- Image field should be 'z.object({"image_url": z.string(), "image_prompt": z.string().max(100)})'
|
||||
- Replace actual image url with 'https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png'
|
||||
- Icon content:
|
||||
- Icon field should be 'z.object({"icon_url": z.string(), "icon_query": z.string().max(30)})'
|
||||
- Replace actual icon url with 'https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg'
|
||||
- Add color styling to the icon to match the color in the image.
|
||||
- Make sure the urls are correct.
|
||||
|
||||
# Layout Rules:
|
||||
- The layout should be fixed 1280px width and 720px height.
|
||||
- Adjust the positions and sizes of elements to fit the layout.
|
||||
- Try to keep the positions and sizes of elements as close to HTML reference as possible.
|
||||
|
||||
# Flexible Positioning and Sizes Rules:
|
||||
- Must not use 'absolute' positioning for elements.
|
||||
- Must use 'flex', 'grid', 'margin', 'padding', 'gap', 'basis', 'justify', 'align', etc for positioning of elements.
|
||||
- For variable length lists, wrap list into a container and center it.
|
||||
- Don't use specific sizes (height, width) for elements if not necessary.
|
||||
|
||||
# Schema Field Name and Description Rules:
|
||||
- Must not use content specific words.
|
||||
- Only use words based on what content types are present in the slide image.
|
||||
- Use words like 'title', description', 'heading', 'image', 'graph', 'table', 'bullet points', etc.
|
||||
- Must not use words like 'budget', 'market', 'revenue', 'sales', 'growth', 'workflow', 'channel', 'plannedValue', 'actualValue', etc.
|
||||
|
||||
# Layout ID, Name and Description Rules:
|
||||
- Must only use slide structure to derive layout id, name and description.
|
||||
- Informations like: Type of content, position of content, etc. should be used.
|
||||
- layoutId example: title-description-right-image.
|
||||
- layoutName example: Title Description Image.
|
||||
- layoutDescription example: A slide with a title, description, and an image on right.
|
||||
|
||||
# Zod Schema Rules:
|
||||
- "describe" must be added for every fields.
|
||||
- "default" must be added in top level fields of schema.
|
||||
- Top level fields are those not nested inside other fields.
|
||||
- Don't mention string type in schema like "url()", "email()", etc.
|
||||
- Table must be object with "columns" and "rows" fields.
|
||||
- "columns" must be an array of strings.
|
||||
- "rows" must be an array of arrays of strings.
|
||||
- Graph must be object with "categories" and "series" fields.
|
||||
- "categories" must be an array of strings.
|
||||
- "series" must be an array of objects with {"name": string, "data": array of numbers}.
|
||||
- Must not use z.record() anywhere in the schema.
|
||||
|
||||
# String and Array Field Rules:
|
||||
- Every string field must include `.max(...)`; every array field must include `.max(...)`.
|
||||
- For strings, set `max` to the exact character count of the text content it represents.
|
||||
- For arrays, set `max` to the exact item count of the array content it represents.
|
||||
- Choose a `max` that keeps the longest allowed content from overflowing its container.
|
||||
|
||||
# Table Rules:
|
||||
- Construct "tr -> th" by iterating over the "columns" field.
|
||||
- Construct "tr -> td" by iterating over the "rows" field.
|
||||
- Make sure table height and width adjusts to fit the content.
|
||||
|
||||
# Grahps, Charts, etc Rules:
|
||||
- Identify if graphs, charts, etc are present in the slide image.
|
||||
- Identify the type of graph, chart, etc.
|
||||
- If present, generate a zod schema for the graph, chart, etc.
|
||||
- Generate TSX code for the graph, chart, etc. even if it is not present in the HTML reference.
|
||||
- Use graph schema and image to generate the TSX code.
|
||||
- Use Recharts library for graphs.
|
||||
|
||||
# Fonts Rules:
|
||||
- Check for "PROVIDED FONTS".
|
||||
- Must use fonts only from "PROVIDED FONTS".
|
||||
- Add "font-[\"font-name\"]" to every text element in the slide.
|
||||
|
||||
# Page Number Rules:
|
||||
- Identify if the slide contains page number from provided HTML reference and image.
|
||||
- If page number is present, add a "page: z.number().min(1).meta({ description: "Page number" })" field in the schema.
|
||||
|
||||
# React Component Rules:
|
||||
- React component must be named dynamicSlideLayout.
|
||||
- dynamicSlideLayout must take "{ data }: { data: Partial<z.infer<typeof Schema>> }" as props.
|
||||
- Wrap the code inside these classes: "relative w-full rounded-sm max-w-[1280px] shadow-lg max-h-[720px] aspect-video bg-white z-20 mx-auto overflow-hidden".
|
||||
- Make sure camelCase is used for all styles. For e.g. "letter-spacing" should be "letterSpacing".
|
||||
- Schema.parse must not be used in the code.
|
||||
- Use 'const {field1, field2, ...} = data;' to access the data.
|
||||
- field1 or field2 or ... can be undefined, so use optional chaining to access them.
|
||||
- Don't use "min-height" on cards and instead make its height grow/shrink to fit the content.
|
||||
- Make sure cards/items are centered vertically and horizontally in the available space.
|
||||
- Make sure no element is scrollable.
|
||||
- Don't add any animations, transitions, or effects.
|
||||
- Make sure no content elements are overflowing the slide boundaries.
|
||||
|
||||
# Import and Export Rules:
|
||||
- All import statements must be defined at the top.
|
||||
- Export using 'export {Schema, layoutId, layoutName, layoutDescription, dynamicSlideLayout}' statement at the bottom.
|
||||
- There must be only one 'export' statement in the whole TSX code.
|
||||
|
||||
# Output Code Rules:
|
||||
- Code should be in following order:
|
||||
- Zod Schema (Schema)
|
||||
- Layout ID, Name and Description (layoutId, layoutName, layoutDescription)
|
||||
- React Component (dynamicSlideLayout)
|
||||
- Give just one valid TSX code as output.
|
||||
- Don't add comments in the code.
|
||||
- Make sure the generated code is valid TSX code.
|
||||
- Give only code as output and nothing else. (no json, no markdown, no text, no explanation)
|
||||
|
||||
- Go through generated code and make sure all rules are followed.
|
||||
- Think as long as you can and iterate as many times as necessary to make sure all rules are followed.
|
||||
"""
|
||||
|
||||
SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT = """
|
||||
You need to edit the given TSX code of the slide layout code according to the prompt and provide it as output.
|
||||
|
||||
# Steps
|
||||
1. Analyze the TSX code to understand the slide layout.
|
||||
2. Analyze the prompt to understand the changes to be made.
|
||||
3. Edit the TSX code according to the prompt.
|
||||
4. Provide the updated TSX code as output.
|
||||
|
||||
# Rules
|
||||
- Make sure the changes does not break the existing code.
|
||||
- Make sure to follow the pattern of the existing code.
|
||||
- Make sure there are no unused schema fields after the changes are made.
|
||||
|
||||
# Icons and Images Rules
|
||||
Follow these rules if new icons/images are asked:
|
||||
- Image field should be 'z.object({"image_url": z.string(), "image_prompt": z.string().max(100)})'
|
||||
- Use this as default image url: 'https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png'
|
||||
- Icon field should be 'z.object({"icon_url": z.string(), "icon_query": z.string().max(30)})'
|
||||
- Use this as default icon url: 'https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg'
|
||||
|
||||
# Schema Rules
|
||||
- "describe" must be added for every fields.
|
||||
- "default" must be added in top level fields of schema.
|
||||
- Top level fields are those not nested inside other fields.
|
||||
- Must set max for every string and array fields.
|
||||
- Must set max to a number that will not cause overflow on max content.
|
||||
|
||||
# Graphs And Table Rules
|
||||
Follow these rules if new graphs/tables are asked:
|
||||
1. Schema Rules
|
||||
- Table must be object with "columns" and "rows" fields.
|
||||
- "columns" must be an array of strings.
|
||||
- "rows" must be an array of arrays of strings.
|
||||
- Graph must be object with "categories" and "series" fields.
|
||||
- "categories" must be an array of strings.
|
||||
- "series" must be an array of objects with {"name": string, "data": array of numbers}.
|
||||
2. React Component Rules
|
||||
- Use recharts library for graphs.
|
||||
|
||||
# Common Prompts
|
||||
1. Fix the slide
|
||||
- Check if text/cards/items is overflowing the slide boundaries or text/cards/items are overlapping.
|
||||
- If yes, fix by moving the element to a better position or resizing the element.
|
||||
|
||||
# Output Rules
|
||||
- Make sure the schema and react component are valid.
|
||||
- No matter what prompt is given, don't break the code.
|
||||
- Provide only the updated TSX code as output and nothing else. (no json, no markdown, no text, no explanation)
|
||||
"""
|
||||
|
||||
SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT = """
|
||||
You need to edit the given TSX code of the slide layout code according to the prompt and provide it as output.
|
||||
|
||||
# Steps
|
||||
1. Analyze the TSX code to understand the slide layout.
|
||||
2. Analyze the prompt to understand the changes to be made.
|
||||
3. Edit the TSX code according to the prompt.
|
||||
4. Provide the updated TSX code as output.
|
||||
|
||||
# Rules
|
||||
- Changes should be made only around the mentioned "section to make changes around".
|
||||
- Make sure the changes does not break the existing code.
|
||||
- Make sure to follow the pattern of the existing code.
|
||||
- Make sure there are no unused schema fields after the changes are made.
|
||||
|
||||
# Icons and Images Rules
|
||||
Follow these rules if new icons/images are asked:
|
||||
- Image field should be 'z.object({"image_url": z.string(), "image_prompt": z.string().max(100)})'
|
||||
- Use this as default image url: 'https://presenton-public-assets.s3.ap-southeast-1.amazonaws.com/replaceable_template_image.png'
|
||||
- Icon field should be 'z.object({"icon_url": z.string(), "icon_query": z.string().max(30)})'
|
||||
- Use this as default icon url: 'https://presenton-public.s3.ap-southeast-1.amazonaws.com/static/icons/placeholder.svg'
|
||||
|
||||
# Output Rules
|
||||
- Make sure the schema and react component are valid.
|
||||
- No matter what prompt is given, don't break the code.
|
||||
- Provide only the updated TSX code as output and nothing else. (no json, no markdown, no text, no explanation)
|
||||
"""
|
||||
365
electron/servers/fastapi/templates/providers.py
Normal file
365
electron/servers/fastapi/templates/providers.py
Normal file
|
|
@ -0,0 +1,365 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from fastapi import HTTPException
|
||||
from google import genai
|
||||
from google.genai import types as google_types
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_codex_access_token_env,
|
||||
get_codex_account_id_env,
|
||||
get_codex_refresh_token_env,
|
||||
get_codex_token_expires_env,
|
||||
get_google_api_key_env,
|
||||
get_openai_api_key_env,
|
||||
)
|
||||
from utils.llm_provider import get_llm_provider
|
||||
from utils.set_env import (
|
||||
set_codex_access_token_env,
|
||||
set_codex_account_id_env,
|
||||
set_codex_refresh_token_env,
|
||||
set_codex_token_expires_env,
|
||||
)
|
||||
|
||||
OPENAI_TEMPLATE_MODEL = "gpt-5.4"
|
||||
CODEX_TEMPLATE_MODEL = "gpt-5.4"
|
||||
GOOGLE_TEMPLATE_MODEL = "gemini-3.1"
|
||||
ANTHROPIC_TEMPLATE_MODEL = "opus 4.6"
|
||||
MAX_ATTEMPTS_PER_PROVIDER = 4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemplateProviderSpec:
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlainLLMProvider:
|
||||
name: str
|
||||
call: Callable[[], Awaitable[str]]
|
||||
|
||||
|
||||
def get_template_provider_spec() -> TemplateProviderSpec:
|
||||
provider = get_llm_provider()
|
||||
if provider == LLMProvider.OPENAI:
|
||||
return TemplateProviderSpec(provider=provider, model=OPENAI_TEMPLATE_MODEL)
|
||||
if provider == LLMProvider.CODEX:
|
||||
return TemplateProviderSpec(provider=provider, model=CODEX_TEMPLATE_MODEL)
|
||||
if provider == LLMProvider.GOOGLE:
|
||||
return TemplateProviderSpec(provider=provider, model=GOOGLE_TEMPLATE_MODEL)
|
||||
if provider == LLMProvider.ANTHROPIC:
|
||||
return TemplateProviderSpec(provider=provider, model=ANTHROPIC_TEMPLATE_MODEL)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Template generation only supports OpenAI, Codex, Google, or Anthropic.",
|
||||
)
|
||||
|
||||
|
||||
async def run_plain_provider_buckets(*, providers: list[PlainLLMProvider]) -> str:
|
||||
last_exception: Optional[Exception] = None
|
||||
|
||||
for provider in providers:
|
||||
for _ in range(MAX_ATTEMPTS_PER_PROVIDER):
|
||||
try:
|
||||
response_text = await provider.call()
|
||||
if response_text:
|
||||
return response_text
|
||||
raise ValueError("No output from template generation provider")
|
||||
except Exception as exc:
|
||||
last_exception = exc
|
||||
|
||||
if isinstance(last_exception, HTTPException):
|
||||
raise last_exception
|
||||
raise HTTPException(status_code=500, detail="Failed to generate template output")
|
||||
|
||||
|
||||
def _read_openai_response_text(response) -> str:
|
||||
output_text = getattr(response, "output_text", None)
|
||||
if output_text:
|
||||
return output_text
|
||||
text = getattr(response, "text", None)
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def _get_openai_client() -> AsyncOpenAI:
|
||||
api_key = get_openai_api_key_env()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="OPENAI_API_KEY is not set")
|
||||
return AsyncOpenAI(api_key=api_key, timeout=120.0)
|
||||
|
||||
|
||||
def _get_codex_headers() -> dict:
|
||||
access_token = get_codex_access_token_env()
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Codex OAuth access token is not set. Please authenticate via /api/v1/ppt/codex/auth/initiate",
|
||||
)
|
||||
|
||||
expires_str = get_codex_token_expires_env()
|
||||
if expires_str:
|
||||
try:
|
||||
expires_ms = int(expires_str)
|
||||
now_ms = int(time.time() * 1000)
|
||||
if now_ms >= expires_ms - 60_000:
|
||||
refresh_token = get_codex_refresh_token_env()
|
||||
if refresh_token:
|
||||
from utils.oauth.openai_codex import (
|
||||
TokenSuccess,
|
||||
get_account_id,
|
||||
refresh_access_token,
|
||||
)
|
||||
|
||||
result = refresh_access_token(refresh_token)
|
||||
if isinstance(result, TokenSuccess):
|
||||
set_codex_access_token_env(result.access)
|
||||
set_codex_refresh_token_env(result.refresh)
|
||||
set_codex_token_expires_env(str(result.expires))
|
||||
account_id = get_account_id(result.access)
|
||||
if account_id:
|
||||
set_codex_account_id_env(account_id)
|
||||
access_token = result.access
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
account_id = get_codex_account_id_env() or ""
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": "pi",
|
||||
}
|
||||
|
||||
|
||||
def _get_codex_client() -> AsyncOpenAI:
|
||||
headers = _get_codex_headers()
|
||||
access_token = (headers.get("Authorization") or "").replace("Bearer ", "").strip()
|
||||
default_headers = {
|
||||
key: value
|
||||
for key, value in headers.items()
|
||||
if key.lower() not in {"authorization", "content-type", "accept"}
|
||||
}
|
||||
return AsyncOpenAI(
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
api_key=access_token or "codex",
|
||||
default_headers=default_headers,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
||||
def _get_google_client() -> genai.Client:
|
||||
api_key = get_google_api_key_env()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="GOOGLE_API_KEY is not set")
|
||||
return genai.Client(api_key=api_key)
|
||||
|
||||
|
||||
def _get_anthropic_client() -> AsyncAnthropic:
|
||||
api_key = get_anthropic_api_key_env()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="ANTHROPIC_API_KEY is not set")
|
||||
return AsyncAnthropic(api_key=api_key)
|
||||
|
||||
|
||||
async def _call_openai_like(
|
||||
*,
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_bytes: Optional[bytes] = None,
|
||||
media_type: str = "image/png",
|
||||
reasoning_effort: str = "medium",
|
||||
) -> str:
|
||||
content = [{"type": "input_text", "text": user_text}]
|
||||
if image_bytes:
|
||||
content.insert(
|
||||
0,
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:{media_type};base64,{base64.b64encode(image_bytes).decode('utf-8')}",
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.responses.create(
|
||||
model=model,
|
||||
instructions=system_prompt,
|
||||
input=[{"role": "user", "content": content}],
|
||||
reasoning={"effort": reasoning_effort},
|
||||
text={"verbosity": "medium"},
|
||||
store=False,
|
||||
)
|
||||
output_text = _read_openai_response_text(response)
|
||||
if not output_text:
|
||||
raise HTTPException(status_code=500, detail="No output from template provider")
|
||||
return output_text
|
||||
|
||||
|
||||
async def _call_google(
|
||||
*,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_bytes: Optional[bytes] = None,
|
||||
media_type: str = "image/png",
|
||||
) -> str:
|
||||
client = _get_google_client()
|
||||
parts = [google_types.Part.from_text(text=user_text)]
|
||||
if image_bytes:
|
||||
parts.append(google_types.Part.from_bytes(data=image_bytes, mime_type=media_type))
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[google_types.Content(role="user", parts=parts)],
|
||||
config=google_types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="text/plain",
|
||||
),
|
||||
)
|
||||
output_text = getattr(response, "text", None) or ""
|
||||
if not output_text:
|
||||
raise HTTPException(status_code=500, detail="No output from template provider")
|
||||
return output_text
|
||||
|
||||
|
||||
async def _call_anthropic(
|
||||
*,
|
||||
model: str,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_bytes: Optional[bytes] = None,
|
||||
media_type: str = "image/png",
|
||||
) -> str:
|
||||
client = _get_anthropic_client()
|
||||
content = [{"type": "text", "text": user_text}]
|
||||
if image_bytes:
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64.b64encode(image_bytes).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.messages.create(
|
||||
model=model,
|
||||
max_tokens=8192,
|
||||
system=system_prompt,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
)
|
||||
output_text = "".join(
|
||||
block.text for block in response.content if getattr(block, "type", None) == "text"
|
||||
)
|
||||
if not output_text:
|
||||
raise HTTPException(status_code=500, detail="No output from template provider")
|
||||
return output_text
|
||||
|
||||
|
||||
def _build_provider_call(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_bytes: Optional[bytes] = None,
|
||||
media_type: str = "image/png",
|
||||
reasoning_effort: str = "medium",
|
||||
) -> PlainLLMProvider:
|
||||
spec = get_template_provider_spec()
|
||||
|
||||
if spec.provider == LLMProvider.OPENAI:
|
||||
return PlainLLMProvider(
|
||||
name="OpenAI",
|
||||
call=lambda: _call_openai_like(
|
||||
client=_get_openai_client(),
|
||||
model=spec.model,
|
||||
system_prompt=system_prompt,
|
||||
user_text=user_text,
|
||||
image_bytes=image_bytes,
|
||||
media_type=media_type,
|
||||
reasoning_effort=reasoning_effort,
|
||||
),
|
||||
)
|
||||
if spec.provider == LLMProvider.CODEX:
|
||||
return PlainLLMProvider(
|
||||
name="Codex",
|
||||
call=lambda: _call_openai_like(
|
||||
client=_get_codex_client(),
|
||||
model=spec.model,
|
||||
system_prompt=system_prompt,
|
||||
user_text=user_text,
|
||||
image_bytes=image_bytes,
|
||||
media_type=media_type,
|
||||
reasoning_effort=reasoning_effort,
|
||||
),
|
||||
)
|
||||
if spec.provider == LLMProvider.GOOGLE:
|
||||
return PlainLLMProvider(
|
||||
name="Google",
|
||||
call=lambda: _call_google(
|
||||
model=spec.model,
|
||||
system_prompt=system_prompt,
|
||||
user_text=user_text,
|
||||
image_bytes=image_bytes,
|
||||
media_type=media_type,
|
||||
),
|
||||
)
|
||||
if spec.provider == LLMProvider.ANTHROPIC:
|
||||
return PlainLLMProvider(
|
||||
name="Anthropic",
|
||||
call=lambda: _call_anthropic(
|
||||
model=spec.model,
|
||||
system_prompt=system_prompt,
|
||||
user_text=user_text,
|
||||
image_bytes=image_bytes,
|
||||
media_type=media_type,
|
||||
),
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Template generation only supports OpenAI, Codex, Google, or Anthropic.",
|
||||
)
|
||||
|
||||
|
||||
async def generate_slide_layout_code(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
image_bytes: bytes,
|
||||
media_type: str = "image/png",
|
||||
) -> str:
|
||||
provider = _build_provider_call(
|
||||
system_prompt=system_prompt,
|
||||
user_text=user_text,
|
||||
image_bytes=image_bytes,
|
||||
media_type=media_type,
|
||||
reasoning_effort="high",
|
||||
)
|
||||
return await run_plain_provider_buckets(providers=[provider])
|
||||
|
||||
|
||||
async def edit_slide_layout_code(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_text: str,
|
||||
) -> str:
|
||||
provider = _build_provider_call(
|
||||
system_prompt=system_prompt,
|
||||
user_text=user_text,
|
||||
reasoning_effort="medium",
|
||||
)
|
||||
return await run_plain_provider_buckets(providers=[provider])
|
||||
65
electron/servers/fastapi/templates/router.py
Normal file
65
electron/servers/fastapi/templates/router.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from templates.handler import (
|
||||
CreateSlideLayoutResponse,
|
||||
EditSlideLayoutResponse,
|
||||
EditSlideLayoutSectionResponse,
|
||||
FontsUploadAndSlidesPreviewResponse,
|
||||
GetTemplateLayoutsResponse,
|
||||
PresentationLayoutModel,
|
||||
SaveTemplateLayoutData,
|
||||
SaveTemplateResponse,
|
||||
TemplateDetail,
|
||||
TemplateExample,
|
||||
clone_slide_layout,
|
||||
clone_template,
|
||||
create_slide_layout,
|
||||
edit_slide_layout,
|
||||
edit_slide_layout_section,
|
||||
get_all_templates,
|
||||
get_layouts,
|
||||
get_template_by_id,
|
||||
get_template_example,
|
||||
init_create_template,
|
||||
save_slide_layout,
|
||||
save_template,
|
||||
update_template,
|
||||
upload_fonts_and_slides_preview,
|
||||
)
|
||||
|
||||
TEMPLATE_ROUTER = APIRouter(prefix="/template", tags=["Template"])
|
||||
|
||||
TEMPLATE_ROUTER.get("/all", response_model=list[TemplateDetail])(get_all_templates)
|
||||
TEMPLATE_ROUTER.get(
|
||||
"/{template_id}/layouts", response_model=GetTemplateLayoutsResponse
|
||||
)(get_layouts)
|
||||
TEMPLATE_ROUTER.get("/{id}", response_model=PresentationLayoutModel)(get_template_by_id)
|
||||
TEMPLATE_ROUTER.get("/{id}/example", response_model=TemplateExample)(
|
||||
get_template_example
|
||||
)
|
||||
TEMPLATE_ROUTER.post(
|
||||
"/fonts-upload-and-slides-preview",
|
||||
response_model=FontsUploadAndSlidesPreviewResponse,
|
||||
)(upload_fonts_and_slides_preview)
|
||||
TEMPLATE_ROUTER.post("/create/init", response_model=uuid.UUID)(init_create_template)
|
||||
TEMPLATE_ROUTER.post("/slide-layout/create", response_model=CreateSlideLayoutResponse)(
|
||||
create_slide_layout
|
||||
)
|
||||
TEMPLATE_ROUTER.post("/create/slide-layout", response_model=CreateSlideLayoutResponse)(
|
||||
create_slide_layout
|
||||
)
|
||||
TEMPLATE_ROUTER.post("/slide-layout/edit", response_model=EditSlideLayoutResponse)(
|
||||
edit_slide_layout
|
||||
)
|
||||
TEMPLATE_ROUTER.post(
|
||||
"/slide-layout/edit-section", response_model=EditSlideLayoutSectionResponse
|
||||
)(edit_slide_layout_section)
|
||||
TEMPLATE_ROUTER.post("/save", response_model=SaveTemplateResponse)(save_template)
|
||||
TEMPLATE_ROUTER.post("/clone", response_model=SaveTemplateResponse)(clone_template)
|
||||
TEMPLATE_ROUTER.put("/update", response_model=SaveTemplateResponse)(update_template)
|
||||
TEMPLATE_ROUTER.post("/slide-layout/save", status_code=200)(save_slide_layout)
|
||||
TEMPLATE_ROUTER.post("/slide-layout/clone", response_model=SaveTemplateLayoutData)(
|
||||
clone_slide_layout
|
||||
)
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import UploadFile
|
||||
import pytest
|
||||
|
||||
from api.main import app
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def create_sample_pptx():
|
||||
"""Create a minimal PPTX file for testing."""
|
||||
# This creates a very basic PPTX structure for testing
|
||||
pptx_content = {
|
||||
'[Content_Types].xml': '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<Types xmlns="http://schemas.openxmlformats.org/package/2006/content-types">
|
||||
<Default Extension="xml" ContentType="application/xml"/>
|
||||
<Default Extension="rels" ContentType="application/vnd.openxmlformats-package.relationships+xml"/>
|
||||
<Override PartName="/ppt/presentation.xml" ContentType="application/vnd.openxmlformats-officedocument.presentationml.presentation.main+xml"/>
|
||||
<Override PartName="/ppt/slides/slide1.xml" ContentType="application/vnd.openxmlformats-officedocument.presentationml.slide+xml"/>
|
||||
</Types>''',
|
||||
'_rels/.rels': '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<Relationships xmlns="http://schemas.openxmlformats.org/package/2006/relationships">
|
||||
<Relationship Id="rId1" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/officeDocument" Target="ppt/presentation.xml"/>
|
||||
</Relationships>''',
|
||||
'ppt/presentation.xml': '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<p:presentation xmlns:p="http://schemas.openxmlformats.org/presentationml/2006/main">
|
||||
<p:sldMasterIdLst/>
|
||||
<p:sldIdLst>
|
||||
<p:sldId id="256" r:id="rId2"/>
|
||||
</p:sldIdLst>
|
||||
<p:sldSz cx="9144000" cy="6858000"/>
|
||||
</p:presentation>''',
|
||||
'ppt/_rels/presentation.xml.rels': '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<Relationships xmlns="http://schemas.openxmlformats.org/package/2006/relationships">
|
||||
<Relationship Id="rId2" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/slide" Target="slides/slide1.xml"/>
|
||||
</Relationships>''',
|
||||
'ppt/slides/slide1.xml': '''<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<p:sld xmlns:p="http://schemas.openxmlformats.org/presentationml/2006/main">
|
||||
<p:cSld>
|
||||
<p:spTree>
|
||||
<p:nvGrpSpPr>
|
||||
<p:cNvPr id="1" name=""/>
|
||||
<p:cNvGrpSpPr/>
|
||||
<p:nvPr/>
|
||||
</p:nvGrpSpPr>
|
||||
<p:grpSpPr>
|
||||
<a:xfrm xmlns:a="http://schemas.openxmlformats.org/drawingml/2006/main">
|
||||
<a:off x="0" y="0"/>
|
||||
<a:ext cx="0" cy="0"/>
|
||||
<a:chOff x="0" y="0"/>
|
||||
<a:chExt cx="0" cy="0"/>
|
||||
</a:xfrm>
|
||||
</p:grpSpPr>
|
||||
</p:spTree>
|
||||
</p:cSld>
|
||||
<p:clrMapOvr>
|
||||
<a:masterClrMapping xmlns:a="http://schemas.openxmlformats.org/drawingml/2006/main"/>
|
||||
</p:clrMapOvr>
|
||||
</p:sld>'''
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.pptx', delete=False) as temp_file:
|
||||
with zipfile.ZipFile(temp_file.name, 'w') as zip_file:
|
||||
for path, content in pptx_content.items():
|
||||
zip_file.writestr(path, content)
|
||||
return temp_file.name
|
||||
|
||||
|
||||
def test_pptx_slides_processing():
|
||||
"""Test the PPTX slides processing endpoint."""
|
||||
|
||||
# Create a sample PPTX file
|
||||
pptx_path = create_sample_pptx()
|
||||
|
||||
try:
|
||||
with open(pptx_path, 'rb') as pptx_file:
|
||||
files = {'pptx_file': ('test.pptx', pptx_file, 'application/vnd.openxmlformats-officedocument.presentationml.presentation')}
|
||||
|
||||
response = client.post("/api/v1/ppt/pptx-slides/process", files=files)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data['success'] == True
|
||||
assert 'slides' in data
|
||||
assert 'total_slides' in data
|
||||
assert data['total_slides'] > 0
|
||||
|
||||
# Check slide data structure
|
||||
if data['slides']:
|
||||
slide = data['slides'][0]
|
||||
assert 'slide_number' in slide
|
||||
assert 'screenshot_url' in slide
|
||||
assert 'xml_content' in slide
|
||||
assert slide['slide_number'] == 1
|
||||
assert slide['xml_content'] != ''
|
||||
|
||||
print(f"✅ Test passed! Processed {data['total_slides']} slides successfully")
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
if os.path.exists(pptx_path):
|
||||
os.unlink(pptx_path)
|
||||
|
||||
|
||||
def test_invalid_file_type():
|
||||
"""Test that non-PPTX files are rejected."""
|
||||
|
||||
# Create a text file and try to upload it
|
||||
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as temp_file:
|
||||
temp_file.write(b"This is not a PPTX file")
|
||||
temp_file.flush()
|
||||
|
||||
try:
|
||||
with open(temp_file.name, 'rb') as txt_file:
|
||||
files = {'pptx_file': ('test.txt', txt_file, 'text/plain')}
|
||||
|
||||
response = client.post("/api/v1/ppt/pptx-slides/process", files=files)
|
||||
|
||||
# Should return 400 for invalid file type
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert 'Invalid file type' in data['detail']
|
||||
|
||||
print("✅ Invalid file type test passed!")
|
||||
|
||||
finally:
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running PPTX slides processing tests...")
|
||||
test_pptx_slides_processing()
|
||||
test_invalid_file_type()
|
||||
print("🎉 All tests completed!")
|
||||
|
|
@ -1,189 +1,117 @@
|
|||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
from fastapi import HTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
class MockAiohttpResponse:
|
||||
def __init__(self, status=200, json_data=None):
|
||||
self.status = status
|
||||
self._json_data = json_data or {"path": "/tmp/exports/test.pdf"}
|
||||
from api.v1.ppt.endpoints.presentation import generate_presentation_sync
|
||||
from models.generate_presentation_request import GeneratePresentationRequest
|
||||
from models.presentation_and_path import PresentationPathAndEditPath
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
class FakeAsyncSession:
|
||||
async def get(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
def add(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
async def text(self):
|
||||
return str(self._json_data)
|
||||
def add_all(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class MockAiohttpSession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
async def commit(self):
|
||||
return None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
def post(self, *args, **kwargs):
|
||||
return MockAiohttpResponse()
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
pptx_model_data = {
|
||||
"slides": [],
|
||||
"title": "Test",
|
||||
"notes": [],
|
||||
"layout": {},
|
||||
"structure": {},
|
||||
}
|
||||
return MockAiohttpResponse(json_data=pptx_model_data)
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = FastAPI()
|
||||
app.include_router(PRESENTATION_ROUTER, prefix="/api/v1/ppt")
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_layout():
|
||||
async def _mock_get_layout_by_name(layout_name: str):
|
||||
mock_slide = MagicMock()
|
||||
mock_slide.name = "Mock Slide"
|
||||
mock_slide.json_schema = {"title": "Mock Slide Title"}
|
||||
mock_slide.description = "Mock slide description"
|
||||
mock_layout = MagicMock(spec=PresentationLayoutModel)
|
||||
mock_layout.name = layout_name
|
||||
mock_layout.ordered = True
|
||||
mock_layout.slides = [mock_slide]
|
||||
mock_layout.model_dump = lambda: {}
|
||||
mock_layout.to_presentation_structure = lambda: PresentationStructureModel(
|
||||
slides=[index for index in range(len(mock_layout.slides))]
|
||||
)
|
||||
def to_string():
|
||||
message = f"## Presentation Layout\n\n"
|
||||
for index, slide in enumerate(mock_layout.slides):
|
||||
message += f"### Slide Layout: {index}: \n"
|
||||
message += f"- Name: {slide.name or slide.json_schema.get('title')} \n"
|
||||
message += f"- Description: {slide.description} \n\n"
|
||||
return message
|
||||
mock_layout.to_string = to_string
|
||||
return mock_layout
|
||||
return _mock_get_layout_by_name
|
||||
|
||||
async def mock_generate_ppt_outline(*args, **kwargs):
|
||||
yield '{"title": "Test", "slides": [{"title": "Slide 1", "body": "Body 1"}], "notes": []}'
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_presentation_api(monkeypatch, mock_get_layout):
|
||||
# Patch all dependencies used in the API
|
||||
patches = [
|
||||
patch('api.v1.ppt.endpoints.presentation.get_layout_by_name', new=AsyncMock(side_effect=mock_get_layout)),
|
||||
patch('api.v1.ppt.endpoints.presentation.TEMP_FILE_SERVICE.create_temp_dir', return_value='/tmp/mockdir'),
|
||||
patch('api.v1.ppt.endpoints.presentation.DocumentsLoader'),
|
||||
patch('api.v1.ppt.endpoints.presentation.generate_document_summary', new_callable=AsyncMock, return_value="mock_summary"),
|
||||
patch('api.v1.ppt.endpoints.presentation.generate_ppt_outline', side_effect=mock_generate_ppt_outline),
|
||||
patch('api.v1.ppt.endpoints.presentation.get_sql_session'),
|
||||
patch('api.v1.ppt.endpoints.presentation.get_slide_content_from_type_and_outline', new_callable=AsyncMock, return_value={"mock": "slide_content"}),
|
||||
patch('api.v1.ppt.endpoints.presentation.process_slide_and_fetch_assets', new_callable=AsyncMock),
|
||||
patch('api.v1.ppt.endpoints.presentation.get_exports_directory', return_value='/tmp/exports'),
|
||||
patch('api.v1.ppt.endpoints.presentation.PptxPresentationCreator'),
|
||||
patch('api.v1.ppt.endpoints.presentation.aiohttp.ClientSession', return_value=MockAiohttpSession()),
|
||||
]
|
||||
mocks = [p.start() for p in patches]
|
||||
|
||||
# Setup DocumentsLoader mock
|
||||
docs_loader = mocks[2]
|
||||
docs_loader.return_value.load_documents = AsyncMock()
|
||||
docs_loader.return_value.documents = []
|
||||
|
||||
# Setup PptxPresentationCreator mock for pptx test
|
||||
pptx_creator = mocks[9]
|
||||
pptx_creator.return_value.create_ppt = AsyncMock()
|
||||
pptx_creator.return_value.save = MagicMock()
|
||||
|
||||
yield
|
||||
|
||||
for p in patches:
|
||||
p.stop()
|
||||
|
||||
class TestPresentationGenerationAPI:
|
||||
def test_generate_presentation_export_as_pdf(self, client):
|
||||
response = client.post(
|
||||
"/api/v1/ppt/presentation/generate",
|
||||
json={
|
||||
"content": "Create a presentation about artificial intelligence and machine learning",
|
||||
"n_slides": 5,
|
||||
"language": "English",
|
||||
"export_as": "pdf",
|
||||
"layout": "general"
|
||||
}
|
||||
def test_generate_presentation_export_as_pdf(self):
|
||||
request = GeneratePresentationRequest(
|
||||
content="Create a presentation about artificial intelligence and machine learning",
|
||||
n_slides=5,
|
||||
language="English",
|
||||
export_as="pdf",
|
||||
template="general",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "presentation_id" in response.json()
|
||||
assert "pdf" in response.json()["path"]
|
||||
|
||||
def test_generate_presentation_export_as_pptx(self, client):
|
||||
response = client.post(
|
||||
"/api/v1/ppt/presentation/generate",
|
||||
json={
|
||||
"content": "Create a presentation about artificial intelligence and machine learning",
|
||||
"n_slides": 5,
|
||||
"language": "English",
|
||||
"export_as": "pptx",
|
||||
"layout": "general"
|
||||
}
|
||||
response_payload = PresentationPathAndEditPath(
|
||||
presentation_id=uuid.uuid4(),
|
||||
path="/tmp/exports/test.pdf",
|
||||
edit_path="/presentation?id=test",
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "presentation_id" in response.json()
|
||||
assert "pptx" in response.json()["path"]
|
||||
|
||||
def test_generate_presentation_with_no_content(self, client):
|
||||
response = client.post(
|
||||
"/api/v1/ppt/presentation/generate",
|
||||
json={
|
||||
"n_slides": 5,
|
||||
"language": "English",
|
||||
"export_as": "pdf",
|
||||
"layout": "general"
|
||||
}
|
||||
with patch(
|
||||
"api.v1.ppt.endpoints.presentation.generate_presentation_handler",
|
||||
new=AsyncMock(return_value=response_payload),
|
||||
) as mock_handler:
|
||||
response = asyncio.run(
|
||||
generate_presentation_sync(request, sql_session=FakeAsyncSession())
|
||||
)
|
||||
|
||||
assert response == response_payload
|
||||
mock_handler.assert_awaited_once()
|
||||
|
||||
def test_generate_presentation_export_as_pptx(self):
|
||||
request = GeneratePresentationRequest(
|
||||
content="Create a presentation about artificial intelligence and machine learning",
|
||||
n_slides=5,
|
||||
language="English",
|
||||
export_as="pptx",
|
||||
template="general",
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_generate_presentation_with_n_slides_less_than_one(self, client):
|
||||
response = client.post(
|
||||
"/api/v1/ppt/presentation/generate",
|
||||
json={
|
||||
"content": "Create a presentation about artificial intelligence and machine learning",
|
||||
"n_slides": 0,
|
||||
"language": "English",
|
||||
"export_as": "pdf",
|
||||
"layout": "general"
|
||||
}
|
||||
response_payload = PresentationPathAndEditPath(
|
||||
presentation_id=uuid.uuid4(),
|
||||
path="/tmp/exports/test.pptx",
|
||||
edit_path="/presentation?id=test",
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_generate_presentation_with_invalid_export_type(self, client):
|
||||
response = client.post(
|
||||
"/api/v1/ppt/presentation/generate",
|
||||
json={
|
||||
"content": "Create a presentation about artificial intelligence and machine learning",
|
||||
"n_slides": 5,
|
||||
"language": "English",
|
||||
"export_as": "invalid_type",
|
||||
"layout": "general"
|
||||
}
|
||||
with patch(
|
||||
"api.v1.ppt.endpoints.presentation.generate_presentation_handler",
|
||||
new=AsyncMock(return_value=response_payload),
|
||||
) as mock_handler:
|
||||
response = asyncio.run(
|
||||
generate_presentation_sync(request, sql_session=FakeAsyncSession())
|
||||
)
|
||||
|
||||
assert response == response_payload
|
||||
mock_handler.assert_awaited_once()
|
||||
|
||||
def test_generate_presentation_with_no_content(self):
|
||||
with pytest.raises(ValidationError):
|
||||
GeneratePresentationRequest.model_validate(
|
||||
{
|
||||
"n_slides": 5,
|
||||
"language": "English",
|
||||
"export_as": "pdf",
|
||||
"template": "general",
|
||||
}
|
||||
)
|
||||
|
||||
def test_generate_presentation_with_n_slides_less_than_one(self):
|
||||
request = GeneratePresentationRequest(
|
||||
content="Create a presentation about artificial intelligence and machine learning",
|
||||
n_slides=0,
|
||||
language="English",
|
||||
export_as="pdf",
|
||||
template="general",
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
asyncio.run(
|
||||
generate_presentation_sync(request, sql_session=FakeAsyncSession())
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
assert exc.value.detail == "Number of slides must be greater than 0"
|
||||
|
||||
def test_generate_presentation_with_invalid_export_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
GeneratePresentationRequest.model_validate(
|
||||
{
|
||||
"content": "Create a presentation about artificial intelligence and machine learning",
|
||||
"n_slides": 5,
|
||||
"language": "English",
|
||||
"export_as": "invalid_type",
|
||||
"template": "general",
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,115 +0,0 @@
|
|||
import pytest
|
||||
import os
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Import the main app
|
||||
from server import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_slide_to_html_endpoint():
|
||||
"""Test the slide-to-html endpoint with streaming API support."""
|
||||
|
||||
# Sample XML data (simplified version of OXML)
|
||||
test_xml = '''<?xml version="1.0" encoding="UTF-8"?>
|
||||
<p:sld xmlns:p="http://schemas.openxmlformats.org/presentationml/2006/main">
|
||||
<p:cSld>
|
||||
<p:bg>
|
||||
<p:bgPr>
|
||||
<a:solidFill xmlns:a="http://schemas.openxmlformats.org/drawingml/2006/main">
|
||||
<a:srgbClr val="FFFFFF"/>
|
||||
</a:solidFill>
|
||||
</p:bgPr>
|
||||
</p:bg>
|
||||
<p:spTree>
|
||||
<p:sp>
|
||||
<p:txBody>
|
||||
<a:p xmlns:a="http://schemas.openxmlformats.org/drawingml/2006/main">
|
||||
<a:r>
|
||||
<a:t>Test Slide</a:t>
|
||||
</a:r>
|
||||
</a:p>
|
||||
</p:txBody>
|
||||
</p:sp>
|
||||
</p:spTree>
|
||||
</p:cSld>
|
||||
</p:sld>'''
|
||||
|
||||
# Skip this test if ANTHROPIC_API_KEY is not set
|
||||
if not os.getenv('ANTHROPIC_API_KEY'):
|
||||
pytest.skip("ANTHROPIC_API_KEY not set - skipping API test")
|
||||
|
||||
# Use a placeholder image path (since we can't easily test with real files)
|
||||
test_data = {
|
||||
"image": "/static/images/placeholder.jpg",
|
||||
"xml": test_xml
|
||||
}
|
||||
|
||||
# Make the request with JSON
|
||||
response = client.post(
|
||||
"/api/v1/ppt/slide-to-html/",
|
||||
json=test_data
|
||||
)
|
||||
|
||||
# Check response (may take several minutes due to streaming)
|
||||
print("Note: This test may take several minutes due to Claude's streaming processing...")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "html" in data
|
||||
assert len(data["html"]) > 0
|
||||
print(f"Generated HTML preview: {data['html'][:200]}...")
|
||||
print("✅ Streaming API test completed successfully")
|
||||
else:
|
||||
print(f"Request failed with status {response.status_code}: {response.text}")
|
||||
# Don't fail the test if API key is missing or invalid
|
||||
if "ANTHROPIC_API_KEY" in response.text:
|
||||
pytest.skip("Invalid API key - skipping test")
|
||||
elif "Streaming is required" in response.text:
|
||||
print("✅ Streaming error handled correctly by endpoint")
|
||||
|
||||
|
||||
def test_slide_to_html_invalid_path():
|
||||
"""Test the endpoint with an invalid image path."""
|
||||
|
||||
test_data = {
|
||||
"image": "/app_data/images/nonexistent/image.png",
|
||||
"xml": "<simple>xml</simple>"
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ppt/slide-to-html/",
|
||||
json=test_data
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Image file not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_slide_to_html_missing_xml():
|
||||
"""Test the endpoint with missing XML data."""
|
||||
|
||||
test_data = {
|
||||
"image": "/static/images/placeholder.jpg"
|
||||
# No XML data provided
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ppt/slide-to-html/",
|
||||
json=test_data
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run a simple test
|
||||
test_slide_to_html_invalid_path()
|
||||
print("✅ Invalid path test passed")
|
||||
|
||||
test_slide_to_html_missing_xml()
|
||||
print("✅ Missing XML test passed")
|
||||
|
||||
print("🧪 Run full tests with: pytest test_slide_to_html.py")
|
||||
479
electron/servers/fastapi/tests/test_template_api.py
Normal file
479
electron/servers/fastapi/tests/test_template_api.py
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.sql import Delete, Select
|
||||
|
||||
from api.v1.ppt.router import API_V1_PPT_ROUTER
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.sql.presentation_layout_code import PresentationLayoutCodeModel
|
||||
from models.sql.template import TemplateModel
|
||||
from models.sql.template_create_info import TemplateCreateInfoModel
|
||||
from templates.handler import (
|
||||
CloneSlideLayoutRequest,
|
||||
CreateSlideLayoutRequest,
|
||||
EditSlideLayoutRequest,
|
||||
EditSlideLayoutSectionRequest,
|
||||
SaveSlideLayoutRequest,
|
||||
SaveTemplateLayoutData,
|
||||
SaveTemplateRequest,
|
||||
UpdateTemplateRequest,
|
||||
clone_slide_layout,
|
||||
create_slide_layout,
|
||||
edit_slide_layout,
|
||||
edit_slide_layout_section,
|
||||
init_create_template,
|
||||
save_slide_layout,
|
||||
save_template,
|
||||
update_template,
|
||||
upload_fonts_and_slides_preview,
|
||||
)
|
||||
from templates.pptx_html_stub import BASIC_TEMPLATE_HTML
|
||||
from templates.preview import (
|
||||
FontCheckResponse,
|
||||
FontsUploadAndSlidesPreviewResponse,
|
||||
check_fonts_in_pptx_handler,
|
||||
)
|
||||
from templates.providers import (
|
||||
ANTHROPIC_TEMPLATE_MODEL,
|
||||
CODEX_TEMPLATE_MODEL,
|
||||
GOOGLE_TEMPLATE_MODEL,
|
||||
OPENAI_TEMPLATE_MODEL,
|
||||
get_template_provider_spec,
|
||||
)
|
||||
|
||||
|
||||
PNG_BYTES = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAusB9s5WzxQAAAAASUVORK5CYII="
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("APP_DATA_DIRECTORY", str(tmp_path / "app_data"))
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(API_V1_PPT_ROUTER)
|
||||
client = TestClient(app)
|
||||
yield client, tmp_path
|
||||
client.close()
|
||||
|
||||
|
||||
class FakeScalarResult:
|
||||
def __init__(self, items):
|
||||
self._items = list(items)
|
||||
|
||||
def all(self):
|
||||
return list(self._items)
|
||||
|
||||
|
||||
class FakeExecuteResult:
|
||||
def __init__(self, items):
|
||||
self._items = list(items)
|
||||
|
||||
def scalars(self):
|
||||
return FakeScalarResult(self._items)
|
||||
|
||||
def all(self):
|
||||
return list(self._items)
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self):
|
||||
self.template_infos: dict = {}
|
||||
self.templates: dict = {}
|
||||
self.layouts: list[PresentationLayoutCodeModel] = []
|
||||
self._next_layout_row_id = 1
|
||||
|
||||
async def get(self, model, key):
|
||||
if model is TemplateCreateInfoModel:
|
||||
return self.template_infos.get(key)
|
||||
if model is TemplateModel:
|
||||
return self.templates.get(key)
|
||||
return None
|
||||
|
||||
def add(self, obj):
|
||||
if isinstance(obj, TemplateCreateInfoModel):
|
||||
if getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(timezone.utc)
|
||||
self.template_infos[obj.id] = obj
|
||||
elif isinstance(obj, TemplateModel):
|
||||
if getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(timezone.utc)
|
||||
self.templates[obj.id] = obj
|
||||
elif isinstance(obj, PresentationLayoutCodeModel):
|
||||
if obj.id is None:
|
||||
obj.id = self._next_layout_row_id
|
||||
self._next_layout_row_id += 1
|
||||
if getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(timezone.utc)
|
||||
if getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(timezone.utc)
|
||||
self.layouts.append(obj)
|
||||
|
||||
def add_all(self, objects):
|
||||
for obj in objects:
|
||||
self.add(obj)
|
||||
|
||||
async def commit(self):
|
||||
return None
|
||||
|
||||
async def refresh(self, _obj):
|
||||
return None
|
||||
|
||||
async def scalar(self, statement):
|
||||
items = self._execute_select(statement)
|
||||
return items[0] if items else None
|
||||
|
||||
async def execute(self, statement):
|
||||
if isinstance(statement, Delete):
|
||||
self._execute_delete(statement)
|
||||
return FakeExecuteResult([])
|
||||
return FakeExecuteResult(self._execute_select(statement))
|
||||
|
||||
def _execute_select(self, statement):
|
||||
if not isinstance(statement, Select):
|
||||
return []
|
||||
entity = statement.column_descriptions[0].get("entity")
|
||||
if entity is not PresentationLayoutCodeModel:
|
||||
return []
|
||||
return [
|
||||
layout
|
||||
for layout in self.layouts
|
||||
if all(self._matches_clause(layout, clause) for clause in statement._where_criteria)
|
||||
]
|
||||
|
||||
def _execute_delete(self, statement):
|
||||
if statement.table.name != PresentationLayoutCodeModel.__tablename__:
|
||||
return
|
||||
self.layouts = [
|
||||
layout
|
||||
for layout in self.layouts
|
||||
if not all(self._matches_clause(layout, clause) for clause in statement._where_criteria)
|
||||
]
|
||||
|
||||
def _matches_clause(self, obj, clause):
|
||||
if hasattr(clause, "clauses"):
|
||||
return all(self._matches_clause(obj, child) for child in clause.clauses)
|
||||
left = getattr(clause.left, "key", None) or getattr(clause.left, "name", None)
|
||||
right = getattr(clause.right, "value", None)
|
||||
return getattr(obj, left) == right
|
||||
|
||||
|
||||
class SimpleUploadFile:
|
||||
def __init__(self, filename: str, content_type: str, data: bytes):
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
self._data = data
|
||||
|
||||
async def read(self):
|
||||
return self._data
|
||||
|
||||
|
||||
def test_router_registration_replaces_old_routes(api_client):
|
||||
client, _ = api_client
|
||||
paths = {
|
||||
route.path
|
||||
for route in client.app.routes
|
||||
if hasattr(route, "path") and route.path.startswith("/api/v1/ppt")
|
||||
}
|
||||
|
||||
assert "/api/v1/ppt/template/all" in paths
|
||||
assert "/api/v1/ppt/template/create/init" in paths
|
||||
assert "/api/v1/ppt/template/slide-layout/create" in paths
|
||||
assert "/api/v1/ppt/template/fonts-upload-and-slides-preview" in paths
|
||||
assert "/api/v1/ppt/fonts/check" in paths
|
||||
|
||||
assert "/api/v1/ppt/fonts/upload" in paths
|
||||
assert "/api/v1/ppt/fonts/list" in paths
|
||||
assert "/api/v1/ppt/fonts/uploaded" in paths
|
||||
|
||||
assert "/api/v1/ppt/slide-to-html/" not in paths
|
||||
assert "/api/v1/ppt/html-to-react/" not in paths
|
||||
assert "/api/v1/ppt/html-edit/" not in paths
|
||||
assert "/api/v1/ppt/pptx-slides/process" not in paths
|
||||
assert "/api/v1/ppt/pdf-slides/process" not in paths
|
||||
assert "/api/v1/ppt/pptx-fonts/process" not in paths
|
||||
|
||||
|
||||
def test_template_create_init_stores_stub_htmls():
|
||||
session = FakeAsyncSession()
|
||||
|
||||
template_info_id = asyncio.run(
|
||||
init_create_template(
|
||||
request=type(
|
||||
"Request",
|
||||
(),
|
||||
{
|
||||
"pptx_url": "/app_data/uploads/template-previews/test/presentation.pptx",
|
||||
"slide_image_urls": [
|
||||
"/app_data/images/a/slide_1.png",
|
||||
"/app_data/images/a/slide_2.png",
|
||||
],
|
||||
"fonts": {
|
||||
"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"
|
||||
},
|
||||
},
|
||||
)(),
|
||||
sql_session=session,
|
||||
)
|
||||
)
|
||||
|
||||
template_info = session.template_infos[template_info_id]
|
||||
assert template_info.slide_htmls == [BASIC_TEMPLATE_HTML, BASIC_TEMPLATE_HTML]
|
||||
assert template_info.slide_image_urls == [
|
||||
"/app_data/images/a/slide_1.png",
|
||||
"/app_data/images/a/slide_2.png",
|
||||
]
|
||||
|
||||
|
||||
def test_fonts_check_endpoint(monkeypatch):
|
||||
async def fake_font_check(_pptx_path: str, _temp_dir: str):
|
||||
return [("Inter", "https://fonts.googleapis.com/css2?family=Inter&display=swap")], [("Custom Font", None)]
|
||||
|
||||
async def fake_to_thread(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"templates.preview.get_available_and_unavailable_fonts_for_pptx",
|
||||
fake_font_check,
|
||||
)
|
||||
monkeypatch.setattr("templates.preview.asyncio.to_thread", fake_to_thread)
|
||||
|
||||
upload = SimpleUploadFile(
|
||||
filename="deck.pptx",
|
||||
content_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
data=b"fake-pptx",
|
||||
)
|
||||
response = asyncio.run(check_fonts_in_pptx_handler(pptx_file=upload))
|
||||
|
||||
assert response == FontCheckResponse(
|
||||
available_fonts=[
|
||||
{
|
||||
"name": "Inter",
|
||||
"url": "https://fonts.googleapis.com/css2?family=Inter&display=swap",
|
||||
}
|
||||
],
|
||||
unavailable_fonts=[{"name": "Custom Font", "url": None}],
|
||||
)
|
||||
|
||||
|
||||
def test_fonts_upload_and_preview_route_uses_new_handler(monkeypatch):
|
||||
async def fake_preview_handler(**kwargs):
|
||||
assert kwargs["pptx_file"].filename == "deck.pptx"
|
||||
return FontsUploadAndSlidesPreviewResponse(
|
||||
slide_image_urls=["/app_data/images/1/slide_1.png"],
|
||||
pptx_url="/app_data/uploads/template-previews/1/presentation.pptx",
|
||||
modified_pptx_url="/app_data/uploads/template-previews/1/presentation.pptx",
|
||||
fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"templates.handler.upload_fonts_and_slides_preview_handler",
|
||||
fake_preview_handler,
|
||||
)
|
||||
|
||||
upload = SimpleUploadFile(
|
||||
filename="deck.pptx",
|
||||
content_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
data=b"fake-pptx",
|
||||
)
|
||||
response = asyncio.run(upload_fonts_and_slides_preview(pptx_file=upload))
|
||||
|
||||
assert response == FontsUploadAndSlidesPreviewResponse(
|
||||
slide_image_urls=["/app_data/images/1/slide_1.png"],
|
||||
pptx_url="/app_data/uploads/template-previews/1/presentation.pptx",
|
||||
modified_pptx_url="/app_data/uploads/template-previews/1/presentation.pptx",
|
||||
fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"},
|
||||
)
|
||||
|
||||
|
||||
def test_provider_spec_mapping_and_restrictions(monkeypatch):
|
||||
monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.OPENAI)
|
||||
spec = get_template_provider_spec()
|
||||
assert spec.provider == LLMProvider.OPENAI
|
||||
assert spec.model == OPENAI_TEMPLATE_MODEL
|
||||
|
||||
monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.CODEX)
|
||||
spec = get_template_provider_spec()
|
||||
assert spec.provider == LLMProvider.CODEX
|
||||
assert spec.model == CODEX_TEMPLATE_MODEL
|
||||
|
||||
monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.GOOGLE)
|
||||
spec = get_template_provider_spec()
|
||||
assert spec.provider == LLMProvider.GOOGLE
|
||||
assert spec.model == GOOGLE_TEMPLATE_MODEL
|
||||
|
||||
monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.ANTHROPIC)
|
||||
spec = get_template_provider_spec()
|
||||
assert spec.provider == LLMProvider.ANTHROPIC
|
||||
assert spec.model == ANTHROPIC_TEMPLATE_MODEL
|
||||
|
||||
monkeypatch.setattr("templates.providers.get_llm_provider", lambda: LLMProvider.OLLAMA)
|
||||
with pytest.raises(Exception) as exc:
|
||||
get_template_provider_spec()
|
||||
assert "Template generation only supports OpenAI, Codex, Google, or Anthropic." in str(exc.value)
|
||||
|
||||
|
||||
def test_create_and_edit_slide_layout_routes_use_provider_layer(tmp_path, monkeypatch):
|
||||
session = FakeAsyncSession()
|
||||
image_path = tmp_path / "slide.png"
|
||||
image_path.write_bytes(PNG_BYTES)
|
||||
|
||||
template_info = TemplateCreateInfoModel(
|
||||
fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"},
|
||||
pptx_url="/app_data/uploads/template-previews/seed/presentation.pptx",
|
||||
slide_htmls=["<div>seed html</div>"],
|
||||
slide_image_urls=[str(image_path)],
|
||||
)
|
||||
session.add(template_info)
|
||||
|
||||
create_calls = []
|
||||
edit_calls = []
|
||||
|
||||
async def fake_generate_layout(**kwargs):
|
||||
create_calls.append(kwargs)
|
||||
return """```tsx
|
||||
import { z } from "zod";
|
||||
const layoutId = "title-image";
|
||||
const layoutName = "Title Image";
|
||||
const layoutDescription = "desc";
|
||||
function dynamicSlideLayout() { return <div>{image_url}{icon_url}{image_prompt}{icon_query}</div>; }
|
||||
export { layoutId };
|
||||
```"""
|
||||
|
||||
async def fake_edit_layout(**kwargs):
|
||||
edit_calls.append(kwargs)
|
||||
return "```tsx\nconst updatedLayout = true;\n```"
|
||||
|
||||
monkeypatch.setattr("templates.handler.generate_slide_layout_code", fake_generate_layout)
|
||||
monkeypatch.setattr("templates.handler.edit_slide_layout_code", fake_edit_layout)
|
||||
|
||||
create_response = asyncio.run(
|
||||
create_slide_layout(
|
||||
request=CreateSlideLayoutRequest(id=template_info.id, index=0),
|
||||
sql_session=session,
|
||||
)
|
||||
)
|
||||
assert "__image_url__" in create_response.react_component
|
||||
assert "__icon_url__" in create_response.react_component
|
||||
assert "__image_prompt__" in create_response.react_component
|
||||
assert "__icon_query__" in create_response.react_component
|
||||
assert "import " not in create_response.react_component
|
||||
assert "export " not in create_response.react_component
|
||||
assert re.search(r'layoutId\s*=\s*"title-image-\d{4}"', create_response.react_component)
|
||||
|
||||
assert create_calls
|
||||
assert create_calls[0]["system_prompt"]
|
||||
assert "#SLIDE HTML REFERENCE" in create_calls[0]["user_text"]
|
||||
assert create_calls[0]["media_type"] == "image/png"
|
||||
assert create_calls[0]["image_bytes"] == PNG_BYTES
|
||||
|
||||
edit_response = asyncio.run(
|
||||
edit_slide_layout(
|
||||
request=EditSlideLayoutRequest(
|
||||
react_component="const x = 1;",
|
||||
prompt="Move title up",
|
||||
)
|
||||
)
|
||||
)
|
||||
assert edit_response.react_component == "const updatedLayout = true;"
|
||||
|
||||
section_response = asyncio.run(
|
||||
edit_slide_layout_section(
|
||||
request=EditSlideLayoutSectionRequest(
|
||||
react_component="const x = 1;",
|
||||
section="header",
|
||||
prompt="Change spacing",
|
||||
)
|
||||
)
|
||||
)
|
||||
assert section_response.react_component == "const updatedLayout = true;"
|
||||
|
||||
assert len(edit_calls) == 2
|
||||
assert "#Prompt\nMove title up" in edit_calls[0]["user_text"]
|
||||
assert "#Section to make changes around\nheader" in edit_calls[1]["user_text"]
|
||||
|
||||
|
||||
def test_save_update_and_clone_template_flow():
|
||||
session = FakeAsyncSession()
|
||||
template_info = TemplateCreateInfoModel(
|
||||
fonts={"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"},
|
||||
pptx_url="/app_data/uploads/template-previews/seed/presentation.pptx",
|
||||
slide_htmls=["<div>seed html</div>"],
|
||||
slide_image_urls=["/app_data/images/seed/slide_1.png"],
|
||||
)
|
||||
session.add(template_info)
|
||||
|
||||
save_response = asyncio.run(
|
||||
save_template(
|
||||
request=SaveTemplateRequest(
|
||||
template_info_id=template_info.id,
|
||||
name="My Template",
|
||||
description="Saved from test",
|
||||
layouts=[
|
||||
SaveTemplateLayoutData(
|
||||
layout_id="title-image-1000",
|
||||
layout_name="Title Image",
|
||||
layout_code='const layoutId = "title-image-1000";',
|
||||
)
|
||||
],
|
||||
),
|
||||
sql_session=session,
|
||||
)
|
||||
)
|
||||
template_id = save_response.id
|
||||
|
||||
clone_layout_response = asyncio.run(
|
||||
clone_slide_layout(
|
||||
request=CloneSlideLayoutRequest(
|
||||
template_id=f"custom-{template_id}",
|
||||
layout_id="title-image-1000",
|
||||
layout_name="Cloned Layout",
|
||||
),
|
||||
sql_session=session,
|
||||
)
|
||||
)
|
||||
assert clone_layout_response.layout_name == "Cloned Layout"
|
||||
assert clone_layout_response.layout_id != "title-image-1000"
|
||||
|
||||
asyncio.run(
|
||||
save_slide_layout(
|
||||
request=SaveSlideLayoutRequest(
|
||||
template_id=template_id,
|
||||
layout_id="title-image-1000",
|
||||
layout_code='const layoutId = "title-image-1000"; const edited = true;',
|
||||
),
|
||||
sql_session=session,
|
||||
)
|
||||
)
|
||||
assert any("edited = true" in layout.layout_code for layout in session.layouts)
|
||||
|
||||
update_response = asyncio.run(
|
||||
update_template(
|
||||
request=UpdateTemplateRequest(
|
||||
id=template_id,
|
||||
layouts=[
|
||||
SaveTemplateLayoutData(
|
||||
layout_id="updated-layout-2000",
|
||||
layout_name="Updated Layout",
|
||||
layout_code='const layoutId = "updated-layout-2000";',
|
||||
)
|
||||
],
|
||||
),
|
||||
sql_session=session,
|
||||
)
|
||||
)
|
||||
assert update_response.id == template_id
|
||||
|
||||
layouts = [layout for layout in session.layouts if layout.presentation == template_id]
|
||||
assert session.templates[template_id] is not None
|
||||
assert len(layouts) == 1
|
||||
assert layouts[0].layout_id == "updated-layout-2000"
|
||||
assert layouts[0].fonts == {
|
||||
"Inter": "https://fonts.googleapis.com/css2?family=Inter&display=swap"
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from templates.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Dict
|
||||
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
import json
|
||||
from typing import Optional
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from templates.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from templates.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.presentation_layout import PresentationLayoutModel
|
||||
from templates.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
import re
|
||||
from typing import List
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue