amazon-transcreation/backend/app/tasks/job_tasks.py
DJP bb8ed2a004 Round 2.7: three broken promises — empty TM, supplementary files, new-TM casing
Bug 1: Empty tm_channels was silently re-defaulted to [campaign channel]
  in both agent_single.py and job_tasks.py via `or [channel]`. Python's
  `or` treats [] as falsy, so the frontend's empty-list intent was lost.
  Fixed by replacing `or` with an explicit `is not None` check at both
  sites. Empty list now means "load no TMs"; None still falls back.

Bug 2: Supplementary files dropped by Agent1Validator. The validator
  built FileManifest(...) with explicit kwargs but forgot
  supplementary_files, so the raw field from _resolve_file_manifest
  never reached agent_single.run(). Files were uploaded to disk but
  never inlined into the LLM context. Fixed by adding
  supplementary_files=raw.get("supplementary_files", []) to the
  validator's FileManifest construction.

Bug 3: New TM channels lowercased in StepReview.tsx, breaking
  case-sensitive file lookup. On Linux, "flat_primecbmt_nl-be.json"
  ≠ "flat_PrimeCBMT_nl-be.json", so the file was silently skipped and
  zero TM entries loaded. Legacy channels worked only because the
  hardcoded CHANNEL_FILE_MAP has lowercase keys mapping to
  canonically-cased filenames; auto-discovered channels (PrimeCBM,
  PrimeCBMT, etc.) had no such safety net. Two-part fix:

  3a. StepReview.tsx no longer lowercases tm_channels — preserves case
      end-to-end from registry → frontend → backend → disk.

  3b. _resolve_all_tm_paths builds a case-insensitive index of the
      locale's TM directory once per call and resolves filenames
      against it. Forgives any historical case-drift between registry
      and disk.

Verified end-to-end with a standalone test script run inside the
backend container: 8/8 assertions pass covering empty tm_channels,
supplementary file passthrough, exact-case lookups, lowercase
fallback, missing channels, legacy MASS in both cases, and empty
tm_channels yielding no TM paths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-11 10:57:21 -04:00

430 lines
17 KiB
Python

"""Celery tasks for job processing."""
import asyncio
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
from app.models.output import ConfidenceTier, OutputRow
from app.models.source import SourceLine
from app.pipeline.orchestrator import PipelineOrchestrator
from app.tasks.celery_app import celery_app
logger = logging.getLogger(__name__)
# TM channel registry: channel name -> TM file pattern
TM_CHANNEL_REGISTRY: dict[str, str] = {
"mass": "flat_MASS_{lc}.json",
"value": "flat_value_{lc}.json",
"onsite": "flat_Onsite_{lc}.json",
"outbound": "flat_Outbound_{lc}.json",
}
def _resolve_file_manifest(
locale_code: str, channel: str, client_id: str, job_id: str | None = None,
) -> dict:
"""Resolve all reference and TM file paths for a locale.
File layout in storage:
- /storage/amazon/tm/{locale_code}/{tm_file}
- /storage/amazon/ref/glossary/{lc_norm}_glossary.json
- /storage/amazon/ref/blacklist/{lc_norm}_blacklist.json
- /storage/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json
- /storage/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json (de-DE, de-AT only)
- /storage/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json
- /storage/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json
"""
import os
storage = settings.STORAGE_ROOT
# Normalise locale: de-DE -> de_DE for ref files, de-de for TM files
lc_norm = locale_code.replace("-", "_")
lc_lower = locale_code.lower()
def _check(path: str) -> str | None:
return path if os.path.exists(path) else None
# Resolve TM file. Try the legacy registry first; if the channel isn't
# listed, fall back to the generic pattern so newly-registered channels
# (e.g. PrimeCBM) work without code changes.
tm_files: list[str] = []
channel_lower = channel.lower() if channel else ""
tm_pattern = TM_CHANNEL_REGISTRY.get(channel_lower)
if tm_pattern:
tm_filename = tm_pattern.replace("{lc}", lc_lower)
elif channel:
tm_filename = f"flat_{channel}_{lc_lower}.json"
else:
tm_filename = ""
if tm_filename:
tm_path = f"{storage}/amazon/tm/{locale_code}/{tm_filename}"
if os.path.exists(tm_path):
tm_files.append(tm_path)
# Resolve reference files
glossary = _check(f"{storage}/amazon/ref/glossary/{lc_norm}_glossary.json")
blacklist = _check(f"{storage}/amazon/ref/blacklist/{lc_norm}_blacklist.json")
tov_global = _check(
f"{storage}/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json"
)
locale_considerations = _check(
f"{storage}/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json"
)
date_pct = _check(
f"{storage}/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json"
)
# DE/AT-specific TOV supplement
tov_supplement = None
if locale_code in ("de-DE", "de-AT"):
tov_supplement = _check(
f"{storage}/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json"
)
# Per-job supplementary files. Each filename may carry a locale
# prefix (e.g. "de-DE_terms.txt") to restrict the file to that locale
# only; files without a recognised locale prefix are global.
supplementary_files: list[str] = []
if job_id:
supp_dir = os.path.join(storage, "jobs", str(job_id), "supplementary")
if os.path.isdir(supp_dir):
for fname in sorted(os.listdir(supp_dir)):
fpath = os.path.join(supp_dir, fname)
if not os.path.isfile(fpath):
continue
if _supplementary_applies_to_locale(fname, locale_code):
supplementary_files.append(fpath)
return {
"tm_files": tm_files,
"glossary_file": glossary,
"blacklist_file": blacklist,
"tov_global_file": tov_global,
"tov_supplement_file": tov_supplement,
"locale_considerations_file": locale_considerations,
"date_pct_formats_file": date_pct,
"supplementary_files": supplementary_files,
}
# All 12 supported locale codes (lowercased) — used to detect locale-
# prefixed supplementary filenames.
_LOCALE_CODES = {
"de-de", "de-at", "fr-fr", "fr-be", "it-it", "es-es",
"ca-es", "nl-nl", "nl-be", "sv-se", "pl-pl", "pt-pt",
}
def _supplementary_applies_to_locale(filename: str, locale_code: str) -> bool:
"""Return True if a supplementary file applies to the given locale.
Locale gating is by filename prefix. Examples:
"de-DE_terms.txt" → only de-DE
"de_DE_terms.txt" → only de-DE (underscore variant)
"fr-FR_brief.docx" → only fr-FR
"global_glossary.txt" → all locales (no recognised prefix)
"campaign-brief.pdf" → all locales (no recognised prefix)
"""
name = filename.lower()
target = locale_code.lower()
# Match either "de-de_..." or "de_de_..." prefixes
for code in _LOCALE_CODES:
underscored = code.replace("-", "_")
if name.startswith(code + "_") or name.startswith(underscored + "_"):
# File is locale-gated; include only if it matches.
return code == target
return True # No locale prefix → global, applies to all locales
def _get_async_session_factory() -> async_sessionmaker[AsyncSession]:
"""Create a fresh async session factory for use in Celery tasks."""
engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True)
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
@celery_app.task(bind=True, max_retries=2)
def process_job(self, job_id: str) -> dict:
"""Fan out locale instance processing for a job.
For each locale instance in the job, dispatches a process_locale_instance task.
"""
logger.info(f"Processing job {job_id}")
async def _process() -> dict:
factory = _get_async_session_factory()
async with factory() as db:
# Get job
result = await db.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if job is None:
return {"error": f"Job {job_id} not found"}
# Update status to running
job.status = JobStatus.running
await db.commit()
# Get locale instances
result = await db.execute(
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
)
instances = result.scalars().all()
# Dispatch per-locale tasks
task_ids = []
for instance in instances:
task = process_locale_instance.delay(job_id, instance.locale_code)
task_ids.append(task.id)
return {
"job_id": job_id,
"dispatched_locales": len(task_ids),
"task_ids": task_ids,
}
return asyncio.run(_process())
@celery_app.task(bind=True, max_retries=1)
def process_locale_instance(self, job_id: str, locale_code: str) -> dict:
"""Process a single locale instance through the pipeline.
Runs the pipeline orchestrator for the given job + locale combination.
"""
logger.info(f"Processing locale {locale_code} for job {job_id}")
async def _process() -> dict:
factory = _get_async_session_factory()
async with factory() as db:
# Get the locale instance
result = await db.execute(
select(LocaleInstance).where(
LocaleInstance.job_id == job_id,
LocaleInstance.locale_code == locale_code,
)
)
instance = result.scalar_one_or_none()
if instance is None:
return {"error": f"Locale instance not found: {job_id}/{locale_code}"}
# Get job
job_result = await db.execute(select(Job).where(Job.id == job_id))
job = job_result.scalar_one_or_none()
if job is None:
return {"error": f"Job {job_id} not found"}
# Update instance status
instance.status = LocaleStatus.running
instance.started_at = datetime.now(timezone.utc)
instance.progress = 0.0
instance.current_stage = "Starting"
await db.commit()
# Stage name mapping for user-friendly display
_STAGE_LABELS = {
"VALIDATE": "Loading Files",
"SINGLE_AGENT": "Transcreating",
"TM_RETRIEVE": "Matching TM",
"RANK": "Ranking Matches",
"TRANSCREATE": "Translating",
"COMPLY": "Reviewing",
"FORMAT": "Formatting Output",
"DONE": "Complete",
"ERROR": "Error",
}
async def _on_progress(state: str, message: str, pct: float) -> None:
"""Update locale instance progress in DB at each pipeline stage."""
try:
# Use batch-level info from message if available
if "batch" in message.lower():
label = message # e.g. "Translating batch 2/4"
else:
label = _STAGE_LABELS.get(state, state)
instance.progress = max(0.0, min(100.0, pct * 100))
instance.current_stage = label
await db.commit()
except Exception as exc:
logger.warning("Progress update failed: %s", exc)
try:
# Get source lines
source_result = await db.execute(
select(SourceLine)
.where(SourceLine.job_id == job_id)
.order_by(SourceLine.row_order)
)
source_lines = [
{
"id": str(sl.id),
"en_gb": sl.en_gb,
"copy_type": sl.copy_type,
"creative_guidance": sl.creative_guidance,
"visual_ref": sl.visual_ref,
"char_limit": sl.char_limit,
"is_display_format": sl.is_display_format,
}
for sl in source_result.scalars().all()
]
# Build job params
job_params = {
"job_id": str(job.id),
"client_id": str(job.client_id),
"locale_code": locale_code,
"channel": job.channel,
"sub_channel": job.sub_channel,
"programme": job.programme.value,
"campaign_name": job.campaign_name,
"context_prompt": job.context_prompt,
# Preserve an explicit empty list ("no TMs") from
# the user; only fall back to the campaign channel
# when the field was never set (legacy jobs).
"tm_channels": (
job.tm_channels
if job.tm_channels is not None
else [job.channel]
),
"llm_model": job.llm_model,
}
# Resolve file manifest for this locale
file_manifest = _resolve_file_manifest(
locale_code, job.channel, str(job.client_id), job_id=str(job_id)
)
# Run pipeline with progress callback
orchestrator = PipelineOrchestrator(
job_params=job_params,
source_lines=source_lines,
file_manifest=file_manifest,
output_dir=settings.STORAGE_ROOT,
on_progress=_on_progress,
)
context = await orchestrator.run()
# Save output rows to database
for i, draft in enumerate(context.draft_outputs):
# Find matching source line
source_line_id = None
if i < len(source_lines):
source_line_id = source_lines[i].get("id")
# Get confidence tier
tier = "low"
if i < len(context.ranking_declarations):
tier = context.ranking_declarations[i].confidence_tier
# Get character counts
char_counts = {}
if i < len(context.compliance_results):
char_counts = context.compliance_results[i].character_counts
output_row = OutputRow(
instance_id=instance.id,
line_id=source_line_id,
row_order=i + 1,
confidence_tier=ConfidenceTier(tier),
option_1=draft.option_1.text if draft.option_1 else "",
backtranslation_1=draft.option_1.backtranslation if draft.option_1 else "",
rationale_1=draft.option_1.rationale if draft.option_1 else "",
option_2=draft.option_2.text if draft.option_2 else None,
backtranslation_2=draft.option_2.backtranslation if draft.option_2 else None,
rationale_2=draft.option_2.rationale if draft.option_2 else None,
option_3=draft.option_3.text if draft.option_3 else None,
backtranslation_3=draft.option_3.backtranslation if draft.option_3 else None,
rationale_3=draft.option_3.rationale if draft.option_3 else None,
tm_entries_cited=draft.tm_entries_cited if draft.tm_entries_cited else None,
character_count_option_1=char_counts.get("option_1"),
character_count_option_2=char_counts.get("option_2"),
character_count_option_3=char_counts.get("option_3"),
)
db.add(output_row)
# Update instance status and token usage
instance.status = LocaleStatus.complete
instance.completed_at = datetime.now(timezone.utc)
instance.token_usage = context.total_input_tokens + context.total_output_tokens
instance.estimated_cost = round(context.total_estimated_cost, 6)
output_path = (
f"{settings.STORAGE_ROOT}/jobs/{job_id}/output/"
f"{locale_code}_{job_id}_output.xlsx"
)
instance.output_file_path = output_path
instance.agent_version = "1.0.0"
await db.commit()
# Check if all instances are complete
await _check_job_completion(db, job_id)
return {
"job_id": job_id,
"locale_code": locale_code,
"status": "complete",
"output_rows": len(context.draft_outputs),
}
except Exception as e:
logger.error(
f"Error processing {locale_code} for job {job_id}: {e}",
exc_info=True,
)
instance.status = LocaleStatus.error
instance.error_log = str(e)
instance.completed_at = datetime.now(timezone.utc)
await db.commit()
# Check if job should be marked as partial/error
await _check_job_completion(db, job_id)
return {
"job_id": job_id,
"locale_code": locale_code,
"status": "error",
"error": str(e),
}
return asyncio.run(_process())
async def _check_job_completion(db: AsyncSession, job_id: str) -> None:
"""Check if all locale instances are done and update job status accordingly."""
result = await db.execute(
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
)
instances = list(result.scalars().all())
if not instances:
return
all_complete = all(i.status == LocaleStatus.complete for i in instances)
all_done = all(
i.status in (LocaleStatus.complete, LocaleStatus.error)
for i in instances
)
any_error = any(i.status == LocaleStatus.error for i in instances)
job_result = await db.execute(select(Job).where(Job.id == job_id))
job = job_result.scalar_one_or_none()
if job is None:
return
if all_complete:
job.status = JobStatus.complete
# Sum up token usage
job.total_token_usage = sum(i.token_usage for i in instances)
job.total_estimated_cost = sum(i.estimated_cost for i in instances)
elif all_done and any_error:
any_complete = any(i.status == LocaleStatus.complete for i in instances)
if any_complete:
job.status = JobStatus.partial_complete
else:
job.status = JobStatus.error
await db.commit()