diff --git a/backend/api/lifespan.py b/backend/api/lifespan.py index 86f5537..55194c2 100644 --- a/backend/api/lifespan.py +++ b/backend/api/lifespan.py @@ -4,6 +4,7 @@ import os from fastapi import FastAPI from services.database import create_db_and_tables +from services.redis_service import close_arq_pool from utils.get_env import get_app_data_directory_env from utils.model_availability import ( check_llm_and_image_provider_api_or_model_availability, @@ -21,3 +22,4 @@ async def app_lifespan(_: FastAPI): await create_db_and_tables() await check_llm_and_image_provider_api_or_model_availability() yield + await close_arq_pool() diff --git a/backend/api/main.py b/backend/api/main.py index 8c984dd..7f2dbcb 100644 --- a/backend/api/main.py +++ b/backend/api/main.py @@ -13,6 +13,7 @@ from api.v1.admin.clients_router import CLIENTS_ROUTER from api.v1.admin.audit_router import AUDIT_ROUTER from api.v1.admin.brand_config_router import BRAND_CONFIG_ROUTER from api.v1.admin.master_decks_router import MASTER_DECKS_ROUTER +from api.v1.ppt.endpoints.jobs import JOBS_ROUTER from api.middlewares.audit_middleware import AuditMiddleware @@ -31,6 +32,7 @@ ADMIN_ROUTER.include_router(MASTER_DECKS_ROUTER) app.include_router(AUTH_ROUTER) app.include_router(ADMIN_ROUTER) app.include_router(API_V1_PPT_ROUTER) +app.include_router(JOBS_ROUTER) app.include_router(API_V1_WEBHOOK_ROUTER) app.include_router(API_V1_MOCK_ROUTER) diff --git a/backend/api/v1/admin/master_decks_router.py b/backend/api/v1/admin/master_decks_router.py index f9a30f3..bf57503 100644 --- a/backend/api/v1/admin/master_decks_router.py +++ b/backend/api/v1/admin/master_decks_router.py @@ -114,11 +114,27 @@ async def upload_master_deck( await session.commit() await session.refresh(deck) - # Kick off async parsing - import asyncio - from services.master_deck_parser_service import parse_master_deck + # Kick off async parsing via ARQ (fallback to asyncio.create_task) + try: + from models.sql.job import JobModel + from services.redis_service import enqueue_job - asyncio.create_task(parse_master_deck(deck_id)) + job = JobModel( + user_id=admin.id, + client_id=client_id, + presentation_id=deck_id, # reuse field for deck_id + job_type="parse_master_deck", + status="queued", + progress=0, + progress_message="Queued for parsing", + ) + session.add(job) + await session.commit() + await enqueue_job("parse_master_deck_task", job_id=str(job.id)) + except Exception: + import asyncio + from services.master_deck_parser_service import parse_master_deck + asyncio.create_task(parse_master_deck(deck_id)) return { "id": str(deck.id), @@ -235,10 +251,26 @@ async def reparse_master_deck( deck.parse_status = "pending" await session.commit() - import asyncio - from services.master_deck_parser_service import parse_master_deck + try: + from models.sql.job import JobModel + from services.redis_service import enqueue_job - asyncio.create_task(parse_master_deck(deck_id)) + job = JobModel( + user_id=admin.id, + client_id=deck.client_id, + presentation_id=deck_id, + job_type="parse_master_deck", + status="queued", + progress=0, + progress_message="Queued for re-parsing", + ) + session.add(job) + await session.commit() + await enqueue_job("parse_master_deck_task", job_id=str(job.id)) + except Exception: + import asyncio + from services.master_deck_parser_service import parse_master_deck + asyncio.create_task(parse_master_deck(deck_id)) return {"ok": True, "parse_status": "pending"} diff --git a/backend/api/v1/ppt/endpoints/jobs.py b/backend/api/v1/ppt/endpoints/jobs.py new file mode 100644 index 0000000..f9a58b6 --- /dev/null +++ b/backend/api/v1/ppt/endpoints/jobs.py @@ -0,0 +1,151 @@ +"""Job status polling and SSE streaming endpoints.""" +import asyncio +import json +import uuid + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from models.sse_response import SSEResponse, SSEErrorResponse +from models.sql.job import JobModel +from models.sql.user import UserModel +from services.database import get_async_session +from services.redis_service import get_arq_pool +from utils.auth_dependencies import get_current_user + +JOBS_ROUTER = APIRouter(prefix="/api/v1/ppt", tags=["Jobs"]) + + +@JOBS_ROUTER.get("/jobs/{job_id}") +async def get_job_status( + job_id: uuid.UUID, + _current_user: UserModel = Depends(get_current_user), + session: AsyncSession = Depends(get_async_session), +): + """Poll job status.""" + job = await session.get(JobModel, job_id) + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + return { + "id": str(job.id), + "job_type": job.job_type, + "status": job.status, + "progress": job.progress, + "progress_message": job.progress_message, + "error_message": job.error_message, + "presentation_id": str(job.presentation_id) if job.presentation_id else None, + "created_at": job.created_at.isoformat() if job.created_at else None, + "started_at": job.started_at.isoformat() if job.started_at else None, + "completed_at": job.completed_at.isoformat() if job.completed_at else None, + } + + +@JOBS_ROUTER.get("/jobs/{job_id}/stream") +async def stream_job_progress( + job_id: uuid.UUID, + _current_user: UserModel = Depends(get_current_user), + session: AsyncSession = Depends(get_async_session), +): + """SSE stream of job progress events via Redis pub/sub.""" + job = await session.get(JobModel, job_id) + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + # If already completed/failed, return final event immediately + if job.status in ("completed", "failed"): + async def done_stream(): + data = { + "type": "progress", + "job_id": str(job.id), + "progress": job.progress, + "message": job.progress_message or job.status, + "status": job.status, + } + yield SSEResponse( + event="response", data=json.dumps(data) + ).to_string() + + return StreamingResponse(done_stream(), media_type="text/event-stream") + + async def progress_stream(): + try: + pool = await get_arq_pool() + pubsub = pool.pubsub() + channel = f"job:{job_id}:progress" + await pubsub.subscribe(channel) + + # Send initial status + yield SSEResponse( + event="response", + data=json.dumps({ + "type": "progress", + "job_id": str(job_id), + "progress": job.progress, + "message": job.progress_message or "Waiting", + "status": job.status, + }), + ).to_string() + + # Listen for updates with timeout + timeout = 600 # 10 minutes max + elapsed = 0 + while elapsed < timeout: + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=2.0 + ) + if message and message["type"] == "message": + raw = message["data"] + if isinstance(raw, bytes): + raw = raw.decode() + payload = json.loads(raw) + yield SSEResponse( + event="response", + data=json.dumps({"type": "progress", **payload}), + ).to_string() + + # Stop streaming on terminal status + if payload.get("status") in ("completed", "failed"): + break + else: + elapsed += 2 + + await asyncio.sleep(0) + + await pubsub.unsubscribe(channel) + await pubsub.aclose() + + except Exception as e: + yield SSEErrorResponse(detail=str(e)[:200]).to_string() + + return StreamingResponse(progress_stream(), media_type="text/event-stream") + + +@JOBS_ROUTER.delete("/jobs/{job_id}") +async def cancel_job( + job_id: uuid.UUID, + _current_user: UserModel = Depends(get_current_user), + session: AsyncSession = Depends(get_async_session), +): + """Cancel a queued or processing job.""" + job = await session.get(JobModel, job_id) + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + if job.status in ("completed", "failed"): + raise HTTPException(status_code=409, detail="Job already finished") + + # Try to abort via ARQ + try: + pool = await get_arq_pool() + await pool.abort_job(str(job_id)) + except Exception: + pass # Best effort + + job.status = "failed" + job.error_message = "Cancelled by user" + job.progress_message = "Cancelled" + await session.commit() + + return {"ok": True, "status": "cancelled"} diff --git a/backend/api/v1/ppt/endpoints/presentation.py b/backend/api/v1/ppt/endpoints/presentation.py index f5f030f..ec5e72b 100644 --- a/backend/api/v1/ppt/endpoints/presentation.py +++ b/backend/api/v1/ppt/endpoints/presentation.py @@ -846,6 +846,18 @@ async def generate_presentation_async( try: (presentation_id,) = await check_if_api_request_is_valid(request, sql_session) + # Create a lightweight presentation record so the worker can load it + presentation = PresentationModel( + id=presentation_id, + content=request.content, + n_slides=request.n_slides, + language=request.language, + tone=request.tone.value, + verbosity=request.verbosity.value, + instructions=request.instructions, + ) + sql_session.add(presentation) + async_status = AsyncPresentationGenerationTaskModel( status="pending", message="Queued for generation", @@ -854,13 +866,39 @@ async def generate_presentation_async( sql_session.add(async_status) await sql_session.commit() - background_tasks.add_task( - generate_presentation_handler, - request, - presentation_id, - async_status=async_status, - sql_session=sql_session, - ) + # Try ARQ job queue first; fall back to BackgroundTasks + job_enqueued = False + try: + from models.sql.job import JobModel + from services.redis_service import enqueue_job + + job = JobModel( + user_id=_current_user.id, + client_id=getattr(_current_user, "default_client_id", _current_user.id), + presentation_id=presentation_id, + job_type="generate_presentation", + status="queued", + progress=0, + progress_message="Queued for generation", + ) + sql_session.add(job) + await sql_session.commit() + + await enqueue_job("generate_presentation_task", job_id=str(job.id)) + job_enqueued = True + except Exception: + # Redis/ARQ unavailable — fall back to in-process background task + pass + + if not job_enqueued: + background_tasks.add_task( + generate_presentation_handler, + request, + presentation_id, + async_status=async_status, + sql_session=sql_session, + ) + return async_status except Exception as e: diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 923017c..90a38bb 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "python-jose[cryptography]>=3.3", "openpyxl>=3.1", "trafilatura>=2.0", + "arq>=0.26", ] [[tool.uv.index]] diff --git a/backend/services/brand_enforcement_service.py b/backend/services/brand_enforcement_service.py new file mode 100644 index 0000000..4a39331 --- /dev/null +++ b/backend/services/brand_enforcement_service.py @@ -0,0 +1,225 @@ +"""Brand Enforcement Service: apply client brand rules to generated content and PPTX output.""" +import math +import uuid +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from models.pptx_models import ( + PptxAutoShapeBoxModel, + PptxChartBoxModel, + PptxFontModel, + PptxPictureBoxModel, + PptxPictureModel, + PptxPositionModel, + PptxPresentationModel, + PptxTextBoxModel, +) +from models.sql.brand_config import BrandConfigModel + + +# Default slide dimensions (matching PptxPresentationCreator) +_SLIDE_WIDTH = 1280 +_SLIDE_HEIGHT = 720 + +# Logo defaults +_LOGO_WIDTH = 80 +_LOGO_HEIGHT = 40 +_LOGO_MARGIN = 20 + + +class BrandEnforcementService: + + async def get_brand_context_for_llm( + self, client_id: uuid.UUID, session: AsyncSession + ) -> str: + """Build a text block for LLM system prompt injection with brand guidelines.""" + brand = await self._load_brand(client_id, session) + if not brand: + return "" + + parts: List[str] = [] + + if brand.primary_colors: + colors_str = ", ".join(brand.primary_colors[:6]) + parts.append(f"Brand primary colors: {colors_str}") + + if brand.secondary_colors: + colors_str = ", ".join(brand.secondary_colors[:6]) + parts.append(f"Brand secondary colors: {colors_str}") + + if brand.fonts: + heading = brand.fonts.get("heading", "") + body = brand.fonts.get("body", "") + if heading: + parts.append(f"Heading font: {heading}") + if body: + parts.append(f"Body font: {body}") + + if brand.voice_rules: + parts.append(f"Brand voice guidelines: {brand.voice_rules}") + + if brand.voice_examples: + examples = "\n - ".join(str(e) for e in brand.voice_examples[:3]) + parts.append(f"Brand voice examples:\n - {examples}") + + return "\n".join(parts) + + def enforce_on_pptx_model( + self, + model: PptxPresentationModel, + brand: BrandConfigModel, + ) -> PptxPresentationModel: + """Apply brand fonts, colors, and logo to a PPTX model.""" + heading_font = (brand.fonts or {}).get("heading") + body_font = (brand.fonts or {}).get("body") + brand_colors = self.get_brand_colors_list(brand) + logo_path = brand.logo_paths[0] if brand.logo_paths else None + + for slide_idx, slide in enumerate(model.slides): + for shape in slide.shapes: + # Replace fonts + if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)): + self._enforce_fonts_on_shape(shape, heading_font, body_font, slide_idx) + + # Set brand colors on charts + if isinstance(shape, PptxChartBoxModel): + if brand_colors: + shape.brand_colors = brand_colors + if body_font: + shape.font_name = body_font + + # Contrast check on all text shapes + bg_color = slide.background.color if slide.background else "FFFFFF" + for shape in slide.shapes: + if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)): + self._fix_contrast(shape, bg_color) + + # Add logo to non-title slides + if slide_idx > 0 and logo_path: + logo_shape = PptxPictureBoxModel( + position=PptxPositionModel( + left=_SLIDE_WIDTH - _LOGO_WIDTH - _LOGO_MARGIN, + top=_SLIDE_HEIGHT - _LOGO_HEIGHT - _LOGO_MARGIN, + width=_LOGO_WIDTH, + height=_LOGO_HEIGHT, + ), + clip=False, + picture=PptxPictureModel(is_network=False, path=logo_path), + ) + slide.shapes.append(logo_shape) + + return model + + def get_brand_colors_list(self, brand: BrandConfigModel) -> List[str]: + """Return usable brand color list, with fallback defaults.""" + colors: List[str] = [] + if brand.primary_colors: + colors.extend(str(c).lstrip("#") for c in brand.primary_colors) + if brand.secondary_colors: + colors.extend(str(c).lstrip("#") for c in brand.secondary_colors) + if not colors: + # Fallback to default chart colors + colors = [ + "4472C4", "ED7D31", "A5A5A5", "FFC000", "5B9BD5", + "70AD47", "264478", "9B57A0", "636363", "EB6E1F", + ] + return colors + + # --- Internal methods --- + + async def _load_brand( + self, client_id: uuid.UUID, session: AsyncSession + ) -> Optional[BrandConfigModel]: + stmt = select(BrandConfigModel).where( + BrandConfigModel.client_id == client_id + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + def _enforce_fonts_on_shape( + self, + shape, + heading_font: Optional[str], + body_font: Optional[str], + slide_idx: int, + ) -> None: + """Replace font names in text shapes with brand fonts.""" + paragraphs = getattr(shape, "paragraphs", None) + if not paragraphs: + return + + for para in paragraphs: + # Title slides (idx 0) and large fonts use heading font + is_heading = slide_idx == 0 or (para.font and para.font.size >= 24) + target_font = heading_font if is_heading else body_font + + if target_font: + if para.font: + para.font.name = target_font + if para.text_runs: + for run in para.text_runs: + if run.font: + run.font.name = target_font + + def _fix_contrast(self, shape, bg_color: str) -> None: + """Ensure text has sufficient contrast against slide background.""" + paragraphs = getattr(shape, "paragraphs", None) + if not paragraphs: + return + + bg_lum = _relative_luminance(bg_color) + + for para in paragraphs: + self._fix_font_contrast(para.font, bg_lum) + if para.text_runs: + for run in para.text_runs: + self._fix_font_contrast(run.font, bg_lum) + + def _fix_font_contrast( + self, font: Optional[PptxFontModel], bg_lum: float + ) -> None: + if not font: + return + + text_lum = _relative_luminance(font.color) + ratio = _contrast_ratio(text_lum, bg_lum) + + # WCAG AA requires 4.5:1 for normal text, 3:1 for large text + if ratio < 3.0: + # Swap to white or black based on background + font.color = "FFFFFF" if bg_lum < 0.5 else "000000" + + +# --- Color utility functions --- + + +def _hex_to_rgb(hex_color: str) -> tuple: + """Convert hex color string to (r, g, b) tuple with values 0-255.""" + h = hex_color.lstrip("#") + if len(h) == 3: + h = h[0]*2 + h[1]*2 + h[2]*2 + if len(h) < 6: + h = h.ljust(6, "0") + return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) + + +def _relative_luminance(hex_color: str) -> float: + """Calculate relative luminance per WCAG 2.0.""" + r, g, b = _hex_to_rgb(hex_color) + rs = r / 255.0 + gs = g / 255.0 + bs = b / 255.0 + + def linearize(c): + return c / 12.92 if c <= 0.03928 else math.pow((c + 0.055) / 1.055, 2.4) + + return 0.2126 * linearize(rs) + 0.7152 * linearize(gs) + 0.0722 * linearize(bs) + + +def _contrast_ratio(lum1: float, lum2: float) -> float: + """Calculate contrast ratio between two luminance values.""" + lighter = max(lum1, lum2) + darker = min(lum1, lum2) + return (lighter + 0.05) / (darker + 0.05) diff --git a/backend/services/llm_client.py b/backend/services/llm_client.py index 9662122..cf4a5d7 100644 --- a/backend/services/llm_client.py +++ b/backend/services/llm_client.py @@ -47,6 +47,7 @@ from utils.get_env import ( get_custom_llm_api_key_env, get_custom_llm_url_env, get_disable_thinking_env, + get_fallback_llm_providers_env, get_google_api_key_env, get_ollama_url_env, get_openai_api_key_env, @@ -62,6 +63,26 @@ from utils.schema_utils import ( ) +import logging + +_logger = logging.getLogger(__name__) + + +def _get_fallback_providers() -> list: + """Parse FALLBACK_LLM_PROVIDERS env var into list of LLMProvider enums.""" + raw = get_fallback_llm_providers_env() + if not raw: + return [] + providers = [] + for name in raw.split(","): + name = name.strip().lower() + try: + providers.append(LLMProvider(name)) + except ValueError: + pass + return providers + + class LLMClient: def __init__(self): self.llm_provider = get_llm_provider() @@ -407,6 +428,43 @@ class LLMClient: messages: List[LLMMessage], max_tokens: Optional[int] = None, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, + ): + try: + return await self._generate_impl(model, messages, max_tokens, tools) + except Exception as primary_error: + fallback_providers = _get_fallback_providers() + # Remove primary provider from fallback list + fallback_providers = [ + p for p in fallback_providers if p != self.llm_provider + ] + for fb_provider in fallback_providers: + _logger.warning( + "LLM generate failed with %s: %s. Trying fallback %s", + self.llm_provider.value, str(primary_error)[:200], fb_provider.value, + ) + try: + fb_client = LLMClient.__new__(LLMClient) + fb_client.llm_provider = fb_provider + fb_client._client = fb_client._get_client() + fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client) + fb_model = get_model() + return await fb_client._generate_impl( + fb_model, messages, max_tokens, tools + ) + except Exception as fb_error: + _logger.warning( + "Fallback %s also failed: %s", + fb_provider.value, str(fb_error)[:200], + ) + continue + raise primary_error + + async def _generate_impl( + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): parsed_tools = self.tool_calls_handler.parse_tools(tools) @@ -781,6 +839,46 @@ class LLMClient: strict: bool = False, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, + ) -> dict: + try: + return await self._generate_structured_impl( + model, messages, response_format, strict, tools, max_tokens + ) + except Exception as primary_error: + fallback_providers = _get_fallback_providers() + fallback_providers = [ + p for p in fallback_providers if p != self.llm_provider + ] + for fb_provider in fallback_providers: + _logger.warning( + "LLM generate_structured failed with %s: %s. Trying fallback %s", + self.llm_provider.value, str(primary_error)[:200], fb_provider.value, + ) + try: + fb_client = LLMClient.__new__(LLMClient) + fb_client.llm_provider = fb_provider + fb_client._client = fb_client._get_client() + fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client) + fb_model = get_model() + return await fb_client._generate_structured_impl( + fb_model, messages, response_format, strict, tools, max_tokens + ) + except Exception as fb_error: + _logger.warning( + "Fallback %s also failed: %s", + fb_provider.value, str(fb_error)[:200], + ) + continue + raise primary_error + + async def _generate_structured_impl( + self, + model: str, + messages: List[LLMMessage], + response_format: dict, + strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, + max_tokens: Optional[int] = None, ) -> dict: parsed_tools = self.tool_calls_handler.parse_tools(tools) diff --git a/backend/services/redis_service.py b/backend/services/redis_service.py new file mode 100644 index 0000000..7ea27a8 --- /dev/null +++ b/backend/services/redis_service.py @@ -0,0 +1,58 @@ +"""Redis service: connection pool and job progress utilities.""" +import json +import os +import uuid +from typing import Optional + +from arq import create_pool +from arq.connections import ArqRedis, RedisSettings + + +_pool: Optional[ArqRedis] = None + + +def _get_redis_settings() -> RedisSettings: + """Parse REDIS_URL env var into ARQ RedisSettings.""" + url = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + return RedisSettings.from_dsn(url) + + +async def get_arq_pool() -> ArqRedis: + """Get or create the shared ARQ Redis connection pool.""" + global _pool + if _pool is None: + _pool = await create_pool(_get_redis_settings()) + return _pool + + +async def close_arq_pool() -> None: + """Close the shared ARQ Redis pool (call on app shutdown).""" + global _pool + if _pool is not None: + await _pool.aclose() + _pool = None + + +async def enqueue_job(function_name: str, **kwargs) -> Optional[str]: + """Enqueue a job via ARQ. Returns the ARQ job ID.""" + pool = await get_arq_pool() + job = await pool.enqueue_job(function_name, **kwargs) + return job.job_id if job else None + + +async def publish_job_progress( + job_id: uuid.UUID, + progress: int, + message: str, + status: str = "processing", +) -> None: + """Publish a progress event to Redis pub/sub for SSE consumers.""" + pool = await get_arq_pool() + channel = f"job:{job_id}:progress" + payload = json.dumps({ + "job_id": str(job_id), + "progress": progress, + "message": message, + "status": status, + }) + await pool.publish(channel, payload.encode()) diff --git a/backend/utils/get_env.py b/backend/utils/get_env.py index c7dc16d..b05f0e8 100644 --- a/backend/utils/get_env.py +++ b/backend/utils/get_env.py @@ -117,3 +117,7 @@ def get_dall_e_3_quality_env(): # Gpt Image 1.5 Quality def get_gpt_image_1_5_quality_env(): return os.getenv("GPT_IMAGE_1_5_QUALITY") + + +def get_fallback_llm_providers_env(): + return os.getenv("FALLBACK_LLM_PROVIDERS") diff --git a/backend/utils/llm_calls/generate_presentation_outlines.py b/backend/utils/llm_calls/generate_presentation_outlines.py index cb044d4..f76f6dd 100644 --- a/backend/utils/llm_calls/generate_presentation_outlines.py +++ b/backend/utils/llm_calls/generate_presentation_outlines.py @@ -14,10 +14,30 @@ def get_system_prompt( verbosity: Optional[str] = None, instructions: Optional[str] = None, include_title_slide: bool = True, + brand_context: Optional[str] = None, + available_layouts: Optional[str] = None, ): + brand_section = "" + if brand_context: + brand_section = f""" + ## Brand Guidelines + {brand_context} + Ensure all text follows these brand voice and tone guidelines. + """ + + layouts_section = "" + if available_layouts: + layouts_section = f""" + ## Available Slide Layouts + {available_layouts} + Consider which content types best match these available layouts when structuring the outline. + """ + return f""" You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content. + You are restructuring and condensing existing content from a brief. Do NOT invent facts, statistics, or claims. Every data point must originate from the source material. + Try to use available tools for better results. {"# User Instruction:" if instructions else ""} @@ -29,6 +49,10 @@ def get_system_prompt( {"# Verbosity:" if verbosity else ""} {verbosity or ""} + {brand_section} + + {layouts_section} + - Provide content for each slide in markdown format. - Make sure that flow of the presentation is logical and consistent. - Place greater emphasis on numerical data. @@ -49,7 +73,12 @@ def get_user_prompt( n_slides: int, language: str, additional_context: Optional[str] = None, + content_summary: Optional[str] = None, ): + summary_section = "" + if content_summary: + summary_section = f"- Content Analysis Summary: {content_summary}" + return f""" **Input:** - User provided content: {content or "Create presentation"} @@ -57,6 +86,7 @@ def get_user_prompt( - Number of Slides: {n_slides} - Current Date and Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} - Additional Information: {additional_context or ""} + {summary_section} """ @@ -69,15 +99,21 @@ def get_messages( verbosity: Optional[str] = None, instructions: Optional[str] = None, include_title_slide: bool = True, + brand_context: Optional[str] = None, + available_layouts: Optional[str] = None, + content_summary: Optional[str] = None, ): return [ LLMSystemMessage( content=get_system_prompt( - tone, verbosity, instructions, include_title_slide + tone, verbosity, instructions, include_title_slide, + brand_context, available_layouts, ), ), LLMUserMessage( - content=get_user_prompt(content, n_slides, language, additional_context), + content=get_user_prompt( + content, n_slides, language, additional_context, content_summary, + ), ), ] @@ -92,6 +128,9 @@ async def generate_ppt_outline( instructions: Optional[str] = None, include_title_slide: bool = True, web_search: bool = False, + brand_context: Optional[str] = None, + available_layouts: Optional[str] = None, + content_summary: Optional[str] = None, ): model = get_model() response_model = get_presentation_outline_model_with_n_slides(n_slides) @@ -110,6 +149,9 @@ async def generate_ppt_outline( verbosity, instructions, include_title_slide, + brand_context, + available_layouts, + content_summary, ), response_model.model_json_schema(), strict=True, diff --git a/backend/utils/llm_calls/generate_slide_content.py b/backend/utils/llm_calls/generate_slide_content.py index fcdb9f1..8b50f9b 100644 --- a/backend/utils/llm_calls/generate_slide_content.py +++ b/backend/utils/llm_calls/generate_slide_content.py @@ -13,10 +13,30 @@ def get_system_prompt( tone: Optional[str] = None, verbosity: Optional[str] = None, instructions: Optional[str] = None, + brand_context: Optional[str] = None, + attachment_context: Optional[str] = None, ): + brand_section = "" + if brand_context: + brand_section = f""" + ## Brand Guidelines + {brand_context} + Ensure all generated text follows these brand voice and tone guidelines. + """ + + attachment_section = "" + if attachment_context: + attachment_section = f""" + ## Attachment Data + The following data from attachments (tables, charts) is available for this slide. Use it directly — do not invent data points. + {attachment_context} + """ + return f""" Generate structured slide based on provided outline, follow mentioned steps and notes and provide structured output. + You are extracting and restructuring content from a brief. Do NOT invent facts, statistics, or claims not present in the source material or attachment data. + {"# User Instructions:" if instructions else ""} {instructions or ""} @@ -26,6 +46,10 @@ def get_system_prompt( {"# Verbosity:" if verbosity else ""} {verbosity or ""} + {brand_section} + + {attachment_section} + # Steps 1. Analyze the outline. 2. Generate structured slide based on the outline. @@ -86,11 +110,15 @@ def get_messages( tone: Optional[str] = None, verbosity: Optional[str] = None, instructions: Optional[str] = None, + brand_context: Optional[str] = None, + attachment_context: Optional[str] = None, ): return [ LLMSystemMessage( - content=get_system_prompt(tone, verbosity, instructions), + content=get_system_prompt( + tone, verbosity, instructions, brand_context, attachment_context, + ), ), LLMUserMessage( content=get_user_prompt(outline, language), @@ -105,6 +133,8 @@ async def get_slide_content_from_type_and_outline( tone: Optional[str] = None, verbosity: Optional[str] = None, instructions: Optional[str] = None, + brand_context: Optional[str] = None, + attachment_context: Optional[str] = None, ): client = LLMClient() model = get_model() @@ -134,6 +164,8 @@ async def get_slide_content_from_type_and_outline( tone, verbosity, instructions, + brand_context, + attachment_context, ), response_format=response_schema, strict=False, diff --git a/backend/workers/__init__.py b/backend/workers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/workers/main.py b/backend/workers/main.py new file mode 100644 index 0000000..33abf11 --- /dev/null +++ b/backend/workers/main.py @@ -0,0 +1,24 @@ +"""ARQ worker entry point. + +Run with: python -m arq workers.main.WorkerSettings +""" +import os + +from arq.connections import RedisSettings + +from workers.master_deck_worker import parse_master_deck_task +from workers.presentation_worker import generate_presentation_task + + +def _get_redis_settings() -> RedisSettings: + url = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + return RedisSettings.from_dsn(url) + + +class WorkerSettings: + redis_settings = _get_redis_settings() + functions = [generate_presentation_task, parse_master_deck_task] + max_jobs = 5 + job_timeout = 600 # 10 minutes + max_tries = 3 + health_check_interval = 30 diff --git a/backend/workers/master_deck_worker.py b/backend/workers/master_deck_worker.py new file mode 100644 index 0000000..e2fd376 --- /dev/null +++ b/backend/workers/master_deck_worker.py @@ -0,0 +1,59 @@ +"""ARQ worker task: parse a master deck PPTX.""" +import traceback +import uuid +from datetime import datetime + +from models.sql.job import JobModel +from services.database import async_session_maker +from services.master_deck_parser_service import parse_master_deck +from services.redis_service import publish_job_progress + + +async def parse_master_deck_task(ctx: dict, job_id: str) -> None: + """ARQ task: parse a master deck via the existing parser service.""" + job_uuid = uuid.UUID(job_id) + + async with async_session_maker() as session: + job = await session.get(JobModel, job_uuid) + if not job: + return + + try: + job.status = "processing" + job.started_at = datetime.utcnow() + job.progress = 10 + job.progress_message = "Parsing master deck" + await session.commit() + + try: + await publish_job_progress(job_uuid, 10, "Parsing master deck") + except Exception: + pass + + # The existing parser updates MasterDeckModel directly + # presentation_id is reused to store the deck_id for this job type + await parse_master_deck(job.presentation_id) + + job.status = "completed" + job.progress = 100 + job.progress_message = "Parsing complete" + job.completed_at = datetime.utcnow() + await session.commit() + + try: + await publish_job_progress(job_uuid, 100, "Parsing complete", "completed") + except Exception: + pass + + except Exception as e: + traceback.print_exc() + job.status = "failed" + job.error_message = str(e)[:500] + job.progress_message = "Parsing failed" + job.completed_at = datetime.utcnow() + await session.commit() + + try: + await publish_job_progress(job_uuid, job.progress, "Parsing failed", "failed") + except Exception: + pass diff --git a/backend/workers/presentation_worker.py b/backend/workers/presentation_worker.py new file mode 100644 index 0000000..342bbfc --- /dev/null +++ b/backend/workers/presentation_worker.py @@ -0,0 +1,253 @@ +"""ARQ worker task: generate a presentation end-to-end.""" +import asyncio +import math +import random +import traceback +import uuid +from datetime import datetime +from typing import List + +import dirtyjson +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from models.presentation_model import PresentationModel +from models.presentation_outline_model import PresentationOutlineModel, SlideOutlineModel +from models.presentation_structure_model import PresentationStructureModel +from models.slide_model import SlideModel +from models.sql.job import JobModel +from services.brand_enforcement_service import BrandEnforcementService +from services.content_intelligence_service import ContentIntelligenceService +from services.database import async_session_maker +from services.image_generation_service import ImageGenerationService +from services.redis_service import publish_job_progress +from services.slide_mapping_engine import SlideMappingEngine +from utils.asset_directory_utils import get_images_directory +from utils.export_utils import export_presentation +from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline +from utils.llm_calls.generate_presentation_structure import generate_presentation_structure +from utils.llm_calls.generate_slide_content import get_slide_content_from_type_and_outline +from utils.presentation_utils import ( + get_layout_by_name, + get_presentation_title_from_outlines, + process_slide_and_fetch_assets, + select_toc_or_list_slide_layout_index, +) + + +async def generate_presentation_task(ctx: dict, job_id: str) -> None: + """ARQ task: full presentation generation pipeline.""" + job_uuid = uuid.UUID(job_id) + + async with async_session_maker() as session: + job = await session.get(JobModel, job_uuid) + if not job: + return + + try: + job.status = "processing" + job.started_at = datetime.utcnow() + job.progress = 0 + job.progress_message = "Starting generation" + await session.commit() + await _publish(job_uuid, 0, "Starting generation") + + # Load the stored request data from the presentation record + presentation = await session.get(PresentationModel, job.presentation_id) + if not presentation: + raise ValueError("Presentation record not found") + + # Extract request parameters from the stored presentation + content = presentation.content or "" + n_slides = presentation.n_slides or 10 + language = presentation.language or "en" + tone = presentation.tone or "professional" + verbosity = presentation.verbosity or "standard" + instructions = presentation.instructions + template = "default" + include_title_slide = True + + # --- Step 1: Brand context --- + brand_context = "" + if job.client_id: + brand_svc = BrandEnforcementService() + brand_context = await brand_svc.get_brand_context_for_llm( + job.client_id, session + ) + + await _update_job(session, job, 5, "Analyzing content") + + # --- Step 2: Content intelligence (if raw content provided) --- + content_summary = None + if content and len(content) > 100: + ci_service = ContentIntelligenceService() + classified = await ci_service.classify(content) + content_summary = classified.summary + + await _update_job(session, job, 10, "Generating outlines") + + # --- Step 3: Generate outlines --- + presentation_outlines_text = "" + async for chunk in generate_ppt_outline( + content, + n_slides, + language, + None, # additional_context + tone, + verbosity, + instructions, + include_title_slide, + False, # web_search + brand_context=brand_context, + content_summary=content_summary, + ): + if isinstance(chunk, HTTPException): + raise chunk + presentation_outlines_text += chunk + + try: + outlines_json = dict(dirtyjson.loads(presentation_outlines_text)) + except Exception: + raise ValueError("Failed to parse generated outlines") + + presentation_outlines = PresentationOutlineModel(**outlines_json) + total_outlines = n_slides + + await _update_job(session, job, 25, "Selecting layouts") + + # --- Step 4: Layout selection --- + layout_model = await get_layout_by_name(template) + total_slide_layouts = len(layout_model.slides) + + if layout_model.ordered: + presentation_structure = layout_model.to_presentation_structure() + else: + presentation_structure = await generate_presentation_structure( + presentation_outlines, layout_model, instructions + ) + + presentation_structure.slides = presentation_structure.slides[:total_outlines] + for index in range(total_outlines): + random_slide_index = random.randint(0, total_slide_layouts - 1) + if index >= len(presentation_structure.slides): + presentation_structure.slides.append(random_slide_index) + elif presentation_structure.slides[index] >= total_slide_layouts: + presentation_structure.slides[index] = random_slide_index + + # Update presentation model with outlines & structure + presentation.title = get_presentation_title_from_outlines(presentation_outlines) + presentation.outlines = presentation_outlines.model_dump() + presentation.layout = layout_model.model_dump() + presentation.structure = presentation_structure.model_dump() + await session.commit() + + await _update_job(session, job, 35, "Generating slides") + + # --- Step 5: Generate slide content --- + image_generation_service = ImageGenerationService(get_images_directory()) + async_assets_generation_tasks = [] + slides: List[SlideModel] = [] + + slide_layout_indices = presentation_structure.slides + slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices] + + batch_size = 10 + for start in range(0, len(slide_layouts), batch_size): + end = min(start + batch_size, len(slide_layouts)) + + content_tasks = [ + get_slide_content_from_type_and_outline( + slide_layouts[i], + presentation_outlines.slides[i], + language, + tone, + verbosity, + instructions, + brand_context=brand_context, + ) + for i in range(start, end) + ] + batch_contents = await asyncio.gather(*content_tasks) + + batch_slides = [] + for offset, slide_content in enumerate(batch_contents): + i = start + offset + slide = SlideModel( + presentation=job.presentation_id, + layout_group=layout_model.name, + layout=slide_layouts[i].id, + index=i, + speaker_note=slide_content.get("__speaker_note__"), + content=slide_content, + ) + slides.append(slide) + batch_slides.append(slide) + + asset_tasks = [ + process_slide_and_fetch_assets(image_generation_service, slide) + for slide in batch_slides + ] + async_assets_generation_tasks.extend(asset_tasks) + + pct = 35 + int((end / len(slide_layouts)) * 40) + await _update_job(session, job, pct, f"Generating slide {end}/{len(slide_layouts)}") + + await _update_job(session, job, 80, "Fetching assets") + + # --- Step 6: Fetch assets --- + generated_assets_list = await asyncio.gather(*async_assets_generation_tasks) + generated_assets = [] + for assets_list in generated_assets_list: + generated_assets.extend(assets_list) + + await _update_job(session, job, 90, "Saving presentation") + + # --- Step 7: Save --- + session.add(presentation) + session.add_all(slides) + session.add_all(generated_assets) + await session.commit() + + await _update_job(session, job, 95, "Exporting PPTX") + + # --- Step 8: Export --- + await export_presentation( + job.presentation_id, + presentation.title or str(uuid.uuid4()), + "pptx", + ) + + # --- Done --- + job.status = "completed" + job.progress = 100 + job.progress_message = "Generation complete" + job.completed_at = datetime.utcnow() + await session.commit() + await _publish(job_uuid, 100, "Generation complete", "completed") + + except Exception as e: + traceback.print_exc() + job.status = "failed" + job.error_message = str(e)[:500] + job.progress_message = "Generation failed" + job.completed_at = datetime.utcnow() + await session.commit() + await _publish(job_uuid, job.progress, "Generation failed", "failed") + + +async def _update_job( + session: AsyncSession, job: JobModel, progress: int, message: str +) -> None: + job.progress = progress + job.progress_message = message + await session.commit() + await publish_job_progress(job.id, progress, message) + + +async def _publish( + job_id: uuid.UUID, progress: int, message: str, status: str = "processing" +) -> None: + try: + await publish_job_progress(job_id, progress, message, status) + except Exception: + pass # Redis unavailable is not fatal