amazon-transcreation/backend/app/tasks/job_tasks.py
DJP 5e0a148b96 feat: add token usage tracking, feedback highlighting, cost on cards, help page
- Wire token usage from LLM agents through pipeline context to DB and frontend
- Agents 2 and 4 accumulate input/output tokens and cost into PipelineContext
- job_tasks.py saves token totals to locale instance after pipeline completion
- Monitoring cards show total tokens and estimated cost instead of broken 0/0
- Make feedback highlighting bolder: colored card borders, stronger button states
- Add estimated cost display to dashboard job cards
- Add Help page with full documentation and link in sidebar navigation
- Comprehensive README with ASCII architecture diagrams

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 16:47:36 -04:00

369 lines
14 KiB
Python

"""Celery tasks for job processing."""
import asyncio
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
from app.models.output import ConfidenceTier, OutputRow
from app.models.source import SourceLine
from app.pipeline.orchestrator import PipelineOrchestrator
from app.tasks.celery_app import celery_app
logger = logging.getLogger(__name__)
# TM channel registry: channel name -> TM file pattern
TM_CHANNEL_REGISTRY: dict[str, str] = {
"mass": "flat_MASS_{lc}.json",
"value": "flat_value_{lc}.json",
"onsite": "flat_Onsite_{lc}.json",
"outbound": "flat_Outbound_{lc}.json",
}
def _resolve_file_manifest(
locale_code: str, channel: str, client_id: str
) -> dict:
"""Resolve all reference and TM file paths for a locale.
File layout in storage:
- /storage/amazon/tm/{locale_code}/{tm_file}
- /storage/amazon/ref/glossary/{lc_norm}_glossary.json
- /storage/amazon/ref/blacklist/{lc_norm}_blacklist.json
- /storage/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json
- /storage/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json (de-DE, de-AT only)
- /storage/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json
- /storage/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json
"""
import os
storage = settings.STORAGE_ROOT
# Normalise locale: de-DE -> de_DE for ref files, de-de for TM files
lc_norm = locale_code.replace("-", "_")
lc_lower = locale_code.lower()
def _check(path: str) -> str | None:
return path if os.path.exists(path) else None
# Resolve TM file
tm_files: list[str] = []
channel_lower = channel.lower() if channel else ""
tm_pattern = TM_CHANNEL_REGISTRY.get(channel_lower)
if tm_pattern:
tm_filename = tm_pattern.replace("{lc}", lc_lower)
tm_path = f"{storage}/amazon/tm/{locale_code}/{tm_filename}"
if os.path.exists(tm_path):
tm_files.append(tm_path)
# Resolve reference files
glossary = _check(f"{storage}/amazon/ref/glossary/{lc_norm}_glossary.json")
blacklist = _check(f"{storage}/amazon/ref/blacklist/{lc_norm}_blacklist.json")
tov_global = _check(
f"{storage}/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json"
)
locale_considerations = _check(
f"{storage}/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json"
)
date_pct = _check(
f"{storage}/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json"
)
# DE/AT-specific TOV supplement
tov_supplement = None
if locale_code in ("de-DE", "de-AT"):
tov_supplement = _check(
f"{storage}/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json"
)
return {
"tm_files": tm_files,
"glossary_file": glossary,
"blacklist_file": blacklist,
"tov_global_file": tov_global,
"tov_supplement_file": tov_supplement,
"locale_considerations_file": locale_considerations,
"date_pct_formats_file": date_pct,
}
def _get_async_session_factory() -> async_sessionmaker[AsyncSession]:
"""Create a fresh async session factory for use in Celery tasks."""
engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True)
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
@celery_app.task(bind=True, max_retries=2)
def process_job(self, job_id: str) -> dict:
"""Fan out locale instance processing for a job.
For each locale instance in the job, dispatches a process_locale_instance task.
"""
logger.info(f"Processing job {job_id}")
async def _process() -> dict:
factory = _get_async_session_factory()
async with factory() as db:
# Get job
result = await db.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if job is None:
return {"error": f"Job {job_id} not found"}
# Update status to running
job.status = JobStatus.running
await db.commit()
# Get locale instances
result = await db.execute(
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
)
instances = result.scalars().all()
# Dispatch per-locale tasks
task_ids = []
for instance in instances:
task = process_locale_instance.delay(job_id, instance.locale_code)
task_ids.append(task.id)
return {
"job_id": job_id,
"dispatched_locales": len(task_ids),
"task_ids": task_ids,
}
return asyncio.run(_process())
@celery_app.task(bind=True, max_retries=1)
def process_locale_instance(self, job_id: str, locale_code: str) -> dict:
"""Process a single locale instance through the pipeline.
Runs the pipeline orchestrator for the given job + locale combination.
"""
logger.info(f"Processing locale {locale_code} for job {job_id}")
async def _process() -> dict:
factory = _get_async_session_factory()
async with factory() as db:
# Get the locale instance
result = await db.execute(
select(LocaleInstance).where(
LocaleInstance.job_id == job_id,
LocaleInstance.locale_code == locale_code,
)
)
instance = result.scalar_one_or_none()
if instance is None:
return {"error": f"Locale instance not found: {job_id}/{locale_code}"}
# Get job
job_result = await db.execute(select(Job).where(Job.id == job_id))
job = job_result.scalar_one_or_none()
if job is None:
return {"error": f"Job {job_id} not found"}
# Update instance status
instance.status = LocaleStatus.running
instance.started_at = datetime.now(timezone.utc)
instance.progress = 0.0
instance.current_stage = "Starting"
await db.commit()
# Stage name mapping for user-friendly display
_STAGE_LABELS = {
"VALIDATE": "Loading Files",
"TM_RETRIEVE": "Matching TM",
"RANK": "Ranking Matches",
"TRANSCREATE": "Translating",
"COMPLY": "Reviewing",
"FORMAT": "Formatting Output",
"DONE": "Complete",
"ERROR": "Error",
}
async def _on_progress(state: str, message: str, pct: float) -> None:
"""Update locale instance progress in DB at each pipeline stage."""
try:
# Use batch-level info from message if available
if "batch" in message.lower():
label = message # e.g. "Translating batch 2/4"
else:
label = _STAGE_LABELS.get(state, state)
instance.progress = max(0.0, min(100.0, pct * 100))
instance.current_stage = label
await db.commit()
except Exception as exc:
logger.warning("Progress update failed: %s", exc)
try:
# Get source lines
source_result = await db.execute(
select(SourceLine)
.where(SourceLine.job_id == job_id)
.order_by(SourceLine.row_order)
)
source_lines = [
{
"id": str(sl.id),
"en_gb": sl.en_gb,
"copy_type": sl.copy_type,
"creative_guidance": sl.creative_guidance,
"visual_ref": sl.visual_ref,
"char_limit": sl.char_limit,
"is_display_format": sl.is_display_format,
}
for sl in source_result.scalars().all()
]
# Build job params
job_params = {
"job_id": str(job.id),
"client_id": str(job.client_id),
"locale_code": locale_code,
"channel": job.channel,
"sub_channel": job.sub_channel,
"programme": job.programme.value,
"campaign_name": job.campaign_name,
"context_prompt": job.context_prompt,
}
# Resolve file manifest for this locale
file_manifest = _resolve_file_manifest(
locale_code, job.channel, str(job.client_id)
)
# Run pipeline with progress callback
orchestrator = PipelineOrchestrator(
job_params=job_params,
source_lines=source_lines,
file_manifest=file_manifest,
output_dir=settings.STORAGE_ROOT,
on_progress=_on_progress,
)
context = await orchestrator.run()
# Save output rows to database
for i, draft in enumerate(context.draft_outputs):
# Find matching source line
source_line_id = None
if i < len(source_lines):
source_line_id = source_lines[i].get("id")
# Get confidence tier
tier = "low"
if i < len(context.ranking_declarations):
tier = context.ranking_declarations[i].confidence_tier
# Get character counts
char_counts = {}
if i < len(context.compliance_results):
char_counts = context.compliance_results[i].character_counts
output_row = OutputRow(
instance_id=instance.id,
line_id=source_line_id,
row_order=i + 1,
confidence_tier=ConfidenceTier(tier),
option_1=draft.option_1.text if draft.option_1 else "",
backtranslation_1=draft.option_1.backtranslation if draft.option_1 else "",
rationale_1=draft.option_1.rationale if draft.option_1 else "",
option_2=draft.option_2.text if draft.option_2 else None,
backtranslation_2=draft.option_2.backtranslation if draft.option_2 else None,
rationale_2=draft.option_2.rationale if draft.option_2 else None,
option_3=draft.option_3.text if draft.option_3 else None,
backtranslation_3=draft.option_3.backtranslation if draft.option_3 else None,
rationale_3=draft.option_3.rationale if draft.option_3 else None,
tm_entries_cited=draft.tm_entries_cited if draft.tm_entries_cited else None,
character_count_option_1=char_counts.get("option_1"),
character_count_option_2=char_counts.get("option_2"),
character_count_option_3=char_counts.get("option_3"),
)
db.add(output_row)
# Update instance status and token usage
instance.status = LocaleStatus.complete
instance.completed_at = datetime.now(timezone.utc)
instance.token_usage = context.total_input_tokens + context.total_output_tokens
instance.estimated_cost = round(context.total_estimated_cost, 6)
output_path = (
f"{settings.STORAGE_ROOT}/jobs/{job_id}/output/"
f"{locale_code}_{job_id}_output.xlsx"
)
instance.output_file_path = output_path
instance.agent_version = "1.0.0"
await db.commit()
# Check if all instances are complete
await _check_job_completion(db, job_id)
return {
"job_id": job_id,
"locale_code": locale_code,
"status": "complete",
"output_rows": len(context.draft_outputs),
}
except Exception as e:
logger.error(
f"Error processing {locale_code} for job {job_id}: {e}",
exc_info=True,
)
instance.status = LocaleStatus.error
instance.error_log = str(e)
instance.completed_at = datetime.now(timezone.utc)
await db.commit()
# Check if job should be marked as partial/error
await _check_job_completion(db, job_id)
return {
"job_id": job_id,
"locale_code": locale_code,
"status": "error",
"error": str(e),
}
return asyncio.run(_process())
async def _check_job_completion(db: AsyncSession, job_id: str) -> None:
"""Check if all locale instances are done and update job status accordingly."""
result = await db.execute(
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
)
instances = list(result.scalars().all())
if not instances:
return
all_complete = all(i.status == LocaleStatus.complete for i in instances)
all_done = all(
i.status in (LocaleStatus.complete, LocaleStatus.error)
for i in instances
)
any_error = any(i.status == LocaleStatus.error for i in instances)
job_result = await db.execute(select(Job).where(Job.id == job_id))
job = job_result.scalar_one_or_none()
if job is None:
return
if all_complete:
job.status = JobStatus.complete
# Sum up token usage
job.total_token_usage = sum(i.token_usage for i in instances)
job.total_estimated_cost = sum(i.estimated_cost for i in instances)
elif all_done and any_error:
any_complete = any(i.status == LocaleStatus.complete for i in instances)
if any_complete:
job.status = JobStatus.partial_complete
else:
job.status = JobStatus.error
await db.commit()