Phase 4: Generation Pipeline — brand enforcement, enhanced LLM calls, ARQ job queue
- Step 14: Brand enforcement service (font/color/logo replacement, WCAG contrast check, LLM prompt context) - Step 15: Enhanced outline & slide content generation with brand context, content summary, "no hallucination" instructions - Step 15b: LLM auto-fallback retry logic across providers (FALLBACK_LLM_PROVIDERS env) - Step 16: Redis/ARQ job queue — worker entry point, presentation & master deck workers, job status/SSE endpoints, graceful fallback to BackgroundTasks when Redis unavailable Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a2bd4cfefa
commit
a0d73b3b63
16 changed files with 1038 additions and 17 deletions
|
|
@ -4,6 +4,7 @@ import os
|
|||
from fastapi import FastAPI
|
||||
|
||||
from services.database import create_db_and_tables
|
||||
from services.redis_service import close_arq_pool
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
from utils.model_availability import (
|
||||
check_llm_and_image_provider_api_or_model_availability,
|
||||
|
|
@ -21,3 +22,4 @@ async def app_lifespan(_: FastAPI):
|
|||
await create_db_and_tables()
|
||||
await check_llm_and_image_provider_api_or_model_availability()
|
||||
yield
|
||||
await close_arq_pool()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from api.v1.admin.clients_router import CLIENTS_ROUTER
|
|||
from api.v1.admin.audit_router import AUDIT_ROUTER
|
||||
from api.v1.admin.brand_config_router import BRAND_CONFIG_ROUTER
|
||||
from api.v1.admin.master_decks_router import MASTER_DECKS_ROUTER
|
||||
from api.v1.ppt.endpoints.jobs import JOBS_ROUTER
|
||||
from api.middlewares.audit_middleware import AuditMiddleware
|
||||
|
||||
|
||||
|
|
@ -31,6 +32,7 @@ ADMIN_ROUTER.include_router(MASTER_DECKS_ROUTER)
|
|||
app.include_router(AUTH_ROUTER)
|
||||
app.include_router(ADMIN_ROUTER)
|
||||
app.include_router(API_V1_PPT_ROUTER)
|
||||
app.include_router(JOBS_ROUTER)
|
||||
app.include_router(API_V1_WEBHOOK_ROUTER)
|
||||
app.include_router(API_V1_MOCK_ROUTER)
|
||||
|
||||
|
|
|
|||
|
|
@ -114,11 +114,27 @@ async def upload_master_deck(
|
|||
await session.commit()
|
||||
await session.refresh(deck)
|
||||
|
||||
# Kick off async parsing
|
||||
import asyncio
|
||||
from services.master_deck_parser_service import parse_master_deck
|
||||
# Kick off async parsing via ARQ (fallback to asyncio.create_task)
|
||||
try:
|
||||
from models.sql.job import JobModel
|
||||
from services.redis_service import enqueue_job
|
||||
|
||||
asyncio.create_task(parse_master_deck(deck_id))
|
||||
job = JobModel(
|
||||
user_id=admin.id,
|
||||
client_id=client_id,
|
||||
presentation_id=deck_id, # reuse field for deck_id
|
||||
job_type="parse_master_deck",
|
||||
status="queued",
|
||||
progress=0,
|
||||
progress_message="Queued for parsing",
|
||||
)
|
||||
session.add(job)
|
||||
await session.commit()
|
||||
await enqueue_job("parse_master_deck_task", job_id=str(job.id))
|
||||
except Exception:
|
||||
import asyncio
|
||||
from services.master_deck_parser_service import parse_master_deck
|
||||
asyncio.create_task(parse_master_deck(deck_id))
|
||||
|
||||
return {
|
||||
"id": str(deck.id),
|
||||
|
|
@ -235,10 +251,26 @@ async def reparse_master_deck(
|
|||
deck.parse_status = "pending"
|
||||
await session.commit()
|
||||
|
||||
import asyncio
|
||||
from services.master_deck_parser_service import parse_master_deck
|
||||
try:
|
||||
from models.sql.job import JobModel
|
||||
from services.redis_service import enqueue_job
|
||||
|
||||
asyncio.create_task(parse_master_deck(deck_id))
|
||||
job = JobModel(
|
||||
user_id=admin.id,
|
||||
client_id=deck.client_id,
|
||||
presentation_id=deck_id,
|
||||
job_type="parse_master_deck",
|
||||
status="queued",
|
||||
progress=0,
|
||||
progress_message="Queued for re-parsing",
|
||||
)
|
||||
session.add(job)
|
||||
await session.commit()
|
||||
await enqueue_job("parse_master_deck_task", job_id=str(job.id))
|
||||
except Exception:
|
||||
import asyncio
|
||||
from services.master_deck_parser_service import parse_master_deck
|
||||
asyncio.create_task(parse_master_deck(deck_id))
|
||||
|
||||
return {"ok": True, "parse_status": "pending"}
|
||||
|
||||
|
|
|
|||
151
backend/api/v1/ppt/endpoints/jobs.py
Normal file
151
backend/api/v1/ppt/endpoints/jobs.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""Job status polling and SSE streaming endpoints."""
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.sse_response import SSEResponse, SSEErrorResponse
|
||||
from models.sql.job import JobModel
|
||||
from models.sql.user import UserModel
|
||||
from services.database import get_async_session
|
||||
from services.redis_service import get_arq_pool
|
||||
from utils.auth_dependencies import get_current_user
|
||||
|
||||
JOBS_ROUTER = APIRouter(prefix="/api/v1/ppt", tags=["Jobs"])
|
||||
|
||||
|
||||
@JOBS_ROUTER.get("/jobs/{job_id}")
|
||||
async def get_job_status(
|
||||
job_id: uuid.UUID,
|
||||
_current_user: UserModel = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Poll job status."""
|
||||
job = await session.get(JobModel, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return {
|
||||
"id": str(job.id),
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"progress": job.progress,
|
||||
"progress_message": job.progress_message,
|
||||
"error_message": job.error_message,
|
||||
"presentation_id": str(job.presentation_id) if job.presentation_id else None,
|
||||
"created_at": job.created_at.isoformat() if job.created_at else None,
|
||||
"started_at": job.started_at.isoformat() if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
@JOBS_ROUTER.get("/jobs/{job_id}/stream")
|
||||
async def stream_job_progress(
|
||||
job_id: uuid.UUID,
|
||||
_current_user: UserModel = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""SSE stream of job progress events via Redis pub/sub."""
|
||||
job = await session.get(JobModel, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
# If already completed/failed, return final event immediately
|
||||
if job.status in ("completed", "failed"):
|
||||
async def done_stream():
|
||||
data = {
|
||||
"type": "progress",
|
||||
"job_id": str(job.id),
|
||||
"progress": job.progress,
|
||||
"message": job.progress_message or job.status,
|
||||
"status": job.status,
|
||||
}
|
||||
yield SSEResponse(
|
||||
event="response", data=json.dumps(data)
|
||||
).to_string()
|
||||
|
||||
return StreamingResponse(done_stream(), media_type="text/event-stream")
|
||||
|
||||
async def progress_stream():
|
||||
try:
|
||||
pool = await get_arq_pool()
|
||||
pubsub = pool.pubsub()
|
||||
channel = f"job:{job_id}:progress"
|
||||
await pubsub.subscribe(channel)
|
||||
|
||||
# Send initial status
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({
|
||||
"type": "progress",
|
||||
"job_id": str(job_id),
|
||||
"progress": job.progress,
|
||||
"message": job.progress_message or "Waiting",
|
||||
"status": job.status,
|
||||
}),
|
||||
).to_string()
|
||||
|
||||
# Listen for updates with timeout
|
||||
timeout = 600 # 10 minutes max
|
||||
elapsed = 0
|
||||
while elapsed < timeout:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=2.0
|
||||
)
|
||||
if message and message["type"] == "message":
|
||||
raw = message["data"]
|
||||
if isinstance(raw, bytes):
|
||||
raw = raw.decode()
|
||||
payload = json.loads(raw)
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "progress", **payload}),
|
||||
).to_string()
|
||||
|
||||
# Stop streaming on terminal status
|
||||
if payload.get("status") in ("completed", "failed"):
|
||||
break
|
||||
else:
|
||||
elapsed += 2
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.aclose()
|
||||
|
||||
except Exception as e:
|
||||
yield SSEErrorResponse(detail=str(e)[:200]).to_string()
|
||||
|
||||
return StreamingResponse(progress_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@JOBS_ROUTER.delete("/jobs/{job_id}")
|
||||
async def cancel_job(
|
||||
job_id: uuid.UUID,
|
||||
_current_user: UserModel = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Cancel a queued or processing job."""
|
||||
job = await session.get(JobModel, job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.status in ("completed", "failed"):
|
||||
raise HTTPException(status_code=409, detail="Job already finished")
|
||||
|
||||
# Try to abort via ARQ
|
||||
try:
|
||||
pool = await get_arq_pool()
|
||||
await pool.abort_job(str(job_id))
|
||||
except Exception:
|
||||
pass # Best effort
|
||||
|
||||
job.status = "failed"
|
||||
job.error_message = "Cancelled by user"
|
||||
job.progress_message = "Cancelled"
|
||||
await session.commit()
|
||||
|
||||
return {"ok": True, "status": "cancelled"}
|
||||
|
|
@ -846,6 +846,18 @@ async def generate_presentation_async(
|
|||
try:
|
||||
(presentation_id,) = await check_if_api_request_is_valid(request, sql_session)
|
||||
|
||||
# Create a lightweight presentation record so the worker can load it
|
||||
presentation = PresentationModel(
|
||||
id=presentation_id,
|
||||
content=request.content,
|
||||
n_slides=request.n_slides,
|
||||
language=request.language,
|
||||
tone=request.tone.value,
|
||||
verbosity=request.verbosity.value,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
sql_session.add(presentation)
|
||||
|
||||
async_status = AsyncPresentationGenerationTaskModel(
|
||||
status="pending",
|
||||
message="Queued for generation",
|
||||
|
|
@ -854,13 +866,39 @@ async def generate_presentation_async(
|
|||
sql_session.add(async_status)
|
||||
await sql_session.commit()
|
||||
|
||||
background_tasks.add_task(
|
||||
generate_presentation_handler,
|
||||
request,
|
||||
presentation_id,
|
||||
async_status=async_status,
|
||||
sql_session=sql_session,
|
||||
)
|
||||
# Try ARQ job queue first; fall back to BackgroundTasks
|
||||
job_enqueued = False
|
||||
try:
|
||||
from models.sql.job import JobModel
|
||||
from services.redis_service import enqueue_job
|
||||
|
||||
job = JobModel(
|
||||
user_id=_current_user.id,
|
||||
client_id=getattr(_current_user, "default_client_id", _current_user.id),
|
||||
presentation_id=presentation_id,
|
||||
job_type="generate_presentation",
|
||||
status="queued",
|
||||
progress=0,
|
||||
progress_message="Queued for generation",
|
||||
)
|
||||
sql_session.add(job)
|
||||
await sql_session.commit()
|
||||
|
||||
await enqueue_job("generate_presentation_task", job_id=str(job.id))
|
||||
job_enqueued = True
|
||||
except Exception:
|
||||
# Redis/ARQ unavailable — fall back to in-process background task
|
||||
pass
|
||||
|
||||
if not job_enqueued:
|
||||
background_tasks.add_task(
|
||||
generate_presentation_handler,
|
||||
request,
|
||||
presentation_id,
|
||||
async_status=async_status,
|
||||
sql_session=sql_session,
|
||||
)
|
||||
|
||||
return async_status
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ dependencies = [
|
|||
"python-jose[cryptography]>=3.3",
|
||||
"openpyxl>=3.1",
|
||||
"trafilatura>=2.0",
|
||||
"arq>=0.26",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
|
|
|||
225
backend/services/brand_enforcement_service.py
Normal file
225
backend/services/brand_enforcement_service.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
"""Brand Enforcement Service: apply client brand rules to generated content and PPTX output."""
|
||||
import math
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.pptx_models import (
|
||||
PptxAutoShapeBoxModel,
|
||||
PptxChartBoxModel,
|
||||
PptxFontModel,
|
||||
PptxPictureBoxModel,
|
||||
PptxPictureModel,
|
||||
PptxPositionModel,
|
||||
PptxPresentationModel,
|
||||
PptxTextBoxModel,
|
||||
)
|
||||
from models.sql.brand_config import BrandConfigModel
|
||||
|
||||
|
||||
# Default slide dimensions (matching PptxPresentationCreator)
|
||||
_SLIDE_WIDTH = 1280
|
||||
_SLIDE_HEIGHT = 720
|
||||
|
||||
# Logo defaults
|
||||
_LOGO_WIDTH = 80
|
||||
_LOGO_HEIGHT = 40
|
||||
_LOGO_MARGIN = 20
|
||||
|
||||
|
||||
class BrandEnforcementService:
|
||||
|
||||
async def get_brand_context_for_llm(
|
||||
self, client_id: uuid.UUID, session: AsyncSession
|
||||
) -> str:
|
||||
"""Build a text block for LLM system prompt injection with brand guidelines."""
|
||||
brand = await self._load_brand(client_id, session)
|
||||
if not brand:
|
||||
return ""
|
||||
|
||||
parts: List[str] = []
|
||||
|
||||
if brand.primary_colors:
|
||||
colors_str = ", ".join(brand.primary_colors[:6])
|
||||
parts.append(f"Brand primary colors: {colors_str}")
|
||||
|
||||
if brand.secondary_colors:
|
||||
colors_str = ", ".join(brand.secondary_colors[:6])
|
||||
parts.append(f"Brand secondary colors: {colors_str}")
|
||||
|
||||
if brand.fonts:
|
||||
heading = brand.fonts.get("heading", "")
|
||||
body = brand.fonts.get("body", "")
|
||||
if heading:
|
||||
parts.append(f"Heading font: {heading}")
|
||||
if body:
|
||||
parts.append(f"Body font: {body}")
|
||||
|
||||
if brand.voice_rules:
|
||||
parts.append(f"Brand voice guidelines: {brand.voice_rules}")
|
||||
|
||||
if brand.voice_examples:
|
||||
examples = "\n - ".join(str(e) for e in brand.voice_examples[:3])
|
||||
parts.append(f"Brand voice examples:\n - {examples}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def enforce_on_pptx_model(
|
||||
self,
|
||||
model: PptxPresentationModel,
|
||||
brand: BrandConfigModel,
|
||||
) -> PptxPresentationModel:
|
||||
"""Apply brand fonts, colors, and logo to a PPTX model."""
|
||||
heading_font = (brand.fonts or {}).get("heading")
|
||||
body_font = (brand.fonts or {}).get("body")
|
||||
brand_colors = self.get_brand_colors_list(brand)
|
||||
logo_path = brand.logo_paths[0] if brand.logo_paths else None
|
||||
|
||||
for slide_idx, slide in enumerate(model.slides):
|
||||
for shape in slide.shapes:
|
||||
# Replace fonts
|
||||
if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)):
|
||||
self._enforce_fonts_on_shape(shape, heading_font, body_font, slide_idx)
|
||||
|
||||
# Set brand colors on charts
|
||||
if isinstance(shape, PptxChartBoxModel):
|
||||
if brand_colors:
|
||||
shape.brand_colors = brand_colors
|
||||
if body_font:
|
||||
shape.font_name = body_font
|
||||
|
||||
# Contrast check on all text shapes
|
||||
bg_color = slide.background.color if slide.background else "FFFFFF"
|
||||
for shape in slide.shapes:
|
||||
if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)):
|
||||
self._fix_contrast(shape, bg_color)
|
||||
|
||||
# Add logo to non-title slides
|
||||
if slide_idx > 0 and logo_path:
|
||||
logo_shape = PptxPictureBoxModel(
|
||||
position=PptxPositionModel(
|
||||
left=_SLIDE_WIDTH - _LOGO_WIDTH - _LOGO_MARGIN,
|
||||
top=_SLIDE_HEIGHT - _LOGO_HEIGHT - _LOGO_MARGIN,
|
||||
width=_LOGO_WIDTH,
|
||||
height=_LOGO_HEIGHT,
|
||||
),
|
||||
clip=False,
|
||||
picture=PptxPictureModel(is_network=False, path=logo_path),
|
||||
)
|
||||
slide.shapes.append(logo_shape)
|
||||
|
||||
return model
|
||||
|
||||
def get_brand_colors_list(self, brand: BrandConfigModel) -> List[str]:
|
||||
"""Return usable brand color list, with fallback defaults."""
|
||||
colors: List[str] = []
|
||||
if brand.primary_colors:
|
||||
colors.extend(str(c).lstrip("#") for c in brand.primary_colors)
|
||||
if brand.secondary_colors:
|
||||
colors.extend(str(c).lstrip("#") for c in brand.secondary_colors)
|
||||
if not colors:
|
||||
# Fallback to default chart colors
|
||||
colors = [
|
||||
"4472C4", "ED7D31", "A5A5A5", "FFC000", "5B9BD5",
|
||||
"70AD47", "264478", "9B57A0", "636363", "EB6E1F",
|
||||
]
|
||||
return colors
|
||||
|
||||
# --- Internal methods ---
|
||||
|
||||
async def _load_brand(
|
||||
self, client_id: uuid.UUID, session: AsyncSession
|
||||
) -> Optional[BrandConfigModel]:
|
||||
stmt = select(BrandConfigModel).where(
|
||||
BrandConfigModel.client_id == client_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _enforce_fonts_on_shape(
|
||||
self,
|
||||
shape,
|
||||
heading_font: Optional[str],
|
||||
body_font: Optional[str],
|
||||
slide_idx: int,
|
||||
) -> None:
|
||||
"""Replace font names in text shapes with brand fonts."""
|
||||
paragraphs = getattr(shape, "paragraphs", None)
|
||||
if not paragraphs:
|
||||
return
|
||||
|
||||
for para in paragraphs:
|
||||
# Title slides (idx 0) and large fonts use heading font
|
||||
is_heading = slide_idx == 0 or (para.font and para.font.size >= 24)
|
||||
target_font = heading_font if is_heading else body_font
|
||||
|
||||
if target_font:
|
||||
if para.font:
|
||||
para.font.name = target_font
|
||||
if para.text_runs:
|
||||
for run in para.text_runs:
|
||||
if run.font:
|
||||
run.font.name = target_font
|
||||
|
||||
def _fix_contrast(self, shape, bg_color: str) -> None:
|
||||
"""Ensure text has sufficient contrast against slide background."""
|
||||
paragraphs = getattr(shape, "paragraphs", None)
|
||||
if not paragraphs:
|
||||
return
|
||||
|
||||
bg_lum = _relative_luminance(bg_color)
|
||||
|
||||
for para in paragraphs:
|
||||
self._fix_font_contrast(para.font, bg_lum)
|
||||
if para.text_runs:
|
||||
for run in para.text_runs:
|
||||
self._fix_font_contrast(run.font, bg_lum)
|
||||
|
||||
def _fix_font_contrast(
|
||||
self, font: Optional[PptxFontModel], bg_lum: float
|
||||
) -> None:
|
||||
if not font:
|
||||
return
|
||||
|
||||
text_lum = _relative_luminance(font.color)
|
||||
ratio = _contrast_ratio(text_lum, bg_lum)
|
||||
|
||||
# WCAG AA requires 4.5:1 for normal text, 3:1 for large text
|
||||
if ratio < 3.0:
|
||||
# Swap to white or black based on background
|
||||
font.color = "FFFFFF" if bg_lum < 0.5 else "000000"
|
||||
|
||||
|
||||
# --- Color utility functions ---
|
||||
|
||||
|
||||
def _hex_to_rgb(hex_color: str) -> tuple:
|
||||
"""Convert hex color string to (r, g, b) tuple with values 0-255."""
|
||||
h = hex_color.lstrip("#")
|
||||
if len(h) == 3:
|
||||
h = h[0]*2 + h[1]*2 + h[2]*2
|
||||
if len(h) < 6:
|
||||
h = h.ljust(6, "0")
|
||||
return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||||
|
||||
|
||||
def _relative_luminance(hex_color: str) -> float:
|
||||
"""Calculate relative luminance per WCAG 2.0."""
|
||||
r, g, b = _hex_to_rgb(hex_color)
|
||||
rs = r / 255.0
|
||||
gs = g / 255.0
|
||||
bs = b / 255.0
|
||||
|
||||
def linearize(c):
|
||||
return c / 12.92 if c <= 0.03928 else math.pow((c + 0.055) / 1.055, 2.4)
|
||||
|
||||
return 0.2126 * linearize(rs) + 0.7152 * linearize(gs) + 0.0722 * linearize(bs)
|
||||
|
||||
|
||||
def _contrast_ratio(lum1: float, lum2: float) -> float:
|
||||
"""Calculate contrast ratio between two luminance values."""
|
||||
lighter = max(lum1, lum2)
|
||||
darker = min(lum1, lum2)
|
||||
return (lighter + 0.05) / (darker + 0.05)
|
||||
|
|
@ -47,6 +47,7 @@ from utils.get_env import (
|
|||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_disable_thinking_env,
|
||||
get_fallback_llm_providers_env,
|
||||
get_google_api_key_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
|
|
@ -62,6 +63,26 @@ from utils.schema_utils import (
|
|||
)
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_fallback_providers() -> list:
|
||||
"""Parse FALLBACK_LLM_PROVIDERS env var into list of LLMProvider enums."""
|
||||
raw = get_fallback_llm_providers_env()
|
||||
if not raw:
|
||||
return []
|
||||
providers = []
|
||||
for name in raw.split(","):
|
||||
name = name.strip().lower()
|
||||
try:
|
||||
providers.append(LLMProvider(name))
|
||||
except ValueError:
|
||||
pass
|
||||
return providers
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self):
|
||||
self.llm_provider = get_llm_provider()
|
||||
|
|
@ -407,6 +428,43 @@ class LLMClient:
|
|||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
||||
):
|
||||
try:
|
||||
return await self._generate_impl(model, messages, max_tokens, tools)
|
||||
except Exception as primary_error:
|
||||
fallback_providers = _get_fallback_providers()
|
||||
# Remove primary provider from fallback list
|
||||
fallback_providers = [
|
||||
p for p in fallback_providers if p != self.llm_provider
|
||||
]
|
||||
for fb_provider in fallback_providers:
|
||||
_logger.warning(
|
||||
"LLM generate failed with %s: %s. Trying fallback %s",
|
||||
self.llm_provider.value, str(primary_error)[:200], fb_provider.value,
|
||||
)
|
||||
try:
|
||||
fb_client = LLMClient.__new__(LLMClient)
|
||||
fb_client.llm_provider = fb_provider
|
||||
fb_client._client = fb_client._get_client()
|
||||
fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client)
|
||||
fb_model = get_model()
|
||||
return await fb_client._generate_impl(
|
||||
fb_model, messages, max_tokens, tools
|
||||
)
|
||||
except Exception as fb_error:
|
||||
_logger.warning(
|
||||
"Fallback %s also failed: %s",
|
||||
fb_provider.value, str(fb_error)[:200],
|
||||
)
|
||||
continue
|
||||
raise primary_error
|
||||
|
||||
async def _generate_impl(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
max_tokens: Optional[int] = None,
|
||||
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
||||
):
|
||||
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
||||
|
||||
|
|
@ -781,6 +839,46 @@ class LLMClient:
|
|||
strict: bool = False,
|
||||
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict:
|
||||
try:
|
||||
return await self._generate_structured_impl(
|
||||
model, messages, response_format, strict, tools, max_tokens
|
||||
)
|
||||
except Exception as primary_error:
|
||||
fallback_providers = _get_fallback_providers()
|
||||
fallback_providers = [
|
||||
p for p in fallback_providers if p != self.llm_provider
|
||||
]
|
||||
for fb_provider in fallback_providers:
|
||||
_logger.warning(
|
||||
"LLM generate_structured failed with %s: %s. Trying fallback %s",
|
||||
self.llm_provider.value, str(primary_error)[:200], fb_provider.value,
|
||||
)
|
||||
try:
|
||||
fb_client = LLMClient.__new__(LLMClient)
|
||||
fb_client.llm_provider = fb_provider
|
||||
fb_client._client = fb_client._get_client()
|
||||
fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client)
|
||||
fb_model = get_model()
|
||||
return await fb_client._generate_structured_impl(
|
||||
fb_model, messages, response_format, strict, tools, max_tokens
|
||||
)
|
||||
except Exception as fb_error:
|
||||
_logger.warning(
|
||||
"Fallback %s also failed: %s",
|
||||
fb_provider.value, str(fb_error)[:200],
|
||||
)
|
||||
continue
|
||||
raise primary_error
|
||||
|
||||
async def _generate_structured_impl(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict:
|
||||
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
||||
|
||||
|
|
|
|||
58
backend/services/redis_service.py
Normal file
58
backend/services/redis_service.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
"""Redis service: connection pool and job progress utilities."""
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from arq import create_pool
|
||||
from arq.connections import ArqRedis, RedisSettings
|
||||
|
||||
|
||||
_pool: Optional[ArqRedis] = None
|
||||
|
||||
|
||||
def _get_redis_settings() -> RedisSettings:
|
||||
"""Parse REDIS_URL env var into ARQ RedisSettings."""
|
||||
url = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
return RedisSettings.from_dsn(url)
|
||||
|
||||
|
||||
async def get_arq_pool() -> ArqRedis:
|
||||
"""Get or create the shared ARQ Redis connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = await create_pool(_get_redis_settings())
|
||||
return _pool
|
||||
|
||||
|
||||
async def close_arq_pool() -> None:
|
||||
"""Close the shared ARQ Redis pool (call on app shutdown)."""
|
||||
global _pool
|
||||
if _pool is not None:
|
||||
await _pool.aclose()
|
||||
_pool = None
|
||||
|
||||
|
||||
async def enqueue_job(function_name: str, **kwargs) -> Optional[str]:
|
||||
"""Enqueue a job via ARQ. Returns the ARQ job ID."""
|
||||
pool = await get_arq_pool()
|
||||
job = await pool.enqueue_job(function_name, **kwargs)
|
||||
return job.job_id if job else None
|
||||
|
||||
|
||||
async def publish_job_progress(
|
||||
job_id: uuid.UUID,
|
||||
progress: int,
|
||||
message: str,
|
||||
status: str = "processing",
|
||||
) -> None:
|
||||
"""Publish a progress event to Redis pub/sub for SSE consumers."""
|
||||
pool = await get_arq_pool()
|
||||
channel = f"job:{job_id}:progress"
|
||||
payload = json.dumps({
|
||||
"job_id": str(job_id),
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"status": status,
|
||||
})
|
||||
await pool.publish(channel, payload.encode())
|
||||
|
|
@ -117,3 +117,7 @@ def get_dall_e_3_quality_env():
|
|||
# Gpt Image 1.5 Quality
|
||||
def get_gpt_image_1_5_quality_env():
|
||||
return os.getenv("GPT_IMAGE_1_5_QUALITY")
|
||||
|
||||
|
||||
def get_fallback_llm_providers_env():
|
||||
return os.getenv("FALLBACK_LLM_PROVIDERS")
|
||||
|
|
|
|||
|
|
@ -14,10 +14,30 @@ def get_system_prompt(
|
|||
verbosity: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
include_title_slide: bool = True,
|
||||
brand_context: Optional[str] = None,
|
||||
available_layouts: Optional[str] = None,
|
||||
):
|
||||
brand_section = ""
|
||||
if brand_context:
|
||||
brand_section = f"""
|
||||
## Brand Guidelines
|
||||
{brand_context}
|
||||
Ensure all text follows these brand voice and tone guidelines.
|
||||
"""
|
||||
|
||||
layouts_section = ""
|
||||
if available_layouts:
|
||||
layouts_section = f"""
|
||||
## Available Slide Layouts
|
||||
{available_layouts}
|
||||
Consider which content types best match these available layouts when structuring the outline.
|
||||
"""
|
||||
|
||||
return f"""
|
||||
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
|
||||
|
||||
You are restructuring and condensing existing content from a brief. Do NOT invent facts, statistics, or claims. Every data point must originate from the source material.
|
||||
|
||||
Try to use available tools for better results.
|
||||
|
||||
{"# User Instruction:" if instructions else ""}
|
||||
|
|
@ -29,6 +49,10 @@ def get_system_prompt(
|
|||
{"# Verbosity:" if verbosity else ""}
|
||||
{verbosity or ""}
|
||||
|
||||
{brand_section}
|
||||
|
||||
{layouts_section}
|
||||
|
||||
- Provide content for each slide in markdown format.
|
||||
- Make sure that flow of the presentation is logical and consistent.
|
||||
- Place greater emphasis on numerical data.
|
||||
|
|
@ -49,7 +73,12 @@ def get_user_prompt(
|
|||
n_slides: int,
|
||||
language: str,
|
||||
additional_context: Optional[str] = None,
|
||||
content_summary: Optional[str] = None,
|
||||
):
|
||||
summary_section = ""
|
||||
if content_summary:
|
||||
summary_section = f"- Content Analysis Summary: {content_summary}"
|
||||
|
||||
return f"""
|
||||
**Input:**
|
||||
- User provided content: {content or "Create presentation"}
|
||||
|
|
@ -57,6 +86,7 @@ def get_user_prompt(
|
|||
- Number of Slides: {n_slides}
|
||||
- Current Date and Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
- Additional Information: {additional_context or ""}
|
||||
{summary_section}
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -69,15 +99,21 @@ def get_messages(
|
|||
verbosity: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
include_title_slide: bool = True,
|
||||
brand_context: Optional[str] = None,
|
||||
available_layouts: Optional[str] = None,
|
||||
content_summary: Optional[str] = None,
|
||||
):
|
||||
return [
|
||||
LLMSystemMessage(
|
||||
content=get_system_prompt(
|
||||
tone, verbosity, instructions, include_title_slide
|
||||
tone, verbosity, instructions, include_title_slide,
|
||||
brand_context, available_layouts,
|
||||
),
|
||||
),
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(content, n_slides, language, additional_context),
|
||||
content=get_user_prompt(
|
||||
content, n_slides, language, additional_context, content_summary,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -92,6 +128,9 @@ async def generate_ppt_outline(
|
|||
instructions: Optional[str] = None,
|
||||
include_title_slide: bool = True,
|
||||
web_search: bool = False,
|
||||
brand_context: Optional[str] = None,
|
||||
available_layouts: Optional[str] = None,
|
||||
content_summary: Optional[str] = None,
|
||||
):
|
||||
model = get_model()
|
||||
response_model = get_presentation_outline_model_with_n_slides(n_slides)
|
||||
|
|
@ -110,6 +149,9 @@ async def generate_ppt_outline(
|
|||
verbosity,
|
||||
instructions,
|
||||
include_title_slide,
|
||||
brand_context,
|
||||
available_layouts,
|
||||
content_summary,
|
||||
),
|
||||
response_model.model_json_schema(),
|
||||
strict=True,
|
||||
|
|
|
|||
|
|
@ -13,10 +13,30 @@ def get_system_prompt(
|
|||
tone: Optional[str] = None,
|
||||
verbosity: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
brand_context: Optional[str] = None,
|
||||
attachment_context: Optional[str] = None,
|
||||
):
|
||||
brand_section = ""
|
||||
if brand_context:
|
||||
brand_section = f"""
|
||||
## Brand Guidelines
|
||||
{brand_context}
|
||||
Ensure all generated text follows these brand voice and tone guidelines.
|
||||
"""
|
||||
|
||||
attachment_section = ""
|
||||
if attachment_context:
|
||||
attachment_section = f"""
|
||||
## Attachment Data
|
||||
The following data from attachments (tables, charts) is available for this slide. Use it directly — do not invent data points.
|
||||
{attachment_context}
|
||||
"""
|
||||
|
||||
return f"""
|
||||
Generate structured slide based on provided outline, follow mentioned steps and notes and provide structured output.
|
||||
|
||||
You are extracting and restructuring content from a brief. Do NOT invent facts, statistics, or claims not present in the source material or attachment data.
|
||||
|
||||
{"# User Instructions:" if instructions else ""}
|
||||
{instructions or ""}
|
||||
|
||||
|
|
@ -26,6 +46,10 @@ def get_system_prompt(
|
|||
{"# Verbosity:" if verbosity else ""}
|
||||
{verbosity or ""}
|
||||
|
||||
{brand_section}
|
||||
|
||||
{attachment_section}
|
||||
|
||||
# Steps
|
||||
1. Analyze the outline.
|
||||
2. Generate structured slide based on the outline.
|
||||
|
|
@ -86,11 +110,15 @@ def get_messages(
|
|||
tone: Optional[str] = None,
|
||||
verbosity: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
brand_context: Optional[str] = None,
|
||||
attachment_context: Optional[str] = None,
|
||||
):
|
||||
|
||||
return [
|
||||
LLMSystemMessage(
|
||||
content=get_system_prompt(tone, verbosity, instructions),
|
||||
content=get_system_prompt(
|
||||
tone, verbosity, instructions, brand_context, attachment_context,
|
||||
),
|
||||
),
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(outline, language),
|
||||
|
|
@ -105,6 +133,8 @@ async def get_slide_content_from_type_and_outline(
|
|||
tone: Optional[str] = None,
|
||||
verbosity: Optional[str] = None,
|
||||
instructions: Optional[str] = None,
|
||||
brand_context: Optional[str] = None,
|
||||
attachment_context: Optional[str] = None,
|
||||
):
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
|
@ -134,6 +164,8 @@ async def get_slide_content_from_type_and_outline(
|
|||
tone,
|
||||
verbosity,
|
||||
instructions,
|
||||
brand_context,
|
||||
attachment_context,
|
||||
),
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
|
|
|
|||
0
backend/workers/__init__.py
Normal file
0
backend/workers/__init__.py
Normal file
24
backend/workers/main.py
Normal file
24
backend/workers/main.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""ARQ worker entry point.
|
||||
|
||||
Run with: python -m arq workers.main.WorkerSettings
|
||||
"""
|
||||
import os
|
||||
|
||||
from arq.connections import RedisSettings
|
||||
|
||||
from workers.master_deck_worker import parse_master_deck_task
|
||||
from workers.presentation_worker import generate_presentation_task
|
||||
|
||||
|
||||
def _get_redis_settings() -> RedisSettings:
|
||||
url = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
return RedisSettings.from_dsn(url)
|
||||
|
||||
|
||||
class WorkerSettings:
|
||||
redis_settings = _get_redis_settings()
|
||||
functions = [generate_presentation_task, parse_master_deck_task]
|
||||
max_jobs = 5
|
||||
job_timeout = 600 # 10 minutes
|
||||
max_tries = 3
|
||||
health_check_interval = 30
|
||||
59
backend/workers/master_deck_worker.py
Normal file
59
backend/workers/master_deck_worker.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""ARQ worker task: parse a master deck PPTX."""
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from models.sql.job import JobModel
|
||||
from services.database import async_session_maker
|
||||
from services.master_deck_parser_service import parse_master_deck
|
||||
from services.redis_service import publish_job_progress
|
||||
|
||||
|
||||
async def parse_master_deck_task(ctx: dict, job_id: str) -> None:
|
||||
"""ARQ task: parse a master deck via the existing parser service."""
|
||||
job_uuid = uuid.UUID(job_id)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
job = await session.get(JobModel, job_uuid)
|
||||
if not job:
|
||||
return
|
||||
|
||||
try:
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.utcnow()
|
||||
job.progress = 10
|
||||
job.progress_message = "Parsing master deck"
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await publish_job_progress(job_uuid, 10, "Parsing master deck")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# The existing parser updates MasterDeckModel directly
|
||||
# presentation_id is reused to store the deck_id for this job type
|
||||
await parse_master_deck(job.presentation_id)
|
||||
|
||||
job.status = "completed"
|
||||
job.progress = 100
|
||||
job.progress_message = "Parsing complete"
|
||||
job.completed_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await publish_job_progress(job_uuid, 100, "Parsing complete", "completed")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
job.status = "failed"
|
||||
job.error_message = str(e)[:500]
|
||||
job.progress_message = "Parsing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await publish_job_progress(job_uuid, job.progress, "Parsing failed", "failed")
|
||||
except Exception:
|
||||
pass
|
||||
253
backend/workers/presentation_worker.py
Normal file
253
backend/workers/presentation_worker.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""ARQ worker task: generate a presentation end-to-end."""
|
||||
import asyncio
|
||||
import math
|
||||
import random
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import dirtyjson
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.presentation_model import PresentationModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel, SlideOutlineModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.slide_model import SlideModel
|
||||
from models.sql.job import JobModel
|
||||
from services.brand_enforcement_service import BrandEnforcementService
|
||||
from services.content_intelligence_service import ContentIntelligenceService
|
||||
from services.database import async_session_maker
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from services.redis_service import publish_job_progress
|
||||
from services.slide_mapping_engine import SlideMappingEngine
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
from utils.export_utils import export_presentation
|
||||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
from utils.llm_calls.generate_presentation_structure import generate_presentation_structure
|
||||
from utils.llm_calls.generate_slide_content import get_slide_content_from_type_and_outline
|
||||
from utils.presentation_utils import (
|
||||
get_layout_by_name,
|
||||
get_presentation_title_from_outlines,
|
||||
process_slide_and_fetch_assets,
|
||||
select_toc_or_list_slide_layout_index,
|
||||
)
|
||||
|
||||
|
||||
async def generate_presentation_task(ctx: dict, job_id: str) -> None:
|
||||
"""ARQ task: full presentation generation pipeline."""
|
||||
job_uuid = uuid.UUID(job_id)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
job = await session.get(JobModel, job_uuid)
|
||||
if not job:
|
||||
return
|
||||
|
||||
try:
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.utcnow()
|
||||
job.progress = 0
|
||||
job.progress_message = "Starting generation"
|
||||
await session.commit()
|
||||
await _publish(job_uuid, 0, "Starting generation")
|
||||
|
||||
# Load the stored request data from the presentation record
|
||||
presentation = await session.get(PresentationModel, job.presentation_id)
|
||||
if not presentation:
|
||||
raise ValueError("Presentation record not found")
|
||||
|
||||
# Extract request parameters from the stored presentation
|
||||
content = presentation.content or ""
|
||||
n_slides = presentation.n_slides or 10
|
||||
language = presentation.language or "en"
|
||||
tone = presentation.tone or "professional"
|
||||
verbosity = presentation.verbosity or "standard"
|
||||
instructions = presentation.instructions
|
||||
template = "default"
|
||||
include_title_slide = True
|
||||
|
||||
# --- Step 1: Brand context ---
|
||||
brand_context = ""
|
||||
if job.client_id:
|
||||
brand_svc = BrandEnforcementService()
|
||||
brand_context = await brand_svc.get_brand_context_for_llm(
|
||||
job.client_id, session
|
||||
)
|
||||
|
||||
await _update_job(session, job, 5, "Analyzing content")
|
||||
|
||||
# --- Step 2: Content intelligence (if raw content provided) ---
|
||||
content_summary = None
|
||||
if content and len(content) > 100:
|
||||
ci_service = ContentIntelligenceService()
|
||||
classified = await ci_service.classify(content)
|
||||
content_summary = classified.summary
|
||||
|
||||
await _update_job(session, job, 10, "Generating outlines")
|
||||
|
||||
# --- Step 3: Generate outlines ---
|
||||
presentation_outlines_text = ""
|
||||
async for chunk in generate_ppt_outline(
|
||||
content,
|
||||
n_slides,
|
||||
language,
|
||||
None, # additional_context
|
||||
tone,
|
||||
verbosity,
|
||||
instructions,
|
||||
include_title_slide,
|
||||
False, # web_search
|
||||
brand_context=brand_context,
|
||||
content_summary=content_summary,
|
||||
):
|
||||
if isinstance(chunk, HTTPException):
|
||||
raise chunk
|
||||
presentation_outlines_text += chunk
|
||||
|
||||
try:
|
||||
outlines_json = dict(dirtyjson.loads(presentation_outlines_text))
|
||||
except Exception:
|
||||
raise ValueError("Failed to parse generated outlines")
|
||||
|
||||
presentation_outlines = PresentationOutlineModel(**outlines_json)
|
||||
total_outlines = n_slides
|
||||
|
||||
await _update_job(session, job, 25, "Selecting layouts")
|
||||
|
||||
# --- Step 4: Layout selection ---
|
||||
layout_model = await get_layout_by_name(template)
|
||||
total_slide_layouts = len(layout_model.slides)
|
||||
|
||||
if layout_model.ordered:
|
||||
presentation_structure = layout_model.to_presentation_structure()
|
||||
else:
|
||||
presentation_structure = await generate_presentation_structure(
|
||||
presentation_outlines, layout_model, instructions
|
||||
)
|
||||
|
||||
presentation_structure.slides = presentation_structure.slides[:total_outlines]
|
||||
for index in range(total_outlines):
|
||||
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
||||
if index >= len(presentation_structure.slides):
|
||||
presentation_structure.slides.append(random_slide_index)
|
||||
elif presentation_structure.slides[index] >= total_slide_layouts:
|
||||
presentation_structure.slides[index] = random_slide_index
|
||||
|
||||
# Update presentation model with outlines & structure
|
||||
presentation.title = get_presentation_title_from_outlines(presentation_outlines)
|
||||
presentation.outlines = presentation_outlines.model_dump()
|
||||
presentation.layout = layout_model.model_dump()
|
||||
presentation.structure = presentation_structure.model_dump()
|
||||
await session.commit()
|
||||
|
||||
await _update_job(session, job, 35, "Generating slides")
|
||||
|
||||
# --- Step 5: Generate slide content ---
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
async_assets_generation_tasks = []
|
||||
slides: List[SlideModel] = []
|
||||
|
||||
slide_layout_indices = presentation_structure.slides
|
||||
slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices]
|
||||
|
||||
batch_size = 10
|
||||
for start in range(0, len(slide_layouts), batch_size):
|
||||
end = min(start + batch_size, len(slide_layouts))
|
||||
|
||||
content_tasks = [
|
||||
get_slide_content_from_type_and_outline(
|
||||
slide_layouts[i],
|
||||
presentation_outlines.slides[i],
|
||||
language,
|
||||
tone,
|
||||
verbosity,
|
||||
instructions,
|
||||
brand_context=brand_context,
|
||||
)
|
||||
for i in range(start, end)
|
||||
]
|
||||
batch_contents = await asyncio.gather(*content_tasks)
|
||||
|
||||
batch_slides = []
|
||||
for offset, slide_content in enumerate(batch_contents):
|
||||
i = start + offset
|
||||
slide = SlideModel(
|
||||
presentation=job.presentation_id,
|
||||
layout_group=layout_model.name,
|
||||
layout=slide_layouts[i].id,
|
||||
index=i,
|
||||
speaker_note=slide_content.get("__speaker_note__"),
|
||||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
batch_slides.append(slide)
|
||||
|
||||
asset_tasks = [
|
||||
process_slide_and_fetch_assets(image_generation_service, slide)
|
||||
for slide in batch_slides
|
||||
]
|
||||
async_assets_generation_tasks.extend(asset_tasks)
|
||||
|
||||
pct = 35 + int((end / len(slide_layouts)) * 40)
|
||||
await _update_job(session, job, pct, f"Generating slide {end}/{len(slide_layouts)}")
|
||||
|
||||
await _update_job(session, job, 80, "Fetching assets")
|
||||
|
||||
# --- Step 6: Fetch assets ---
|
||||
generated_assets_list = await asyncio.gather(*async_assets_generation_tasks)
|
||||
generated_assets = []
|
||||
for assets_list in generated_assets_list:
|
||||
generated_assets.extend(assets_list)
|
||||
|
||||
await _update_job(session, job, 90, "Saving presentation")
|
||||
|
||||
# --- Step 7: Save ---
|
||||
session.add(presentation)
|
||||
session.add_all(slides)
|
||||
session.add_all(generated_assets)
|
||||
await session.commit()
|
||||
|
||||
await _update_job(session, job, 95, "Exporting PPTX")
|
||||
|
||||
# --- Step 8: Export ---
|
||||
await export_presentation(
|
||||
job.presentation_id,
|
||||
presentation.title or str(uuid.uuid4()),
|
||||
"pptx",
|
||||
)
|
||||
|
||||
# --- Done ---
|
||||
job.status = "completed"
|
||||
job.progress = 100
|
||||
job.progress_message = "Generation complete"
|
||||
job.completed_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await _publish(job_uuid, 100, "Generation complete", "completed")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
job.status = "failed"
|
||||
job.error_message = str(e)[:500]
|
||||
job.progress_message = "Generation failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
await session.commit()
|
||||
await _publish(job_uuid, job.progress, "Generation failed", "failed")
|
||||
|
||||
|
||||
async def _update_job(
|
||||
session: AsyncSession, job: JobModel, progress: int, message: str
|
||||
) -> None:
|
||||
job.progress = progress
|
||||
job.progress_message = message
|
||||
await session.commit()
|
||||
await publish_job_progress(job.id, progress, message)
|
||||
|
||||
|
||||
async def _publish(
|
||||
job_id: uuid.UUID, progress: int, message: str, status: str = "processing"
|
||||
) -> None:
|
||||
try:
|
||||
await publish_job_progress(job_id, progress, message, status)
|
||||
except Exception:
|
||||
pass # Redis unavailable is not fatal
|
||||
Loading…
Add table
Reference in a new issue