feat: complete Phase 1-2 scaffold — backend, frontend, pipeline skeleton

Full-stack Amazon AI Transcreation Platform with:
- FastAPI backend (async, PostgreSQL, Redis, Celery) with 11 DB tables
- JWT auth (SSO-ready abstract provider pattern)
- 6-agent pipeline orchestrator with deterministic modules
- Next.js 14 frontend with Amazon branding (Ember fonts, orange/dark theme)
- Job wizard, monitoring HUD, output review, admin screens
- 154 TM/reference files imported, 12 locales configured
- Docker Compose for all services

Agents 2-5 (TM retrieval, ranker, transcreator, compliance) are stubs
pending Phase 3 LLM integration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
DJP 2026-04-10 12:31:43 -04:00
parent e3c3dccfe9
commit 98fa16bfc3
180 changed files with 21920 additions and 39 deletions

19
.env.example Normal file
View file

@ -0,0 +1,19 @@
# Database
DATABASE_URL=postgresql+asyncpg://transcreation:transcreation@db:5432/transcreation
# Redis
REDIS_URL=redis://redis:6379/0
# Anthropic
ANTHROPIC_API_KEY=sk-ant-REPLACE_ME
# Auth
JWT_SECRET_KEY=CHANGE_ME_TO_A_RANDOM_SECRET
JWT_ALGORITHM=HS256
JWT_EXPIRY_HOURS=8
# Storage
STORAGE_ROOT=/storage
# LLM
LLM_MODEL=claude-sonnet-4-20250514

75
.gitignore vendored
View file

@ -1,50 +1,47 @@
# These are some examples of commonly ignored file patterns.
# You should customize this list as applicable to your project.
# Learn more about .gitignore:
# https://www.atlassian.com/git/tutorials/saving-changes/gitignore
# Node artifact files
node_modules/
dist/
# Compiled Java class files
*.class
# Compiled Python bytecode
# Python
__pycache__/
*.py[cod]
# Log files
*.log
# Package files
*.jar
# Maven
target/
*$py.class
*.so
*.egg-info/
dist/
build/
.eggs/
venv/
.venv/
# JetBrains IDE
# Environment
.env
.env.local
.env.*.local
# IDE
.vscode/
.idea/
*.swp
*.swo
# Unit test reports
TEST*.xml
# Node
node_modules/
.next/
out/
# Generated by MacOS
# Storage (keep structure, ignore uploaded files)
storage/*
!storage/.gitkeep
# OS
.DS_Store
# Generated by Windows
Thumbs.db
# Applications
*.app
*.exe
*.war
# Docker
docker-compose.override.yml
# Large media files
*.mp4
*.tiff
*.avi
*.flv
*.mov
*.wmv
# Test
.coverage
htmlcov/
.pytest_cache/
# Misc
*.log
*.bak

35
Makefile Normal file
View file

@ -0,0 +1,35 @@
.PHONY: up down build migrate seed test shell logs
up:
docker compose up -d
down:
docker compose down
build:
docker compose build
migrate:
docker compose exec backend alembic upgrade head
seed:
docker compose exec backend python -m seed.create_default_client
docker compose exec backend python -m seed.create_test_users
test:
docker compose exec backend python -m pytest tests/ -v
shell:
docker compose exec backend bash
logs:
docker compose logs -f
restart:
docker compose restart backend celery_worker
db-shell:
docker compose exec db psql -U transcreation
redis-cli:
docker compose exec redis redis-cli

179
README.md Normal file
View file

@ -0,0 +1,179 @@
# Amazon AI Transcreation Platform
An AI-powered platform for transcreating marketing content across European locales. Built for Amazon's creative workflows, it combines Claude LLM capabilities with translation memory, glossaries, and brand voice profiles to produce culturally adapted copy at scale.
## Tech Stack
| Layer | Technology |
|-------------|---------------------------------------------|
| Frontend | Next.js 14, React 18, Tailwind CSS, Radix UI |
| Backend | FastAPI, Python 3.12, SQLAlchemy 2 (async) |
| Database | PostgreSQL 16 |
| Cache/Queue | Redis 7, Celery |
| LLM | Anthropic Claude (via API) |
| Infra | Docker, Docker Compose |
## Prerequisites
- [Docker](https://docs.docker.com/get-docker/) and Docker Compose v2
- An [Anthropic API key](https://console.anthropic.com/) for Claude access
## Quick Start
```bash
# 1. Clone the repository
git clone <repo-url> amazon-transcreation
cd amazon-transcreation
# 2. Configure environment
cp .env.example .env
# Edit .env and set ANTHROPIC_API_KEY, JWT_SECRET_KEY
# 3. Start all services
docker compose up -d
# 4. Run database migrations
make migrate
# 5. Seed the database with the Amazon client
make seed
# 6. Open the app
# Backend API docs: http://localhost:8000/docs
# Frontend: http://localhost:3000
```
## Architecture
```
+---------------------------------------------------------------+
| Frontend (Next.js) |
| Dashboard | Job Builder | Review UI | TM Manager |
+-------------------------------+-------------------------------+
| REST + WebSocket |
+-------------------------------+-------------------------------+
| Backend (FastAPI) |
| Auth | Jobs API | TM API | Files API | WebSocket |
+-------------------------------+-------------------------------+
| |
+--------------+------+ +-------+----------------------------+
| PostgreSQL | | Celery Workers |
| - clients | | - LLM transcreation tasks |
| - users | | - TM lookup & scoring |
| - jobs | | - File processing |
| - source_lines | +------------------------------------+
| - output_rows | |
| - feedback | +----------+-------------------------+
| - tm_file_registry| | Claude API (Anthropic) |
| - reference_files | | - Transcreation generation |
| - audit_logs | | - Quality scoring |
+---------------------+ +------------------------------------+
|
+---------------------+
| Storage (volume) |
| - TM files (JSONL) |
| - Reference files |
+---------------------+
```
## Development
### Backend
```bash
# Shell into the backend container
make shell
# Run tests
make test
# View logs
make logs
# Access the database
make db-shell
```
### Frontend
```bash
cd frontend
npm install
npm run dev
```
The frontend dev server runs on port 3000 and proxies API requests to the backend on port 8000.
### Running Tests
```bash
# Backend tests (inside Docker)
make test
# Frontend tests
cd frontend && npm run lint
```
## API Documentation
FastAPI auto-generates interactive API documentation:
- **Swagger UI**: [http://localhost:8000/docs](http://localhost:8000/docs)
- **ReDoc**: [http://localhost:8000/redoc](http://localhost:8000/redoc)
## Project Structure
```
amazon-transcreation/
├── backend/
│ ├── alembic/ # Database migrations
│ ├── app/
│ │ ├── api/v1/ # REST API routes
│ │ ├── models/ # SQLAlchemy models
│ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic
│ │ ├── tasks/ # Celery async tasks
│ │ ├── ws/ # WebSocket handlers
│ │ ├── config.py # Settings from environment
│ │ ├── dependencies.py # FastAPI dependency injection
│ │ └── main.py # Application factory
│ ├── Dockerfile
│ └── requirements.txt
├── frontend/
│ ├── public/ # Static assets, fonts
│ ├── src/ # Next.js pages and components
│ ├── Dockerfile
│ └── package.json
├── scripts/
│ ├── tm_format_migrator.py # Convert compact TM to multi-field JSONL
│ └── validate_tm_files.py # Audit TM files for format and coverage
├── seed/
│ ├── create_default_client.py # Seed Amazon client + voice profiles
│ └── import_reference_files.py # Import TMs and refs to storage/
├── storage/ # Mounted volume for TM and reference files
├── .env.example # Environment variable template
├── docker-compose.yml # Service orchestration
├── Makefile # Development shortcuts
└── README.md
```
## Key Concepts
- **Client**: An organization (e.g., Amazon) with voice profiles, supported locales, and channel configurations stored as JSON settings.
- **Voice Profile**: Brand personality (Retail, Prime, Brand) that guides tone and style during transcreation.
- **Channel**: A content distribution context (Mass TV, Onsite, etc.) with its own translation memory.
- **Translation Memory (TM)**: JSONL files of previously approved source/target pairs, used for context and consistency.
- **Reference Files**: Glossaries, blacklists, date/percent format guides, and locale-specific considerations that constrain the LLM output.
## Environment Variables
| Variable | Description | Default |
|---------------------|--------------------------------------|------------------------------|
| `DATABASE_URL` | PostgreSQL async connection string | `postgresql+asyncpg://...` |
| `REDIS_URL` | Redis connection string | `redis://redis:6379/0` |
| `ANTHROPIC_API_KEY` | Anthropic API key for Claude | (required) |
| `JWT_SECRET_KEY` | Secret for signing JWT tokens | (required, change default) |
| `JWT_ALGORITHM` | JWT signing algorithm | `HS256` |
| `JWT_EXPIRY_HOURS` | Token expiry in hours | `8` |
| `STORAGE_ROOT` | Root path for file storage | `/storage` |
| `LLM_MODEL` | Claude model identifier | `claude-sonnet-4-20250514` |

19
backend/Dockerfile Normal file
View file

@ -0,0 +1,19 @@
FROM python:3.12-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends gcc libpq-dev && \
rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

40
backend/alembic.ini Normal file
View file

@ -0,0 +1,40 @@
[alembic]
script_location = alembic
prepend_sys_path = .
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

71
backend/alembic/env.py Normal file
View file

@ -0,0 +1,71 @@
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from app.config import settings
# Import all models so Alembic can discover them
from app.models import Base # noqa: F401
# Alembic Config object
config = context.config
# Set the SQLAlchemy URL from settings
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL)
# Logging configuration
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# MetaData for autogenerate support
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode with async engine."""
connectable = async_engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View file

@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View file

View file

@ -0,0 +1,219 @@
"""initial_schema
Revision ID: d4a016fd0817
Revises:
Create Date: 2026-04-10 16:25:45.318282
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd4a016fd0817'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('clients',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('settings', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('users',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('email', sa.String(length=255), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('password_hash', sa.String(length=255), nullable=False),
sa.Column('role', sa.Enum('admin', 'tm_manager', 'reviewer', name='user_role', create_constraint=True), nullable=False),
sa.Column('status', sa.Enum('active', 'inactive', name='user_status', create_constraint=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('email')
)
op.create_table('audit_logs',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=True),
sa.Column('action', sa.String(length=100), nullable=False),
sa.Column('entity_type', sa.String(length=100), nullable=False),
sa.Column('entity_id', sa.String(length=255), nullable=False),
sa.Column('details', sa.JSON(), nullable=True),
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('jobs',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('client_id', sa.Uuid(), nullable=False),
sa.Column('created_by', sa.Uuid(), nullable=False),
sa.Column('job_ref', sa.String(length=100), nullable=True),
sa.Column('campaign_name', sa.String(length=255), nullable=False),
sa.Column('programme', sa.Enum('retail', 'prime', 'brand', name='programme_type', create_constraint=True), nullable=False),
sa.Column('channel', sa.String(length=100), nullable=False),
sa.Column('sub_channel', sa.String(length=100), nullable=True),
sa.Column('context_prompt', sa.Text(), nullable=True),
sa.Column('job_type', sa.Enum('main', 'derived', name='job_type', create_constraint=True), nullable=False),
sa.Column('parent_job_id', sa.Uuid(), nullable=True),
sa.Column('status', sa.Enum('created', 'validating', 'queued', 'running', 'partial_complete', 'complete', 'error', 'exported', name='job_status', create_constraint=True), nullable=False),
sa.Column('total_token_usage', sa.Integer(), nullable=False),
sa.Column('total_estimated_cost', sa.Float(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['parent_job_id'], ['jobs.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('reference_files',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('client_id', sa.Uuid(), nullable=False),
sa.Column('file_type', sa.Enum('glossary', 'blacklist', 'tov_global', 'tov_supplement', 'locale_considerations', 'date_pct_formats', name='reference_file_type', create_constraint=True), nullable=False),
sa.Column('locale_scope', sa.String(length=10), nullable=False),
sa.Column('filename', sa.String(length=255), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('uploaded_by', sa.Uuid(), nullable=True),
sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['uploaded_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('tm_file_registry',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('client_id', sa.Uuid(), nullable=False),
sa.Column('locale_code', sa.String(length=10), nullable=False),
sa.Column('channel', sa.String(length=100), nullable=False),
sa.Column('filename', sa.String(length=255), nullable=False),
sa.Column('file_path', sa.String(length=500), nullable=False),
sa.Column('segment_count', sa.Integer(), nullable=False),
sa.Column('uploaded_by', sa.Uuid(), nullable=True),
sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['uploaded_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user_clients',
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('client_id', sa.Uuid(), nullable=False),
sa.Column('role_override', sa.String(length=50), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('user_id', 'client_id')
)
op.create_table('locale_instances',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('job_id', sa.Uuid(), nullable=False),
sa.Column('locale_code', sa.String(length=10), nullable=False),
sa.Column('locale_type', sa.Enum('main', 'derived', name='locale_type', create_constraint=True), nullable=False),
sa.Column('status', sa.Enum('queued', 'running', 'complete', 'error', name='locale_status', create_constraint=True), nullable=False),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('token_usage', sa.Integer(), nullable=False),
sa.Column('estimated_cost', sa.Float(), nullable=False),
sa.Column('output_file_path', sa.String(length=500), nullable=True),
sa.Column('error_log', sa.Text(), nullable=True),
sa.Column('tm_files_loaded', sa.JSON(), nullable=True),
sa.Column('ref_files_loaded', sa.JSON(), nullable=True),
sa.Column('agent_version', sa.String(length=50), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('source_lines',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('job_id', sa.Uuid(), nullable=False),
sa.Column('row_order', sa.Integer(), nullable=False),
sa.Column('en_gb', sa.Text(), nullable=False),
sa.Column('copy_type', sa.String(length=100), nullable=True),
sa.Column('creative_guidance', sa.Text(), nullable=True),
sa.Column('visual_ref', sa.String(length=500), nullable=True),
sa.Column('char_limit', sa.String(length=50), nullable=True),
sa.Column('is_display_format', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('output_rows',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('instance_id', sa.Uuid(), nullable=False),
sa.Column('line_id', sa.Uuid(), nullable=False),
sa.Column('row_order', sa.Integer(), nullable=False),
sa.Column('confidence_tier', sa.Enum('high', 'moderate', 'low', name='confidence_tier', create_constraint=True), nullable=False),
sa.Column('option_1', sa.Text(), nullable=False),
sa.Column('backtranslation_1', sa.Text(), nullable=False),
sa.Column('rationale_1', sa.Text(), nullable=False),
sa.Column('option_2', sa.Text(), nullable=True),
sa.Column('backtranslation_2', sa.Text(), nullable=True),
sa.Column('rationale_2', sa.Text(), nullable=True),
sa.Column('option_3', sa.Text(), nullable=True),
sa.Column('backtranslation_3', sa.Text(), nullable=True),
sa.Column('rationale_3', sa.Text(), nullable=True),
sa.Column('tm_entries_cited', sa.JSON(), nullable=True),
sa.Column('winning_seg_key', sa.String(length=255), nullable=True),
sa.Column('character_count_option_1', sa.Integer(), nullable=True),
sa.Column('character_count_option_2', sa.Integer(), nullable=True),
sa.Column('character_count_option_3', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['instance_id'], ['locale_instances.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['line_id'], ['source_lines.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('token_usage_logs',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('instance_id', sa.Uuid(), nullable=False),
sa.Column('agent_name', sa.String(length=100), nullable=False),
sa.Column('model', sa.String(length=100), nullable=False),
sa.Column('input_tokens', sa.Integer(), nullable=False),
sa.Column('output_tokens', sa.Integer(), nullable=False),
sa.Column('total_tokens', sa.Integer(), nullable=False),
sa.Column('estimated_cost_usd', sa.Float(), nullable=False),
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['instance_id'], ['locale_instances.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('feedback',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('output_id', sa.Uuid(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('option_column', sa.Integer(), nullable=False),
sa.Column('flag_type', sa.Enum('approved', 'needs_revision', 'comment', name='flag_type', create_constraint=True), nullable=False),
sa.Column('comment', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['output_id'], ['output_rows.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('feedback')
op.drop_table('token_usage_logs')
op.drop_table('output_rows')
op.drop_table('source_lines')
op.drop_table('locale_instances')
op.drop_table('user_clients')
op.drop_table('tm_file_registry')
op.drop_table('reference_files')
op.drop_table('jobs')
op.drop_table('audit_logs')
op.drop_table('users')
op.drop_table('clients')
# ### end Alembic commands ###

0
backend/app/__init__.py Normal file
View file

View file

View file

View file

@ -0,0 +1,60 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, require_role
from app.services.audit_service import AuditService
router = APIRouter(prefix="/audit", tags=["audit"])
audit_service = AuditService()
@router.get("/logs")
async def list_audit_logs(
user_id: UUID | None = Query(None),
action: str | None = Query(None),
entity_type: str | None = Query(None),
entity_id: str | None = Query(None),
date_from: datetime | None = Query(None),
date_to: datetime | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> dict[str, Any]:
"""List audit logs with filters (admin only)."""
logs, total = await audit_service.list_logs(
db,
user_id=user_id,
action=action,
entity_type=entity_type,
entity_id=entity_id,
date_from=date_from,
date_to=date_to,
page=page,
page_size=page_size,
)
pages = (total + page_size - 1) // page_size if total > 0 else 1
return {
"items": [
{
"id": str(log.id),
"user_id": str(log.user_id) if log.user_id else None,
"action": log.action,
"entity_type": log.entity_type,
"entity_id": log.entity_id,
"details": log.details,
"timestamp": log.timestamp.isoformat(),
"ip_address": log.ip_address,
}
for log in logs
],
"total": total,
"page": page,
"page_size": page_size,
"pages": pages,
}

View file

@ -0,0 +1,105 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_db, require_role
from app.models.client import Client
from app.schemas.client import ClientCreate, ClientResponse, ClientUpdate
from app.schemas.common import PaginatedResponse
router = APIRouter(prefix="/clients", tags=["clients"])
@router.post(
"",
response_model=ClientResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_client(
body: ClientCreate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> ClientResponse:
"""Create a new client (admin only)."""
client = Client(name=body.name, settings=body.settings)
db.add(client)
await db.flush()
return ClientResponse.model_validate(client)
@router.get("", response_model=PaginatedResponse[ClientResponse])
async def list_clients(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> PaginatedResponse[ClientResponse]:
"""List all clients (admin only)."""
count_result = await db.execute(select(func.count(Client.id)))
total = count_result.scalar() or 0
result = await db.execute(
select(Client)
.order_by(Client.created_at.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
clients = [ClientResponse.model_validate(c) for c in result.scalars().all()]
pages = (total + page_size - 1) // page_size if total > 0 else 1
return PaginatedResponse(
items=clients, total=total, page=page, page_size=page_size, pages=pages
)
@router.get("/{client_id}", response_model=ClientResponse)
async def get_client(
client_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> ClientResponse:
"""Get a client by ID (admin only)."""
result = await db.execute(select(Client).where(Client.id == client_id))
client = result.scalar_one_or_none()
if client is None:
raise HTTPException(status_code=404, detail="Client not found")
return ClientResponse.model_validate(client)
@router.put("/{client_id}", response_model=ClientResponse)
async def update_client(
client_id: UUID,
body: ClientUpdate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> ClientResponse:
"""Update a client (admin only)."""
result = await db.execute(select(Client).where(Client.id == client_id))
client = result.scalar_one_or_none()
if client is None:
raise HTTPException(status_code=404, detail="Client not found")
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(client, field, value)
await db.flush()
return ClientResponse.model_validate(client)
@router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_client(
client_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> None:
"""Delete a client (admin only)."""
result = await db.execute(select(Client).where(Client.id == client_id))
client = result.scalar_one_or_none()
if client is None:
raise HTTPException(status_code=404, detail="Client not found")
await db.delete(client)
await db.flush()

179
backend/app/api/v1/files.py Normal file
View file

@ -0,0 +1,179 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, status
from fastapi.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.models.files import ReferenceFileType
from app.schemas.files import (
FileUploadResponse,
ReferenceFileResponse,
TMFileResponse,
)
from app.services.file_service import FileService
router = APIRouter(prefix="/files", tags=["files"])
file_service = FileService()
# ---- TM Files ----
@router.post("/tm", response_model=FileUploadResponse, status_code=status.HTTP_201_CREATED)
async def upload_tm_file(
client_id: UUID = Query(...),
locale_code: str = Query(...),
channel: str = Query(...),
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> FileUploadResponse:
"""Upload a Translation Memory (JSONL) file."""
if not file.filename:
raise HTTPException(status_code=400, detail="File must have a filename")
if not file_service.validate_file_extension(file.filename, [".jsonl", ".json"]):
raise HTTPException(status_code=400, detail="Only .jsonl/.json files accepted")
tm = await file_service.upload_tm_file(
db, client_id, locale_code, channel, file.file, file.filename,
uploaded_by=current_user["user_id"],
)
return FileUploadResponse(
id=tm.id,
filename=tm.filename,
file_path=tm.file_path,
message=f"Uploaded TM file with {tm.segment_count} segments",
)
@router.get("/tm", response_model=list[TMFileResponse])
async def list_tm_files(
client_id: UUID = Query(...),
locale_code: str | None = Query(None),
channel: str | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> list[TMFileResponse]:
"""List TM files for a client."""
files = await file_service.list_tm_files(db, client_id, locale_code, channel)
return [TMFileResponse.model_validate(f) for f in files]
@router.get("/tm/{file_id}/download")
async def download_tm_file(
file_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> FileResponse:
"""Download a TM file."""
from sqlalchemy import select
from app.models.files import TMFileRegistry
result = await db.execute(
select(TMFileRegistry).where(TMFileRegistry.id == file_id)
)
tm = result.scalar_one_or_none()
if tm is None:
raise HTTPException(status_code=404, detail="TM file not found")
path = file_service.get_file_path(tm.file_path)
if path is None:
raise HTTPException(status_code=404, detail="File not found on disk")
return FileResponse(path=str(path), filename=tm.filename)
@router.delete("/tm/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_tm_file(
file_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> None:
"""Delete a TM file."""
deleted = await file_service.delete_tm_file(db, file_id)
if not deleted:
raise HTTPException(status_code=404, detail="TM file not found")
# ---- Reference Files ----
@router.post(
"/reference",
response_model=FileUploadResponse,
status_code=status.HTTP_201_CREATED,
)
async def upload_reference_file(
client_id: UUID = Query(...),
file_type: ReferenceFileType = Query(...),
locale_scope: str = Query(...),
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> FileUploadResponse:
"""Upload a reference file (glossary, blacklist, TOV, etc.)."""
if not file.filename:
raise HTTPException(status_code=400, detail="File must have a filename")
ref = await file_service.upload_reference_file(
db, client_id, file_type, locale_scope, file.file, file.filename,
uploaded_by=current_user["user_id"],
)
return FileUploadResponse(
id=ref.id,
filename=ref.filename,
file_path=ref.file_path,
message=f"Uploaded {file_type.value} reference file",
)
@router.get("/reference", response_model=list[ReferenceFileResponse])
async def list_reference_files(
client_id: UUID = Query(...),
file_type: ReferenceFileType | None = Query(None),
locale_scope: str | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> list[ReferenceFileResponse]:
"""List reference files for a client."""
files = await file_service.list_reference_files(
db, client_id, file_type, locale_scope
)
return [ReferenceFileResponse.model_validate(f) for f in files]
@router.get("/reference/{file_id}/download")
async def download_reference_file(
file_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> FileResponse:
"""Download a reference file."""
from sqlalchemy import select
from app.models.files import ReferenceFile
result = await db.execute(
select(ReferenceFile).where(ReferenceFile.id == file_id)
)
ref = result.scalar_one_or_none()
if ref is None:
raise HTTPException(status_code=404, detail="Reference file not found")
path = file_service.get_file_path(ref.file_path)
if path is None:
raise HTTPException(status_code=404, detail="File not found on disk")
return FileResponse(path=str(path), filename=ref.filename)
@router.delete("/reference/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_reference_file(
file_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> None:
"""Delete a reference file."""
deleted = await file_service.delete_reference_file(db, file_id)
if not deleted:
raise HTTPException(status_code=404, detail="Reference file not found")

206
backend/app/api/v1/jobs.py Normal file
View file

@ -0,0 +1,206 @@
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.schemas.common import PaginatedResponse
from app.schemas.job import JobCreate, JobListResponse, JobResponse, JobUpdate, LocaleInstanceResponse
from app.services.audit_service import AuditService
from app.services.file_service import FileService
from app.services.job_service import JobService
router = APIRouter(prefix="/jobs", tags=["jobs"])
job_service = JobService()
file_service = FileService()
audit_service = AuditService()
@router.post("", response_model=JobResponse, status_code=status.HTTP_201_CREATED)
async def create_job(
body: JobCreate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> JobResponse:
"""Create a new transcreation job."""
job = await job_service.create_job(db, body, current_user["user_id"])
await audit_service.log(
db, "create", "job", str(job.id), current_user["user_id"],
details={"campaign_name": body.campaign_name},
)
return JobResponse.model_validate(job)
@router.get("", response_model=PaginatedResponse[JobListResponse])
async def list_jobs(
client_id: UUID | None = Query(None),
job_status: str | None = Query(None, alias="status"),
locale: str | None = Query(None),
date_from: datetime | None = Query(None),
date_to: datetime | None = Query(None),
search: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> PaginatedResponse[JobListResponse]:
"""List jobs with filters and pagination."""
jobs, total = await job_service.list_jobs(
db,
client_id=client_id,
status=job_status,
locale=locale,
date_from=date_from,
date_to=date_to,
search=search,
page=page,
page_size=page_size,
)
items = []
for job in jobs:
item = JobListResponse(
id=job.id,
client_id=job.client_id,
created_by=job.created_by,
job_ref=job.job_ref,
campaign_name=job.campaign_name,
programme=job.programme,
channel=job.channel,
status=job.status,
total_token_usage=job.total_token_usage,
total_estimated_cost=job.total_estimated_cost,
locale_count=len(job.locale_instances) if job.locale_instances else 0,
created_at=job.created_at,
updated_at=job.updated_at,
)
items.append(item)
pages = (total + page_size - 1) // page_size if total > 0 else 1
return PaginatedResponse(
items=items, total=total, page=page, page_size=page_size, pages=pages
)
@router.get("/{job_id}", response_model=JobResponse)
async def get_job(
job_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> JobResponse:
"""Get job details with locale instances."""
job = await job_service.get_job(db, job_id)
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
return JobResponse.model_validate(job)
@router.put("/{job_id}/source")
async def upload_source(
job_id: UUID,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> dict:
"""Upload source xlsx file for a job."""
if not file.filename or not file.filename.endswith(".xlsx"):
raise HTTPException(status_code=400, detail="Only .xlsx files are accepted")
job = await job_service.get_job(db, job_id)
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
source_lines = await file_service.upload_source_file(
db, job_id, file.file, file.filename
)
await audit_service.log(
db, "upload_source", "job", str(job_id), current_user["user_id"],
details={"filename": file.filename, "line_count": len(source_lines)},
)
return {
"message": f"Uploaded {len(source_lines)} source lines",
"line_count": len(source_lines),
}
@router.post("/{job_id}/supplementary")
async def upload_supplementary(
job_id: UUID,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> dict:
"""Upload supplementary files for a job."""
job = await job_service.get_job(db, job_id)
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
file_path = await file_service.upload_supplementary_file(
db, job_id, file.file, file.filename or "unknown"
)
return {"message": "File uploaded", "file_path": file_path}
@router.post("/{job_id}/launch", response_model=JobResponse)
async def launch_job(
job_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> JobResponse:
"""Validate and queue a job for processing."""
try:
job = await job_service.launch_job(db, job_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
await audit_service.log(
db, "launch", "job", str(job_id), current_user["user_id"],
)
return JobResponse.model_validate(job)
@router.post("/{job_id}/cancel", response_model=JobResponse)
async def cancel_job(
job_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> JobResponse:
"""Cancel a running job."""
try:
job = await job_service.cancel_job(db, job_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
await audit_service.log(
db, "cancel", "job", str(job_id), current_user["user_id"],
)
return JobResponse.model_validate(job)
@router.post(
"/{job_id}/locales/{locale_code}/rerun",
response_model=LocaleInstanceResponse,
)
async def rerun_locale(
job_id: UUID,
locale_code: str,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> LocaleInstanceResponse:
"""Re-run a single locale instance."""
instance = await job_service.rerun_locale(db, job_id, locale_code)
if instance is None:
raise HTTPException(
status_code=404,
detail=f"Locale instance '{locale_code}' not found for job",
)
await audit_service.log(
db, "rerun_locale", "locale_instance", str(instance.id),
current_user["user_id"],
details={"locale_code": locale_code},
)
return LocaleInstanceResponse.model_validate(instance)

View file

@ -0,0 +1,84 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.schemas.feedback import FeedbackCreate, FeedbackResponse
from app.schemas.output import OutputPreviewResponse
from app.services.feedback_service import FeedbackService
from app.services.output_service import OutputService
router = APIRouter(prefix="/output", tags=["output"])
output_service = OutputService()
feedback_service = FeedbackService()
@router.get(
"/jobs/{job_id}/locales/{locale_code}/preview",
response_model=OutputPreviewResponse,
)
async def get_output_preview(
job_id: UUID,
locale_code: str,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> OutputPreviewResponse:
"""Get preview data for a locale instance output."""
preview = await output_service.get_preview(db, job_id, locale_code)
if preview is None:
raise HTTPException(
status_code=404,
detail="No output found for this job/locale combination",
)
return preview
@router.post(
"/feedback",
response_model=FeedbackResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_feedback(
body: FeedbackCreate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> FeedbackResponse:
"""Submit feedback on an output row."""
feedback = await feedback_service.create_feedback(
db, body, current_user["user_id"]
)
return FeedbackResponse.model_validate(feedback)
@router.get("/feedback/{output_id}", response_model=list[FeedbackResponse])
async def list_feedback_for_output(
output_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> list[FeedbackResponse]:
"""List all feedback for a specific output row."""
items = await feedback_service.list_feedback(db, output_id=output_id)
return [FeedbackResponse.model_validate(f) for f in items]
@router.get("/jobs/{job_id}/locales/{locale_code}/export")
async def export_output(
job_id: UUID,
locale_code: str,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> FileResponse:
"""Export output as xlsx for a locale instance."""
file_path = await output_service.trigger_export(db, job_id, locale_code)
if file_path is None:
raise HTTPException(
status_code=404,
detail="No export available for this job/locale combination",
)
return FileResponse(
path=file_path,
filename=f"{job_id}_{locale_code}_output.xlsx",
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)

View file

@ -0,0 +1,50 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.dependencies import get_current_user, get_db
from app.services.report_service import ReportService
router = APIRouter(prefix="/reports", tags=["reports"])
report_service = ReportService()
@router.get("/usage")
async def get_usage_stats(
client_id: UUID | None = Query(None),
date_from: datetime | None = Query(None),
date_to: datetime | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> dict[str, Any]:
"""Get usage statistics (total jobs, tokens, cost, status breakdown)."""
return await report_service.get_usage_stats(
db, client_id=client_id, date_from=date_from, date_to=date_to
)
@router.get("/tokens")
async def get_token_cost_data(
client_id: UUID | None = Query(None),
date_from: datetime | None = Query(None),
date_to: datetime | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> list[dict[str, Any]]:
"""Get token usage and cost data grouped by day."""
return await report_service.get_token_cost_data(
db, client_id=client_id, date_from=date_from, date_to=date_to
)
@router.get("/quality")
async def get_quality_metrics(
client_id: UUID | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user),
) -> dict[str, Any]:
"""Get quality metrics (confidence tier distribution, feedback stats)."""
return await report_service.get_quality_metrics(db, client_id=client_id)

View file

@ -0,0 +1,21 @@
from fastapi import APIRouter
from app.auth.router import router as auth_router
from app.api.v1.jobs import router as jobs_router
from app.api.v1.users import router as users_router
from app.api.v1.clients import router as clients_router
from app.api.v1.files import router as files_router
from app.api.v1.output import router as output_router
from app.api.v1.reports import router as reports_router
from app.api.v1.audit import router as audit_router
api_v1_router = APIRouter(prefix="/api/v1")
api_v1_router.include_router(auth_router)
api_v1_router.include_router(jobs_router)
api_v1_router.include_router(users_router)
api_v1_router.include_router(clients_router)
api_v1_router.include_router(files_router)
api_v1_router.include_router(output_router)
api_v1_router.include_router(reports_router)
api_v1_router.include_router(audit_router)

130
backend/app/api/v1/users.py Normal file
View file

@ -0,0 +1,130 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.service import AuthService
from app.dependencies import get_db, require_role
from app.models.user import User, UserClient, UserStatus
from app.schemas.common import PaginatedResponse
from app.schemas.user import UserCreate, UserResponse, UserUpdate
router = APIRouter(prefix="/users", tags=["users"])
auth_service = AuthService()
@router.post(
"",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_user(
body: UserCreate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> UserResponse:
"""Create a new user (admin only)."""
# Check for duplicate email
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status_code=400, detail="Email already registered")
user = User(
email=body.email,
name=body.name,
password_hash=auth_service.hash_password(body.password),
role=body.role,
status=UserStatus.active,
)
db.add(user)
await db.flush()
# Associate with clients
for client_id in body.client_ids:
uc = UserClient(user_id=user.id, client_id=client_id)
db.add(uc)
await db.flush()
return UserResponse.model_validate(user)
@router.get("", response_model=PaginatedResponse[UserResponse])
async def list_users(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> PaginatedResponse[UserResponse]:
"""List all users (admin only)."""
count_result = await db.execute(select(func.count(User.id)))
total = count_result.scalar() or 0
result = await db.execute(
select(User)
.order_by(User.created_at.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
users = [UserResponse.model_validate(u) for u in result.scalars().all()]
pages = (total + page_size - 1) // page_size if total > 0 else 1
return PaginatedResponse(
items=users, total=total, page=page, page_size=page_size, pages=pages
)
@router.get("/{user_id}", response_model=UserResponse)
async def get_user(
user_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> UserResponse:
"""Get a user by ID (admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return UserResponse.model_validate(user)
@router.put("/{user_id}", response_model=UserResponse)
async def update_user(
user_id: UUID,
body: UserUpdate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> UserResponse:
"""Update a user (admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=404, detail="User not found")
update_data = body.model_dump(exclude_unset=True)
if "password" in update_data:
update_data["password_hash"] = auth_service.hash_password(
update_data.pop("password")
)
for field, value in update_data.items():
setattr(user, field, value)
await db.flush()
return UserResponse.model_validate(user)
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(require_role(["admin"])),
) -> None:
"""Soft-delete a user by setting status to inactive (admin only)."""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=404, detail="User not found")
user.status = UserStatus.inactive
await db.flush()

View file

View file

@ -0,0 +1,33 @@
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.auth.service import AuthService
security = HTTPBearer()
auth_service = AuthService()
async def decode_jwt(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> dict:
"""Decode JWT from Authorization header and return user claims."""
token = credentials.credentials
claims = auth_service.validate_token(token)
if claims is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
if claims.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
return {
"user_id": UUID(claims["sub"]),
"email": claims.get("email", ""),
"role": claims.get("role", ""),
"name": claims.get("name", ""),
}

View file

View file

@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Any
class AbstractAuthProvider(ABC):
"""Abstract base for authentication providers."""
@abstractmethod
async def authenticate(self, email: str, password: str) -> dict[str, Any] | None:
"""Authenticate a user. Returns user dict or None if invalid."""
...
@abstractmethod
def create_access_token(self, data: dict[str, Any]) -> str:
"""Create an access token for the given data."""
...
@abstractmethod
def create_refresh_token(self, data: dict[str, Any]) -> str:
"""Create a refresh token for the given data."""
...
@abstractmethod
def validate_token(self, token: str) -> dict[str, Any] | None:
"""Validate a token and return the claims, or None if invalid."""
...
@abstractmethod
def hash_password(self, password: str) -> str:
"""Hash a plaintext password."""
...
@abstractmethod
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a plaintext password against a hash."""
...

View file

@ -0,0 +1,52 @@
from datetime import datetime, timedelta, timezone
from typing import Any
import bcrypt
from jose import JWTError, jwt
from app.config import settings
from app.auth.providers.base import AbstractAuthProvider
class JWTAuthProvider(AbstractAuthProvider):
"""JWT-based authentication provider with bcrypt password hashing."""
def __init__(self) -> None:
self.secret_key = settings.JWT_SECRET_KEY
self.algorithm = settings.JWT_ALGORITHM
self.access_token_expiry = timedelta(hours=settings.JWT_EXPIRY_HOURS)
self.refresh_token_expiry = timedelta(days=7)
async def authenticate(self, email: str, password: str) -> dict[str, Any] | None:
"""Not used directly - authentication happens through AuthService."""
return None
def create_access_token(self, data: dict[str, Any]) -> str:
"""Create a JWT access token (8-hour expiry by default)."""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + self.access_token_expiry
to_encode.update({"exp": expire, "type": "access"})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
def create_refresh_token(self, data: dict[str, Any]) -> str:
"""Create a JWT refresh token (7-day expiry)."""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + self.refresh_token_expiry
to_encode.update({"exp": expire, "type": "refresh"})
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
def validate_token(self, token: str) -> dict[str, Any] | None:
"""Validate a JWT and return its claims, or None if invalid."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
return payload
except JWTError:
return None
def hash_password(self, password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a bcrypt hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))

View file

@ -0,0 +1,44 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.schemas import LoginRequest, RefreshRequest, TokenResponse, UserClaims
from app.auth.service import AuthService
from app.dependencies import get_current_user, get_db
router = APIRouter(prefix="/auth", tags=["auth"])
auth_service = AuthService()
@router.post("/login", response_model=TokenResponse)
async def login(
body: LoginRequest,
db: AsyncSession = Depends(get_db),
) -> TokenResponse:
"""Authenticate user and return access + refresh tokens."""
result = await auth_service.login(body.email, body.password, db)
if result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
return TokenResponse(**result)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(body: RefreshRequest) -> TokenResponse:
"""Exchange a valid refresh token for a new token pair."""
result = auth_service.refresh_tokens(body.refresh_token)
if result is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
return TokenResponse(**result)
@router.get("/me", response_model=UserClaims)
async def get_me(
current_user: dict = Depends(get_current_user),
) -> UserClaims:
"""Return the current authenticated user's claims."""
return UserClaims(**current_user)

View file

@ -0,0 +1,25 @@
from uuid import UUID
from pydantic import BaseModel, EmailStr
class LoginRequest(BaseModel):
email: EmailStr
password: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
class RefreshRequest(BaseModel):
refresh_token: str
class UserClaims(BaseModel):
user_id: UUID
email: str
role: str
name: str

View file

@ -0,0 +1,70 @@
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.providers.jwt_provider import JWTAuthProvider
from app.models.user import User
class AuthService:
"""Authentication service wrapping the JWT provider."""
def __init__(self) -> None:
self.provider = JWTAuthProvider()
async def login(
self, email: str, password: str, db: AsyncSession
) -> dict[str, str] | None:
"""Authenticate a user and return tokens, or None if invalid."""
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if user is None:
return None
if not self.provider.verify_password(password, user.password_hash):
return None
if user.status.value != "active":
return None
token_data = {
"sub": str(user.id),
"email": user.email,
"role": user.role.value,
"name": user.name,
}
return {
"access_token": self.provider.create_access_token(token_data),
"refresh_token": self.provider.create_refresh_token(token_data),
"token_type": "bearer",
}
def refresh_tokens(self, refresh_token: str) -> dict[str, str] | None:
"""Validate a refresh token and issue new token pair."""
claims = self.provider.validate_token(refresh_token)
if claims is None:
return None
if claims.get("type") != "refresh":
return None
token_data = {
"sub": claims["sub"],
"email": claims.get("email", ""),
"role": claims.get("role", ""),
"name": claims.get("name", ""),
}
return {
"access_token": self.provider.create_access_token(token_data),
"refresh_token": self.provider.create_refresh_token(token_data),
"token_type": "bearer",
}
def validate_token(self, token: str) -> dict[str, Any] | None:
"""Validate a token and return claims."""
return self.provider.validate_token(token)
def hash_password(self, password: str) -> str:
"""Hash a password."""
return self.provider.hash_password(password)

23
backend/app/config.py Normal file
View file

@ -0,0 +1,23 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
DATABASE_URL: str = "postgresql+asyncpg://transcreation:transcreation@db:5432/transcreation"
REDIS_URL: str = "redis://redis:6379/0"
ANTHROPIC_API_KEY: str = ""
JWT_SECRET_KEY: str = "CHANGE_ME_TO_A_RANDOM_SECRET"
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRY_HOURS: int = 8
STORAGE_ROOT: str = "/storage"
LLM_MODEL: str = "claude-sonnet-4-20250514"
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"extra": "ignore",
}
settings = Settings()

View file

@ -0,0 +1,82 @@
from typing import AsyncGenerator, Callable
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
security = HTTPBearer()
_engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
pool_size=20,
max_overflow=10,
pool_pre_ping=True,
)
_async_session_factory = async_sessionmaker(
bind=_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield an async database session."""
async with _async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> dict:
"""Decode JWT from Authorization header and return user claims."""
token = credentials.credentials
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
user_id: str | None = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token: missing subject",
)
return {
"user_id": UUID(user_id),
"email": payload.get("email", ""),
"role": payload.get("role", ""),
"name": payload.get("name", ""),
}
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Could not validate credentials: {exc}",
)
def require_role(roles: list[str]) -> Callable:
"""Dependency factory that enforces role-based access."""
async def _check_role(
current_user: dict = Depends(get_current_user),
) -> dict:
if current_user["role"] not in roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{current_user['role']}' not permitted. Required: {roles}",
)
return current_user
return _check_role

View file

143
backend/app/llm/client.py Normal file
View file

@ -0,0 +1,143 @@
"""Anthropic SDK wrapper with retry logic and token tracking."""
import logging
import time
from typing import Any
import anthropic
from app.config import settings
logger = logging.getLogger(__name__)
# Cost per token (approximate, varies by model)
COST_PER_INPUT_TOKEN = 3.0 / 1_000_000 # $3 per 1M input tokens
COST_PER_OUTPUT_TOKEN = 15.0 / 1_000_000 # $15 per 1M output tokens
class LLMClient:
"""Wrapper around the Anthropic SDK with retry and token tracking.
Provides exponential backoff retry on rate limit and server errors.
Tracks token usage per call for cost monitoring.
"""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
max_retries: int = 3,
base_delay: float = 1.0,
) -> None:
self.api_key = api_key or settings.ANTHROPIC_API_KEY
self.model = model or settings.LLM_MODEL
self.max_retries = max_retries
self.base_delay = base_delay
self.client = anthropic.Anthropic(api_key=self.api_key)
self.last_usage: dict[str, Any] = {}
def create_message(
self,
system_prompt: str,
user_message: str,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> tuple[str, dict[str, Any]]:
"""Send a message to Claude and return the response with usage data.
Args:
system_prompt: The system prompt.
user_message: The user message.
max_tokens: Maximum tokens in the response.
temperature: Sampling temperature.
Returns:
Tuple of (response_text, usage_dict).
usage_dict has keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd.
Raises:
anthropic.APIError: If all retries are exhausted.
"""
last_error = None
for attempt in range(1, self.max_retries + 1):
try:
response = self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
system=system_prompt,
messages=[{"role": "user", "content": user_message}],
temperature=temperature,
)
# Extract text
response_text = ""
for block in response.content:
if hasattr(block, "text"):
response_text += block.text
# Track usage
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
total_tokens = input_tokens + output_tokens
estimated_cost = (
input_tokens * COST_PER_INPUT_TOKEN
+ output_tokens * COST_PER_OUTPUT_TOKEN
)
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"estimated_cost_usd": round(estimated_cost, 6),
"model": self.model,
}
self.last_usage = usage
return response_text, usage
except anthropic.RateLimitError as e:
last_error = e
delay = self.base_delay * (2 ** (attempt - 1))
logger.warning(
f"Rate limited (attempt {attempt}/{self.max_retries}), "
f"retrying in {delay}s"
)
time.sleep(delay)
except anthropic.APIStatusError as e:
if e.status_code >= 500:
last_error = e
delay = self.base_delay * (2 ** (attempt - 1))
logger.warning(
f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), "
f"retrying in {delay}s"
)
time.sleep(delay)
else:
raise
raise last_error # type: ignore[misc]
async def acreate_message(
self,
system_prompt: str,
user_message: str,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> tuple[str, dict[str, Any]]:
"""Async version of create_message using the async client.
Same interface as create_message but uses asyncio.
"""
import asyncio
# Run sync client in executor to avoid blocking
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.create_message(
system_prompt, user_message, max_tokens, temperature
),
)

View file

@ -0,0 +1,101 @@
"""Token usage tracking - records LLM usage to the database."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.audit import TokenUsageLog
logger = logging.getLogger(__name__)
async def record_token_usage(
db: AsyncSession,
instance_id: UUID,
agent_name: str,
usage: dict[str, Any],
) -> TokenUsageLog:
"""Record token usage from an LLM call to the database.
Args:
db: Async database session.
instance_id: The locale instance ID this usage is for.
agent_name: Name of the agent that made the call.
usage: Usage dict from LLMClient with keys:
input_tokens, output_tokens, total_tokens,
estimated_cost_usd, model.
Returns:
The created TokenUsageLog record.
"""
log_entry = TokenUsageLog(
instance_id=instance_id,
agent_name=agent_name,
model=usage.get("model", "unknown"),
input_tokens=usage.get("input_tokens", 0),
output_tokens=usage.get("output_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
estimated_cost_usd=usage.get("estimated_cost_usd", 0.0),
)
db.add(log_entry)
await db.flush()
logger.info(
f"Token usage recorded: agent={agent_name}, "
f"tokens={usage.get('total_tokens', 0)}, "
f"cost=${usage.get('estimated_cost_usd', 0.0):.6f}"
)
return log_entry
async def get_total_usage_for_instance(
db: AsyncSession,
instance_id: UUID,
) -> dict[str, Any]:
"""Get aggregated token usage for a locale instance.
Args:
db: Async database session.
instance_id: The locale instance ID.
Returns:
Dict with total_tokens, total_cost, by_agent breakdown.
"""
from sqlalchemy import func, select
result = await db.execute(
select(
func.sum(TokenUsageLog.input_tokens).label("total_input"),
func.sum(TokenUsageLog.output_tokens).label("total_output"),
func.sum(TokenUsageLog.total_tokens).label("total_tokens"),
func.sum(TokenUsageLog.estimated_cost_usd).label("total_cost"),
).where(TokenUsageLog.instance_id == instance_id)
)
row = result.one()
# By-agent breakdown
agent_result = await db.execute(
select(
TokenUsageLog.agent_name,
func.sum(TokenUsageLog.total_tokens).label("tokens"),
func.sum(TokenUsageLog.estimated_cost_usd).label("cost"),
)
.where(TokenUsageLog.instance_id == instance_id)
.group_by(TokenUsageLog.agent_name)
)
by_agent = {
agent_name: {"tokens": tokens, "cost": float(cost)}
for agent_name, tokens, cost in agent_result.all()
}
return {
"total_input_tokens": row.total_input or 0,
"total_output_tokens": row.total_output or 0,
"total_tokens": row.total_tokens or 0,
"total_cost_usd": float(row.total_cost or 0.0),
"by_agent": by_agent,
}

63
backend/app/main.py Normal file
View file

@ -0,0 +1,63 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from app.config import settings
from app.api.v1.router import api_v1_router
from app.ws.handler import ws_router
engine: AsyncEngine | None = None
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage application lifespan: set up and tear down DB engine."""
global engine
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
pool_size=20,
max_overflow=10,
pool_pre_ping=True,
)
yield
if engine is not None:
await engine.dispose()
engine = None
def create_app() -> FastAPI:
"""Application factory."""
application = FastAPI(
title="Amazon AI Transcreation Platform",
description="Backend API for AI-powered marketing transcreation",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware (permissive for development)
application.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API routes
application.include_router(api_v1_router)
# WebSocket routes
application.include_router(ws_router)
@application.get("/health", tags=["health"])
async def health_check() -> dict:
return {"status": "healthy", "version": "1.0.0"}
return application
app = create_app()

View file

@ -0,0 +1,11 @@
"""SQLAlchemy models - import all models so Alembic can discover them."""
from app.models.base import Base, TimestampMixin # noqa: F401
from app.models.client import Client # noqa: F401
from app.models.user import User, UserClient # noqa: F401
from app.models.job import Job, LocaleInstance # noqa: F401
from app.models.source import SourceLine # noqa: F401
from app.models.output import OutputRow # noqa: F401
from app.models.feedback import Feedback # noqa: F401
from app.models.files import TMFileRegistry, ReferenceFile # noqa: F401
from app.models.audit import AuditLog, TokenUsageLog # noqa: F401

View file

@ -0,0 +1,56 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Float, ForeignKey, Integer, JSON, String, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, generate_uuid
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
user_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("users.id"), nullable=True
)
action: Mapped[str] = mapped_column(String(100), nullable=False)
entity_type: Mapped[str] = mapped_column(String(100), nullable=False)
entity_id: Mapped[str] = mapped_column(String(255), nullable=False)
details: Mapped[dict | None] = mapped_column(JSON, nullable=True)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
# Relationships
user = relationship("User", back_populates="audit_logs")
class TokenUsageLog(Base):
__tablename__ = "token_usage_logs"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
instance_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("locale_instances.id", ondelete="CASCADE"), nullable=False
)
agent_name: Mapped[str] = mapped_column(String(100), nullable=False)
model: Mapped[str] = mapped_column(String(100), nullable=False)
input_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
estimated_cost_usd: Mapped[float] = mapped_column(Float, nullable=False)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
# Relationships
instance = relationship("LocaleInstance", back_populates="token_usage_logs")

View file

@ -0,0 +1,30 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy models."""
pass
class TimestampMixin:
"""Mixin that adds created_at and updated_at columns."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
def generate_uuid() -> uuid.UUID:
return uuid.uuid4()

View file

@ -0,0 +1,22 @@
import uuid
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, generate_uuid
class Client(Base, TimestampMixin):
__tablename__ = "clients"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
settings: Mapped[dict | None] = mapped_column(JSON, nullable=True)
# Relationships
user_clients = relationship("UserClient", back_populates="client", lazy="selectin")
jobs = relationship("Job", back_populates="client", lazy="selectin")
tm_files = relationship("TMFileRegistry", back_populates="client", lazy="selectin")
reference_files = relationship("ReferenceFile", back_populates="client", lazy="selectin")

View file

@ -0,0 +1,43 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, generate_uuid
class FlagType(str, enum.Enum):
approved = "approved"
needs_revision = "needs_revision"
comment = "comment"
class Feedback(Base):
__tablename__ = "feedback"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
output_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("output_rows.id", ondelete="CASCADE"), nullable=False
)
user_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("users.id"), nullable=False
)
option_column: Mapped[int] = mapped_column(Integer, nullable=False)
flag_type: Mapped[FlagType] = mapped_column(
Enum(FlagType, name="flag_type", create_constraint=True),
nullable=False,
)
comment: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
# Relationships
output = relationship("OutputRow", back_populates="feedback")
user = relationship("User", back_populates="feedback")

View file

@ -0,0 +1,89 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, String, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, generate_uuid
class ReferenceFileType(str, enum.Enum):
glossary = "glossary"
blacklist = "blacklist"
tov_global = "tov_global"
tov_supplement = "tov_supplement"
locale_considerations = "locale_considerations"
date_pct_formats = "date_pct_formats"
class TMFileRegistry(Base):
__tablename__ = "tm_file_registry"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
client_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("clients.id", ondelete="CASCADE"), nullable=False
)
locale_code: Mapped[str] = mapped_column(String(10), nullable=False)
channel: Mapped[str] = mapped_column(String(100), nullable=False)
filename: Mapped[str] = mapped_column(String(255), nullable=False)
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
segment_count: Mapped[int] = mapped_column(Integer, default=0)
uploaded_by: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("users.id"), nullable=True
)
uploaded_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
last_updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
version: Mapped[int] = mapped_column(Integer, default=1)
# Relationships
client = relationship("Client", back_populates="tm_files")
uploader = relationship("User", foreign_keys=[uploaded_by])
class ReferenceFile(Base):
__tablename__ = "reference_files"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
client_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("clients.id", ondelete="CASCADE"), nullable=False
)
file_type: Mapped[ReferenceFileType] = mapped_column(
Enum(ReferenceFileType, name="reference_file_type", create_constraint=True),
nullable=False,
)
locale_scope: Mapped[str] = mapped_column(String(10), nullable=False)
filename: Mapped[str] = mapped_column(String(255), nullable=False)
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
uploaded_by: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("users.id"), nullable=True
)
uploaded_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
last_updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
version: Mapped[int] = mapped_column(Integer, default=1)
# Relationships
client = relationship("Client", back_populates="reference_files")
uploader = relationship("User", foreign_keys=[uploaded_by])

143
backend/app/models/job.py Normal file
View file

@ -0,0 +1,143 @@
import enum
import uuid
from datetime import datetime
from sqlalchemy import (
JSON,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Text,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, generate_uuid
class Programme(str, enum.Enum):
retail = "retail"
prime = "prime"
brand = "brand"
class JobType(str, enum.Enum):
main = "main"
derived = "derived"
class JobStatus(str, enum.Enum):
created = "created"
validating = "validating"
queued = "queued"
running = "running"
partial_complete = "partial_complete"
complete = "complete"
error = "error"
exported = "exported"
class LocaleType(str, enum.Enum):
main = "main"
derived = "derived"
class LocaleStatus(str, enum.Enum):
queued = "queued"
running = "running"
complete = "complete"
error = "error"
class Job(Base, TimestampMixin):
__tablename__ = "jobs"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
client_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("clients.id", ondelete="CASCADE"), nullable=False
)
created_by: Mapped[uuid.UUID] = mapped_column(
ForeignKey("users.id"), nullable=False
)
job_ref: Mapped[str | None] = mapped_column(String(100), nullable=True)
campaign_name: Mapped[str] = mapped_column(String(255), nullable=False)
programme: Mapped[Programme] = mapped_column(
Enum(Programme, name="programme_type", create_constraint=True),
nullable=False,
)
channel: Mapped[str] = mapped_column(String(100), nullable=False)
sub_channel: Mapped[str | None] = mapped_column(String(100), nullable=True)
context_prompt: Mapped[str | None] = mapped_column(Text, nullable=True)
job_type: Mapped[JobType] = mapped_column(
Enum(JobType, name="job_type", create_constraint=True),
default=JobType.main,
nullable=False,
)
parent_job_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("jobs.id"), nullable=True
)
status: Mapped[JobStatus] = mapped_column(
Enum(JobStatus, name="job_status", create_constraint=True),
default=JobStatus.created,
nullable=False,
)
total_token_usage: Mapped[int] = mapped_column(Integer, default=0)
total_estimated_cost: Mapped[float] = mapped_column(Float, default=0.0)
# Relationships
client = relationship("Client", back_populates="jobs")
creator = relationship("User", back_populates="jobs_created")
parent_job = relationship("Job", remote_side="Job.id", lazy="selectin")
locale_instances = relationship(
"LocaleInstance", back_populates="job", lazy="selectin", cascade="all, delete-orphan"
)
source_lines = relationship(
"SourceLine", back_populates="job", lazy="selectin", cascade="all, delete-orphan"
)
class LocaleInstance(Base, TimestampMixin):
__tablename__ = "locale_instances"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
job_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
)
locale_code: Mapped[str] = mapped_column(String(10), nullable=False)
locale_type: Mapped[LocaleType] = mapped_column(
Enum(LocaleType, name="locale_type", create_constraint=True),
nullable=False,
)
status: Mapped[LocaleStatus] = mapped_column(
Enum(LocaleStatus, name="locale_status", create_constraint=True),
default=LocaleStatus.queued,
nullable=False,
)
started_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
token_usage: Mapped[int] = mapped_column(Integer, default=0)
estimated_cost: Mapped[float] = mapped_column(Float, default=0.0)
output_file_path: Mapped[str | None] = mapped_column(String(500), nullable=True)
error_log: Mapped[str | None] = mapped_column(Text, nullable=True)
tm_files_loaded: Mapped[dict | None] = mapped_column(JSON, nullable=True)
ref_files_loaded: Mapped[dict | None] = mapped_column(JSON, nullable=True)
agent_version: Mapped[str | None] = mapped_column(String(50), nullable=True)
# Relationships
job = relationship("Job", back_populates="locale_instances")
output_rows = relationship(
"OutputRow", back_populates="instance", lazy="selectin", cascade="all, delete-orphan"
)
token_usage_logs = relationship(
"TokenUsageLog", back_populates="instance", lazy="selectin", cascade="all, delete-orphan"
)

View file

@ -0,0 +1,53 @@
import enum
import uuid
from sqlalchemy import Enum, ForeignKey, Integer, JSON, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, generate_uuid
class ConfidenceTier(str, enum.Enum):
high = "high"
moderate = "moderate"
low = "low"
class OutputRow(Base, TimestampMixin):
__tablename__ = "output_rows"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
instance_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("locale_instances.id", ondelete="CASCADE"), nullable=False
)
line_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("source_lines.id", ondelete="CASCADE"), nullable=False
)
row_order: Mapped[int] = mapped_column(Integer, nullable=False)
confidence_tier: Mapped[ConfidenceTier] = mapped_column(
Enum(ConfidenceTier, name="confidence_tier", create_constraint=True),
nullable=False,
)
option_1: Mapped[str] = mapped_column(Text, nullable=False)
backtranslation_1: Mapped[str] = mapped_column(Text, nullable=False)
rationale_1: Mapped[str] = mapped_column(Text, nullable=False)
option_2: Mapped[str | None] = mapped_column(Text, nullable=True)
backtranslation_2: Mapped[str | None] = mapped_column(Text, nullable=True)
rationale_2: Mapped[str | None] = mapped_column(Text, nullable=True)
option_3: Mapped[str | None] = mapped_column(Text, nullable=True)
backtranslation_3: Mapped[str | None] = mapped_column(Text, nullable=True)
rationale_3: Mapped[str | None] = mapped_column(Text, nullable=True)
tm_entries_cited: Mapped[dict | None] = mapped_column(JSON, nullable=True)
winning_seg_key: Mapped[str | None] = mapped_column(String(255), nullable=True)
character_count_option_1: Mapped[int | None] = mapped_column(Integer, nullable=True)
character_count_option_2: Mapped[int | None] = mapped_column(Integer, nullable=True)
character_count_option_3: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationships
instance = relationship("LocaleInstance", back_populates="output_rows")
line = relationship("SourceLine", back_populates="output_rows")
feedback = relationship(
"Feedback", back_populates="output", lazy="selectin", cascade="all, delete-orphan"
)

View file

@ -0,0 +1,30 @@
import uuid
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, generate_uuid
class SourceLine(Base, TimestampMixin):
__tablename__ = "source_lines"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
job_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
)
row_order: Mapped[int] = mapped_column(Integer, nullable=False)
en_gb: Mapped[str] = mapped_column(Text, nullable=False)
copy_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
creative_guidance: Mapped[str | None] = mapped_column(Text, nullable=True)
visual_ref: Mapped[str | None] = mapped_column(String(500), nullable=True)
char_limit: Mapped[str | None] = mapped_column(String(50), nullable=True)
is_display_format: Mapped[bool] = mapped_column(Boolean, default=False)
# Relationships
job = relationship("Job", back_populates="source_lines")
output_rows = relationship(
"OutputRow", back_populates="line", lazy="selectin", cascade="all, delete-orphan"
)

View file

@ -0,0 +1,60 @@
import enum
import uuid
from sqlalchemy import Enum, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, TimestampMixin, generate_uuid
class UserRole(str, enum.Enum):
admin = "admin"
tm_manager = "tm_manager"
reviewer = "reviewer"
class UserStatus(str, enum.Enum):
active = "active"
inactive = "inactive"
class User(Base, TimestampMixin):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(
primary_key=True, default=generate_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, name="user_role", create_constraint=True),
nullable=False,
)
status: Mapped[UserStatus] = mapped_column(
Enum(UserStatus, name="user_status", create_constraint=True),
default=UserStatus.active,
nullable=False,
)
# Relationships
user_clients = relationship("UserClient", back_populates="user", lazy="selectin")
jobs_created = relationship("Job", back_populates="creator", lazy="selectin")
feedback = relationship("Feedback", back_populates="user", lazy="selectin")
audit_logs = relationship("AuditLog", back_populates="user", lazy="selectin")
class UserClient(Base):
__tablename__ = "user_clients"
user_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
client_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("clients.id", ondelete="CASCADE"), primary_key=True
)
role_override: Mapped[str | None] = mapped_column(String(50), nullable=True)
# Relationships
user = relationship("User", back_populates="user_clients")
client = relationship("Client", back_populates="user_clients")

View file

View file

View file

@ -0,0 +1,106 @@
"""Agent 1: Validator
Validates source file, loads reference files, and builds the initial
PipelineContext. This agent is deterministic (no LLM call).
"""
from typing import Any
from app.pipeline.agents.base import BaseAgent
from app.pipeline.contracts import (
FileManifest,
JobParams,
ParsedJob,
PipelineContext,
SourceLineContract,
)
from app.pipeline.modules.source_file_parser import parse_source_file
from app.pipeline.modules.ref_file_loader import load_all_reference_files
class Agent1Validator(BaseAgent):
"""Validates inputs and builds the initial pipeline context."""
name = "agent_1_validator"
description = "Validates source file and reference files, builds ParsedJob"
def __init__(
self,
source_file_path: str | None = None,
source_lines: list[dict] | None = None,
file_manifest: dict[str, str | None] | None = None,
job_params: dict[str, Any] | None = None,
) -> None:
self.source_file_path = source_file_path
self.source_lines_raw = source_lines
self.file_manifest_raw = file_manifest or {}
self.job_params_raw = job_params or {}
def get_system_prompt(self) -> str:
return "" # No LLM call for this agent
def build_user_message(self, context: PipelineContext) -> str:
return "" # No LLM call for this agent
def parse_response(self, response: str, context: PipelineContext) -> Any:
return None # No LLM call for this agent
async def run(self, context: PipelineContext) -> PipelineContext:
"""Validate inputs and populate the pipeline context.
1. Parse source file or use pre-parsed source lines
2. Resolve and load reference files
3. Build initial PipelineContext
"""
# Parse source lines
if self.source_lines_raw:
raw_lines = self.source_lines_raw
elif self.source_file_path:
raw_lines = parse_source_file(self.source_file_path)
else:
raw_lines = []
source_lines = [
SourceLineContract(
line_id=str(i + 1),
row_order=i + 1,
en_gb=line.get("en_gb", ""),
copy_type=line.get("copy_type"),
creative_guidance=line.get("creative_guidance"),
visual_ref=line.get("visual_ref"),
char_limit=line.get("char_limit"),
is_display_format=line.get("is_display_format", False),
)
for i, line in enumerate(raw_lines)
]
# Build file manifest
file_manifest = FileManifest(
tm_files=self.file_manifest_raw.get("tm_files", []),
glossary_file=self.file_manifest_raw.get("glossary_file"),
blacklist_file=self.file_manifest_raw.get("blacklist_file"),
tov_global_file=self.file_manifest_raw.get("tov_global_file"),
tov_supplement_file=self.file_manifest_raw.get("tov_supplement_file"),
locale_considerations_file=self.file_manifest_raw.get(
"locale_considerations_file"
),
date_pct_formats_file=self.file_manifest_raw.get("date_pct_formats_file"),
)
# Load reference files (for validation)
if any([
file_manifest.glossary_file,
file_manifest.blacklist_file,
file_manifest.date_pct_formats_file,
]):
load_all_reference_files(self.file_manifest_raw)
# Build job params
job_params = JobParams(**self.job_params_raw)
# Update context
context.job_params = job_params
context.source_lines = source_lines
context.file_manifest = file_manifest
return context

View file

@ -0,0 +1,40 @@
"""Agent 2: TM Retrieval (STUB)
Retrieves Translation Memory matches for each source line.
Currently returns empty results as a stub.
"""
from typing import Any
from app.pipeline.agents.base import BaseAgent
from app.pipeline.contracts import PipelineContext, TMSweepResult
class Agent2TMRetrieval(BaseAgent):
"""STUB: TM retrieval agent returning empty sweep results."""
name = "agent_2_tm_retrieval"
description = "Retrieves TM matches for source lines (STUB)"
def get_system_prompt(self) -> str:
return "You are a Translation Memory retrieval agent."
def build_user_message(self, context: PipelineContext) -> str:
return "Retrieve TM matches for the provided source lines."
def parse_response(self, response: str, context: PipelineContext) -> Any:
return []
async def run(self, context: PipelineContext) -> PipelineContext:
"""STUB: Return empty TM sweep results for all source lines."""
context.tm_sweep_results = [
TMSweepResult(
line_id=line.line_id,
confirmed_matches=[],
pass_4_triggered=False,
pass_4_result=None,
no_match=True,
)
for line in context.source_lines
]
return context

View file

@ -0,0 +1,42 @@
"""Agent 3: Ranker (STUB)
Ranks TM matches and declares confidence tiers for each source line.
Currently returns LOW confidence for all lines as a stub.
"""
from typing import Any
from app.pipeline.agents.base import BaseAgent
from app.pipeline.contracts import PipelineContext, RankingDeclaration
class Agent3Ranker(BaseAgent):
"""STUB: Ranking agent returning LOW confidence for all lines."""
name = "agent_3_ranker"
description = "Ranks TM matches and declares confidence (STUB)"
def get_system_prompt(self) -> str:
return "You are a ranking and confidence declaration agent."
def build_user_message(self, context: PipelineContext) -> str:
return "Rank the TM matches for each source line."
def parse_response(self, response: str, context: PipelineContext) -> Any:
return []
async def run(self, context: PipelineContext) -> PipelineContext:
"""STUB: Return LOW confidence ranking for all source lines."""
context.ranking_declarations = [
RankingDeclaration(
line_id=line.line_id,
winning_entry=None,
runner_ups=[],
confidence_tier="low",
option_count=3,
is_new_creative_line=True,
notes="STUB: No TM matches available",
)
for line in context.source_lines
]
return context

View file

@ -0,0 +1,55 @@
"""Agent 4: Transcreator (STUB)
Generates transcreation drafts for each source line.
Currently returns placeholder translations as a stub.
"""
from typing import Any
from app.pipeline.agents.base import BaseAgent
from app.pipeline.contracts import DraftOption, DraftOutput, PipelineContext
class Agent4Transcreator(BaseAgent):
"""STUB: Transcreation agent returning placeholder translations."""
name = "agent_4_transcreator"
description = "Generates transcreation drafts (STUB)"
def get_system_prompt(self) -> str:
return "You are a creative transcreation agent."
def build_user_message(self, context: PipelineContext) -> str:
return "Generate transcreation drafts for each source line."
def parse_response(self, response: str, context: PipelineContext) -> Any:
return []
async def run(self, context: PipelineContext) -> PipelineContext:
"""STUB: Return placeholder translations for all source lines."""
locale = context.job_params.locale_code
context.draft_outputs = [
DraftOutput(
line_id=line.line_id,
option_1=DraftOption(
text=f"[{locale}] {line.en_gb}",
backtranslation=line.en_gb,
rationale=f"STUB: Direct placeholder for '{line.en_gb[:50]}...'",
),
option_2=DraftOption(
text=f"[{locale} alt] {line.en_gb}",
backtranslation=line.en_gb,
rationale="STUB: Alternative placeholder",
),
option_3=DraftOption(
text=f"[{locale} creative] {line.en_gb}",
backtranslation=line.en_gb,
rationale="STUB: Creative placeholder",
),
tm_entries_cited=[],
adaptations_applied=[],
)
for line in context.source_lines
]
return context

View file

@ -0,0 +1,52 @@
"""Agent 5: Compliance Checker (STUB)
Checks transcreation drafts against compliance rules.
Currently returns PASS for all lines as a stub.
"""
from typing import Any
from app.pipeline.agents.base import BaseAgent
from app.pipeline.contracts import ComplianceResult, PipelineContext
from app.pipeline.modules.character_counter import count_characters
class Agent5Compliance(BaseAgent):
"""STUB: Compliance agent returning pass for all lines."""
name = "agent_5_compliance"
description = "Checks compliance of transcreation drafts (STUB)"
def get_system_prompt(self) -> str:
return "You are a compliance checking agent."
def build_user_message(self, context: PipelineContext) -> str:
return "Check compliance for all transcreation drafts."
def parse_response(self, response: str, context: PipelineContext) -> Any:
return []
async def run(self, context: PipelineContext) -> PipelineContext:
"""STUB: Return pass for all compliance checks, with character counts."""
context.compliance_results = []
for draft in context.draft_outputs:
char_counts: dict[str, int] = {}
if draft.option_1:
char_counts["option_1"] = count_characters(draft.option_1.text)
if draft.option_2:
char_counts["option_2"] = count_characters(draft.option_2.text)
if draft.option_3:
char_counts["option_3"] = count_characters(draft.option_3.text)
context.compliance_results.append(
ComplianceResult(
line_id=draft.line_id,
passed=True,
violations=[],
character_counts=char_counts,
)
)
return context

View file

@ -0,0 +1,112 @@
"""Agent 6: Formatter
Generates the output xlsx file and builds output row data.
This agent is deterministic (no LLM call).
"""
from datetime import datetime, timezone
from typing import Any
from uuid import uuid4
from app.pipeline.agents.base import BaseAgent
from app.pipeline.contracts import PipelineContext
from app.pipeline.modules.excel_writer import generate_output_xlsx
from app.config import settings
class Agent6Formatter(BaseAgent):
"""Formats pipeline output into xlsx and structured data."""
name = "agent_6_formatter"
description = "Generates output xlsx and structured output rows"
def __init__(self, output_dir: str | None = None) -> None:
self.output_dir = output_dir or settings.STORAGE_ROOT
def get_system_prompt(self) -> str:
return "" # No LLM call
def build_user_message(self, context: PipelineContext) -> str:
return "" # No LLM call
def parse_response(self, response: str, context: PipelineContext) -> Any:
return None # No LLM call
async def run(self, context: PipelineContext) -> PipelineContext:
"""Generate output xlsx and return updated context with file path.
Returns the context as-is (output rows are built by the orchestrator
for database persistence). The xlsx file is written to disk.
"""
job_id = context.job_params.job_id
locale_code = context.job_params.locale_code
# Build source lines for excel
source_lines_data = [
{
"row_order": sl.row_order,
"en_gb": sl.en_gb,
"copy_type": sl.copy_type or "",
}
for sl in context.source_lines
]
# Build output rows for excel
output_rows_data = []
for i, draft in enumerate(context.draft_outputs):
row: dict[str, Any] = {
"row_order": i + 1,
"line_id": draft.line_id,
"option_1": draft.option_1.text if draft.option_1 else "",
"backtranslation_1": draft.option_1.backtranslation if draft.option_1 else "",
"rationale_1": draft.option_1.rationale if draft.option_1 else "",
}
if draft.option_2:
row["option_2"] = draft.option_2.text
row["backtranslation_2"] = draft.option_2.backtranslation
row["rationale_2"] = draft.option_2.rationale
if draft.option_3:
row["option_3"] = draft.option_3.text
row["backtranslation_3"] = draft.option_3.backtranslation
row["rationale_3"] = draft.option_3.rationale
output_rows_data.append(row)
# Build summary
compliance_counts = {"high": 0, "moderate": 0, "low": 0}
for ranking in context.ranking_declarations:
tier = ranking.confidence_tier
if tier in compliance_counts:
compliance_counts[tier] += 1
summary = {
"job_id": job_id,
"campaign_name": context.job_params.campaign_name,
"locale_code": locale_code,
"channel": context.job_params.channel,
"programme": context.job_params.programme,
"total_source_lines": len(context.source_lines),
"total_output_rows": len(output_rows_data),
"high_confidence": compliance_counts["high"],
"moderate_confidence": compliance_counts["moderate"],
"low_confidence": compliance_counts["low"],
"total_tokens": 0,
"estimated_cost": 0.0,
"agent_version": "1.0.0",
"generated_at": datetime.now(timezone.utc).isoformat(),
}
# Generate xlsx
output_path = (
f"{self.output_dir}/jobs/{job_id}/output/"
f"{locale_code}_{job_id}_output.xlsx"
)
generate_output_xlsx(
output_path=output_path,
source_lines=source_lines_data,
output_rows=output_rows_data,
summary=summary,
)
return context

View file

@ -0,0 +1,62 @@
"""Base agent abstract class for the transcreation pipeline."""
from abc import ABC, abstractmethod
from typing import Any
from app.pipeline.contracts import PipelineContext
class BaseAgent(ABC):
"""Abstract base class for pipeline agents.
Each agent in the pipeline implements this interface:
1. get_system_prompt() - Returns the system prompt for the LLM
2. build_user_message() - Builds the user message from pipeline context
3. parse_response() - Parses the LLM response into structured data
4. run() - Orchestrates the agent's execution
"""
name: str = "base_agent"
description: str = ""
@abstractmethod
def get_system_prompt(self) -> str:
"""Return the system prompt for this agent's LLM call."""
...
@abstractmethod
def build_user_message(self, context: PipelineContext) -> str:
"""Build the user message from pipeline context.
Args:
context: The current pipeline context.
Returns:
The user message string to send to the LLM.
"""
...
@abstractmethod
def parse_response(self, response: str, context: PipelineContext) -> Any:
"""Parse the LLM response into structured data.
Args:
response: The raw LLM response text.
context: The current pipeline context.
Returns:
Structured data appropriate for this agent's output.
"""
...
@abstractmethod
async def run(self, context: PipelineContext) -> PipelineContext:
"""Execute this agent and return the updated pipeline context.
Args:
context: The current pipeline context.
Returns:
Updated pipeline context with this agent's results.
"""
...

View file

@ -0,0 +1,139 @@
"""Pipeline data contracts - Pydantic models for inter-agent communication."""
from typing import Any
from pydantic import BaseModel
class TMEntry(BaseModel):
"""A single Translation Memory entry."""
seg_key: str
date: str
en: str
lc: str
tx: str
nt: str = ""
channel: str = ""
sub_channel: str = ""
_text: str = ""
model_config = {"from_attributes": True}
class SourceLineContract(BaseModel):
"""A parsed source line from the input xlsx."""
line_id: str
row_order: int
en_gb: str
copy_type: str | None = None
creative_guidance: str | None = None
visual_ref: str | None = None
char_limit: str | None = None
is_display_format: bool = False
class FileManifest(BaseModel):
"""Manifest of all files loaded for a job."""
tm_files: list[str] = []
glossary_file: str | None = None
blacklist_file: str | None = None
tov_global_file: str | None = None
tov_supplement_file: str | None = None
locale_considerations_file: str | None = None
date_pct_formats_file: str | None = None
class JobParams(BaseModel):
"""Parameters for a transcreation job."""
job_id: str
client_id: str
locale_code: str
channel: str
sub_channel: str | None = None
programme: str
campaign_name: str
context_prompt: str | None = None
class ParsedJob(BaseModel):
"""Output of Agent 1 (Validator): validated job parameters + source."""
job_params: JobParams
source_lines: list[SourceLineContract]
file_manifest: FileManifest
class ConfirmedMatch(BaseModel):
"""A confirmed TM match for a source line."""
seg_key: str
pass_found: int
date: str
en: str
tx: str
nt: str = ""
channel: str = ""
sub_channel: str = ""
is_cross_channel: bool = False
class TMSweepResult(BaseModel):
"""TM sweep results for a single source line."""
line_id: str
confirmed_matches: list[ConfirmedMatch] = []
pass_4_triggered: bool = False
pass_4_result: ConfirmedMatch | None = None
no_match: bool = False
class RankingDeclaration(BaseModel):
"""Ranking decision for a single source line."""
line_id: str
winning_entry: ConfirmedMatch | None = None
runner_ups: list[ConfirmedMatch] = []
confidence_tier: str = "low"
option_count: int = 3
is_new_creative_line: bool = False
notes: str = ""
class DraftOption(BaseModel):
"""A single draft transcreation option."""
text: str
backtranslation: str
rationale: str
class DraftOutput(BaseModel):
"""Transcreation draft output for a single source line."""
line_id: str
option_1: DraftOption
option_2: DraftOption | None = None
option_3: DraftOption | None = None
tm_entries_cited: list[str] = []
adaptations_applied: list[str] = []
class ComplianceViolation(BaseModel):
"""A single compliance violation found during checking."""
type: str
option_affected: int
description: str
severity: str = "warning"
class ComplianceResult(BaseModel):
"""Compliance check result for a single source line."""
line_id: str
passed: bool
violations: list[ComplianceViolation] = []
character_counts: dict[str, int] = {}
class PipelineContext(BaseModel):
"""Full pipeline context passed between agents."""
job_params: JobParams
source_lines: list[SourceLineContract] = []
file_manifest: FileManifest = FileManifest()
tm_sweep_results: list[TMSweepResult] = []
ranking_declarations: list[RankingDeclaration] = []
draft_outputs: list[DraftOutput] = []
compliance_results: list[ComplianceResult] = []

View file

View file

@ -0,0 +1,108 @@
"""Blacklist scanner for forbidden terms and brand violations.
Supports:
- Exact match: term appears as-is in the text (case-insensitive).
- Root-based match: checks if the root/stem of a blacklisted term appears.
"""
import re
from dataclasses import dataclass
@dataclass
class BlacklistViolation:
"""A detected blacklist violation."""
term: str
match_type: str # "exact" or "root"
position: int # character position in text
context: str # surrounding text
def scan_text(
text: str,
blacklist_entries: list[dict],
) -> list[BlacklistViolation]:
"""Scan text against blacklist entries for violations.
Args:
text: The text to scan.
blacklist_entries: List of dicts with keys:
- term: str (the forbidden term)
- root: str | None (optional root form for root-based matching)
- severity: str (optional)
Returns:
List of BlacklistViolation instances found.
"""
if not text or not blacklist_entries:
return []
violations: list[BlacklistViolation] = []
text_lower = text.lower()
for entry in blacklist_entries:
term = entry.get("term", "").strip()
if not term:
continue
root = entry.get("root", "").strip() if entry.get("root") else None
# Exact match (case-insensitive, word boundary)
pattern = re.compile(r"\b" + re.escape(term) + r"\b", re.IGNORECASE)
for match in pattern.finditer(text):
start = max(0, match.start() - 20)
end = min(len(text), match.end() + 20)
violations.append(
BlacklistViolation(
term=term,
match_type="exact",
position=match.start(),
context=text[start:end],
)
)
# Root-based match
if root:
root_pattern = re.compile(
r"\b" + re.escape(root) + r"\w*\b", re.IGNORECASE
)
for match in root_pattern.finditer(text):
# Skip if this is already caught as an exact match
matched_text = match.group().lower()
if matched_text == term.lower():
continue
start = max(0, match.start() - 20)
end = min(len(text), match.end() + 20)
violations.append(
BlacklistViolation(
term=term,
match_type="root",
position=match.start(),
context=text[start:end],
)
)
return violations
def scan_all_options(
options: list[str],
blacklist_entries: list[dict],
) -> dict[int, list[BlacklistViolation]]:
"""Scan multiple text options against the blacklist.
Args:
options: List of text options to scan (index 0 = option 1, etc.)
blacklist_entries: The blacklist entries.
Returns:
Dict mapping option index (1-based) to violations found.
"""
results: dict[int, list[BlacklistViolation]] = {}
for i, option_text in enumerate(options):
if option_text:
violations = scan_text(option_text, blacklist_entries)
if violations:
results[i + 1] = violations
return results

View file

@ -0,0 +1,67 @@
"""Unicode grapheme cluster character counting.
Uses the grapheme library for accurate character counting that handles
multi-codepoint characters (emoji, combining marks, etc.) correctly.
Rules:
- Strip leading/trailing whitespace before counting.
- Line breaks (\\n) count as 0 characters.
"""
import grapheme
def count_characters(text: str) -> int:
"""Count grapheme clusters in a string, excluding line breaks.
Args:
text: The text to count characters in.
Returns:
Number of grapheme clusters (visible characters), excluding line breaks.
"""
if not text:
return 0
# Strip whitespace
cleaned = text.strip()
# Remove line breaks (they count as 0)
cleaned = cleaned.replace("\n", "").replace("\r", "")
return grapheme.length(cleaned)
def check_character_limit(text: str, char_limit: str | None) -> tuple[int, bool]:
"""Count characters and check against a limit.
Args:
text: The text to measure.
char_limit: The character limit as a string (may be numeric or range like "20-30").
Returns:
Tuple of (character_count, within_limit). If no limit is set, within_limit is True.
"""
count = count_characters(text)
if not char_limit:
return count, True
char_limit = char_limit.strip()
# Handle range format "20-30"
if "-" in char_limit:
parts = char_limit.split("-")
try:
lower = int(parts[0].strip())
upper = int(parts[1].strip())
return count, lower <= count <= upper
except (ValueError, IndexError):
return count, True
# Handle simple numeric limit
try:
limit = int(char_limit)
return count, count <= limit
except ValueError:
return count, True

View file

@ -0,0 +1,121 @@
"""Validate date and percentage format strings against approved locale formats.
Checks that dates and percentages in transcreated text conform to the
locale-specific format rules defined in the date/percentage format file.
"""
import re
from dataclasses import dataclass
@dataclass
class FormatViolation:
"""A detected date/percentage format violation."""
found: str
expected_format: str
description: str
def validate_date_formats(
text: str,
approved_formats: list[dict],
) -> list[FormatViolation]:
"""Validate date strings in text against approved formats.
Args:
text: The text containing dates to validate.
approved_formats: List of dicts with keys:
- pattern: str (regex pattern for valid format)
- example: str (example of the correct format)
- description: str
Returns:
List of FormatViolation instances.
"""
if not text or not approved_formats:
return []
violations: list[FormatViolation] = []
# Common date-like patterns to detect
date_patterns = [
# DD/MM/YYYY, MM/DD/YYYY, YYYY/MM/DD
r"\b\d{1,2}[/\-.]\d{1,2}[/\-.]\d{2,4}\b",
# Month DD, YYYY
r"\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s+\d{1,2},?\s+\d{4}\b",
# DD Month YYYY
r"\b\d{1,2}\s+(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s+\d{4}\b",
]
for date_pattern in date_patterns:
for match in re.finditer(date_pattern, text, re.IGNORECASE):
found_date = match.group()
is_valid = False
for fmt in approved_formats:
pattern = fmt.get("pattern", "")
if pattern and re.match(pattern, found_date, re.IGNORECASE):
is_valid = True
break
if not is_valid and approved_formats:
examples = [
fmt.get("example", "") for fmt in approved_formats if fmt.get("example")
]
violations.append(
FormatViolation(
found=found_date,
expected_format=", ".join(examples[:3]),
description=f"Date format '{found_date}' does not match approved formats",
)
)
return violations
def validate_percentage_formats(
text: str,
approved_formats: list[dict],
) -> list[FormatViolation]:
"""Validate percentage strings in text against approved formats.
Args:
text: The text containing percentages to validate.
approved_formats: List of dicts with keys:
- pattern: str (regex)
- example: str
- description: str
Returns:
List of FormatViolation instances.
"""
if not text or not approved_formats:
return []
violations: list[FormatViolation] = []
# Find percentage-like patterns
pct_pattern = r"\b\d+[\.,]?\d*\s*[%]\b"
for match in re.finditer(pct_pattern, text):
found_pct = match.group()
is_valid = False
for fmt in approved_formats:
pattern = fmt.get("pattern", "")
if pattern and re.match(pattern, found_pct):
is_valid = True
break
if not is_valid and approved_formats:
examples = [
fmt.get("example", "") for fmt in approved_formats if fmt.get("example")
]
violations.append(
FormatViolation(
found=found_pct,
expected_format=", ".join(examples[:3]),
description=f"Percentage format '{found_pct}' does not match approved formats",
)
)
return violations

View file

@ -0,0 +1,110 @@
"""Domain substitution for Amazon locales.
Maps Amazon.co.uk (source domain) to the correct locale-specific domain.
Handles both full domain URLs and bare "Amazon" references.
Emerging locales: bare "Amazon" stays as "Amazon"
Non-emerging locales: bare "Amazon" -> locale-specific brand name rules apply
"""
# Full domain map for all 12 supported locales
DOMAIN_MAP: dict[str, str] = {
"de_DE": "Amazon.de",
"fr_FR": "Amazon.fr",
"it_IT": "Amazon.it",
"es_ES": "Amazon.es",
"nl_NL": "Amazon.nl",
"pl_PL": "Amazon.pl",
"sv_SE": "Amazon.se",
"pt_BR": "Amazon.com.br",
"ja_JP": "Amazon.co.jp",
"en_AU": "Amazon.com.au",
"en_SG": "Amazon.sg",
"ar_AE": "Amazon.ae",
}
# Emerging locales where bare "Amazon" stays as-is
EMERGING_LOCALES: set[str] = {
"pl_PL",
"sv_SE",
"nl_NL",
"ar_AE",
"en_SG",
}
# Source domain to replace
SOURCE_DOMAIN = "Amazon.co.uk"
SOURCE_DOMAIN_LOWER = SOURCE_DOMAIN.lower()
def substitute_domains(text: str, locale_code: str) -> str:
"""Replace Amazon.co.uk with the locale-specific domain.
Args:
text: The text containing domain references.
locale_code: The target locale code (e.g., "de_DE").
Returns:
Text with domains substituted.
"""
if not text:
return text
target_domain = DOMAIN_MAP.get(locale_code)
if target_domain is None:
return text
# Replace full domain (case-insensitive)
result = text
idx = 0
while True:
lower_result = result.lower()
pos = lower_result.find(SOURCE_DOMAIN_LOWER, idx)
if pos == -1:
break
result = result[:pos] + target_domain + result[pos + len(SOURCE_DOMAIN):]
idx = pos + len(target_domain)
return result
def substitute_bare_amazon(text: str, locale_code: str) -> str:
"""Handle bare 'Amazon' references based on locale type.
For emerging locales: leave bare 'Amazon' as-is.
For non-emerging locales: append locale domain context if needed.
Args:
text: The text with potential bare Amazon references.
locale_code: The target locale code.
Returns:
Text with bare Amazon references handled.
"""
if not text:
return text
if locale_code in EMERGING_LOCALES:
# Emerging locales: bare Amazon stays as Amazon
return text
# For non-emerging locales, bare "Amazon" (not followed by .)
# is kept as-is since it's the brand name
return text
def get_locale_domain(locale_code: str) -> str | None:
"""Get the domain for a locale code.
Args:
locale_code: The target locale code.
Returns:
The domain string or None if locale is not supported.
"""
return DOMAIN_MAP.get(locale_code)
def is_emerging_locale(locale_code: str) -> bool:
"""Check if a locale is classified as emerging."""
return locale_code in EMERGING_LOCALES

View file

@ -0,0 +1,184 @@
"""Generate output xlsx files with structured output and summary tabs.
Tab 1: 11-column output table
Tab 2: Transcreation Summary
Column widths and formatting per specification.
"""
from pathlib import Path
from typing import Any
from openpyxl import Workbook
from openpyxl.styles import Alignment, Font, PatternFill
from openpyxl.utils import get_column_letter
from app.pipeline.modules.line_break_normaliser import normalise_for_excel
# Tab 1 column definitions
OUTPUT_COLUMNS = [
("EN_GB", 40),
("Copy Type", 15),
("Option 1", 40),
("Back-translation 1", 40),
("Rationale 1", 35),
("Option 2", 40),
("Back-translation 2", 40),
("Rationale 2", 35),
("Option 3", 40),
("Back-translation 3", 40),
("Rationale 3", 35),
]
# Header style
HEADER_FONT = Font(bold=True, size=11, color="FFFFFF")
HEADER_FILL = PatternFill(start_color="232F3E", end_color="232F3E", fill_type="solid")
HEADER_ALIGNMENT = Alignment(horizontal="center", vertical="center", wrap_text=True)
# Data style
DATA_ALIGNMENT = Alignment(vertical="top", wrap_text=True)
def generate_output_xlsx(
output_path: str,
source_lines: list[dict[str, Any]],
output_rows: list[dict[str, Any]],
summary: dict[str, Any] | None = None,
) -> str:
"""Generate the output xlsx file.
Args:
output_path: Absolute path where the xlsx should be saved.
source_lines: List of source line dicts (en_gb, copy_type, etc.).
output_rows: List of output row dicts with options, backtranslations, rationales.
summary: Optional summary data for Tab 2.
Returns:
The absolute path to the generated file.
"""
wb = Workbook()
# ---- Tab 1: Output Table ----
ws1 = wb.active
ws1.title = "Transcreation Output"
# Write headers
for col_idx, (header, width) in enumerate(OUTPUT_COLUMNS, start=1):
cell = ws1.cell(row=1, column=col_idx, value=header)
cell.font = HEADER_FONT
cell.fill = HEADER_FILL
cell.alignment = HEADER_ALIGNMENT
ws1.column_dimensions[get_column_letter(col_idx)].width = width
# Write data rows
for row_idx, output_row in enumerate(output_rows, start=2):
# Find matching source line
source_line = _find_source_line(source_lines, output_row)
ws1.cell(
row=row_idx, column=1,
value=normalise_for_excel(source_line.get("en_gb", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=2,
value=source_line.get("copy_type", ""),
).alignment = DATA_ALIGNMENT
# Option 1
ws1.cell(
row=row_idx, column=3,
value=normalise_for_excel(output_row.get("option_1", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=4,
value=normalise_for_excel(output_row.get("backtranslation_1", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=5,
value=output_row.get("rationale_1", ""),
).alignment = DATA_ALIGNMENT
# Option 2
ws1.cell(
row=row_idx, column=6,
value=normalise_for_excel(output_row.get("option_2", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=7,
value=normalise_for_excel(output_row.get("backtranslation_2", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=8,
value=output_row.get("rationale_2", ""),
).alignment = DATA_ALIGNMENT
# Option 3
ws1.cell(
row=row_idx, column=9,
value=normalise_for_excel(output_row.get("option_3", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=10,
value=normalise_for_excel(output_row.get("backtranslation_3", "")),
).alignment = DATA_ALIGNMENT
ws1.cell(
row=row_idx, column=11,
value=output_row.get("rationale_3", ""),
).alignment = DATA_ALIGNMENT
# ---- Tab 2: Transcreation Summary ----
ws2 = wb.create_sheet("Transcreation Summary")
summary_data = summary or {}
summary_rows = [
("Job ID", summary_data.get("job_id", "")),
("Campaign", summary_data.get("campaign_name", "")),
("Locale", summary_data.get("locale_code", "")),
("Channel", summary_data.get("channel", "")),
("Programme", summary_data.get("programme", "")),
("Total Source Lines", summary_data.get("total_source_lines", 0)),
("Total Output Rows", summary_data.get("total_output_rows", 0)),
("High Confidence", summary_data.get("high_confidence", 0)),
("Moderate Confidence", summary_data.get("moderate_confidence", 0)),
("Low Confidence", summary_data.get("low_confidence", 0)),
("Total Tokens Used", summary_data.get("total_tokens", 0)),
("Estimated Cost (USD)", summary_data.get("estimated_cost", 0.0)),
("Agent Version", summary_data.get("agent_version", "")),
("Generated At", summary_data.get("generated_at", "")),
]
ws2.column_dimensions["A"].width = 25
ws2.column_dimensions["B"].width = 40
for row_idx, (label, value) in enumerate(summary_rows, start=1):
label_cell = ws2.cell(row=row_idx, column=1, value=label)
label_cell.font = Font(bold=True)
ws2.cell(row=row_idx, column=2, value=str(value))
# Save
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
wb.save(output_path)
wb.close()
return output_path
def _find_source_line(
source_lines: list[dict[str, Any]], output_row: dict[str, Any]
) -> dict[str, Any]:
"""Find the source line matching an output row by row_order or line_id."""
row_order = output_row.get("row_order")
line_id = output_row.get("line_id")
for sl in source_lines:
if line_id and sl.get("id") == line_id:
return sl
if row_order is not None and sl.get("row_order") == row_order:
return sl
# Fallback: match by index
if row_order is not None and 0 < row_order <= len(source_lines):
return source_lines[row_order - 1]
return {}

View file

@ -0,0 +1,67 @@
"""Line break normalisation utilities.
Three modes:
- normalise_for_query: Strip line breaks, collapse multiple spaces to single.
Used when building search queries against TM.
- normalise_for_excel: Convert \\n to openpyxl-compatible line breaks.
Used when writing output cells.
- preserve_raw: Return text as-is (identity function for pipeline clarity).
"""
import re
def normalise_for_query(text: str) -> str:
"""Strip line breaks and collapse spaces for TM query matching.
Args:
text: Raw text potentially containing line breaks.
Returns:
Single-line text with normalised whitespace.
"""
if not text:
return ""
# Replace all line break variants with a space
result = text.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")
# Collapse multiple spaces to one
result = re.sub(r"\s+", " ", result)
return result.strip()
def normalise_for_excel(text: str) -> str:
"""Convert line breaks to openpyxl-compatible format.
openpyxl uses \\n for in-cell line breaks when wrap_text is enabled.
This ensures consistent line break representation.
Args:
text: Text with potential line breaks.
Returns:
Text with standardised \\n line breaks.
"""
if not text:
return ""
# Normalise all line break variants to \\n
result = text.replace("\r\n", "\n").replace("\r", "\n")
return result
def preserve_raw(text: str) -> str:
"""Return text as-is (identity function).
Used in the pipeline to explicitly indicate no normalisation is applied.
Args:
text: Any text.
Returns:
The same text, unchanged.
"""
return text

View file

@ -0,0 +1,177 @@
"""Reference file loader.
Loads various reference files used in the transcreation pipeline:
- Glossary (JSON): locale-specific term glossary
- Blacklist (JSON): forbidden terms and roots
- Date/Percentage formats (JSON): approved format patterns
- Locale Considerations (JSON): locale-specific rules and notes
- TOV (Tone of Voice) files (JSON): global and supplementary voice profiles
"""
import json
from pathlib import Path
from typing import Any
class RefFileLoadError(Exception):
"""Raised when a reference file cannot be loaded or parsed."""
pass
def load_json_file(file_path: str) -> Any:
"""Load and parse a JSON file.
Args:
file_path: Absolute path to the JSON file.
Returns:
Parsed JSON data.
Raises:
RefFileLoadError: If file cannot be read or parsed.
"""
path = Path(file_path)
if not path.exists():
raise RefFileLoadError(f"Reference file not found: {file_path}")
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as exc:
raise RefFileLoadError(f"Invalid JSON in reference file: {exc}")
except UnicodeDecodeError as exc:
raise RefFileLoadError(f"Encoding error reading reference file: {exc}")
def load_glossary(file_path: str) -> list[dict[str, str]]:
"""Load a glossary file.
Expected format: list of dicts with keys like:
{"en": "source term", "tx": "translated term", "context": "usage notes"}
Args:
file_path: Path to glossary JSON file.
Returns:
List of glossary entry dicts.
"""
data = load_json_file(file_path)
if not isinstance(data, list):
raise RefFileLoadError("Glossary file must contain a JSON array")
return data
def load_blacklist(file_path: str) -> list[dict[str, str]]:
"""Load a blacklist file.
Expected format: list of dicts with keys:
{"term": "forbidden term", "root": "optional root", "reason": "why forbidden"}
Args:
file_path: Path to blacklist JSON file.
Returns:
List of blacklist entry dicts.
"""
data = load_json_file(file_path)
if not isinstance(data, list):
raise RefFileLoadError("Blacklist file must contain a JSON array")
return data
def load_date_pct_formats(file_path: str) -> dict[str, list[dict[str, str]]]:
"""Load date/percentage format rules.
Expected format:
{
"date_formats": [{"pattern": "...", "example": "...", "description": "..."}],
"percentage_formats": [{"pattern": "...", "example": "...", "description": "..."}]
}
Args:
file_path: Path to date/pct formats JSON file.
Returns:
Dict with "date_formats" and "percentage_formats" keys.
"""
data = load_json_file(file_path)
if not isinstance(data, dict):
raise RefFileLoadError("Date/pct format file must contain a JSON object")
return {
"date_formats": data.get("date_formats", []),
"percentage_formats": data.get("percentage_formats", []),
}
def load_locale_considerations(file_path: str) -> dict[str, Any]:
"""Load locale-specific considerations.
Expected format: JSON object with locale-specific rules, cultural notes, etc.
Args:
file_path: Path to locale considerations JSON file.
Returns:
Dict of locale considerations.
"""
data = load_json_file(file_path)
if not isinstance(data, dict):
raise RefFileLoadError(
"Locale considerations file must contain a JSON object"
)
return data
def load_tov(file_path: str) -> dict[str, Any]:
"""Load a Tone of Voice file (global or supplement).
Expected format: JSON object with voice profile data.
Args:
file_path: Path to TOV JSON file.
Returns:
Dict of TOV profile data.
"""
data = load_json_file(file_path)
if not isinstance(data, dict):
raise RefFileLoadError("TOV file must contain a JSON object")
return data
def load_all_reference_files(
file_manifest: dict[str, str | None],
) -> dict[str, Any]:
"""Load all reference files from a file manifest.
Args:
file_manifest: Dict mapping file types to file paths.
Keys: glossary_file, blacklist_file, tov_global_file,
tov_supplement_file, locale_considerations_file,
date_pct_formats_file
Returns:
Dict mapping file types to loaded data.
"""
result: dict[str, Any] = {}
loaders = {
"glossary_file": load_glossary,
"blacklist_file": load_blacklist,
"date_pct_formats_file": load_date_pct_formats,
"locale_considerations_file": load_locale_considerations,
"tov_global_file": load_tov,
"tov_supplement_file": load_tov,
}
for key, loader in loaders.items():
path = file_manifest.get(key)
if path:
try:
result[key] = loader(path)
except RefFileLoadError:
result[key] = None
else:
result[key] = None
return result

View file

@ -0,0 +1,97 @@
"""Parse source xlsx files into structured source line data.
Validates the expected 5 column headers (case sensitive):
EN_GB, Copy Type, Creative Guidance, Visual Ref, Char Limit
Skips rows where EN_GB is empty. Detects \\n in EN_GB for is_display_format.
"""
from typing import Any
from openpyxl import load_workbook
REQUIRED_HEADERS = ["EN_GB", "Copy Type", "Creative Guidance", "Visual Ref", "Char Limit"]
class SourceFileParseError(Exception):
"""Raised when the source file has validation errors."""
pass
def parse_source_file(file_path: str) -> list[dict[str, Any]]:
"""Parse a source xlsx file and return a list of source line dicts.
Args:
file_path: Absolute path to the xlsx file.
Returns:
List of dicts with keys: en_gb, copy_type, creative_guidance,
visual_ref, char_limit, is_display_format.
Raises:
SourceFileParseError: If headers are invalid or file cannot be read.
"""
try:
wb = load_workbook(file_path, read_only=True, data_only=True)
except Exception as exc:
raise SourceFileParseError(f"Cannot open xlsx file: {exc}")
ws = wb.active
if ws is None:
raise SourceFileParseError("Workbook has no active sheet")
# Read and validate headers from first row
rows = ws.iter_rows(min_row=1, max_row=1, values_only=True)
first_row = next(rows, None)
if first_row is None:
raise SourceFileParseError("File is empty - no header row found")
headers = [str(cell).strip() if cell else "" for cell in first_row]
# Validate all required headers exist (case sensitive)
for required in REQUIRED_HEADERS:
if required not in headers:
raise SourceFileParseError(
f"Missing required header '{required}'. "
f"Found headers: {headers}. "
f"Expected: {REQUIRED_HEADERS}"
)
# Build column index map
col_map = {header: idx for idx, header in enumerate(headers)}
# Parse data rows
source_lines: list[dict[str, Any]] = []
for row in ws.iter_rows(min_row=2, values_only=True):
en_gb_idx = col_map["EN_GB"]
en_gb_raw = row[en_gb_idx] if en_gb_idx < len(row) else None
# Skip rows where EN_GB is empty
if en_gb_raw is None or str(en_gb_raw).strip() == "":
continue
en_gb = str(en_gb_raw).strip()
# Detect display format: presence of \n in EN_GB text
is_display_format = "\n" in en_gb
def _get_cell(header: str) -> str | None:
idx = col_map.get(header)
if idx is None or idx >= len(row):
return None
val = row[idx]
if val is None:
return None
return str(val).strip() or None
source_lines.append({
"en_gb": en_gb,
"copy_type": _get_cell("Copy Type"),
"creative_guidance": _get_cell("Creative Guidance"),
"visual_ref": _get_cell("Visual Ref"),
"char_limit": _get_cell("Char Limit"),
"is_display_format": is_display_format,
})
wb.close()
return source_lines

View file

@ -0,0 +1,133 @@
"""Translation Memory file loader.
Reads JSONL files in two formats:
1. Compact: {"t": "seg_key|date|en|lc|tx|nt|channel|sub_channel"}
2. Multi-field: {"seg_key": "...", "date": "...", "en": "...", ...}
Applies a locale hard-match gate: only entries matching the target locale are returned.
"""
import json
from typing import Any
from app.pipeline.contracts import TMEntry
class TMFileLoadError(Exception):
"""Raised when a TM file cannot be loaded or parsed."""
pass
def load_tm_file(
file_path: str,
target_locale: str,
) -> list[TMEntry]:
"""Load and parse a JSONL TM file, filtering by locale.
Args:
file_path: Absolute path to the JSONL file.
target_locale: Target locale code (e.g., "de_DE"). Only entries
matching this locale will be returned.
Returns:
List of TMEntry objects matching the target locale.
Raises:
TMFileLoadError: If the file cannot be read or parsed.
"""
entries: list[TMEntry] = []
try:
with open(file_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError as exc:
raise TMFileLoadError(
f"Invalid JSON on line {line_num}: {exc}"
)
entry = _parse_entry(data, line_num)
if entry is None:
continue
# Locale hard-match gate
if entry.lc == target_locale:
entries.append(entry)
except FileNotFoundError:
raise TMFileLoadError(f"TM file not found: {file_path}")
except UnicodeDecodeError as exc:
raise TMFileLoadError(f"Encoding error reading TM file: {exc}")
return entries
def _parse_entry(data: dict[str, Any], line_num: int) -> TMEntry | None:
"""Parse a single JSON object into a TMEntry.
Detects compact vs multi-field format automatically.
Args:
data: Parsed JSON dict.
line_num: Line number for error reporting.
Returns:
TMEntry or None if the entry is malformed.
"""
# Compact format: {"t": "seg_key|date|en|lc|tx|nt|channel|sub_channel"}
if "t" in data and isinstance(data["t"], str):
parts = data["t"].split("|")
if len(parts) < 5:
return None # Malformed compact entry
return TMEntry(
seg_key=parts[0] if len(parts) > 0 else "",
date=parts[1] if len(parts) > 1 else "",
en=parts[2] if len(parts) > 2 else "",
lc=parts[3] if len(parts) > 3 else "",
tx=parts[4] if len(parts) > 4 else "",
nt=parts[5] if len(parts) > 5 else "",
channel=parts[6] if len(parts) > 6 else "",
sub_channel=parts[7] if len(parts) > 7 else "",
_text=data["t"],
)
# Multi-field format
if "seg_key" in data and "en" in data:
return TMEntry(
seg_key=str(data.get("seg_key", "")),
date=str(data.get("date", "")),
en=str(data.get("en", "")),
lc=str(data.get("lc", "")),
tx=str(data.get("tx", "")),
nt=str(data.get("nt", "")),
channel=str(data.get("channel", "")),
sub_channel=str(data.get("sub_channel", "")),
)
return None
def load_multiple_tm_files(
file_paths: list[str],
target_locale: str,
) -> list[TMEntry]:
"""Load and merge multiple TM files.
Args:
file_paths: List of file paths to load.
target_locale: Target locale code.
Returns:
Combined list of TMEntry objects from all files.
"""
all_entries: list[TMEntry] = []
for path in file_paths:
entries = load_tm_file(path, target_locale)
all_entries.extend(entries)
return all_entries

View file

@ -0,0 +1,190 @@
"""Pipeline orchestrator - state machine managing the transcreation pipeline.
States: INIT -> VALIDATE -> TM_RETRIEVE -> RANK -> TRANSCREATE -> COMPLY -> FORMAT -> DONE
Runs agents in sequence with compliance re-draft loop (max 3 iterations).
Progress callbacks enable WebSocket updates.
"""
import enum
import logging
from typing import Any, Callable
from app.pipeline.contracts import FileManifest, JobParams, PipelineContext
from app.pipeline.agents.agent_1_validator import Agent1Validator
from app.pipeline.agents.agent_2_tm_retrieval import Agent2TMRetrieval
from app.pipeline.agents.agent_3_ranker import Agent3Ranker
from app.pipeline.agents.agent_4_transcreator import Agent4Transcreator
from app.pipeline.agents.agent_5_compliance import Agent5Compliance
from app.pipeline.agents.agent_6_formatter import Agent6Formatter
logger = logging.getLogger(__name__)
MAX_COMPLIANCE_RETRIES = 3
class PipelineState(str, enum.Enum):
INIT = "INIT"
VALIDATE = "VALIDATE"
TM_RETRIEVE = "TM_RETRIEVE"
RANK = "RANK"
TRANSCREATE = "TRANSCREATE"
COMPLY = "COMPLY"
FORMAT = "FORMAT"
DONE = "DONE"
ERROR = "ERROR"
class PipelineOrchestrator:
"""Orchestrates the transcreation pipeline through its state machine.
Args:
job_params: Job parameters dict.
source_lines: Pre-parsed source lines (optional).
source_file_path: Path to source xlsx (optional, used if source_lines not given).
file_manifest: Dict of reference file paths.
output_dir: Directory for output files.
on_progress: Optional callback(state, message, pct) for progress updates.
"""
def __init__(
self,
job_params: dict[str, Any],
source_lines: list[dict] | None = None,
source_file_path: str | None = None,
file_manifest: dict[str, Any] | None = None,
output_dir: str | None = None,
on_progress: Callable[[str, str, float], None] | None = None,
) -> None:
self.job_params = job_params
self.source_lines = source_lines
self.source_file_path = source_file_path
self.file_manifest = file_manifest or {}
self.output_dir = output_dir
self.on_progress = on_progress
self.state = PipelineState.INIT
self.context: PipelineContext | None = None
def _emit_progress(self, message: str, pct: float) -> None:
"""Emit a progress callback if one is registered."""
if self.on_progress:
try:
self.on_progress(self.state.value, message, pct)
except Exception as e:
logger.warning(f"Progress callback error: {e}")
async def run(self) -> PipelineContext:
"""Run the full pipeline from INIT to DONE.
Returns:
The final PipelineContext with all results.
Raises:
Exception: If any agent fails.
"""
# Initialize context
self.context = PipelineContext(
job_params=JobParams(**self.job_params),
file_manifest=FileManifest(),
)
try:
# VALIDATE
self.state = PipelineState.VALIDATE
self._emit_progress("Validating inputs", 0.1)
logger.info(f"[{self.job_params.get('job_id')}] VALIDATE")
validator = Agent1Validator(
source_file_path=self.source_file_path,
source_lines=self.source_lines,
file_manifest=self.file_manifest,
job_params=self.job_params,
)
self.context = await validator.run(self.context)
# TM_RETRIEVE
self.state = PipelineState.TM_RETRIEVE
self._emit_progress("Retrieving TM matches", 0.25)
logger.info(f"[{self.job_params.get('job_id')}] TM_RETRIEVE")
tm_retriever = Agent2TMRetrieval()
self.context = await tm_retriever.run(self.context)
# RANK
self.state = PipelineState.RANK
self._emit_progress("Ranking matches", 0.4)
logger.info(f"[{self.job_params.get('job_id')}] RANK")
ranker = Agent3Ranker()
self.context = await ranker.run(self.context)
# TRANSCREATE + COMPLY loop
for attempt in range(1, MAX_COMPLIANCE_RETRIES + 1):
# TRANSCREATE
self.state = PipelineState.TRANSCREATE
self._emit_progress(
f"Generating transcreations (attempt {attempt})",
0.5 + (attempt - 1) * 0.1,
)
logger.info(
f"[{self.job_params.get('job_id')}] TRANSCREATE (attempt {attempt})"
)
transcreator = Agent4Transcreator()
self.context = await transcreator.run(self.context)
# COMPLY
self.state = PipelineState.COMPLY
self._emit_progress(
f"Checking compliance (attempt {attempt})",
0.6 + (attempt - 1) * 0.1,
)
logger.info(
f"[{self.job_params.get('job_id')}] COMPLY (attempt {attempt})"
)
compliance = Agent5Compliance()
self.context = await compliance.run(self.context)
# Check if all passed
all_passed = all(
cr.passed for cr in self.context.compliance_results
)
if all_passed:
logger.info(
f"[{self.job_params.get('job_id')}] All compliance checks passed"
)
break
else:
failed_count = sum(
1 for cr in self.context.compliance_results if not cr.passed
)
logger.warning(
f"[{self.job_params.get('job_id')}] "
f"{failed_count} compliance failures, "
f"attempt {attempt}/{MAX_COMPLIANCE_RETRIES}"
)
# FORMAT
self.state = PipelineState.FORMAT
self._emit_progress("Generating output file", 0.9)
logger.info(f"[{self.job_params.get('job_id')}] FORMAT")
formatter = Agent6Formatter(output_dir=self.output_dir)
self.context = await formatter.run(self.context)
# DONE
self.state = PipelineState.DONE
self._emit_progress("Complete", 1.0)
logger.info(f"[{self.job_params.get('job_id')}] DONE")
return self.context
except Exception as e:
self.state = PipelineState.ERROR
self._emit_progress(f"Error: {str(e)}", -1.0)
logger.error(
f"[{self.job_params.get('job_id')}] Pipeline error: {e}",
exc_info=True,
)
raise

View file

View file

@ -0,0 +1,29 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
class ClientCreate(BaseModel):
name: str
settings: dict[str, Any] | None = None
model_config = {"from_attributes": True}
class ClientUpdate(BaseModel):
name: str | None = None
settings: dict[str, Any] | None = None
model_config = {"from_attributes": True}
class ClientResponse(BaseModel):
id: UUID
name: str
settings: dict[str, Any] | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View file

@ -0,0 +1,19 @@
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
T = TypeVar("T")
class PaginatedResponse(BaseModel, Generic[T]):
items: list[T]
total: int
page: int
page_size: int
pages: int
class ErrorResponse(BaseModel):
detail: str
code: str | None = None
errors: list[dict[str, Any]] | None = None

View file

@ -0,0 +1,27 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from app.models.feedback import FlagType
class FeedbackCreate(BaseModel):
output_id: UUID
option_column: int
flag_type: FlagType
comment: str | None = None
model_config = {"from_attributes": True}
class FeedbackResponse(BaseModel):
id: UUID
output_id: UUID
user_id: UUID
option_column: int
flag_type: FlagType
comment: str | None = None
created_at: datetime
model_config = {"from_attributes": True}

View file

@ -0,0 +1,46 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from app.models.files import ReferenceFileType
class TMFileResponse(BaseModel):
id: UUID
client_id: UUID
locale_code: str
channel: str
filename: str
file_path: str
segment_count: int
uploaded_by: UUID | None = None
uploaded_at: datetime
last_updated_at: datetime
version: int
model_config = {"from_attributes": True}
class ReferenceFileResponse(BaseModel):
id: UUID
client_id: UUID
file_type: ReferenceFileType
locale_scope: str
filename: str
file_path: str
uploaded_by: UUID | None = None
uploaded_at: datetime
last_updated_at: datetime
version: int
model_config = {"from_attributes": True}
class FileUploadResponse(BaseModel):
id: UUID
filename: str
file_path: str
message: str
model_config = {"from_attributes": True}

View file

@ -0,0 +1,94 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from app.models.job import JobStatus, JobType, LocaleStatus, LocaleType, Programme
class JobCreate(BaseModel):
client_id: UUID
campaign_name: str
programme: Programme
channel: str
sub_channel: str | None = None
context_prompt: str | None = None
job_type: JobType = JobType.main
parent_job_id: UUID | None = None
job_ref: str | None = None
locale_codes: list[str] = []
model_config = {"from_attributes": True}
class JobUpdate(BaseModel):
campaign_name: str | None = None
programme: Programme | None = None
channel: str | None = None
sub_channel: str | None = None
context_prompt: str | None = None
job_ref: str | None = None
model_config = {"from_attributes": True}
class LocaleInstanceResponse(BaseModel):
id: UUID
job_id: UUID
locale_code: str
locale_type: LocaleType
status: LocaleStatus
started_at: datetime | None = None
completed_at: datetime | None = None
token_usage: int = 0
estimated_cost: float = 0.0
output_file_path: str | None = None
error_log: str | None = None
tm_files_loaded: dict[str, Any] | None = None
ref_files_loaded: dict[str, Any] | None = None
agent_version: str | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class JobResponse(BaseModel):
id: UUID
client_id: UUID
created_by: UUID
job_ref: str | None = None
campaign_name: str
programme: Programme
channel: str
sub_channel: str | None = None
context_prompt: str | None = None
job_type: JobType
parent_job_id: UUID | None = None
status: JobStatus
total_token_usage: int = 0
total_estimated_cost: float = 0.0
locale_instances: list[LocaleInstanceResponse] = []
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class JobListResponse(BaseModel):
id: UUID
client_id: UUID
created_by: UUID
job_ref: str | None = None
campaign_name: str
programme: Programme
channel: str
status: JobStatus
total_token_usage: int = 0
total_estimated_cost: float = 0.0
locale_count: int = 0
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View file

@ -0,0 +1,51 @@
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from app.models.output import ConfidenceTier
class OutputRowResponse(BaseModel):
id: UUID
instance_id: UUID
line_id: UUID
row_order: int
confidence_tier: ConfidenceTier
option_1: str
backtranslation_1: str
rationale_1: str
option_2: str | None = None
backtranslation_2: str | None = None
rationale_2: str | None = None
option_3: str | None = None
backtranslation_3: str | None = None
rationale_3: str | None = None
tm_entries_cited: dict[str, Any] | None = None
winning_seg_key: str | None = None
character_count_option_1: int | None = None
character_count_option_2: int | None = None
character_count_option_3: int | None = None
model_config = {"from_attributes": True}
class SourceLinePreview(BaseModel):
id: UUID
row_order: int
en_gb: str
copy_type: str | None = None
creative_guidance: str | None = None
char_limit: str | None = None
model_config = {"from_attributes": True}
class OutputPreviewResponse(BaseModel):
locale_code: str
instance_id: UUID
source_lines: list[SourceLinePreview]
output_rows: list[OutputRowResponse]
total_rows: int
model_config = {"from_attributes": True}

View file

@ -0,0 +1,38 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, EmailStr
from app.models.user import UserRole, UserStatus
class UserCreate(BaseModel):
email: EmailStr
name: str
password: str
role: UserRole = UserRole.reviewer
client_ids: list[UUID] = []
model_config = {"from_attributes": True}
class UserUpdate(BaseModel):
email: EmailStr | None = None
name: str | None = None
password: str | None = None
role: UserRole | None = None
status: UserStatus | None = None
model_config = {"from_attributes": True}
class UserResponse(BaseModel):
id: UUID
email: str
name: str
role: UserRole
status: UserStatus
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View file

View file

@ -0,0 +1,77 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.audit import AuditLog
class AuditService:
"""Service for audit log creation and retrieval."""
async def log(
self,
db: AsyncSession,
action: str,
entity_type: str,
entity_id: str,
user_id: UUID | None = None,
details: dict[str, Any] | None = None,
ip_address: str | None = None,
) -> AuditLog:
"""Create an audit log entry."""
entry = AuditLog(
user_id=user_id,
action=action,
entity_type=entity_type,
entity_id=entity_id,
details=details,
ip_address=ip_address,
)
db.add(entry)
await db.flush()
return entry
async def list_logs(
self,
db: AsyncSession,
user_id: UUID | None = None,
action: str | None = None,
entity_type: str | None = None,
entity_id: str | None = None,
date_from: datetime | None = None,
date_to: datetime | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[AuditLog], int]:
"""List audit logs with filters and pagination."""
query = select(AuditLog)
if user_id:
query = query.where(AuditLog.user_id == user_id)
if action:
query = query.where(AuditLog.action == action)
if entity_type:
query = query.where(AuditLog.entity_type == entity_type)
if entity_id:
query = query.where(AuditLog.entity_id == entity_id)
if date_from:
query = query.where(AuditLog.timestamp >= date_from)
if date_to:
query = query.where(AuditLog.timestamp <= date_to)
# Count
count_query = select(func.count()).select_from(query.subquery())
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Paginate
query = query.order_by(AuditLog.timestamp.desc())
query = query.offset((page - 1) * page_size).limit(page_size)
result = await db.execute(query)
logs = list(result.scalars().all())
return logs, total

View file

@ -0,0 +1,66 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.feedback import Feedback
from app.schemas.feedback import FeedbackCreate
class FeedbackService:
"""Service for feedback CRUD operations."""
async def create_feedback(
self,
db: AsyncSession,
data: FeedbackCreate,
user_id: UUID,
) -> Feedback:
"""Create a new feedback entry."""
feedback = Feedback(
output_id=data.output_id,
user_id=user_id,
option_column=data.option_column,
flag_type=data.flag_type,
comment=data.comment,
)
db.add(feedback)
await db.flush()
return feedback
async def list_feedback(
self,
db: AsyncSession,
output_id: UUID | None = None,
user_id: UUID | None = None,
) -> list[Feedback]:
"""List feedback entries with optional filters."""
query = select(Feedback)
if output_id:
query = query.where(Feedback.output_id == output_id)
if user_id:
query = query.where(Feedback.user_id == user_id)
query = query.order_by(Feedback.created_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
async def get_feedback(
self, db: AsyncSession, feedback_id: UUID
) -> Feedback | None:
"""Get a single feedback entry by ID."""
result = await db.execute(
select(Feedback).where(Feedback.id == feedback_id)
)
return result.scalar_one_or_none()
async def delete_feedback(
self, db: AsyncSession, feedback_id: UUID
) -> bool:
"""Delete a feedback entry."""
feedback = await self.get_feedback(db, feedback_id)
if feedback is None:
return False
await db.delete(feedback)
await db.flush()
return True

View file

@ -0,0 +1,234 @@
import os
import shutil
from pathlib import Path
from typing import BinaryIO
from uuid import UUID, uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.files import ReferenceFile, ReferenceFileType, TMFileRegistry
from app.models.source import SourceLine
from app.pipeline.modules.source_file_parser import parse_source_file
class FileService:
"""Service for file upload, download, path resolution, and storage management."""
def __init__(self) -> None:
self.storage_root = Path(settings.STORAGE_ROOT)
def _resolve_path(self, *parts: str) -> Path:
"""Resolve a storage path and ensure parent directories exist."""
path = self.storage_root.joinpath(*parts)
path.parent.mkdir(parents=True, exist_ok=True)
return path
async def upload_source_file(
self,
db: AsyncSession,
job_id: UUID,
file: BinaryIO,
filename: str,
) -> list[SourceLine]:
"""Upload and parse a source xlsx file, creating SourceLine records."""
# Save to storage
file_path = self._resolve_path("jobs", str(job_id), "source", filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file, f)
# Parse the xlsx
parsed_lines = parse_source_file(str(file_path))
# Delete existing source lines for this job
existing = await db.execute(
select(SourceLine).where(SourceLine.job_id == job_id)
)
for line in existing.scalars().all():
await db.delete(line)
# Create new source lines
source_lines = []
for i, row in enumerate(parsed_lines):
source_line = SourceLine(
job_id=job_id,
row_order=i + 1,
en_gb=row["en_gb"],
copy_type=row.get("copy_type"),
creative_guidance=row.get("creative_guidance"),
visual_ref=row.get("visual_ref"),
char_limit=row.get("char_limit"),
is_display_format=row.get("is_display_format", False),
)
db.add(source_line)
source_lines.append(source_line)
await db.flush()
return source_lines
async def upload_supplementary_file(
self,
db: AsyncSession,
job_id: UUID,
file: BinaryIO,
filename: str,
) -> str:
"""Upload a supplementary file (TM, glossary, etc.) for a job."""
file_path = self._resolve_path("jobs", str(job_id), "supplementary", filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file, f)
return str(file_path)
async def upload_tm_file(
self,
db: AsyncSession,
client_id: UUID,
locale_code: str,
channel: str,
file: BinaryIO,
filename: str,
uploaded_by: UUID | None = None,
) -> TMFileRegistry:
"""Upload a TM file and create a registry entry."""
file_path = self._resolve_path(
"clients", str(client_id), "tm", locale_code, filename
)
with open(file_path, "wb") as f:
shutil.copyfileobj(file, f)
# Count segments (lines in JSONL)
segment_count = 0
with open(file_path, "r") as f:
for line in f:
if line.strip():
segment_count += 1
tm_file = TMFileRegistry(
client_id=client_id,
locale_code=locale_code,
channel=channel,
filename=filename,
file_path=str(file_path),
segment_count=segment_count,
uploaded_by=uploaded_by,
)
db.add(tm_file)
await db.flush()
return tm_file
async def upload_reference_file(
self,
db: AsyncSession,
client_id: UUID,
file_type: ReferenceFileType,
locale_scope: str,
file: BinaryIO,
filename: str,
uploaded_by: UUID | None = None,
) -> ReferenceFile:
"""Upload a reference file and create a registry entry."""
file_path = self._resolve_path(
"clients", str(client_id), "reference", file_type.value, filename
)
with open(file_path, "wb") as f:
shutil.copyfileobj(file, f)
ref_file = ReferenceFile(
client_id=client_id,
file_type=file_type,
locale_scope=locale_scope,
filename=filename,
file_path=str(file_path),
uploaded_by=uploaded_by,
)
db.add(ref_file)
await db.flush()
return ref_file
async def list_tm_files(
self,
db: AsyncSession,
client_id: UUID,
locale_code: str | None = None,
channel: str | None = None,
) -> list[TMFileRegistry]:
"""List TM files for a client with optional filters."""
query = select(TMFileRegistry).where(TMFileRegistry.client_id == client_id)
if locale_code:
query = query.where(TMFileRegistry.locale_code == locale_code)
if channel:
query = query.where(TMFileRegistry.channel == channel)
result = await db.execute(query.order_by(TMFileRegistry.uploaded_at.desc()))
return list(result.scalars().all())
async def list_reference_files(
self,
db: AsyncSession,
client_id: UUID,
file_type: ReferenceFileType | None = None,
locale_scope: str | None = None,
) -> list[ReferenceFile]:
"""List reference files for a client with optional filters."""
query = select(ReferenceFile).where(ReferenceFile.client_id == client_id)
if file_type:
query = query.where(ReferenceFile.file_type == file_type)
if locale_scope:
query = query.where(ReferenceFile.locale_scope == locale_scope)
result = await db.execute(query.order_by(ReferenceFile.uploaded_at.desc()))
return list(result.scalars().all())
def get_file_path(self, stored_path: str) -> Path | None:
"""Resolve a stored file path and verify it exists."""
path = Path(stored_path)
if path.exists():
return path
return None
async def delete_tm_file(
self, db: AsyncSession, file_id: UUID
) -> bool:
"""Delete a TM file from storage and database."""
result = await db.execute(
select(TMFileRegistry).where(TMFileRegistry.id == file_id)
)
tm_file = result.scalar_one_or_none()
if tm_file is None:
return False
# Remove from filesystem
file_path = Path(tm_file.file_path)
if file_path.exists():
os.remove(file_path)
await db.delete(tm_file)
await db.flush()
return True
async def delete_reference_file(
self, db: AsyncSession, file_id: UUID
) -> bool:
"""Delete a reference file from storage and database."""
result = await db.execute(
select(ReferenceFile).where(ReferenceFile.id == file_id)
)
ref_file = result.scalar_one_or_none()
if ref_file is None:
return False
file_path = Path(ref_file.file_path)
if file_path.exists():
os.remove(file_path)
await db.delete(ref_file)
await db.flush()
return True
def validate_file_extension(
self, filename: str, allowed_extensions: list[str]
) -> bool:
"""Validate that a file has an allowed extension."""
ext = Path(filename).suffix.lower()
return ext in allowed_extensions

View file

@ -0,0 +1,202 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus, LocaleType
from app.models.source import SourceLine
from app.schemas.job import JobCreate, JobUpdate
class JobService:
"""Service for Job CRUD and launch logic."""
async def create_job(
self, db: AsyncSession, data: JobCreate, user_id: UUID
) -> Job:
"""Create a new job with locale instances."""
job = Job(
client_id=data.client_id,
created_by=user_id,
campaign_name=data.campaign_name,
programme=data.programme,
channel=data.channel,
sub_channel=data.sub_channel,
context_prompt=data.context_prompt,
job_type=data.job_type,
parent_job_id=data.parent_job_id,
job_ref=data.job_ref,
status=JobStatus.created,
)
db.add(job)
await db.flush()
# Create locale instances
for i, locale_code in enumerate(data.locale_codes):
locale_type = LocaleType.main if i == 0 else LocaleType.derived
instance = LocaleInstance(
job_id=job.id,
locale_code=locale_code,
locale_type=locale_type,
status=LocaleStatus.queued,
)
db.add(instance)
await db.flush()
return job
async def get_job(self, db: AsyncSession, job_id: UUID) -> Job | None:
"""Get a job by ID with locale instances."""
result = await db.execute(
select(Job)
.options(selectinload(Job.locale_instances))
.where(Job.id == job_id)
)
return result.scalar_one_or_none()
async def list_jobs(
self,
db: AsyncSession,
client_id: UUID | None = None,
status: str | None = None,
locale: str | None = None,
date_from: datetime | None = None,
date_to: datetime | None = None,
search: str | None = None,
page: int = 1,
page_size: int = 20,
) -> tuple[list[Job], int]:
"""List jobs with filters and pagination."""
query = select(Job).options(selectinload(Job.locale_instances))
if client_id:
query = query.where(Job.client_id == client_id)
if status:
query = query.where(Job.status == status)
if locale:
query = query.join(LocaleInstance).where(
LocaleInstance.locale_code == locale
)
if date_from:
query = query.where(Job.created_at >= date_from)
if date_to:
query = query.where(Job.created_at <= date_to)
if search:
query = query.where(
Job.campaign_name.ilike(f"%{search}%")
| Job.job_ref.ilike(f"%{search}%")
)
# Count total
count_query = select(func.count()).select_from(query.subquery())
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# Apply pagination
query = query.order_by(Job.created_at.desc())
query = query.offset((page - 1) * page_size).limit(page_size)
result = await db.execute(query)
jobs = list(result.scalars().unique().all())
return jobs, total
async def update_job(
self, db: AsyncSession, job_id: UUID, data: JobUpdate
) -> Job | None:
"""Update a job."""
job = await self.get_job(db, job_id)
if job is None:
return None
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(job, field, value)
await db.flush()
return job
async def launch_job(self, db: AsyncSession, job_id: UUID) -> Job | None:
"""Validate and queue a job for processing."""
job = await self.get_job(db, job_id)
if job is None:
return None
if job.status != JobStatus.created:
raise ValueError(
f"Job cannot be launched from status '{job.status.value}'"
)
# Check source lines exist
result = await db.execute(
select(func.count()).where(SourceLine.job_id == job_id)
)
source_count = result.scalar() or 0
if source_count == 0:
raise ValueError("Job has no source lines. Upload source file first.")
# Check locale instances exist
if not job.locale_instances:
raise ValueError("Job has no locale instances configured.")
# Update status
job.status = JobStatus.queued
await db.flush()
# Dispatch Celery tasks
from app.tasks.job_tasks import process_job
process_job.delay(str(job_id))
return job
async def cancel_job(self, db: AsyncSession, job_id: UUID) -> Job | None:
"""Cancel a running job."""
job = await self.get_job(db, job_id)
if job is None:
return None
if job.status not in (JobStatus.queued, JobStatus.running):
raise ValueError(
f"Job cannot be cancelled from status '{job.status.value}'"
)
job.status = JobStatus.error
for instance in job.locale_instances:
if instance.status in (LocaleStatus.queued, LocaleStatus.running):
instance.status = LocaleStatus.error
instance.error_log = "Cancelled by user"
await db.flush()
return job
async def rerun_locale(
self, db: AsyncSession, job_id: UUID, locale_code: str
) -> LocaleInstance | None:
"""Re-run a single locale instance."""
result = await db.execute(
select(LocaleInstance).where(
LocaleInstance.job_id == job_id,
LocaleInstance.locale_code == locale_code,
)
)
instance = result.scalar_one_or_none()
if instance is None:
return None
instance.status = LocaleStatus.queued
instance.error_log = None
instance.started_at = None
instance.completed_at = None
instance.token_usage = 0
instance.estimated_cost = 0.0
await db.flush()
from app.tasks.job_tasks import process_locale_instance
process_locale_instance.delay(str(job_id), locale_code)
return instance

View file

@ -0,0 +1,104 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.job import LocaleInstance
from app.models.output import OutputRow
from app.models.source import SourceLine
from app.schemas.output import (
OutputPreviewResponse,
OutputRowResponse,
SourceLinePreview,
)
class OutputService:
"""Service for assembling output preview data and triggering exports."""
async def get_preview(
self,
db: AsyncSession,
job_id: UUID,
locale_code: str,
) -> OutputPreviewResponse | None:
"""Assemble output preview data for a specific locale instance."""
# Get the locale instance
result = await db.execute(
select(LocaleInstance)
.where(
LocaleInstance.job_id == job_id,
LocaleInstance.locale_code == locale_code,
)
)
instance = result.scalar_one_or_none()
if instance is None:
return None
# Get source lines
source_result = await db.execute(
select(SourceLine)
.where(SourceLine.job_id == job_id)
.order_by(SourceLine.row_order)
)
source_lines = [
SourceLinePreview.model_validate(sl)
for sl in source_result.scalars().all()
]
# Get output rows
output_result = await db.execute(
select(OutputRow)
.where(OutputRow.instance_id == instance.id)
.order_by(OutputRow.row_order)
)
output_rows = [
OutputRowResponse.model_validate(row)
for row in output_result.scalars().all()
]
return OutputPreviewResponse(
locale_code=locale_code,
instance_id=instance.id,
source_lines=source_lines,
output_rows=output_rows,
total_rows=len(output_rows),
)
async def get_output_rows(
self,
db: AsyncSession,
instance_id: UUID,
) -> list[OutputRow]:
"""Get all output rows for a locale instance."""
result = await db.execute(
select(OutputRow)
.where(OutputRow.instance_id == instance_id)
.order_by(OutputRow.row_order)
)
return list(result.scalars().all())
async def trigger_export(
self,
db: AsyncSession,
job_id: UUID,
locale_code: str,
) -> str | None:
"""Trigger export generation for a locale and return the file path."""
result = await db.execute(
select(LocaleInstance)
.where(
LocaleInstance.job_id == job_id,
LocaleInstance.locale_code == locale_code,
)
)
instance = result.scalar_one_or_none()
if instance is None:
return None
if instance.output_file_path:
return instance.output_file_path
# Export would be triggered here; for now return None indicating no export yet
return None

View file

@ -0,0 +1,133 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.audit import TokenUsageLog
from app.models.feedback import Feedback, FlagType
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
from app.models.output import OutputRow, ConfidenceTier
class ReportService:
"""Service for aggregation queries powering reports."""
async def get_usage_stats(
self,
db: AsyncSession,
client_id: UUID | None = None,
date_from: datetime | None = None,
date_to: datetime | None = None,
) -> dict[str, Any]:
"""Get overall usage statistics."""
job_query = select(
func.count(Job.id).label("total_jobs"),
func.sum(Job.total_token_usage).label("total_tokens"),
func.sum(Job.total_estimated_cost).label("total_cost"),
)
if client_id:
job_query = job_query.where(Job.client_id == client_id)
if date_from:
job_query = job_query.where(Job.created_at >= date_from)
if date_to:
job_query = job_query.where(Job.created_at <= date_to)
result = await db.execute(job_query)
row = result.one()
# Status breakdown
status_query = select(
Job.status, func.count(Job.id)
).group_by(Job.status)
if client_id:
status_query = status_query.where(Job.client_id == client_id)
status_result = await db.execute(status_query)
status_breakdown = {
status.value: count for status, count in status_result.all()
}
return {
"total_jobs": row.total_jobs or 0,
"total_tokens": row.total_tokens or 0,
"total_cost": float(row.total_cost or 0.0),
"status_breakdown": status_breakdown,
}
async def get_token_cost_data(
self,
db: AsyncSession,
client_id: UUID | None = None,
date_from: datetime | None = None,
date_to: datetime | None = None,
) -> list[dict[str, Any]]:
"""Get token usage and cost data grouped by date."""
query = (
select(
func.date_trunc("day", TokenUsageLog.timestamp).label("date"),
func.sum(TokenUsageLog.input_tokens).label("input_tokens"),
func.sum(TokenUsageLog.output_tokens).label("output_tokens"),
func.sum(TokenUsageLog.total_tokens).label("total_tokens"),
func.sum(TokenUsageLog.estimated_cost_usd).label("total_cost"),
)
.group_by(func.date_trunc("day", TokenUsageLog.timestamp))
.order_by(func.date_trunc("day", TokenUsageLog.timestamp))
)
if client_id:
query = query.join(LocaleInstance).join(Job).where(
Job.client_id == client_id
)
if date_from:
query = query.where(TokenUsageLog.timestamp >= date_from)
if date_to:
query = query.where(TokenUsageLog.timestamp <= date_to)
result = await db.execute(query)
return [
{
"date": str(row.date),
"input_tokens": row.input_tokens or 0,
"output_tokens": row.output_tokens or 0,
"total_tokens": row.total_tokens or 0,
"total_cost": float(row.total_cost or 0.0),
}
for row in result.all()
]
async def get_quality_metrics(
self,
db: AsyncSession,
client_id: UUID | None = None,
) -> dict[str, Any]:
"""Get quality metrics from output confidence tiers and feedback."""
# Confidence tier distribution
tier_query = select(
OutputRow.confidence_tier, func.count(OutputRow.id)
).group_by(OutputRow.confidence_tier)
if client_id:
tier_query = tier_query.join(LocaleInstance).join(Job).where(
Job.client_id == client_id
)
tier_result = await db.execute(tier_query)
tier_breakdown = {
tier.value: count for tier, count in tier_result.all()
}
# Feedback distribution
feedback_query = select(
Feedback.flag_type, func.count(Feedback.id)
).group_by(Feedback.flag_type)
feedback_result = await db.execute(feedback_query)
feedback_breakdown = {
ft.value: count for ft, count in feedback_result.all()
}
return {
"confidence_tiers": tier_breakdown,
"feedback_distribution": feedback_breakdown,
}

View file

View file

@ -0,0 +1,26 @@
"""Celery application configuration."""
from celery import Celery
from app.config import settings
celery_app = Celery(
"transcreation",
broker=settings.REDIS_URL,
backend=settings.REDIS_URL,
include=["app.tasks.job_tasks"],
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="UTC",
enable_utc=True,
task_track_started=True,
task_acks_late=True,
worker_prefetch_multiplier=1,
result_expires=3600,
task_soft_time_limit=1800, # 30 minutes
task_time_limit=3600, # 1 hour
)

View file

@ -0,0 +1,258 @@
"""Celery tasks for job processing."""
import asyncio
import logging
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
from app.models.output import ConfidenceTier, OutputRow
from app.models.source import SourceLine
from app.pipeline.orchestrator import PipelineOrchestrator
from app.tasks.celery_app import celery_app
logger = logging.getLogger(__name__)
def _get_async_session_factory() -> async_sessionmaker[AsyncSession]:
"""Create a fresh async session factory for use in Celery tasks."""
engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True)
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
@celery_app.task(bind=True, max_retries=2)
def process_job(self, job_id: str) -> dict:
"""Fan out locale instance processing for a job.
For each locale instance in the job, dispatches a process_locale_instance task.
"""
logger.info(f"Processing job {job_id}")
async def _process() -> dict:
factory = _get_async_session_factory()
async with factory() as db:
# Get job
result = await db.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if job is None:
return {"error": f"Job {job_id} not found"}
# Update status to running
job.status = JobStatus.running
await db.commit()
# Get locale instances
result = await db.execute(
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
)
instances = result.scalars().all()
# Dispatch per-locale tasks
task_ids = []
for instance in instances:
task = process_locale_instance.delay(job_id, instance.locale_code)
task_ids.append(task.id)
return {
"job_id": job_id,
"dispatched_locales": len(task_ids),
"task_ids": task_ids,
}
return asyncio.run(_process())
@celery_app.task(bind=True, max_retries=1)
def process_locale_instance(self, job_id: str, locale_code: str) -> dict:
"""Process a single locale instance through the pipeline.
Runs the pipeline orchestrator for the given job + locale combination.
"""
logger.info(f"Processing locale {locale_code} for job {job_id}")
async def _process() -> dict:
factory = _get_async_session_factory()
async with factory() as db:
# Get the locale instance
result = await db.execute(
select(LocaleInstance).where(
LocaleInstance.job_id == job_id,
LocaleInstance.locale_code == locale_code,
)
)
instance = result.scalar_one_or_none()
if instance is None:
return {"error": f"Locale instance not found: {job_id}/{locale_code}"}
# Get job
job_result = await db.execute(select(Job).where(Job.id == job_id))
job = job_result.scalar_one_or_none()
if job is None:
return {"error": f"Job {job_id} not found"}
# Update instance status
instance.status = LocaleStatus.running
instance.started_at = datetime.now(timezone.utc)
await db.commit()
try:
# Get source lines
source_result = await db.execute(
select(SourceLine)
.where(SourceLine.job_id == job_id)
.order_by(SourceLine.row_order)
)
source_lines = [
{
"id": str(sl.id),
"en_gb": sl.en_gb,
"copy_type": sl.copy_type,
"creative_guidance": sl.creative_guidance,
"visual_ref": sl.visual_ref,
"char_limit": sl.char_limit,
"is_display_format": sl.is_display_format,
}
for sl in source_result.scalars().all()
]
# Build job params
job_params = {
"job_id": str(job.id),
"client_id": str(job.client_id),
"locale_code": locale_code,
"channel": job.channel,
"sub_channel": job.sub_channel,
"programme": job.programme.value,
"campaign_name": job.campaign_name,
"context_prompt": job.context_prompt,
}
# Run pipeline
orchestrator = PipelineOrchestrator(
job_params=job_params,
source_lines=source_lines,
output_dir=settings.STORAGE_ROOT,
)
context = await orchestrator.run()
# Save output rows to database
for i, draft in enumerate(context.draft_outputs):
# Find matching source line
source_line_id = None
if i < len(source_lines):
source_line_id = source_lines[i].get("id")
# Get confidence tier
tier = "low"
if i < len(context.ranking_declarations):
tier = context.ranking_declarations[i].confidence_tier
# Get character counts
char_counts = {}
if i < len(context.compliance_results):
char_counts = context.compliance_results[i].character_counts
output_row = OutputRow(
instance_id=instance.id,
line_id=source_line_id,
row_order=i + 1,
confidence_tier=ConfidenceTier(tier),
option_1=draft.option_1.text if draft.option_1 else "",
backtranslation_1=draft.option_1.backtranslation if draft.option_1 else "",
rationale_1=draft.option_1.rationale if draft.option_1 else "",
option_2=draft.option_2.text if draft.option_2 else None,
backtranslation_2=draft.option_2.backtranslation if draft.option_2 else None,
rationale_2=draft.option_2.rationale if draft.option_2 else None,
option_3=draft.option_3.text if draft.option_3 else None,
backtranslation_3=draft.option_3.backtranslation if draft.option_3 else None,
rationale_3=draft.option_3.rationale if draft.option_3 else None,
tm_entries_cited=draft.tm_entries_cited if draft.tm_entries_cited else None,
character_count_option_1=char_counts.get("option_1"),
character_count_option_2=char_counts.get("option_2"),
character_count_option_3=char_counts.get("option_3"),
)
db.add(output_row)
# Update instance status
instance.status = LocaleStatus.complete
instance.completed_at = datetime.now(timezone.utc)
output_path = (
f"{settings.STORAGE_ROOT}/jobs/{job_id}/output/"
f"{locale_code}_{job_id}_output.xlsx"
)
instance.output_file_path = output_path
instance.agent_version = "1.0.0"
await db.commit()
# Check if all instances are complete
await _check_job_completion(db, job_id)
return {
"job_id": job_id,
"locale_code": locale_code,
"status": "complete",
"output_rows": len(context.draft_outputs),
}
except Exception as e:
logger.error(
f"Error processing {locale_code} for job {job_id}: {e}",
exc_info=True,
)
instance.status = LocaleStatus.error
instance.error_log = str(e)
instance.completed_at = datetime.now(timezone.utc)
await db.commit()
# Check if job should be marked as partial/error
await _check_job_completion(db, job_id)
return {
"job_id": job_id,
"locale_code": locale_code,
"status": "error",
"error": str(e),
}
return asyncio.run(_process())
async def _check_job_completion(db: AsyncSession, job_id: str) -> None:
"""Check if all locale instances are done and update job status accordingly."""
result = await db.execute(
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
)
instances = list(result.scalars().all())
if not instances:
return
all_complete = all(i.status == LocaleStatus.complete for i in instances)
all_done = all(
i.status in (LocaleStatus.complete, LocaleStatus.error)
for i in instances
)
any_error = any(i.status == LocaleStatus.error for i in instances)
job_result = await db.execute(select(Job).where(Job.id == job_id))
job = job_result.scalar_one_or_none()
if job is None:
return
if all_complete:
job.status = JobStatus.complete
# Sum up token usage
job.total_token_usage = sum(i.token_usage for i in instances)
job.total_estimated_cost = sum(i.estimated_cost for i in instances)
elif all_done and any_error:
any_complete = any(i.status == LocaleStatus.complete for i in instances)
if any_complete:
job.status = JobStatus.partial_complete
else:
job.status = JobStatus.error
await db.commit()

View file

47
backend/app/ws/handler.py Normal file
View file

@ -0,0 +1,47 @@
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from app.auth.service import AuthService
from app.ws.manager import manager
ws_router = APIRouter()
auth_service = AuthService()
@ws_router.websocket("/ws/jobs/{job_id}")
async def websocket_endpoint(
websocket: WebSocket,
job_id: str,
token: str = Query(...),
) -> None:
"""WebSocket endpoint for real-time job progress updates.
Authentication is performed via query parameter token.
"""
# Validate token
claims = auth_service.validate_token(token)
if claims is None:
await websocket.close(code=4001, reason="Invalid or expired token")
return
await manager.connect(job_id, websocket)
try:
# Send initial connection message
await manager.send_personal(
websocket,
{
"type": "connected",
"job_id": job_id,
"message": "Connected to job progress stream",
},
)
# Keep connection alive and listen for client messages
while True:
data = await websocket.receive_text()
# Currently we don't process incoming messages,
# but the loop keeps the connection alive
except WebSocketDisconnect:
manager.disconnect(job_id, websocket)
except Exception:
manager.disconnect(job_id, websocket)

59
backend/app/ws/manager.py Normal file
View file

@ -0,0 +1,59 @@
import json
from typing import Any
from fastapi import WebSocket
class ConnectionManager:
"""Manage WebSocket connections grouped by job_id."""
def __init__(self) -> None:
self._connections: dict[str, list[WebSocket]] = {}
async def connect(self, job_id: str, websocket: WebSocket) -> None:
"""Accept a websocket connection and register it for a job."""
await websocket.accept()
if job_id not in self._connections:
self._connections[job_id] = []
self._connections[job_id].append(websocket)
def disconnect(self, job_id: str, websocket: WebSocket) -> None:
"""Remove a websocket connection from a job's connection list."""
if job_id in self._connections:
self._connections[job_id] = [
ws for ws in self._connections[job_id] if ws != websocket
]
if not self._connections[job_id]:
del self._connections[job_id]
async def broadcast(self, job_id: str, message: dict[str, Any]) -> None:
"""Broadcast a JSON message to all connections for a job."""
if job_id not in self._connections:
return
payload = json.dumps(message)
disconnected: list[WebSocket] = []
for websocket in self._connections[job_id]:
try:
await websocket.send_text(payload)
except Exception:
disconnected.append(websocket)
# Clean up dead connections
for ws in disconnected:
self.disconnect(job_id, ws)
async def send_personal(
self, websocket: WebSocket, message: dict[str, Any]
) -> None:
"""Send a message to a specific websocket."""
await websocket.send_text(json.dumps(message))
def get_connection_count(self, job_id: str) -> int:
"""Get the number of active connections for a job."""
return len(self._connections.get(job_id, []))
# Singleton instance
manager = ConnectionManager()

17
backend/requirements.txt Normal file
View file

@ -0,0 +1,17 @@
fastapi>=0.111.0
uvicorn[standard]>=0.29.0
sqlalchemy[asyncio]>=2.0.30
asyncpg>=0.29.0
alembic>=1.13.0
celery[redis]>=5.4.0
redis>=5.0.0
pydantic[email]>=2.7.0
pydantic-settings>=2.2.0
python-jose[cryptography]>=3.3.0
bcrypt>=4.0.0
python-multipart>=0.0.9
openpyxl>=3.1.0
grapheme>=0.6.0
anthropic>=0.25.0
websockets>=12.0
httpx>=0.27.0

View file

0
backend/tests/fixtures/__init__.py vendored Normal file
View file

View file

View file

View file

View file

71
docker-compose.yml Normal file
View file

@ -0,0 +1,71 @@
version: "3.9"
services:
db:
image: postgres:16
restart: unless-stopped
environment:
POSTGRES_USER: transcreation
POSTGRES_PASSWORD: transcreation
POSTGRES_DB: transcreation
ports:
- "5492:5432"
volumes:
- pgdata:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U transcreation"]
interval: 5s
timeout: 5s
retries: 5
redis:
image: redis:7-alpine
restart: unless-stopped
ports:
- "6389:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 5s
retries: 5
backend:
build:
context: ./backend
dockerfile: Dockerfile
restart: unless-stopped
ports:
- "8040:8000"
env_file:
- .env
volumes:
- ./backend:/app
- ./storage:/storage
- ./seed:/app/seed
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
celery_worker:
build:
context: ./backend
dockerfile: Dockerfile
restart: unless-stopped
env_file:
- .env
volumes:
- ./backend:/app
- ./storage:/storage
- ./seed:/app/seed
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
command: celery -A app.tasks.celery_app worker --loglevel=info --concurrency=4
volumes:
pgdata:

View file

@ -0,0 +1,2 @@
NEXT_PUBLIC_API_URL=http://localhost:8000
NEXT_PUBLIC_WS_URL=ws://localhost:8000

35
frontend/.gitignore vendored Normal file
View file

@ -0,0 +1,35 @@
# dependencies
/node_modules
/.pnp
.pnp.js
.yarn/install-state.gz
# testing
/coverage
# next.js
/.next/
/out/
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# local env files
.env*.local
.env
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts

15
frontend/Dockerfile Normal file
View file

@ -0,0 +1,15 @@
FROM node:20-alpine AS builder
WORKDIR /app
COPY package.json package-lock.json* ./
RUN npm ci
COPY . .
RUN npm run build
FROM node:20-alpine AS runner
WORKDIR /app
ENV NODE_ENV production
COPY --from=builder /app/public ./public
COPY --from=builder /app/.next/standalone ./
COPY --from=builder /app/.next/static ./.next/static
EXPOSE 3000
CMD ["node", "server.js"]

16
frontend/components.json Normal file
View file

@ -0,0 +1,16 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "default",
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.ts",
"css": "src/app/globals.css",
"baseColor": "slate",
"cssVariables": true
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"
}
}

Some files were not shown because too many files have changed in this diff Show more