diff --git a/servers/fastapi/api/main.py b/servers/fastapi/api/main.py index ae866738..d207a76e 100644 --- a/servers/fastapi/api/main.py +++ b/servers/fastapi/api/main.py @@ -2,8 +2,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from api.lifespan import app_lifespan from api.middlewares import UserConfigEnvUpdateMiddleware -from api.v1.mock.router import API_V1_MOCK_ROUTER from api.v1.ppt.router import API_V1_PPT_ROUTER +from api.v1.webhook.router import API_V1_WEBHOOK_ROUTER +from api.v1.mock.router import API_V1_MOCK_ROUTER app = FastAPI(lifespan=app_lifespan) @@ -11,6 +12,7 @@ app = FastAPI(lifespan=app_lifespan) # Routers app.include_router(API_V1_PPT_ROUTER) +app.include_router(API_V1_WEBHOOK_ROUTER) app.include_router(API_V1_MOCK_ROUTER) # Middlewares diff --git a/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/servers/fastapi/api/v1/ppt/endpoints/presentation.py index 4cd92cb3..11f098b9 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/presentation.py +++ b/servers/fastapi/api/v1/ppt/endpoints/presentation.py @@ -1,16 +1,20 @@ import asyncio +from datetime import datetime import json import math import os import random import traceback -from typing import Annotated, List, Literal, Optional +from typing import Annotated, List, Literal, Optional, Tuple import dirtyjson -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Path from fastapi.responses import StreamingResponse from sqlalchemy import delete from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import select +from constants.presentation import DEFAULT_TEMPLATES +from enums.webhook_event import WebhookEvent +from models.api_error_model import APIErrorModel from models.generate_presentation_request import GeneratePresentationRequest from models.presentation_and_path import PresentationPathAndEditPath from models.presentation_from_template import EditPresentationRequest @@ -26,8 +30,10 @@ from models.presentation_structure_model import PresentationStructureModel from models.presentation_with_slides import ( PresentationWithSlides, ) +from models.sql.template import TemplateModel from services.documents_loader import DocumentsLoader +from services.webhook_service import WebhookService from utils.get_layout_by_name import get_layout_by_name from services.image_generation_service import ImageGenerationService from utils.dict_utils import deep_update @@ -38,8 +44,12 @@ from models.sse_response import SSECompleteResponse, SSEErrorResponse, SSERespon from services.database import get_async_session from services.temp_file_service import TEMP_FILE_SERVICE +from services.concurrent_service import CONCURRENT_SERVICE from models.sql.presentation import PresentationModel from services.pptx_presentation_creator import PptxPresentationCreator +from models.sql.async_presentation_generation_status import ( + AsyncPresentationGenerationTaskModel, +) from utils.asset_directory_utils import get_exports_directory, get_images_directory from utils.llm_calls.generate_presentation_structure import ( generate_presentation_structure, @@ -434,237 +444,427 @@ async def export_presentation_as_pptx_or_pdf( ) -@PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath) -async def generate_presentation_api( +async def check_if_api_request_is_valid( request: GeneratePresentationRequest, sql_session: AsyncSession = Depends(get_async_session), -): +) -> Tuple[uuid.UUID,]: presentation_id = uuid.uuid4() + print(f"Presentation ID: {presentation_id}") - using_slides_markdown = False + # Making sure either content, slides markdown or files is provided + if not (request.content or request.slides_markdown or request.files): + raise HTTPException( + status_code=400, + detail="Either content or slides markdown or files is required to generate presentation", + ) - if request.slides_markdown: - using_slides_markdown = True - request.n_slides = len(request.slides_markdown) + # Making sure number of slides is greater than 0 + if request.n_slides <= 0: + raise HTTPException( + status_code=400, + detail="Number of slides must be greater than 0", + ) - if not using_slides_markdown: - additional_context = "" - - if request.files: - documents_loader = DocumentsLoader(file_paths=request.files) - await documents_loader.load_documents() - documents = documents_loader.documents - if documents: - additional_context = "\n\n".join(documents) - - # Finding number of slides to generate by considering table of contents - n_slides_to_generate = request.n_slides - if request.include_table_of_contents: - needed_toc_count = math.ceil( - ( - (request.n_slides - 1) - if request.include_title_slide - else request.n_slides - ) - / 10 - ) - n_slides_to_generate -= math.ceil( - (request.n_slides - needed_toc_count) / 10 - ) - - presentation_outlines_text = "" - async for chunk in generate_ppt_outline( - request.content, - n_slides_to_generate, - request.language, - additional_context, - request.tone.value, - request.verbosity.value, - request.instructions, - request.include_title_slide, - request.web_search, - ): - - if isinstance(chunk, HTTPException): - raise chunk - - presentation_outlines_text += chunk - - try: - presentation_outlines_json = dict( - dirtyjson.loads(presentation_outlines_text) - ) - except Exception as e: - traceback.print_exc() + # Checking if template is valid + if request.template not in DEFAULT_TEMPLATES: + request.template = request.template.lower() + if not request.template.startswith("custom-"): raise HTTPException( status_code=400, - detail="Failed to generate presentation outlines. Please try again.", + detail="Template not found. Please use a valid template.", ) - presentation_outlines = PresentationOutlineModel(**presentation_outlines_json) - total_outlines = n_slides_to_generate - - else: - # Setting outlines to slides markdown - presentation_outlines = PresentationOutlineModel( - slides=[ - SlideOutlineModel(content=slide) for slide in request.slides_markdown - ] - ) - total_outlines = len(request.slides_markdown) - - print("-" * 40) - print(f"Generated {total_outlines} outlines for the presentation") - - # Parse Layouts - layout_model = await get_layout_by_name(request.template) - total_slide_layouts = len(layout_model.slides) - - # Generate Structure - if layout_model.ordered: - presentation_structure = layout_model.to_presentation_structure() - else: - presentation_structure: PresentationStructureModel = ( - await generate_presentation_structure( - presentation_outlines, - layout_model, - request.instructions, - using_slides_markdown, + template_id = request.template.replace("custom-", "") + try: + template = await sql_session.get(TemplateModel, uuid.UUID(template_id)) + if not template: + raise Exception() + except Exception as e: + raise HTTPException( + status_code=400, + detail="Template not found. Please use a valid template.", ) - ) - presentation_structure.slides = presentation_structure.slides[:total_outlines] - for index in range(total_outlines): - random_slide_index = random.randint(0, total_slide_layouts - 1) - if index >= total_outlines: - presentation_structure.slides.append(random_slide_index) - continue - if presentation_structure.slides[index] >= total_slide_layouts: - presentation_structure.slides[index] = random_slide_index + return presentation_id - # Injecting table of contents to the presentation structure and outlines - if request.include_table_of_contents and not using_slides_markdown: - n_toc_slides = request.n_slides - total_outlines - toc_slide_layout_index = select_toc_or_list_slide_layout_index(layout_model) - if toc_slide_layout_index != -1: - outline_index = 1 if request.include_title_slide else 0 - for i in range(n_toc_slides): - outlines_to = outline_index + 10 - if total_outlines == outlines_to: - outlines_to -= 1 - presentation_structure.slides.insert( - i + 1 if request.include_title_slide else i, - toc_slide_layout_index, - ) - toc_outline = f"Table of Contents\n\n" +async def generate_presentation_handler( + request: GeneratePresentationRequest, + presentation_id: uuid.UUID, + async_status: Optional[AsyncPresentationGenerationTaskModel], + sql_session: AsyncSession = Depends(get_async_session), +): + try: + using_slides_markdown = False - for outline in presentation_outlines.slides[outline_index:outlines_to]: - page_number = ( - outline_index - i + n_toc_slides + 1 + if request.slides_markdown: + using_slides_markdown = True + request.n_slides = len(request.slides_markdown) + + if not using_slides_markdown: + additional_context = "" + + # Updating async status + if async_status: + async_status.message = "Generating presentation outlines" + async_status.updated_at = datetime.now() + sql_session.add(async_status) + await sql_session.commit() + + if request.files: + documents_loader = DocumentsLoader(file_paths=request.files) + await documents_loader.load_documents() + documents = documents_loader.documents + if documents: + additional_context = "\n\n".join(documents) + + # Finding number of slides to generate by considering table of contents + n_slides_to_generate = request.n_slides + if request.include_table_of_contents: + needed_toc_count = math.ceil( + ( + (request.n_slides - 1) if request.include_title_slide - else outline_index - i + n_toc_slides + else request.n_slides ) - toc_outline += f"Slide page number: {page_number}\n Slide Content: {outline.content[:100]}\n\n" - outline_index += 1 - - outline_index += 1 - - presentation_outlines.slides.insert( - i + 1 if request.include_title_slide else i, - SlideOutlineModel( - content=toc_outline, - ), + / 10 + ) + n_slides_to_generate -= math.ceil( + (request.n_slides - needed_toc_count) / 10 ) - # Create PresentationModel - presentation = PresentationModel( - id=presentation_id, - content=request.content, - n_slides=request.n_slides, - language=request.language, - title=get_presentation_title_from_outlines(presentation_outlines), - outlines=presentation_outlines.model_dump(), - layout=layout_model.model_dump(), - structure=presentation_structure.model_dump(), - tone=request.tone, - verbosity=request.verbosity, - instructions=request.instructions, - ) - - image_generation_service = ImageGenerationService(get_images_directory()) - async_assets_generation_tasks = [] - - # 7. Generate slide content concurrently (batched), then build slides and fetch assets - slides: List[SlideModel] = [] - - slide_layout_indices = presentation_structure.slides - slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices] - - # Schedule slide content generation and asset fetching in batches of 10 - batch_size = 10 - for start in range(0, len(slide_layouts), batch_size): - end = min(start + batch_size, len(slide_layouts)) - - print(f"Generating slides from {start} to {end}") - - # Generate contents for this batch concurrently - content_tasks = [ - get_slide_content_from_type_and_outline( - slide_layouts[i], - presentation_outlines.slides[i], + presentation_outlines_text = "" + async for chunk in generate_ppt_outline( + request.content, + n_slides_to_generate, request.language, + additional_context, request.tone.value, request.verbosity.value, request.instructions, + request.include_title_slide, + request.web_search, + ): + + if isinstance(chunk, HTTPException): + raise chunk + + presentation_outlines_text += chunk + + try: + presentation_outlines_json = dict( + dirtyjson.loads(presentation_outlines_text) + ) + except Exception as e: + traceback.print_exc() + raise HTTPException( + status_code=400, + detail="Failed to generate presentation outlines. Please try again.", + ) + presentation_outlines = PresentationOutlineModel( + **presentation_outlines_json ) - for i in range(start, end) - ] - batch_contents: List[dict] = await asyncio.gather(*content_tasks) + total_outlines = n_slides_to_generate - # Build slides for this batch - batch_slides: List[SlideModel] = [] - for offset, slide_content in enumerate(batch_contents): - i = start + offset - slide_layout = slide_layouts[i] - slide = SlideModel( - presentation=presentation_id, - layout_group=layout_model.name, - layout=slide_layout.id, - index=i, - speaker_note=slide_content.get("__speaker_note__"), - content=slide_content, + else: + # Setting outlines to slides markdown + presentation_outlines = PresentationOutlineModel( + slides=[ + SlideOutlineModel(content=slide) + for slide in request.slides_markdown + ] ) - slides.append(slide) - batch_slides.append(slide) + total_outlines = len(request.slides_markdown) - # Start asset fetch tasks for just-generated slides so they run while next batch is processed - asset_tasks = [ - process_slide_and_fetch_assets(image_generation_service, slide) - for slide in batch_slides - ] - async_assets_generation_tasks.extend(asset_tasks) + # Updating async status + if async_status: + async_status.message = f"Selecting layout for each slide" + async_status.updated_at = datetime.now() + sql_session.add(async_status) + await sql_session.commit() - # Run all asset tasks concurrently while batches may still be generating content - generated_assets_list = await asyncio.gather(*async_assets_generation_tasks) - generated_assets = [] - for assets_list in generated_assets_list: - generated_assets.extend(assets_list) + print("-" * 40) + print(f"Generated {total_outlines} outlines for the presentation") - # 8. Save PresentationModel and Slides - sql_session.add(presentation) - sql_session.add_all(slides) - sql_session.add_all(generated_assets) - await sql_session.commit() + # Parse Layouts + layout_model = await get_layout_by_name(request.template) + total_slide_layouts = len(layout_model.slides) - # 9. Export - presentation_and_path = await export_presentation( - presentation_id, presentation.title or str(uuid.uuid4()), request.export_as - ) + # Generate Structure + if layout_model.ordered: + presentation_structure = layout_model.to_presentation_structure() + else: + presentation_structure: PresentationStructureModel = ( + await generate_presentation_structure( + presentation_outlines, + layout_model, + request.instructions, + using_slides_markdown, + ) + ) - return PresentationPathAndEditPath( - **presentation_and_path.model_dump(), - edit_path=f"/presentation?id={presentation_id}", - ) + presentation_structure.slides = presentation_structure.slides[:total_outlines] + for index in range(total_outlines): + random_slide_index = random.randint(0, total_slide_layouts - 1) + if index >= total_outlines: + presentation_structure.slides.append(random_slide_index) + continue + if presentation_structure.slides[index] >= total_slide_layouts: + presentation_structure.slides[index] = random_slide_index + + # Injecting table of contents to the presentation structure and outlines + if request.include_table_of_contents and not using_slides_markdown: + n_toc_slides = request.n_slides - total_outlines + toc_slide_layout_index = select_toc_or_list_slide_layout_index(layout_model) + if toc_slide_layout_index != -1: + outline_index = 1 if request.include_title_slide else 0 + for i in range(n_toc_slides): + outlines_to = outline_index + 10 + if total_outlines == outlines_to: + outlines_to -= 1 + + presentation_structure.slides.insert( + i + 1 if request.include_title_slide else i, + toc_slide_layout_index, + ) + toc_outline = f"Table of Contents\n\n" + + for outline in presentation_outlines.slides[ + outline_index:outlines_to + ]: + page_number = ( + outline_index - i + n_toc_slides + 1 + if request.include_title_slide + else outline_index - i + n_toc_slides + ) + toc_outline += f"Slide page number: {page_number}\n Slide Content: {outline.content[:100]}\n\n" + outline_index += 1 + + outline_index += 1 + + presentation_outlines.slides.insert( + i + 1 if request.include_title_slide else i, + SlideOutlineModel( + content=toc_outline, + ), + ) + + # Create PresentationModel + presentation = PresentationModel( + id=presentation_id, + content=request.content, + n_slides=request.n_slides, + language=request.language, + title=get_presentation_title_from_outlines(presentation_outlines), + outlines=presentation_outlines.model_dump(), + layout=layout_model.model_dump(), + structure=presentation_structure.model_dump(), + tone=request.tone.value, + verbosity=request.verbosity.value, + instructions=request.instructions, + ) + + # Updating async status + if async_status: + async_status.message = "Generating slides" + async_status.updated_at = datetime.now() + sql_session.add(async_status) + await sql_session.commit() + + image_generation_service = ImageGenerationService(get_images_directory()) + async_assets_generation_tasks = [] + + # 7. Generate slide content concurrently (batched), then build slides and fetch assets + slides: List[SlideModel] = [] + + slide_layout_indices = presentation_structure.slides + slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices] + + # Schedule slide content generation and asset fetching in batches of 10 + batch_size = 10 + for start in range(0, len(slide_layouts), batch_size): + end = min(start + batch_size, len(slide_layouts)) + + print(f"Generating slides from {start} to {end}") + + # Generate contents for this batch concurrently + content_tasks = [ + get_slide_content_from_type_and_outline( + slide_layouts[i], + presentation_outlines.slides[i], + request.language, + request.tone.value, + request.verbosity.value, + request.instructions, + ) + for i in range(start, end) + ] + batch_contents: List[dict] = await asyncio.gather(*content_tasks) + + # Build slides for this batch + batch_slides: List[SlideModel] = [] + for offset, slide_content in enumerate(batch_contents): + i = start + offset + slide_layout = slide_layouts[i] + slide = SlideModel( + presentation=presentation_id, + layout_group=layout_model.name, + layout=slide_layout.id, + index=i, + speaker_note=slide_content.get("__speaker_note__"), + content=slide_content, + ) + slides.append(slide) + batch_slides.append(slide) + + # Start asset fetch tasks for just-generated slides so they run while next batch is processed + asset_tasks = [ + process_slide_and_fetch_assets(image_generation_service, slide) + for slide in batch_slides + ] + async_assets_generation_tasks.extend(asset_tasks) + + if async_status: + async_status.message = "Fetching assets for slides" + async_status.updated_at = datetime.now() + sql_session.add(async_status) + await sql_session.commit() + + # Run all asset tasks concurrently while batches may still be generating content + generated_assets_list = await asyncio.gather(*async_assets_generation_tasks) + generated_assets = [] + for assets_list in generated_assets_list: + generated_assets.extend(assets_list) + + # 8. Save PresentationModel and Slides + sql_session.add(presentation) + sql_session.add_all(slides) + sql_session.add_all(generated_assets) + await sql_session.commit() + + if async_status: + async_status.message = "Exporting presentation" + async_status.updated_at = datetime.now() + sql_session.add(async_status) + + # 9. Export + presentation_and_path = await export_presentation( + presentation_id, presentation.title or str(uuid.uuid4()), request.export_as + ) + + response = PresentationPathAndEditPath( + **presentation_and_path.model_dump(), + edit_path=f"/presentation?id={presentation_id}", + ) + + if async_status: + async_status.message = "Presentation generation completed" + async_status.status = "completed" + async_status.data = response.model_dump(mode="json") + async_status.updated_at = datetime.now() + sql_session.add(async_status) + await sql_session.commit() + + # Triggering webhook on success + CONCURRENT_SERVICE.run_task( + None, + WebhookService.send_webhook, + WebhookEvent.PRESENTATION_GENERATION_COMPLETED, + response.model_dump(mode="json"), + ) + + return response + + except Exception as e: + if not isinstance(e, HTTPException): + traceback.print_exc() + e = HTTPException(status_code=500, detail="Presentation generation failed") + + api_error_model = APIErrorModel.from_exception(e) + + # Triggering webhook on failure + CONCURRENT_SERVICE.run_task( + None, + WebhookService.send_webhook, + WebhookEvent.PRESENTATION_GENERATION_FAILED, + api_error_model.model_dump(mode="json"), + ) + + if async_status: + async_status.status = "error" + async_status.message = "Presentation generation failed" + async_status.updated_at = datetime.now() + async_status.error = api_error_model.model_dump(mode="json") + sql_session.add(async_status) + await sql_session.commit() + + else: + raise e + + +@PRESENTATION_ROUTER.post("/generate", response_model=PresentationPathAndEditPath) +async def generate_presentation_sync( + request: GeneratePresentationRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + try: + (presentation_id,) = await check_if_api_request_is_valid(request, sql_session) + return await generate_presentation_handler( + request, presentation_id, None, sql_session + ) + except Exception as e: + traceback.print_exc() + raise HTTPException(status_code=500, detail="Presentation generation failed") + + +@PRESENTATION_ROUTER.post( + "/generate/async", response_model=AsyncPresentationGenerationTaskModel +) +async def generate_presentation_async( + request: GeneratePresentationRequest, + background_tasks: BackgroundTasks, + sql_session: AsyncSession = Depends(get_async_session), +): + try: + (presentation_id,) = await check_if_api_request_is_valid(request, sql_session) + + async_status = AsyncPresentationGenerationTaskModel( + status="pending", + message="Queued for generation", + data=None, + ) + sql_session.add(async_status) + await sql_session.commit() + + background_tasks.add_task( + generate_presentation_handler, + request, + presentation_id, + async_status=async_status, + sql_session=sql_session, + ) + return async_status + + except Exception as e: + if not isinstance(e, HTTPException): + print(e) + e = HTTPException(status_code=500, detail="Presentation generation failed") + + raise e + + +@PRESENTATION_ROUTER.get( + "/status/{id}", response_model=AsyncPresentationGenerationTaskModel +) +async def check_async_presentation_generation_status( + id: str = Path(description="ID of the presentation generation task"), + sql_session: AsyncSession = Depends(get_async_session), +): + status = await sql_session.get(AsyncPresentationGenerationTaskModel, id) + if not status: + raise HTTPException( + status_code=404, detail="No presentation generation task found" + ) + return status @PRESENTATION_ROUTER.post("/edit", response_model=PresentationPathAndEditPath) diff --git a/servers/fastapi/api/v1/webhook/router.py b/servers/fastapi/api/v1/webhook/router.py new file mode 100644 index 00000000..44b5b635 --- /dev/null +++ b/servers/fastapi/api/v1/webhook/router.py @@ -0,0 +1,54 @@ +from typing import Optional +import uuid +from fastapi import APIRouter, Body, Depends, HTTPException, Path +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from enums.webhook_event import WebhookEvent +from models.sql.webhook_subscription import WebhookSubscription +from services.database import get_async_session + +API_V1_WEBHOOK_ROUTER = APIRouter(prefix="/api/v1/webhook", tags=["Webhook"]) + + +class SubscribeToWebhookRequest(BaseModel): + url: str = Field(description="The URL to send the webhook to") + secret: Optional[str] = Field(None, description="The secret to use for the webhook") + event: WebhookEvent = Field(description="The event to subscribe to") + + +class SubscribeToWebhookResponse(BaseModel): + id: uuid.UUID + + +@API_V1_WEBHOOK_ROUTER.post( + "/subscribe", response_model=SubscribeToWebhookResponse, status_code=201 +) +async def subscribe_to_webhook( + body: SubscribeToWebhookRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + webhook_subscription = WebhookSubscription( + url=body.url, + secret=body.secret, + event=body.event, + ) + sql_session.add(webhook_subscription) + await sql_session.commit() + return SubscribeToWebhookResponse(id=webhook_subscription.id) + + +@API_V1_WEBHOOK_ROUTER.delete("/unsubscribe", status_code=204) +async def unsubscribe_to_webhook( + id: uuid.UUID = Body( + embed=True, description="The ID of the webhook subscription to unsubscribe from" + ), + sql_session: AsyncSession = Depends(get_async_session), +): + + webhook_subscription = await sql_session.get(WebhookSubscription, id) + if not webhook_subscription: + raise HTTPException(404, "Webhook subscription not found") + + await sql_session.delete(webhook_subscription) + await sql_session.commit() diff --git a/servers/fastapi/constants/presentation.py b/servers/fastapi/constants/presentation.py new file mode 100644 index 00000000..22d6b012 --- /dev/null +++ b/servers/fastapi/constants/presentation.py @@ -0,0 +1 @@ +DEFAULT_TEMPLATES = ["general", "modern", "standard", "swift"] diff --git a/servers/fastapi/enums/webhook_event.py b/servers/fastapi/enums/webhook_event.py new file mode 100644 index 00000000..ce1ca4ac --- /dev/null +++ b/servers/fastapi/enums/webhook_event.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class WebhookEvent(str, Enum): + PRESENTATION_GENERATION_COMPLETED = "presentation.generation.completed" + PRESENTATION_GENERATION_FAILED = "presentation.generation.failed" diff --git a/servers/fastapi/models/sql/async_presentation_generation_status.py b/servers/fastapi/models/sql/async_presentation_generation_status.py new file mode 100644 index 00000000..0431b049 --- /dev/null +++ b/servers/fastapi/models/sql/async_presentation_generation_status.py @@ -0,0 +1,22 @@ +from datetime import datetime +import secrets +from typing import Optional +import uuid + +from sqlalchemy import JSON, Column +from sqlmodel import Field, SQLModel + + +class AsyncPresentationGenerationTaskModel(SQLModel, table=True): + + __tablename__ = "async_presentation_generation_tasks" + + id: str = Field( + default_factory=lambda: f"task-{secrets.token_hex(32)}", primary_key=True + ) + status: str + message: Optional[str] = None + error: Optional[dict] = Field(sa_column=Column(JSON), default=None) + created_at: datetime = Field(default_factory=datetime.now) + updated_at: datetime = Field(default_factory=datetime.now) + data: Optional[dict] = Field(sa_column=Column(JSON), default=None) diff --git a/servers/fastapi/models/sql/webhook_subscription.py b/servers/fastapi/models/sql/webhook_subscription.py new file mode 100644 index 00000000..530a1ad8 --- /dev/null +++ b/servers/fastapi/models/sql/webhook_subscription.py @@ -0,0 +1,19 @@ +from typing import Optional +import uuid +from datetime import datetime +from sqlmodel import Column, DateTime, Field, SQLModel + +from utils.datetime_utils import get_current_utc_datetime + + +class WebhookSubscription(SQLModel, table=True): + __tablename__ = "webhook_subscriptions" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False), + default_factory=get_current_utc_datetime, + ) + url: str + secret: Optional[str] = None + event: str = Field(index=True) diff --git a/servers/fastapi/services/concurrent_service.py b/servers/fastapi/services/concurrent_service.py new file mode 100644 index 00000000..e2c922f8 --- /dev/null +++ b/servers/fastapi/services/concurrent_service.py @@ -0,0 +1,35 @@ +import asyncio +from asyncio import Task +from typing import Any, Callable, Coroutine, Optional + + +class ConcurrentService: + def __init__(self): + self._background_tasks = set[Task]() + + def run_task( + self, + delay: Optional[int], + callable: Callable[..., Coroutine[Any, Any, Any]], + *args, + **kwargs, + ): + async def wrapper(): + if delay: + await asyncio.sleep(delay) + await callable(*args, **kwargs) + + task = asyncio.create_task(wrapper()) + + print(f"Running task: {task} - executing {callable.__name__}") + + self._background_tasks.add(task) + task.add_done_callback(self.on_task_done) + + def on_task_done(self, task: Task): + print(f"Task done: {task}") + + self._background_tasks.discard(task) + + +CONCURRENT_SERVICE = ConcurrentService() diff --git a/servers/fastapi/services/webhook_service.py b/servers/fastapi/services/webhook_service.py new file mode 100644 index 00000000..4525e15f --- /dev/null +++ b/servers/fastapi/services/webhook_service.py @@ -0,0 +1,55 @@ +import asyncio +import aiohttp +from sqlmodel import select +from enums.webhook_event import WebhookEvent +from models.sql.webhook_subscription import WebhookSubscription +from services.database import get_async_session + + +class WebhookService: + + @classmethod + async def send_webhook(cls, event: WebhookEvent, data: dict): + async for sql_session in get_async_session(): + webhook_subscriptions = await sql_session.scalars( + select(WebhookSubscription).where( + WebhookSubscription.event == event.value + ) + ) + webhook_subscriptions = list(webhook_subscriptions) + if not webhook_subscriptions: + return + + async_tasks = [] + for webhook_subscription in webhook_subscriptions: + async_tasks.append( + cls.send_request_to_webhook(webhook_subscription, data) + ) + + await asyncio.gather(*async_tasks) + + break + + @classmethod + async def send_request_to_webhook( + cls, subscription: WebhookSubscription, data: dict + ): + + headers = { + "Content-Type": "application/json", + } + if subscription.secret: + headers["Authorization"] = f"Bearer {subscription.secret}" + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + subscription.url, + json=data, + headers=headers, + ) as _: + pass + + except Exception as e: + print(f"Error sending request to webhook {subscription.id}: {e}") + pass