Bug 1: Empty tm_channels was silently re-defaulted to [campaign channel]
in both agent_single.py and job_tasks.py via `or [channel]`. Python's
`or` treats [] as falsy, so the frontend's empty-list intent was lost.
Fixed by replacing `or` with an explicit `is not None` check at both
sites. Empty list now means "load no TMs"; None still falls back.
Bug 2: Supplementary files dropped by Agent1Validator. The validator
built FileManifest(...) with explicit kwargs but forgot
supplementary_files, so the raw field from _resolve_file_manifest
never reached agent_single.run(). Files were uploaded to disk but
never inlined into the LLM context. Fixed by adding
supplementary_files=raw.get("supplementary_files", []) to the
validator's FileManifest construction.
Bug 3: New TM channels lowercased in StepReview.tsx, breaking
case-sensitive file lookup. On Linux, "flat_primecbmt_nl-be.json"
≠ "flat_PrimeCBMT_nl-be.json", so the file was silently skipped and
zero TM entries loaded. Legacy channels worked only because the
hardcoded CHANNEL_FILE_MAP has lowercase keys mapping to
canonically-cased filenames; auto-discovered channels (PrimeCBM,
PrimeCBMT, etc.) had no such safety net. Two-part fix:
3a. StepReview.tsx no longer lowercases tm_channels — preserves case
end-to-end from registry → frontend → backend → disk.
3b. _resolve_all_tm_paths builds a case-insensitive index of the
locale's TM directory once per call and resolves filenames
against it. Forgives any historical case-drift between registry
and disk.
Verified end-to-end with a standalone test script run inside the
backend container: 8/8 assertions pass covering empty tm_channels,
supplementary file passthrough, exact-case lookups, lowercase
fallback, missing channels, legacy MASS in both cases, and empty
tm_channels yielding no TM paths.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
430 lines
17 KiB
Python
430 lines
17 KiB
Python
"""Celery tasks for job processing."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.config import settings
|
|
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
|
|
from app.models.output import ConfidenceTier, OutputRow
|
|
from app.models.source import SourceLine
|
|
from app.pipeline.orchestrator import PipelineOrchestrator
|
|
from app.tasks.celery_app import celery_app
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# TM channel registry: channel name -> TM file pattern
|
|
TM_CHANNEL_REGISTRY: dict[str, str] = {
|
|
"mass": "flat_MASS_{lc}.json",
|
|
"value": "flat_value_{lc}.json",
|
|
"onsite": "flat_Onsite_{lc}.json",
|
|
"outbound": "flat_Outbound_{lc}.json",
|
|
}
|
|
|
|
|
|
def _resolve_file_manifest(
|
|
locale_code: str, channel: str, client_id: str, job_id: str | None = None,
|
|
) -> dict:
|
|
"""Resolve all reference and TM file paths for a locale.
|
|
|
|
File layout in storage:
|
|
- /storage/amazon/tm/{locale_code}/{tm_file}
|
|
- /storage/amazon/ref/glossary/{lc_norm}_glossary.json
|
|
- /storage/amazon/ref/blacklist/{lc_norm}_blacklist.json
|
|
- /storage/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json
|
|
- /storage/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json (de-DE, de-AT only)
|
|
- /storage/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json
|
|
- /storage/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json
|
|
"""
|
|
import os
|
|
|
|
storage = settings.STORAGE_ROOT
|
|
# Normalise locale: de-DE -> de_DE for ref files, de-de for TM files
|
|
lc_norm = locale_code.replace("-", "_")
|
|
lc_lower = locale_code.lower()
|
|
|
|
def _check(path: str) -> str | None:
|
|
return path if os.path.exists(path) else None
|
|
|
|
# Resolve TM file. Try the legacy registry first; if the channel isn't
|
|
# listed, fall back to the generic pattern so newly-registered channels
|
|
# (e.g. PrimeCBM) work without code changes.
|
|
tm_files: list[str] = []
|
|
channel_lower = channel.lower() if channel else ""
|
|
tm_pattern = TM_CHANNEL_REGISTRY.get(channel_lower)
|
|
if tm_pattern:
|
|
tm_filename = tm_pattern.replace("{lc}", lc_lower)
|
|
elif channel:
|
|
tm_filename = f"flat_{channel}_{lc_lower}.json"
|
|
else:
|
|
tm_filename = ""
|
|
if tm_filename:
|
|
tm_path = f"{storage}/amazon/tm/{locale_code}/{tm_filename}"
|
|
if os.path.exists(tm_path):
|
|
tm_files.append(tm_path)
|
|
|
|
# Resolve reference files
|
|
glossary = _check(f"{storage}/amazon/ref/glossary/{lc_norm}_glossary.json")
|
|
blacklist = _check(f"{storage}/amazon/ref/blacklist/{lc_norm}_blacklist.json")
|
|
tov_global = _check(
|
|
f"{storage}/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json"
|
|
)
|
|
locale_considerations = _check(
|
|
f"{storage}/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json"
|
|
)
|
|
date_pct = _check(
|
|
f"{storage}/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json"
|
|
)
|
|
|
|
# DE/AT-specific TOV supplement
|
|
tov_supplement = None
|
|
if locale_code in ("de-DE", "de-AT"):
|
|
tov_supplement = _check(
|
|
f"{storage}/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json"
|
|
)
|
|
|
|
# Per-job supplementary files. Each filename may carry a locale
|
|
# prefix (e.g. "de-DE_terms.txt") to restrict the file to that locale
|
|
# only; files without a recognised locale prefix are global.
|
|
supplementary_files: list[str] = []
|
|
if job_id:
|
|
supp_dir = os.path.join(storage, "jobs", str(job_id), "supplementary")
|
|
if os.path.isdir(supp_dir):
|
|
for fname in sorted(os.listdir(supp_dir)):
|
|
fpath = os.path.join(supp_dir, fname)
|
|
if not os.path.isfile(fpath):
|
|
continue
|
|
if _supplementary_applies_to_locale(fname, locale_code):
|
|
supplementary_files.append(fpath)
|
|
|
|
return {
|
|
"tm_files": tm_files,
|
|
"glossary_file": glossary,
|
|
"blacklist_file": blacklist,
|
|
"tov_global_file": tov_global,
|
|
"tov_supplement_file": tov_supplement,
|
|
"locale_considerations_file": locale_considerations,
|
|
"date_pct_formats_file": date_pct,
|
|
"supplementary_files": supplementary_files,
|
|
}
|
|
|
|
|
|
# All 12 supported locale codes (lowercased) — used to detect locale-
|
|
# prefixed supplementary filenames.
|
|
_LOCALE_CODES = {
|
|
"de-de", "de-at", "fr-fr", "fr-be", "it-it", "es-es",
|
|
"ca-es", "nl-nl", "nl-be", "sv-se", "pl-pl", "pt-pt",
|
|
}
|
|
|
|
|
|
def _supplementary_applies_to_locale(filename: str, locale_code: str) -> bool:
|
|
"""Return True if a supplementary file applies to the given locale.
|
|
|
|
Locale gating is by filename prefix. Examples:
|
|
"de-DE_terms.txt" → only de-DE
|
|
"de_DE_terms.txt" → only de-DE (underscore variant)
|
|
"fr-FR_brief.docx" → only fr-FR
|
|
"global_glossary.txt" → all locales (no recognised prefix)
|
|
"campaign-brief.pdf" → all locales (no recognised prefix)
|
|
"""
|
|
name = filename.lower()
|
|
target = locale_code.lower()
|
|
# Match either "de-de_..." or "de_de_..." prefixes
|
|
for code in _LOCALE_CODES:
|
|
underscored = code.replace("-", "_")
|
|
if name.startswith(code + "_") or name.startswith(underscored + "_"):
|
|
# File is locale-gated; include only if it matches.
|
|
return code == target
|
|
return True # No locale prefix → global, applies to all locales
|
|
|
|
|
|
def _get_async_session_factory() -> async_sessionmaker[AsyncSession]:
|
|
"""Create a fresh async session factory for use in Celery tasks."""
|
|
engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True)
|
|
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=2)
|
|
def process_job(self, job_id: str) -> dict:
|
|
"""Fan out locale instance processing for a job.
|
|
|
|
For each locale instance in the job, dispatches a process_locale_instance task.
|
|
"""
|
|
logger.info(f"Processing job {job_id}")
|
|
|
|
async def _process() -> dict:
|
|
factory = _get_async_session_factory()
|
|
async with factory() as db:
|
|
# Get job
|
|
result = await db.execute(select(Job).where(Job.id == job_id))
|
|
job = result.scalar_one_or_none()
|
|
if job is None:
|
|
return {"error": f"Job {job_id} not found"}
|
|
|
|
# Update status to running
|
|
job.status = JobStatus.running
|
|
await db.commit()
|
|
|
|
# Get locale instances
|
|
result = await db.execute(
|
|
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
|
|
)
|
|
instances = result.scalars().all()
|
|
|
|
# Dispatch per-locale tasks
|
|
task_ids = []
|
|
for instance in instances:
|
|
task = process_locale_instance.delay(job_id, instance.locale_code)
|
|
task_ids.append(task.id)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"dispatched_locales": len(task_ids),
|
|
"task_ids": task_ids,
|
|
}
|
|
|
|
return asyncio.run(_process())
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=1)
|
|
def process_locale_instance(self, job_id: str, locale_code: str) -> dict:
|
|
"""Process a single locale instance through the pipeline.
|
|
|
|
Runs the pipeline orchestrator for the given job + locale combination.
|
|
"""
|
|
logger.info(f"Processing locale {locale_code} for job {job_id}")
|
|
|
|
async def _process() -> dict:
|
|
factory = _get_async_session_factory()
|
|
async with factory() as db:
|
|
# Get the locale instance
|
|
result = await db.execute(
|
|
select(LocaleInstance).where(
|
|
LocaleInstance.job_id == job_id,
|
|
LocaleInstance.locale_code == locale_code,
|
|
)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
if instance is None:
|
|
return {"error": f"Locale instance not found: {job_id}/{locale_code}"}
|
|
|
|
# Get job
|
|
job_result = await db.execute(select(Job).where(Job.id == job_id))
|
|
job = job_result.scalar_one_or_none()
|
|
if job is None:
|
|
return {"error": f"Job {job_id} not found"}
|
|
|
|
# Update instance status
|
|
instance.status = LocaleStatus.running
|
|
instance.started_at = datetime.now(timezone.utc)
|
|
instance.progress = 0.0
|
|
instance.current_stage = "Starting"
|
|
await db.commit()
|
|
|
|
# Stage name mapping for user-friendly display
|
|
_STAGE_LABELS = {
|
|
"VALIDATE": "Loading Files",
|
|
"SINGLE_AGENT": "Transcreating",
|
|
"TM_RETRIEVE": "Matching TM",
|
|
"RANK": "Ranking Matches",
|
|
"TRANSCREATE": "Translating",
|
|
"COMPLY": "Reviewing",
|
|
"FORMAT": "Formatting Output",
|
|
"DONE": "Complete",
|
|
"ERROR": "Error",
|
|
}
|
|
|
|
async def _on_progress(state: str, message: str, pct: float) -> None:
|
|
"""Update locale instance progress in DB at each pipeline stage."""
|
|
try:
|
|
# Use batch-level info from message if available
|
|
if "batch" in message.lower():
|
|
label = message # e.g. "Translating batch 2/4"
|
|
else:
|
|
label = _STAGE_LABELS.get(state, state)
|
|
instance.progress = max(0.0, min(100.0, pct * 100))
|
|
instance.current_stage = label
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.warning("Progress update failed: %s", exc)
|
|
|
|
try:
|
|
# Get source lines
|
|
source_result = await db.execute(
|
|
select(SourceLine)
|
|
.where(SourceLine.job_id == job_id)
|
|
.order_by(SourceLine.row_order)
|
|
)
|
|
source_lines = [
|
|
{
|
|
"id": str(sl.id),
|
|
"en_gb": sl.en_gb,
|
|
"copy_type": sl.copy_type,
|
|
"creative_guidance": sl.creative_guidance,
|
|
"visual_ref": sl.visual_ref,
|
|
"char_limit": sl.char_limit,
|
|
"is_display_format": sl.is_display_format,
|
|
}
|
|
for sl in source_result.scalars().all()
|
|
]
|
|
|
|
# Build job params
|
|
job_params = {
|
|
"job_id": str(job.id),
|
|
"client_id": str(job.client_id),
|
|
"locale_code": locale_code,
|
|
"channel": job.channel,
|
|
"sub_channel": job.sub_channel,
|
|
"programme": job.programme.value,
|
|
"campaign_name": job.campaign_name,
|
|
"context_prompt": job.context_prompt,
|
|
# Preserve an explicit empty list ("no TMs") from
|
|
# the user; only fall back to the campaign channel
|
|
# when the field was never set (legacy jobs).
|
|
"tm_channels": (
|
|
job.tm_channels
|
|
if job.tm_channels is not None
|
|
else [job.channel]
|
|
),
|
|
"llm_model": job.llm_model,
|
|
}
|
|
|
|
# Resolve file manifest for this locale
|
|
file_manifest = _resolve_file_manifest(
|
|
locale_code, job.channel, str(job.client_id), job_id=str(job_id)
|
|
)
|
|
|
|
# Run pipeline with progress callback
|
|
orchestrator = PipelineOrchestrator(
|
|
job_params=job_params,
|
|
source_lines=source_lines,
|
|
file_manifest=file_manifest,
|
|
output_dir=settings.STORAGE_ROOT,
|
|
on_progress=_on_progress,
|
|
)
|
|
context = await orchestrator.run()
|
|
|
|
# Save output rows to database
|
|
for i, draft in enumerate(context.draft_outputs):
|
|
# Find matching source line
|
|
source_line_id = None
|
|
if i < len(source_lines):
|
|
source_line_id = source_lines[i].get("id")
|
|
|
|
# Get confidence tier
|
|
tier = "low"
|
|
if i < len(context.ranking_declarations):
|
|
tier = context.ranking_declarations[i].confidence_tier
|
|
|
|
# Get character counts
|
|
char_counts = {}
|
|
if i < len(context.compliance_results):
|
|
char_counts = context.compliance_results[i].character_counts
|
|
|
|
output_row = OutputRow(
|
|
instance_id=instance.id,
|
|
line_id=source_line_id,
|
|
row_order=i + 1,
|
|
confidence_tier=ConfidenceTier(tier),
|
|
option_1=draft.option_1.text if draft.option_1 else "",
|
|
backtranslation_1=draft.option_1.backtranslation if draft.option_1 else "",
|
|
rationale_1=draft.option_1.rationale if draft.option_1 else "",
|
|
option_2=draft.option_2.text if draft.option_2 else None,
|
|
backtranslation_2=draft.option_2.backtranslation if draft.option_2 else None,
|
|
rationale_2=draft.option_2.rationale if draft.option_2 else None,
|
|
option_3=draft.option_3.text if draft.option_3 else None,
|
|
backtranslation_3=draft.option_3.backtranslation if draft.option_3 else None,
|
|
rationale_3=draft.option_3.rationale if draft.option_3 else None,
|
|
tm_entries_cited=draft.tm_entries_cited if draft.tm_entries_cited else None,
|
|
character_count_option_1=char_counts.get("option_1"),
|
|
character_count_option_2=char_counts.get("option_2"),
|
|
character_count_option_3=char_counts.get("option_3"),
|
|
)
|
|
db.add(output_row)
|
|
|
|
# Update instance status and token usage
|
|
instance.status = LocaleStatus.complete
|
|
instance.completed_at = datetime.now(timezone.utc)
|
|
instance.token_usage = context.total_input_tokens + context.total_output_tokens
|
|
instance.estimated_cost = round(context.total_estimated_cost, 6)
|
|
output_path = (
|
|
f"{settings.STORAGE_ROOT}/jobs/{job_id}/output/"
|
|
f"{locale_code}_{job_id}_output.xlsx"
|
|
)
|
|
instance.output_file_path = output_path
|
|
instance.agent_version = "1.0.0"
|
|
|
|
await db.commit()
|
|
|
|
# Check if all instances are complete
|
|
await _check_job_completion(db, job_id)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"locale_code": locale_code,
|
|
"status": "complete",
|
|
"output_rows": len(context.draft_outputs),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error processing {locale_code} for job {job_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
instance.status = LocaleStatus.error
|
|
instance.error_log = str(e)
|
|
instance.completed_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
# Check if job should be marked as partial/error
|
|
await _check_job_completion(db, job_id)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"locale_code": locale_code,
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
|
|
return asyncio.run(_process())
|
|
|
|
|
|
async def _check_job_completion(db: AsyncSession, job_id: str) -> None:
|
|
"""Check if all locale instances are done and update job status accordingly."""
|
|
result = await db.execute(
|
|
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
|
|
)
|
|
instances = list(result.scalars().all())
|
|
|
|
if not instances:
|
|
return
|
|
|
|
all_complete = all(i.status == LocaleStatus.complete for i in instances)
|
|
all_done = all(
|
|
i.status in (LocaleStatus.complete, LocaleStatus.error)
|
|
for i in instances
|
|
)
|
|
any_error = any(i.status == LocaleStatus.error for i in instances)
|
|
|
|
job_result = await db.execute(select(Job).where(Job.id == job_id))
|
|
job = job_result.scalar_one_or_none()
|
|
if job is None:
|
|
return
|
|
|
|
if all_complete:
|
|
job.status = JobStatus.complete
|
|
# Sum up token usage
|
|
job.total_token_usage = sum(i.token_usage for i in instances)
|
|
job.total_estimated_cost = sum(i.estimated_cost for i in instances)
|
|
elif all_done and any_error:
|
|
any_complete = any(i.status == LocaleStatus.complete for i in instances)
|
|
if any_complete:
|
|
job.status = JobStatus.partial_complete
|
|
else:
|
|
job.status = JobStatus.error
|
|
|
|
await db.commit()
|