import asyncio import json import os import random from typing import Annotated, List, Literal, Optional import uuid from annotated_types import Len from fastapi import APIRouter, Body, File, HTTPException, UploadFile from fastapi.responses import StreamingResponse from sqlalchemy import delete from sqlmodel import select from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES from models.presentation_and_path import PresentationPathAndEditPath from models.presentation_from_template import GetPresentationUsingTemplateRequest from models.presentation_outline_model import ( PresentationOutlineModel, SlideOutlineModel, ) from models.pptx_models import PptxPresentationModel from models.presentation_layout import PresentationLayoutModel from models.presentation_structure_model import PresentationStructureModel from models.presentation_with_slides import PresentationWithSlides from services.get_layout_by_name import get_layout_by_name from services.icon_finder_service import IconFinderService from services.image_generation_service import ImageGenerationService from utils.dict_utils import deep_update from utils.export_utils import export_presentation from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline from models.sql.slide import SlideModel from models.sse_response import SSECompleteResponse, SSEResponse from services import TEMP_FILE_SERVICE from services.database import get_sql_session from services.documents_loader import DocumentsLoader from models.sql.presentation import PresentationModel from services.pptx_presentation_creator import PptxPresentationCreator from utils.asset_directory_utils import get_exports_directory, get_images_directory from utils.llm_calls.generate_document_summary import generate_document_summary from utils.llm_calls.generate_presentation_structure import ( generate_presentation_structure, ) from utils.llm_calls.generate_slide_content import ( get_slide_content_from_type_and_outline, ) from utils.process_slides import process_slide_and_fetch_assets from utils.randomizers import get_random_uuid from utils.validators import validate_files PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"]) @PRESENTATION_ROUTER.get("", response_model=PresentationWithSlides) def get_presentation(id: str): with get_sql_session() as sql_session: presentation = sql_session.get(PresentationModel, id) if not presentation: raise HTTPException(404, "Presentation not found") slides = sql_session.exec( select(SlideModel) .where(SlideModel.presentation == id) .order_by(SlideModel.index) ) return PresentationWithSlides( **presentation.model_dump(), slides=slides, ) @PRESENTATION_ROUTER.delete("", status_code=204) def delete_presentation(id: str): with get_sql_session() as sql_session: presentation = sql_session.get(PresentationModel, id) if not presentation: raise HTTPException(404, "Presentation not found") slides = sql_session.exec( select(SlideModel).where(SlideModel.presentation == id) ).all() for slide in slides: sql_session.delete(slide) sql_session.delete(presentation) sql_session.commit() @PRESENTATION_ROUTER.get("/all", response_model=List[PresentationWithSlides]) def get_all_presentations(): with get_sql_session() as sql_session: presentations_with_slides = [] presentations = sql_session.exec(select(PresentationModel)) for presentation in presentations: slides = sql_session.exec( select(SlideModel) .where(SlideModel.presentation == presentation.id) .where(SlideModel.index == 0) ).all() if not slides: continue presentations_with_slides.append( PresentationWithSlides( **presentation.model_dump(), slides=slides, ) ) return presentations_with_slides @PRESENTATION_ROUTER.post("/create", response_model=PresentationModel) async def create_presentation( prompt: Annotated[str, Body()], n_slides: Annotated[int, Body()], language: Annotated[str, Body()], file_paths: Annotated[Optional[List[str]], Body()] = None, ): presentation_id = str(uuid.uuid4()) summary = None if file_paths: temp_dir = TEMP_FILE_SERVICE.create_temp_dir(presentation_id) documents_loader = DocumentsLoader(file_paths=file_paths) await documents_loader.load_documents(temp_dir) summary = await generate_document_summary(documents_loader.documents) presentation = PresentationModel( id=presentation_id, prompt=prompt, n_slides=n_slides, language=language, summary=summary, ) with get_sql_session() as sql_session: sql_session.add(presentation) sql_session.commit() sql_session.refresh(presentation) return presentation @PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel) async def prepare_presentation( presentation_id: Annotated[str, Body()], outlines: Annotated[List[SlideOutlineModel], Body()], layout: Annotated[PresentationLayoutModel, Body()], title: Annotated[Optional[str], Body()] = None, ): if not outlines: raise HTTPException(status_code=400, detail="Outlines are required") with get_sql_session() as sql_session: presentation = sql_session.get(PresentationModel, presentation_id) total_slide_layouts = len(layout.slides) total_outlines = len(outlines) if layout.ordered: presentation_structure = layout.to_presentation_structure() else: presentation_structure: PresentationStructureModel = ( await generate_presentation_structure( presentation_outline=presentation.get_presentation_outline(), presentation_layout=layout, ) ) presentation_structure.slides = presentation_structure.slides[: len(outlines)] for index in range(total_outlines): random_slide_index = random.randint(0, total_slide_layouts - 1) if index >= total_outlines: presentation_structure.slides.append(random_slide_index) continue if presentation_structure.slides[index] >= total_slide_layouts: presentation_structure.slides[index] = random_slide_index with get_sql_session() as sql_session: sql_session.add(presentation) presentation.outlines = [each.model_dump() for each in outlines] presentation.title = title or presentation.title presentation.set_layout(layout) presentation.set_structure(presentation_structure) sql_session.commit() sql_session.refresh(presentation) return presentation @PRESENTATION_ROUTER.get("/stream", response_model=PresentationWithSlides) async def stream_presentation(presentation_id: str): with get_sql_session() as sql_session: presentation = sql_session.get(PresentationModel, presentation_id) if not presentation: raise HTTPException(status_code=404, detail="Presentation not found") if not presentation.structure: raise HTTPException( status_code=400, detail="Presentation not prepared for stream", ) if not presentation.outlines: raise HTTPException( status_code=400, detail="Outlines can not be empty", ) image_generation_service = ImageGenerationService(get_images_directory()) icon_finder_service = IconFinderService() async def inner(): structure = presentation.get_structure() layout = presentation.get_layout() outline = presentation.get_presentation_outline() # These tasks will be gathered and awaited after all slides are generated async_assets_generation_tasks = [] slides: List[SlideModel] = [] yield SSEResponse( event="response", data=json.dumps({"type": "chunk", "chunk": '{ "slides": [ '}), ).to_string() for i, slide_layout_index in enumerate(structure.slides): slide_layout = layout.slides[slide_layout_index] slide_content = await get_slide_content_from_type_and_outline( slide_layout, outline.slides[i] ) slide = SlideModel( presentation=presentation_id, layout_group=layout.name, layout=slide_layout.id, index=i, content=slide_content, ) slides.append(slide) # This will mutate slide async_assets_generation_tasks.append( process_slide_and_fetch_assets( image_generation_service, icon_finder_service, slide ) ) # Give control to the event loop await asyncio.sleep(0) yield SSEResponse( event="response", data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}), ).to_string() yield SSEResponse( event="response", data=json.dumps({"type": "chunk", "chunk": " ] }"}), ).to_string() generated_assets_lists = await asyncio.gather(*async_assets_generation_tasks) generated_assets = [] for assets_list in generated_assets_lists: generated_assets.extend(assets_list) with get_sql_session() as sql_session: sql_session.add(presentation) sql_session.add_all(slides) sql_session.add_all(generated_assets) sql_session.commit() sql_session.refresh(presentation) for each_slide in slides: sql_session.refresh(each_slide) response = PresentationWithSlides( **presentation.model_dump(), slides=slides, ) yield SSECompleteResponse( key="presentation", value=response.model_dump(mode="json"), ).to_string() return StreamingResponse(inner(), media_type="text/event-stream") @PRESENTATION_ROUTER.put("/update", response_model=PresentationWithSlides) def update_presentation( presentation_with_slides: Annotated[PresentationWithSlides, Body()], ): updated_presentation = presentation_with_slides.to_presentation_model() updated_slides = presentation_with_slides.slides with get_sql_session() as sql_session: presentation = sql_session.get(PresentationModel, updated_presentation.id) if not presentation: raise HTTPException(status_code=404, detail="Presentation not found") presentation.sqlmodel_update(updated_presentation) sql_session.exec( delete(SlideModel).where(SlideModel.presentation == updated_presentation.id) ) sql_session.add_all(updated_slides) sql_session.commit() sql_session.refresh(presentation) for slide in updated_slides: sql_session.refresh(slide) return PresentationWithSlides( **presentation.model_dump(), slides=updated_slides, ) @PRESENTATION_ROUTER.post("/export/pptx", response_model=str) async def create_pptx(pptx_model: Annotated[PptxPresentationModel, Body()]): temp_dir = TEMP_FILE_SERVICE.create_temp_dir() pptx_creator = PptxPresentationCreator(pptx_model, temp_dir) await pptx_creator.create_ppt() export_directory = get_exports_directory() pptx_path = os.path.join( export_directory, f"{pptx_model.name or get_random_uuid()}.pptx" ) pptx_creator.save(pptx_path) return pptx_path @PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath) async def generate_presentation_api( prompt: Annotated[str, Body()], n_slides: Annotated[int, Body()] = 8, language: Annotated[str, Body()] = "English", layout: Annotated[str, Body()] = "general", files: Annotated[Optional[List[UploadFile]], File()] = None, export_as: Annotated[Literal["pptx", "pdf"], Body()] = "pptx", ): validate_files(files, True, True, 50, UPLOAD_ACCEPTED_FILE_TYPES) presentation_id = get_random_uuid() # 1. Save uploaded files file_paths = [] if files: temp_dir = TEMP_FILE_SERVICE.create_temp_dir() for upload in files: file_path = os.path.join(temp_dir, upload.filename) with open(file_path, "wb") as f: f.write(await upload.read()) file_paths.append(file_path) # 2. Create Presentation Summary (if documents are provided) summary = None if file_paths: temp_dir = TEMP_FILE_SERVICE.create_temp_dir(presentation_id) documents_loader = DocumentsLoader(file_paths=file_paths) await documents_loader.load_documents(temp_dir) summary = await generate_document_summary(documents_loader.documents) # 3. Generate Outlines presentation_content_text = "" async for chunk in generate_ppt_outline( prompt, n_slides, language, summary, ): presentation_content_text += chunk presentation_content_json = json.loads(presentation_content_text) presentation_content = PresentationOutlineModel(**presentation_content_json) outlines = presentation_content.slides[:n_slides] total_outlines = len(outlines) print("-" * 40) print(f"Generated {total_outlines} outlines for the presentation") # 4. Parse Layouts layout = await get_layout_by_name(layout) total_slide_layouts = len(layout.slides) # 5. Generate Structure if layout.ordered: presentation_structure = layout.to_presentation_structure() else: presentation_structure: PresentationStructureModel = ( await generate_presentation_structure( presentation_outline=PresentationOutlineModel( title=presentation_content.title, slides=outlines, notes=presentation_content.notes, ), presentation_layout=layout, ) ) presentation_structure.slides = presentation_structure.slides[:total_outlines] for index in range(total_outlines): random_slide_index = random.randint(0, total_slide_layouts - 1) if index >= total_outlines: presentation_structure.slides.append(random_slide_index) continue if presentation_structure.slides[index] >= total_slide_layouts: presentation_structure.slides[index] = random_slide_index # 6. Create PresentationModel presentation = PresentationModel( id=presentation_id, prompt=prompt, n_slides=n_slides, language=language, title=presentation_content.title, summary=summary, outlines=[each.model_dump() for each in outlines], notes=presentation_content.notes, layout=layout.model_dump(), structure=presentation_structure.model_dump(), ) image_generation_service = ImageGenerationService(get_images_directory()) icon_finder_service = IconFinderService() async_asset_generation_tasks = [] # 7. Generate slide content and save slides slides: List[SlideModel] = [] slide_contents: List[dict] = [] for i, slide_layout_index in enumerate(presentation_structure.slides): slide_layout = layout.slides[slide_layout_index] print(f"Generating content for slide {i} with layout {slide_layout.id}") slide_content = await get_slide_content_from_type_and_outline( slide_layout, outlines[i] ) slide = SlideModel( presentation=presentation_id, layout_group=layout.name, layout=slide_layout.id, index=i, content=slide_content, ) async_asset_generation_tasks.append( process_slide_and_fetch_assets( image_generation_service, icon_finder_service, slide ) ) slides.append(slide) slide_contents.append(slide_content) generated_assets_lists = await asyncio.gather(*async_asset_generation_tasks) generated_assets = [] for assets_list in generated_assets_lists: generated_assets.extend(assets_list) # 8. Save PresentationModel and Slides with get_sql_session() as sql_session: sql_session.add(presentation) sql_session.add_all(slides) sql_session.add_all(generated_assets) sql_session.commit() # 9. Export presentation_and_path = await export_presentation( presentation_id, presentation_content.title, export_as ) return PresentationPathAndEditPath( **presentation_and_path.model_dump(), edit_path=f"/presentation?id={presentation_id}", ) @PRESENTATION_ROUTER.post("/from-template", response_model=PresentationPathAndEditPath) async def from_template( data: Annotated[GetPresentationUsingTemplateRequest, Body()], ): with get_sql_session() as sql_session: presentation = sql_session.get(PresentationModel, data.presentation_id) if not presentation: raise HTTPException(status_code=404, detail="Presentation not found") slides = sql_session.exec( select(SlideModel).where(SlideModel.presentation == data.presentation_id) ).all() new_presentation = presentation.get_new_presentation() new_slides = [] for each_slide in slides: updated_content = None new_slide_data = list(filter(lambda x: x.index == each_slide.index, data.data)) if new_slide_data: updated_content = deep_update(each_slide.content, new_slide_data[0].content) print(f"Updated content for slide {each_slide.index}: {updated_content}") new_slides.append( each_slide.get_new_slide(new_presentation.id, updated_content) ) with get_sql_session() as sql_session: sql_session.add(new_presentation) sql_session.add_all(new_slides) sql_session.commit() sql_session.refresh(new_presentation) presentation_and_path = await export_presentation( new_presentation.id, new_presentation.title, data.export_as ) return PresentationPathAndEditPath( **presentation_and_path.model_dump(), edit_path=f"/presentation?id={new_presentation.id}", )