diff --git a/servers/fastapi/api/main.py b/servers/fastapi/api/main.py index 80eea709..742ef913 100644 --- a/servers/fastapi/api/main.py +++ b/servers/fastapi/api/main.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from api.lifespan import app_lifespan from api.middlewares import UserConfigEnvUpdateMiddleware from api.v1.ppt.router import API_V1_PPT_ROUTER +from api.v1.webhook.router import API_V1_WEBHOOK_ROUTER app = FastAPI(lifespan=app_lifespan) @@ -10,6 +11,7 @@ app = FastAPI(lifespan=app_lifespan) # Routers app.include_router(API_V1_PPT_ROUTER) +app.include_router(API_V1_WEBHOOK_ROUTER) # Middlewares origins = ["*"] diff --git a/servers/fastapi/api/v1/webhook/router.py b/servers/fastapi/api/v1/webhook/router.py new file mode 100644 index 00000000..44b5b635 --- /dev/null +++ b/servers/fastapi/api/v1/webhook/router.py @@ -0,0 +1,54 @@ +from typing import Optional +import uuid +from fastapi import APIRouter, Body, Depends, HTTPException, Path +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from enums.webhook_event import WebhookEvent +from models.sql.webhook_subscription import WebhookSubscription +from services.database import get_async_session + +API_V1_WEBHOOK_ROUTER = APIRouter(prefix="/api/v1/webhook", tags=["Webhook"]) + + +class SubscribeToWebhookRequest(BaseModel): + url: str = Field(description="The URL to send the webhook to") + secret: Optional[str] = Field(None, description="The secret to use for the webhook") + event: WebhookEvent = Field(description="The event to subscribe to") + + +class SubscribeToWebhookResponse(BaseModel): + id: uuid.UUID + + +@API_V1_WEBHOOK_ROUTER.post( + "/subscribe", response_model=SubscribeToWebhookResponse, status_code=201 +) +async def subscribe_to_webhook( + body: SubscribeToWebhookRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + webhook_subscription = WebhookSubscription( + url=body.url, + secret=body.secret, + event=body.event, + ) + sql_session.add(webhook_subscription) + await sql_session.commit() + return SubscribeToWebhookResponse(id=webhook_subscription.id) + + +@API_V1_WEBHOOK_ROUTER.delete("/unsubscribe", status_code=204) +async def unsubscribe_to_webhook( + id: uuid.UUID = Body( + embed=True, description="The ID of the webhook subscription to unsubscribe from" + ), + sql_session: AsyncSession = Depends(get_async_session), +): + + webhook_subscription = await sql_session.get(WebhookSubscription, id) + if not webhook_subscription: + raise HTTPException(404, "Webhook subscription not found") + + await sql_session.delete(webhook_subscription) + await sql_session.commit() diff --git a/servers/fastapi/enums/webhook_event.py b/servers/fastapi/enums/webhook_event.py new file mode 100644 index 00000000..ce1ca4ac --- /dev/null +++ b/servers/fastapi/enums/webhook_event.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class WebhookEvent(str, Enum): + PRESENTATION_GENERATION_COMPLETED = "presentation.generation.completed" + PRESENTATION_GENERATION_FAILED = "presentation.generation.failed" diff --git a/servers/fastapi/models/sql/webhook_subscription.py b/servers/fastapi/models/sql/webhook_subscription.py new file mode 100644 index 00000000..530a1ad8 --- /dev/null +++ b/servers/fastapi/models/sql/webhook_subscription.py @@ -0,0 +1,19 @@ +from typing import Optional +import uuid +from datetime import datetime +from sqlmodel import Column, DateTime, Field, SQLModel + +from utils.datetime_utils import get_current_utc_datetime + + +class WebhookSubscription(SQLModel, table=True): + __tablename__ = "webhook_subscriptions" + + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + created_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), nullable=False), + default_factory=get_current_utc_datetime, + ) + url: str + secret: Optional[str] = None + event: str = Field(index=True) diff --git a/servers/fastapi/services/concurrent_service.py b/servers/fastapi/services/concurrent_service.py new file mode 100644 index 00000000..e2c922f8 --- /dev/null +++ b/servers/fastapi/services/concurrent_service.py @@ -0,0 +1,35 @@ +import asyncio +from asyncio import Task +from typing import Any, Callable, Coroutine, Optional + + +class ConcurrentService: + def __init__(self): + self._background_tasks = set[Task]() + + def run_task( + self, + delay: Optional[int], + callable: Callable[..., Coroutine[Any, Any, Any]], + *args, + **kwargs, + ): + async def wrapper(): + if delay: + await asyncio.sleep(delay) + await callable(*args, **kwargs) + + task = asyncio.create_task(wrapper()) + + print(f"Running task: {task} - executing {callable.__name__}") + + self._background_tasks.add(task) + task.add_done_callback(self.on_task_done) + + def on_task_done(self, task: Task): + print(f"Task done: {task}") + + self._background_tasks.discard(task) + + +CONCURRENT_SERVICE = ConcurrentService() diff --git a/servers/fastapi/services/webhook_service.py b/servers/fastapi/services/webhook_service.py new file mode 100644 index 00000000..4525e15f --- /dev/null +++ b/servers/fastapi/services/webhook_service.py @@ -0,0 +1,55 @@ +import asyncio +import aiohttp +from sqlmodel import select +from enums.webhook_event import WebhookEvent +from models.sql.webhook_subscription import WebhookSubscription +from services.database import get_async_session + + +class WebhookService: + + @classmethod + async def send_webhook(cls, event: WebhookEvent, data: dict): + async for sql_session in get_async_session(): + webhook_subscriptions = await sql_session.scalars( + select(WebhookSubscription).where( + WebhookSubscription.event == event.value + ) + ) + webhook_subscriptions = list(webhook_subscriptions) + if not webhook_subscriptions: + return + + async_tasks = [] + for webhook_subscription in webhook_subscriptions: + async_tasks.append( + cls.send_request_to_webhook(webhook_subscription, data) + ) + + await asyncio.gather(*async_tasks) + + break + + @classmethod + async def send_request_to_webhook( + cls, subscription: WebhookSubscription, data: dict + ): + + headers = { + "Content-Type": "application/json", + } + if subscription.secret: + headers["Authorization"] = f"Bearer {subscription.secret}" + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + subscription.url, + json=data, + headers=headers, + ) as _: + pass + + except Exception as e: + print(f"Error sending request to webhook {subscription.id}: {e}") + pass