ppt-tool/backend/services/brand_enforcement_service.py
Vadym Samoilenko a0d73b3b63 Phase 4: Generation Pipeline — brand enforcement, enhanced LLM calls, ARQ job queue
- Step 14: Brand enforcement service (font/color/logo replacement, WCAG contrast check, LLM prompt context)
- Step 15: Enhanced outline & slide content generation with brand context, content summary, "no hallucination" instructions
- Step 15b: LLM auto-fallback retry logic across providers (FALLBACK_LLM_PROVIDERS env)
- Step 16: Redis/ARQ job queue — worker entry point, presentation & master deck workers, job status/SSE endpoints, graceful fallback to BackgroundTasks when Redis unavailable

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-26 16:15:25 +00:00

225 lines
7.6 KiB
Python

"""Brand Enforcement Service: apply client brand rules to generated content and PPTX output."""
import math
import uuid
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models.pptx_models import (
PptxAutoShapeBoxModel,
PptxChartBoxModel,
PptxFontModel,
PptxPictureBoxModel,
PptxPictureModel,
PptxPositionModel,
PptxPresentationModel,
PptxTextBoxModel,
)
from models.sql.brand_config import BrandConfigModel
# Default slide dimensions (matching PptxPresentationCreator)
_SLIDE_WIDTH = 1280
_SLIDE_HEIGHT = 720
# Logo defaults
_LOGO_WIDTH = 80
_LOGO_HEIGHT = 40
_LOGO_MARGIN = 20
class BrandEnforcementService:
async def get_brand_context_for_llm(
self, client_id: uuid.UUID, session: AsyncSession
) -> str:
"""Build a text block for LLM system prompt injection with brand guidelines."""
brand = await self._load_brand(client_id, session)
if not brand:
return ""
parts: List[str] = []
if brand.primary_colors:
colors_str = ", ".join(brand.primary_colors[:6])
parts.append(f"Brand primary colors: {colors_str}")
if brand.secondary_colors:
colors_str = ", ".join(brand.secondary_colors[:6])
parts.append(f"Brand secondary colors: {colors_str}")
if brand.fonts:
heading = brand.fonts.get("heading", "")
body = brand.fonts.get("body", "")
if heading:
parts.append(f"Heading font: {heading}")
if body:
parts.append(f"Body font: {body}")
if brand.voice_rules:
parts.append(f"Brand voice guidelines: {brand.voice_rules}")
if brand.voice_examples:
examples = "\n - ".join(str(e) for e in brand.voice_examples[:3])
parts.append(f"Brand voice examples:\n - {examples}")
return "\n".join(parts)
def enforce_on_pptx_model(
self,
model: PptxPresentationModel,
brand: BrandConfigModel,
) -> PptxPresentationModel:
"""Apply brand fonts, colors, and logo to a PPTX model."""
heading_font = (brand.fonts or {}).get("heading")
body_font = (brand.fonts or {}).get("body")
brand_colors = self.get_brand_colors_list(brand)
logo_path = brand.logo_paths[0] if brand.logo_paths else None
for slide_idx, slide in enumerate(model.slides):
for shape in slide.shapes:
# Replace fonts
if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)):
self._enforce_fonts_on_shape(shape, heading_font, body_font, slide_idx)
# Set brand colors on charts
if isinstance(shape, PptxChartBoxModel):
if brand_colors:
shape.brand_colors = brand_colors
if body_font:
shape.font_name = body_font
# Contrast check on all text shapes
bg_color = slide.background.color if slide.background else "FFFFFF"
for shape in slide.shapes:
if isinstance(shape, (PptxTextBoxModel, PptxAutoShapeBoxModel)):
self._fix_contrast(shape, bg_color)
# Add logo to non-title slides
if slide_idx > 0 and logo_path:
logo_shape = PptxPictureBoxModel(
position=PptxPositionModel(
left=_SLIDE_WIDTH - _LOGO_WIDTH - _LOGO_MARGIN,
top=_SLIDE_HEIGHT - _LOGO_HEIGHT - _LOGO_MARGIN,
width=_LOGO_WIDTH,
height=_LOGO_HEIGHT,
),
clip=False,
picture=PptxPictureModel(is_network=False, path=logo_path),
)
slide.shapes.append(logo_shape)
return model
def get_brand_colors_list(self, brand: BrandConfigModel) -> List[str]:
"""Return usable brand color list, with fallback defaults."""
colors: List[str] = []
if brand.primary_colors:
colors.extend(str(c).lstrip("#") for c in brand.primary_colors)
if brand.secondary_colors:
colors.extend(str(c).lstrip("#") for c in brand.secondary_colors)
if not colors:
# Fallback to default chart colors
colors = [
"4472C4", "ED7D31", "A5A5A5", "FFC000", "5B9BD5",
"70AD47", "264478", "9B57A0", "636363", "EB6E1F",
]
return colors
# --- Internal methods ---
async def _load_brand(
self, client_id: uuid.UUID, session: AsyncSession
) -> Optional[BrandConfigModel]:
stmt = select(BrandConfigModel).where(
BrandConfigModel.client_id == client_id
)
result = await session.execute(stmt)
return result.scalar_one_or_none()
def _enforce_fonts_on_shape(
self,
shape,
heading_font: Optional[str],
body_font: Optional[str],
slide_idx: int,
) -> None:
"""Replace font names in text shapes with brand fonts."""
paragraphs = getattr(shape, "paragraphs", None)
if not paragraphs:
return
for para in paragraphs:
# Title slides (idx 0) and large fonts use heading font
is_heading = slide_idx == 0 or (para.font and para.font.size >= 24)
target_font = heading_font if is_heading else body_font
if target_font:
if para.font:
para.font.name = target_font
if para.text_runs:
for run in para.text_runs:
if run.font:
run.font.name = target_font
def _fix_contrast(self, shape, bg_color: str) -> None:
"""Ensure text has sufficient contrast against slide background."""
paragraphs = getattr(shape, "paragraphs", None)
if not paragraphs:
return
bg_lum = _relative_luminance(bg_color)
for para in paragraphs:
self._fix_font_contrast(para.font, bg_lum)
if para.text_runs:
for run in para.text_runs:
self._fix_font_contrast(run.font, bg_lum)
def _fix_font_contrast(
self, font: Optional[PptxFontModel], bg_lum: float
) -> None:
if not font:
return
text_lum = _relative_luminance(font.color)
ratio = _contrast_ratio(text_lum, bg_lum)
# WCAG AA requires 4.5:1 for normal text, 3:1 for large text
if ratio < 3.0:
# Swap to white or black based on background
font.color = "FFFFFF" if bg_lum < 0.5 else "000000"
# --- Color utility functions ---
def _hex_to_rgb(hex_color: str) -> tuple:
"""Convert hex color string to (r, g, b) tuple with values 0-255."""
h = hex_color.lstrip("#")
if len(h) == 3:
h = h[0]*2 + h[1]*2 + h[2]*2
if len(h) < 6:
h = h.ljust(6, "0")
return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
def _relative_luminance(hex_color: str) -> float:
"""Calculate relative luminance per WCAG 2.0."""
r, g, b = _hex_to_rgb(hex_color)
rs = r / 255.0
gs = g / 255.0
bs = b / 255.0
def linearize(c):
return c / 12.92 if c <= 0.03928 else math.pow((c + 0.055) / 1.055, 2.4)
return 0.2126 * linearize(rs) + 0.7152 * linearize(gs) + 0.0722 * linearize(bs)
def _contrast_ratio(lum1: float, lum2: float) -> float:
"""Calculate contrast ratio between two luminance values."""
lighter = max(lum1, lum2)
darker = min(lum1, lum2)
return (lighter + 0.05) / (darker + 0.05)