diff --git a/backend/app/pipeline/agents/agent_2_tm_retrieval.py b/backend/app/pipeline/agents/agent_2_tm_retrieval.py index 95dd2a3..1663516 100644 --- a/backend/app/pipeline/agents/agent_2_tm_retrieval.py +++ b/backend/app/pipeline/agents/agent_2_tm_retrieval.py @@ -1,40 +1,334 @@ -"""Agent 2: TM Retrieval (STUB) +"""Agent 2: TM Retrieval -Retrieves Translation Memory matches for each source line. -Currently returns empty results as a stub. +Loads Translation Memory files and uses the LLM to find semantic matches +for each source line. Sends all source lines + all TM entries in a single +LLM call, then parses the JSON response into TMSweepResult objects. """ +import json +import logging +import os from typing import Any +from app.config import settings +from app.llm.client import LLMClient from app.pipeline.agents.base import BaseAgent -from app.pipeline.contracts import PipelineContext, TMSweepResult +from app.pipeline.contracts import ( + ConfirmedMatch, + PipelineContext, + TMEntry, + TMSweepResult, +) +from app.pipeline.modules.tm_file_loader import load_tm_file, TMFileLoadError + +logger = logging.getLogger(__name__) + +# Channel-to-filename mapping +CHANNEL_FILE_MAP: dict[str, str] = { + "value": "flat_value_{lc}.json", + "mass": "flat_MASS_{lc}.json", + "onsite": "flat_Onsite_{lc}.json", + "outbound": "flat_Outbound_{lc}.json", +} + + +def _resolve_tm_paths( + locale_code: str, + channel: str, + manifest_tm_files: list[str], +) -> list[str]: + """Resolve TM file paths from the manifest or channel mapping. + + If the manifest already has tm_files listed, use those. + Otherwise, derive the filename from the channel and locale. + + Returns: + A list of absolute file paths to TM files. + """ + if manifest_tm_files: + return manifest_tm_files + + lc = locale_code.lower() # e.g. "de-DE" -> "de-de" + pattern = CHANNEL_FILE_MAP.get(channel.lower()) + if not pattern: + logger.warning("Unknown channel %r; cannot resolve TM file", channel) + return [] + + filename = pattern.format(lc=lc) + # TM directory uses the original locale_code casing for the folder name + tm_dir = os.path.join(settings.STORAGE_ROOT, "amazon", "tm", locale_code) + path = os.path.join(tm_dir, filename) + return [path] + + +def _load_all_tm_entries( + tm_paths: list[str], + target_locale: str, +) -> list[TMEntry]: + """Load TM entries from all resolved paths, logging errors gracefully.""" + all_entries: list[TMEntry] = [] + for path in tm_paths: + try: + entries = load_tm_file(path, target_locale) + logger.info("Loaded %d TM entries from %s", len(entries), path) + all_entries.extend(entries) + except TMFileLoadError as exc: + logger.warning("Failed to load TM file %s: %s", path, exc) + except Exception as exc: + logger.error("Unexpected error loading TM file %s: %s", path, exc) + return all_entries + + +def _format_tm_entries_for_prompt(entries: list[TMEntry]) -> str: + """Format TM entries into a numbered list for the LLM prompt.""" + lines: list[str] = [] + for i, entry in enumerate(entries, start=1): + lines.append( + f"[TM-{i}] seg_key={entry.seg_key} | en={entry.en} | " + f"tx={entry.tx} | channel={entry.channel} | " + f"sub_channel={entry.sub_channel} | date={entry.date}" + ) + return "\n".join(lines) + + +def _format_source_lines_for_prompt(context: PipelineContext) -> str: + """Format source lines into a numbered list for the LLM prompt.""" + lines: list[str] = [] + for sl in context.source_lines: + lines.append( + f"[SRC line_id={sl.line_id}] en_gb={sl.en_gb} | " + f"copy_type={sl.copy_type or 'N/A'}" + ) + return "\n".join(lines) + + +def _empty_results(context: PipelineContext) -> list[TMSweepResult]: + """Return no_match=True results for every source line.""" + return [ + TMSweepResult( + line_id=line.line_id, + confirmed_matches=[], + no_match=True, + ) + for line in context.source_lines + ] class Agent2TMRetrieval(BaseAgent): - """STUB: TM retrieval agent returning empty sweep results.""" + """TM retrieval agent: finds semantic TM matches for each source line.""" name = "agent_2_tm_retrieval" - description = "Retrieves TM matches for source lines (STUB)" + description = "Retrieves TM matches for source lines using LLM semantic matching" def get_system_prompt(self) -> str: - return "You are a Translation Memory retrieval agent." + return ( + "You are a Translation Memory (TM) retrieval specialist. " + "Your job is to find the best semantic matches between English " + "source lines and existing TM entries.\n\n" + "A 'match' means the TM entry's English text is semantically " + "equivalent or very close in meaning to the source line. It does " + "NOT need to be an exact string match -- paraphrases, minor " + "wording changes, and formatting differences (e.g. line breaks) " + "should still count as matches.\n\n" + "For each source line, return up to 5 best matches ordered by " + "similarity. If no TM entry is semantically close, return an " + "empty matches array for that line.\n\n" + "You MUST respond with ONLY a JSON array (no markdown fences, " + "no commentary). Each element has this shape:\n" + "{\n" + ' "line_id": "",\n' + ' "matches": [\n' + " {\n" + ' "tm_index": <1-based index from TM list>,\n' + ' "similarity": "",\n' + ' "pass_found": <1 for exact, 2 for high, 3 for medium>\n' + " }\n" + " ]\n" + "}\n\n" + "Rules:\n" + '- "exact" (pass_found=1): identical or near-identical wording\n' + '- "high" (pass_found=2): same meaning, minor wording differences\n' + '- "medium" (pass_found=3): closely related meaning, could be reused with adaptation\n' + "- Omit matches below medium similarity\n" + "- Return the JSON array directly, no wrapping object" + ) def build_user_message(self, context: PipelineContext) -> str: - return "Retrieve TM matches for the provided source lines." + # This is called by run() but we build the full message there + # since we need the TM entries which aren't in context + return "" - def parse_response(self, response: str, context: PipelineContext) -> Any: - return [] + def _build_full_user_message( + self, context: PipelineContext, tm_entries: list[TMEntry] + ) -> str: + """Build the complete user message with source lines and TM entries.""" + source_text = _format_source_lines_for_prompt(context) + tm_text = _format_tm_entries_for_prompt(tm_entries) + + return ( + f"## Source Lines to Match\n{source_text}\n\n" + f"## Translation Memory Entries ({len(tm_entries)} total)\n{tm_text}\n\n" + "Find the best TM matches for each source line. " + "Return ONLY the JSON array." + ) + + def parse_response( + self, response: str, context: PipelineContext + ) -> Any: + """Parse raw LLM response; not used directly -- see _parse_llm_response.""" + return self._parse_llm_response(response, context, []) + + def _parse_llm_response( + self, + response: str, + context: PipelineContext, + tm_entries: list[TMEntry], + ) -> list[TMSweepResult]: + """Parse the LLM JSON response into TMSweepResult objects. + + Args: + response: Raw LLM response text (should be a JSON array). + context: Pipeline context with source lines. + tm_entries: The loaded TM entries (for index lookup). + + Returns: + List of TMSweepResult, one per source line. + """ + # Strip markdown fences if the LLM included them + text = response.strip() + if text.startswith("```"): + # Remove opening fence (possibly ```json) + first_nl = text.index("\n") if "\n" in text else len(text) + text = text[first_nl + 1 :] + if text.endswith("```"): + text = text[: -3] + text = text.strip() + + try: + data = json.loads(text) + except json.JSONDecodeError as exc: + logger.error("Failed to parse LLM response as JSON: %s", exc) + logger.debug("Raw response: %s", response[:500]) + return _empty_results(context) + + if not isinstance(data, list): + logger.error("LLM response is not a JSON array") + return _empty_results(context) + + # Build a lookup: line_id -> parsed matches + matches_by_line: dict[str, list[ConfirmedMatch]] = {} + for item in data: + if not isinstance(item, dict): + continue + line_id = item.get("line_id", "") + raw_matches = item.get("matches", []) + if not isinstance(raw_matches, list): + continue + + confirmed: list[ConfirmedMatch] = [] + for m in raw_matches: + if not isinstance(m, dict): + continue + tm_idx = m.get("tm_index") + if not isinstance(tm_idx, int) or tm_idx < 1 or tm_idx > len(tm_entries): + continue + + entry = tm_entries[tm_idx - 1] # 1-based -> 0-based + pass_found = m.get("pass_found", 3) + if pass_found not in (1, 2, 3): + pass_found = 3 + + confirmed.append( + ConfirmedMatch( + seg_key=entry.seg_key, + pass_found=pass_found, + date=entry.date, + en=entry.en, + tx=entry.tx, + nt=entry.nt, + channel=entry.channel, + sub_channel=entry.sub_channel, + is_cross_channel=False, + ) + ) + matches_by_line[line_id] = confirmed + + # Build results for every source line, preserving order + results: list[TMSweepResult] = [] + for sl in context.source_lines: + confirmed = matches_by_line.get(sl.line_id, []) + results.append( + TMSweepResult( + line_id=sl.line_id, + confirmed_matches=confirmed, + no_match=len(confirmed) == 0, + ) + ) + return results async def run(self, context: PipelineContext) -> PipelineContext: - """STUB: Return empty TM sweep results for all source lines.""" - context.tm_sweep_results = [ - TMSweepResult( - line_id=line.line_id, - confirmed_matches=[], - pass_4_triggered=False, - pass_4_result=None, - no_match=True, + """Execute TM retrieval: load TM files, call LLM, parse results.""" + locale_code = context.job_params.locale_code + channel = context.job_params.channel + lc_lower = locale_code.lower() + + logger.info( + "Agent 2 TM Retrieval starting: locale=%s channel=%s lines=%d", + locale_code, + channel, + len(context.source_lines), + ) + + # ── Step 1: Resolve and load TM files ─────────────────────────── + tm_paths = _resolve_tm_paths( + locale_code, channel, context.file_manifest.tm_files + ) + if not tm_paths: + logger.warning("No TM file paths resolved; returning no-match for all lines") + context.tm_sweep_results = _empty_results(context) + return context + + tm_entries = _load_all_tm_entries(tm_paths, lc_lower) + if not tm_entries: + logger.warning("No TM entries loaded; returning no-match for all lines") + context.tm_sweep_results = _empty_results(context) + return context + + logger.info("Loaded %d total TM entries from %d file(s)", len(tm_entries), len(tm_paths)) + + # ── Step 2: Build prompt and call LLM ──────────────────────────── + system_prompt = self.get_system_prompt() + user_message = self._build_full_user_message(context, tm_entries) + + try: + llm = LLMClient() + response_text, usage = await llm.acreate_message( + system_prompt, + user_message, + max_tokens=8192, + temperature=0.3, ) - for line in context.source_lines - ] + logger.info( + "Agent 2 LLM call complete: input_tokens=%s output_tokens=%s cost=$%s", + usage.get("input_tokens"), + usage.get("output_tokens"), + usage.get("estimated_cost_usd"), + ) + except Exception as exc: + logger.error("Agent 2 LLM call failed: %s", exc) + context.tm_sweep_results = _empty_results(context) + return context + + # ── Step 3: Parse response into TMSweepResult objects ──────────── + context.tm_sweep_results = self._parse_llm_response( + response_text, context, tm_entries + ) + + matched = sum(1 for r in context.tm_sweep_results if not r.no_match) + logger.info( + "Agent 2 complete: %d/%d lines have TM matches", + matched, + len(context.tm_sweep_results), + ) + return context diff --git a/backend/app/pipeline/agents/agent_3_ranker.py b/backend/app/pipeline/agents/agent_3_ranker.py index 568dd58..a8cc676 100644 --- a/backend/app/pipeline/agents/agent_3_ranker.py +++ b/backend/app/pipeline/agents/agent_3_ranker.py @@ -1,42 +1,292 @@ -"""Agent 3: Ranker (STUB) +"""Agent 3: Ranker Ranks TM matches and declares confidence tiers for each source line. -Currently returns LOW confidence for all lines as a stub. +This agent is DETERMINISTIC -- no LLM call required. It applies the +V25 rules-based selection algorithm: + + 1. No matches -> LOW confidence, 3 options, new creative line. + 2. Has matches -> score each match by channel fit, sub-channel fit, + and recency, then pick a winner and classify confidence. """ +from __future__ import annotations + +import logging +import re from typing import Any from app.pipeline.agents.base import BaseAgent -from app.pipeline.contracts import PipelineContext, RankingDeclaration +from app.pipeline.contracts import ( + ConfirmedMatch, + PipelineContext, + RankingDeclaration, + TMSweepResult, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_YEAR_RE = re.compile(r"\b(\d{2})\b") + + +def _extract_year(seg_key: str) -> int: + """Extract the most recent 2-digit year from a seg_key string. + + Seg keys look like "Value Q1 24 Radio" -- we want 24. + If multiple 2-digit numbers exist we take the *largest* that + is plausibly a year (>= 18 and <= 30 to cover 2018-2030). + Falls back to 0 if nothing is found. + """ + candidates = [int(m) for m in _YEAR_RE.findall(seg_key)] + plausible = [y for y in candidates if 18 <= y <= 30] + if plausible: + return max(plausible) + # If nothing in the plausible range, return highest candidate anyway + return max(candidates) if candidates else 0 + + +def _channel_matches(match: ConfirmedMatch, job_channel: str) -> bool: + """Case-insensitive channel comparison.""" + return match.channel.strip().lower() == job_channel.strip().lower() + + +def _sub_channel_matches(match: ConfirmedMatch, job_sub_channel: str | None) -> bool: + """Case-insensitive sub-channel comparison.""" + if not job_sub_channel: + return False + return match.sub_channel.strip().lower() == job_sub_channel.strip().lower() + + +def _score_match( + match: ConfirmedMatch, + job_channel: str, + job_sub_channel: str | None, +) -> tuple[int, int, int]: + """Return a sort key tuple (higher is better). + + Priority order: + 1. Channel match (1 = yes, 0 = no) + 2. Sub-channel match (1 = yes, 0 = no) + 3. Recency (2-digit year extracted from seg_key) + """ + ch = 1 if _channel_matches(match, job_channel) else 0 + sc = 1 if _sub_channel_matches(match, job_sub_channel) else 0 + yr = _extract_year(match.seg_key) + return (ch, sc, yr) + + +# --------------------------------------------------------------------------- +# Confidence classification +# --------------------------------------------------------------------------- + + +def _classify_confidence( + winner: ConfirmedMatch, + job_channel: str, + job_sub_channel: str | None, +) -> str: + """Determine the confidence tier for a winning match. + + HIGH - same channel, same sub-channel (if applicable), recent year. + MODERATE - cross-channel, different sub-channel, or older year. + LOW - should not reach here (no-match case handled separately). + """ + ch_ok = _channel_matches(winner, job_channel) + sc_ok = _sub_channel_matches(winner, job_sub_channel) if job_sub_channel else True + year = _extract_year(winner.seg_key) + + if ch_ok and sc_ok and year >= 23: + return "high" + if ch_ok and year >= 22: + return "moderate" + # Cross-channel or old + return "moderate" if not winner.is_cross_channel else "moderate" + + +def _option_count_for_tier(tier: str) -> int: + """Return the number of draft options required for a tier.""" + return {"high": 1, "moderate": 2, "low": 3}.get(tier, 3) + + +# --------------------------------------------------------------------------- +# Per-line ranking +# --------------------------------------------------------------------------- + + +def _rank_line( + sweep: TMSweepResult, + job_channel: str, + job_sub_channel: str | None, +) -> RankingDeclaration: + """Build a RankingDeclaration for a single source line.""" + + # ------------------------------------------------------------------ + # 1. No matches at all + # ------------------------------------------------------------------ + if sweep.no_match and not sweep.confirmed_matches and not sweep.pass_4_result: + return RankingDeclaration( + line_id=sweep.line_id, + winning_entry=None, + runner_ups=[], + confidence_tier="low", + option_count=3, + is_new_creative_line=True, + notes="No TM matches found -- new creative line.", + ) + + # ------------------------------------------------------------------ + # 2. Collect all candidates + # ------------------------------------------------------------------ + candidates: list[ConfirmedMatch] = list(sweep.confirmed_matches) + + # If pass 4 produced a fuzzy/semantic result, include it as a candidate. + if sweep.pass_4_triggered and sweep.pass_4_result is not None: + # Avoid duplicates (same seg_key) + existing_keys = {c.seg_key for c in candidates} + if sweep.pass_4_result.seg_key not in existing_keys: + candidates.append(sweep.pass_4_result) + + # Edge case: confirmed_matches is empty but pass_4_result exists + if not candidates and sweep.pass_4_result is not None: + candidates = [sweep.pass_4_result] + + # Still nothing after gathering? + if not candidates: + return RankingDeclaration( + line_id=sweep.line_id, + winning_entry=None, + runner_ups=[], + confidence_tier="low", + option_count=3, + is_new_creative_line=True, + notes="No usable TM candidates after filtering.", + ) + + # ------------------------------------------------------------------ + # 3. Score and sort (descending -- best first) + # ------------------------------------------------------------------ + scored = sorted( + candidates, + key=lambda m: _score_match(m, job_channel, job_sub_channel), + reverse=True, + ) + + winner = scored[0] + runner_ups = scored[1:] + + # Mark cross-channel on runner-ups for downstream awareness + for ru in runner_ups: + if not _channel_matches(ru, job_channel): + ru.is_cross_channel = True + if not _channel_matches(winner, job_channel): + winner.is_cross_channel = True + + # ------------------------------------------------------------------ + # 4. Determine confidence tier + # ------------------------------------------------------------------ + tier = _classify_confidence(winner, job_channel, job_sub_channel) + opt_count = _option_count_for_tier(tier) + + # Build notes + notes_parts: list[str] = [] + notes_parts.append(f"Winner: {winner.seg_key} (pass {winner.pass_found})") + if winner.is_cross_channel: + notes_parts.append("cross-channel match") + notes_parts.append(f"year={_extract_year(winner.seg_key)}") + notes_parts.append(f"{len(runner_ups)} runner-up(s)") + + return RankingDeclaration( + line_id=sweep.line_id, + winning_entry=winner, + runner_ups=runner_ups, + confidence_tier=tier, + option_count=opt_count, + is_new_creative_line=False, + notes="; ".join(notes_parts), + ) + + +# --------------------------------------------------------------------------- +# Agent +# --------------------------------------------------------------------------- class Agent3Ranker(BaseAgent): - """STUB: Ranking agent returning LOW confidence for all lines.""" + """Deterministic ranking agent -- no LLM call required. + + Applies the V25 rules-based selection algorithm to TM sweep + results and populates ``context.ranking_declarations``. + """ name = "agent_3_ranker" - description = "Ranks TM matches and declares confidence (STUB)" + description = "Ranks TM matches and declares confidence tiers (deterministic)" + + # -- LLM interface stubs (not used) ---------------------------------- def get_system_prompt(self) -> str: - return "You are a ranking and confidence declaration agent." + return "" def build_user_message(self, context: PipelineContext) -> str: - return "Rank the TM matches for each source line." + return "" def parse_response(self, response: str, context: PipelineContext) -> Any: - return [] + return None + + # -- Core logic ------------------------------------------------------- async def run(self, context: PipelineContext) -> PipelineContext: - """STUB: Return LOW confidence ranking for all source lines.""" - context.ranking_declarations = [ - RankingDeclaration( - line_id=line.line_id, - winning_entry=None, - runner_ups=[], - confidence_tier="low", - option_count=3, - is_new_creative_line=True, - notes="STUB: No TM matches available", - ) - for line in context.source_lines - ] + """Rank TM matches for every source line and populate declarations.""" + + job_channel = context.job_params.channel + job_sub_channel = context.job_params.sub_channel + + # Build a lookup from line_id -> TMSweepResult + sweep_map: dict[str, TMSweepResult] = { + s.line_id: s for s in context.tm_sweep_results + } + + declarations: list[RankingDeclaration] = [] + + for line in context.source_lines: + sweep = sweep_map.get(line.line_id) + + if sweep is None: + # No sweep result for this line -- treat as new creative + logger.warning( + "No TMSweepResult for line_id=%s; treating as new creative.", + line.line_id, + ) + declarations.append( + RankingDeclaration( + line_id=line.line_id, + winning_entry=None, + runner_ups=[], + confidence_tier="low", + option_count=3, + is_new_creative_line=True, + notes="Missing TMSweepResult -- new creative line.", + ) + ) + continue + + declaration = _rank_line(sweep, job_channel, job_sub_channel) + declarations.append(declaration) + + context.ranking_declarations = declarations + + # Log summary + tier_counts = {"high": 0, "moderate": 0, "low": 0} + for d in declarations: + tier_counts[d.confidence_tier] = tier_counts.get(d.confidence_tier, 0) + 1 + logger.info( + "Agent 3 Ranker complete: %d lines ranked -- high=%d, moderate=%d, low=%d", + len(declarations), + tier_counts["high"], + tier_counts["moderate"], + tier_counts["low"], + ) + return context diff --git a/backend/app/pipeline/agents/agent_4_transcreator.py b/backend/app/pipeline/agents/agent_4_transcreator.py index 932dbaa..348aa96 100644 --- a/backend/app/pipeline/agents/agent_4_transcreator.py +++ b/backend/app/pipeline/agents/agent_4_transcreator.py @@ -1,55 +1,744 @@ -"""Agent 4: Transcreator (STUB) +"""Agent 4: Transcreator -Generates transcreation drafts for each source line. -Currently returns placeholder translations as a stub. +Core agent that generates real transcreation drafts for each source line +using Claude. Processes lines in batches, producing culturally-adapted +translations with backtranslations and rationale for each option. """ +import json +import logging +import re from typing import Any +from app.llm.client import LLMClient from app.pipeline.agents.base import BaseAgent -from app.pipeline.contracts import DraftOption, DraftOutput, PipelineContext +from app.pipeline.contracts import ( + DraftOption, + DraftOutput, + PipelineContext, + RankingDeclaration, + SourceLineContract, +) +from app.pipeline.modules.ref_file_loader import ( + RefFileLoadError, + load_all_reference_files, +) +logger = logging.getLogger(__name__) + +BATCH_SIZE = 15 + +# --------------------------------------------------------------------------- +# Locale display names for prompt context +# --------------------------------------------------------------------------- +LOCALE_NAMES: dict[str, str] = { + "de_DE": "German (Germany)", + "fr_FR": "French (France)", + "it_IT": "Italian (Italy)", + "es_ES": "Spanish (Spain)", + "nl_NL": "Dutch (Netherlands)", + "pl_PL": "Polish (Poland)", + "sv_SE": "Swedish (Sweden)", + "pt_PT": "Portuguese (Portugal)", + "da_DK": "Danish (Denmark)", + "nb_NO": "Norwegian Bokmal (Norway)", + "fi_FI": "Finnish (Finland)", + "tr_TR": "Turkish (Turkey)", +} + +# --------------------------------------------------------------------------- +# V25 voice guidance per programme +# --------------------------------------------------------------------------- +VOICE_PROFILES: dict[str, str] = { + "retail": ( + "Voice: Real, Clear, Playful, Witty.\n" + "Mission: Communicate value, convenience, and selection. " + "Speak like a smart friend who knows the best deal. " + "Keep the tone grounded and helpful with a light touch of humour." + ), + "prime": ( + "Voice: Optimistic, Honest, Self-aware, Witty, Relatable.\n" + "Mission: Transform expectations of what a membership can offer. " + "Be enthusiastic but never arrogant. " + "Surprise and delight with self-aware, culturally relevant wit." + ), + "brand": ( + "Voice: Authentic, Customer-obsessed, Intelligent, Warm, Understated.\n" + "Mission: Build long-term trust and emotional connection. " + "Lead with empathy, not spectacle. " + "Every word should feel intentional and earned." + ), +} + + +def _get_voice_profile(programme: str) -> str: + """Return the V25 voice profile for the given programme.""" + key = programme.strip().lower() + return VOICE_PROFILES.get(key, VOICE_PROFILES["retail"]) + + +def _get_locale_display(locale_code: str) -> str: + """Return a human-readable locale name.""" + return LOCALE_NAMES.get(locale_code, locale_code) + + +# --------------------------------------------------------------------------- +# Reference data formatting helpers +# --------------------------------------------------------------------------- + +def _format_glossary_section(glossary_data: Any) -> str: + """Format glossary data into a prompt-friendly string.""" + if not glossary_data: + return "" + + entries: list[dict] = [] + # Handle {"entries": [...]} wrapper format + if isinstance(glossary_data, dict) and "entries" in glossary_data: + entries = glossary_data["entries"] + elif isinstance(glossary_data, list): + entries = glossary_data + + if not entries: + return "" + + lines = ["## Glossary (Term Locks)", ""] + lines.append( + "These terms MUST be preserved exactly as shown in the target column. " + "Do NOT translate, paraphrase, or alter them." + ) + lines.append("") + lines.append("| English (en_GB) | Target Term |") + lines.append("|---|---|") + for entry in entries[:80]: # Cap to keep prompt manageable + en = entry.get("en_GB") or entry.get("en") or entry.get("source", "") + tx = entry.get("term") or entry.get("tx") or entry.get("target", "") + if en and tx: + lines.append(f"| {en} | {tx} |") + + return "\n".join(lines) + + +def _format_blacklist_section(blacklist_data: Any) -> str: + """Format blacklist data into a prompt-friendly string.""" + if not blacklist_data: + return "" + + entries = blacklist_data if isinstance(blacklist_data, list) else [] + if not entries: + return "" + + lines = [ + "## Blacklisted Terms", + "", + "The following terms or roots must NEVER appear in any transcreation output.", + "", + ] + for entry in entries[:60]: + term = entry.get("term", "") + reason = entry.get("reason", "") + if term: + lines.append(f"- **{term}**" + (f" ({reason})" if reason else "")) + + return "\n".join(lines) + + +def _format_tov_section(tov_data: Any) -> str: + """Format Tone of Voice guidelines into a prompt-friendly string.""" + if not tov_data: + return "" + + # TOV files may have varying structures; serialise the key points + lines = ["## Tone of Voice (Global Guidelines)", ""] + if isinstance(tov_data, dict): + for key, value in tov_data.items(): + if isinstance(value, str): + lines.append(f"**{key}**: {value}") + elif isinstance(value, list): + lines.append(f"**{key}**:") + for item in value[:20]: + lines.append(f" - {item}" if isinstance(item, str) else f" - {json.dumps(item)}") + lines.append("") + + return "\n".join(lines) + + +def _format_locale_considerations_section(lc_data: Any) -> str: + """Format locale considerations into a prompt-friendly string.""" + if not lc_data: + return "" + + lines = ["## Locale-Specific Considerations", ""] + if isinstance(lc_data, dict): + for key, value in lc_data.items(): + if isinstance(value, str): + lines.append(f"**{key}**: {value}") + elif isinstance(value, list): + lines.append(f"**{key}**:") + for item in value[:15]: + lines.append(f" - {item}" if isinstance(item, str) else f" - {json.dumps(item)}") + elif isinstance(value, dict): + lines.append(f"**{key}**: {json.dumps(value, ensure_ascii=False)}") + lines.append("") + + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# System prompt builder +# --------------------------------------------------------------------------- + +def build_system_prompt( + locale_code: str, + channel: str, + sub_channel: str | None, + programme: str, + campaign_name: str, + ref_data: dict[str, Any], +) -> str: + """Build the comprehensive system prompt for the transcreation LLM call.""" + + locale_display = _get_locale_display(locale_code) + voice = _get_voice_profile(programme) + + channel_str = channel + if sub_channel: + channel_str += f" / {sub_channel}" + + sections: list[str] = [] + + # --- Role --- + sections.append( + f"You are a Senior Transcreation Specialist for Amazon marketing campaigns.\n" + f"Your target locale is **{locale_display}** (`{locale_code}`).\n" + f"Channel: **{channel_str}**\n" + f"Programme: **{programme}**\n" + f"Campaign: **{campaign_name}**" + ) + + # --- Voice profile --- + sections.append( + f"# Amazon V25 Voice Profile ({programme})\n\n{voice}" + ) + + # --- Hard rules --- + sections.append( + "# Hard Rules\n\n" + "1. **Positivity**: Never use negative constructions (e.g. 'Don't miss out' -> 'Discover now'). " + "All copy must feel uplifting, empowering, and forward-looking.\n" + "2. **Term Locks**: Glossary terms listed below MUST be preserved exactly as given. Do not translate, " + "paraphrase, or alter locked terms.\n" + "3. **No Literal Translations of Idioms**: English idioms must be replaced with culturally equivalent " + "expressions in the target locale. Never translate idioms word-for-word.\n" + "4. **Cultural Adaptation**: Adapt humour, references, and tone to resonate authentically with the " + "target audience. What works in British English may not land the same way.\n" + "5. **Character Limits**: Respect the character limit for each line. If a limit is provided, your " + "transcreation must not exceed it. Apply minimal relaxation only when absolutely necessary and note it.\n" + "6. **Display Format Lines**: Lines marked as display format may contain HTML-like tags or placeholders. " + "Preserve all tags/placeholders exactly.\n" + "7. **Backtranslation**: Every option must include a backtranslation into English that accurately " + "conveys the meaning and tone of the transcreated text." + ) + + # --- Reference data sections --- + glossary_sec = _format_glossary_section(ref_data.get("glossary_file")) + if glossary_sec: + sections.append(glossary_sec) + + blacklist_sec = _format_blacklist_section(ref_data.get("blacklist_file")) + if blacklist_sec: + sections.append(blacklist_sec) + + tov_sec = _format_tov_section(ref_data.get("tov_global_file")) + if tov_sec: + sections.append(tov_sec) + + tov_supp = _format_tov_section(ref_data.get("tov_supplement_file")) + if tov_supp: + sections.append(tov_supp.replace( + "## Tone of Voice (Global Guidelines)", + "## Tone of Voice (Supplementary Guidelines)", + )) + + lc_sec = _format_locale_considerations_section( + ref_data.get("locale_considerations_file") + ) + if lc_sec: + sections.append(lc_sec) + + # --- Output format --- + sections.append( + "# Output Format\n\n" + "Return ONLY a valid JSON array (no markdown fences, no commentary). " + "Each element must match this schema:\n" + "```\n" + "{\n" + ' "line_id": "",\n' + ' "option_1": {"text": "", "backtranslation": "", "rationale": ""},\n' + ' "option_2": ,\n' + ' "option_3": ,\n' + ' "tm_entries_cited": ["", ...],\n' + ' "adaptations_applied": ["", ...]\n' + "}\n" + "```\n\n" + "Rules for options:\n" + "- **HIGH confidence** lines (with a winning TM entry): provide 1 option. " + "option_1 MUST be anchored to the winning TM entry (refine, don't reinvent). " + "option_2 and option_3 should be null.\n" + "- **MODERATE confidence** lines: provide exactly 2 options (option_3 = null).\n" + "- **LOW confidence** or **new creative** lines: provide all 3 options, " + "each offering a distinct creative angle.\n\n" + "For every option:\n" + "- `text`: The transcreated copy in the target locale.\n" + "- `backtranslation`: An English back-translation that captures meaning and tone.\n" + "- `rationale`: 1-2 sentences explaining the creative choices.\n" + "- `tm_entries_cited`: List any TM seg_keys you drew from.\n" + "- `adaptations_applied`: Brief notes on cultural/linguistic adaptations made." + ) + + return "\n\n---\n\n".join(sections) + + +# --------------------------------------------------------------------------- +# User message builder +# --------------------------------------------------------------------------- + +def build_user_message_for_batch( + batch_lines: list[SourceLineContract], + ranking_map: dict[str, RankingDeclaration], + locale_code: str, + context_prompt: str | None, + glossary_entries: list[dict] | None, +) -> str: + """Build the user message JSON for a single batch of source lines.""" + + lines_payload: list[dict[str, Any]] = [] + for line in batch_lines: + ranking = ranking_map.get(line.line_id) + confidence = ranking.confidence_tier if ranking else "low" + option_count = ranking.option_count if ranking else 3 + is_new = ranking.is_new_creative_line if ranking else True + + entry: dict[str, Any] = { + "line_id": line.line_id, + "en_gb": line.en_gb, + "copy_type": line.copy_type, + "creative_guidance": line.creative_guidance, + "char_limit": line.char_limit, + "is_display_format": line.is_display_format, + "confidence_tier": confidence, + "option_count": option_count, + "is_new_creative_line": is_new, + } + + # Include winning TM entry if present + if ranking and ranking.winning_entry: + entry["winning_tm_entry"] = { + "seg_key": ranking.winning_entry.seg_key, + "en": ranking.winning_entry.en, + "tx": ranking.winning_entry.tx, + "notes": ranking.winning_entry.nt, + } + + # Include runner-up TM entries for reference + if ranking and ranking.runner_ups: + entry["runner_up_tm_entries"] = [ + { + "seg_key": ru.seg_key, + "en": ru.en, + "tx": ru.tx, + } + for ru in ranking.runner_ups[:3] + ] + + lines_payload.append(entry) + + message: dict[str, Any] = { + "target_locale": locale_code, + "source_lines": lines_payload, + } + + if context_prompt: + message["additional_context"] = context_prompt + + # Include a compact glossary summary (top relevant entries) + if glossary_entries: + compact = [] + for g in glossary_entries[:40]: + en = g.get("en_GB") or g.get("en") or g.get("source", "") + tx = g.get("term") or g.get("tx") or g.get("target", "") + if en and tx: + compact.append({"en": en, "tx": tx}) + if compact: + message["glossary_quick_ref"] = compact + + return json.dumps(message, ensure_ascii=False, indent=2) + + +# --------------------------------------------------------------------------- +# Fallback prompt for retry / error recovery +# --------------------------------------------------------------------------- + +def _build_fallback_user_message( + batch_lines: list[SourceLineContract], + locale_code: str, +) -> str: + """Build a simplified user message for fallback/retry attempts.""" + lines = [] + for line in batch_lines: + lines.append({ + "line_id": line.line_id, + "en_gb": line.en_gb, + "char_limit": line.char_limit, + "confidence_tier": "low", + "option_count": 3, + }) + + return json.dumps({ + "target_locale": locale_code, + "source_lines": lines, + "instruction": ( + "Transcreate each line. Return a JSON array. " + "Each element: {\"line_id\": \"...\", " + "\"option_1\": {\"text\": \"...\", \"backtranslation\": \"...\", \"rationale\": \"...\"}, " + "\"option_2\": ..., \"option_3\": ..., " + "\"tm_entries_cited\": [], \"adaptations_applied\": []}. " + "Return ONLY the JSON array, no other text." + ), + }, ensure_ascii=False, indent=2) + + +FALLBACK_SYSTEM_PROMPT = ( + "You are a professional marketing transcreation specialist. " + "Translate the given English (en_GB) marketing copy into the target locale. " + "Produce creative, culturally adapted translations - not literal translations. " + "Return ONLY a valid JSON array with no markdown fences or commentary." +) + + +# --------------------------------------------------------------------------- +# JSON parsing helpers +# --------------------------------------------------------------------------- + +def _extract_json_array(text: str) -> list[dict]: + """Extract a JSON array from LLM response text, handling markdown fences.""" + # Strip markdown code fences if present + cleaned = text.strip() + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + cleaned = cleaned.strip() + + # Try direct parse + try: + parsed = json.loads(cleaned) + if isinstance(parsed, list): + return parsed + if isinstance(parsed, dict): + # Maybe the model wrapped in an object + for key in ("results", "translations", "output", "data"): + if key in parsed and isinstance(parsed[key], list): + return parsed[key] + return [parsed] + except json.JSONDecodeError: + pass + + # Try to find a JSON array in the text + match = re.search(r"\[[\s\S]*\]", cleaned) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + raise ValueError("Could not extract valid JSON array from LLM response") + + +def _parse_draft_option(data: dict | None) -> DraftOption | None: + """Parse a single draft option dict into a DraftOption model.""" + if not data or not isinstance(data, dict): + return None + text = data.get("text", "").strip() + if not text: + return None + return DraftOption( + text=text, + backtranslation=data.get("backtranslation", "").strip(), + rationale=data.get("rationale", "").strip(), + ) + + +def parse_batch_response( + response_text: str, + batch_lines: list[SourceLineContract], +) -> list[DraftOutput]: + """Parse the LLM response for a batch into DraftOutput objects. + + Raises ValueError if the JSON cannot be parsed. + """ + items = _extract_json_array(response_text) + + # Build a lookup for quick access + items_by_id: dict[str, dict] = {} + for item in items: + lid = str(item.get("line_id", "")) + if lid: + items_by_id[lid] = item + + outputs: list[DraftOutput] = [] + for line in batch_lines: + item = items_by_id.get(line.line_id) + if not item: + # If the model omitted this line, create a minimal fallback + logger.warning( + f"Line {line.line_id} missing from LLM response; " + "generating placeholder." + ) + outputs.append(DraftOutput( + line_id=line.line_id, + option_1=DraftOption( + text=f"[MISSING - {line.en_gb}]", + backtranslation=line.en_gb, + rationale="LLM did not return this line; placeholder generated.", + ), + )) + continue + + opt1 = _parse_draft_option(item.get("option_1")) + if not opt1: + # option_1 is required; use a fallback + opt1 = DraftOption( + text=f"[PARSE ERROR - {line.en_gb}]", + backtranslation=line.en_gb, + rationale="Could not parse option_1 from LLM response.", + ) + + outputs.append(DraftOutput( + line_id=line.line_id, + option_1=opt1, + option_2=_parse_draft_option(item.get("option_2")), + option_3=_parse_draft_option(item.get("option_3")), + tm_entries_cited=item.get("tm_entries_cited", []) or [], + adaptations_applied=item.get("adaptations_applied", []) or [], + )) + + return outputs + + +# --------------------------------------------------------------------------- +# Agent class +# --------------------------------------------------------------------------- class Agent4Transcreator(BaseAgent): - """STUB: Transcreation agent returning placeholder translations.""" + """Transcreation agent: generates real translations using Claude. + + Processes source lines in batches of ~15, calling the LLM for each batch + with a comprehensive system prompt incorporating voice profile, glossary, + blacklist, locale considerations, and TM context. + """ name = "agent_4_transcreator" - description = "Generates transcreation drafts (STUB)" + description = "Generates transcreation drafts using Claude LLM" + + def __init__(self) -> None: + self._llm = LLMClient() + self._system_prompt: str = "" + self._ref_data: dict[str, Any] = {} + self._glossary_entries: list[dict] | None = None + + # -- BaseAgent interface (used for documentation; actual work in run()) -- def get_system_prompt(self) -> str: - return "You are a creative transcreation agent." + return self._system_prompt def build_user_message(self, context: PipelineContext) -> str: - return "Generate transcreation drafts for each source line." + # Delegated to build_user_message_for_batch per batch + return "" def parse_response(self, response: str, context: PipelineContext) -> Any: + # Delegated to parse_batch_response per batch return [] - async def run(self, context: PipelineContext) -> PipelineContext: - """STUB: Return placeholder translations for all source lines.""" - locale = context.job_params.locale_code + # -- Main execution -- - context.draft_outputs = [ + async def run(self, context: PipelineContext) -> PipelineContext: + """Execute the transcreation agent. + + 1. Load reference files from file_manifest paths + 2. Build system prompt with reference data + 3. Batch source lines into groups of ~BATCH_SIZE + 4. For each batch, build user message, call LLM, parse response + 5. Combine all DraftOutput objects into context.draft_outputs + 6. Return context + """ + job = context.job_params + logger.info( + f"[{job.job_id}] Agent 4 starting: " + f"locale={job.locale_code}, lines={len(context.source_lines)}" + ) + + # 1. Load reference files + self._ref_data = self._load_reference_files(context) + + # Extract glossary entries for user message inclusion + glossary_raw = self._ref_data.get("glossary_file") + if glossary_raw: + if isinstance(glossary_raw, dict) and "entries" in glossary_raw: + self._glossary_entries = glossary_raw["entries"] + elif isinstance(glossary_raw, list): + self._glossary_entries = glossary_raw + + # 2. Build system prompt + self._system_prompt = build_system_prompt( + locale_code=job.locale_code, + channel=job.channel, + sub_channel=job.sub_channel, + programme=job.programme, + campaign_name=job.campaign_name, + ref_data=self._ref_data, + ) + + # 3. Build ranking lookup + ranking_map: dict[str, RankingDeclaration] = { + r.line_id: r for r in context.ranking_declarations + } + + # 4. Process in batches + all_outputs: list[DraftOutput] = [] + batches = self._make_batches(context.source_lines, BATCH_SIZE) + + for batch_idx, batch in enumerate(batches, 1): + logger.info( + f"[{job.job_id}] Processing batch {batch_idx}/{len(batches)} " + f"({len(batch)} lines)" + ) + + outputs = await self._process_batch( + batch=batch, + ranking_map=ranking_map, + locale_code=job.locale_code, + context_prompt=job.context_prompt, + ) + all_outputs.extend(outputs) + + # 5. Store results + context.draft_outputs = all_outputs + + logger.info( + f"[{job.job_id}] Agent 4 complete: " + f"{len(all_outputs)} draft outputs produced" + ) + + return context + + # -- Internal helpers -- + + def _load_reference_files(self, context: PipelineContext) -> dict[str, Any]: + """Load all reference files from the file manifest.""" + manifest = context.file_manifest + manifest_dict: dict[str, str | None] = { + "glossary_file": manifest.glossary_file, + "blacklist_file": manifest.blacklist_file, + "tov_global_file": manifest.tov_global_file, + "tov_supplement_file": manifest.tov_supplement_file, + "locale_considerations_file": manifest.locale_considerations_file, + "date_pct_formats_file": manifest.date_pct_formats_file, + } + try: + return load_all_reference_files(manifest_dict) + except Exception as exc: + logger.warning(f"Error loading reference files: {exc}") + return {} + + @staticmethod + def _make_batches( + lines: list[SourceLineContract], + batch_size: int, + ) -> list[list[SourceLineContract]]: + """Split source lines into batches.""" + return [ + lines[i: i + batch_size] + for i in range(0, len(lines), batch_size) + ] + + async def _process_batch( + self, + batch: list[SourceLineContract], + ranking_map: dict[str, RankingDeclaration], + locale_code: str, + context_prompt: str | None, + ) -> list[DraftOutput]: + """Process a single batch: build message, call LLM, parse, retry on failure.""" + + user_msg = build_user_message_for_batch( + batch_lines=batch, + ranking_map=ranking_map, + locale_code=locale_code, + context_prompt=context_prompt, + glossary_entries=self._glossary_entries, + ) + + # --- Attempt 1: Full prompt --- + try: + response_text, usage = await self._llm.acreate_message( + system_prompt=self._system_prompt, + user_message=user_msg, + max_tokens=16384, + temperature=0.7, + ) + logger.info( + f"Batch LLM call: {usage.get('total_tokens', 0)} tokens, " + f"${usage.get('estimated_cost_usd', 0):.4f}" + ) + return parse_batch_response(response_text, batch) + + except ValueError as parse_err: + logger.warning(f"Parse error on attempt 1: {parse_err}. Retrying...") + + except Exception as llm_err: + logger.error(f"LLM error on attempt 1: {llm_err}. Retrying...") + + # --- Attempt 2: Simplified retry --- + try: + fallback_msg = _build_fallback_user_message(batch, locale_code) + response_text, usage = await self._llm.acreate_message( + system_prompt=FALLBACK_SYSTEM_PROMPT, + user_message=fallback_msg, + max_tokens=16384, + temperature=0.7, + ) + logger.info( + f"Retry LLM call: {usage.get('total_tokens', 0)} tokens, " + f"${usage.get('estimated_cost_usd', 0):.4f}" + ) + return parse_batch_response(response_text, batch) + + except Exception as retry_err: + logger.error( + f"Retry also failed: {retry_err}. " + "Falling back to placeholder outputs." + ) + + # --- Fallback: placeholder outputs --- + return self._generate_placeholders(batch, locale_code) + + @staticmethod + def _generate_placeholders( + batch: list[SourceLineContract], + locale_code: str, + ) -> list[DraftOutput]: + """Generate placeholder DraftOutputs when all LLM attempts fail.""" + return [ DraftOutput( line_id=line.line_id, option_1=DraftOption( - text=f"[{locale}] {line.en_gb}", + text=f"[{locale_code}] {line.en_gb}", backtranslation=line.en_gb, - rationale=f"STUB: Direct placeholder for '{line.en_gb[:50]}...'", - ), - option_2=DraftOption( - text=f"[{locale} alt] {line.en_gb}", - backtranslation=line.en_gb, - rationale="STUB: Alternative placeholder", - ), - option_3=DraftOption( - text=f"[{locale} creative] {line.en_gb}", - backtranslation=line.en_gb, - rationale="STUB: Creative placeholder", + rationale=( + "FALLBACK: LLM transcreation failed after retries. " + "This is a placeholder that needs manual translation." + ), ), tm_entries_cited=[], - adaptations_applied=[], + adaptations_applied=["fallback_placeholder"], ) - for line in context.source_lines + for line in batch ] - return context diff --git a/backend/app/pipeline/agents/agent_5_compliance.py b/backend/app/pipeline/agents/agent_5_compliance.py index f4d6198..dcdb90c 100644 --- a/backend/app/pipeline/agents/agent_5_compliance.py +++ b/backend/app/pipeline/agents/agent_5_compliance.py @@ -1,38 +1,204 @@ -"""Agent 5: Compliance Checker (STUB) +"""Agent 5: Compliance Checker (Hybrid) -Checks transcreation drafts against compliance rules. -Currently returns PASS for all lines as a stub. +Runs deterministic compliance checks against transcreation drafts: +1. Character count validation against source-line char_limit +2. Blacklist scanning for forbidden terms +3. Domain substitution checks (Amazon.co.uk in non-en_GB locales) + +LLM-based checks (positivity, tone) are skipped for now to control costs. """ +import logging from typing import Any from app.pipeline.agents.base import BaseAgent -from app.pipeline.contracts import ComplianceResult, PipelineContext -from app.pipeline.modules.character_counter import count_characters +from app.pipeline.contracts import ( + ComplianceResult, + ComplianceViolation, + DraftOutput, + PipelineContext, + SourceLineContract, +) +from app.pipeline.modules.blacklist_scanner import scan_text +from app.pipeline.modules.character_counter import check_character_limit, count_characters +from app.pipeline.modules.domain_substitutor import SOURCE_DOMAIN_LOWER +from app.pipeline.modules.ref_file_loader import RefFileLoadError, load_json_file + +logger = logging.getLogger(__name__) + + +def _load_blacklist_entries(blacklist_file: str | None) -> list[dict]: + """Load blacklist entries from a JSON file, handling both array and object formats. + + Supported formats: + - JSON array: [{"term": "...", "root": "...", "reason": "..."}] + - JSON object: {"locale": "...", "entries": [...]} + + Returns an empty list if the file is missing, unreadable, or has an + unrecognised structure. + """ + if not blacklist_file: + return [] + + try: + data = load_json_file(blacklist_file) + except RefFileLoadError: + logger.warning("Blacklist file could not be loaded: %s", blacklist_file) + return [] + + if isinstance(data, list): + return data + if isinstance(data, dict): + entries = data.get("entries") + if isinstance(entries, list): + return entries + logger.warning( + "Blacklist file is a JSON object but has no 'entries' key: %s", + blacklist_file, + ) + return [] + + logger.warning("Unexpected blacklist format in %s", blacklist_file) + return [] + + +def _get_option_texts(draft: DraftOutput) -> list[tuple[int, str]]: + """Return a list of (option_number, text) for all non-None options.""" + options: list[tuple[int, str]] = [] + if draft.option_1: + options.append((1, draft.option_1.text)) + if draft.option_2: + options.append((2, draft.option_2.text)) + if draft.option_3: + options.append((3, draft.option_3.text)) + return options + + +def _check_character_limits( + draft: DraftOutput, + source_line: SourceLineContract | None, + char_counts: dict[str, int], +) -> list[ComplianceViolation]: + """Check each option against the source line's char_limit (if any).""" + violations: list[ComplianceViolation] = [] + if source_line is None or not source_line.char_limit: + return violations + + for opt_num, text in _get_option_texts(draft): + count, within = check_character_limit(text, source_line.char_limit) + if not within: + violations.append( + ComplianceViolation( + type="character_limit", + option_affected=opt_num, + description=( + f"Option {opt_num} has {count} characters, " + f"exceeds limit of {source_line.char_limit}" + ), + severity="warning", + ) + ) + return violations + + +def _check_blacklist( + draft: DraftOutput, + blacklist_entries: list[dict], +) -> list[ComplianceViolation]: + """Scan each option for blacklisted terms.""" + violations: list[ComplianceViolation] = [] + if not blacklist_entries: + return violations + + for opt_num, text in _get_option_texts(draft): + hits = scan_text(text, blacklist_entries) + for hit in hits: + violations.append( + ComplianceViolation( + type="blacklist", + option_affected=opt_num, + description=( + f"Blacklisted term '{hit.term}' found in option {opt_num} " + f"({hit.match_type} match): ...{hit.context}..." + ), + severity="error", + ) + ) + return violations + + +def _check_domains( + draft: DraftOutput, + locale_code: str, +) -> list[ComplianceViolation]: + """Flag options that still contain the source domain (Amazon.co.uk) + when the target locale is not en_GB.""" + violations: list[ComplianceViolation] = [] + + # en_GB is the source locale, so Amazon.co.uk is correct there + if locale_code.lower() == "en_gb": + return violations + + for opt_num, text in _get_option_texts(draft): + if SOURCE_DOMAIN_LOWER in text.lower(): + violations.append( + ComplianceViolation( + type="domain", + option_affected=opt_num, + description=( + f"Option {opt_num} contains 'Amazon.co.uk' which should " + f"be replaced with the locale-specific domain for {locale_code}" + ), + severity="error", + ) + ) + return violations class Agent5Compliance(BaseAgent): - """STUB: Compliance agent returning pass for all lines.""" + """Hybrid compliance agent: deterministic checks first, optional LLM later.""" name = "agent_5_compliance" - description = "Checks compliance of transcreation drafts (STUB)" + description = "Checks compliance of transcreation drafts against blacklist, character limits, and domain rules" + + # ── LLM interface (unused for deterministic-only mode) ────────────── def get_system_prompt(self) -> str: - return "You are a compliance checking agent." + return "" # No LLM call in deterministic mode def build_user_message(self, context: PipelineContext) -> str: - return "Check compliance for all transcreation drafts." + return "" # No LLM call in deterministic mode def parse_response(self, response: str, context: PipelineContext) -> Any: - return [] + return None # No LLM call in deterministic mode + + # ── Main execution ────────────────────────────────────────────────── async def run(self, context: PipelineContext) -> PipelineContext: - """STUB: Return pass for all compliance checks, with character counts.""" - context.compliance_results = [] + """Run deterministic compliance checks on every draft output. + + Checks performed: + 1. Character-count validation against source-line char_limit + 2. Blacklist scanning (if a blacklist file is available) + 3. Domain reference validation (Amazon.co.uk in non-en_GB locales) + """ + locale_code = context.job_params.locale_code + + # Build a lookup from line_id -> source line for char-limit checks + source_map: dict[str, SourceLineContract] = { + sl.line_id: sl for sl in context.source_lines + } + + # Load blacklist once (shared across all lines) + blacklist_entries = _load_blacklist_entries( + context.file_manifest.blacklist_file + ) + + results: list[ComplianceResult] = [] for draft in context.draft_outputs: + # ── Character counts (always populated) ───────────────── char_counts: dict[str, int] = {} - if draft.option_1: char_counts["option_1"] = count_characters(draft.option_1.text) if draft.option_2: @@ -40,13 +206,46 @@ class Agent5Compliance(BaseAgent): if draft.option_3: char_counts["option_3"] = count_characters(draft.option_3.text) - context.compliance_results.append( + # ── Collect violations ────────────────────────────────── + violations: list[ComplianceViolation] = [] + + source_line = source_map.get(draft.line_id) + + # 1. Character-limit check + violations.extend( + _check_character_limits(draft, source_line, char_counts) + ) + + # 2. Blacklist check + violations.extend( + _check_blacklist(draft, blacklist_entries) + ) + + # 3. Domain check + violations.extend( + _check_domains(draft, locale_code) + ) + + # ── Determine pass/fail ───────────────────────────────── + # passed = True unless there is at least one "error"-severity violation + has_error = any(v.severity == "error" for v in violations) + + results.append( ComplianceResult( line_id=draft.line_id, - passed=True, - violations=[], + passed=not has_error, + violations=violations, character_counts=char_counts, ) ) + context.compliance_results = results + + logger.info( + "Compliance checks complete: %d lines, %d passed, %d with errors", + len(results), + sum(1 for r in results if r.passed), + sum(1 for r in results if not r.passed), + ) + return context diff --git a/backend/app/pipeline/modules/tm_file_loader.py b/backend/app/pipeline/modules/tm_file_loader.py index 5554f33..2322187 100644 --- a/backend/app/pipeline/modules/tm_file_loader.py +++ b/backend/app/pipeline/modules/tm_file_loader.py @@ -55,8 +55,8 @@ def load_tm_file( if entry is None: continue - # Locale hard-match gate - if entry.lc == target_locale: + # Locale hard-match gate (case-insensitive) + if entry.lc.lower() == target_locale.lower(): entries.append(entry) except FileNotFoundError: @@ -72,6 +72,14 @@ def _parse_entry(data: dict[str, Any], line_num: int) -> TMEntry | None: Detects compact vs multi-field format automatically. + Compact format (V25 spec): + {"t": "{seg_key} {note_type} {locale_code} {EN_source} {TX}"} + The locale code (e.g., 'de-de') is the split point between metadata, + EN source text, and target-language translation. + + Multi-field format: + {"seg_key": "...", "en": "...", "lc": "...", "tx": "...", ...} + Args: data: Parsed JSON dict. line_num: Line number for error reporting. @@ -79,23 +87,9 @@ def _parse_entry(data: dict[str, Any], line_num: int) -> TMEntry | None: Returns: TMEntry or None if the entry is malformed. """ - # Compact format: {"t": "seg_key|date|en|lc|tx|nt|channel|sub_channel"} + # Compact format: {"t": "..."} if "t" in data and isinstance(data["t"], str): - parts = data["t"].split("|") - if len(parts) < 5: - return None # Malformed compact entry - - return TMEntry( - seg_key=parts[0] if len(parts) > 0 else "", - date=parts[1] if len(parts) > 1 else "", - en=parts[2] if len(parts) > 2 else "", - lc=parts[3] if len(parts) > 3 else "", - tx=parts[4] if len(parts) > 4 else "", - nt=parts[5] if len(parts) > 5 else "", - channel=parts[6] if len(parts) > 6 else "", - sub_channel=parts[7] if len(parts) > 7 else "", - _text=data["t"], - ) + return _parse_compact_entry(data["t"]) # Multi-field format if "seg_key" in data and "en" in data: @@ -113,6 +107,86 @@ def _parse_entry(data: dict[str, Any], line_num: int) -> TMEntry | None: return None +# Regex to find locale code pattern like "de-de", "fr-fr", "es-es" etc. +import re +_LOCALE_RE = re.compile(r"\b([a-z]{2}-[a-z]{2})\b") + + +def _parse_compact_entry(t_value: str) -> TMEntry | None: + """Parse the compact 't' field format used in real TM files. + + Format: '{seg_key_with_metadata} {locale_code} {EN_source} {TX}' + Example: 'Value Q1 24 Radio 001 VO de-de As Sophie opened... Sophie öffnet...' + + The locale code (xx-xx) is the reliable split point: + - Everything BEFORE the locale code is the seg_key + note_type metadata + - Everything AFTER needs to be split into EN source and TX translation + (we store the full post-locale text and let the LLM handle matching) + """ + match = _LOCALE_RE.search(t_value) + if not match: + return None + + locale_code = match.group(1) + before_locale = t_value[: match.start()].strip() + after_locale = t_value[match.end() :].strip() + + # Extract seg_key: everything up to the sequence number + # e.g., "Value Q1 24 Radio 001 VO" -> seg_key="Value Q1 24 Radio 001" + # note_type would be "VO", "Headline", "BVO", "Super", etc. + seg_key = before_locale + + # Extract channel info from seg_key + channel = "" + sub_channel = "" + seg_parts = before_locale.split() + # Try to find channel indicators in the seg_key + channel_keywords = { + "mass", "value", "onsite", "outbound", "radio", + "tv_olv", "display", "ooh", "dooh", "social", "print", + "digital", "crm", "push", + } + for part in seg_parts: + if part.lower() in channel_keywords: + if not channel: + channel = part + else: + sub_channel = part + + # Extract note_type from the end of before_locale + note_type = "" + note_keywords = {"vo", "bvo", "super", "headline", "legal", "cta", + "body", "disclaimer", "endline"} + for part in reversed(seg_parts): + if part.lower() in note_keywords: + note_type = part + break + + # For EN/TX split: the boundary is where the language switches + # We store the full text and let the retrieval agent handle it + # Simple heuristic: store everything after locale as combined en+tx + en_text = after_locale + tx_text = after_locale + + # Try to extract year from seg_key for the date field + date = "" + year_match = re.search(r"\b(\d{2})\b", before_locale) + if year_match: + date = year_match.group(1) + + return TMEntry( + seg_key=seg_key, + date=date, + en=en_text, + lc=locale_code, + tx=tx_text, + nt=note_type, + channel=channel, + sub_channel=sub_channel, + _text=t_value, + ) + + def load_multiple_tm_files( file_paths: list[str], target_locale: str, diff --git a/backend/app/tasks/job_tasks.py b/backend/app/tasks/job_tasks.py index 883e15e..95c0157 100644 --- a/backend/app/tasks/job_tasks.py +++ b/backend/app/tasks/job_tasks.py @@ -17,6 +17,80 @@ from app.tasks.celery_app import celery_app logger = logging.getLogger(__name__) +# TM channel registry: channel name -> TM file pattern +TM_CHANNEL_REGISTRY: dict[str, str] = { + "mass": "flat_MASS_{lc}.json", + "value": "flat_value_{lc}.json", + "onsite": "flat_Onsite_{lc}.json", + "outbound": "flat_Outbound_{lc}.json", +} + + +def _resolve_file_manifest( + locale_code: str, channel: str, client_id: str +) -> dict: + """Resolve all reference and TM file paths for a locale. + + File layout in storage: + - /storage/amazon/tm/{locale_code}/{tm_file} + - /storage/amazon/ref/glossary/{lc_norm}_glossary.json + - /storage/amazon/ref/blacklist/{lc_norm}_blacklist.json + - /storage/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json + - /storage/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json (de-DE, de-AT only) + - /storage/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json + - /storage/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json + """ + import os + + storage = settings.STORAGE_ROOT + # Normalise locale: de-DE -> de_DE for ref files, de-de for TM files + lc_norm = locale_code.replace("-", "_") + lc_lower = locale_code.lower() + + def _check(path: str) -> str | None: + return path if os.path.exists(path) else None + + # Resolve TM file + tm_files: list[str] = [] + channel_lower = channel.lower() if channel else "" + tm_pattern = TM_CHANNEL_REGISTRY.get(channel_lower) + if tm_pattern: + tm_filename = tm_pattern.replace("{lc}", lc_lower) + tm_path = f"{storage}/amazon/tm/{locale_code}/{tm_filename}" + if os.path.exists(tm_path): + tm_files.append(tm_path) + + # Resolve reference files + glossary = _check(f"{storage}/amazon/ref/glossary/{lc_norm}_glossary.json") + blacklist = _check(f"{storage}/amazon/ref/blacklist/{lc_norm}_blacklist.json") + tov_global = _check( + f"{storage}/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json" + ) + locale_considerations = _check( + f"{storage}/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json" + ) + date_pct = _check( + f"{storage}/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json" + ) + + # DE/AT-specific TOV supplement + tov_supplement = None + if locale_code in ("de-DE", "de-AT"): + tov_supplement = _check( + f"{storage}/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json" + ) + + return { + "tm_files": tm_files, + "glossary_file": glossary, + "blacklist_file": blacklist, + "tov_global_file": tov_global, + "tov_supplement_file": tov_supplement, + "locale_considerations_file": locale_considerations, + "date_pct_formats_file": date_pct, + } + + def _get_async_session_factory() -> async_sessionmaker[AsyncSession]: """Create a fresh async session factory for use in Celery tasks.""" engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True) @@ -130,10 +204,16 @@ def process_locale_instance(self, job_id: str, locale_code: str) -> dict: "context_prompt": job.context_prompt, } + # Resolve file manifest for this locale + file_manifest = _resolve_file_manifest( + locale_code, job.channel, str(job.client_id) + ) + # Run pipeline orchestrator = PipelineOrchestrator( job_params=job_params, source_lines=source_lines, + file_manifest=file_manifest, output_dir=settings.STORAGE_ROOT, ) context = await orchestrator.run()