1029 lines
37 KiB
Python
1029 lines
37 KiB
Python
import asyncio
|
|
from datetime import datetime
|
|
import json
|
|
import os
|
|
import random
|
|
import traceback
|
|
from typing import Annotated, List, Literal, Optional, Tuple
|
|
import dirtyjson
|
|
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Path
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy import delete
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlmodel import select
|
|
from constants.presentation import DEFAULT_TEMPLATES, MAX_NUMBER_OF_SLIDES
|
|
from enums.webhook_event import WebhookEvent
|
|
from models.api_error_model import APIErrorModel
|
|
from models.generate_presentation_request import GeneratePresentationRequest
|
|
from models.presentation_and_path import PresentationPathAndEditPath
|
|
from models.presentation_from_template import EditPresentationRequest
|
|
from models.presentation_outline_model import (
|
|
PresentationOutlineModel,
|
|
SlideOutlineModel,
|
|
)
|
|
from enums.tone import Tone
|
|
from enums.verbosity import Verbosity
|
|
from models.presentation_structure_model import PresentationStructureModel
|
|
from models.presentation_with_slides import (
|
|
PresentationWithSlides,
|
|
)
|
|
from models.sql.template import TemplateModel
|
|
from services.documents_loader import DocumentsLoader
|
|
from services.webhook_service import WebhookService
|
|
from services.image_generation_service import ImageGenerationService
|
|
from utils.dict_utils import deep_update
|
|
from utils.export_utils import export_presentation
|
|
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
|
from models.sql.slide import SlideModel
|
|
from models.sql.presentation_layout_code import PresentationLayoutCodeModel
|
|
from models.sse_response import SSECompleteResponse, SSEErrorResponse, SSEResponse
|
|
|
|
from services.database import get_async_session
|
|
from services.concurrent_service import CONCURRENT_SERVICE
|
|
from models.sql.presentation import PresentationModel
|
|
from models.sql.async_presentation_generation_status import (
|
|
AsyncPresentationGenerationTaskModel,
|
|
)
|
|
from utils.asset_directory_utils import get_images_directory
|
|
from utils.llm_calls.generate_presentation_structure import (
|
|
generate_presentation_structure,
|
|
)
|
|
from utils.llm_calls.generate_slide_content import (
|
|
get_slide_content_from_type_and_outline,
|
|
)
|
|
from utils.ppt_utils import (
|
|
select_toc_or_list_slide_layout_index,
|
|
)
|
|
from utils.outline_utils import (
|
|
get_images_for_slides_from_outline,
|
|
get_no_of_outlines_to_generate_for_n_slides,
|
|
get_no_of_toc_required_for_n_outlines,
|
|
get_presentation_outline_model_with_toc,
|
|
get_presentation_title_from_presentation_outline,
|
|
)
|
|
from utils.process_slides import (
|
|
process_slide_add_placeholder_assets,
|
|
process_slide_and_fetch_assets,
|
|
)
|
|
from templates.get_layout_by_name import get_layout_by_name
|
|
from templates.presentation_layout import PresentationLayoutModel
|
|
import uuid
|
|
|
|
|
|
PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"])
|
|
|
|
|
|
def _extract_custom_template_id(layout_name: Optional[str]) -> Optional[uuid.UUID]:
|
|
if not layout_name or not layout_name.startswith("custom-"):
|
|
return None
|
|
try:
|
|
return uuid.UUID(layout_name.replace("custom-", ""))
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
async def _resolve_presentation_fonts(
|
|
presentation: PresentationModel,
|
|
slides: List[SlideModel],
|
|
sql_session: AsyncSession,
|
|
):
|
|
candidate_template_ids: List[uuid.UUID] = []
|
|
seen = set()
|
|
|
|
layout_name = None
|
|
if isinstance(presentation.layout, dict):
|
|
layout_name = presentation.layout.get("name")
|
|
layout_template_id = _extract_custom_template_id(layout_name)
|
|
if layout_template_id and layout_template_id not in seen:
|
|
candidate_template_ids.append(layout_template_id)
|
|
seen.add(layout_template_id)
|
|
|
|
for slide in slides:
|
|
template_id = _extract_custom_template_id(slide.layout_group)
|
|
if template_id and template_id not in seen:
|
|
candidate_template_ids.append(template_id)
|
|
seen.add(template_id)
|
|
|
|
for template_id in candidate_template_ids:
|
|
result = await sql_session.execute(
|
|
select(PresentationLayoutCodeModel.fonts).where(
|
|
PresentationLayoutCodeModel.presentation == template_id
|
|
)
|
|
)
|
|
fonts_list = result.scalars().all()
|
|
for fonts in fonts_list:
|
|
if fonts is not None:
|
|
return fonts
|
|
|
|
return None
|
|
|
|
|
|
def _insert_toc_layouts(
|
|
structure: PresentationStructureModel,
|
|
n_toc_slides: int,
|
|
include_title_slide: bool,
|
|
toc_slide_layout_index: int,
|
|
):
|
|
if n_toc_slides <= 0 or toc_slide_layout_index == -1:
|
|
return
|
|
|
|
insertion_index = 1 if include_title_slide else 0
|
|
for i in range(n_toc_slides):
|
|
structure.slides.insert(insertion_index + i, toc_slide_layout_index)
|
|
|
|
|
|
@PRESENTATION_ROUTER.get("/all", response_model=List[PresentationWithSlides])
|
|
async def get_all_presentations(sql_session: AsyncSession = Depends(get_async_session)):
|
|
query = (
|
|
select(PresentationModel, SlideModel)
|
|
.join(
|
|
SlideModel,
|
|
(SlideModel.presentation == PresentationModel.id) & (SlideModel.index == 0),
|
|
)
|
|
.order_by(PresentationModel.created_at.desc())
|
|
)
|
|
|
|
results = await sql_session.execute(query)
|
|
rows = results.all()
|
|
presentations_with_slides = []
|
|
for presentation, first_slide in rows:
|
|
slides = [first_slide]
|
|
fonts = await _resolve_presentation_fonts(presentation, slides, sql_session)
|
|
presentations_with_slides.append(
|
|
PresentationWithSlides(
|
|
**presentation.model_dump(),
|
|
slides=slides,
|
|
fonts=fonts,
|
|
)
|
|
)
|
|
return presentations_with_slides
|
|
|
|
|
|
@PRESENTATION_ROUTER.get("/{id}", response_model=PresentationWithSlides)
|
|
async def get_presentation(
|
|
id: uuid.UUID, sql_session: AsyncSession = Depends(get_async_session)
|
|
):
|
|
presentation = await sql_session.get(PresentationModel, id)
|
|
if not presentation:
|
|
raise HTTPException(404, "Presentation not found")
|
|
slides_result = await sql_session.scalars(
|
|
select(SlideModel)
|
|
.where(SlideModel.presentation == id)
|
|
.order_by(SlideModel.index)
|
|
)
|
|
slides = list(slides_result)
|
|
fonts = await _resolve_presentation_fonts(presentation, slides, sql_session)
|
|
return PresentationWithSlides(
|
|
**presentation.model_dump(),
|
|
slides=slides,
|
|
fonts=fonts,
|
|
)
|
|
|
|
|
|
@PRESENTATION_ROUTER.delete("/{id}", status_code=204)
|
|
async def delete_presentation(
|
|
id: uuid.UUID, sql_session: AsyncSession = Depends(get_async_session)
|
|
):
|
|
presentation = await sql_session.get(PresentationModel, id)
|
|
if not presentation:
|
|
raise HTTPException(404, "Presentation not found")
|
|
|
|
await sql_session.delete(presentation)
|
|
await sql_session.commit()
|
|
|
|
|
|
@PRESENTATION_ROUTER.post("/create", response_model=PresentationModel)
|
|
async def create_presentation(
|
|
content: Annotated[str, Body()],
|
|
n_slides: Annotated[Optional[int], Body()] = None,
|
|
language: Annotated[Optional[str], Body()] = None,
|
|
file_paths: Annotated[Optional[List[str]], Body()] = None,
|
|
tone: Annotated[Tone, Body()] = Tone.DEFAULT,
|
|
verbosity: Annotated[Verbosity, Body()] = Verbosity.STANDARD,
|
|
instructions: Annotated[Optional[str], Body()] = None,
|
|
include_table_of_contents: Annotated[bool, Body()] = False,
|
|
include_title_slide: Annotated[bool, Body()] = True,
|
|
web_search: Annotated[bool, Body()] = False,
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
|
|
if n_slides is not None and n_slides < 1:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Number of slides must be greater than 0",
|
|
)
|
|
|
|
if n_slides is not None and n_slides > MAX_NUMBER_OF_SLIDES:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Number of slides cannot be greater than {MAX_NUMBER_OF_SLIDES}",
|
|
)
|
|
|
|
if include_table_of_contents and n_slides is not None and n_slides < 3:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Number of slides cannot be less than 3 if table of contents is included",
|
|
)
|
|
|
|
presentation_id = uuid.uuid4()
|
|
language_to_store = (language or "").strip()
|
|
# DB schema stores an int; 0 is used as internal marker for auto slide count.
|
|
n_slides_to_store = n_slides if n_slides is not None else 0
|
|
|
|
presentation = PresentationModel(
|
|
id=presentation_id,
|
|
content=content,
|
|
n_slides=n_slides_to_store,
|
|
language=language_to_store,
|
|
file_paths=file_paths,
|
|
tone=tone.value,
|
|
verbosity=verbosity.value,
|
|
instructions=instructions,
|
|
include_table_of_contents=include_table_of_contents,
|
|
include_title_slide=include_title_slide,
|
|
web_search=web_search,
|
|
)
|
|
|
|
sql_session.add(presentation)
|
|
await sql_session.commit()
|
|
|
|
return presentation
|
|
|
|
|
|
@PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel)
|
|
async def prepare_presentation(
|
|
presentation_id: Annotated[uuid.UUID, Body()],
|
|
outlines: Annotated[List[SlideOutlineModel], Body()],
|
|
layout: Annotated[PresentationLayoutModel, Body()],
|
|
title: Annotated[Optional[str], Body()] = None,
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
if not outlines:
|
|
raise HTTPException(status_code=400, detail="Outlines are required")
|
|
|
|
presentation = await sql_session.get(PresentationModel, presentation_id)
|
|
if not presentation:
|
|
raise HTTPException(status_code=404, detail="Presentation not found")
|
|
|
|
presentation_outline_model = PresentationOutlineModel(slides=outlines)
|
|
|
|
total_slide_layouts = len(layout.slides)
|
|
total_outlines = len(outlines)
|
|
|
|
if layout.ordered:
|
|
presentation_structure = layout.to_presentation_structure()
|
|
else:
|
|
presentation_structure: PresentationStructureModel = (
|
|
await generate_presentation_structure(
|
|
presentation_outline=presentation_outline_model,
|
|
presentation_layout=layout,
|
|
instructions=presentation.instructions,
|
|
)
|
|
)
|
|
|
|
presentation_structure.slides = presentation_structure.slides[: len(outlines)]
|
|
for index in range(total_outlines):
|
|
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
|
if index >= total_outlines:
|
|
presentation_structure.slides.append(random_slide_index)
|
|
continue
|
|
if presentation_structure.slides[index] >= total_slide_layouts:
|
|
presentation_structure.slides[index] = random_slide_index
|
|
|
|
if presentation.include_table_of_contents:
|
|
n_toc_slides = get_no_of_toc_required_for_n_outlines(
|
|
n_outlines=total_outlines,
|
|
title_slide=presentation.include_title_slide,
|
|
target_total_slides=(presentation.n_slides if presentation.n_slides > 0 else None),
|
|
)
|
|
toc_slide_layout_index = select_toc_or_list_slide_layout_index(layout)
|
|
_insert_toc_layouts(
|
|
presentation_structure,
|
|
n_toc_slides,
|
|
presentation.include_title_slide,
|
|
toc_slide_layout_index,
|
|
)
|
|
if toc_slide_layout_index != -1 and n_toc_slides > 0:
|
|
presentation_outline_model = get_presentation_outline_model_with_toc(
|
|
outline=presentation_outline_model,
|
|
n_toc_slides=n_toc_slides,
|
|
title_slide=presentation.include_title_slide,
|
|
)
|
|
|
|
sql_session.add(presentation)
|
|
presentation.outlines = presentation_outline_model.model_dump(mode="json")
|
|
presentation.title = title or presentation.title
|
|
presentation.set_layout(layout)
|
|
presentation.set_structure(presentation_structure)
|
|
await sql_session.commit()
|
|
|
|
return presentation
|
|
|
|
|
|
@PRESENTATION_ROUTER.get("/stream/{id}", response_model=PresentationWithSlides)
|
|
async def stream_presentation(
|
|
id: uuid.UUID, sql_session: AsyncSession = Depends(get_async_session)
|
|
):
|
|
presentation = await sql_session.get(PresentationModel, id)
|
|
if not presentation:
|
|
raise HTTPException(status_code=404, detail="Presentation not found")
|
|
if not presentation.structure:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Presentation not prepared for stream",
|
|
)
|
|
if not presentation.outlines:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Outlines can not be empty",
|
|
)
|
|
|
|
image_generation_service = ImageGenerationService(get_images_directory())
|
|
|
|
async def inner():
|
|
structure = presentation.get_structure()
|
|
layout = presentation.get_layout()
|
|
outline = presentation.get_presentation_outline()
|
|
image_urls_for_slides = get_images_for_slides_from_outline(outline.slides)
|
|
|
|
# These tasks will be gathered and awaited after all slides are generated
|
|
async_assets_generation_tasks = []
|
|
|
|
slides: List[SlideModel] = []
|
|
yield SSEResponse(
|
|
event="response",
|
|
data=json.dumps({"type": "chunk", "chunk": '{ "slides": [ '}),
|
|
).to_string()
|
|
for i, slide_layout_index in enumerate(structure.slides):
|
|
slide_layout = layout.slides[slide_layout_index]
|
|
|
|
try:
|
|
slide_content = await get_slide_content_from_type_and_outline(
|
|
slide_layout,
|
|
outline.slides[i],
|
|
presentation.language,
|
|
presentation.tone,
|
|
presentation.verbosity,
|
|
presentation.instructions,
|
|
)
|
|
except HTTPException as e:
|
|
yield SSEErrorResponse(detail=e.detail).to_string()
|
|
return
|
|
|
|
slide = SlideModel(
|
|
presentation=id,
|
|
layout_group=layout.name,
|
|
layout=slide_layout.id,
|
|
index=i,
|
|
speaker_note=slide_content.get("__speaker_note__", ""),
|
|
content=slide_content,
|
|
)
|
|
slides.append(slide)
|
|
|
|
# This will mutate slide and add placeholder assets
|
|
process_slide_add_placeholder_assets(slide)
|
|
|
|
# This will mutate slide - start task immediately so it runs in parallel with next slide LLM generation
|
|
async_assets_generation_tasks.append(
|
|
asyncio.create_task(
|
|
process_slide_and_fetch_assets(
|
|
image_generation_service,
|
|
slide,
|
|
outline_image_urls=(
|
|
image_urls_for_slides[i]
|
|
if i < len(image_urls_for_slides)
|
|
else None
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
yield SSEResponse(
|
|
event="response",
|
|
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
|
).to_string()
|
|
|
|
yield SSEResponse(
|
|
event="response",
|
|
data=json.dumps({"type": "chunk", "chunk": " ] }"}),
|
|
).to_string()
|
|
|
|
generated_assets_lists = await asyncio.gather(*async_assets_generation_tasks)
|
|
generated_assets = []
|
|
for assets_list in generated_assets_lists:
|
|
generated_assets.extend(assets_list)
|
|
|
|
# Moved this here to make sure new slides are generated before deleting the old ones
|
|
await sql_session.execute(
|
|
delete(SlideModel).where(SlideModel.presentation == id)
|
|
)
|
|
await sql_session.commit()
|
|
|
|
sql_session.add(presentation)
|
|
sql_session.add_all(slides)
|
|
sql_session.add_all(generated_assets)
|
|
await sql_session.commit()
|
|
|
|
response = PresentationWithSlides(
|
|
**presentation.model_dump(),
|
|
slides=slides,
|
|
fonts=await _resolve_presentation_fonts(presentation, slides, sql_session),
|
|
)
|
|
|
|
yield SSECompleteResponse(
|
|
key="presentation",
|
|
value=response.model_dump(mode="json"),
|
|
).to_string()
|
|
|
|
return StreamingResponse(inner(), media_type="text/event-stream")
|
|
|
|
|
|
@PRESENTATION_ROUTER.patch("/update", response_model=PresentationWithSlides)
|
|
async def update_presentation(
|
|
id: Annotated[uuid.UUID, Body()],
|
|
n_slides: Annotated[Optional[int], Body()] = None,
|
|
title: Annotated[Optional[str], Body()] = None,
|
|
theme: Annotated[Optional[dict], Body()] = None,
|
|
slides: Annotated[Optional[List[SlideModel]], Body()] = None,
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
presentation = await sql_session.get(PresentationModel, id)
|
|
if not presentation:
|
|
raise HTTPException(status_code=404, detail="Presentation not found")
|
|
|
|
presentation_update_dict = {}
|
|
if n_slides is not None:
|
|
presentation_update_dict["n_slides"] = n_slides
|
|
if title:
|
|
presentation_update_dict["title"] = title
|
|
if theme or theme is None:
|
|
presentation_update_dict["theme"] = theme
|
|
|
|
if presentation_update_dict:
|
|
presentation.sqlmodel_update(presentation_update_dict)
|
|
if slides:
|
|
# Just to make sure id is UUID
|
|
for slide in slides:
|
|
slide.presentation = uuid.UUID(slide.presentation)
|
|
slide.id = uuid.UUID(slide.id)
|
|
|
|
await sql_session.execute(
|
|
delete(SlideModel).where(SlideModel.presentation == presentation.id)
|
|
)
|
|
sql_session.add_all(slides)
|
|
|
|
await sql_session.commit()
|
|
|
|
response_slides = slides or []
|
|
fonts = await _resolve_presentation_fonts(
|
|
presentation,
|
|
response_slides,
|
|
sql_session,
|
|
)
|
|
|
|
return PresentationWithSlides(
|
|
**presentation.model_dump(),
|
|
slides=response_slides,
|
|
fonts=fonts,
|
|
)
|
|
async def check_if_api_request_is_valid(
|
|
request: GeneratePresentationRequest,
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
) -> Tuple[uuid.UUID,]:
|
|
presentation_id = uuid.uuid4()
|
|
print(f"Presentation ID: {presentation_id}")
|
|
|
|
# Making sure either content, slides markdown or files is provided
|
|
if not (request.content or request.slides_markdown or request.files):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Either content or slides markdown or files is required to generate presentation",
|
|
)
|
|
|
|
if request.n_slides is not None and request.n_slides <= 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Number of slides must be greater than 0",
|
|
)
|
|
|
|
if request.n_slides is not None and request.n_slides > MAX_NUMBER_OF_SLIDES:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Number of slides cannot be greater than {MAX_NUMBER_OF_SLIDES}",
|
|
)
|
|
|
|
if (
|
|
request.include_table_of_contents
|
|
and request.n_slides is not None
|
|
and request.n_slides < 3
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Number of slides cannot be less than 3 if table of contents is included",
|
|
)
|
|
|
|
# Checking if template is valid
|
|
if request.template not in DEFAULT_TEMPLATES:
|
|
request.template = request.template.lower()
|
|
if not request.template.startswith("custom-"):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Template not found. Please use a valid template.",
|
|
)
|
|
template_id = request.template.replace("custom-", "")
|
|
try:
|
|
template = await sql_session.get(TemplateModel, uuid.UUID(template_id))
|
|
if not template:
|
|
raise Exception()
|
|
except Exception:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Template not found. Please use a valid template.",
|
|
)
|
|
|
|
return (presentation_id,)
|
|
|
|
|
|
async def generate_presentation_handler(
|
|
request: GeneratePresentationRequest,
|
|
presentation_id: uuid.UUID,
|
|
async_status: Optional[AsyncPresentationGenerationTaskModel],
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
try:
|
|
using_slides_markdown = False
|
|
language_to_use = (request.language or "").strip() or None
|
|
|
|
if request.slides_markdown:
|
|
using_slides_markdown = True
|
|
request.n_slides = len(request.slides_markdown)
|
|
|
|
if not using_slides_markdown:
|
|
additional_context = ""
|
|
|
|
# Updating async status
|
|
if async_status:
|
|
async_status.message = "Generating presentation outlines"
|
|
async_status.updated_at = datetime.now()
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
if request.files:
|
|
documents_loader = DocumentsLoader(
|
|
file_paths=request.files,
|
|
presentation_language=request.language,
|
|
)
|
|
await documents_loader.load_documents()
|
|
documents = documents_loader.documents
|
|
if documents:
|
|
additional_context = "\n\n".join(documents)
|
|
|
|
# Finding number of slides to generate by considering table of contents
|
|
n_slides_to_generate = request.n_slides
|
|
if request.include_table_of_contents and request.n_slides is not None:
|
|
n_slides_to_generate = (
|
|
get_no_of_outlines_to_generate_for_n_slides(
|
|
n_slides=request.n_slides,
|
|
toc=True,
|
|
title_slide=request.include_title_slide,
|
|
)
|
|
)
|
|
|
|
presentation_outlines_text = ""
|
|
async for chunk in generate_ppt_outline(
|
|
request.content,
|
|
n_slides_to_generate,
|
|
language_to_use,
|
|
additional_context,
|
|
request.tone.value,
|
|
request.verbosity.value,
|
|
request.instructions,
|
|
request.include_title_slide,
|
|
request.web_search,
|
|
request.include_table_of_contents,
|
|
):
|
|
|
|
if isinstance(chunk, HTTPException):
|
|
raise chunk
|
|
|
|
presentation_outlines_text += chunk
|
|
|
|
try:
|
|
presentation_outlines_json = dict(
|
|
dirtyjson.loads(presentation_outlines_text)
|
|
)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Failed to generate presentation outlines. Please try again.",
|
|
)
|
|
presentation_outlines = PresentationOutlineModel(
|
|
**presentation_outlines_json
|
|
)
|
|
|
|
if (
|
|
n_slides_to_generate is not None
|
|
and len(presentation_outlines.slides) != n_slides_to_generate
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
"Failed to generate presentation outlines with requested "
|
|
"number of slides. Please try again."
|
|
),
|
|
)
|
|
|
|
total_outlines = len(presentation_outlines.slides)
|
|
|
|
else:
|
|
# Setting outlines to slides markdown
|
|
presentation_outlines = PresentationOutlineModel(
|
|
slides=[
|
|
SlideOutlineModel(content=slide)
|
|
for slide in request.slides_markdown
|
|
]
|
|
)
|
|
total_outlines = len(request.slides_markdown)
|
|
|
|
# Updating async status
|
|
if async_status:
|
|
async_status.message = "Selecting layout for each slide"
|
|
async_status.updated_at = datetime.now()
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
print("-" * 40)
|
|
print(f"Generated {total_outlines} outlines for the presentation")
|
|
|
|
# Parse Layouts
|
|
layout_model = await get_layout_by_name(request.template)
|
|
total_slide_layouts = len(layout_model.slides)
|
|
|
|
# Generate Structure
|
|
if layout_model.ordered:
|
|
presentation_structure = layout_model.to_presentation_structure()
|
|
else:
|
|
presentation_structure: PresentationStructureModel = (
|
|
await generate_presentation_structure(
|
|
presentation_outlines,
|
|
layout_model,
|
|
request.instructions,
|
|
using_slides_markdown,
|
|
)
|
|
)
|
|
|
|
presentation_structure.slides = presentation_structure.slides[:total_outlines]
|
|
for index in range(total_outlines):
|
|
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
|
if index >= total_outlines:
|
|
presentation_structure.slides.append(random_slide_index)
|
|
continue
|
|
if presentation_structure.slides[index] >= total_slide_layouts:
|
|
presentation_structure.slides[index] = random_slide_index
|
|
|
|
should_include_toc = (
|
|
request.include_table_of_contents and not using_slides_markdown
|
|
)
|
|
if should_include_toc:
|
|
n_toc_slides = get_no_of_toc_required_for_n_outlines(
|
|
n_outlines=total_outlines,
|
|
title_slide=request.include_title_slide,
|
|
target_total_slides=request.n_slides,
|
|
)
|
|
toc_slide_layout_index = select_toc_or_list_slide_layout_index(layout_model)
|
|
_insert_toc_layouts(
|
|
presentation_structure,
|
|
n_toc_slides,
|
|
request.include_title_slide,
|
|
toc_slide_layout_index,
|
|
)
|
|
if toc_slide_layout_index != -1 and n_toc_slides > 0:
|
|
presentation_outlines = get_presentation_outline_model_with_toc(
|
|
outline=presentation_outlines,
|
|
n_toc_slides=n_toc_slides,
|
|
title_slide=request.include_title_slide,
|
|
)
|
|
|
|
final_n_slides = request.n_slides
|
|
if final_n_slides is None:
|
|
final_n_slides = len(presentation_outlines.slides)
|
|
|
|
# Create PresentationModel
|
|
presentation = PresentationModel(
|
|
id=presentation_id,
|
|
content=request.content,
|
|
n_slides=final_n_slides,
|
|
language=language_to_use or "",
|
|
title=get_presentation_title_from_presentation_outline(
|
|
presentation_outlines
|
|
),
|
|
outlines=presentation_outlines.model_dump(),
|
|
layout=layout_model.model_dump(),
|
|
structure=presentation_structure.model_dump(),
|
|
tone=request.tone.value,
|
|
verbosity=request.verbosity.value,
|
|
instructions=request.instructions,
|
|
)
|
|
|
|
# Updating async status
|
|
if async_status:
|
|
async_status.message = "Generating slides"
|
|
async_status.updated_at = datetime.now()
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
image_generation_service = ImageGenerationService(get_images_directory())
|
|
async_assets_generation_tasks = []
|
|
|
|
# 7. Generate slide content concurrently (batched), then build slides and fetch assets
|
|
slides: List[SlideModel] = []
|
|
|
|
slide_layout_indices = presentation_structure.slides
|
|
slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices]
|
|
|
|
# Schedule slide content generation and asset fetching in batches of 10
|
|
batch_size = 10
|
|
for start in range(0, len(slide_layouts), batch_size):
|
|
end = min(start + batch_size, len(slide_layouts))
|
|
|
|
print(f"Generating slides from {start} to {end}")
|
|
|
|
# Generate contents for this batch concurrently
|
|
content_tasks = [
|
|
get_slide_content_from_type_and_outline(
|
|
slide_layouts[i],
|
|
presentation_outlines.slides[i],
|
|
language_to_use,
|
|
request.tone.value,
|
|
request.verbosity.value,
|
|
request.instructions,
|
|
)
|
|
for i in range(start, end)
|
|
]
|
|
batch_contents: List[dict] = await asyncio.gather(*content_tasks)
|
|
|
|
# Build slides for this batch
|
|
batch_slides: List[SlideModel] = []
|
|
for offset, slide_content in enumerate(batch_contents):
|
|
i = start + offset
|
|
slide_layout = slide_layouts[i]
|
|
slide = SlideModel(
|
|
presentation=presentation_id,
|
|
layout_group=layout_model.name,
|
|
layout=slide_layout.id,
|
|
index=i,
|
|
speaker_note=slide_content.get("__speaker_note__"),
|
|
content=slide_content,
|
|
)
|
|
slides.append(slide)
|
|
batch_slides.append(slide)
|
|
|
|
if using_slides_markdown:
|
|
image_urls_for_batch = get_images_for_slides_from_outline(
|
|
presentation_outlines.slides[start:end]
|
|
)
|
|
else:
|
|
image_urls_for_batch = [[] for _ in batch_slides]
|
|
|
|
# Start asset fetch tasks immediately so they run in parallel with next batch's LLM calls
|
|
asset_tasks = [
|
|
asyncio.create_task(
|
|
process_slide_and_fetch_assets(
|
|
image_generation_service,
|
|
slide,
|
|
outline_image_urls=image_urls_for_batch[offset],
|
|
)
|
|
)
|
|
for offset, slide in enumerate(batch_slides)
|
|
]
|
|
async_assets_generation_tasks.extend(asset_tasks)
|
|
|
|
if async_status:
|
|
async_status.message = "Fetching assets for slides"
|
|
async_status.updated_at = datetime.now()
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
# Run all asset tasks concurrently while batches may still be generating content
|
|
generated_assets_list = await asyncio.gather(*async_assets_generation_tasks)
|
|
generated_assets = []
|
|
for assets_list in generated_assets_list:
|
|
generated_assets.extend(assets_list)
|
|
|
|
# 8. Save PresentationModel and Slides
|
|
sql_session.add(presentation)
|
|
sql_session.add_all(slides)
|
|
sql_session.add_all(generated_assets)
|
|
await sql_session.commit()
|
|
|
|
if async_status:
|
|
async_status.message = "Exporting presentation"
|
|
async_status.updated_at = datetime.now()
|
|
sql_session.add(async_status)
|
|
|
|
# 9. Export
|
|
presentation_and_path = await export_presentation(
|
|
presentation_id, presentation.title or str(uuid.uuid4()), request.export_as
|
|
)
|
|
|
|
response = PresentationPathAndEditPath(
|
|
**presentation_and_path.model_dump(),
|
|
edit_path=f"/presentation?id={presentation_id}",
|
|
)
|
|
|
|
if async_status:
|
|
async_status.message = "Presentation generation completed"
|
|
async_status.status = "completed"
|
|
async_status.data = response.model_dump(mode="json")
|
|
async_status.updated_at = datetime.now()
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
# Triggering webhook on success
|
|
CONCURRENT_SERVICE.run_task(
|
|
None,
|
|
WebhookService.send_webhook,
|
|
WebhookEvent.PRESENTATION_GENERATION_COMPLETED,
|
|
response.model_dump(mode="json"),
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
if not isinstance(e, HTTPException):
|
|
traceback.print_exc()
|
|
e = HTTPException(status_code=500, detail="Presentation generation failed")
|
|
|
|
api_error_model = APIErrorModel.from_exception(e)
|
|
|
|
# Triggering webhook on failure
|
|
CONCURRENT_SERVICE.run_task(
|
|
None,
|
|
WebhookService.send_webhook,
|
|
WebhookEvent.PRESENTATION_GENERATION_FAILED,
|
|
api_error_model.model_dump(mode="json"),
|
|
)
|
|
|
|
if async_status:
|
|
async_status.status = "error"
|
|
async_status.message = "Presentation generation failed"
|
|
async_status.updated_at = datetime.now()
|
|
async_status.error = api_error_model.model_dump(mode="json")
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
else:
|
|
raise e
|
|
|
|
|
|
@PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath)
|
|
async def generate_presentation_sync(
|
|
request: GeneratePresentationRequest,
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
try:
|
|
(presentation_id,) = await check_if_api_request_is_valid(request, sql_session)
|
|
return await generate_presentation_handler(
|
|
request, presentation_id, None, sql_session
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
except Exception:
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=500, detail="Presentation generation failed")
|
|
|
|
|
|
@PRESENTATION_ROUTER.post(
|
|
"/generate/async", response_model=AsyncPresentationGenerationTaskModel
|
|
)
|
|
async def generate_presentation_async(
|
|
request: GeneratePresentationRequest,
|
|
background_tasks: BackgroundTasks,
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
try:
|
|
(presentation_id,) = await check_if_api_request_is_valid(request, sql_session)
|
|
|
|
async_status = AsyncPresentationGenerationTaskModel(
|
|
status="pending",
|
|
message="Queued for generation",
|
|
data=None,
|
|
)
|
|
sql_session.add(async_status)
|
|
await sql_session.commit()
|
|
|
|
background_tasks.add_task(
|
|
generate_presentation_handler,
|
|
request,
|
|
presentation_id,
|
|
async_status=async_status,
|
|
sql_session=sql_session,
|
|
)
|
|
return async_status
|
|
|
|
except Exception as e:
|
|
if not isinstance(e, HTTPException):
|
|
print(e)
|
|
e = HTTPException(status_code=500, detail="Presentation generation failed")
|
|
|
|
raise e
|
|
|
|
|
|
@PRESENTATION_ROUTER.get(
|
|
"/status/{id}", response_model=AsyncPresentationGenerationTaskModel
|
|
)
|
|
async def check_async_presentation_generation_status(
|
|
id: str = Path(description="ID of the presentation generation task"),
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
status = await sql_session.get(AsyncPresentationGenerationTaskModel, id)
|
|
if not status:
|
|
raise HTTPException(
|
|
status_code=404, detail="No presentation generation task found"
|
|
)
|
|
return status
|
|
|
|
|
|
@PRESENTATION_ROUTER.post("/edit", response_model=PresentationPathAndEditPath)
|
|
async def edit_presentation_with_new_content(
|
|
data: Annotated[EditPresentationRequest, Body()],
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
presentation = await sql_session.get(PresentationModel, data.presentation_id)
|
|
if not presentation:
|
|
raise HTTPException(status_code=404, detail="Presentation not found")
|
|
|
|
slides = await sql_session.scalars(
|
|
select(SlideModel).where(SlideModel.presentation == data.presentation_id)
|
|
)
|
|
|
|
new_slides = []
|
|
slides_to_delete = []
|
|
for each_slide in slides:
|
|
updated_content = None
|
|
new_slide_data = list(
|
|
filter(lambda x: x.index == each_slide.index, data.slides)
|
|
)
|
|
if new_slide_data:
|
|
updated_content = deep_update(each_slide.content, new_slide_data[0].content)
|
|
new_slides.append(
|
|
each_slide.get_new_slide(presentation.id, updated_content)
|
|
)
|
|
slides_to_delete.append(each_slide.id)
|
|
|
|
await sql_session.execute(
|
|
delete(SlideModel).where(SlideModel.id.in_(slides_to_delete))
|
|
)
|
|
|
|
sql_session.add_all(new_slides)
|
|
await sql_session.commit()
|
|
|
|
presentation_and_path = await export_presentation(
|
|
presentation.id, presentation.title or str(uuid.uuid4()), data.export_as
|
|
)
|
|
|
|
return PresentationPathAndEditPath(
|
|
**presentation_and_path.model_dump(),
|
|
edit_path=f"/presentation?id={presentation.id}",
|
|
)
|
|
|
|
|
|
@PRESENTATION_ROUTER.post("/derive", response_model=PresentationPathAndEditPath)
|
|
async def derive_presentation_from_existing_one(
|
|
data: Annotated[EditPresentationRequest, Body()],
|
|
sql_session: AsyncSession = Depends(get_async_session),
|
|
):
|
|
presentation = await sql_session.get(PresentationModel, data.presentation_id)
|
|
if not presentation:
|
|
raise HTTPException(status_code=404, detail="Presentation not found")
|
|
|
|
slides = await sql_session.scalars(
|
|
select(SlideModel).where(SlideModel.presentation == data.presentation_id)
|
|
)
|
|
|
|
new_presentation = presentation.get_new_presentation()
|
|
new_slides = []
|
|
for each_slide in slides:
|
|
updated_content = None
|
|
new_slide_data = list(
|
|
filter(lambda x: x.index == each_slide.index, data.slides)
|
|
)
|
|
if new_slide_data:
|
|
updated_content = deep_update(each_slide.content, new_slide_data[0].content)
|
|
new_slides.append(
|
|
each_slide.get_new_slide(new_presentation.id, updated_content)
|
|
)
|
|
|
|
sql_session.add(new_presentation)
|
|
sql_session.add_all(new_slides)
|
|
await sql_session.commit()
|
|
|
|
presentation_and_path = await export_presentation(
|
|
new_presentation.id, new_presentation.title or str(uuid.uuid4()), data.export_as
|
|
)
|
|
|
|
return PresentationPathAndEditPath(
|
|
**presentation_and_path.model_dump(),
|
|
edit_path=f"/presentation?id={new_presentation.id}",
|
|
)
|