Merge branch 'main' into feat/mock-endpoints
This commit is contained in:
commit
527d3eb890
9 changed files with 598 additions and 204 deletions
|
|
@ -2,8 +2,9 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from api.lifespan import app_lifespan
|
||||
from api.middlewares import UserConfigEnvUpdateMiddleware
|
||||
from api.v1.mock.router import API_V1_MOCK_ROUTER
|
||||
from api.v1.ppt.router import API_V1_PPT_ROUTER
|
||||
from api.v1.webhook.router import API_V1_WEBHOOK_ROUTER
|
||||
from api.v1.mock.router import API_V1_MOCK_ROUTER
|
||||
|
||||
|
||||
app = FastAPI(lifespan=app_lifespan)
|
||||
|
|
@ -11,6 +12,7 @@ app = FastAPI(lifespan=app_lifespan)
|
|||
|
||||
# Routers
|
||||
app.include_router(API_V1_PPT_ROUTER)
|
||||
app.include_router(API_V1_WEBHOOK_ROUTER)
|
||||
app.include_router(API_V1_MOCK_ROUTER)
|
||||
|
||||
# Middlewares
|
||||
|
|
|
|||
|
|
@ -1,16 +1,20 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from typing import Annotated, List, Literal, Optional
|
||||
from typing import Annotated, List, Literal, Optional, Tuple
|
||||
import dirtyjson
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
from constants.presentation import DEFAULT_TEMPLATES
|
||||
from enums.webhook_event import WebhookEvent
|
||||
from models.api_error_model import APIErrorModel
|
||||
from models.generate_presentation_request import GeneratePresentationRequest
|
||||
from models.presentation_and_path import PresentationPathAndEditPath
|
||||
from models.presentation_from_template import EditPresentationRequest
|
||||
|
|
@ -26,8 +30,10 @@ from models.presentation_structure_model import PresentationStructureModel
|
|||
from models.presentation_with_slides import (
|
||||
PresentationWithSlides,
|
||||
)
|
||||
from models.sql.template import TemplateModel
|
||||
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from services.webhook_service import WebhookService
|
||||
from utils.get_layout_by_name import get_layout_by_name
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.dict_utils import deep_update
|
||||
|
|
@ -38,8 +44,12 @@ from models.sse_response import SSECompleteResponse, SSEErrorResponse, SSERespon
|
|||
|
||||
from services.database import get_async_session
|
||||
from services.temp_file_service import TEMP_FILE_SERVICE
|
||||
from services.concurrent_service import CONCURRENT_SERVICE
|
||||
from models.sql.presentation import PresentationModel
|
||||
from services.pptx_presentation_creator import PptxPresentationCreator
|
||||
from models.sql.async_presentation_generation_status import (
|
||||
AsyncPresentationGenerationTaskModel,
|
||||
)
|
||||
from utils.asset_directory_utils import get_exports_directory, get_images_directory
|
||||
from utils.llm_calls.generate_presentation_structure import (
|
||||
generate_presentation_structure,
|
||||
|
|
@ -434,237 +444,427 @@ async def export_presentation_as_pptx_or_pdf(
|
|||
)
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath)
|
||||
async def generate_presentation_api(
|
||||
async def check_if_api_request_is_valid(
|
||||
request: GeneratePresentationRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
) -> Tuple[uuid.UUID,]:
|
||||
presentation_id = uuid.uuid4()
|
||||
print(f"Presentation ID: {presentation_id}")
|
||||
|
||||
using_slides_markdown = False
|
||||
# Making sure either content, slides markdown or files is provided
|
||||
if not (request.content or request.slides_markdown or request.files):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either content or slides markdown or files is required to generate presentation",
|
||||
)
|
||||
|
||||
if request.slides_markdown:
|
||||
using_slides_markdown = True
|
||||
request.n_slides = len(request.slides_markdown)
|
||||
# Making sure number of slides is greater than 0
|
||||
if request.n_slides <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of slides must be greater than 0",
|
||||
)
|
||||
|
||||
if not using_slides_markdown:
|
||||
additional_context = ""
|
||||
|
||||
if request.files:
|
||||
documents_loader = DocumentsLoader(file_paths=request.files)
|
||||
await documents_loader.load_documents()
|
||||
documents = documents_loader.documents
|
||||
if documents:
|
||||
additional_context = "\n\n".join(documents)
|
||||
|
||||
# Finding number of slides to generate by considering table of contents
|
||||
n_slides_to_generate = request.n_slides
|
||||
if request.include_table_of_contents:
|
||||
needed_toc_count = math.ceil(
|
||||
(
|
||||
(request.n_slides - 1)
|
||||
if request.include_title_slide
|
||||
else request.n_slides
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
n_slides_to_generate -= math.ceil(
|
||||
(request.n_slides - needed_toc_count) / 10
|
||||
)
|
||||
|
||||
presentation_outlines_text = ""
|
||||
async for chunk in generate_ppt_outline(
|
||||
request.content,
|
||||
n_slides_to_generate,
|
||||
request.language,
|
||||
additional_context,
|
||||
request.tone.value,
|
||||
request.verbosity.value,
|
||||
request.instructions,
|
||||
request.include_title_slide,
|
||||
request.web_search,
|
||||
):
|
||||
|
||||
if isinstance(chunk, HTTPException):
|
||||
raise chunk
|
||||
|
||||
presentation_outlines_text += chunk
|
||||
|
||||
try:
|
||||
presentation_outlines_json = dict(
|
||||
dirtyjson.loads(presentation_outlines_text)
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# Checking if template is valid
|
||||
if request.template not in DEFAULT_TEMPLATES:
|
||||
request.template = request.template.lower()
|
||||
if not request.template.startswith("custom-"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to generate presentation outlines. Please try again.",
|
||||
detail="Template not found. Please use a valid template.",
|
||||
)
|
||||
presentation_outlines = PresentationOutlineModel(**presentation_outlines_json)
|
||||
total_outlines = n_slides_to_generate
|
||||
|
||||
else:
|
||||
# Setting outlines to slides markdown
|
||||
presentation_outlines = PresentationOutlineModel(
|
||||
slides=[
|
||||
SlideOutlineModel(content=slide) for slide in request.slides_markdown
|
||||
]
|
||||
)
|
||||
total_outlines = len(request.slides_markdown)
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Generated {total_outlines} outlines for the presentation")
|
||||
|
||||
# Parse Layouts
|
||||
layout_model = await get_layout_by_name(request.template)
|
||||
total_slide_layouts = len(layout_model.slides)
|
||||
|
||||
# Generate Structure
|
||||
if layout_model.ordered:
|
||||
presentation_structure = layout_model.to_presentation_structure()
|
||||
else:
|
||||
presentation_structure: PresentationStructureModel = (
|
||||
await generate_presentation_structure(
|
||||
presentation_outlines,
|
||||
layout_model,
|
||||
request.instructions,
|
||||
using_slides_markdown,
|
||||
template_id = request.template.replace("custom-", "")
|
||||
try:
|
||||
template = await sql_session.get(TemplateModel, uuid.UUID(template_id))
|
||||
if not template:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Template not found. Please use a valid template.",
|
||||
)
|
||||
)
|
||||
|
||||
presentation_structure.slides = presentation_structure.slides[:total_outlines]
|
||||
for index in range(total_outlines):
|
||||
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
||||
if index >= total_outlines:
|
||||
presentation_structure.slides.append(random_slide_index)
|
||||
continue
|
||||
if presentation_structure.slides[index] >= total_slide_layouts:
|
||||
presentation_structure.slides[index] = random_slide_index
|
||||
return presentation_id
|
||||
|
||||
# Injecting table of contents to the presentation structure and outlines
|
||||
if request.include_table_of_contents and not using_slides_markdown:
|
||||
n_toc_slides = request.n_slides - total_outlines
|
||||
toc_slide_layout_index = select_toc_or_list_slide_layout_index(layout_model)
|
||||
if toc_slide_layout_index != -1:
|
||||
outline_index = 1 if request.include_title_slide else 0
|
||||
for i in range(n_toc_slides):
|
||||
outlines_to = outline_index + 10
|
||||
if total_outlines == outlines_to:
|
||||
outlines_to -= 1
|
||||
|
||||
presentation_structure.slides.insert(
|
||||
i + 1 if request.include_title_slide else i,
|
||||
toc_slide_layout_index,
|
||||
)
|
||||
toc_outline = f"Table of Contents\n\n"
|
||||
async def generate_presentation_handler(
|
||||
request: GeneratePresentationRequest,
|
||||
presentation_id: uuid.UUID,
|
||||
async_status: Optional[AsyncPresentationGenerationTaskModel],
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
try:
|
||||
using_slides_markdown = False
|
||||
|
||||
for outline in presentation_outlines.slides[outline_index:outlines_to]:
|
||||
page_number = (
|
||||
outline_index - i + n_toc_slides + 1
|
||||
if request.slides_markdown:
|
||||
using_slides_markdown = True
|
||||
request.n_slides = len(request.slides_markdown)
|
||||
|
||||
if not using_slides_markdown:
|
||||
additional_context = ""
|
||||
|
||||
# Updating async status
|
||||
if async_status:
|
||||
async_status.message = "Generating presentation outlines"
|
||||
async_status.updated_at = datetime.now()
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
if request.files:
|
||||
documents_loader = DocumentsLoader(file_paths=request.files)
|
||||
await documents_loader.load_documents()
|
||||
documents = documents_loader.documents
|
||||
if documents:
|
||||
additional_context = "\n\n".join(documents)
|
||||
|
||||
# Finding number of slides to generate by considering table of contents
|
||||
n_slides_to_generate = request.n_slides
|
||||
if request.include_table_of_contents:
|
||||
needed_toc_count = math.ceil(
|
||||
(
|
||||
(request.n_slides - 1)
|
||||
if request.include_title_slide
|
||||
else outline_index - i + n_toc_slides
|
||||
else request.n_slides
|
||||
)
|
||||
toc_outline += f"Slide page number: {page_number}\n Slide Content: {outline.content[:100]}\n\n"
|
||||
outline_index += 1
|
||||
|
||||
outline_index += 1
|
||||
|
||||
presentation_outlines.slides.insert(
|
||||
i + 1 if request.include_title_slide else i,
|
||||
SlideOutlineModel(
|
||||
content=toc_outline,
|
||||
),
|
||||
/ 10
|
||||
)
|
||||
n_slides_to_generate -= math.ceil(
|
||||
(request.n_slides - needed_toc_count) / 10
|
||||
)
|
||||
|
||||
# Create PresentationModel
|
||||
presentation = PresentationModel(
|
||||
id=presentation_id,
|
||||
content=request.content,
|
||||
n_slides=request.n_slides,
|
||||
language=request.language,
|
||||
title=get_presentation_title_from_outlines(presentation_outlines),
|
||||
outlines=presentation_outlines.model_dump(),
|
||||
layout=layout_model.model_dump(),
|
||||
structure=presentation_structure.model_dump(),
|
||||
tone=request.tone,
|
||||
verbosity=request.verbosity,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
async_assets_generation_tasks = []
|
||||
|
||||
# 7. Generate slide content concurrently (batched), then build slides and fetch assets
|
||||
slides: List[SlideModel] = []
|
||||
|
||||
slide_layout_indices = presentation_structure.slides
|
||||
slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices]
|
||||
|
||||
# Schedule slide content generation and asset fetching in batches of 10
|
||||
batch_size = 10
|
||||
for start in range(0, len(slide_layouts), batch_size):
|
||||
end = min(start + batch_size, len(slide_layouts))
|
||||
|
||||
print(f"Generating slides from {start} to {end}")
|
||||
|
||||
# Generate contents for this batch concurrently
|
||||
content_tasks = [
|
||||
get_slide_content_from_type_and_outline(
|
||||
slide_layouts[i],
|
||||
presentation_outlines.slides[i],
|
||||
presentation_outlines_text = ""
|
||||
async for chunk in generate_ppt_outline(
|
||||
request.content,
|
||||
n_slides_to_generate,
|
||||
request.language,
|
||||
additional_context,
|
||||
request.tone.value,
|
||||
request.verbosity.value,
|
||||
request.instructions,
|
||||
request.include_title_slide,
|
||||
request.web_search,
|
||||
):
|
||||
|
||||
if isinstance(chunk, HTTPException):
|
||||
raise chunk
|
||||
|
||||
presentation_outlines_text += chunk
|
||||
|
||||
try:
|
||||
presentation_outlines_json = dict(
|
||||
dirtyjson.loads(presentation_outlines_text)
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to generate presentation outlines. Please try again.",
|
||||
)
|
||||
presentation_outlines = PresentationOutlineModel(
|
||||
**presentation_outlines_json
|
||||
)
|
||||
for i in range(start, end)
|
||||
]
|
||||
batch_contents: List[dict] = await asyncio.gather(*content_tasks)
|
||||
total_outlines = n_slides_to_generate
|
||||
|
||||
# Build slides for this batch
|
||||
batch_slides: List[SlideModel] = []
|
||||
for offset, slide_content in enumerate(batch_contents):
|
||||
i = start + offset
|
||||
slide_layout = slide_layouts[i]
|
||||
slide = SlideModel(
|
||||
presentation=presentation_id,
|
||||
layout_group=layout_model.name,
|
||||
layout=slide_layout.id,
|
||||
index=i,
|
||||
speaker_note=slide_content.get("__speaker_note__"),
|
||||
content=slide_content,
|
||||
else:
|
||||
# Setting outlines to slides markdown
|
||||
presentation_outlines = PresentationOutlineModel(
|
||||
slides=[
|
||||
SlideOutlineModel(content=slide)
|
||||
for slide in request.slides_markdown
|
||||
]
|
||||
)
|
||||
slides.append(slide)
|
||||
batch_slides.append(slide)
|
||||
total_outlines = len(request.slides_markdown)
|
||||
|
||||
# Start asset fetch tasks for just-generated slides so they run while next batch is processed
|
||||
asset_tasks = [
|
||||
process_slide_and_fetch_assets(image_generation_service, slide)
|
||||
for slide in batch_slides
|
||||
]
|
||||
async_assets_generation_tasks.extend(asset_tasks)
|
||||
# Updating async status
|
||||
if async_status:
|
||||
async_status.message = f"Selecting layout for each slide"
|
||||
async_status.updated_at = datetime.now()
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
# Run all asset tasks concurrently while batches may still be generating content
|
||||
generated_assets_list = await asyncio.gather(*async_assets_generation_tasks)
|
||||
generated_assets = []
|
||||
for assets_list in generated_assets_list:
|
||||
generated_assets.extend(assets_list)
|
||||
print("-" * 40)
|
||||
print(f"Generated {total_outlines} outlines for the presentation")
|
||||
|
||||
# 8. Save PresentationModel and Slides
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slides)
|
||||
sql_session.add_all(generated_assets)
|
||||
await sql_session.commit()
|
||||
# Parse Layouts
|
||||
layout_model = await get_layout_by_name(request.template)
|
||||
total_slide_layouts = len(layout_model.slides)
|
||||
|
||||
# 9. Export
|
||||
presentation_and_path = await export_presentation(
|
||||
presentation_id, presentation.title or str(uuid.uuid4()), request.export_as
|
||||
)
|
||||
# Generate Structure
|
||||
if layout_model.ordered:
|
||||
presentation_structure = layout_model.to_presentation_structure()
|
||||
else:
|
||||
presentation_structure: PresentationStructureModel = (
|
||||
await generate_presentation_structure(
|
||||
presentation_outlines,
|
||||
layout_model,
|
||||
request.instructions,
|
||||
using_slides_markdown,
|
||||
)
|
||||
)
|
||||
|
||||
return PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={presentation_id}",
|
||||
)
|
||||
presentation_structure.slides = presentation_structure.slides[:total_outlines]
|
||||
for index in range(total_outlines):
|
||||
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
||||
if index >= total_outlines:
|
||||
presentation_structure.slides.append(random_slide_index)
|
||||
continue
|
||||
if presentation_structure.slides[index] >= total_slide_layouts:
|
||||
presentation_structure.slides[index] = random_slide_index
|
||||
|
||||
# Injecting table of contents to the presentation structure and outlines
|
||||
if request.include_table_of_contents and not using_slides_markdown:
|
||||
n_toc_slides = request.n_slides - total_outlines
|
||||
toc_slide_layout_index = select_toc_or_list_slide_layout_index(layout_model)
|
||||
if toc_slide_layout_index != -1:
|
||||
outline_index = 1 if request.include_title_slide else 0
|
||||
for i in range(n_toc_slides):
|
||||
outlines_to = outline_index + 10
|
||||
if total_outlines == outlines_to:
|
||||
outlines_to -= 1
|
||||
|
||||
presentation_structure.slides.insert(
|
||||
i + 1 if request.include_title_slide else i,
|
||||
toc_slide_layout_index,
|
||||
)
|
||||
toc_outline = f"Table of Contents\n\n"
|
||||
|
||||
for outline in presentation_outlines.slides[
|
||||
outline_index:outlines_to
|
||||
]:
|
||||
page_number = (
|
||||
outline_index - i + n_toc_slides + 1
|
||||
if request.include_title_slide
|
||||
else outline_index - i + n_toc_slides
|
||||
)
|
||||
toc_outline += f"Slide page number: {page_number}\n Slide Content: {outline.content[:100]}\n\n"
|
||||
outline_index += 1
|
||||
|
||||
outline_index += 1
|
||||
|
||||
presentation_outlines.slides.insert(
|
||||
i + 1 if request.include_title_slide else i,
|
||||
SlideOutlineModel(
|
||||
content=toc_outline,
|
||||
),
|
||||
)
|
||||
|
||||
# Create PresentationModel
|
||||
presentation = PresentationModel(
|
||||
id=presentation_id,
|
||||
content=request.content,
|
||||
n_slides=request.n_slides,
|
||||
language=request.language,
|
||||
title=get_presentation_title_from_outlines(presentation_outlines),
|
||||
outlines=presentation_outlines.model_dump(),
|
||||
layout=layout_model.model_dump(),
|
||||
structure=presentation_structure.model_dump(),
|
||||
tone=request.tone.value,
|
||||
verbosity=request.verbosity.value,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
# Updating async status
|
||||
if async_status:
|
||||
async_status.message = "Generating slides"
|
||||
async_status.updated_at = datetime.now()
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
async_assets_generation_tasks = []
|
||||
|
||||
# 7. Generate slide content concurrently (batched), then build slides and fetch assets
|
||||
slides: List[SlideModel] = []
|
||||
|
||||
slide_layout_indices = presentation_structure.slides
|
||||
slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices]
|
||||
|
||||
# Schedule slide content generation and asset fetching in batches of 10
|
||||
batch_size = 10
|
||||
for start in range(0, len(slide_layouts), batch_size):
|
||||
end = min(start + batch_size, len(slide_layouts))
|
||||
|
||||
print(f"Generating slides from {start} to {end}")
|
||||
|
||||
# Generate contents for this batch concurrently
|
||||
content_tasks = [
|
||||
get_slide_content_from_type_and_outline(
|
||||
slide_layouts[i],
|
||||
presentation_outlines.slides[i],
|
||||
request.language,
|
||||
request.tone.value,
|
||||
request.verbosity.value,
|
||||
request.instructions,
|
||||
)
|
||||
for i in range(start, end)
|
||||
]
|
||||
batch_contents: List[dict] = await asyncio.gather(*content_tasks)
|
||||
|
||||
# Build slides for this batch
|
||||
batch_slides: List[SlideModel] = []
|
||||
for offset, slide_content in enumerate(batch_contents):
|
||||
i = start + offset
|
||||
slide_layout = slide_layouts[i]
|
||||
slide = SlideModel(
|
||||
presentation=presentation_id,
|
||||
layout_group=layout_model.name,
|
||||
layout=slide_layout.id,
|
||||
index=i,
|
||||
speaker_note=slide_content.get("__speaker_note__"),
|
||||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
batch_slides.append(slide)
|
||||
|
||||
# Start asset fetch tasks for just-generated slides so they run while next batch is processed
|
||||
asset_tasks = [
|
||||
process_slide_and_fetch_assets(image_generation_service, slide)
|
||||
for slide in batch_slides
|
||||
]
|
||||
async_assets_generation_tasks.extend(asset_tasks)
|
||||
|
||||
if async_status:
|
||||
async_status.message = "Fetching assets for slides"
|
||||
async_status.updated_at = datetime.now()
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
# Run all asset tasks concurrently while batches may still be generating content
|
||||
generated_assets_list = await asyncio.gather(*async_assets_generation_tasks)
|
||||
generated_assets = []
|
||||
for assets_list in generated_assets_list:
|
||||
generated_assets.extend(assets_list)
|
||||
|
||||
# 8. Save PresentationModel and Slides
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slides)
|
||||
sql_session.add_all(generated_assets)
|
||||
await sql_session.commit()
|
||||
|
||||
if async_status:
|
||||
async_status.message = "Exporting presentation"
|
||||
async_status.updated_at = datetime.now()
|
||||
sql_session.add(async_status)
|
||||
|
||||
# 9. Export
|
||||
presentation_and_path = await export_presentation(
|
||||
presentation_id, presentation.title or str(uuid.uuid4()), request.export_as
|
||||
)
|
||||
|
||||
response = PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={presentation_id}",
|
||||
)
|
||||
|
||||
if async_status:
|
||||
async_status.message = "Presentation generation completed"
|
||||
async_status.status = "completed"
|
||||
async_status.data = response.model_dump(mode="json")
|
||||
async_status.updated_at = datetime.now()
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
# Triggering webhook on success
|
||||
CONCURRENT_SERVICE.run_task(
|
||||
None,
|
||||
WebhookService.send_webhook,
|
||||
WebhookEvent.PRESENTATION_GENERATION_COMPLETED,
|
||||
response.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, HTTPException):
|
||||
traceback.print_exc()
|
||||
e = HTTPException(status_code=500, detail="Presentation generation failed")
|
||||
|
||||
api_error_model = APIErrorModel.from_exception(e)
|
||||
|
||||
# Triggering webhook on failure
|
||||
CONCURRENT_SERVICE.run_task(
|
||||
None,
|
||||
WebhookService.send_webhook,
|
||||
WebhookEvent.PRESENTATION_GENERATION_FAILED,
|
||||
api_error_model.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
if async_status:
|
||||
async_status.status = "error"
|
||||
async_status.message = "Presentation generation failed"
|
||||
async_status.updated_at = datetime.now()
|
||||
async_status.error = api_error_model.model_dump(mode="json")
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath)
|
||||
async def generate_presentation_sync(
|
||||
request: GeneratePresentationRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
try:
|
||||
(presentation_id,) = await check_if_api_request_is_valid(request, sql_session)
|
||||
return await generate_presentation_handler(
|
||||
request, presentation_id, None, sql_session
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail="Presentation generation failed")
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post(
|
||||
"/generate/async", response_model=AsyncPresentationGenerationTaskModel
|
||||
)
|
||||
async def generate_presentation_async(
|
||||
request: GeneratePresentationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
try:
|
||||
(presentation_id,) = await check_if_api_request_is_valid(request, sql_session)
|
||||
|
||||
async_status = AsyncPresentationGenerationTaskModel(
|
||||
status="pending",
|
||||
message="Queued for generation",
|
||||
data=None,
|
||||
)
|
||||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
background_tasks.add_task(
|
||||
generate_presentation_handler,
|
||||
request,
|
||||
presentation_id,
|
||||
async_status=async_status,
|
||||
sql_session=sql_session,
|
||||
)
|
||||
return async_status
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, HTTPException):
|
||||
print(e)
|
||||
e = HTTPException(status_code=500, detail="Presentation generation failed")
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.get(
|
||||
"/status/{id}", response_model=AsyncPresentationGenerationTaskModel
|
||||
)
|
||||
async def check_async_presentation_generation_status(
|
||||
id: str = Path(description="ID of the presentation generation task"),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
status = await sql_session.get(AsyncPresentationGenerationTaskModel, id)
|
||||
if not status:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No presentation generation task found"
|
||||
)
|
||||
return status
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/edit", response_model=PresentationPathAndEditPath)
|
||||
|
|
|
|||
54
servers/fastapi/api/v1/webhook/router.py
Normal file
54
servers/fastapi/api/v1/webhook/router.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from enums.webhook_event import WebhookEvent
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from services.database import get_async_session
|
||||
|
||||
API_V1_WEBHOOK_ROUTER = APIRouter(prefix="/api/v1/webhook", tags=["Webhook"])
|
||||
|
||||
|
||||
class SubscribeToWebhookRequest(BaseModel):
|
||||
url: str = Field(description="The URL to send the webhook to")
|
||||
secret: Optional[str] = Field(None, description="The secret to use for the webhook")
|
||||
event: WebhookEvent = Field(description="The event to subscribe to")
|
||||
|
||||
|
||||
class SubscribeToWebhookResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
|
||||
|
||||
@API_V1_WEBHOOK_ROUTER.post(
|
||||
"/subscribe", response_model=SubscribeToWebhookResponse, status_code=201
|
||||
)
|
||||
async def subscribe_to_webhook(
|
||||
body: SubscribeToWebhookRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
webhook_subscription = WebhookSubscription(
|
||||
url=body.url,
|
||||
secret=body.secret,
|
||||
event=body.event,
|
||||
)
|
||||
sql_session.add(webhook_subscription)
|
||||
await sql_session.commit()
|
||||
return SubscribeToWebhookResponse(id=webhook_subscription.id)
|
||||
|
||||
|
||||
@API_V1_WEBHOOK_ROUTER.delete("/unsubscribe", status_code=204)
|
||||
async def unsubscribe_to_webhook(
|
||||
id: uuid.UUID = Body(
|
||||
embed=True, description="The ID of the webhook subscription to unsubscribe from"
|
||||
),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
|
||||
webhook_subscription = await sql_session.get(WebhookSubscription, id)
|
||||
if not webhook_subscription:
|
||||
raise HTTPException(404, "Webhook subscription not found")
|
||||
|
||||
await sql_session.delete(webhook_subscription)
|
||||
await sql_session.commit()
|
||||
1
servers/fastapi/constants/presentation.py
Normal file
1
servers/fastapi/constants/presentation.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
DEFAULT_TEMPLATES = ["general", "modern", "standard", "swift"]
|
||||
6
servers/fastapi/enums/webhook_event.py
Normal file
6
servers/fastapi/enums/webhook_event.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class WebhookEvent(str, Enum):
|
||||
PRESENTATION_GENERATION_COMPLETED = "presentation.generation.completed"
|
||||
PRESENTATION_GENERATION_FAILED = "presentation.generation.failed"
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from datetime import datetime
|
||||
import secrets
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class AsyncPresentationGenerationTaskModel(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "async_presentation_generation_tasks"
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: f"task-{secrets.token_hex(32)}", primary_key=True
|
||||
)
|
||||
status: str
|
||||
message: Optional[str] = None
|
||||
error: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
data: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
19
servers/fastapi/models/sql/webhook_subscription.py
Normal file
19
servers/fastapi/models/sql/webhook_subscription.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlmodel import Column, DateTime, Field, SQLModel
|
||||
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
||||
class WebhookSubscription(SQLModel, table=True):
|
||||
__tablename__ = "webhook_subscriptions"
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False),
|
||||
default_factory=get_current_utc_datetime,
|
||||
)
|
||||
url: str
|
||||
secret: Optional[str] = None
|
||||
event: str = Field(index=True)
|
||||
35
servers/fastapi/services/concurrent_service.py
Normal file
35
servers/fastapi/services/concurrent_service.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import asyncio
|
||||
from asyncio import Task
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
|
||||
class ConcurrentService:
|
||||
def __init__(self):
|
||||
self._background_tasks = set[Task]()
|
||||
|
||||
def run_task(
|
||||
self,
|
||||
delay: Optional[int],
|
||||
callable: Callable[..., Coroutine[Any, Any, Any]],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
async def wrapper():
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
await callable(*args, **kwargs)
|
||||
|
||||
task = asyncio.create_task(wrapper())
|
||||
|
||||
print(f"Running task: {task} - executing {callable.__name__}")
|
||||
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self.on_task_done)
|
||||
|
||||
def on_task_done(self, task: Task):
|
||||
print(f"Task done: {task}")
|
||||
|
||||
self._background_tasks.discard(task)
|
||||
|
||||
|
||||
CONCURRENT_SERVICE = ConcurrentService()
|
||||
55
servers/fastapi/services/webhook_service.py
Normal file
55
servers/fastapi/services/webhook_service.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import asyncio
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from enums.webhook_event import WebhookEvent
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from services.database import get_async_session
|
||||
|
||||
|
||||
class WebhookService:
|
||||
|
||||
@classmethod
|
||||
async def send_webhook(cls, event: WebhookEvent, data: dict):
|
||||
async for sql_session in get_async_session():
|
||||
webhook_subscriptions = await sql_session.scalars(
|
||||
select(WebhookSubscription).where(
|
||||
WebhookSubscription.event == event.value
|
||||
)
|
||||
)
|
||||
webhook_subscriptions = list(webhook_subscriptions)
|
||||
if not webhook_subscriptions:
|
||||
return
|
||||
|
||||
async_tasks = []
|
||||
for webhook_subscription in webhook_subscriptions:
|
||||
async_tasks.append(
|
||||
cls.send_request_to_webhook(webhook_subscription, data)
|
||||
)
|
||||
|
||||
await asyncio.gather(*async_tasks)
|
||||
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def send_request_to_webhook(
|
||||
cls, subscription: WebhookSubscription, data: dict
|
||||
):
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if subscription.secret:
|
||||
headers["Authorization"] = f"Bearer {subscription.secret}"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
subscription.url,
|
||||
json=data,
|
||||
headers=headers,
|
||||
) as _:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error sending request to webhook {subscription.id}: {e}")
|
||||
pass
|
||||
Loading…
Add table
Reference in a new issue