- Wire token usage from LLM agents through pipeline context to DB and frontend - Agents 2 and 4 accumulate input/output tokens and cost into PipelineContext - job_tasks.py saves token totals to locale instance after pipeline completion - Monitoring cards show total tokens and estimated cost instead of broken 0/0 - Make feedback highlighting bolder: colored card borders, stronger button states - Add estimated cost display to dashboard job cards - Add Help page with full documentation and link in sidebar navigation - Comprehensive README with ASCII architecture diagrams Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
369 lines
14 KiB
Python
369 lines
14 KiB
Python
"""Celery tasks for job processing."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.config import settings
|
|
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
|
|
from app.models.output import ConfidenceTier, OutputRow
|
|
from app.models.source import SourceLine
|
|
from app.pipeline.orchestrator import PipelineOrchestrator
|
|
from app.tasks.celery_app import celery_app
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# TM channel registry: channel name -> TM file pattern
|
|
TM_CHANNEL_REGISTRY: dict[str, str] = {
|
|
"mass": "flat_MASS_{lc}.json",
|
|
"value": "flat_value_{lc}.json",
|
|
"onsite": "flat_Onsite_{lc}.json",
|
|
"outbound": "flat_Outbound_{lc}.json",
|
|
}
|
|
|
|
|
|
def _resolve_file_manifest(
|
|
locale_code: str, channel: str, client_id: str
|
|
) -> dict:
|
|
"""Resolve all reference and TM file paths for a locale.
|
|
|
|
File layout in storage:
|
|
- /storage/amazon/tm/{locale_code}/{tm_file}
|
|
- /storage/amazon/ref/glossary/{lc_norm}_glossary.json
|
|
- /storage/amazon/ref/blacklist/{lc_norm}_blacklist.json
|
|
- /storage/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json
|
|
- /storage/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json (de-DE, de-AT only)
|
|
- /storage/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json
|
|
- /storage/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json
|
|
"""
|
|
import os
|
|
|
|
storage = settings.STORAGE_ROOT
|
|
# Normalise locale: de-DE -> de_DE for ref files, de-de for TM files
|
|
lc_norm = locale_code.replace("-", "_")
|
|
lc_lower = locale_code.lower()
|
|
|
|
def _check(path: str) -> str | None:
|
|
return path if os.path.exists(path) else None
|
|
|
|
# Resolve TM file
|
|
tm_files: list[str] = []
|
|
channel_lower = channel.lower() if channel else ""
|
|
tm_pattern = TM_CHANNEL_REGISTRY.get(channel_lower)
|
|
if tm_pattern:
|
|
tm_filename = tm_pattern.replace("{lc}", lc_lower)
|
|
tm_path = f"{storage}/amazon/tm/{locale_code}/{tm_filename}"
|
|
if os.path.exists(tm_path):
|
|
tm_files.append(tm_path)
|
|
|
|
# Resolve reference files
|
|
glossary = _check(f"{storage}/amazon/ref/glossary/{lc_norm}_glossary.json")
|
|
blacklist = _check(f"{storage}/amazon/ref/blacklist/{lc_norm}_blacklist.json")
|
|
tov_global = _check(
|
|
f"{storage}/amazon/ref/tov_global/Amazon_TOV_Guidelines_for_Transcreation_290224.json"
|
|
)
|
|
locale_considerations = _check(
|
|
f"{storage}/amazon/ref/locale_considerations/{lc_norm}_local_considerations.json"
|
|
)
|
|
date_pct = _check(
|
|
f"{storage}/amazon/ref/date_pct_formats/{lc_norm}_date_percent_formats.json"
|
|
)
|
|
|
|
# DE/AT-specific TOV supplement
|
|
tov_supplement = None
|
|
if locale_code in ("de-DE", "de-AT"):
|
|
tov_supplement = _check(
|
|
f"{storage}/amazon/ref/tov_supplement/DE_AT_TOV_Guidelines.json"
|
|
)
|
|
|
|
return {
|
|
"tm_files": tm_files,
|
|
"glossary_file": glossary,
|
|
"blacklist_file": blacklist,
|
|
"tov_global_file": tov_global,
|
|
"tov_supplement_file": tov_supplement,
|
|
"locale_considerations_file": locale_considerations,
|
|
"date_pct_formats_file": date_pct,
|
|
}
|
|
|
|
|
|
def _get_async_session_factory() -> async_sessionmaker[AsyncSession]:
|
|
"""Create a fresh async session factory for use in Celery tasks."""
|
|
engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True)
|
|
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=2)
|
|
def process_job(self, job_id: str) -> dict:
|
|
"""Fan out locale instance processing for a job.
|
|
|
|
For each locale instance in the job, dispatches a process_locale_instance task.
|
|
"""
|
|
logger.info(f"Processing job {job_id}")
|
|
|
|
async def _process() -> dict:
|
|
factory = _get_async_session_factory()
|
|
async with factory() as db:
|
|
# Get job
|
|
result = await db.execute(select(Job).where(Job.id == job_id))
|
|
job = result.scalar_one_or_none()
|
|
if job is None:
|
|
return {"error": f"Job {job_id} not found"}
|
|
|
|
# Update status to running
|
|
job.status = JobStatus.running
|
|
await db.commit()
|
|
|
|
# Get locale instances
|
|
result = await db.execute(
|
|
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
|
|
)
|
|
instances = result.scalars().all()
|
|
|
|
# Dispatch per-locale tasks
|
|
task_ids = []
|
|
for instance in instances:
|
|
task = process_locale_instance.delay(job_id, instance.locale_code)
|
|
task_ids.append(task.id)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"dispatched_locales": len(task_ids),
|
|
"task_ids": task_ids,
|
|
}
|
|
|
|
return asyncio.run(_process())
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=1)
|
|
def process_locale_instance(self, job_id: str, locale_code: str) -> dict:
|
|
"""Process a single locale instance through the pipeline.
|
|
|
|
Runs the pipeline orchestrator for the given job + locale combination.
|
|
"""
|
|
logger.info(f"Processing locale {locale_code} for job {job_id}")
|
|
|
|
async def _process() -> dict:
|
|
factory = _get_async_session_factory()
|
|
async with factory() as db:
|
|
# Get the locale instance
|
|
result = await db.execute(
|
|
select(LocaleInstance).where(
|
|
LocaleInstance.job_id == job_id,
|
|
LocaleInstance.locale_code == locale_code,
|
|
)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
if instance is None:
|
|
return {"error": f"Locale instance not found: {job_id}/{locale_code}"}
|
|
|
|
# Get job
|
|
job_result = await db.execute(select(Job).where(Job.id == job_id))
|
|
job = job_result.scalar_one_or_none()
|
|
if job is None:
|
|
return {"error": f"Job {job_id} not found"}
|
|
|
|
# Update instance status
|
|
instance.status = LocaleStatus.running
|
|
instance.started_at = datetime.now(timezone.utc)
|
|
instance.progress = 0.0
|
|
instance.current_stage = "Starting"
|
|
await db.commit()
|
|
|
|
# Stage name mapping for user-friendly display
|
|
_STAGE_LABELS = {
|
|
"VALIDATE": "Loading Files",
|
|
"TM_RETRIEVE": "Matching TM",
|
|
"RANK": "Ranking Matches",
|
|
"TRANSCREATE": "Translating",
|
|
"COMPLY": "Reviewing",
|
|
"FORMAT": "Formatting Output",
|
|
"DONE": "Complete",
|
|
"ERROR": "Error",
|
|
}
|
|
|
|
async def _on_progress(state: str, message: str, pct: float) -> None:
|
|
"""Update locale instance progress in DB at each pipeline stage."""
|
|
try:
|
|
# Use batch-level info from message if available
|
|
if "batch" in message.lower():
|
|
label = message # e.g. "Translating batch 2/4"
|
|
else:
|
|
label = _STAGE_LABELS.get(state, state)
|
|
instance.progress = max(0.0, min(100.0, pct * 100))
|
|
instance.current_stage = label
|
|
await db.commit()
|
|
except Exception as exc:
|
|
logger.warning("Progress update failed: %s", exc)
|
|
|
|
try:
|
|
# Get source lines
|
|
source_result = await db.execute(
|
|
select(SourceLine)
|
|
.where(SourceLine.job_id == job_id)
|
|
.order_by(SourceLine.row_order)
|
|
)
|
|
source_lines = [
|
|
{
|
|
"id": str(sl.id),
|
|
"en_gb": sl.en_gb,
|
|
"copy_type": sl.copy_type,
|
|
"creative_guidance": sl.creative_guidance,
|
|
"visual_ref": sl.visual_ref,
|
|
"char_limit": sl.char_limit,
|
|
"is_display_format": sl.is_display_format,
|
|
}
|
|
for sl in source_result.scalars().all()
|
|
]
|
|
|
|
# Build job params
|
|
job_params = {
|
|
"job_id": str(job.id),
|
|
"client_id": str(job.client_id),
|
|
"locale_code": locale_code,
|
|
"channel": job.channel,
|
|
"sub_channel": job.sub_channel,
|
|
"programme": job.programme.value,
|
|
"campaign_name": job.campaign_name,
|
|
"context_prompt": job.context_prompt,
|
|
}
|
|
|
|
# Resolve file manifest for this locale
|
|
file_manifest = _resolve_file_manifest(
|
|
locale_code, job.channel, str(job.client_id)
|
|
)
|
|
|
|
# Run pipeline with progress callback
|
|
orchestrator = PipelineOrchestrator(
|
|
job_params=job_params,
|
|
source_lines=source_lines,
|
|
file_manifest=file_manifest,
|
|
output_dir=settings.STORAGE_ROOT,
|
|
on_progress=_on_progress,
|
|
)
|
|
context = await orchestrator.run()
|
|
|
|
# Save output rows to database
|
|
for i, draft in enumerate(context.draft_outputs):
|
|
# Find matching source line
|
|
source_line_id = None
|
|
if i < len(source_lines):
|
|
source_line_id = source_lines[i].get("id")
|
|
|
|
# Get confidence tier
|
|
tier = "low"
|
|
if i < len(context.ranking_declarations):
|
|
tier = context.ranking_declarations[i].confidence_tier
|
|
|
|
# Get character counts
|
|
char_counts = {}
|
|
if i < len(context.compliance_results):
|
|
char_counts = context.compliance_results[i].character_counts
|
|
|
|
output_row = OutputRow(
|
|
instance_id=instance.id,
|
|
line_id=source_line_id,
|
|
row_order=i + 1,
|
|
confidence_tier=ConfidenceTier(tier),
|
|
option_1=draft.option_1.text if draft.option_1 else "",
|
|
backtranslation_1=draft.option_1.backtranslation if draft.option_1 else "",
|
|
rationale_1=draft.option_1.rationale if draft.option_1 else "",
|
|
option_2=draft.option_2.text if draft.option_2 else None,
|
|
backtranslation_2=draft.option_2.backtranslation if draft.option_2 else None,
|
|
rationale_2=draft.option_2.rationale if draft.option_2 else None,
|
|
option_3=draft.option_3.text if draft.option_3 else None,
|
|
backtranslation_3=draft.option_3.backtranslation if draft.option_3 else None,
|
|
rationale_3=draft.option_3.rationale if draft.option_3 else None,
|
|
tm_entries_cited=draft.tm_entries_cited if draft.tm_entries_cited else None,
|
|
character_count_option_1=char_counts.get("option_1"),
|
|
character_count_option_2=char_counts.get("option_2"),
|
|
character_count_option_3=char_counts.get("option_3"),
|
|
)
|
|
db.add(output_row)
|
|
|
|
# Update instance status and token usage
|
|
instance.status = LocaleStatus.complete
|
|
instance.completed_at = datetime.now(timezone.utc)
|
|
instance.token_usage = context.total_input_tokens + context.total_output_tokens
|
|
instance.estimated_cost = round(context.total_estimated_cost, 6)
|
|
output_path = (
|
|
f"{settings.STORAGE_ROOT}/jobs/{job_id}/output/"
|
|
f"{locale_code}_{job_id}_output.xlsx"
|
|
)
|
|
instance.output_file_path = output_path
|
|
instance.agent_version = "1.0.0"
|
|
|
|
await db.commit()
|
|
|
|
# Check if all instances are complete
|
|
await _check_job_completion(db, job_id)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"locale_code": locale_code,
|
|
"status": "complete",
|
|
"output_rows": len(context.draft_outputs),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error processing {locale_code} for job {job_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
instance.status = LocaleStatus.error
|
|
instance.error_log = str(e)
|
|
instance.completed_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
|
|
# Check if job should be marked as partial/error
|
|
await _check_job_completion(db, job_id)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"locale_code": locale_code,
|
|
"status": "error",
|
|
"error": str(e),
|
|
}
|
|
|
|
return asyncio.run(_process())
|
|
|
|
|
|
async def _check_job_completion(db: AsyncSession, job_id: str) -> None:
|
|
"""Check if all locale instances are done and update job status accordingly."""
|
|
result = await db.execute(
|
|
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
|
|
)
|
|
instances = list(result.scalars().all())
|
|
|
|
if not instances:
|
|
return
|
|
|
|
all_complete = all(i.status == LocaleStatus.complete for i in instances)
|
|
all_done = all(
|
|
i.status in (LocaleStatus.complete, LocaleStatus.error)
|
|
for i in instances
|
|
)
|
|
any_error = any(i.status == LocaleStatus.error for i in instances)
|
|
|
|
job_result = await db.execute(select(Job).where(Job.id == job_id))
|
|
job = job_result.scalar_one_or_none()
|
|
if job is None:
|
|
return
|
|
|
|
if all_complete:
|
|
job.status = JobStatus.complete
|
|
# Sum up token usage
|
|
job.total_token_usage = sum(i.token_usage for i in instances)
|
|
job.total_estimated_cost = sum(i.estimated_cost for i in instances)
|
|
elif all_done and any_error:
|
|
any_complete = any(i.status == LocaleStatus.complete for i in instances)
|
|
if any_complete:
|
|
job.status = JobStatus.partial_complete
|
|
else:
|
|
job.status = JobStatus.error
|
|
|
|
await db.commit()
|