presenton/servers/fastapi/templates/handler.py

707 lines
22 KiB
Python

import os
import random
import re
import uuid
from datetime import datetime
from typing import Any, List, Optional
import aiohttp
from fastapi import Body, Depends, File, Form, HTTPException, Path, Query, UploadFile
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import delete, select
from constants.presentation import DEFAULT_TEMPLATES
from models.sql.presentation_layout_code import PresentationLayoutCodeModel
from models.sql.template import TemplateModel
from models.sql.template_create_info import TemplateCreateInfoModel
from services.database import get_async_session
from services.export_task_service import EXPORT_TASK_SERVICE
from templates.example import build_template_example
from templates.get_layout_by_name import get_layout_by_name
from templates.presentation_layout import PresentationLayoutModel
from templates.preview import (
FontsUploadAndSlidesPreviewResponse,
upload_fonts_and_slides_preview_handler,
)
from templates.prompts import (
SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT,
SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT,
SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT,
)
from templates.providers import edit_slide_layout_code, generate_slide_layout_code
from utils.asset_directory_utils import (
resolve_app_path_to_filesystem,
resolve_image_path_to_filesystem,
)
class TemplateDetail(BaseModel):
id: str
name: str
total_layouts: Optional[int] = None
class TemplateLayoutData(BaseModel):
template: uuid.UUID
layout_id: str
layout_name: str
layout_code: str
fonts: Optional[Any] = None
class TemplateData(BaseModel):
id: uuid.UUID
init_id: Optional[uuid.UUID] = None
name: str
description: Optional[str] = None
created_at: datetime
class GetTemplateLayoutsResponse(BaseModel):
layouts: list[TemplateLayoutData]
template: Optional[TemplateData] = None
fonts: Optional[Any] = None
class TemplateExample(BaseModel):
template: str
slides: List[dict]
class CreateTemplateInitRequest(BaseModel):
pptx_url: str
slide_image_urls: List[str]
fonts: dict = {}
class CreateSlideLayoutRequest(BaseModel):
id: uuid.UUID
index: int
class CreateSlideLayoutResponse(BaseModel):
react_component: str
class EditSlideLayoutRequest(BaseModel):
react_component: str
prompt: str
class EditSlideLayoutResponse(CreateSlideLayoutResponse):
pass
class EditSlideLayoutSectionRequest(BaseModel):
react_component: str
section: str
prompt: str
class EditSlideLayoutSectionResponse(CreateSlideLayoutResponse):
pass
class SaveTemplateLayoutData(BaseModel):
layout_id: str
layout_name: str
layout_code: str
class SaveTemplateRequest(BaseModel):
template_info_id: uuid.UUID
name: str
description: Optional[str] = None
layouts: List[SaveTemplateLayoutData]
class SaveTemplateResponse(BaseModel):
id: uuid.UUID
name: str
description: Optional[str] = None
created_at: datetime
class CloneTemplateRequest(BaseModel):
id: str
name: str
description: Optional[str] = None
class UpdateTemplateRequest(BaseModel):
id: uuid.UUID
layouts: List[SaveTemplateLayoutData]
class SaveSlideLayoutRequest(BaseModel):
template_id: uuid.UUID
layout_id: str
layout_code: str
class CloneSlideLayoutRequest(BaseModel):
template_id: str
layout_id: str
layout_name: Optional[str] = None
def _strip_code_fences(value: str) -> str:
return (
value.replace("```tsx", "")
.replace("```typescript", "")
.replace("```ts", "")
.replace("```", "")
.strip()
)
def _normalize_layout_code_for_create(code: str) -> str:
normalized = _strip_code_fences(code)
normalized = (
normalized.replace("image_url", "__image_url__")
.replace("icon_url", "__icon_url__")
.replace("image_prompt", "__image_prompt__")
.replace("icon_query", "__icon_query__")
)
first_import_match = re.search(r"(?m)^\s*import\b", normalized)
if first_import_match:
normalized = normalized[first_import_match.start() :]
first_export_match = re.search(r"(?m)^\s*export\b", normalized)
if first_export_match:
normalized = normalized[: first_export_match.start()]
normalized = re.sub(
r"(?ms)^\s*(?:import|export)\b.*?;(?:\r?\n|$)",
"",
normalized,
)
normalized = re.sub(
r"(?m)^\s*(?:import|export)\b.*(?:\r?\n|$)",
"",
normalized,
)
normalized = normalized.strip()
normalized = re.sub(
r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])',
lambda match: (
match.group(0)
if re.search(r"-\d{4}$", match.group(2))
else f"{match.group(1)}{match.group(2)}-{random.randint(1000, 9999)}{match.group(3)}"
),
normalized,
)
return normalized
def _update_layout_id_in_code(code: str) -> tuple[str, str]:
match = re.search(r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])', code)
if not match:
raise HTTPException(status_code=400, detail="layoutId not found in layout code")
current_id = match.group(2)
suffix = f"{random.randint(1000, 9999)}"
new_id = re.sub(r"-\d{4}$", f"-{suffix}", current_id)
if new_id == current_id:
new_id = f"{current_id}-{suffix}"
new_code = re.sub(
r'(layoutId\s*=\s*["\'])([^"\']+)(["\'])',
f"\\1{new_id}\\3",
code,
count=1,
)
return new_code, new_id
async def _download_image_bytes(image_url: str) -> bytes:
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
raise HTTPException(
status_code=400,
detail=f"Failed to download slide image: {image_url}",
)
return await response.read()
async def _read_image_bytes_and_media_type(image_url: str) -> tuple[bytes, str]:
actual_image_path = resolve_image_path_to_filesystem(image_url)
if actual_image_path and os.path.isfile(actual_image_path):
with open(actual_image_path, "rb") as image_file:
image_bytes = image_file.read()
file_extension = os.path.splitext(actual_image_path)[1].lower()
else:
image_bytes = await _download_image_bytes(image_url)
file_extension = os.path.splitext(image_url)[1].lower()
media_type_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}
return image_bytes, media_type_map.get(file_extension, "image/png")
async def get_all_templates(
include_defaults: bool = Query(
default=True, description="Whether to include default templates"
),
sql_session: AsyncSession = Depends(get_async_session),
):
result = await sql_session.execute(
select(
TemplateModel.id,
TemplateModel.name,
func.count(PresentationLayoutCodeModel.id).label("total_layouts"),
)
.join(
PresentationLayoutCodeModel,
PresentationLayoutCodeModel.presentation == TemplateModel.id,
)
.group_by(TemplateModel.id, TemplateModel.name)
)
rows = result.all()
templates: list[TemplateDetail] = []
if include_defaults:
templates.extend(
TemplateDetail(id=template, name=template) for template in DEFAULT_TEMPLATES
)
templates.extend(
TemplateDetail(
id=f"custom-{template_id}",
name=template_name,
total_layouts=total_layouts,
)
for template_id, template_name, total_layouts in rows
)
return templates
async def get_layouts(
template_id: str = Path(..., description="The id of the template"),
session: AsyncSession = Depends(get_async_session),
):
if not template_id or not template_id.strip():
raise HTTPException(status_code=400, detail="Template ID cannot be empty")
try:
cleaned_template_id = template_id.replace("custom-", "")
template_id_uuid = uuid.UUID(cleaned_template_id)
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc
result = await session.execute(
select(PresentationLayoutCodeModel).where(
PresentationLayoutCodeModel.presentation == template_id_uuid
)
)
layouts_db = result.scalars().all()
if not layouts_db:
raise HTTPException(
status_code=404,
detail=f"No layouts found for template ID: {template_id}",
)
template_meta = await session.get(TemplateModel, template_id_uuid)
template = None
if template_meta:
template = TemplateData(
id=template_id_uuid,
init_id=None,
name=template_meta.name,
description=template_meta.description,
created_at=template_meta.created_at,
)
layouts = [
TemplateLayoutData(
template=template_id_uuid,
layout_id=layout.layout_id,
layout_name=layout.layout_name,
layout_code=layout.layout_code,
fonts=layout.fonts,
)
for layout in layouts_db
]
return GetTemplateLayoutsResponse(
layouts=layouts,
template=template,
fonts=layouts[0].fonts if layouts else None,
)
async def get_template_by_id(
id: str = Path(
...,
description=f"The id of the template, must be one of {', '.join(DEFAULT_TEMPLATES)} or your custom template",
),
sql_session: AsyncSession = Depends(get_async_session),
):
if id.startswith("custom-"):
try:
template_id = uuid.UUID(id.replace("custom-", ""))
except Exception as exc:
raise HTTPException(
status_code=400,
detail="Template not found. Please use a valid template.",
) from exc
template = await sql_session.get(TemplateModel, template_id)
if not template:
raise HTTPException(
status_code=400,
detail="Template not found. Please use a valid template.",
)
return await get_layout_by_name(id)
async def get_template_example(
id: str = Path(
...,
description=f"The id of the template, must be one of {', '.join(DEFAULT_TEMPLATES)} or your custom template",
),
sql_session: AsyncSession = Depends(get_async_session),
):
template = await get_template_by_id(id=id, sql_session=sql_session)
return TemplateExample(**build_template_example(id, template))
async def upload_fonts_and_slides_preview(
pptx_file: UploadFile = File(..., description="PPTX file to preview"),
font_files: Optional[List[UploadFile]] = File(
default=None, description="Font files to upload"
),
original_font_names: Optional[List[str]] = Form(default=None),
):
return await upload_fonts_and_slides_preview_handler(
pptx_file=pptx_file,
font_files=font_files,
original_font_names=original_font_names,
max_slides=25,
)
async def init_create_template(
request: CreateTemplateInitRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
if not request.slide_image_urls:
raise HTTPException(
status_code=400, detail="At least one slide image is required"
)
pptx_path = resolve_app_path_to_filesystem(request.pptx_url)
if not pptx_path or not os.path.isfile(pptx_path):
raise HTTPException(status_code=400, detail="PPTX file not found")
pptx_document = await EXPORT_TASK_SERVICE.convert_pptx_to_html(
pptx_path, get_fonts=False
)
if not pptx_document.slides:
raise HTTPException(
status_code=500,
detail="PPTX-to-HTML export returned no slides",
)
if len(pptx_document.slides) < len(request.slide_image_urls):
raise HTTPException(
status_code=400,
detail=(
"PPTX-to-HTML export returned fewer slides than the preview images. "
f"Expected at least {len(request.slide_image_urls)}, got {len(pptx_document.slides)}."
),
)
slide_htmls = pptx_document.slides[: len(request.slide_image_urls)]
template_create_info = TemplateCreateInfoModel(
fonts=request.fonts or {},
pptx_url=request.pptx_url,
slide_image_urls=request.slide_image_urls,
slide_htmls=slide_htmls,
)
sql_session.add(template_create_info)
await sql_session.commit()
await sql_session.refresh(template_create_info)
return template_create_info.id
async def create_slide_layout(
request: CreateSlideLayoutRequest = Body(...),
sql_session: AsyncSession = Depends(get_async_session),
):
template_info = await sql_session.get(TemplateCreateInfoModel, request.id)
if not template_info:
raise HTTPException(status_code=400, detail="Template not found")
total_slides = len(template_info.slide_htmls)
if request.index < 0 or request.index >= total_slides:
raise HTTPException(status_code=400, detail="Invalid slide index")
slide_html = template_info.slide_htmls[request.index]
slide_image_url = template_info.slide_image_urls[request.index]
image_bytes, media_type = await _read_image_bytes_and_media_type(slide_image_url)
fonts_text = ""
if template_info.fonts:
font_names = [font.replace(" ", "_") for font in template_info.fonts.keys()]
fonts_text = "#PROVIDED FONTS\n- " + "\n- ".join(font_names)
user_text = f"{fonts_text}\n\n#SLIDE HTML REFERENCE\n{slide_html}"
react_component = await generate_slide_layout_code(
system_prompt=SLIDE_LAYOUT_CREATION_SYSTEM_PROMPT,
user_text=user_text,
image_bytes=image_bytes,
media_type=media_type,
)
normalized_react_component = _normalize_layout_code_for_create(react_component)
return CreateSlideLayoutResponse(react_component=normalized_react_component)
async def edit_slide_layout(
request: EditSlideLayoutRequest,
):
user_text = f"#Prompt\n{request.prompt}\n\n#TSX code\n{request.react_component}"
react_component = await edit_slide_layout_code(
system_prompt=SLIDE_LAYOUT_EDIT_SYSTEM_PROMPT,
user_text=user_text,
)
return EditSlideLayoutResponse(react_component=_strip_code_fences(react_component))
async def edit_slide_layout_section(
request: EditSlideLayoutSectionRequest,
):
user_text = (
f"#Prompt\n{request.prompt}\n\n"
f"#Section to make changes around\n{request.section}\n\n"
f"#TSX code\n{request.react_component}"
)
react_component = await edit_slide_layout_code(
system_prompt=SLIDE_LAYOUT_EDIT_SECTION_SYSTEM_PROMPT,
user_text=user_text,
)
return EditSlideLayoutSectionResponse(
react_component=_strip_code_fences(react_component)
)
async def save_template(
request: SaveTemplateRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
if not request.layouts:
raise HTTPException(status_code=400, detail="Layouts are required")
template_info = await sql_session.get(TemplateCreateInfoModel, request.template_info_id)
if not template_info:
raise HTTPException(status_code=400, detail="Template info not found")
template = TemplateModel(
id=uuid.uuid4(),
name=request.name,
description=request.description,
)
sql_session.add(template)
sql_session.add_all(
[
PresentationLayoutCodeModel(
presentation=template.id,
layout_id=layout.layout_id,
layout_name=layout.layout_name,
layout_code=layout.layout_code,
fonts=template_info.fonts,
)
for layout in request.layouts
]
)
await sql_session.commit()
await sql_session.refresh(template)
return SaveTemplateResponse(
id=template.id,
name=template.name,
description=template.description,
created_at=template.created_at,
)
async def clone_template(
request: CloneTemplateRequest = Body(...),
sql_session: AsyncSession = Depends(get_async_session),
):
if not request.id or not request.id.strip():
raise HTTPException(status_code=400, detail="Template ID cannot be empty")
try:
template_id_uuid = uuid.UUID(request.id.replace("custom-", ""))
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc
template = await sql_session.get(TemplateModel, template_id_uuid)
if not template:
raise HTTPException(
status_code=400,
detail="Template not found. Please use a valid template.",
)
result = await sql_session.execute(
select(PresentationLayoutCodeModel).where(
PresentationLayoutCodeModel.presentation == template_id_uuid
)
)
layouts_db = result.scalars().all()
if not layouts_db:
raise HTTPException(status_code=400, detail="No layouts found for template")
new_template = TemplateModel(
id=uuid.uuid4(),
name=request.name,
description=template.description
if request.description is None
else request.description,
)
sql_session.add(new_template)
sql_session.add_all(
[
PresentationLayoutCodeModel(
presentation=new_template.id,
layout_id=layout.layout_id,
layout_name=layout.layout_name,
layout_code=layout.layout_code,
fonts=layout.fonts,
)
for layout in layouts_db
]
)
await sql_session.commit()
await sql_session.refresh(new_template)
return SaveTemplateResponse(
id=new_template.id,
name=new_template.name,
description=new_template.description,
created_at=new_template.created_at,
)
async def update_template(
request: UpdateTemplateRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
if not request.layouts:
raise HTTPException(status_code=400, detail="Layouts are required")
template = await sql_session.get(TemplateModel, request.id)
if not template:
raise HTTPException(status_code=400, detail="Template not found")
existing_layout = await sql_session.scalar(
select(PresentationLayoutCodeModel).where(
PresentationLayoutCodeModel.presentation == request.id
)
)
fonts = existing_layout.fonts if existing_layout else None
await sql_session.execute(
delete(PresentationLayoutCodeModel).where(
PresentationLayoutCodeModel.presentation == request.id
)
)
sql_session.add_all(
[
PresentationLayoutCodeModel(
presentation=template.id,
layout_id=layout.layout_id,
layout_name=layout.layout_name,
layout_code=layout.layout_code,
fonts=fonts,
)
for layout in request.layouts
]
)
await sql_session.commit()
return SaveTemplateResponse(
id=template.id,
name=template.name,
description=template.description,
created_at=template.created_at,
)
async def save_slide_layout(
request: SaveSlideLayoutRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
template = await sql_session.get(TemplateModel, request.template_id)
if not template:
raise HTTPException(status_code=400, detail="Template not found")
layout = await sql_session.scalar(
select(PresentationLayoutCodeModel).where(
PresentationLayoutCodeModel.presentation == request.template_id,
PresentationLayoutCodeModel.layout_id == request.layout_id,
)
)
if not layout:
raise HTTPException(status_code=400, detail="Layout not found")
layout.layout_code = request.layout_code
sql_session.add(layout)
await sql_session.commit()
async def clone_slide_layout(
request: CloneSlideLayoutRequest = Body(...),
sql_session: AsyncSession = Depends(get_async_session),
):
if not request.template_id or not request.template_id.strip():
raise HTTPException(status_code=400, detail="Template ID cannot be empty")
try:
template_id_uuid = uuid.UUID(request.template_id.replace("custom-", ""))
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid custom template ID") from exc
template = await sql_session.get(TemplateModel, template_id_uuid)
if not template:
raise HTTPException(status_code=400, detail="Template not found")
layout = await sql_session.scalar(
select(PresentationLayoutCodeModel).where(
PresentationLayoutCodeModel.presentation == template_id_uuid,
PresentationLayoutCodeModel.layout_id == request.layout_id,
)
)
if not layout:
raise HTTPException(status_code=400, detail="Layout not found")
new_layout_code, new_layout_id = _update_layout_id_in_code(layout.layout_code)
new_layout = PresentationLayoutCodeModel(
presentation=template_id_uuid,
layout_id=new_layout_id,
layout_name=request.layout_name or layout.layout_name,
layout_code=new_layout_code,
fonts=layout.fonts,
)
sql_session.add(new_layout)
await sql_session.commit()
await sql_session.refresh(new_layout)
return SaveTemplateLayoutData(
layout_id=new_layout.layout_id,
layout_name=new_layout.layout_name,
layout_code=new_layout.layout_code,
)