Merge pull request #153 from presenton/feat/use-presentation-as-template
feat(fastapi): adds an endpoint where you can use generated presentation as template and create new presentation
This commit is contained in:
commit
f6f70337bc
8 changed files with 197 additions and 66 deletions
|
|
@ -8,6 +8,7 @@ from fastapi import APIRouter, Body, HTTPException
|
|||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import select
|
||||
from models.presentation_from_template import GetPresentationUsingTemplateRequest
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
|
|
@ -24,6 +25,8 @@ from models.generate_presentation_api import (
|
|||
from services.get_layout_by_name import get_layout_by_name
|
||||
from services.icon_finder_service import IconFinderService
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.dict_utils import deep_update
|
||||
from utils.export_utils import export_presentation
|
||||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sse_response import SSECompleteResponse, SSEResponse
|
||||
|
|
@ -46,7 +49,7 @@ from utils.randomizers import get_random_uuid
|
|||
PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"])
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.get("/", response_model=PresentationWithSlides)
|
||||
@PRESENTATION_ROUTER.get("", response_model=PresentationWithSlides)
|
||||
def get_presentation(id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, id)
|
||||
|
|
@ -63,7 +66,7 @@ def get_presentation(id: str):
|
|||
)
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.delete("/", status_code=204)
|
||||
@PRESENTATION_ROUTER.delete("", status_code=204)
|
||||
def delete_presentation(id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, id)
|
||||
|
|
@ -317,7 +320,7 @@ async def create_pptx(pptx_model: Annotated[PptxPresentationModel, Body()]):
|
|||
return pptx_path
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/generate")
|
||||
@PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath)
|
||||
async def generate_presentation_api(
|
||||
data: Annotated[GeneratePresentationRequest, Body()],
|
||||
):
|
||||
|
|
@ -447,63 +450,52 @@ async def generate_presentation_api(
|
|||
sql_session.add_all(slides)
|
||||
sql_session.commit()
|
||||
|
||||
# 8. Export as PPTX
|
||||
if data.export_as == "pptx":
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation as PPTX")
|
||||
|
||||
# Get the converted PPTX model from your existing Next.js service
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost/api/presentation_to_pptx_model?id={presentation_id}"
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
print(f"Failed to get PPTX model: {error_text}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to convert presentation to PPTX model",
|
||||
)
|
||||
pptx_model_data = await response.json()
|
||||
print(f"Received PPTX model data: {json.dumps(pptx_model_data, indent=2)}")
|
||||
|
||||
# Create PPTX file using the converted model
|
||||
pptx_model = PptxPresentationModel(**pptx_model_data)
|
||||
print(f"Creating PPTX with model: {pptx_model.model_dump_json(indent=2)}")
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
pptx_creator = PptxPresentationCreator(pptx_model, temp_dir)
|
||||
await pptx_creator.create_ppt()
|
||||
|
||||
export_directory = get_exports_directory()
|
||||
pptx_path = os.path.join(export_directory, f"{presentation_content.title}.pptx")
|
||||
pptx_creator.save(pptx_path)
|
||||
|
||||
presentation_and_path = PresentationAndPath(
|
||||
presentation_id=presentation_id,
|
||||
path=pptx_path,
|
||||
)
|
||||
else:
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation as PDF")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"http://localhost/api/export-as-pdf",
|
||||
json={
|
||||
"id": presentation_id,
|
||||
"title": presentation_content.title,
|
||||
},
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
print(f"Received PDF export response: {json.dumps(response_json, indent=2)}")
|
||||
|
||||
presentation_and_path = PresentationAndPath(
|
||||
presentation_id=presentation_id,
|
||||
path=response_json["path"],
|
||||
)
|
||||
# 8. Export
|
||||
presentation_and_path = await export_presentation(
|
||||
presentation_id, presentation_content.title, data.export_as
|
||||
)
|
||||
|
||||
return PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={presentation_id}",
|
||||
)
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/from-template", response_model=PresentationPathAndEditPath)
|
||||
async def from_template(
|
||||
data: Annotated[GetPresentationUsingTemplateRequest, Body()],
|
||||
):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, data.presentation_id)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
slides = sql_session.exec(
|
||||
select(SlideModel).where(SlideModel.presentation == data.presentation_id)
|
||||
).all()
|
||||
|
||||
new_presentation = presentation.get_new_presentation()
|
||||
new_slides = []
|
||||
for each_slide in slides:
|
||||
updated_content = None
|
||||
new_slide_data = list(filter(lambda x: x.index == each_slide.index, data.data))
|
||||
if new_slide_data:
|
||||
updated_content = deep_update(each_slide.content, new_slide_data[0].content)
|
||||
print(f"Updated content for slide {each_slide.index}: {updated_content}")
|
||||
new_slides.append(
|
||||
each_slide.get_new_slide(new_presentation.id, updated_content)
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(new_presentation)
|
||||
sql_session.add_all(new_slides)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(new_presentation)
|
||||
|
||||
presentation_and_path = await export_presentation(
|
||||
new_presentation.id, new_presentation.title, data.export_as
|
||||
)
|
||||
|
||||
return PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={new_presentation.id}",
|
||||
)
|
||||
|
|
|
|||
13
servers/fastapi/models/presentation_from_template.py
Normal file
13
servers/fastapi/models/presentation_from_template.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from typing import List, Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlideContentUpdate(BaseModel):
|
||||
index: int
|
||||
content: dict
|
||||
|
||||
|
||||
class GetPresentationUsingTemplateRequest(BaseModel):
|
||||
presentation_id: str
|
||||
data: List[SlideContentUpdate]
|
||||
export_as: Literal["pptx", "pdf"] = "pptx"
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
|
@ -9,6 +10,7 @@ from models.presentation_outline_model import (
|
|||
SlideOutlineModel,
|
||||
)
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class PresentationModel(SQLModel, table=True):
|
||||
|
|
@ -25,6 +27,20 @@ class PresentationModel(SQLModel, table=True):
|
|||
layout: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
structure: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
|
||||
def get_new_presentation(self):
|
||||
return PresentationModel(
|
||||
id=get_random_uuid(),
|
||||
prompt=self.prompt,
|
||||
n_slides=self.n_slides,
|
||||
language=self.language,
|
||||
title=self.title,
|
||||
notes=self.notes,
|
||||
outlines=self.outlines,
|
||||
summary=self.summary,
|
||||
layout=self.layout,
|
||||
structure=self.structure,
|
||||
)
|
||||
|
||||
def get_presentation_outline(self):
|
||||
if not self.outlines:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
|
@ -13,3 +14,14 @@ class SlideModel(SQLModel, table=True):
|
|||
content: dict = Field(sa_column=Column(JSON))
|
||||
html_content: Optional[str]
|
||||
properties: Optional[dict] = Field(sa_column=Column(JSON))
|
||||
|
||||
def get_new_slide(self, presentation_id: str, content: Optional[dict] = None):
|
||||
return SlideModel(
|
||||
id=get_random_uuid(),
|
||||
presentation=presentation_id,
|
||||
layout_group=self.layout_group,
|
||||
layout=self.layout,
|
||||
index=self.index,
|
||||
content=content or self.content,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,10 +73,10 @@ class PptxPresentationCreator:
|
|||
if isinstance(each_shape, PptxPictureBoxModel):
|
||||
image_path = each_shape.picture.path
|
||||
if image_path.startswith("http"):
|
||||
if "app_data/images" in image_path:
|
||||
relative_path = image_path.split("/app_data/images/")[1]
|
||||
if "app_data/" in image_path:
|
||||
relative_path = image_path.split("/app_data/")[1]
|
||||
each_shape.picture.path = os.path.join(
|
||||
"app_data/images", relative_path
|
||||
"app_data", relative_path
|
||||
)
|
||||
each_shape.picture.is_network = False
|
||||
continue
|
||||
|
|
@ -88,10 +88,10 @@ class PptxPresentationCreator:
|
|||
if isinstance(each_shape, PptxPictureBoxModel):
|
||||
image_path = each_shape.picture.path
|
||||
if image_path.startswith("http"):
|
||||
if "app_data/images" in image_path:
|
||||
relative_path = image_path.split("/app_data/images/")[1]
|
||||
if "app_data" in image_path:
|
||||
relative_path = image_path.split("/app_data/")[1]
|
||||
each_shape.picture.path = os.path.join(
|
||||
"app_data/images", relative_path
|
||||
"app_data", relative_path
|
||||
)
|
||||
each_shape.picture.is_network = False
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -46,3 +46,35 @@ def set_dict_at_path(data: dict, path: JsonPathGuide, value: dict):
|
|||
current[final_guide.key] = value
|
||||
elif isinstance(final_guide, ListGuide):
|
||||
current[final_guide.index] = value
|
||||
|
||||
|
||||
def deep_update(original: dict, updates: dict) -> dict:
|
||||
for key, value in updates.items():
|
||||
if key in original:
|
||||
if isinstance(original[key], dict) and isinstance(value, dict):
|
||||
deep_update(original[key], value)
|
||||
elif isinstance(original[key], list) and isinstance(value, list):
|
||||
if len(value) == 0:
|
||||
continue
|
||||
elif len(value) == 1 and isinstance(value[0], dict):
|
||||
if len(original[key]) > 0 and isinstance(original[key][0], dict):
|
||||
deep_update(original[key][0], value[0])
|
||||
else:
|
||||
original[key][0] = (
|
||||
value[0] if len(original[key]) > 0 else value[0]
|
||||
)
|
||||
else:
|
||||
min_length = min(len(original[key]), len(value))
|
||||
for i in range(min_length):
|
||||
if isinstance(original[key][i], dict) and isinstance(
|
||||
value[i], dict
|
||||
):
|
||||
deep_update(original[key][i], value[i])
|
||||
else:
|
||||
original[key][i] = value[i]
|
||||
elif not isinstance(value, (dict, list)):
|
||||
original[key] = value
|
||||
else:
|
||||
if not isinstance(value, (dict, list)):
|
||||
original[key] = value
|
||||
return original
|
||||
|
|
|
|||
66
servers/fastapi/utils/export_utils.py
Normal file
66
servers/fastapi/utils/export_utils.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import json
|
||||
import os
|
||||
import aiohttp
|
||||
from typing import Literal
|
||||
from fastapi import HTTPException
|
||||
from pathvalidate import sanitize_filename
|
||||
|
||||
from models.pptx_models import PptxPresentationModel
|
||||
from models.generate_presentation_api import PresentationAndPath
|
||||
from services.pptx_presentation_creator import PptxPresentationCreator
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from utils.asset_directory_utils import get_exports_directory
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
async def export_presentation(
|
||||
presentation_id: str, title: str, export_as: Literal["pptx", "pdf"]
|
||||
) -> PresentationAndPath:
|
||||
if export_as == "pptx":
|
||||
|
||||
# Get the converted PPTX model from the Next.js service
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost/api/presentation_to_pptx_model?id={presentation_id}"
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
print(f"Failed to get PPTX model: {error_text}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to convert presentation to PPTX model",
|
||||
)
|
||||
pptx_model_data = await response.json()
|
||||
|
||||
# Create PPTX file using the converted model
|
||||
pptx_model = PptxPresentationModel(**pptx_model_data)
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
pptx_creator = PptxPresentationCreator(pptx_model, temp_dir)
|
||||
await pptx_creator.create_ppt()
|
||||
|
||||
export_directory = get_exports_directory()
|
||||
pptx_path = os.path.join(
|
||||
export_directory,
|
||||
f"{sanitize_filename(title or get_random_uuid())}.pptx",
|
||||
)
|
||||
pptx_creator.save(pptx_path)
|
||||
|
||||
return PresentationAndPath(
|
||||
presentation_id=presentation_id,
|
||||
path=pptx_path,
|
||||
)
|
||||
else:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"http://localhost/api/export-as-pdf",
|
||||
json={
|
||||
"id": presentation_id,
|
||||
"title": sanitize_filename(title or get_random_uuid()),
|
||||
},
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
return PresentationAndPath(
|
||||
presentation_id=presentation_id,
|
||||
path=response_json["path"],
|
||||
)
|
||||
|
|
@ -49,7 +49,7 @@ export class DashboardApi {
|
|||
static async getPresentation(id: string) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/v1/ppt/presentation/?id=${id}`,
|
||||
`/api/v1/ppt/presentation?id=${id}`,
|
||||
{
|
||||
method: "GET",
|
||||
}
|
||||
|
|
@ -65,7 +65,7 @@ export class DashboardApi {
|
|||
static async deletePresentation(presentation_id: string) {
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/v1/ppt/presentation/?id=${presentation_id}`,
|
||||
`/api/v1/ppt/presentation?id=${presentation_id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
headers: getHeader(),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue