Phase 4: Generation Pipeline — brand enforcement, enhanced LLM calls, ARQ job queue

- Step 14: Brand enforcement service (font/color/logo replacement, WCAG contrast check, LLM prompt context)
- Step 15: Enhanced outline & slide content generation with brand context, content summary, "no hallucination" instructions
- Step 15b: LLM auto-fallback retry logic across providers (FALLBACK_LLM_PROVIDERS env)
- Step 16: Redis/ARQ job queue — worker entry point, presentation & master deck workers, job status/SSE endpoints, graceful fallback to BackgroundTasks when Redis unavailable

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Vadym Samoilenko 2026-02-26 16:15:25 +00:00
parent a2bd4cfefa
commit a0d73b3b63
16 changed files with 1038 additions and 17 deletions

View file

@ -4,6 +4,7 @@ import os
from fastapi import FastAPI
from services.database import create_db_and_tables
from services.redis_service import close_arq_pool
from utils.get_env import get_app_data_directory_env
from utils.model_availability import (
check_llm_and_image_provider_api_or_model_availability,
@ -21,3 +22,4 @@ async def app_lifespan(_: FastAPI):
await create_db_and_tables()
await check_llm_and_image_provider_api_or_model_availability()
yield
await close_arq_pool()

View file

@ -13,6 +13,7 @@ from api.v1.admin.clients_router import CLIENTS_ROUTER
from api.v1.admin.audit_router import AUDIT_ROUTER
from api.v1.admin.brand_config_router import BRAND_CONFIG_ROUTER
from api.v1.admin.master_decks_router import MASTER_DECKS_ROUTER
from api.v1.ppt.endpoints.jobs import JOBS_ROUTER
from api.middlewares.audit_middleware import AuditMiddleware
@ -31,6 +32,7 @@ ADMIN_ROUTER.include_router(MASTER_DECKS_ROUTER)
app.include_router(AUTH_ROUTER)
app.include_router(ADMIN_ROUTER)
app.include_router(API_V1_PPT_ROUTER)
app.include_router(JOBS_ROUTER)
app.include_router(API_V1_WEBHOOK_ROUTER)
app.include_router(API_V1_MOCK_ROUTER)

View file

@ -114,11 +114,27 @@ async def upload_master_deck(
await session.commit()
await session.refresh(deck)
# Kick off async parsing
import asyncio
from services.master_deck_parser_service import parse_master_deck
# Kick off async parsing via ARQ (fallback to asyncio.create_task)
try:
from models.sql.job import JobModel
from services.redis_service import enqueue_job
asyncio.create_task(parse_master_deck(deck_id))
job = JobModel(
user_id=admin.id,
client_id=client_id,
presentation_id=deck_id, # reuse field for deck_id
job_type="parse_master_deck",
status="queued",
progress=0,
progress_message="Queued for parsing",
)
session.add(job)
await session.commit()
await enqueue_job("parse_master_deck_task", job_id=str(job.id))
except Exception:
import asyncio
from services.master_deck_parser_service import parse_master_deck
asyncio.create_task(parse_master_deck(deck_id))
return {
"id": str(deck.id),
@ -235,10 +251,26 @@ async def reparse_master_deck(
deck.parse_status = "pending"
await session.commit()
import asyncio
from services.master_deck_parser_service import parse_master_deck
try:
from models.sql.job import JobModel
from services.redis_service import enqueue_job
asyncio.create_task(parse_master_deck(deck_id))
job = JobModel(
user_id=admin.id,
client_id=deck.client_id,
presentation_id=deck_id,
job_type="parse_master_deck",
status="queued",
progress=0,
progress_message="Queued for re-parsing",
)
session.add(job)
await session.commit()
await enqueue_job("parse_master_deck_task", job_id=str(job.id))
except Exception:
import asyncio
from services.master_deck_parser_service import parse_master_deck
asyncio.create_task(parse_master_deck(deck_id))
return {"ok": True, "parse_status": "pending"}

View file

@ -0,0 +1,151 @@
"""Job status polling and SSE streaming endpoints."""
import asyncio
import json
import uuid
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from models.sse_response import SSEResponse, SSEErrorResponse
from models.sql.job import JobModel
from models.sql.user import UserModel
from services.database import get_async_session
from services.redis_service import get_arq_pool
from utils.auth_dependencies import get_current_user
JOBS_ROUTER = APIRouter(prefix="/api/v1/ppt", tags=["Jobs"])
@JOBS_ROUTER.get("/jobs/{job_id}")
async def get_job_status(
job_id: uuid.UUID,
_current_user: UserModel = Depends(get_current_user),
session: AsyncSession = Depends(get_async_session),
):
"""Poll job status."""
job = await session.get(JobModel, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return {
"id": str(job.id),
"job_type": job.job_type,
"status": job.status,
"progress": job.progress,
"progress_message": job.progress_message,
"error_message": job.error_message,
"presentation_id": str(job.presentation_id) if job.presentation_id else None,
"created_at": job.created_at.isoformat() if job.created_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
}
@JOBS_ROUTER.get("/jobs/{job_id}/stream")
async def stream_job_progress(
job_id: uuid.UUID,
_current_user: UserModel = Depends(get_current_user),
session: AsyncSession = Depends(get_async_session),
):
"""SSE stream of job progress events via Redis pub/sub."""
job = await session.get(JobModel, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# If already completed/failed, return final event immediately
if job.status in ("completed", "failed"):
async def done_stream():
data = {
"type": "progress",
"job_id": str(job.id),
"progress": job.progress,
"message": job.progress_message or job.status,
"status": job.status,
}
yield SSEResponse(
event="response", data=json.dumps(data)
).to_string()
return StreamingResponse(done_stream(), media_type="text/event-stream")
async def progress_stream():
try:
pool = await get_arq_pool()
pubsub = pool.pubsub()
channel = f"job:{job_id}:progress"
await pubsub.subscribe(channel)
# Send initial status
yield SSEResponse(
event="response",
data=json.dumps({
"type": "progress",
"job_id": str(job_id),
"progress": job.progress,
"message": job.progress_message or "Waiting",
"status": job.status,
}),
).to_string()
# Listen for updates with timeout
timeout = 600 # 10 minutes max
elapsed = 0
while elapsed < timeout:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=2.0
)
if message and message["type"] == "message":
raw = message["data"]
if isinstance(raw, bytes):
raw = raw.decode()
payload = json.loads(raw)
yield SSEResponse(
event="response",
data=json.dumps({"type": "progress", **payload}),
).to_string()
# Stop streaming on terminal status
if payload.get("status") in ("completed", "failed"):
break
else:
elapsed += 2
await asyncio.sleep(0)
await pubsub.unsubscribe(channel)
await pubsub.aclose()
except Exception as e:
yield SSEErrorResponse(detail=str(e)[:200]).to_string()
return StreamingResponse(progress_stream(), media_type="text/event-stream")
@JOBS_ROUTER.delete("/jobs/{job_id}")
async def cancel_job(
job_id: uuid.UUID,
_current_user: UserModel = Depends(get_current_user),
session: AsyncSession = Depends(get_async_session),
):
"""Cancel a queued or processing job."""
job = await session.get(JobModel, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status in ("completed", "failed"):
raise HTTPException(status_code=409, detail="Job already finished")
# Try to abort via ARQ
try:
pool = await get_arq_pool()
await pool.abort_job(str(job_id))
except Exception:
pass # Best effort
job.status = "failed"
job.error_message = "Cancelled by user"
job.progress_message = "Cancelled"
await session.commit()
return {"ok": True, "status": "cancelled"}

View file

@ -846,6 +846,18 @@ async def generate_presentation_async(
try:
(presentation_id,) = await check_if_api_request_is_valid(request, sql_session)
# Create a lightweight presentation record so the worker can load it
presentation = PresentationModel(
id=presentation_id,
content=request.content,
n_slides=request.n_slides,
language=request.language,
tone=request.tone.value,
verbosity=request.verbosity.value,
instructions=request.instructions,
)
sql_session.add(presentation)
async_status = AsyncPresentationGenerationTaskModel(
status="pending",
message="Queued for generation",
@ -854,13 +866,39 @@ async def generate_presentation_async(
sql_session.add(async_status)
await sql_session.commit()
background_tasks.add_task(
generate_presentation_handler,
request,
presentation_id,
async_status=async_status,
sql_session=sql_session,
)
# Try ARQ job queue first; fall back to BackgroundTasks
job_enqueued = False
try:
from models.sql.job import JobModel
from services.redis_service import enqueue_job
job = JobModel(
user_id=_current_user.id,
client_id=getattr(_current_user, "default_client_id", _current_user.id),
presentation_id=presentation_id,
job_type="generate_presentation",
status="queued",
progress=0,
progress_message="Queued for generation",
)
sql_session.add(job)
await sql_session.commit()
await enqueue_job("generate_presentation_task", job_id=str(job.id))
job_enqueued = True
except Exception:
# Redis/ARQ unavailable — fall back to in-process background task
pass
if not job_enqueued:
background_tasks.add_task(
generate_presentation_handler,
request,
presentation_id,
async_status=async_status,
sql_session=sql_session,
)
return async_status
except Exception as e:

View file

@ -28,6 +28,7 @@ dependencies = [
"python-jose[cryptography]>=3.3",
"openpyxl>=3.1",
"trafilatura>=2.0",
"arq>=0.26",
]
[[tool.uv.index]]

View file

@ -0,0 +1,225 @@
"""Brand Enforcement Service: apply client brand rules to generated content and PPTX output."""
import math
import uuid
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.pptx_models import (
PptxAutoShapeBoxModel,
PptxChartBoxModel,
PptxFontModel,
PptxPictureBoxModel,
PptxPictureModel,
PptxPositionModel,
PptxPresentationModel,
PptxTextBoxModel,
)
from models.sql.brand_config import BrandConfigModel
# Default slide dimensions (matching PptxPresentationCreator)
_SLIDE_WIDTH = 1280
_SLIDE_HEIGHT = 720
# Logo defaults
_LOGO_WIDTH = 80
_LOGO_HEIGHT = 40
_LOGO_MARGIN = 20
class BrandEnforcementService:
async def get_brand_context_for_llm(
self, client_id: uuid.UUID, session: AsyncSession
) -> str:
"""Build a text block for LLM system prompt injection with brand guidelines."""
brand = await self._load_brand(client_id, session)
if not brand:
return ""
parts: List[str] = []
if brand.primary_colors:
colors_str = ", ".join(brand.primary_colors[:6])
parts.append(f"Brand primary colors: {colors_str}")
if brand.secondary_colors:
colors_str = ", ".join(brand.secondary_colors[:6])
parts.append(f"Brand secondary colors: {colors_str}")
if brand.fonts:
heading = brand.fonts.get("heading", "")
body = brand.fonts.get("body", "")
if heading:
parts.append(f"Heading font: {heading}")
if body:
parts.append(f"Body font: {body}")
if brand.voice_rules:
parts.append(f"Brand voice guidelines: {brand.voice_rules}")
if brand.voice_examples:
examples = "\n - ".join(str(e) for e in brand.voice_examples[:3])
parts.append(f"Brand voice examples:\n - {examples}")
return "\n".join(parts)
def enforce_on_pptx_model(
self,
model: PptxPresentationModel,
brand: BrandConfigModel,
) -> PptxPresentationModel:
"""Apply brand fonts, colors, and logo to a PPTX model."""
heading_font = (brand.fonts or {}).get("heading")
body_font = (brand.fonts or {}).get("body")
brand_colors = self.get_brand_colors_list(brand)
logo_path = brand.logo_paths[0] if brand.logo_paths else None
for slide_idx, slide in enumerate(model.slides):
for shape in slide.shapes:
# Replace fonts
if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)):
self._enforce_fonts_on_shape(shape, heading_font, body_font, slide_idx)
# Set brand colors on charts
if isinstance(shape, PptxChartBoxModel):
if brand_colors:
shape.brand_colors = brand_colors
if body_font:
shape.font_name = body_font
# Contrast check on all text shapes
bg_color = slide.background.color if slide.background else "FFFFFF"
for shape in slide.shapes:
if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)):
self._fix_contrast(shape, bg_color)
# Add logo to non-title slides
if slide_idx > 0 and logo_path:
logo_shape = PptxPictureBoxModel(
position=PptxPositionModel(
left=_SLIDE_WIDTH - _LOGO_WIDTH - _LOGO_MARGIN,
top=_SLIDE_HEIGHT - _LOGO_HEIGHT - _LOGO_MARGIN,
width=_LOGO_WIDTH,
height=_LOGO_HEIGHT,
),
clip=False,
picture=PptxPictureModel(is_network=False, path=logo_path),
)
slide.shapes.append(logo_shape)
return model
def get_brand_colors_list(self, brand: BrandConfigModel) -> List[str]:
"""Return usable brand color list, with fallback defaults."""
colors: List[str] = []
if brand.primary_colors:
colors.extend(str(c).lstrip("#") for c in brand.primary_colors)
if brand.secondary_colors:
colors.extend(str(c).lstrip("#") for c in brand.secondary_colors)
if not colors:
# Fallback to default chart colors
colors = [
"4472C4", "ED7D31", "A5A5A5", "FFC000", "5B9BD5",
"70AD47", "264478", "9B57A0", "636363", "EB6E1F",
]
return colors
# --- Internal methods ---
async def _load_brand(
self, client_id: uuid.UUID, session: AsyncSession
) -> Optional[BrandConfigModel]:
stmt = select(BrandConfigModel).where(
BrandConfigModel.client_id == client_id
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
def _enforce_fonts_on_shape(
self,
shape,
heading_font: Optional[str],
body_font: Optional[str],
slide_idx: int,
) -> None:
"""Replace font names in text shapes with brand fonts."""
paragraphs = getattr(shape, "paragraphs", None)
if not paragraphs:
return
for para in paragraphs:
# Title slides (idx 0) and large fonts use heading font
is_heading = slide_idx == 0 or (para.font and para.font.size >= 24)
target_font = heading_font if is_heading else body_font
if target_font:
if para.font:
para.font.name = target_font
if para.text_runs:
for run in para.text_runs:
if run.font:
run.font.name = target_font
def _fix_contrast(self, shape, bg_color: str) -> None:
"""Ensure text has sufficient contrast against slide background."""
paragraphs = getattr(shape, "paragraphs", None)
if not paragraphs:
return
bg_lum = _relative_luminance(bg_color)
for para in paragraphs:
self._fix_font_contrast(para.font, bg_lum)
if para.text_runs:
for run in para.text_runs:
self._fix_font_contrast(run.font, bg_lum)
def _fix_font_contrast(
self, font: Optional[PptxFontModel], bg_lum: float
) -> None:
if not font:
return
text_lum = _relative_luminance(font.color)
ratio = _contrast_ratio(text_lum, bg_lum)
# WCAG AA requires 4.5:1 for normal text, 3:1 for large text
if ratio < 3.0:
# Swap to white or black based on background
font.color = "FFFFFF" if bg_lum < 0.5 else "000000"
# --- Color utility functions ---
def _hex_to_rgb(hex_color: str) -> tuple:
"""Convert hex color string to (r, g, b) tuple with values 0-255."""
h = hex_color.lstrip("#")
if len(h) == 3:
h = h[0]*2 + h[1]*2 + h[2]*2
if len(h) < 6:
h = h.ljust(6, "0")
return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
def _relative_luminance(hex_color: str) -> float:
"""Calculate relative luminance per WCAG 2.0."""
r, g, b = _hex_to_rgb(hex_color)
rs = r / 255.0
gs = g / 255.0
bs = b / 255.0
def linearize(c):
return c / 12.92 if c <= 0.03928 else math.pow((c + 0.055) / 1.055, 2.4)
return 0.2126 * linearize(rs) + 0.7152 * linearize(gs) + 0.0722 * linearize(bs)
def _contrast_ratio(lum1: float, lum2: float) -> float:
"""Calculate contrast ratio between two luminance values."""
lighter = max(lum1, lum2)
darker = min(lum1, lum2)
return (lighter + 0.05) / (darker + 0.05)

View file

@ -47,6 +47,7 @@ from utils.get_env import (
get_custom_llm_api_key_env,
get_custom_llm_url_env,
get_disable_thinking_env,
get_fallback_llm_providers_env,
get_google_api_key_env,
get_ollama_url_env,
get_openai_api_key_env,
@ -62,6 +63,26 @@ from utils.schema_utils import (
)
import logging
_logger = logging.getLogger(__name__)
def _get_fallback_providers() -> list:
"""Parse FALLBACK_LLM_PROVIDERS env var into list of LLMProvider enums."""
raw = get_fallback_llm_providers_env()
if not raw:
return []
providers = []
for name in raw.split(","):
name = name.strip().lower()
try:
providers.append(LLMProvider(name))
except ValueError:
pass
return providers
class LLMClient:
def __init__(self):
self.llm_provider = get_llm_provider()
@ -407,6 +428,43 @@ class LLMClient:
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
try:
return await self._generate_impl(model, messages, max_tokens, tools)
except Exception as primary_error:
fallback_providers = _get_fallback_providers()
# Remove primary provider from fallback list
fallback_providers = [
p for p in fallback_providers if p != self.llm_provider
]
for fb_provider in fallback_providers:
_logger.warning(
"LLM generate failed with %s: %s. Trying fallback %s",
self.llm_provider.value, str(primary_error)[:200], fb_provider.value,
)
try:
fb_client = LLMClient.__new__(LLMClient)
fb_client.llm_provider = fb_provider
fb_client._client = fb_client._get_client()
fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client)
fb_model = get_model()
return await fb_client._generate_impl(
fb_model, messages, max_tokens, tools
)
except Exception as fb_error:
_logger.warning(
"Fallback %s also failed: %s",
fb_provider.value, str(fb_error)[:200],
)
continue
raise primary_error
async def _generate_impl(
self,
model: str,
messages: List[LLMMessage],
max_tokens: Optional[int] = None,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
):
parsed_tools = self.tool_calls_handler.parse_tools(tools)
@ -781,6 +839,46 @@ class LLMClient:
strict: bool = False,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
) -> dict:
try:
return await self._generate_structured_impl(
model, messages, response_format, strict, tools, max_tokens
)
except Exception as primary_error:
fallback_providers = _get_fallback_providers()
fallback_providers = [
p for p in fallback_providers if p != self.llm_provider
]
for fb_provider in fallback_providers:
_logger.warning(
"LLM generate_structured failed with %s: %s. Trying fallback %s",
self.llm_provider.value, str(primary_error)[:200], fb_provider.value,
)
try:
fb_client = LLMClient.__new__(LLMClient)
fb_client.llm_provider = fb_provider
fb_client._client = fb_client._get_client()
fb_client.tool_calls_handler = LLMToolCallsHandler(fb_client)
fb_model = get_model()
return await fb_client._generate_structured_impl(
fb_model, messages, response_format, strict, tools, max_tokens
)
except Exception as fb_error:
_logger.warning(
"Fallback %s also failed: %s",
fb_provider.value, str(fb_error)[:200],
)
continue
raise primary_error
async def _generate_structured_impl(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
) -> dict:
parsed_tools = self.tool_calls_handler.parse_tools(tools)

View file

@ -0,0 +1,58 @@
"""Redis service: connection pool and job progress utilities."""
import json
import os
import uuid
from typing import Optional
from arq import create_pool
from arq.connections import ArqRedis, RedisSettings
_pool: Optional[ArqRedis] = None
def _get_redis_settings() -> RedisSettings:
"""Parse REDIS_URL env var into ARQ RedisSettings."""
url = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
return RedisSettings.from_dsn(url)
async def get_arq_pool() -> ArqRedis:
"""Get or create the shared ARQ Redis connection pool."""
global _pool
if _pool is None:
_pool = await create_pool(_get_redis_settings())
return _pool
async def close_arq_pool() -> None:
"""Close the shared ARQ Redis pool (call on app shutdown)."""
global _pool
if _pool is not None:
await _pool.aclose()
_pool = None
async def enqueue_job(function_name: str, **kwargs) -> Optional[str]:
"""Enqueue a job via ARQ. Returns the ARQ job ID."""
pool = await get_arq_pool()
job = await pool.enqueue_job(function_name, **kwargs)
return job.job_id if job else None
async def publish_job_progress(
job_id: uuid.UUID,
progress: int,
message: str,
status: str = "processing",
) -> None:
"""Publish a progress event to Redis pub/sub for SSE consumers."""
pool = await get_arq_pool()
channel = f"job:{job_id}:progress"
payload = json.dumps({
"job_id": str(job_id),
"progress": progress,
"message": message,
"status": status,
})
await pool.publish(channel, payload.encode())

View file

@ -117,3 +117,7 @@ def get_dall_e_3_quality_env():
# Gpt Image 1.5 Quality
def get_gpt_image_1_5_quality_env():
return os.getenv("GPT_IMAGE_1_5_QUALITY")
def get_fallback_llm_providers_env():
return os.getenv("FALLBACK_LLM_PROVIDERS")

View file

@ -14,10 +14,30 @@ def get_system_prompt(
verbosity: Optional[str] = None,
instructions: Optional[str] = None,
include_title_slide: bool = True,
brand_context: Optional[str] = None,
available_layouts: Optional[str] = None,
):
brand_section = ""
if brand_context:
brand_section = f"""
## Brand Guidelines
{brand_context}
Ensure all text follows these brand voice and tone guidelines.
"""
layouts_section = ""
if available_layouts:
layouts_section = f"""
## Available Slide Layouts
{available_layouts}
Consider which content types best match these available layouts when structuring the outline.
"""
return f"""
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
You are restructuring and condensing existing content from a brief. Do NOT invent facts, statistics, or claims. Every data point must originate from the source material.
Try to use available tools for better results.
{"# User Instruction:" if instructions else ""}
@ -29,6 +49,10 @@ def get_system_prompt(
{"# Verbosity:" if verbosity else ""}
{verbosity or ""}
{brand_section}
{layouts_section}
- Provide content for each slide in markdown format.
- Make sure that flow of the presentation is logical and consistent.
- Place greater emphasis on numerical data.
@ -49,7 +73,12 @@ def get_user_prompt(
n_slides: int,
language: str,
additional_context: Optional[str] = None,
content_summary: Optional[str] = None,
):
summary_section = ""
if content_summary:
summary_section = f"- Content Analysis Summary: {content_summary}"
return f"""
**Input:**
- User provided content: {content or "Create presentation"}
@ -57,6 +86,7 @@ def get_user_prompt(
- Number of Slides: {n_slides}
- Current Date and Time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
- Additional Information: {additional_context or ""}
{summary_section}
"""
@ -69,15 +99,21 @@ def get_messages(
verbosity: Optional[str] = None,
instructions: Optional[str] = None,
include_title_slide: bool = True,
brand_context: Optional[str] = None,
available_layouts: Optional[str] = None,
content_summary: Optional[str] = None,
):
return [
LLMSystemMessage(
content=get_system_prompt(
tone, verbosity, instructions, include_title_slide
tone, verbosity, instructions, include_title_slide,
brand_context, available_layouts,
),
),
LLMUserMessage(
content=get_user_prompt(content, n_slides, language, additional_context),
content=get_user_prompt(
content, n_slides, language, additional_context, content_summary,
),
),
]
@ -92,6 +128,9 @@ async def generate_ppt_outline(
instructions: Optional[str] = None,
include_title_slide: bool = True,
web_search: bool = False,
brand_context: Optional[str] = None,
available_layouts: Optional[str] = None,
content_summary: Optional[str] = None,
):
model = get_model()
response_model = get_presentation_outline_model_with_n_slides(n_slides)
@ -110,6 +149,9 @@ async def generate_ppt_outline(
verbosity,
instructions,
include_title_slide,
brand_context,
available_layouts,
content_summary,
),
response_model.model_json_schema(),
strict=True,

View file

@ -13,10 +13,30 @@ def get_system_prompt(
tone: Optional[str] = None,
verbosity: Optional[str] = None,
instructions: Optional[str] = None,
brand_context: Optional[str] = None,
attachment_context: Optional[str] = None,
):
brand_section = ""
if brand_context:
brand_section = f"""
## Brand Guidelines
{brand_context}
Ensure all generated text follows these brand voice and tone guidelines.
"""
attachment_section = ""
if attachment_context:
attachment_section = f"""
## Attachment Data
The following data from attachments (tables, charts) is available for this slide. Use it directly do not invent data points.
{attachment_context}
"""
return f"""
Generate structured slide based on provided outline, follow mentioned steps and notes and provide structured output.
You are extracting and restructuring content from a brief. Do NOT invent facts, statistics, or claims not present in the source material or attachment data.
{"# User Instructions:" if instructions else ""}
{instructions or ""}
@ -26,6 +46,10 @@ def get_system_prompt(
{"# Verbosity:" if verbosity else ""}
{verbosity or ""}
{brand_section}
{attachment_section}
# Steps
1. Analyze the outline.
2. Generate structured slide based on the outline.
@ -86,11 +110,15 @@ def get_messages(
tone: Optional[str] = None,
verbosity: Optional[str] = None,
instructions: Optional[str] = None,
brand_context: Optional[str] = None,
attachment_context: Optional[str] = None,
):
return [
LLMSystemMessage(
content=get_system_prompt(tone, verbosity, instructions),
content=get_system_prompt(
tone, verbosity, instructions, brand_context, attachment_context,
),
),
LLMUserMessage(
content=get_user_prompt(outline, language),
@ -105,6 +133,8 @@ async def get_slide_content_from_type_and_outline(
tone: Optional[str] = None,
verbosity: Optional[str] = None,
instructions: Optional[str] = None,
brand_context: Optional[str] = None,
attachment_context: Optional[str] = None,
):
client = LLMClient()
model = get_model()
@ -134,6 +164,8 @@ async def get_slide_content_from_type_and_outline(
tone,
verbosity,
instructions,
brand_context,
attachment_context,
),
response_format=response_schema,
strict=False,

View file

24
backend/workers/main.py Normal file
View file

@ -0,0 +1,24 @@
"""ARQ worker entry point.
Run with: python -m arq workers.main.WorkerSettings
"""
import os
from arq.connections import RedisSettings
from workers.master_deck_worker import parse_master_deck_task
from workers.presentation_worker import generate_presentation_task
def _get_redis_settings() -> RedisSettings:
url = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
return RedisSettings.from_dsn(url)
class WorkerSettings:
redis_settings = _get_redis_settings()
functions = [generate_presentation_task, parse_master_deck_task]
max_jobs = 5
job_timeout = 600 # 10 minutes
max_tries = 3
health_check_interval = 30

View file

@ -0,0 +1,59 @@
"""ARQ worker task: parse a master deck PPTX."""
import traceback
import uuid
from datetime import datetime
from models.sql.job import JobModel
from services.database import async_session_maker
from services.master_deck_parser_service import parse_master_deck
from services.redis_service import publish_job_progress
async def parse_master_deck_task(ctx: dict, job_id: str) -> None:
"""ARQ task: parse a master deck via the existing parser service."""
job_uuid = uuid.UUID(job_id)
async with async_session_maker() as session:
job = await session.get(JobModel, job_uuid)
if not job:
return
try:
job.status = "processing"
job.started_at = datetime.utcnow()
job.progress = 10
job.progress_message = "Parsing master deck"
await session.commit()
try:
await publish_job_progress(job_uuid, 10, "Parsing master deck")
except Exception:
pass
# The existing parser updates MasterDeckModel directly
# presentation_id is reused to store the deck_id for this job type
await parse_master_deck(job.presentation_id)
job.status = "completed"
job.progress = 100
job.progress_message = "Parsing complete"
job.completed_at = datetime.utcnow()
await session.commit()
try:
await publish_job_progress(job_uuid, 100, "Parsing complete", "completed")
except Exception:
pass
except Exception as e:
traceback.print_exc()
job.status = "failed"
job.error_message = str(e)[:500]
job.progress_message = "Parsing failed"
job.completed_at = datetime.utcnow()
await session.commit()
try:
await publish_job_progress(job_uuid, job.progress, "Parsing failed", "failed")
except Exception:
pass

View file

@ -0,0 +1,253 @@
"""ARQ worker task: generate a presentation end-to-end."""
import asyncio
import math
import random
import traceback
import uuid
from datetime import datetime
from typing import List
import dirtyjson
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from models.presentation_model import PresentationModel
from models.presentation_outline_model import PresentationOutlineModel, SlideOutlineModel
from models.presentation_structure_model import PresentationStructureModel
from models.slide_model import SlideModel
from models.sql.job import JobModel
from services.brand_enforcement_service import BrandEnforcementService
from services.content_intelligence_service import ContentIntelligenceService
from services.database import async_session_maker
from services.image_generation_service import ImageGenerationService
from services.redis_service import publish_job_progress
from services.slide_mapping_engine import SlideMappingEngine
from utils.asset_directory_utils import get_images_directory
from utils.export_utils import export_presentation
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
from utils.llm_calls.generate_presentation_structure import generate_presentation_structure
from utils.llm_calls.generate_slide_content import get_slide_content_from_type_and_outline
from utils.presentation_utils import (
get_layout_by_name,
get_presentation_title_from_outlines,
process_slide_and_fetch_assets,
select_toc_or_list_slide_layout_index,
)
async def generate_presentation_task(ctx: dict, job_id: str) -> None:
"""ARQ task: full presentation generation pipeline."""
job_uuid = uuid.UUID(job_id)
async with async_session_maker() as session:
job = await session.get(JobModel, job_uuid)
if not job:
return
try:
job.status = "processing"
job.started_at = datetime.utcnow()
job.progress = 0
job.progress_message = "Starting generation"
await session.commit()
await _publish(job_uuid, 0, "Starting generation")
# Load the stored request data from the presentation record
presentation = await session.get(PresentationModel, job.presentation_id)
if not presentation:
raise ValueError("Presentation record not found")
# Extract request parameters from the stored presentation
content = presentation.content or ""
n_slides = presentation.n_slides or 10
language = presentation.language or "en"
tone = presentation.tone or "professional"
verbosity = presentation.verbosity or "standard"
instructions = presentation.instructions
template = "default"
include_title_slide = True
# --- Step 1: Brand context ---
brand_context = ""
if job.client_id:
brand_svc = BrandEnforcementService()
brand_context = await brand_svc.get_brand_context_for_llm(
job.client_id, session
)
await _update_job(session, job, 5, "Analyzing content")
# --- Step 2: Content intelligence (if raw content provided) ---
content_summary = None
if content and len(content) > 100:
ci_service = ContentIntelligenceService()
classified = await ci_service.classify(content)
content_summary = classified.summary
await _update_job(session, job, 10, "Generating outlines")
# --- Step 3: Generate outlines ---
presentation_outlines_text = ""
async for chunk in generate_ppt_outline(
content,
n_slides,
language,
None, # additional_context
tone,
verbosity,
instructions,
include_title_slide,
False, # web_search
brand_context=brand_context,
content_summary=content_summary,
):
if isinstance(chunk, HTTPException):
raise chunk
presentation_outlines_text += chunk
try:
outlines_json = dict(dirtyjson.loads(presentation_outlines_text))
except Exception:
raise ValueError("Failed to parse generated outlines")
presentation_outlines = PresentationOutlineModel(**outlines_json)
total_outlines = n_slides
await _update_job(session, job, 25, "Selecting layouts")
# --- Step 4: Layout selection ---
layout_model = await get_layout_by_name(template)
total_slide_layouts = len(layout_model.slides)
if layout_model.ordered:
presentation_structure = layout_model.to_presentation_structure()
else:
presentation_structure = await generate_presentation_structure(
presentation_outlines, layout_model, instructions
)
presentation_structure.slides = presentation_structure.slides[:total_outlines]
for index in range(total_outlines):
random_slide_index = random.randint(0, total_slide_layouts - 1)
if index >= len(presentation_structure.slides):
presentation_structure.slides.append(random_slide_index)
elif presentation_structure.slides[index] >= total_slide_layouts:
presentation_structure.slides[index] = random_slide_index
# Update presentation model with outlines & structure
presentation.title = get_presentation_title_from_outlines(presentation_outlines)
presentation.outlines = presentation_outlines.model_dump()
presentation.layout = layout_model.model_dump()
presentation.structure = presentation_structure.model_dump()
await session.commit()
await _update_job(session, job, 35, "Generating slides")
# --- Step 5: Generate slide content ---
image_generation_service = ImageGenerationService(get_images_directory())
async_assets_generation_tasks = []
slides: List[SlideModel] = []
slide_layout_indices = presentation_structure.slides
slide_layouts = [layout_model.slides[idx] for idx in slide_layout_indices]
batch_size = 10
for start in range(0, len(slide_layouts), batch_size):
end = min(start + batch_size, len(slide_layouts))
content_tasks = [
get_slide_content_from_type_and_outline(
slide_layouts[i],
presentation_outlines.slides[i],
language,
tone,
verbosity,
instructions,
brand_context=brand_context,
)
for i in range(start, end)
]
batch_contents = await asyncio.gather(*content_tasks)
batch_slides = []
for offset, slide_content in enumerate(batch_contents):
i = start + offset
slide = SlideModel(
presentation=job.presentation_id,
layout_group=layout_model.name,
layout=slide_layouts[i].id,
index=i,
speaker_note=slide_content.get("__speaker_note__"),
content=slide_content,
)
slides.append(slide)
batch_slides.append(slide)
asset_tasks = [
process_slide_and_fetch_assets(image_generation_service, slide)
for slide in batch_slides
]
async_assets_generation_tasks.extend(asset_tasks)
pct = 35 + int((end / len(slide_layouts)) * 40)
await _update_job(session, job, pct, f"Generating slide {end}/{len(slide_layouts)}")
await _update_job(session, job, 80, "Fetching assets")
# --- Step 6: Fetch assets ---
generated_assets_list = await asyncio.gather(*async_assets_generation_tasks)
generated_assets = []
for assets_list in generated_assets_list:
generated_assets.extend(assets_list)
await _update_job(session, job, 90, "Saving presentation")
# --- Step 7: Save ---
session.add(presentation)
session.add_all(slides)
session.add_all(generated_assets)
await session.commit()
await _update_job(session, job, 95, "Exporting PPTX")
# --- Step 8: Export ---
await export_presentation(
job.presentation_id,
presentation.title or str(uuid.uuid4()),
"pptx",
)
# --- Done ---
job.status = "completed"
job.progress = 100
job.progress_message = "Generation complete"
job.completed_at = datetime.utcnow()
await session.commit()
await _publish(job_uuid, 100, "Generation complete", "completed")
except Exception as e:
traceback.print_exc()
job.status = "failed"
job.error_message = str(e)[:500]
job.progress_message = "Generation failed"
job.completed_at = datetime.utcnow()
await session.commit()
await _publish(job_uuid, job.progress, "Generation failed", "failed")
async def _update_job(
session: AsyncSession, job: JobModel, progress: int, message: str
) -> None:
job.progress = progress
job.progress_message = message
await session.commit()
await publish_job_progress(job.id, progress, message)
async def _publish(
job_id: uuid.UUID, progress: int, message: str, status: str = "processing"
) -> None:
try:
await publish_job_progress(job_id, progress, message, status)
except Exception:
pass # Redis unavailable is not fatal