feat: complete Phase 1-2 scaffold — backend, frontend, pipeline skeleton
Full-stack Amazon AI Transcreation Platform with: - FastAPI backend (async, PostgreSQL, Redis, Celery) with 11 DB tables - JWT auth (SSO-ready abstract provider pattern) - 6-agent pipeline orchestrator with deterministic modules - Next.js 14 frontend with Amazon branding (Ember fonts, orange/dark theme) - Job wizard, monitoring HUD, output review, admin screens - 154 TM/reference files imported, 12 locales configured - Docker Compose for all services Agents 2-5 (TM retrieval, ranker, transcreator, compliance) are stubs pending Phase 3 LLM integration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e3c3dccfe9
commit
98fa16bfc3
180 changed files with 21920 additions and 39 deletions
19
.env.example
Normal file
19
.env.example
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Database
|
||||
DATABASE_URL=postgresql+asyncpg://transcreation:transcreation@db:5432/transcreation
|
||||
|
||||
# Redis
|
||||
REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# Anthropic
|
||||
ANTHROPIC_API_KEY=sk-ant-REPLACE_ME
|
||||
|
||||
# Auth
|
||||
JWT_SECRET_KEY=CHANGE_ME_TO_A_RANDOM_SECRET
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRY_HOURS=8
|
||||
|
||||
# Storage
|
||||
STORAGE_ROOT=/storage
|
||||
|
||||
# LLM
|
||||
LLM_MODEL=claude-sonnet-4-20250514
|
||||
75
.gitignore
vendored
75
.gitignore
vendored
|
|
@ -1,50 +1,47 @@
|
|||
# These are some examples of commonly ignored file patterns.
|
||||
# You should customize this list as applicable to your project.
|
||||
# Learn more about .gitignore:
|
||||
# https://www.atlassian.com/git/tutorials/saving-changes/gitignore
|
||||
|
||||
# Node artifact files
|
||||
node_modules/
|
||||
dist/
|
||||
|
||||
# Compiled Java class files
|
||||
*.class
|
||||
|
||||
# Compiled Python bytecode
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
# Log files
|
||||
*.log
|
||||
|
||||
# Package files
|
||||
*.jar
|
||||
|
||||
# Maven
|
||||
target/
|
||||
*$py.class
|
||||
*.so
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.eggs/
|
||||
venv/
|
||||
.venv/
|
||||
|
||||
# JetBrains IDE
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Unit test reports
|
||||
TEST*.xml
|
||||
# Node
|
||||
node_modules/
|
||||
.next/
|
||||
out/
|
||||
|
||||
# Generated by MacOS
|
||||
# Storage (keep structure, ignore uploaded files)
|
||||
storage/*
|
||||
!storage/.gitkeep
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
# Generated by Windows
|
||||
Thumbs.db
|
||||
|
||||
# Applications
|
||||
*.app
|
||||
*.exe
|
||||
*.war
|
||||
# Docker
|
||||
docker-compose.override.yml
|
||||
|
||||
# Large media files
|
||||
*.mp4
|
||||
*.tiff
|
||||
*.avi
|
||||
*.flv
|
||||
*.mov
|
||||
*.wmv
|
||||
# Test
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
|
||||
# Misc
|
||||
*.log
|
||||
*.bak
|
||||
|
|
|
|||
35
Makefile
Normal file
35
Makefile
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
.PHONY: up down build migrate seed test shell logs
|
||||
|
||||
up:
|
||||
docker compose up -d
|
||||
|
||||
down:
|
||||
docker compose down
|
||||
|
||||
build:
|
||||
docker compose build
|
||||
|
||||
migrate:
|
||||
docker compose exec backend alembic upgrade head
|
||||
|
||||
seed:
|
||||
docker compose exec backend python -m seed.create_default_client
|
||||
docker compose exec backend python -m seed.create_test_users
|
||||
|
||||
test:
|
||||
docker compose exec backend python -m pytest tests/ -v
|
||||
|
||||
shell:
|
||||
docker compose exec backend bash
|
||||
|
||||
logs:
|
||||
docker compose logs -f
|
||||
|
||||
restart:
|
||||
docker compose restart backend celery_worker
|
||||
|
||||
db-shell:
|
||||
docker compose exec db psql -U transcreation
|
||||
|
||||
redis-cli:
|
||||
docker compose exec redis redis-cli
|
||||
179
README.md
Normal file
179
README.md
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
# Amazon AI Transcreation Platform
|
||||
|
||||
An AI-powered platform for transcreating marketing content across European locales. Built for Amazon's creative workflows, it combines Claude LLM capabilities with translation memory, glossaries, and brand voice profiles to produce culturally adapted copy at scale.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Layer | Technology |
|
||||
|-------------|---------------------------------------------|
|
||||
| Frontend | Next.js 14, React 18, Tailwind CSS, Radix UI |
|
||||
| Backend | FastAPI, Python 3.12, SQLAlchemy 2 (async) |
|
||||
| Database | PostgreSQL 16 |
|
||||
| Cache/Queue | Redis 7, Celery |
|
||||
| LLM | Anthropic Claude (via API) |
|
||||
| Infra | Docker, Docker Compose |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- [Docker](https://docs.docker.com/get-docker/) and Docker Compose v2
|
||||
- An [Anthropic API key](https://console.anthropic.com/) for Claude access
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Clone the repository
|
||||
git clone <repo-url> amazon-transcreation
|
||||
cd amazon-transcreation
|
||||
|
||||
# 2. Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env and set ANTHROPIC_API_KEY, JWT_SECRET_KEY
|
||||
|
||||
# 3. Start all services
|
||||
docker compose up -d
|
||||
|
||||
# 4. Run database migrations
|
||||
make migrate
|
||||
|
||||
# 5. Seed the database with the Amazon client
|
||||
make seed
|
||||
|
||||
# 6. Open the app
|
||||
# Backend API docs: http://localhost:8000/docs
|
||||
# Frontend: http://localhost:3000
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
+---------------------------------------------------------------+
|
||||
| Frontend (Next.js) |
|
||||
| Dashboard | Job Builder | Review UI | TM Manager |
|
||||
+-------------------------------+-------------------------------+
|
||||
| REST + WebSocket |
|
||||
+-------------------------------+-------------------------------+
|
||||
| Backend (FastAPI) |
|
||||
| Auth | Jobs API | TM API | Files API | WebSocket |
|
||||
+-------------------------------+-------------------------------+
|
||||
| |
|
||||
+--------------+------+ +-------+----------------------------+
|
||||
| PostgreSQL | | Celery Workers |
|
||||
| - clients | | - LLM transcreation tasks |
|
||||
| - users | | - TM lookup & scoring |
|
||||
| - jobs | | - File processing |
|
||||
| - source_lines | +------------------------------------+
|
||||
| - output_rows | |
|
||||
| - feedback | +----------+-------------------------+
|
||||
| - tm_file_registry| | Claude API (Anthropic) |
|
||||
| - reference_files | | - Transcreation generation |
|
||||
| - audit_logs | | - Quality scoring |
|
||||
+---------------------+ +------------------------------------+
|
||||
|
|
||||
+---------------------+
|
||||
| Storage (volume) |
|
||||
| - TM files (JSONL) |
|
||||
| - Reference files |
|
||||
+---------------------+
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Backend
|
||||
|
||||
```bash
|
||||
# Shell into the backend container
|
||||
make shell
|
||||
|
||||
# Run tests
|
||||
make test
|
||||
|
||||
# View logs
|
||||
make logs
|
||||
|
||||
# Access the database
|
||||
make db-shell
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
The frontend dev server runs on port 3000 and proxies API requests to the backend on port 8000.
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Backend tests (inside Docker)
|
||||
make test
|
||||
|
||||
# Frontend tests
|
||||
cd frontend && npm run lint
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
FastAPI auto-generates interactive API documentation:
|
||||
|
||||
- **Swagger UI**: [http://localhost:8000/docs](http://localhost:8000/docs)
|
||||
- **ReDoc**: [http://localhost:8000/redoc](http://localhost:8000/redoc)
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
amazon-transcreation/
|
||||
├── backend/
|
||||
│ ├── alembic/ # Database migrations
|
||||
│ ├── app/
|
||||
│ │ ├── api/v1/ # REST API routes
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||
│ │ ├── services/ # Business logic
|
||||
│ │ ├── tasks/ # Celery async tasks
|
||||
│ │ ├── ws/ # WebSocket handlers
|
||||
│ │ ├── config.py # Settings from environment
|
||||
│ │ ├── dependencies.py # FastAPI dependency injection
|
||||
│ │ └── main.py # Application factory
|
||||
│ ├── Dockerfile
|
||||
│ └── requirements.txt
|
||||
├── frontend/
|
||||
│ ├── public/ # Static assets, fonts
|
||||
│ ├── src/ # Next.js pages and components
|
||||
│ ├── Dockerfile
|
||||
│ └── package.json
|
||||
├── scripts/
|
||||
│ ├── tm_format_migrator.py # Convert compact TM to multi-field JSONL
|
||||
│ └── validate_tm_files.py # Audit TM files for format and coverage
|
||||
├── seed/
|
||||
│ ├── create_default_client.py # Seed Amazon client + voice profiles
|
||||
│ └── import_reference_files.py # Import TMs and refs to storage/
|
||||
├── storage/ # Mounted volume for TM and reference files
|
||||
├── .env.example # Environment variable template
|
||||
├── docker-compose.yml # Service orchestration
|
||||
├── Makefile # Development shortcuts
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Key Concepts
|
||||
|
||||
- **Client**: An organization (e.g., Amazon) with voice profiles, supported locales, and channel configurations stored as JSON settings.
|
||||
- **Voice Profile**: Brand personality (Retail, Prime, Brand) that guides tone and style during transcreation.
|
||||
- **Channel**: A content distribution context (Mass TV, Onsite, etc.) with its own translation memory.
|
||||
- **Translation Memory (TM)**: JSONL files of previously approved source/target pairs, used for context and consistency.
|
||||
- **Reference Files**: Glossaries, blacklists, date/percent format guides, and locale-specific considerations that constrain the LLM output.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|---------------------|--------------------------------------|------------------------------|
|
||||
| `DATABASE_URL` | PostgreSQL async connection string | `postgresql+asyncpg://...` |
|
||||
| `REDIS_URL` | Redis connection string | `redis://redis:6379/0` |
|
||||
| `ANTHROPIC_API_KEY` | Anthropic API key for Claude | (required) |
|
||||
| `JWT_SECRET_KEY` | Secret for signing JWT tokens | (required, change default) |
|
||||
| `JWT_ALGORITHM` | JWT signing algorithm | `HS256` |
|
||||
| `JWT_EXPIRY_HOURS` | Token expiry in hours | `8` |
|
||||
| `STORAGE_ROOT` | Root path for file storage | `/storage` |
|
||||
| `LLM_MODEL` | Claude model identifier | `claude-sonnet-4-20250514` |
|
||||
19
backend/Dockerfile
Normal file
19
backend/Dockerfile
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc libpq-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
40
backend/alembic.ini
Normal file
40
backend/alembic.ini
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
71
backend/alembic/env.py
Normal file
71
backend/alembic/env.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# Import all models so Alembic can discover them
|
||||
from app.models import Base # noqa: F401
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Set the SQLAlchemy URL from settings
|
||||
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL)
|
||||
|
||||
# Logging configuration
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# MetaData for autogenerate support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with async engine."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
backend/alembic/script.py.mako
Normal file
26
backend/alembic/script.py.mako
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
0
backend/alembic/versions/.gitkeep
Normal file
0
backend/alembic/versions/.gitkeep
Normal file
219
backend/alembic/versions/d4a016fd0817_initial_schema.py
Normal file
219
backend/alembic/versions/d4a016fd0817_initial_schema.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""initial_schema
|
||||
|
||||
Revision ID: d4a016fd0817
|
||||
Revises:
|
||||
Create Date: 2026-04-10 16:25:45.318282
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'd4a016fd0817'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('clients',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('settings', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=False),
|
||||
sa.Column('role', sa.Enum('admin', 'tm_manager', 'reviewer', name='user_role', create_constraint=True), nullable=False),
|
||||
sa.Column('status', sa.Enum('active', 'inactive', name='user_status', create_constraint=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email')
|
||||
)
|
||||
op.create_table('audit_logs',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('action', sa.String(length=100), nullable=False),
|
||||
sa.Column('entity_type', sa.String(length=100), nullable=False),
|
||||
sa.Column('entity_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('details', sa.JSON(), nullable=True),
|
||||
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('jobs',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('created_by', sa.Uuid(), nullable=False),
|
||||
sa.Column('job_ref', sa.String(length=100), nullable=True),
|
||||
sa.Column('campaign_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('programme', sa.Enum('retail', 'prime', 'brand', name='programme_type', create_constraint=True), nullable=False),
|
||||
sa.Column('channel', sa.String(length=100), nullable=False),
|
||||
sa.Column('sub_channel', sa.String(length=100), nullable=True),
|
||||
sa.Column('context_prompt', sa.Text(), nullable=True),
|
||||
sa.Column('job_type', sa.Enum('main', 'derived', name='job_type', create_constraint=True), nullable=False),
|
||||
sa.Column('parent_job_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('status', sa.Enum('created', 'validating', 'queued', 'running', 'partial_complete', 'complete', 'error', 'exported', name='job_status', create_constraint=True), nullable=False),
|
||||
sa.Column('total_token_usage', sa.Integer(), nullable=False),
|
||||
sa.Column('total_estimated_cost', sa.Float(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['parent_job_id'], ['jobs.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('reference_files',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('file_type', sa.Enum('glossary', 'blacklist', 'tov_global', 'tov_supplement', 'locale_considerations', 'date_pct_formats', name='reference_file_type', create_constraint=True), nullable=False),
|
||||
sa.Column('locale_scope', sa.String(length=10), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('uploaded_by', sa.Uuid(), nullable=True),
|
||||
sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('version', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['uploaded_by'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('tm_file_registry',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('locale_code', sa.String(length=10), nullable=False),
|
||||
sa.Column('channel', sa.String(length=100), nullable=False),
|
||||
sa.Column('filename', sa.String(length=255), nullable=False),
|
||||
sa.Column('file_path', sa.String(length=500), nullable=False),
|
||||
sa.Column('segment_count', sa.Integer(), nullable=False),
|
||||
sa.Column('uploaded_by', sa.Uuid(), nullable=True),
|
||||
sa.Column('uploaded_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('version', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['uploaded_by'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('user_clients',
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('role_override', sa.String(length=50), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['clients.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('user_id', 'client_id')
|
||||
)
|
||||
op.create_table('locale_instances',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('job_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('locale_code', sa.String(length=10), nullable=False),
|
||||
sa.Column('locale_type', sa.Enum('main', 'derived', name='locale_type', create_constraint=True), nullable=False),
|
||||
sa.Column('status', sa.Enum('queued', 'running', 'complete', 'error', name='locale_status', create_constraint=True), nullable=False),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('token_usage', sa.Integer(), nullable=False),
|
||||
sa.Column('estimated_cost', sa.Float(), nullable=False),
|
||||
sa.Column('output_file_path', sa.String(length=500), nullable=True),
|
||||
sa.Column('error_log', sa.Text(), nullable=True),
|
||||
sa.Column('tm_files_loaded', sa.JSON(), nullable=True),
|
||||
sa.Column('ref_files_loaded', sa.JSON(), nullable=True),
|
||||
sa.Column('agent_version', sa.String(length=50), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('source_lines',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('job_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('row_order', sa.Integer(), nullable=False),
|
||||
sa.Column('en_gb', sa.Text(), nullable=False),
|
||||
sa.Column('copy_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('creative_guidance', sa.Text(), nullable=True),
|
||||
sa.Column('visual_ref', sa.String(length=500), nullable=True),
|
||||
sa.Column('char_limit', sa.String(length=50), nullable=True),
|
||||
sa.Column('is_display_format', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['job_id'], ['jobs.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('output_rows',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('instance_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('line_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('row_order', sa.Integer(), nullable=False),
|
||||
sa.Column('confidence_tier', sa.Enum('high', 'moderate', 'low', name='confidence_tier', create_constraint=True), nullable=False),
|
||||
sa.Column('option_1', sa.Text(), nullable=False),
|
||||
sa.Column('backtranslation_1', sa.Text(), nullable=False),
|
||||
sa.Column('rationale_1', sa.Text(), nullable=False),
|
||||
sa.Column('option_2', sa.Text(), nullable=True),
|
||||
sa.Column('backtranslation_2', sa.Text(), nullable=True),
|
||||
sa.Column('rationale_2', sa.Text(), nullable=True),
|
||||
sa.Column('option_3', sa.Text(), nullable=True),
|
||||
sa.Column('backtranslation_3', sa.Text(), nullable=True),
|
||||
sa.Column('rationale_3', sa.Text(), nullable=True),
|
||||
sa.Column('tm_entries_cited', sa.JSON(), nullable=True),
|
||||
sa.Column('winning_seg_key', sa.String(length=255), nullable=True),
|
||||
sa.Column('character_count_option_1', sa.Integer(), nullable=True),
|
||||
sa.Column('character_count_option_2', sa.Integer(), nullable=True),
|
||||
sa.Column('character_count_option_3', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['instance_id'], ['locale_instances.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['line_id'], ['source_lines.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('token_usage_logs',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('instance_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('agent_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('model', sa.String(length=100), nullable=False),
|
||||
sa.Column('input_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('output_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('total_tokens', sa.Integer(), nullable=False),
|
||||
sa.Column('estimated_cost_usd', sa.Float(), nullable=False),
|
||||
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['instance_id'], ['locale_instances.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('feedback',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('output_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('option_column', sa.Integer(), nullable=False),
|
||||
sa.Column('flag_type', sa.Enum('approved', 'needs_revision', 'comment', name='flag_type', create_constraint=True), nullable=False),
|
||||
sa.Column('comment', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['output_id'], ['output_rows.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('feedback')
|
||||
op.drop_table('token_usage_logs')
|
||||
op.drop_table('output_rows')
|
||||
op.drop_table('source_lines')
|
||||
op.drop_table('locale_instances')
|
||||
op.drop_table('user_clients')
|
||||
op.drop_table('tm_file_registry')
|
||||
op.drop_table('reference_files')
|
||||
op.drop_table('jobs')
|
||||
op.drop_table('audit_logs')
|
||||
op.drop_table('users')
|
||||
op.drop_table('clients')
|
||||
# ### end Alembic commands ###
|
||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/v1/__init__.py
Normal file
0
backend/app/api/v1/__init__.py
Normal file
60
backend/app/api/v1/audit.py
Normal file
60
backend/app/api/v1/audit.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_role
|
||||
from app.services.audit_service import AuditService
|
||||
|
||||
router = APIRouter(prefix="/audit", tags=["audit"])
|
||||
audit_service = AuditService()
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def list_audit_logs(
|
||||
user_id: UUID | None = Query(None),
|
||||
action: str | None = Query(None),
|
||||
entity_type: str | None = Query(None),
|
||||
entity_id: str | None = Query(None),
|
||||
date_from: datetime | None = Query(None),
|
||||
date_to: datetime | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> dict[str, Any]:
|
||||
"""List audit logs with filters (admin only)."""
|
||||
logs, total = await audit_service.list_logs(
|
||||
db,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
pages = (total + page_size - 1) // page_size if total > 0 else 1
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(log.id),
|
||||
"user_id": str(log.user_id) if log.user_id else None,
|
||||
"action": log.action,
|
||||
"entity_type": log.entity_type,
|
||||
"entity_id": log.entity_id,
|
||||
"details": log.details,
|
||||
"timestamp": log.timestamp.isoformat(),
|
||||
"ip_address": log.ip_address,
|
||||
}
|
||||
for log in logs
|
||||
],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"pages": pages,
|
||||
}
|
||||
105
backend/app/api/v1/clients.py
Normal file
105
backend/app/api/v1/clients.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db, require_role
|
||||
from app.models.client import Client
|
||||
from app.schemas.client import ClientCreate, ClientResponse, ClientUpdate
|
||||
from app.schemas.common import PaginatedResponse
|
||||
|
||||
router = APIRouter(prefix="/clients", tags=["clients"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ClientResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_client(
|
||||
body: ClientCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> ClientResponse:
|
||||
"""Create a new client (admin only)."""
|
||||
client = Client(name=body.name, settings=body.settings)
|
||||
db.add(client)
|
||||
await db.flush()
|
||||
return ClientResponse.model_validate(client)
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[ClientResponse])
|
||||
async def list_clients(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> PaginatedResponse[ClientResponse]:
|
||||
"""List all clients (admin only)."""
|
||||
count_result = await db.execute(select(func.count(Client.id)))
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
select(Client)
|
||||
.order_by(Client.created_at.desc())
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
clients = [ClientResponse.model_validate(c) for c in result.scalars().all()]
|
||||
|
||||
pages = (total + page_size - 1) // page_size if total > 0 else 1
|
||||
return PaginatedResponse(
|
||||
items=clients, total=total, page=page, page_size=page_size, pages=pages
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{client_id}", response_model=ClientResponse)
|
||||
async def get_client(
|
||||
client_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> ClientResponse:
|
||||
"""Get a client by ID (admin only)."""
|
||||
result = await db.execute(select(Client).where(Client.id == client_id))
|
||||
client = result.scalar_one_or_none()
|
||||
if client is None:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
return ClientResponse.model_validate(client)
|
||||
|
||||
|
||||
@router.put("/{client_id}", response_model=ClientResponse)
|
||||
async def update_client(
|
||||
client_id: UUID,
|
||||
body: ClientUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> ClientResponse:
|
||||
"""Update a client (admin only)."""
|
||||
result = await db.execute(select(Client).where(Client.id == client_id))
|
||||
client = result.scalar_one_or_none()
|
||||
if client is None:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(client, field, value)
|
||||
|
||||
await db.flush()
|
||||
return ClientResponse.model_validate(client)
|
||||
|
||||
|
||||
@router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_client(
|
||||
client_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> None:
|
||||
"""Delete a client (admin only)."""
|
||||
result = await db.execute(select(Client).where(Client.id == client_id))
|
||||
client = result.scalar_one_or_none()
|
||||
if client is None:
|
||||
raise HTTPException(status_code=404, detail="Client not found")
|
||||
|
||||
await db.delete(client)
|
||||
await db.flush()
|
||||
179
backend/app/api/v1/files.py
Normal file
179
backend/app/api/v1/files.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, status
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.files import ReferenceFileType
|
||||
from app.schemas.files import (
|
||||
FileUploadResponse,
|
||||
ReferenceFileResponse,
|
||||
TMFileResponse,
|
||||
)
|
||||
from app.services.file_service import FileService
|
||||
|
||||
router = APIRouter(prefix="/files", tags=["files"])
|
||||
file_service = FileService()
|
||||
|
||||
|
||||
# ---- TM Files ----
|
||||
|
||||
|
||||
@router.post("/tm", response_model=FileUploadResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_tm_file(
|
||||
client_id: UUID = Query(...),
|
||||
locale_code: str = Query(...),
|
||||
channel: str = Query(...),
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> FileUploadResponse:
|
||||
"""Upload a Translation Memory (JSONL) file."""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File must have a filename")
|
||||
if not file_service.validate_file_extension(file.filename, [".jsonl", ".json"]):
|
||||
raise HTTPException(status_code=400, detail="Only .jsonl/.json files accepted")
|
||||
|
||||
tm = await file_service.upload_tm_file(
|
||||
db, client_id, locale_code, channel, file.file, file.filename,
|
||||
uploaded_by=current_user["user_id"],
|
||||
)
|
||||
return FileUploadResponse(
|
||||
id=tm.id,
|
||||
filename=tm.filename,
|
||||
file_path=tm.file_path,
|
||||
message=f"Uploaded TM file with {tm.segment_count} segments",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tm", response_model=list[TMFileResponse])
|
||||
async def list_tm_files(
|
||||
client_id: UUID = Query(...),
|
||||
locale_code: str | None = Query(None),
|
||||
channel: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[TMFileResponse]:
|
||||
"""List TM files for a client."""
|
||||
files = await file_service.list_tm_files(db, client_id, locale_code, channel)
|
||||
return [TMFileResponse.model_validate(f) for f in files]
|
||||
|
||||
|
||||
@router.get("/tm/{file_id}/download")
|
||||
async def download_tm_file(
|
||||
file_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> FileResponse:
|
||||
"""Download a TM file."""
|
||||
from sqlalchemy import select
|
||||
from app.models.files import TMFileRegistry
|
||||
|
||||
result = await db.execute(
|
||||
select(TMFileRegistry).where(TMFileRegistry.id == file_id)
|
||||
)
|
||||
tm = result.scalar_one_or_none()
|
||||
if tm is None:
|
||||
raise HTTPException(status_code=404, detail="TM file not found")
|
||||
|
||||
path = file_service.get_file_path(tm.file_path)
|
||||
if path is None:
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
|
||||
return FileResponse(path=str(path), filename=tm.filename)
|
||||
|
||||
|
||||
@router.delete("/tm/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_tm_file(
|
||||
file_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> None:
|
||||
"""Delete a TM file."""
|
||||
deleted = await file_service.delete_tm_file(db, file_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="TM file not found")
|
||||
|
||||
|
||||
# ---- Reference Files ----
|
||||
|
||||
|
||||
@router.post(
|
||||
"/reference",
|
||||
response_model=FileUploadResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def upload_reference_file(
|
||||
client_id: UUID = Query(...),
|
||||
file_type: ReferenceFileType = Query(...),
|
||||
locale_scope: str = Query(...),
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> FileUploadResponse:
|
||||
"""Upload a reference file (glossary, blacklist, TOV, etc.)."""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File must have a filename")
|
||||
|
||||
ref = await file_service.upload_reference_file(
|
||||
db, client_id, file_type, locale_scope, file.file, file.filename,
|
||||
uploaded_by=current_user["user_id"],
|
||||
)
|
||||
return FileUploadResponse(
|
||||
id=ref.id,
|
||||
filename=ref.filename,
|
||||
file_path=ref.file_path,
|
||||
message=f"Uploaded {file_type.value} reference file",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/reference", response_model=list[ReferenceFileResponse])
|
||||
async def list_reference_files(
|
||||
client_id: UUID = Query(...),
|
||||
file_type: ReferenceFileType | None = Query(None),
|
||||
locale_scope: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[ReferenceFileResponse]:
|
||||
"""List reference files for a client."""
|
||||
files = await file_service.list_reference_files(
|
||||
db, client_id, file_type, locale_scope
|
||||
)
|
||||
return [ReferenceFileResponse.model_validate(f) for f in files]
|
||||
|
||||
|
||||
@router.get("/reference/{file_id}/download")
|
||||
async def download_reference_file(
|
||||
file_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> FileResponse:
|
||||
"""Download a reference file."""
|
||||
from sqlalchemy import select
|
||||
from app.models.files import ReferenceFile
|
||||
|
||||
result = await db.execute(
|
||||
select(ReferenceFile).where(ReferenceFile.id == file_id)
|
||||
)
|
||||
ref = result.scalar_one_or_none()
|
||||
if ref is None:
|
||||
raise HTTPException(status_code=404, detail="Reference file not found")
|
||||
|
||||
path = file_service.get_file_path(ref.file_path)
|
||||
if path is None:
|
||||
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||
|
||||
return FileResponse(path=str(path), filename=ref.filename)
|
||||
|
||||
|
||||
@router.delete("/reference/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_reference_file(
|
||||
file_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> None:
|
||||
"""Delete a reference file."""
|
||||
deleted = await file_service.delete_reference_file(db, file_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Reference file not found")
|
||||
206
backend/app/api/v1/jobs.py
Normal file
206
backend/app/api/v1/jobs.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.schemas.common import PaginatedResponse
|
||||
from app.schemas.job import JobCreate, JobListResponse, JobResponse, JobUpdate, LocaleInstanceResponse
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.file_service import FileService
|
||||
from app.services.job_service import JobService
|
||||
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
job_service = JobService()
|
||||
file_service = FileService()
|
||||
audit_service = AuditService()
|
||||
|
||||
|
||||
@router.post("", response_model=JobResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_job(
|
||||
body: JobCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> JobResponse:
|
||||
"""Create a new transcreation job."""
|
||||
job = await job_service.create_job(db, body, current_user["user_id"])
|
||||
await audit_service.log(
|
||||
db, "create", "job", str(job.id), current_user["user_id"],
|
||||
details={"campaign_name": body.campaign_name},
|
||||
)
|
||||
return JobResponse.model_validate(job)
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[JobListResponse])
|
||||
async def list_jobs(
|
||||
client_id: UUID | None = Query(None),
|
||||
job_status: str | None = Query(None, alias="status"),
|
||||
locale: str | None = Query(None),
|
||||
date_from: datetime | None = Query(None),
|
||||
date_to: datetime | None = Query(None),
|
||||
search: str | None = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> PaginatedResponse[JobListResponse]:
|
||||
"""List jobs with filters and pagination."""
|
||||
jobs, total = await job_service.list_jobs(
|
||||
db,
|
||||
client_id=client_id,
|
||||
status=job_status,
|
||||
locale=locale,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
search=search,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
items = []
|
||||
for job in jobs:
|
||||
item = JobListResponse(
|
||||
id=job.id,
|
||||
client_id=job.client_id,
|
||||
created_by=job.created_by,
|
||||
job_ref=job.job_ref,
|
||||
campaign_name=job.campaign_name,
|
||||
programme=job.programme,
|
||||
channel=job.channel,
|
||||
status=job.status,
|
||||
total_token_usage=job.total_token_usage,
|
||||
total_estimated_cost=job.total_estimated_cost,
|
||||
locale_count=len(job.locale_instances) if job.locale_instances else 0,
|
||||
created_at=job.created_at,
|
||||
updated_at=job.updated_at,
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
pages = (total + page_size - 1) // page_size if total > 0 else 1
|
||||
return PaginatedResponse(
|
||||
items=items, total=total, page=page, page_size=page_size, pages=pages
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=JobResponse)
|
||||
async def get_job(
|
||||
job_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> JobResponse:
|
||||
"""Get job details with locale instances."""
|
||||
job = await job_service.get_job(db, job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return JobResponse.model_validate(job)
|
||||
|
||||
|
||||
@router.put("/{job_id}/source")
|
||||
async def upload_source(
|
||||
job_id: UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Upload source xlsx file for a job."""
|
||||
if not file.filename or not file.filename.endswith(".xlsx"):
|
||||
raise HTTPException(status_code=400, detail="Only .xlsx files are accepted")
|
||||
|
||||
job = await job_service.get_job(db, job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
source_lines = await file_service.upload_source_file(
|
||||
db, job_id, file.file, file.filename
|
||||
)
|
||||
await audit_service.log(
|
||||
db, "upload_source", "job", str(job_id), current_user["user_id"],
|
||||
details={"filename": file.filename, "line_count": len(source_lines)},
|
||||
)
|
||||
return {
|
||||
"message": f"Uploaded {len(source_lines)} source lines",
|
||||
"line_count": len(source_lines),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{job_id}/supplementary")
|
||||
async def upload_supplementary(
|
||||
job_id: UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Upload supplementary files for a job."""
|
||||
job = await job_service.get_job(db, job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
file_path = await file_service.upload_supplementary_file(
|
||||
db, job_id, file.file, file.filename or "unknown"
|
||||
)
|
||||
return {"message": "File uploaded", "file_path": file_path}
|
||||
|
||||
|
||||
@router.post("/{job_id}/launch", response_model=JobResponse)
|
||||
async def launch_job(
|
||||
job_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> JobResponse:
|
||||
"""Validate and queue a job for processing."""
|
||||
try:
|
||||
job = await job_service.launch_job(db, job_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
await audit_service.log(
|
||||
db, "launch", "job", str(job_id), current_user["user_id"],
|
||||
)
|
||||
return JobResponse.model_validate(job)
|
||||
|
||||
|
||||
@router.post("/{job_id}/cancel", response_model=JobResponse)
|
||||
async def cancel_job(
|
||||
job_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> JobResponse:
|
||||
"""Cancel a running job."""
|
||||
try:
|
||||
job = await job_service.cancel_job(db, job_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
await audit_service.log(
|
||||
db, "cancel", "job", str(job_id), current_user["user_id"],
|
||||
)
|
||||
return JobResponse.model_validate(job)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{job_id}/locales/{locale_code}/rerun",
|
||||
response_model=LocaleInstanceResponse,
|
||||
)
|
||||
async def rerun_locale(
|
||||
job_id: UUID,
|
||||
locale_code: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> LocaleInstanceResponse:
|
||||
"""Re-run a single locale instance."""
|
||||
instance = await job_service.rerun_locale(db, job_id, locale_code)
|
||||
if instance is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Locale instance '{locale_code}' not found for job",
|
||||
)
|
||||
await audit_service.log(
|
||||
db, "rerun_locale", "locale_instance", str(instance.id),
|
||||
current_user["user_id"],
|
||||
details={"locale_code": locale_code},
|
||||
)
|
||||
return LocaleInstanceResponse.model_validate(instance)
|
||||
84
backend/app/api/v1/output.py
Normal file
84
backend/app/api/v1/output.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.schemas.feedback import FeedbackCreate, FeedbackResponse
|
||||
from app.schemas.output import OutputPreviewResponse
|
||||
from app.services.feedback_service import FeedbackService
|
||||
from app.services.output_service import OutputService
|
||||
|
||||
router = APIRouter(prefix="/output", tags=["output"])
|
||||
output_service = OutputService()
|
||||
feedback_service = FeedbackService()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/jobs/{job_id}/locales/{locale_code}/preview",
|
||||
response_model=OutputPreviewResponse,
|
||||
)
|
||||
async def get_output_preview(
|
||||
job_id: UUID,
|
||||
locale_code: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> OutputPreviewResponse:
|
||||
"""Get preview data for a locale instance output."""
|
||||
preview = await output_service.get_preview(db, job_id, locale_code)
|
||||
if preview is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No output found for this job/locale combination",
|
||||
)
|
||||
return preview
|
||||
|
||||
|
||||
@router.post(
|
||||
"/feedback",
|
||||
response_model=FeedbackResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_feedback(
|
||||
body: FeedbackCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> FeedbackResponse:
|
||||
"""Submit feedback on an output row."""
|
||||
feedback = await feedback_service.create_feedback(
|
||||
db, body, current_user["user_id"]
|
||||
)
|
||||
return FeedbackResponse.model_validate(feedback)
|
||||
|
||||
|
||||
@router.get("/feedback/{output_id}", response_model=list[FeedbackResponse])
|
||||
async def list_feedback_for_output(
|
||||
output_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[FeedbackResponse]:
|
||||
"""List all feedback for a specific output row."""
|
||||
items = await feedback_service.list_feedback(db, output_id=output_id)
|
||||
return [FeedbackResponse.model_validate(f) for f in items]
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/locales/{locale_code}/export")
|
||||
async def export_output(
|
||||
job_id: UUID,
|
||||
locale_code: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> FileResponse:
|
||||
"""Export output as xlsx for a locale instance."""
|
||||
file_path = await output_service.trigger_export(db, job_id, locale_code)
|
||||
if file_path is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No export available for this job/locale combination",
|
||||
)
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=f"{job_id}_{locale_code}_output.xlsx",
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
50
backend/app/api/v1/reports.py
Normal file
50
backend/app/api/v1/reports.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.services.report_service import ReportService
|
||||
|
||||
router = APIRouter(prefix="/reports", tags=["reports"])
|
||||
report_service = ReportService()
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_usage_stats(
|
||||
client_id: UUID | None = Query(None),
|
||||
date_from: datetime | None = Query(None),
|
||||
date_to: datetime | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Get usage statistics (total jobs, tokens, cost, status breakdown)."""
|
||||
return await report_service.get_usage_stats(
|
||||
db, client_id=client_id, date_from=date_from, date_to=date_to
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tokens")
|
||||
async def get_token_cost_data(
|
||||
client_id: UUID | None = Query(None),
|
||||
date_from: datetime | None = Query(None),
|
||||
date_to: datetime | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get token usage and cost data grouped by day."""
|
||||
return await report_service.get_token_cost_data(
|
||||
db, client_id=client_id, date_from=date_from, date_to=date_to
|
||||
)
|
||||
|
||||
|
||||
@router.get("/quality")
|
||||
async def get_quality_metrics(
|
||||
client_id: UUID | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Get quality metrics (confidence tier distribution, feedback stats)."""
|
||||
return await report_service.get_quality_metrics(db, client_id=client_id)
|
||||
21
backend/app/api/v1/router.py
Normal file
21
backend/app/api/v1/router.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from app.auth.router import router as auth_router
|
||||
from app.api.v1.jobs import router as jobs_router
|
||||
from app.api.v1.users import router as users_router
|
||||
from app.api.v1.clients import router as clients_router
|
||||
from app.api.v1.files import router as files_router
|
||||
from app.api.v1.output import router as output_router
|
||||
from app.api.v1.reports import router as reports_router
|
||||
from app.api.v1.audit import router as audit_router
|
||||
|
||||
api_v1_router = APIRouter(prefix="/api/v1")
|
||||
|
||||
api_v1_router.include_router(auth_router)
|
||||
api_v1_router.include_router(jobs_router)
|
||||
api_v1_router.include_router(users_router)
|
||||
api_v1_router.include_router(clients_router)
|
||||
api_v1_router.include_router(files_router)
|
||||
api_v1_router.include_router(output_router)
|
||||
api_v1_router.include_router(reports_router)
|
||||
api_v1_router.include_router(audit_router)
|
||||
130
backend/app/api/v1/users.py
Normal file
130
backend/app/api/v1/users.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.service import AuthService
|
||||
from app.dependencies import get_db, require_role
|
||||
from app.models.user import User, UserClient, UserStatus
|
||||
from app.schemas.common import PaginatedResponse
|
||||
from app.schemas.user import UserCreate, UserResponse, UserUpdate
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_user(
|
||||
body: UserCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> UserResponse:
|
||||
"""Create a new user (admin only)."""
|
||||
# Check for duplicate email
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
user = User(
|
||||
email=body.email,
|
||||
name=body.name,
|
||||
password_hash=auth_service.hash_password(body.password),
|
||||
role=body.role,
|
||||
status=UserStatus.active,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
# Associate with clients
|
||||
for client_id in body.client_ids:
|
||||
uc = UserClient(user_id=user.id, client_id=client_id)
|
||||
db.add(uc)
|
||||
|
||||
await db.flush()
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[UserResponse])
|
||||
async def list_users(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> PaginatedResponse[UserResponse]:
|
||||
"""List all users (admin only)."""
|
||||
count_result = await db.execute(select(func.count(User.id)))
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
select(User)
|
||||
.order_by(User.created_at.desc())
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
users = [UserResponse.model_validate(u) for u in result.scalars().all()]
|
||||
|
||||
pages = (total + page_size - 1) // page_size if total > 0 else 1
|
||||
return PaginatedResponse(
|
||||
items=users, total=total, page=page, page_size=page_size, pages=pages
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> UserResponse:
|
||||
"""Get a user by ID (admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
body: UserUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> UserResponse:
|
||||
"""Update a user (admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = auth_service.hash_password(
|
||||
update_data.pop("password")
|
||||
)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
await db.flush()
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(require_role(["admin"])),
|
||||
) -> None:
|
||||
"""Soft-delete a user by setting status to inactive (admin only)."""
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.status = UserStatus.inactive
|
||||
await db.flush()
|
||||
0
backend/app/auth/__init__.py
Normal file
0
backend/app/auth/__init__.py
Normal file
33
backend/app/auth/middleware.py
Normal file
33
backend/app/auth/middleware.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.auth.service import AuthService
|
||||
|
||||
security = HTTPBearer()
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
async def decode_jwt(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> dict:
|
||||
"""Decode JWT from Authorization header and return user claims."""
|
||||
token = credentials.credentials
|
||||
claims = auth_service.validate_token(token)
|
||||
if claims is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
)
|
||||
if claims.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
return {
|
||||
"user_id": UUID(claims["sub"]),
|
||||
"email": claims.get("email", ""),
|
||||
"role": claims.get("role", ""),
|
||||
"name": claims.get("name", ""),
|
||||
}
|
||||
0
backend/app/auth/providers/__init__.py
Normal file
0
backend/app/auth/providers/__init__.py
Normal file
36
backend/app/auth/providers/base.py
Normal file
36
backend/app/auth/providers/base.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AbstractAuthProvider(ABC):
|
||||
"""Abstract base for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, email: str, password: str) -> dict[str, Any] | None:
|
||||
"""Authenticate a user. Returns user dict or None if invalid."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_access_token(self, data: dict[str, Any]) -> str:
|
||||
"""Create an access token for the given data."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_refresh_token(self, data: dict[str, Any]) -> str:
|
||||
"""Create a refresh token for the given data."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def validate_token(self, token: str) -> dict[str, Any] | None:
|
||||
"""Validate a token and return the claims, or None if invalid."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def hash_password(self, password: str) -> str:
|
||||
"""Hash a plaintext password."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a plaintext password against a hash."""
|
||||
...
|
||||
52
backend/app/auth/providers/jwt_provider.py
Normal file
52
backend/app/auth/providers/jwt_provider.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config import settings
|
||||
from app.auth.providers.base import AbstractAuthProvider
|
||||
|
||||
|
||||
class JWTAuthProvider(AbstractAuthProvider):
|
||||
"""JWT-based authentication provider with bcrypt password hashing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.secret_key = settings.JWT_SECRET_KEY
|
||||
self.algorithm = settings.JWT_ALGORITHM
|
||||
self.access_token_expiry = timedelta(hours=settings.JWT_EXPIRY_HOURS)
|
||||
self.refresh_token_expiry = timedelta(days=7)
|
||||
|
||||
async def authenticate(self, email: str, password: str) -> dict[str, Any] | None:
|
||||
"""Not used directly - authentication happens through AuthService."""
|
||||
return None
|
||||
|
||||
def create_access_token(self, data: dict[str, Any]) -> str:
|
||||
"""Create a JWT access token (8-hour expiry by default)."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + self.access_token_expiry
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def create_refresh_token(self, data: dict[str, Any]) -> str:
|
||||
"""Create a JWT refresh token (7-day expiry)."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + self.refresh_token_expiry
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def validate_token(self, token: str) -> dict[str, Any] | None:
|
||||
"""Validate a JWT and return its claims, or None if invalid."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
44
backend/app/auth/router.py
Normal file
44
backend/app/auth/router.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.schemas import LoginRequest, RefreshRequest, TokenResponse, UserClaims
|
||||
from app.auth.service import AuthService
|
||||
from app.dependencies import get_current_user, get_db
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
body: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TokenResponse:
|
||||
"""Authenticate user and return access + refresh tokens."""
|
||||
result = await auth_service.login(body.email, body.password, db)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
return TokenResponse(**result)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_token(body: RefreshRequest) -> TokenResponse:
|
||||
"""Exchange a valid refresh token for a new token pair."""
|
||||
result = auth_service.refresh_tokens(body.refresh_token)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
return TokenResponse(**result)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserClaims)
|
||||
async def get_me(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> UserClaims:
|
||||
"""Return the current authenticated user's claims."""
|
||||
return UserClaims(**current_user)
|
||||
25
backend/app/auth/schemas.py
Normal file
25
backend/app/auth/schemas.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserClaims(BaseModel):
|
||||
user_id: UUID
|
||||
email: str
|
||||
role: str
|
||||
name: str
|
||||
70
backend/app/auth/service.py
Normal file
70
backend/app/auth/service.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.providers.jwt_provider import JWTAuthProvider
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Authentication service wrapping the JWT provider."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider = JWTAuthProvider()
|
||||
|
||||
async def login(
|
||||
self, email: str, password: str, db: AsyncSession
|
||||
) -> dict[str, str] | None:
|
||||
"""Authenticate a user and return tokens, or None if invalid."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
return None
|
||||
if not self.provider.verify_password(password, user.password_hash):
|
||||
return None
|
||||
if user.status.value != "active":
|
||||
return None
|
||||
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"role": user.role.value,
|
||||
"name": user.name,
|
||||
}
|
||||
|
||||
return {
|
||||
"access_token": self.provider.create_access_token(token_data),
|
||||
"refresh_token": self.provider.create_refresh_token(token_data),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
def refresh_tokens(self, refresh_token: str) -> dict[str, str] | None:
|
||||
"""Validate a refresh token and issue new token pair."""
|
||||
claims = self.provider.validate_token(refresh_token)
|
||||
if claims is None:
|
||||
return None
|
||||
if claims.get("type") != "refresh":
|
||||
return None
|
||||
|
||||
token_data = {
|
||||
"sub": claims["sub"],
|
||||
"email": claims.get("email", ""),
|
||||
"role": claims.get("role", ""),
|
||||
"name": claims.get("name", ""),
|
||||
}
|
||||
|
||||
return {
|
||||
"access_token": self.provider.create_access_token(token_data),
|
||||
"refresh_token": self.provider.create_refresh_token(token_data),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
def validate_token(self, token: str) -> dict[str, Any] | None:
|
||||
"""Validate a token and return claims."""
|
||||
return self.provider.validate_token(token)
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
"""Hash a password."""
|
||||
return self.provider.hash_password(password)
|
||||
23
backend/app/config.py
Normal file
23
backend/app/config.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
DATABASE_URL: str = "postgresql+asyncpg://transcreation:transcreation@db:5432/transcreation"
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
JWT_SECRET_KEY: str = "CHANGE_ME_TO_A_RANDOM_SECRET"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_EXPIRY_HOURS: int = 8
|
||||
STORAGE_ROOT: str = "/storage"
|
||||
LLM_MODEL: str = "claude-sonnet-4-20250514"
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
82
backend/app/dependencies.py
Normal file
82
backend/app/dependencies.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
from typing import AsyncGenerator, Callable
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_size=20,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
_async_session_factory = async_sessionmaker(
|
||||
bind=_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield an async database session."""
|
||||
async with _async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> dict:
|
||||
"""Decode JWT from Authorization header and return user claims."""
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token: missing subject",
|
||||
)
|
||||
return {
|
||||
"user_id": UUID(user_id),
|
||||
"email": payload.get("email", ""),
|
||||
"role": payload.get("role", ""),
|
||||
"name": payload.get("name", ""),
|
||||
}
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Could not validate credentials: {exc}",
|
||||
)
|
||||
|
||||
|
||||
def require_role(roles: list[str]) -> Callable:
|
||||
"""Dependency factory that enforces role-based access."""
|
||||
|
||||
async def _check_role(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
) -> dict:
|
||||
if current_user["role"] not in roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role '{current_user['role']}' not permitted. Required: {roles}",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return _check_role
|
||||
0
backend/app/llm/__init__.py
Normal file
0
backend/app/llm/__init__.py
Normal file
143
backend/app/llm/client.py
Normal file
143
backend/app/llm/client.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Anthropic SDK wrapper with retry logic and token tracking."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cost per token (approximate, varies by model)
|
||||
COST_PER_INPUT_TOKEN = 3.0 / 1_000_000 # $3 per 1M input tokens
|
||||
COST_PER_OUTPUT_TOKEN = 15.0 / 1_000_000 # $15 per 1M output tokens
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Wrapper around the Anthropic SDK with retry and token tracking.
|
||||
|
||||
Provides exponential backoff retry on rate limit and server errors.
|
||||
Tracks token usage per call for cost monitoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
) -> None:
|
||||
self.api_key = api_key or settings.ANTHROPIC_API_KEY
|
||||
self.model = model or settings.LLM_MODEL
|
||||
self.max_retries = max_retries
|
||||
self.base_delay = base_delay
|
||||
|
||||
self.client = anthropic.Anthropic(api_key=self.api_key)
|
||||
self.last_usage: dict[str, Any] = {}
|
||||
|
||||
def create_message(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Send a message to Claude and return the response with usage data.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt.
|
||||
user_message: The user message.
|
||||
max_tokens: Maximum tokens in the response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, usage_dict).
|
||||
usage_dict has keys: input_tokens, output_tokens, total_tokens, estimated_cost_usd.
|
||||
|
||||
Raises:
|
||||
anthropic.APIError: If all retries are exhausted.
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(1, self.max_retries + 1):
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_prompt,
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Extract text
|
||||
response_text = ""
|
||||
for block in response.content:
|
||||
if hasattr(block, "text"):
|
||||
response_text += block.text
|
||||
|
||||
# Track usage
|
||||
input_tokens = response.usage.input_tokens
|
||||
output_tokens = response.usage.output_tokens
|
||||
total_tokens = input_tokens + output_tokens
|
||||
estimated_cost = (
|
||||
input_tokens * COST_PER_INPUT_TOKEN
|
||||
+ output_tokens * COST_PER_OUTPUT_TOKEN
|
||||
)
|
||||
|
||||
usage = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"estimated_cost_usd": round(estimated_cost, 6),
|
||||
"model": self.model,
|
||||
}
|
||||
self.last_usage = usage
|
||||
|
||||
return response_text, usage
|
||||
|
||||
except anthropic.RateLimitError as e:
|
||||
last_error = e
|
||||
delay = self.base_delay * (2 ** (attempt - 1))
|
||||
logger.warning(
|
||||
f"Rate limited (attempt {attempt}/{self.max_retries}), "
|
||||
f"retrying in {delay}s"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
except anthropic.APIStatusError as e:
|
||||
if e.status_code >= 500:
|
||||
last_error = e
|
||||
delay = self.base_delay * (2 ** (attempt - 1))
|
||||
logger.warning(
|
||||
f"Server error {e.status_code} (attempt {attempt}/{self.max_retries}), "
|
||||
f"retrying in {delay}s"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise
|
||||
|
||||
raise last_error # type: ignore[misc]
|
||||
|
||||
async def acreate_message(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Async version of create_message using the async client.
|
||||
|
||||
Same interface as create_message but uses asyncio.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Run sync client in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.create_message(
|
||||
system_prompt, user_message, max_tokens, temperature
|
||||
),
|
||||
)
|
||||
101
backend/app/llm/token_tracker.py
Normal file
101
backend/app/llm/token_tracker.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Token usage tracking - records LLM usage to the database."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.audit import TokenUsageLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def record_token_usage(
|
||||
db: AsyncSession,
|
||||
instance_id: UUID,
|
||||
agent_name: str,
|
||||
usage: dict[str, Any],
|
||||
) -> TokenUsageLog:
|
||||
"""Record token usage from an LLM call to the database.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
instance_id: The locale instance ID this usage is for.
|
||||
agent_name: Name of the agent that made the call.
|
||||
usage: Usage dict from LLMClient with keys:
|
||||
input_tokens, output_tokens, total_tokens,
|
||||
estimated_cost_usd, model.
|
||||
|
||||
Returns:
|
||||
The created TokenUsageLog record.
|
||||
"""
|
||||
log_entry = TokenUsageLog(
|
||||
instance_id=instance_id,
|
||||
agent_name=agent_name,
|
||||
model=usage.get("model", "unknown"),
|
||||
input_tokens=usage.get("input_tokens", 0),
|
||||
output_tokens=usage.get("output_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
estimated_cost_usd=usage.get("estimated_cost_usd", 0.0),
|
||||
)
|
||||
db.add(log_entry)
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Token usage recorded: agent={agent_name}, "
|
||||
f"tokens={usage.get('total_tokens', 0)}, "
|
||||
f"cost=${usage.get('estimated_cost_usd', 0.0):.6f}"
|
||||
)
|
||||
|
||||
return log_entry
|
||||
|
||||
|
||||
async def get_total_usage_for_instance(
|
||||
db: AsyncSession,
|
||||
instance_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated token usage for a locale instance.
|
||||
|
||||
Args:
|
||||
db: Async database session.
|
||||
instance_id: The locale instance ID.
|
||||
|
||||
Returns:
|
||||
Dict with total_tokens, total_cost, by_agent breakdown.
|
||||
"""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
result = await db.execute(
|
||||
select(
|
||||
func.sum(TokenUsageLog.input_tokens).label("total_input"),
|
||||
func.sum(TokenUsageLog.output_tokens).label("total_output"),
|
||||
func.sum(TokenUsageLog.total_tokens).label("total_tokens"),
|
||||
func.sum(TokenUsageLog.estimated_cost_usd).label("total_cost"),
|
||||
).where(TokenUsageLog.instance_id == instance_id)
|
||||
)
|
||||
row = result.one()
|
||||
|
||||
# By-agent breakdown
|
||||
agent_result = await db.execute(
|
||||
select(
|
||||
TokenUsageLog.agent_name,
|
||||
func.sum(TokenUsageLog.total_tokens).label("tokens"),
|
||||
func.sum(TokenUsageLog.estimated_cost_usd).label("cost"),
|
||||
)
|
||||
.where(TokenUsageLog.instance_id == instance_id)
|
||||
.group_by(TokenUsageLog.agent_name)
|
||||
)
|
||||
|
||||
by_agent = {
|
||||
agent_name: {"tokens": tokens, "cost": float(cost)}
|
||||
for agent_name, tokens, cost in agent_result.all()
|
||||
}
|
||||
|
||||
return {
|
||||
"total_input_tokens": row.total_input or 0,
|
||||
"total_output_tokens": row.total_output or 0,
|
||||
"total_tokens": row.total_tokens or 0,
|
||||
"total_cost_usd": float(row.total_cost or 0.0),
|
||||
"by_agent": by_agent,
|
||||
}
|
||||
63
backend/app/main.py
Normal file
63
backend/app/main.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
from app.api.v1.router import api_v1_router
|
||||
from app.ws.handler import ws_router
|
||||
|
||||
engine: AsyncEngine | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Manage application lifespan: set up and tear down DB engine."""
|
||||
global engine
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_size=20,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
yield
|
||||
if engine is not None:
|
||||
await engine.dispose()
|
||||
engine = None
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Application factory."""
|
||||
application = FastAPI(
|
||||
title="Amazon AI Transcreation Platform",
|
||||
description="Backend API for AI-powered marketing transcreation",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware (permissive for development)
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# API routes
|
||||
application.include_router(api_v1_router)
|
||||
|
||||
# WebSocket routes
|
||||
application.include_router(ws_router)
|
||||
|
||||
@application.get("/health", tags=["health"])
|
||||
async def health_check() -> dict:
|
||||
return {"status": "healthy", "version": "1.0.0"}
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = create_app()
|
||||
11
backend/app/models/__init__.py
Normal file
11
backend/app/models/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""SQLAlchemy models - import all models so Alembic can discover them."""
|
||||
|
||||
from app.models.base import Base, TimestampMixin # noqa: F401
|
||||
from app.models.client import Client # noqa: F401
|
||||
from app.models.user import User, UserClient # noqa: F401
|
||||
from app.models.job import Job, LocaleInstance # noqa: F401
|
||||
from app.models.source import SourceLine # noqa: F401
|
||||
from app.models.output import OutputRow # noqa: F401
|
||||
from app.models.feedback import Feedback # noqa: F401
|
||||
from app.models.files import TMFileRegistry, ReferenceFile # noqa: F401
|
||||
from app.models.audit import AuditLog, TokenUsageLog # noqa: F401
|
||||
56
backend/app/models/audit.py
Normal file
56
backend/app/models/audit.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey, Integer, JSON, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, generate_uuid
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
entity_type: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
entity_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
details: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="audit_logs")
|
||||
|
||||
|
||||
class TokenUsageLog(Base):
|
||||
__tablename__ = "token_usage_logs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
instance_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("locale_instances.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
agent_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
model: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
input_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
output_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
estimated_cost_usd: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
instance = relationship("LocaleInstance", back_populates="token_usage_logs")
|
||||
30
backend/app/models/base.py
Normal file
30
backend/app/models/base.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all SQLAlchemy models."""
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin that adds created_at and updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def generate_uuid() -> uuid.UUID:
|
||||
return uuid.uuid4()
|
||||
22
backend/app/models/client.py
Normal file
22
backend/app/models/client.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, generate_uuid
|
||||
|
||||
|
||||
class Client(Base, TimestampMixin):
|
||||
__tablename__ = "clients"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
settings: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user_clients = relationship("UserClient", back_populates="client", lazy="selectin")
|
||||
jobs = relationship("Job", back_populates="client", lazy="selectin")
|
||||
tm_files = relationship("TMFileRegistry", back_populates="client", lazy="selectin")
|
||||
reference_files = relationship("ReferenceFile", back_populates="client", lazy="selectin")
|
||||
43
backend/app/models/feedback.py
Normal file
43
backend/app/models/feedback.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, generate_uuid
|
||||
|
||||
|
||||
class FlagType(str, enum.Enum):
|
||||
approved = "approved"
|
||||
needs_revision = "needs_revision"
|
||||
comment = "comment"
|
||||
|
||||
|
||||
class Feedback(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
output_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("output_rows.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("users.id"), nullable=False
|
||||
)
|
||||
option_column: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
flag_type: Mapped[FlagType] = mapped_column(
|
||||
Enum(FlagType, name="flag_type", create_constraint=True),
|
||||
nullable=False,
|
||||
)
|
||||
comment: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
output = relationship("OutputRow", back_populates="feedback")
|
||||
user = relationship("User", back_populates="feedback")
|
||||
89
backend/app/models/files.py
Normal file
89
backend/app/models/files.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Enum, ForeignKey, Integer, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, generate_uuid
|
||||
|
||||
|
||||
class ReferenceFileType(str, enum.Enum):
|
||||
glossary = "glossary"
|
||||
blacklist = "blacklist"
|
||||
tov_global = "tov_global"
|
||||
tov_supplement = "tov_supplement"
|
||||
locale_considerations = "locale_considerations"
|
||||
date_pct_formats = "date_pct_formats"
|
||||
|
||||
|
||||
class TMFileRegistry(Base):
|
||||
__tablename__ = "tm_file_registry"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
client_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("clients.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
locale_code: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
channel: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
segment_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
uploaded_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
uploaded_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
last_updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
version: Mapped[int] = mapped_column(Integer, default=1)
|
||||
|
||||
# Relationships
|
||||
client = relationship("Client", back_populates="tm_files")
|
||||
uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||
|
||||
|
||||
class ReferenceFile(Base):
|
||||
__tablename__ = "reference_files"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
client_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("clients.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_type: Mapped[ReferenceFileType] = mapped_column(
|
||||
Enum(ReferenceFileType, name="reference_file_type", create_constraint=True),
|
||||
nullable=False,
|
||||
)
|
||||
locale_scope: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
filename: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
uploaded_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
uploaded_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
last_updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
version: Mapped[int] = mapped_column(Integer, default=1)
|
||||
|
||||
# Relationships
|
||||
client = relationship("Client", back_populates="reference_files")
|
||||
uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||
143
backend/app/models/job.py
Normal file
143
backend/app/models/job.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
import enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, generate_uuid
|
||||
|
||||
|
||||
class Programme(str, enum.Enum):
|
||||
retail = "retail"
|
||||
prime = "prime"
|
||||
brand = "brand"
|
||||
|
||||
|
||||
class JobType(str, enum.Enum):
|
||||
main = "main"
|
||||
derived = "derived"
|
||||
|
||||
|
||||
class JobStatus(str, enum.Enum):
|
||||
created = "created"
|
||||
validating = "validating"
|
||||
queued = "queued"
|
||||
running = "running"
|
||||
partial_complete = "partial_complete"
|
||||
complete = "complete"
|
||||
error = "error"
|
||||
exported = "exported"
|
||||
|
||||
|
||||
class LocaleType(str, enum.Enum):
|
||||
main = "main"
|
||||
derived = "derived"
|
||||
|
||||
|
||||
class LocaleStatus(str, enum.Enum):
|
||||
queued = "queued"
|
||||
running = "running"
|
||||
complete = "complete"
|
||||
error = "error"
|
||||
|
||||
|
||||
class Job(Base, TimestampMixin):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
client_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("clients.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
created_by: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("users.id"), nullable=False
|
||||
)
|
||||
job_ref: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
campaign_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
programme: Mapped[Programme] = mapped_column(
|
||||
Enum(Programme, name="programme_type", create_constraint=True),
|
||||
nullable=False,
|
||||
)
|
||||
channel: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
sub_channel: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
context_prompt: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
job_type: Mapped[JobType] = mapped_column(
|
||||
Enum(JobType, name="job_type", create_constraint=True),
|
||||
default=JobType.main,
|
||||
nullable=False,
|
||||
)
|
||||
parent_job_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("jobs.id"), nullable=True
|
||||
)
|
||||
status: Mapped[JobStatus] = mapped_column(
|
||||
Enum(JobStatus, name="job_status", create_constraint=True),
|
||||
default=JobStatus.created,
|
||||
nullable=False,
|
||||
)
|
||||
total_token_usage: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_estimated_cost: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
|
||||
# Relationships
|
||||
client = relationship("Client", back_populates="jobs")
|
||||
creator = relationship("User", back_populates="jobs_created")
|
||||
parent_job = relationship("Job", remote_side="Job.id", lazy="selectin")
|
||||
locale_instances = relationship(
|
||||
"LocaleInstance", back_populates="job", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
source_lines = relationship(
|
||||
"SourceLine", back_populates="job", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class LocaleInstance(Base, TimestampMixin):
|
||||
__tablename__ = "locale_instances"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
job_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
locale_code: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
locale_type: Mapped[LocaleType] = mapped_column(
|
||||
Enum(LocaleType, name="locale_type", create_constraint=True),
|
||||
nullable=False,
|
||||
)
|
||||
status: Mapped[LocaleStatus] = mapped_column(
|
||||
Enum(LocaleStatus, name="locale_status", create_constraint=True),
|
||||
default=LocaleStatus.queued,
|
||||
nullable=False,
|
||||
)
|
||||
started_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
token_usage: Mapped[int] = mapped_column(Integer, default=0)
|
||||
estimated_cost: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
output_file_path: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
error_log: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tm_files_loaded: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
ref_files_loaded: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
agent_version: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
|
||||
# Relationships
|
||||
job = relationship("Job", back_populates="locale_instances")
|
||||
output_rows = relationship(
|
||||
"OutputRow", back_populates="instance", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
token_usage_logs = relationship(
|
||||
"TokenUsageLog", back_populates="instance", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
53
backend/app/models/output.py
Normal file
53
backend/app/models/output.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Enum, ForeignKey, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, generate_uuid
|
||||
|
||||
|
||||
class ConfidenceTier(str, enum.Enum):
|
||||
high = "high"
|
||||
moderate = "moderate"
|
||||
low = "low"
|
||||
|
||||
|
||||
class OutputRow(Base, TimestampMixin):
|
||||
__tablename__ = "output_rows"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
instance_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("locale_instances.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
line_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("source_lines.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
row_order: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
confidence_tier: Mapped[ConfidenceTier] = mapped_column(
|
||||
Enum(ConfidenceTier, name="confidence_tier", create_constraint=True),
|
||||
nullable=False,
|
||||
)
|
||||
option_1: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
backtranslation_1: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
rationale_1: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
option_2: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
backtranslation_2: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
rationale_2: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
option_3: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
backtranslation_3: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
rationale_3: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tm_entries_cited: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
winning_seg_key: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
character_count_option_1: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
character_count_option_2: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
character_count_option_3: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
instance = relationship("LocaleInstance", back_populates="output_rows")
|
||||
line = relationship("SourceLine", back_populates="output_rows")
|
||||
feedback = relationship(
|
||||
"Feedback", back_populates="output", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
30
backend/app/models/source.py
Normal file
30
backend/app/models/source.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, generate_uuid
|
||||
|
||||
|
||||
class SourceLine(Base, TimestampMixin):
|
||||
__tablename__ = "source_lines"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
job_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
row_order: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
en_gb: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
copy_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
creative_guidance: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
visual_ref: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
char_limit: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
is_display_format: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Relationships
|
||||
job = relationship("Job", back_populates="source_lines")
|
||||
output_rows = relationship(
|
||||
"OutputRow", back_populates="line", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
60
backend/app/models/user.py
Normal file
60
backend/app/models/user.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Enum, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, generate_uuid
|
||||
|
||||
|
||||
class UserRole(str, enum.Enum):
|
||||
admin = "admin"
|
||||
tm_manager = "tm_manager"
|
||||
reviewer = "reviewer"
|
||||
|
||||
|
||||
class UserStatus(str, enum.Enum):
|
||||
active = "active"
|
||||
inactive = "inactive"
|
||||
|
||||
|
||||
class User(Base, TimestampMixin):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
primary_key=True, default=generate_uuid
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, name="user_role", create_constraint=True),
|
||||
nullable=False,
|
||||
)
|
||||
status: Mapped[UserStatus] = mapped_column(
|
||||
Enum(UserStatus, name="user_status", create_constraint=True),
|
||||
default=UserStatus.active,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user_clients = relationship("UserClient", back_populates="user", lazy="selectin")
|
||||
jobs_created = relationship("Job", back_populates="creator", lazy="selectin")
|
||||
feedback = relationship("Feedback", back_populates="user", lazy="selectin")
|
||||
audit_logs = relationship("AuditLog", back_populates="user", lazy="selectin")
|
||||
|
||||
|
||||
class UserClient(Base):
|
||||
__tablename__ = "user_clients"
|
||||
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
client_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("clients.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
role_override: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="user_clients")
|
||||
client = relationship("Client", back_populates="user_clients")
|
||||
0
backend/app/pipeline/__init__.py
Normal file
0
backend/app/pipeline/__init__.py
Normal file
0
backend/app/pipeline/agents/__init__.py
Normal file
0
backend/app/pipeline/agents/__init__.py
Normal file
106
backend/app/pipeline/agents/agent_1_validator.py
Normal file
106
backend/app/pipeline/agents/agent_1_validator.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
"""Agent 1: Validator
|
||||
|
||||
Validates source file, loads reference files, and builds the initial
|
||||
PipelineContext. This agent is deterministic (no LLM call).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.agents.base import BaseAgent
|
||||
from app.pipeline.contracts import (
|
||||
FileManifest,
|
||||
JobParams,
|
||||
ParsedJob,
|
||||
PipelineContext,
|
||||
SourceLineContract,
|
||||
)
|
||||
from app.pipeline.modules.source_file_parser import parse_source_file
|
||||
from app.pipeline.modules.ref_file_loader import load_all_reference_files
|
||||
|
||||
|
||||
class Agent1Validator(BaseAgent):
|
||||
"""Validates inputs and builds the initial pipeline context."""
|
||||
|
||||
name = "agent_1_validator"
|
||||
description = "Validates source file and reference files, builds ParsedJob"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_file_path: str | None = None,
|
||||
source_lines: list[dict] | None = None,
|
||||
file_manifest: dict[str, str | None] | None = None,
|
||||
job_params: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.source_file_path = source_file_path
|
||||
self.source_lines_raw = source_lines
|
||||
self.file_manifest_raw = file_manifest or {}
|
||||
self.job_params_raw = job_params or {}
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return "" # No LLM call for this agent
|
||||
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
return "" # No LLM call for this agent
|
||||
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
return None # No LLM call for this agent
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""Validate inputs and populate the pipeline context.
|
||||
|
||||
1. Parse source file or use pre-parsed source lines
|
||||
2. Resolve and load reference files
|
||||
3. Build initial PipelineContext
|
||||
"""
|
||||
# Parse source lines
|
||||
if self.source_lines_raw:
|
||||
raw_lines = self.source_lines_raw
|
||||
elif self.source_file_path:
|
||||
raw_lines = parse_source_file(self.source_file_path)
|
||||
else:
|
||||
raw_lines = []
|
||||
|
||||
source_lines = [
|
||||
SourceLineContract(
|
||||
line_id=str(i + 1),
|
||||
row_order=i + 1,
|
||||
en_gb=line.get("en_gb", ""),
|
||||
copy_type=line.get("copy_type"),
|
||||
creative_guidance=line.get("creative_guidance"),
|
||||
visual_ref=line.get("visual_ref"),
|
||||
char_limit=line.get("char_limit"),
|
||||
is_display_format=line.get("is_display_format", False),
|
||||
)
|
||||
for i, line in enumerate(raw_lines)
|
||||
]
|
||||
|
||||
# Build file manifest
|
||||
file_manifest = FileManifest(
|
||||
tm_files=self.file_manifest_raw.get("tm_files", []),
|
||||
glossary_file=self.file_manifest_raw.get("glossary_file"),
|
||||
blacklist_file=self.file_manifest_raw.get("blacklist_file"),
|
||||
tov_global_file=self.file_manifest_raw.get("tov_global_file"),
|
||||
tov_supplement_file=self.file_manifest_raw.get("tov_supplement_file"),
|
||||
locale_considerations_file=self.file_manifest_raw.get(
|
||||
"locale_considerations_file"
|
||||
),
|
||||
date_pct_formats_file=self.file_manifest_raw.get("date_pct_formats_file"),
|
||||
)
|
||||
|
||||
# Load reference files (for validation)
|
||||
if any([
|
||||
file_manifest.glossary_file,
|
||||
file_manifest.blacklist_file,
|
||||
file_manifest.date_pct_formats_file,
|
||||
]):
|
||||
load_all_reference_files(self.file_manifest_raw)
|
||||
|
||||
# Build job params
|
||||
job_params = JobParams(**self.job_params_raw)
|
||||
|
||||
# Update context
|
||||
context.job_params = job_params
|
||||
context.source_lines = source_lines
|
||||
context.file_manifest = file_manifest
|
||||
|
||||
return context
|
||||
40
backend/app/pipeline/agents/agent_2_tm_retrieval.py
Normal file
40
backend/app/pipeline/agents/agent_2_tm_retrieval.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Agent 2: TM Retrieval (STUB)
|
||||
|
||||
Retrieves Translation Memory matches for each source line.
|
||||
Currently returns empty results as a stub.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.agents.base import BaseAgent
|
||||
from app.pipeline.contracts import PipelineContext, TMSweepResult
|
||||
|
||||
|
||||
class Agent2TMRetrieval(BaseAgent):
|
||||
"""STUB: TM retrieval agent returning empty sweep results."""
|
||||
|
||||
name = "agent_2_tm_retrieval"
|
||||
description = "Retrieves TM matches for source lines (STUB)"
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return "You are a Translation Memory retrieval agent."
|
||||
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
return "Retrieve TM matches for the provided source lines."
|
||||
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
return []
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""STUB: Return empty TM sweep results for all source lines."""
|
||||
context.tm_sweep_results = [
|
||||
TMSweepResult(
|
||||
line_id=line.line_id,
|
||||
confirmed_matches=[],
|
||||
pass_4_triggered=False,
|
||||
pass_4_result=None,
|
||||
no_match=True,
|
||||
)
|
||||
for line in context.source_lines
|
||||
]
|
||||
return context
|
||||
42
backend/app/pipeline/agents/agent_3_ranker.py
Normal file
42
backend/app/pipeline/agents/agent_3_ranker.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Agent 3: Ranker (STUB)
|
||||
|
||||
Ranks TM matches and declares confidence tiers for each source line.
|
||||
Currently returns LOW confidence for all lines as a stub.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.agents.base import BaseAgent
|
||||
from app.pipeline.contracts import PipelineContext, RankingDeclaration
|
||||
|
||||
|
||||
class Agent3Ranker(BaseAgent):
|
||||
"""STUB: Ranking agent returning LOW confidence for all lines."""
|
||||
|
||||
name = "agent_3_ranker"
|
||||
description = "Ranks TM matches and declares confidence (STUB)"
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return "You are a ranking and confidence declaration agent."
|
||||
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
return "Rank the TM matches for each source line."
|
||||
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
return []
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""STUB: Return LOW confidence ranking for all source lines."""
|
||||
context.ranking_declarations = [
|
||||
RankingDeclaration(
|
||||
line_id=line.line_id,
|
||||
winning_entry=None,
|
||||
runner_ups=[],
|
||||
confidence_tier="low",
|
||||
option_count=3,
|
||||
is_new_creative_line=True,
|
||||
notes="STUB: No TM matches available",
|
||||
)
|
||||
for line in context.source_lines
|
||||
]
|
||||
return context
|
||||
55
backend/app/pipeline/agents/agent_4_transcreator.py
Normal file
55
backend/app/pipeline/agents/agent_4_transcreator.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""Agent 4: Transcreator (STUB)
|
||||
|
||||
Generates transcreation drafts for each source line.
|
||||
Currently returns placeholder translations as a stub.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.agents.base import BaseAgent
|
||||
from app.pipeline.contracts import DraftOption, DraftOutput, PipelineContext
|
||||
|
||||
|
||||
class Agent4Transcreator(BaseAgent):
|
||||
"""STUB: Transcreation agent returning placeholder translations."""
|
||||
|
||||
name = "agent_4_transcreator"
|
||||
description = "Generates transcreation drafts (STUB)"
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return "You are a creative transcreation agent."
|
||||
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
return "Generate transcreation drafts for each source line."
|
||||
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
return []
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""STUB: Return placeholder translations for all source lines."""
|
||||
locale = context.job_params.locale_code
|
||||
|
||||
context.draft_outputs = [
|
||||
DraftOutput(
|
||||
line_id=line.line_id,
|
||||
option_1=DraftOption(
|
||||
text=f"[{locale}] {line.en_gb}",
|
||||
backtranslation=line.en_gb,
|
||||
rationale=f"STUB: Direct placeholder for '{line.en_gb[:50]}...'",
|
||||
),
|
||||
option_2=DraftOption(
|
||||
text=f"[{locale} alt] {line.en_gb}",
|
||||
backtranslation=line.en_gb,
|
||||
rationale="STUB: Alternative placeholder",
|
||||
),
|
||||
option_3=DraftOption(
|
||||
text=f"[{locale} creative] {line.en_gb}",
|
||||
backtranslation=line.en_gb,
|
||||
rationale="STUB: Creative placeholder",
|
||||
),
|
||||
tm_entries_cited=[],
|
||||
adaptations_applied=[],
|
||||
)
|
||||
for line in context.source_lines
|
||||
]
|
||||
return context
|
||||
52
backend/app/pipeline/agents/agent_5_compliance.py
Normal file
52
backend/app/pipeline/agents/agent_5_compliance.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""Agent 5: Compliance Checker (STUB)
|
||||
|
||||
Checks transcreation drafts against compliance rules.
|
||||
Currently returns PASS for all lines as a stub.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.agents.base import BaseAgent
|
||||
from app.pipeline.contracts import ComplianceResult, PipelineContext
|
||||
from app.pipeline.modules.character_counter import count_characters
|
||||
|
||||
|
||||
class Agent5Compliance(BaseAgent):
|
||||
"""STUB: Compliance agent returning pass for all lines."""
|
||||
|
||||
name = "agent_5_compliance"
|
||||
description = "Checks compliance of transcreation drafts (STUB)"
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return "You are a compliance checking agent."
|
||||
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
return "Check compliance for all transcreation drafts."
|
||||
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
return []
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""STUB: Return pass for all compliance checks, with character counts."""
|
||||
context.compliance_results = []
|
||||
|
||||
for draft in context.draft_outputs:
|
||||
char_counts: dict[str, int] = {}
|
||||
|
||||
if draft.option_1:
|
||||
char_counts["option_1"] = count_characters(draft.option_1.text)
|
||||
if draft.option_2:
|
||||
char_counts["option_2"] = count_characters(draft.option_2.text)
|
||||
if draft.option_3:
|
||||
char_counts["option_3"] = count_characters(draft.option_3.text)
|
||||
|
||||
context.compliance_results.append(
|
||||
ComplianceResult(
|
||||
line_id=draft.line_id,
|
||||
passed=True,
|
||||
violations=[],
|
||||
character_counts=char_counts,
|
||||
)
|
||||
)
|
||||
|
||||
return context
|
||||
112
backend/app/pipeline/agents/agent_6_formatter.py
Normal file
112
backend/app/pipeline/agents/agent_6_formatter.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""Agent 6: Formatter
|
||||
|
||||
Generates the output xlsx file and builds output row data.
|
||||
This agent is deterministic (no LLM call).
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from app.pipeline.agents.base import BaseAgent
|
||||
from app.pipeline.contracts import PipelineContext
|
||||
from app.pipeline.modules.excel_writer import generate_output_xlsx
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class Agent6Formatter(BaseAgent):
|
||||
"""Formats pipeline output into xlsx and structured data."""
|
||||
|
||||
name = "agent_6_formatter"
|
||||
description = "Generates output xlsx and structured output rows"
|
||||
|
||||
def __init__(self, output_dir: str | None = None) -> None:
|
||||
self.output_dir = output_dir or settings.STORAGE_ROOT
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return "" # No LLM call
|
||||
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
return "" # No LLM call
|
||||
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
return None # No LLM call
|
||||
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""Generate output xlsx and return updated context with file path.
|
||||
|
||||
Returns the context as-is (output rows are built by the orchestrator
|
||||
for database persistence). The xlsx file is written to disk.
|
||||
"""
|
||||
job_id = context.job_params.job_id
|
||||
locale_code = context.job_params.locale_code
|
||||
|
||||
# Build source lines for excel
|
||||
source_lines_data = [
|
||||
{
|
||||
"row_order": sl.row_order,
|
||||
"en_gb": sl.en_gb,
|
||||
"copy_type": sl.copy_type or "",
|
||||
}
|
||||
for sl in context.source_lines
|
||||
]
|
||||
|
||||
# Build output rows for excel
|
||||
output_rows_data = []
|
||||
for i, draft in enumerate(context.draft_outputs):
|
||||
row: dict[str, Any] = {
|
||||
"row_order": i + 1,
|
||||
"line_id": draft.line_id,
|
||||
"option_1": draft.option_1.text if draft.option_1 else "",
|
||||
"backtranslation_1": draft.option_1.backtranslation if draft.option_1 else "",
|
||||
"rationale_1": draft.option_1.rationale if draft.option_1 else "",
|
||||
}
|
||||
if draft.option_2:
|
||||
row["option_2"] = draft.option_2.text
|
||||
row["backtranslation_2"] = draft.option_2.backtranslation
|
||||
row["rationale_2"] = draft.option_2.rationale
|
||||
if draft.option_3:
|
||||
row["option_3"] = draft.option_3.text
|
||||
row["backtranslation_3"] = draft.option_3.backtranslation
|
||||
row["rationale_3"] = draft.option_3.rationale
|
||||
|
||||
output_rows_data.append(row)
|
||||
|
||||
# Build summary
|
||||
compliance_counts = {"high": 0, "moderate": 0, "low": 0}
|
||||
for ranking in context.ranking_declarations:
|
||||
tier = ranking.confidence_tier
|
||||
if tier in compliance_counts:
|
||||
compliance_counts[tier] += 1
|
||||
|
||||
summary = {
|
||||
"job_id": job_id,
|
||||
"campaign_name": context.job_params.campaign_name,
|
||||
"locale_code": locale_code,
|
||||
"channel": context.job_params.channel,
|
||||
"programme": context.job_params.programme,
|
||||
"total_source_lines": len(context.source_lines),
|
||||
"total_output_rows": len(output_rows_data),
|
||||
"high_confidence": compliance_counts["high"],
|
||||
"moderate_confidence": compliance_counts["moderate"],
|
||||
"low_confidence": compliance_counts["low"],
|
||||
"total_tokens": 0,
|
||||
"estimated_cost": 0.0,
|
||||
"agent_version": "1.0.0",
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Generate xlsx
|
||||
output_path = (
|
||||
f"{self.output_dir}/jobs/{job_id}/output/"
|
||||
f"{locale_code}_{job_id}_output.xlsx"
|
||||
)
|
||||
|
||||
generate_output_xlsx(
|
||||
output_path=output_path,
|
||||
source_lines=source_lines_data,
|
||||
output_rows=output_rows_data,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
return context
|
||||
62
backend/app/pipeline/agents/base.py
Normal file
62
backend/app/pipeline/agents/base.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Base agent abstract class for the transcreation pipeline."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.contracts import PipelineContext
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Abstract base class for pipeline agents.
|
||||
|
||||
Each agent in the pipeline implements this interface:
|
||||
1. get_system_prompt() - Returns the system prompt for the LLM
|
||||
2. build_user_message() - Builds the user message from pipeline context
|
||||
3. parse_response() - Parses the LLM response into structured data
|
||||
4. run() - Orchestrates the agent's execution
|
||||
"""
|
||||
|
||||
name: str = "base_agent"
|
||||
description: str = ""
|
||||
|
||||
@abstractmethod
|
||||
def get_system_prompt(self) -> str:
|
||||
"""Return the system prompt for this agent's LLM call."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def build_user_message(self, context: PipelineContext) -> str:
|
||||
"""Build the user message from pipeline context.
|
||||
|
||||
Args:
|
||||
context: The current pipeline context.
|
||||
|
||||
Returns:
|
||||
The user message string to send to the LLM.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: str, context: PipelineContext) -> Any:
|
||||
"""Parse the LLM response into structured data.
|
||||
|
||||
Args:
|
||||
response: The raw LLM response text.
|
||||
context: The current pipeline context.
|
||||
|
||||
Returns:
|
||||
Structured data appropriate for this agent's output.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, context: PipelineContext) -> PipelineContext:
|
||||
"""Execute this agent and return the updated pipeline context.
|
||||
|
||||
Args:
|
||||
context: The current pipeline context.
|
||||
|
||||
Returns:
|
||||
Updated pipeline context with this agent's results.
|
||||
"""
|
||||
...
|
||||
0
backend/app/pipeline/agents/prompts/__init__.py
Normal file
0
backend/app/pipeline/agents/prompts/__init__.py
Normal file
139
backend/app/pipeline/contracts.py
Normal file
139
backend/app/pipeline/contracts.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""Pipeline data contracts - Pydantic models for inter-agent communication."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TMEntry(BaseModel):
|
||||
"""A single Translation Memory entry."""
|
||||
seg_key: str
|
||||
date: str
|
||||
en: str
|
||||
lc: str
|
||||
tx: str
|
||||
nt: str = ""
|
||||
channel: str = ""
|
||||
sub_channel: str = ""
|
||||
_text: str = ""
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SourceLineContract(BaseModel):
|
||||
"""A parsed source line from the input xlsx."""
|
||||
line_id: str
|
||||
row_order: int
|
||||
en_gb: str
|
||||
copy_type: str | None = None
|
||||
creative_guidance: str | None = None
|
||||
visual_ref: str | None = None
|
||||
char_limit: str | None = None
|
||||
is_display_format: bool = False
|
||||
|
||||
|
||||
class FileManifest(BaseModel):
|
||||
"""Manifest of all files loaded for a job."""
|
||||
tm_files: list[str] = []
|
||||
glossary_file: str | None = None
|
||||
blacklist_file: str | None = None
|
||||
tov_global_file: str | None = None
|
||||
tov_supplement_file: str | None = None
|
||||
locale_considerations_file: str | None = None
|
||||
date_pct_formats_file: str | None = None
|
||||
|
||||
|
||||
class JobParams(BaseModel):
|
||||
"""Parameters for a transcreation job."""
|
||||
job_id: str
|
||||
client_id: str
|
||||
locale_code: str
|
||||
channel: str
|
||||
sub_channel: str | None = None
|
||||
programme: str
|
||||
campaign_name: str
|
||||
context_prompt: str | None = None
|
||||
|
||||
|
||||
class ParsedJob(BaseModel):
|
||||
"""Output of Agent 1 (Validator): validated job parameters + source."""
|
||||
job_params: JobParams
|
||||
source_lines: list[SourceLineContract]
|
||||
file_manifest: FileManifest
|
||||
|
||||
|
||||
class ConfirmedMatch(BaseModel):
|
||||
"""A confirmed TM match for a source line."""
|
||||
seg_key: str
|
||||
pass_found: int
|
||||
date: str
|
||||
en: str
|
||||
tx: str
|
||||
nt: str = ""
|
||||
channel: str = ""
|
||||
sub_channel: str = ""
|
||||
is_cross_channel: bool = False
|
||||
|
||||
|
||||
class TMSweepResult(BaseModel):
|
||||
"""TM sweep results for a single source line."""
|
||||
line_id: str
|
||||
confirmed_matches: list[ConfirmedMatch] = []
|
||||
pass_4_triggered: bool = False
|
||||
pass_4_result: ConfirmedMatch | None = None
|
||||
no_match: bool = False
|
||||
|
||||
|
||||
class RankingDeclaration(BaseModel):
|
||||
"""Ranking decision for a single source line."""
|
||||
line_id: str
|
||||
winning_entry: ConfirmedMatch | None = None
|
||||
runner_ups: list[ConfirmedMatch] = []
|
||||
confidence_tier: str = "low"
|
||||
option_count: int = 3
|
||||
is_new_creative_line: bool = False
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class DraftOption(BaseModel):
|
||||
"""A single draft transcreation option."""
|
||||
text: str
|
||||
backtranslation: str
|
||||
rationale: str
|
||||
|
||||
|
||||
class DraftOutput(BaseModel):
|
||||
"""Transcreation draft output for a single source line."""
|
||||
line_id: str
|
||||
option_1: DraftOption
|
||||
option_2: DraftOption | None = None
|
||||
option_3: DraftOption | None = None
|
||||
tm_entries_cited: list[str] = []
|
||||
adaptations_applied: list[str] = []
|
||||
|
||||
|
||||
class ComplianceViolation(BaseModel):
|
||||
"""A single compliance violation found during checking."""
|
||||
type: str
|
||||
option_affected: int
|
||||
description: str
|
||||
severity: str = "warning"
|
||||
|
||||
|
||||
class ComplianceResult(BaseModel):
|
||||
"""Compliance check result for a single source line."""
|
||||
line_id: str
|
||||
passed: bool
|
||||
violations: list[ComplianceViolation] = []
|
||||
character_counts: dict[str, int] = {}
|
||||
|
||||
|
||||
class PipelineContext(BaseModel):
|
||||
"""Full pipeline context passed between agents."""
|
||||
job_params: JobParams
|
||||
source_lines: list[SourceLineContract] = []
|
||||
file_manifest: FileManifest = FileManifest()
|
||||
tm_sweep_results: list[TMSweepResult] = []
|
||||
ranking_declarations: list[RankingDeclaration] = []
|
||||
draft_outputs: list[DraftOutput] = []
|
||||
compliance_results: list[ComplianceResult] = []
|
||||
0
backend/app/pipeline/modules/__init__.py
Normal file
0
backend/app/pipeline/modules/__init__.py
Normal file
108
backend/app/pipeline/modules/blacklist_scanner.py
Normal file
108
backend/app/pipeline/modules/blacklist_scanner.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""Blacklist scanner for forbidden terms and brand violations.
|
||||
|
||||
Supports:
|
||||
- Exact match: term appears as-is in the text (case-insensitive).
|
||||
- Root-based match: checks if the root/stem of a blacklisted term appears.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlacklistViolation:
|
||||
"""A detected blacklist violation."""
|
||||
term: str
|
||||
match_type: str # "exact" or "root"
|
||||
position: int # character position in text
|
||||
context: str # surrounding text
|
||||
|
||||
|
||||
def scan_text(
|
||||
text: str,
|
||||
blacklist_entries: list[dict],
|
||||
) -> list[BlacklistViolation]:
|
||||
"""Scan text against blacklist entries for violations.
|
||||
|
||||
Args:
|
||||
text: The text to scan.
|
||||
blacklist_entries: List of dicts with keys:
|
||||
- term: str (the forbidden term)
|
||||
- root: str | None (optional root form for root-based matching)
|
||||
- severity: str (optional)
|
||||
|
||||
Returns:
|
||||
List of BlacklistViolation instances found.
|
||||
"""
|
||||
if not text or not blacklist_entries:
|
||||
return []
|
||||
|
||||
violations: list[BlacklistViolation] = []
|
||||
text_lower = text.lower()
|
||||
|
||||
for entry in blacklist_entries:
|
||||
term = entry.get("term", "").strip()
|
||||
if not term:
|
||||
continue
|
||||
|
||||
root = entry.get("root", "").strip() if entry.get("root") else None
|
||||
|
||||
# Exact match (case-insensitive, word boundary)
|
||||
pattern = re.compile(r"\b" + re.escape(term) + r"\b", re.IGNORECASE)
|
||||
for match in pattern.finditer(text):
|
||||
start = max(0, match.start() - 20)
|
||||
end = min(len(text), match.end() + 20)
|
||||
violations.append(
|
||||
BlacklistViolation(
|
||||
term=term,
|
||||
match_type="exact",
|
||||
position=match.start(),
|
||||
context=text[start:end],
|
||||
)
|
||||
)
|
||||
|
||||
# Root-based match
|
||||
if root:
|
||||
root_pattern = re.compile(
|
||||
r"\b" + re.escape(root) + r"\w*\b", re.IGNORECASE
|
||||
)
|
||||
for match in root_pattern.finditer(text):
|
||||
# Skip if this is already caught as an exact match
|
||||
matched_text = match.group().lower()
|
||||
if matched_text == term.lower():
|
||||
continue
|
||||
|
||||
start = max(0, match.start() - 20)
|
||||
end = min(len(text), match.end() + 20)
|
||||
violations.append(
|
||||
BlacklistViolation(
|
||||
term=term,
|
||||
match_type="root",
|
||||
position=match.start(),
|
||||
context=text[start:end],
|
||||
)
|
||||
)
|
||||
|
||||
return violations
|
||||
|
||||
|
||||
def scan_all_options(
|
||||
options: list[str],
|
||||
blacklist_entries: list[dict],
|
||||
) -> dict[int, list[BlacklistViolation]]:
|
||||
"""Scan multiple text options against the blacklist.
|
||||
|
||||
Args:
|
||||
options: List of text options to scan (index 0 = option 1, etc.)
|
||||
blacklist_entries: The blacklist entries.
|
||||
|
||||
Returns:
|
||||
Dict mapping option index (1-based) to violations found.
|
||||
"""
|
||||
results: dict[int, list[BlacklistViolation]] = {}
|
||||
for i, option_text in enumerate(options):
|
||||
if option_text:
|
||||
violations = scan_text(option_text, blacklist_entries)
|
||||
if violations:
|
||||
results[i + 1] = violations
|
||||
return results
|
||||
67
backend/app/pipeline/modules/character_counter.py
Normal file
67
backend/app/pipeline/modules/character_counter.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Unicode grapheme cluster character counting.
|
||||
|
||||
Uses the grapheme library for accurate character counting that handles
|
||||
multi-codepoint characters (emoji, combining marks, etc.) correctly.
|
||||
|
||||
Rules:
|
||||
- Strip leading/trailing whitespace before counting.
|
||||
- Line breaks (\\n) count as 0 characters.
|
||||
"""
|
||||
|
||||
import grapheme
|
||||
|
||||
|
||||
def count_characters(text: str) -> int:
|
||||
"""Count grapheme clusters in a string, excluding line breaks.
|
||||
|
||||
Args:
|
||||
text: The text to count characters in.
|
||||
|
||||
Returns:
|
||||
Number of grapheme clusters (visible characters), excluding line breaks.
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Strip whitespace
|
||||
cleaned = text.strip()
|
||||
|
||||
# Remove line breaks (they count as 0)
|
||||
cleaned = cleaned.replace("\n", "").replace("\r", "")
|
||||
|
||||
return grapheme.length(cleaned)
|
||||
|
||||
|
||||
def check_character_limit(text: str, char_limit: str | None) -> tuple[int, bool]:
|
||||
"""Count characters and check against a limit.
|
||||
|
||||
Args:
|
||||
text: The text to measure.
|
||||
char_limit: The character limit as a string (may be numeric or range like "20-30").
|
||||
|
||||
Returns:
|
||||
Tuple of (character_count, within_limit). If no limit is set, within_limit is True.
|
||||
"""
|
||||
count = count_characters(text)
|
||||
|
||||
if not char_limit:
|
||||
return count, True
|
||||
|
||||
char_limit = char_limit.strip()
|
||||
|
||||
# Handle range format "20-30"
|
||||
if "-" in char_limit:
|
||||
parts = char_limit.split("-")
|
||||
try:
|
||||
lower = int(parts[0].strip())
|
||||
upper = int(parts[1].strip())
|
||||
return count, lower <= count <= upper
|
||||
except (ValueError, IndexError):
|
||||
return count, True
|
||||
|
||||
# Handle simple numeric limit
|
||||
try:
|
||||
limit = int(char_limit)
|
||||
return count, count <= limit
|
||||
except ValueError:
|
||||
return count, True
|
||||
121
backend/app/pipeline/modules/date_format_validator.py
Normal file
121
backend/app/pipeline/modules/date_format_validator.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Validate date and percentage format strings against approved locale formats.
|
||||
|
||||
Checks that dates and percentages in transcreated text conform to the
|
||||
locale-specific format rules defined in the date/percentage format file.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormatViolation:
|
||||
"""A detected date/percentage format violation."""
|
||||
found: str
|
||||
expected_format: str
|
||||
description: str
|
||||
|
||||
|
||||
def validate_date_formats(
|
||||
text: str,
|
||||
approved_formats: list[dict],
|
||||
) -> list[FormatViolation]:
|
||||
"""Validate date strings in text against approved formats.
|
||||
|
||||
Args:
|
||||
text: The text containing dates to validate.
|
||||
approved_formats: List of dicts with keys:
|
||||
- pattern: str (regex pattern for valid format)
|
||||
- example: str (example of the correct format)
|
||||
- description: str
|
||||
|
||||
Returns:
|
||||
List of FormatViolation instances.
|
||||
"""
|
||||
if not text or not approved_formats:
|
||||
return []
|
||||
|
||||
violations: list[FormatViolation] = []
|
||||
|
||||
# Common date-like patterns to detect
|
||||
date_patterns = [
|
||||
# DD/MM/YYYY, MM/DD/YYYY, YYYY/MM/DD
|
||||
r"\b\d{1,2}[/\-.]\d{1,2}[/\-.]\d{2,4}\b",
|
||||
# Month DD, YYYY
|
||||
r"\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s+\d{1,2},?\s+\d{4}\b",
|
||||
# DD Month YYYY
|
||||
r"\b\d{1,2}\s+(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s+\d{4}\b",
|
||||
]
|
||||
|
||||
for date_pattern in date_patterns:
|
||||
for match in re.finditer(date_pattern, text, re.IGNORECASE):
|
||||
found_date = match.group()
|
||||
is_valid = False
|
||||
|
||||
for fmt in approved_formats:
|
||||
pattern = fmt.get("pattern", "")
|
||||
if pattern and re.match(pattern, found_date, re.IGNORECASE):
|
||||
is_valid = True
|
||||
break
|
||||
|
||||
if not is_valid and approved_formats:
|
||||
examples = [
|
||||
fmt.get("example", "") for fmt in approved_formats if fmt.get("example")
|
||||
]
|
||||
violations.append(
|
||||
FormatViolation(
|
||||
found=found_date,
|
||||
expected_format=", ".join(examples[:3]),
|
||||
description=f"Date format '{found_date}' does not match approved formats",
|
||||
)
|
||||
)
|
||||
|
||||
return violations
|
||||
|
||||
|
||||
def validate_percentage_formats(
|
||||
text: str,
|
||||
approved_formats: list[dict],
|
||||
) -> list[FormatViolation]:
|
||||
"""Validate percentage strings in text against approved formats.
|
||||
|
||||
Args:
|
||||
text: The text containing percentages to validate.
|
||||
approved_formats: List of dicts with keys:
|
||||
- pattern: str (regex)
|
||||
- example: str
|
||||
- description: str
|
||||
|
||||
Returns:
|
||||
List of FormatViolation instances.
|
||||
"""
|
||||
if not text or not approved_formats:
|
||||
return []
|
||||
|
||||
violations: list[FormatViolation] = []
|
||||
|
||||
# Find percentage-like patterns
|
||||
pct_pattern = r"\b\d+[\.,]?\d*\s*[%%]\b"
|
||||
for match in re.finditer(pct_pattern, text):
|
||||
found_pct = match.group()
|
||||
is_valid = False
|
||||
|
||||
for fmt in approved_formats:
|
||||
pattern = fmt.get("pattern", "")
|
||||
if pattern and re.match(pattern, found_pct):
|
||||
is_valid = True
|
||||
break
|
||||
|
||||
if not is_valid and approved_formats:
|
||||
examples = [
|
||||
fmt.get("example", "") for fmt in approved_formats if fmt.get("example")
|
||||
]
|
||||
violations.append(
|
||||
FormatViolation(
|
||||
found=found_pct,
|
||||
expected_format=", ".join(examples[:3]),
|
||||
description=f"Percentage format '{found_pct}' does not match approved formats",
|
||||
)
|
||||
)
|
||||
|
||||
return violations
|
||||
110
backend/app/pipeline/modules/domain_substitutor.py
Normal file
110
backend/app/pipeline/modules/domain_substitutor.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Domain substitution for Amazon locales.
|
||||
|
||||
Maps Amazon.co.uk (source domain) to the correct locale-specific domain.
|
||||
Handles both full domain URLs and bare "Amazon" references.
|
||||
|
||||
Emerging locales: bare "Amazon" stays as "Amazon"
|
||||
Non-emerging locales: bare "Amazon" -> locale-specific brand name rules apply
|
||||
"""
|
||||
|
||||
# Full domain map for all 12 supported locales
|
||||
DOMAIN_MAP: dict[str, str] = {
|
||||
"de_DE": "Amazon.de",
|
||||
"fr_FR": "Amazon.fr",
|
||||
"it_IT": "Amazon.it",
|
||||
"es_ES": "Amazon.es",
|
||||
"nl_NL": "Amazon.nl",
|
||||
"pl_PL": "Amazon.pl",
|
||||
"sv_SE": "Amazon.se",
|
||||
"pt_BR": "Amazon.com.br",
|
||||
"ja_JP": "Amazon.co.jp",
|
||||
"en_AU": "Amazon.com.au",
|
||||
"en_SG": "Amazon.sg",
|
||||
"ar_AE": "Amazon.ae",
|
||||
}
|
||||
|
||||
# Emerging locales where bare "Amazon" stays as-is
|
||||
EMERGING_LOCALES: set[str] = {
|
||||
"pl_PL",
|
||||
"sv_SE",
|
||||
"nl_NL",
|
||||
"ar_AE",
|
||||
"en_SG",
|
||||
}
|
||||
|
||||
# Source domain to replace
|
||||
SOURCE_DOMAIN = "Amazon.co.uk"
|
||||
SOURCE_DOMAIN_LOWER = SOURCE_DOMAIN.lower()
|
||||
|
||||
|
||||
def substitute_domains(text: str, locale_code: str) -> str:
|
||||
"""Replace Amazon.co.uk with the locale-specific domain.
|
||||
|
||||
Args:
|
||||
text: The text containing domain references.
|
||||
locale_code: The target locale code (e.g., "de_DE").
|
||||
|
||||
Returns:
|
||||
Text with domains substituted.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
target_domain = DOMAIN_MAP.get(locale_code)
|
||||
if target_domain is None:
|
||||
return text
|
||||
|
||||
# Replace full domain (case-insensitive)
|
||||
result = text
|
||||
idx = 0
|
||||
while True:
|
||||
lower_result = result.lower()
|
||||
pos = lower_result.find(SOURCE_DOMAIN_LOWER, idx)
|
||||
if pos == -1:
|
||||
break
|
||||
result = result[:pos] + target_domain + result[pos + len(SOURCE_DOMAIN):]
|
||||
idx = pos + len(target_domain)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def substitute_bare_amazon(text: str, locale_code: str) -> str:
|
||||
"""Handle bare 'Amazon' references based on locale type.
|
||||
|
||||
For emerging locales: leave bare 'Amazon' as-is.
|
||||
For non-emerging locales: append locale domain context if needed.
|
||||
|
||||
Args:
|
||||
text: The text with potential bare Amazon references.
|
||||
locale_code: The target locale code.
|
||||
|
||||
Returns:
|
||||
Text with bare Amazon references handled.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
if locale_code in EMERGING_LOCALES:
|
||||
# Emerging locales: bare Amazon stays as Amazon
|
||||
return text
|
||||
|
||||
# For non-emerging locales, bare "Amazon" (not followed by .)
|
||||
# is kept as-is since it's the brand name
|
||||
return text
|
||||
|
||||
|
||||
def get_locale_domain(locale_code: str) -> str | None:
|
||||
"""Get the domain for a locale code.
|
||||
|
||||
Args:
|
||||
locale_code: The target locale code.
|
||||
|
||||
Returns:
|
||||
The domain string or None if locale is not supported.
|
||||
"""
|
||||
return DOMAIN_MAP.get(locale_code)
|
||||
|
||||
|
||||
def is_emerging_locale(locale_code: str) -> bool:
|
||||
"""Check if a locale is classified as emerging."""
|
||||
return locale_code in EMERGING_LOCALES
|
||||
184
backend/app/pipeline/modules/excel_writer.py
Normal file
184
backend/app/pipeline/modules/excel_writer.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
"""Generate output xlsx files with structured output and summary tabs.
|
||||
|
||||
Tab 1: 11-column output table
|
||||
Tab 2: Transcreation Summary
|
||||
|
||||
Column widths and formatting per specification.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Alignment, Font, PatternFill
|
||||
from openpyxl.utils import get_column_letter
|
||||
|
||||
from app.pipeline.modules.line_break_normaliser import normalise_for_excel
|
||||
|
||||
# Tab 1 column definitions
|
||||
OUTPUT_COLUMNS = [
|
||||
("EN_GB", 40),
|
||||
("Copy Type", 15),
|
||||
("Option 1", 40),
|
||||
("Back-translation 1", 40),
|
||||
("Rationale 1", 35),
|
||||
("Option 2", 40),
|
||||
("Back-translation 2", 40),
|
||||
("Rationale 2", 35),
|
||||
("Option 3", 40),
|
||||
("Back-translation 3", 40),
|
||||
("Rationale 3", 35),
|
||||
]
|
||||
|
||||
# Header style
|
||||
HEADER_FONT = Font(bold=True, size=11, color="FFFFFF")
|
||||
HEADER_FILL = PatternFill(start_color="232F3E", end_color="232F3E", fill_type="solid")
|
||||
HEADER_ALIGNMENT = Alignment(horizontal="center", vertical="center", wrap_text=True)
|
||||
|
||||
# Data style
|
||||
DATA_ALIGNMENT = Alignment(vertical="top", wrap_text=True)
|
||||
|
||||
|
||||
def generate_output_xlsx(
|
||||
output_path: str,
|
||||
source_lines: list[dict[str, Any]],
|
||||
output_rows: list[dict[str, Any]],
|
||||
summary: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Generate the output xlsx file.
|
||||
|
||||
Args:
|
||||
output_path: Absolute path where the xlsx should be saved.
|
||||
source_lines: List of source line dicts (en_gb, copy_type, etc.).
|
||||
output_rows: List of output row dicts with options, backtranslations, rationales.
|
||||
summary: Optional summary data for Tab 2.
|
||||
|
||||
Returns:
|
||||
The absolute path to the generated file.
|
||||
"""
|
||||
wb = Workbook()
|
||||
|
||||
# ---- Tab 1: Output Table ----
|
||||
ws1 = wb.active
|
||||
ws1.title = "Transcreation Output"
|
||||
|
||||
# Write headers
|
||||
for col_idx, (header, width) in enumerate(OUTPUT_COLUMNS, start=1):
|
||||
cell = ws1.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = HEADER_FONT
|
||||
cell.fill = HEADER_FILL
|
||||
cell.alignment = HEADER_ALIGNMENT
|
||||
ws1.column_dimensions[get_column_letter(col_idx)].width = width
|
||||
|
||||
# Write data rows
|
||||
for row_idx, output_row in enumerate(output_rows, start=2):
|
||||
# Find matching source line
|
||||
source_line = _find_source_line(source_lines, output_row)
|
||||
|
||||
ws1.cell(
|
||||
row=row_idx, column=1,
|
||||
value=normalise_for_excel(source_line.get("en_gb", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
|
||||
ws1.cell(
|
||||
row=row_idx, column=2,
|
||||
value=source_line.get("copy_type", ""),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
|
||||
# Option 1
|
||||
ws1.cell(
|
||||
row=row_idx, column=3,
|
||||
value=normalise_for_excel(output_row.get("option_1", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
ws1.cell(
|
||||
row=row_idx, column=4,
|
||||
value=normalise_for_excel(output_row.get("backtranslation_1", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
ws1.cell(
|
||||
row=row_idx, column=5,
|
||||
value=output_row.get("rationale_1", ""),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
|
||||
# Option 2
|
||||
ws1.cell(
|
||||
row=row_idx, column=6,
|
||||
value=normalise_for_excel(output_row.get("option_2", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
ws1.cell(
|
||||
row=row_idx, column=7,
|
||||
value=normalise_for_excel(output_row.get("backtranslation_2", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
ws1.cell(
|
||||
row=row_idx, column=8,
|
||||
value=output_row.get("rationale_2", ""),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
|
||||
# Option 3
|
||||
ws1.cell(
|
||||
row=row_idx, column=9,
|
||||
value=normalise_for_excel(output_row.get("option_3", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
ws1.cell(
|
||||
row=row_idx, column=10,
|
||||
value=normalise_for_excel(output_row.get("backtranslation_3", "")),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
ws1.cell(
|
||||
row=row_idx, column=11,
|
||||
value=output_row.get("rationale_3", ""),
|
||||
).alignment = DATA_ALIGNMENT
|
||||
|
||||
# ---- Tab 2: Transcreation Summary ----
|
||||
ws2 = wb.create_sheet("Transcreation Summary")
|
||||
|
||||
summary_data = summary or {}
|
||||
summary_rows = [
|
||||
("Job ID", summary_data.get("job_id", "")),
|
||||
("Campaign", summary_data.get("campaign_name", "")),
|
||||
("Locale", summary_data.get("locale_code", "")),
|
||||
("Channel", summary_data.get("channel", "")),
|
||||
("Programme", summary_data.get("programme", "")),
|
||||
("Total Source Lines", summary_data.get("total_source_lines", 0)),
|
||||
("Total Output Rows", summary_data.get("total_output_rows", 0)),
|
||||
("High Confidence", summary_data.get("high_confidence", 0)),
|
||||
("Moderate Confidence", summary_data.get("moderate_confidence", 0)),
|
||||
("Low Confidence", summary_data.get("low_confidence", 0)),
|
||||
("Total Tokens Used", summary_data.get("total_tokens", 0)),
|
||||
("Estimated Cost (USD)", summary_data.get("estimated_cost", 0.0)),
|
||||
("Agent Version", summary_data.get("agent_version", "")),
|
||||
("Generated At", summary_data.get("generated_at", "")),
|
||||
]
|
||||
|
||||
ws2.column_dimensions["A"].width = 25
|
||||
ws2.column_dimensions["B"].width = 40
|
||||
|
||||
for row_idx, (label, value) in enumerate(summary_rows, start=1):
|
||||
label_cell = ws2.cell(row=row_idx, column=1, value=label)
|
||||
label_cell.font = Font(bold=True)
|
||||
ws2.cell(row=row_idx, column=2, value=str(value))
|
||||
|
||||
# Save
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
wb.save(output_path)
|
||||
wb.close()
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def _find_source_line(
|
||||
source_lines: list[dict[str, Any]], output_row: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Find the source line matching an output row by row_order or line_id."""
|
||||
row_order = output_row.get("row_order")
|
||||
line_id = output_row.get("line_id")
|
||||
|
||||
for sl in source_lines:
|
||||
if line_id and sl.get("id") == line_id:
|
||||
return sl
|
||||
if row_order is not None and sl.get("row_order") == row_order:
|
||||
return sl
|
||||
|
||||
# Fallback: match by index
|
||||
if row_order is not None and 0 < row_order <= len(source_lines):
|
||||
return source_lines[row_order - 1]
|
||||
|
||||
return {}
|
||||
67
backend/app/pipeline/modules/line_break_normaliser.py
Normal file
67
backend/app/pipeline/modules/line_break_normaliser.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""Line break normalisation utilities.
|
||||
|
||||
Three modes:
|
||||
- normalise_for_query: Strip line breaks, collapse multiple spaces to single.
|
||||
Used when building search queries against TM.
|
||||
- normalise_for_excel: Convert \\n to openpyxl-compatible line breaks.
|
||||
Used when writing output cells.
|
||||
- preserve_raw: Return text as-is (identity function for pipeline clarity).
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def normalise_for_query(text: str) -> str:
|
||||
"""Strip line breaks and collapse spaces for TM query matching.
|
||||
|
||||
Args:
|
||||
text: Raw text potentially containing line breaks.
|
||||
|
||||
Returns:
|
||||
Single-line text with normalised whitespace.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Replace all line break variants with a space
|
||||
result = text.replace("\r\n", " ").replace("\r", " ").replace("\n", " ")
|
||||
|
||||
# Collapse multiple spaces to one
|
||||
result = re.sub(r"\s+", " ", result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
|
||||
def normalise_for_excel(text: str) -> str:
|
||||
"""Convert line breaks to openpyxl-compatible format.
|
||||
|
||||
openpyxl uses \\n for in-cell line breaks when wrap_text is enabled.
|
||||
This ensures consistent line break representation.
|
||||
|
||||
Args:
|
||||
text: Text with potential line breaks.
|
||||
|
||||
Returns:
|
||||
Text with standardised \\n line breaks.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Normalise all line break variants to \\n
|
||||
result = text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def preserve_raw(text: str) -> str:
|
||||
"""Return text as-is (identity function).
|
||||
|
||||
Used in the pipeline to explicitly indicate no normalisation is applied.
|
||||
|
||||
Args:
|
||||
text: Any text.
|
||||
|
||||
Returns:
|
||||
The same text, unchanged.
|
||||
"""
|
||||
return text
|
||||
177
backend/app/pipeline/modules/ref_file_loader.py
Normal file
177
backend/app/pipeline/modules/ref_file_loader.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""Reference file loader.
|
||||
|
||||
Loads various reference files used in the transcreation pipeline:
|
||||
- Glossary (JSON): locale-specific term glossary
|
||||
- Blacklist (JSON): forbidden terms and roots
|
||||
- Date/Percentage formats (JSON): approved format patterns
|
||||
- Locale Considerations (JSON): locale-specific rules and notes
|
||||
- TOV (Tone of Voice) files (JSON): global and supplementary voice profiles
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RefFileLoadError(Exception):
|
||||
"""Raised when a reference file cannot be loaded or parsed."""
|
||||
pass
|
||||
|
||||
|
||||
def load_json_file(file_path: str) -> Any:
|
||||
"""Load and parse a JSON file.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Parsed JSON data.
|
||||
|
||||
Raises:
|
||||
RefFileLoadError: If file cannot be read or parsed.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise RefFileLoadError(f"Reference file not found: {file_path}")
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RefFileLoadError(f"Invalid JSON in reference file: {exc}")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise RefFileLoadError(f"Encoding error reading reference file: {exc}")
|
||||
|
||||
|
||||
def load_glossary(file_path: str) -> list[dict[str, str]]:
|
||||
"""Load a glossary file.
|
||||
|
||||
Expected format: list of dicts with keys like:
|
||||
{"en": "source term", "tx": "translated term", "context": "usage notes"}
|
||||
|
||||
Args:
|
||||
file_path: Path to glossary JSON file.
|
||||
|
||||
Returns:
|
||||
List of glossary entry dicts.
|
||||
"""
|
||||
data = load_json_file(file_path)
|
||||
if not isinstance(data, list):
|
||||
raise RefFileLoadError("Glossary file must contain a JSON array")
|
||||
return data
|
||||
|
||||
|
||||
def load_blacklist(file_path: str) -> list[dict[str, str]]:
|
||||
"""Load a blacklist file.
|
||||
|
||||
Expected format: list of dicts with keys:
|
||||
{"term": "forbidden term", "root": "optional root", "reason": "why forbidden"}
|
||||
|
||||
Args:
|
||||
file_path: Path to blacklist JSON file.
|
||||
|
||||
Returns:
|
||||
List of blacklist entry dicts.
|
||||
"""
|
||||
data = load_json_file(file_path)
|
||||
if not isinstance(data, list):
|
||||
raise RefFileLoadError("Blacklist file must contain a JSON array")
|
||||
return data
|
||||
|
||||
|
||||
def load_date_pct_formats(file_path: str) -> dict[str, list[dict[str, str]]]:
|
||||
"""Load date/percentage format rules.
|
||||
|
||||
Expected format:
|
||||
{
|
||||
"date_formats": [{"pattern": "...", "example": "...", "description": "..."}],
|
||||
"percentage_formats": [{"pattern": "...", "example": "...", "description": "..."}]
|
||||
}
|
||||
|
||||
Args:
|
||||
file_path: Path to date/pct formats JSON file.
|
||||
|
||||
Returns:
|
||||
Dict with "date_formats" and "percentage_formats" keys.
|
||||
"""
|
||||
data = load_json_file(file_path)
|
||||
if not isinstance(data, dict):
|
||||
raise RefFileLoadError("Date/pct format file must contain a JSON object")
|
||||
return {
|
||||
"date_formats": data.get("date_formats", []),
|
||||
"percentage_formats": data.get("percentage_formats", []),
|
||||
}
|
||||
|
||||
|
||||
def load_locale_considerations(file_path: str) -> dict[str, Any]:
|
||||
"""Load locale-specific considerations.
|
||||
|
||||
Expected format: JSON object with locale-specific rules, cultural notes, etc.
|
||||
|
||||
Args:
|
||||
file_path: Path to locale considerations JSON file.
|
||||
|
||||
Returns:
|
||||
Dict of locale considerations.
|
||||
"""
|
||||
data = load_json_file(file_path)
|
||||
if not isinstance(data, dict):
|
||||
raise RefFileLoadError(
|
||||
"Locale considerations file must contain a JSON object"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def load_tov(file_path: str) -> dict[str, Any]:
|
||||
"""Load a Tone of Voice file (global or supplement).
|
||||
|
||||
Expected format: JSON object with voice profile data.
|
||||
|
||||
Args:
|
||||
file_path: Path to TOV JSON file.
|
||||
|
||||
Returns:
|
||||
Dict of TOV profile data.
|
||||
"""
|
||||
data = load_json_file(file_path)
|
||||
if not isinstance(data, dict):
|
||||
raise RefFileLoadError("TOV file must contain a JSON object")
|
||||
return data
|
||||
|
||||
|
||||
def load_all_reference_files(
|
||||
file_manifest: dict[str, str | None],
|
||||
) -> dict[str, Any]:
|
||||
"""Load all reference files from a file manifest.
|
||||
|
||||
Args:
|
||||
file_manifest: Dict mapping file types to file paths.
|
||||
Keys: glossary_file, blacklist_file, tov_global_file,
|
||||
tov_supplement_file, locale_considerations_file,
|
||||
date_pct_formats_file
|
||||
|
||||
Returns:
|
||||
Dict mapping file types to loaded data.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
loaders = {
|
||||
"glossary_file": load_glossary,
|
||||
"blacklist_file": load_blacklist,
|
||||
"date_pct_formats_file": load_date_pct_formats,
|
||||
"locale_considerations_file": load_locale_considerations,
|
||||
"tov_global_file": load_tov,
|
||||
"tov_supplement_file": load_tov,
|
||||
}
|
||||
|
||||
for key, loader in loaders.items():
|
||||
path = file_manifest.get(key)
|
||||
if path:
|
||||
try:
|
||||
result[key] = loader(path)
|
||||
except RefFileLoadError:
|
||||
result[key] = None
|
||||
else:
|
||||
result[key] = None
|
||||
|
||||
return result
|
||||
97
backend/app/pipeline/modules/source_file_parser.py
Normal file
97
backend/app/pipeline/modules/source_file_parser.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Parse source xlsx files into structured source line data.
|
||||
|
||||
Validates the expected 5 column headers (case sensitive):
|
||||
EN_GB, Copy Type, Creative Guidance, Visual Ref, Char Limit
|
||||
|
||||
Skips rows where EN_GB is empty. Detects \\n in EN_GB for is_display_format.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from openpyxl import load_workbook
|
||||
|
||||
REQUIRED_HEADERS = ["EN_GB", "Copy Type", "Creative Guidance", "Visual Ref", "Char Limit"]
|
||||
|
||||
|
||||
class SourceFileParseError(Exception):
|
||||
"""Raised when the source file has validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
def parse_source_file(file_path: str) -> list[dict[str, Any]]:
|
||||
"""Parse a source xlsx file and return a list of source line dicts.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the xlsx file.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: en_gb, copy_type, creative_guidance,
|
||||
visual_ref, char_limit, is_display_format.
|
||||
|
||||
Raises:
|
||||
SourceFileParseError: If headers are invalid or file cannot be read.
|
||||
"""
|
||||
try:
|
||||
wb = load_workbook(file_path, read_only=True, data_only=True)
|
||||
except Exception as exc:
|
||||
raise SourceFileParseError(f"Cannot open xlsx file: {exc}")
|
||||
|
||||
ws = wb.active
|
||||
if ws is None:
|
||||
raise SourceFileParseError("Workbook has no active sheet")
|
||||
|
||||
# Read and validate headers from first row
|
||||
rows = ws.iter_rows(min_row=1, max_row=1, values_only=True)
|
||||
first_row = next(rows, None)
|
||||
if first_row is None:
|
||||
raise SourceFileParseError("File is empty - no header row found")
|
||||
|
||||
headers = [str(cell).strip() if cell else "" for cell in first_row]
|
||||
|
||||
# Validate all required headers exist (case sensitive)
|
||||
for required in REQUIRED_HEADERS:
|
||||
if required not in headers:
|
||||
raise SourceFileParseError(
|
||||
f"Missing required header '{required}'. "
|
||||
f"Found headers: {headers}. "
|
||||
f"Expected: {REQUIRED_HEADERS}"
|
||||
)
|
||||
|
||||
# Build column index map
|
||||
col_map = {header: idx for idx, header in enumerate(headers)}
|
||||
|
||||
# Parse data rows
|
||||
source_lines: list[dict[str, Any]] = []
|
||||
for row in ws.iter_rows(min_row=2, values_only=True):
|
||||
en_gb_idx = col_map["EN_GB"]
|
||||
en_gb_raw = row[en_gb_idx] if en_gb_idx < len(row) else None
|
||||
|
||||
# Skip rows where EN_GB is empty
|
||||
if en_gb_raw is None or str(en_gb_raw).strip() == "":
|
||||
continue
|
||||
|
||||
en_gb = str(en_gb_raw).strip()
|
||||
|
||||
# Detect display format: presence of \n in EN_GB text
|
||||
is_display_format = "\n" in en_gb
|
||||
|
||||
def _get_cell(header: str) -> str | None:
|
||||
idx = col_map.get(header)
|
||||
if idx is None or idx >= len(row):
|
||||
return None
|
||||
val = row[idx]
|
||||
if val is None:
|
||||
return None
|
||||
return str(val).strip() or None
|
||||
|
||||
source_lines.append({
|
||||
"en_gb": en_gb,
|
||||
"copy_type": _get_cell("Copy Type"),
|
||||
"creative_guidance": _get_cell("Creative Guidance"),
|
||||
"visual_ref": _get_cell("Visual Ref"),
|
||||
"char_limit": _get_cell("Char Limit"),
|
||||
"is_display_format": is_display_format,
|
||||
})
|
||||
|
||||
wb.close()
|
||||
return source_lines
|
||||
133
backend/app/pipeline/modules/tm_file_loader.py
Normal file
133
backend/app/pipeline/modules/tm_file_loader.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""Translation Memory file loader.
|
||||
|
||||
Reads JSONL files in two formats:
|
||||
1. Compact: {"t": "seg_key|date|en|lc|tx|nt|channel|sub_channel"}
|
||||
2. Multi-field: {"seg_key": "...", "date": "...", "en": "...", ...}
|
||||
|
||||
Applies a locale hard-match gate: only entries matching the target locale are returned.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from app.pipeline.contracts import TMEntry
|
||||
|
||||
|
||||
class TMFileLoadError(Exception):
|
||||
"""Raised when a TM file cannot be loaded or parsed."""
|
||||
pass
|
||||
|
||||
|
||||
def load_tm_file(
|
||||
file_path: str,
|
||||
target_locale: str,
|
||||
) -> list[TMEntry]:
|
||||
"""Load and parse a JSONL TM file, filtering by locale.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the JSONL file.
|
||||
target_locale: Target locale code (e.g., "de_DE"). Only entries
|
||||
matching this locale will be returned.
|
||||
|
||||
Returns:
|
||||
List of TMEntry objects matching the target locale.
|
||||
|
||||
Raises:
|
||||
TMFileLoadError: If the file cannot be read or parsed.
|
||||
"""
|
||||
entries: list[TMEntry] = []
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise TMFileLoadError(
|
||||
f"Invalid JSON on line {line_num}: {exc}"
|
||||
)
|
||||
|
||||
entry = _parse_entry(data, line_num)
|
||||
if entry is None:
|
||||
continue
|
||||
|
||||
# Locale hard-match gate
|
||||
if entry.lc == target_locale:
|
||||
entries.append(entry)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise TMFileLoadError(f"TM file not found: {file_path}")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise TMFileLoadError(f"Encoding error reading TM file: {exc}")
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def _parse_entry(data: dict[str, Any], line_num: int) -> TMEntry | None:
|
||||
"""Parse a single JSON object into a TMEntry.
|
||||
|
||||
Detects compact vs multi-field format automatically.
|
||||
|
||||
Args:
|
||||
data: Parsed JSON dict.
|
||||
line_num: Line number for error reporting.
|
||||
|
||||
Returns:
|
||||
TMEntry or None if the entry is malformed.
|
||||
"""
|
||||
# Compact format: {"t": "seg_key|date|en|lc|tx|nt|channel|sub_channel"}
|
||||
if "t" in data and isinstance(data["t"], str):
|
||||
parts = data["t"].split("|")
|
||||
if len(parts) < 5:
|
||||
return None # Malformed compact entry
|
||||
|
||||
return TMEntry(
|
||||
seg_key=parts[0] if len(parts) > 0 else "",
|
||||
date=parts[1] if len(parts) > 1 else "",
|
||||
en=parts[2] if len(parts) > 2 else "",
|
||||
lc=parts[3] if len(parts) > 3 else "",
|
||||
tx=parts[4] if len(parts) > 4 else "",
|
||||
nt=parts[5] if len(parts) > 5 else "",
|
||||
channel=parts[6] if len(parts) > 6 else "",
|
||||
sub_channel=parts[7] if len(parts) > 7 else "",
|
||||
_text=data["t"],
|
||||
)
|
||||
|
||||
# Multi-field format
|
||||
if "seg_key" in data and "en" in data:
|
||||
return TMEntry(
|
||||
seg_key=str(data.get("seg_key", "")),
|
||||
date=str(data.get("date", "")),
|
||||
en=str(data.get("en", "")),
|
||||
lc=str(data.get("lc", "")),
|
||||
tx=str(data.get("tx", "")),
|
||||
nt=str(data.get("nt", "")),
|
||||
channel=str(data.get("channel", "")),
|
||||
sub_channel=str(data.get("sub_channel", "")),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_multiple_tm_files(
|
||||
file_paths: list[str],
|
||||
target_locale: str,
|
||||
) -> list[TMEntry]:
|
||||
"""Load and merge multiple TM files.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to load.
|
||||
target_locale: Target locale code.
|
||||
|
||||
Returns:
|
||||
Combined list of TMEntry objects from all files.
|
||||
"""
|
||||
all_entries: list[TMEntry] = []
|
||||
for path in file_paths:
|
||||
entries = load_tm_file(path, target_locale)
|
||||
all_entries.extend(entries)
|
||||
return all_entries
|
||||
190
backend/app/pipeline/orchestrator.py
Normal file
190
backend/app/pipeline/orchestrator.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Pipeline orchestrator - state machine managing the transcreation pipeline.
|
||||
|
||||
States: INIT -> VALIDATE -> TM_RETRIEVE -> RANK -> TRANSCREATE -> COMPLY -> FORMAT -> DONE
|
||||
|
||||
Runs agents in sequence with compliance re-draft loop (max 3 iterations).
|
||||
Progress callbacks enable WebSocket updates.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.pipeline.contracts import FileManifest, JobParams, PipelineContext
|
||||
from app.pipeline.agents.agent_1_validator import Agent1Validator
|
||||
from app.pipeline.agents.agent_2_tm_retrieval import Agent2TMRetrieval
|
||||
from app.pipeline.agents.agent_3_ranker import Agent3Ranker
|
||||
from app.pipeline.agents.agent_4_transcreator import Agent4Transcreator
|
||||
from app.pipeline.agents.agent_5_compliance import Agent5Compliance
|
||||
from app.pipeline.agents.agent_6_formatter import Agent6Formatter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_COMPLIANCE_RETRIES = 3
|
||||
|
||||
|
||||
class PipelineState(str, enum.Enum):
|
||||
INIT = "INIT"
|
||||
VALIDATE = "VALIDATE"
|
||||
TM_RETRIEVE = "TM_RETRIEVE"
|
||||
RANK = "RANK"
|
||||
TRANSCREATE = "TRANSCREATE"
|
||||
COMPLY = "COMPLY"
|
||||
FORMAT = "FORMAT"
|
||||
DONE = "DONE"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
class PipelineOrchestrator:
|
||||
"""Orchestrates the transcreation pipeline through its state machine.
|
||||
|
||||
Args:
|
||||
job_params: Job parameters dict.
|
||||
source_lines: Pre-parsed source lines (optional).
|
||||
source_file_path: Path to source xlsx (optional, used if source_lines not given).
|
||||
file_manifest: Dict of reference file paths.
|
||||
output_dir: Directory for output files.
|
||||
on_progress: Optional callback(state, message, pct) for progress updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_params: dict[str, Any],
|
||||
source_lines: list[dict] | None = None,
|
||||
source_file_path: str | None = None,
|
||||
file_manifest: dict[str, Any] | None = None,
|
||||
output_dir: str | None = None,
|
||||
on_progress: Callable[[str, str, float], None] | None = None,
|
||||
) -> None:
|
||||
self.job_params = job_params
|
||||
self.source_lines = source_lines
|
||||
self.source_file_path = source_file_path
|
||||
self.file_manifest = file_manifest or {}
|
||||
self.output_dir = output_dir
|
||||
self.on_progress = on_progress
|
||||
self.state = PipelineState.INIT
|
||||
self.context: PipelineContext | None = None
|
||||
|
||||
def _emit_progress(self, message: str, pct: float) -> None:
|
||||
"""Emit a progress callback if one is registered."""
|
||||
if self.on_progress:
|
||||
try:
|
||||
self.on_progress(self.state.value, message, pct)
|
||||
except Exception as e:
|
||||
logger.warning(f"Progress callback error: {e}")
|
||||
|
||||
async def run(self) -> PipelineContext:
|
||||
"""Run the full pipeline from INIT to DONE.
|
||||
|
||||
Returns:
|
||||
The final PipelineContext with all results.
|
||||
|
||||
Raises:
|
||||
Exception: If any agent fails.
|
||||
"""
|
||||
# Initialize context
|
||||
self.context = PipelineContext(
|
||||
job_params=JobParams(**self.job_params),
|
||||
file_manifest=FileManifest(),
|
||||
)
|
||||
|
||||
try:
|
||||
# VALIDATE
|
||||
self.state = PipelineState.VALIDATE
|
||||
self._emit_progress("Validating inputs", 0.1)
|
||||
logger.info(f"[{self.job_params.get('job_id')}] VALIDATE")
|
||||
|
||||
validator = Agent1Validator(
|
||||
source_file_path=self.source_file_path,
|
||||
source_lines=self.source_lines,
|
||||
file_manifest=self.file_manifest,
|
||||
job_params=self.job_params,
|
||||
)
|
||||
self.context = await validator.run(self.context)
|
||||
|
||||
# TM_RETRIEVE
|
||||
self.state = PipelineState.TM_RETRIEVE
|
||||
self._emit_progress("Retrieving TM matches", 0.25)
|
||||
logger.info(f"[{self.job_params.get('job_id')}] TM_RETRIEVE")
|
||||
|
||||
tm_retriever = Agent2TMRetrieval()
|
||||
self.context = await tm_retriever.run(self.context)
|
||||
|
||||
# RANK
|
||||
self.state = PipelineState.RANK
|
||||
self._emit_progress("Ranking matches", 0.4)
|
||||
logger.info(f"[{self.job_params.get('job_id')}] RANK")
|
||||
|
||||
ranker = Agent3Ranker()
|
||||
self.context = await ranker.run(self.context)
|
||||
|
||||
# TRANSCREATE + COMPLY loop
|
||||
for attempt in range(1, MAX_COMPLIANCE_RETRIES + 1):
|
||||
# TRANSCREATE
|
||||
self.state = PipelineState.TRANSCREATE
|
||||
self._emit_progress(
|
||||
f"Generating transcreations (attempt {attempt})",
|
||||
0.5 + (attempt - 1) * 0.1,
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.job_params.get('job_id')}] TRANSCREATE (attempt {attempt})"
|
||||
)
|
||||
|
||||
transcreator = Agent4Transcreator()
|
||||
self.context = await transcreator.run(self.context)
|
||||
|
||||
# COMPLY
|
||||
self.state = PipelineState.COMPLY
|
||||
self._emit_progress(
|
||||
f"Checking compliance (attempt {attempt})",
|
||||
0.6 + (attempt - 1) * 0.1,
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.job_params.get('job_id')}] COMPLY (attempt {attempt})"
|
||||
)
|
||||
|
||||
compliance = Agent5Compliance()
|
||||
self.context = await compliance.run(self.context)
|
||||
|
||||
# Check if all passed
|
||||
all_passed = all(
|
||||
cr.passed for cr in self.context.compliance_results
|
||||
)
|
||||
if all_passed:
|
||||
logger.info(
|
||||
f"[{self.job_params.get('job_id')}] All compliance checks passed"
|
||||
)
|
||||
break
|
||||
else:
|
||||
failed_count = sum(
|
||||
1 for cr in self.context.compliance_results if not cr.passed
|
||||
)
|
||||
logger.warning(
|
||||
f"[{self.job_params.get('job_id')}] "
|
||||
f"{failed_count} compliance failures, "
|
||||
f"attempt {attempt}/{MAX_COMPLIANCE_RETRIES}"
|
||||
)
|
||||
|
||||
# FORMAT
|
||||
self.state = PipelineState.FORMAT
|
||||
self._emit_progress("Generating output file", 0.9)
|
||||
logger.info(f"[{self.job_params.get('job_id')}] FORMAT")
|
||||
|
||||
formatter = Agent6Formatter(output_dir=self.output_dir)
|
||||
self.context = await formatter.run(self.context)
|
||||
|
||||
# DONE
|
||||
self.state = PipelineState.DONE
|
||||
self._emit_progress("Complete", 1.0)
|
||||
logger.info(f"[{self.job_params.get('job_id')}] DONE")
|
||||
|
||||
return self.context
|
||||
|
||||
except Exception as e:
|
||||
self.state = PipelineState.ERROR
|
||||
self._emit_progress(f"Error: {str(e)}", -1.0)
|
||||
logger.error(
|
||||
f"[{self.job_params.get('job_id')}] Pipeline error: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
0
backend/app/schemas/__init__.py
Normal file
0
backend/app/schemas/__init__.py
Normal file
29
backend/app/schemas/client.py
Normal file
29
backend/app/schemas/client.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ClientCreate(BaseModel):
|
||||
name: str
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ClientUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ClientResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
settings: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
19
backend/app/schemas/common.py
Normal file
19
backend/app/schemas/common.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
items: list[T]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: str
|
||||
code: str | None = None
|
||||
errors: list[dict[str, Any]] | None = None
|
||||
27
backend/app/schemas/feedback.py
Normal file
27
backend/app/schemas/feedback.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.feedback import FlagType
|
||||
|
||||
|
||||
class FeedbackCreate(BaseModel):
|
||||
output_id: UUID
|
||||
option_column: int
|
||||
flag_type: FlagType
|
||||
comment: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
id: UUID
|
||||
output_id: UUID
|
||||
user_id: UUID
|
||||
option_column: int
|
||||
flag_type: FlagType
|
||||
comment: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
46
backend/app/schemas/files.py
Normal file
46
backend/app/schemas/files.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.files import ReferenceFileType
|
||||
|
||||
|
||||
class TMFileResponse(BaseModel):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
locale_code: str
|
||||
channel: str
|
||||
filename: str
|
||||
file_path: str
|
||||
segment_count: int
|
||||
uploaded_by: UUID | None = None
|
||||
uploaded_at: datetime
|
||||
last_updated_at: datetime
|
||||
version: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ReferenceFileResponse(BaseModel):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
file_type: ReferenceFileType
|
||||
locale_scope: str
|
||||
filename: str
|
||||
file_path: str
|
||||
uploaded_by: UUID | None = None
|
||||
uploaded_at: datetime
|
||||
last_updated_at: datetime
|
||||
version: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FileUploadResponse(BaseModel):
|
||||
id: UUID
|
||||
filename: str
|
||||
file_path: str
|
||||
message: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
94
backend/app/schemas/job.py
Normal file
94
backend/app/schemas/job.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.job import JobStatus, JobType, LocaleStatus, LocaleType, Programme
|
||||
|
||||
|
||||
class JobCreate(BaseModel):
|
||||
client_id: UUID
|
||||
campaign_name: str
|
||||
programme: Programme
|
||||
channel: str
|
||||
sub_channel: str | None = None
|
||||
context_prompt: str | None = None
|
||||
job_type: JobType = JobType.main
|
||||
parent_job_id: UUID | None = None
|
||||
job_ref: str | None = None
|
||||
locale_codes: list[str] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class JobUpdate(BaseModel):
|
||||
campaign_name: str | None = None
|
||||
programme: Programme | None = None
|
||||
channel: str | None = None
|
||||
sub_channel: str | None = None
|
||||
context_prompt: str | None = None
|
||||
job_ref: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class LocaleInstanceResponse(BaseModel):
|
||||
id: UUID
|
||||
job_id: UUID
|
||||
locale_code: str
|
||||
locale_type: LocaleType
|
||||
status: LocaleStatus
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
token_usage: int = 0
|
||||
estimated_cost: float = 0.0
|
||||
output_file_path: str | None = None
|
||||
error_log: str | None = None
|
||||
tm_files_loaded: dict[str, Any] | None = None
|
||||
ref_files_loaded: dict[str, Any] | None = None
|
||||
agent_version: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
created_by: UUID
|
||||
job_ref: str | None = None
|
||||
campaign_name: str
|
||||
programme: Programme
|
||||
channel: str
|
||||
sub_channel: str | None = None
|
||||
context_prompt: str | None = None
|
||||
job_type: JobType
|
||||
parent_job_id: UUID | None = None
|
||||
status: JobStatus
|
||||
total_token_usage: int = 0
|
||||
total_estimated_cost: float = 0.0
|
||||
locale_instances: list[LocaleInstanceResponse] = []
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class JobListResponse(BaseModel):
|
||||
id: UUID
|
||||
client_id: UUID
|
||||
created_by: UUID
|
||||
job_ref: str | None = None
|
||||
campaign_name: str
|
||||
programme: Programme
|
||||
channel: str
|
||||
status: JobStatus
|
||||
total_token_usage: int = 0
|
||||
total_estimated_cost: float = 0.0
|
||||
locale_count: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
51
backend/app/schemas/output.py
Normal file
51
backend/app/schemas/output.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.output import ConfidenceTier
|
||||
|
||||
|
||||
class OutputRowResponse(BaseModel):
|
||||
id: UUID
|
||||
instance_id: UUID
|
||||
line_id: UUID
|
||||
row_order: int
|
||||
confidence_tier: ConfidenceTier
|
||||
option_1: str
|
||||
backtranslation_1: str
|
||||
rationale_1: str
|
||||
option_2: str | None = None
|
||||
backtranslation_2: str | None = None
|
||||
rationale_2: str | None = None
|
||||
option_3: str | None = None
|
||||
backtranslation_3: str | None = None
|
||||
rationale_3: str | None = None
|
||||
tm_entries_cited: dict[str, Any] | None = None
|
||||
winning_seg_key: str | None = None
|
||||
character_count_option_1: int | None = None
|
||||
character_count_option_2: int | None = None
|
||||
character_count_option_3: int | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SourceLinePreview(BaseModel):
|
||||
id: UUID
|
||||
row_order: int
|
||||
en_gb: str
|
||||
copy_type: str | None = None
|
||||
creative_guidance: str | None = None
|
||||
char_limit: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class OutputPreviewResponse(BaseModel):
|
||||
locale_code: str
|
||||
instance_id: UUID
|
||||
source_lines: list[SourceLinePreview]
|
||||
output_rows: list[OutputRowResponse]
|
||||
total_rows: int
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
38
backend/app/schemas/user.py
Normal file
38
backend/app/schemas/user.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from app.models.user import UserRole, UserStatus
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
name: str
|
||||
password: str
|
||||
role: UserRole = UserRole.reviewer
|
||||
client_ids: list[UUID] = []
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
email: EmailStr | None = None
|
||||
name: str | None = None
|
||||
password: str | None = None
|
||||
role: UserRole | None = None
|
||||
status: UserStatus | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
name: str
|
||||
role: UserRole
|
||||
status: UserStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
77
backend/app/services/audit_service.py
Normal file
77
backend/app/services/audit_service.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.audit import AuditLog
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""Service for audit log creation and retrieval."""
|
||||
|
||||
async def log(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
action: str,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
user_id: UUID | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> AuditLog:
|
||||
"""Create an audit log entry."""
|
||||
entry = AuditLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(entry)
|
||||
await db.flush()
|
||||
return entry
|
||||
|
||||
async def list_logs(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: UUID | None = None,
|
||||
action: str | None = None,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
date_from: datetime | None = None,
|
||||
date_to: datetime | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[AuditLog], int]:
|
||||
"""List audit logs with filters and pagination."""
|
||||
query = select(AuditLog)
|
||||
|
||||
if user_id:
|
||||
query = query.where(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.where(AuditLog.action == action)
|
||||
if entity_type:
|
||||
query = query.where(AuditLog.entity_type == entity_type)
|
||||
if entity_id:
|
||||
query = query.where(AuditLog.entity_id == entity_id)
|
||||
if date_from:
|
||||
query = query.where(AuditLog.timestamp >= date_from)
|
||||
if date_to:
|
||||
query = query.where(AuditLog.timestamp <= date_to)
|
||||
|
||||
# Count
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Paginate
|
||||
query = query.order_by(AuditLog.timestamp.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
logs = list(result.scalars().all())
|
||||
|
||||
return logs, total
|
||||
66
backend/app/services/feedback_service.py
Normal file
66
backend/app/services/feedback_service.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.feedback import Feedback
|
||||
from app.schemas.feedback import FeedbackCreate
|
||||
|
||||
|
||||
class FeedbackService:
|
||||
"""Service for feedback CRUD operations."""
|
||||
|
||||
async def create_feedback(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
data: FeedbackCreate,
|
||||
user_id: UUID,
|
||||
) -> Feedback:
|
||||
"""Create a new feedback entry."""
|
||||
feedback = Feedback(
|
||||
output_id=data.output_id,
|
||||
user_id=user_id,
|
||||
option_column=data.option_column,
|
||||
flag_type=data.flag_type,
|
||||
comment=data.comment,
|
||||
)
|
||||
db.add(feedback)
|
||||
await db.flush()
|
||||
return feedback
|
||||
|
||||
async def list_feedback(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
output_id: UUID | None = None,
|
||||
user_id: UUID | None = None,
|
||||
) -> list[Feedback]:
|
||||
"""List feedback entries with optional filters."""
|
||||
query = select(Feedback)
|
||||
if output_id:
|
||||
query = query.where(Feedback.output_id == output_id)
|
||||
if user_id:
|
||||
query = query.where(Feedback.user_id == user_id)
|
||||
query = query.order_by(Feedback.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_feedback(
|
||||
self, db: AsyncSession, feedback_id: UUID
|
||||
) -> Feedback | None:
|
||||
"""Get a single feedback entry by ID."""
|
||||
result = await db.execute(
|
||||
select(Feedback).where(Feedback.id == feedback_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete_feedback(
|
||||
self, db: AsyncSession, feedback_id: UUID
|
||||
) -> bool:
|
||||
"""Delete a feedback entry."""
|
||||
feedback = await self.get_feedback(db, feedback_id)
|
||||
if feedback is None:
|
||||
return False
|
||||
await db.delete(feedback)
|
||||
await db.flush()
|
||||
return True
|
||||
234
backend/app/services/file_service.py
Normal file
234
backend/app/services/file_service.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.files import ReferenceFile, ReferenceFileType, TMFileRegistry
|
||||
from app.models.source import SourceLine
|
||||
from app.pipeline.modules.source_file_parser import parse_source_file
|
||||
|
||||
|
||||
class FileService:
|
||||
"""Service for file upload, download, path resolution, and storage management."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.storage_root = Path(settings.STORAGE_ROOT)
|
||||
|
||||
def _resolve_path(self, *parts: str) -> Path:
|
||||
"""Resolve a storage path and ensure parent directories exist."""
|
||||
path = self.storage_root.joinpath(*parts)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
async def upload_source_file(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
job_id: UUID,
|
||||
file: BinaryIO,
|
||||
filename: str,
|
||||
) -> list[SourceLine]:
|
||||
"""Upload and parse a source xlsx file, creating SourceLine records."""
|
||||
# Save to storage
|
||||
file_path = self._resolve_path("jobs", str(job_id), "source", filename)
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file, f)
|
||||
|
||||
# Parse the xlsx
|
||||
parsed_lines = parse_source_file(str(file_path))
|
||||
|
||||
# Delete existing source lines for this job
|
||||
existing = await db.execute(
|
||||
select(SourceLine).where(SourceLine.job_id == job_id)
|
||||
)
|
||||
for line in existing.scalars().all():
|
||||
await db.delete(line)
|
||||
|
||||
# Create new source lines
|
||||
source_lines = []
|
||||
for i, row in enumerate(parsed_lines):
|
||||
source_line = SourceLine(
|
||||
job_id=job_id,
|
||||
row_order=i + 1,
|
||||
en_gb=row["en_gb"],
|
||||
copy_type=row.get("copy_type"),
|
||||
creative_guidance=row.get("creative_guidance"),
|
||||
visual_ref=row.get("visual_ref"),
|
||||
char_limit=row.get("char_limit"),
|
||||
is_display_format=row.get("is_display_format", False),
|
||||
)
|
||||
db.add(source_line)
|
||||
source_lines.append(source_line)
|
||||
|
||||
await db.flush()
|
||||
return source_lines
|
||||
|
||||
async def upload_supplementary_file(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
job_id: UUID,
|
||||
file: BinaryIO,
|
||||
filename: str,
|
||||
) -> str:
|
||||
"""Upload a supplementary file (TM, glossary, etc.) for a job."""
|
||||
file_path = self._resolve_path("jobs", str(job_id), "supplementary", filename)
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file, f)
|
||||
return str(file_path)
|
||||
|
||||
async def upload_tm_file(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID,
|
||||
locale_code: str,
|
||||
channel: str,
|
||||
file: BinaryIO,
|
||||
filename: str,
|
||||
uploaded_by: UUID | None = None,
|
||||
) -> TMFileRegistry:
|
||||
"""Upload a TM file and create a registry entry."""
|
||||
file_path = self._resolve_path(
|
||||
"clients", str(client_id), "tm", locale_code, filename
|
||||
)
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file, f)
|
||||
|
||||
# Count segments (lines in JSONL)
|
||||
segment_count = 0
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
segment_count += 1
|
||||
|
||||
tm_file = TMFileRegistry(
|
||||
client_id=client_id,
|
||||
locale_code=locale_code,
|
||||
channel=channel,
|
||||
filename=filename,
|
||||
file_path=str(file_path),
|
||||
segment_count=segment_count,
|
||||
uploaded_by=uploaded_by,
|
||||
)
|
||||
db.add(tm_file)
|
||||
await db.flush()
|
||||
return tm_file
|
||||
|
||||
async def upload_reference_file(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID,
|
||||
file_type: ReferenceFileType,
|
||||
locale_scope: str,
|
||||
file: BinaryIO,
|
||||
filename: str,
|
||||
uploaded_by: UUID | None = None,
|
||||
) -> ReferenceFile:
|
||||
"""Upload a reference file and create a registry entry."""
|
||||
file_path = self._resolve_path(
|
||||
"clients", str(client_id), "reference", file_type.value, filename
|
||||
)
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file, f)
|
||||
|
||||
ref_file = ReferenceFile(
|
||||
client_id=client_id,
|
||||
file_type=file_type,
|
||||
locale_scope=locale_scope,
|
||||
filename=filename,
|
||||
file_path=str(file_path),
|
||||
uploaded_by=uploaded_by,
|
||||
)
|
||||
db.add(ref_file)
|
||||
await db.flush()
|
||||
return ref_file
|
||||
|
||||
async def list_tm_files(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID,
|
||||
locale_code: str | None = None,
|
||||
channel: str | None = None,
|
||||
) -> list[TMFileRegistry]:
|
||||
"""List TM files for a client with optional filters."""
|
||||
query = select(TMFileRegistry).where(TMFileRegistry.client_id == client_id)
|
||||
if locale_code:
|
||||
query = query.where(TMFileRegistry.locale_code == locale_code)
|
||||
if channel:
|
||||
query = query.where(TMFileRegistry.channel == channel)
|
||||
|
||||
result = await db.execute(query.order_by(TMFileRegistry.uploaded_at.desc()))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def list_reference_files(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID,
|
||||
file_type: ReferenceFileType | None = None,
|
||||
locale_scope: str | None = None,
|
||||
) -> list[ReferenceFile]:
|
||||
"""List reference files for a client with optional filters."""
|
||||
query = select(ReferenceFile).where(ReferenceFile.client_id == client_id)
|
||||
if file_type:
|
||||
query = query.where(ReferenceFile.file_type == file_type)
|
||||
if locale_scope:
|
||||
query = query.where(ReferenceFile.locale_scope == locale_scope)
|
||||
|
||||
result = await db.execute(query.order_by(ReferenceFile.uploaded_at.desc()))
|
||||
return list(result.scalars().all())
|
||||
|
||||
def get_file_path(self, stored_path: str) -> Path | None:
|
||||
"""Resolve a stored file path and verify it exists."""
|
||||
path = Path(stored_path)
|
||||
if path.exists():
|
||||
return path
|
||||
return None
|
||||
|
||||
async def delete_tm_file(
|
||||
self, db: AsyncSession, file_id: UUID
|
||||
) -> bool:
|
||||
"""Delete a TM file from storage and database."""
|
||||
result = await db.execute(
|
||||
select(TMFileRegistry).where(TMFileRegistry.id == file_id)
|
||||
)
|
||||
tm_file = result.scalar_one_or_none()
|
||||
if tm_file is None:
|
||||
return False
|
||||
|
||||
# Remove from filesystem
|
||||
file_path = Path(tm_file.file_path)
|
||||
if file_path.exists():
|
||||
os.remove(file_path)
|
||||
|
||||
await db.delete(tm_file)
|
||||
await db.flush()
|
||||
return True
|
||||
|
||||
async def delete_reference_file(
|
||||
self, db: AsyncSession, file_id: UUID
|
||||
) -> bool:
|
||||
"""Delete a reference file from storage and database."""
|
||||
result = await db.execute(
|
||||
select(ReferenceFile).where(ReferenceFile.id == file_id)
|
||||
)
|
||||
ref_file = result.scalar_one_or_none()
|
||||
if ref_file is None:
|
||||
return False
|
||||
|
||||
file_path = Path(ref_file.file_path)
|
||||
if file_path.exists():
|
||||
os.remove(file_path)
|
||||
|
||||
await db.delete(ref_file)
|
||||
await db.flush()
|
||||
return True
|
||||
|
||||
def validate_file_extension(
|
||||
self, filename: str, allowed_extensions: list[str]
|
||||
) -> bool:
|
||||
"""Validate that a file has an allowed extension."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
return ext in allowed_extensions
|
||||
202
backend/app/services/job_service.py
Normal file
202
backend/app/services/job_service.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus, LocaleType
|
||||
from app.models.source import SourceLine
|
||||
from app.schemas.job import JobCreate, JobUpdate
|
||||
|
||||
|
||||
class JobService:
|
||||
"""Service for Job CRUD and launch logic."""
|
||||
|
||||
async def create_job(
|
||||
self, db: AsyncSession, data: JobCreate, user_id: UUID
|
||||
) -> Job:
|
||||
"""Create a new job with locale instances."""
|
||||
job = Job(
|
||||
client_id=data.client_id,
|
||||
created_by=user_id,
|
||||
campaign_name=data.campaign_name,
|
||||
programme=data.programme,
|
||||
channel=data.channel,
|
||||
sub_channel=data.sub_channel,
|
||||
context_prompt=data.context_prompt,
|
||||
job_type=data.job_type,
|
||||
parent_job_id=data.parent_job_id,
|
||||
job_ref=data.job_ref,
|
||||
status=JobStatus.created,
|
||||
)
|
||||
db.add(job)
|
||||
await db.flush()
|
||||
|
||||
# Create locale instances
|
||||
for i, locale_code in enumerate(data.locale_codes):
|
||||
locale_type = LocaleType.main if i == 0 else LocaleType.derived
|
||||
instance = LocaleInstance(
|
||||
job_id=job.id,
|
||||
locale_code=locale_code,
|
||||
locale_type=locale_type,
|
||||
status=LocaleStatus.queued,
|
||||
)
|
||||
db.add(instance)
|
||||
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
async def get_job(self, db: AsyncSession, job_id: UUID) -> Job | None:
|
||||
"""Get a job by ID with locale instances."""
|
||||
result = await db.execute(
|
||||
select(Job)
|
||||
.options(selectinload(Job.locale_instances))
|
||||
.where(Job.id == job_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_jobs(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID | None = None,
|
||||
status: str | None = None,
|
||||
locale: str | None = None,
|
||||
date_from: datetime | None = None,
|
||||
date_to: datetime | None = None,
|
||||
search: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[Job], int]:
|
||||
"""List jobs with filters and pagination."""
|
||||
query = select(Job).options(selectinload(Job.locale_instances))
|
||||
|
||||
if client_id:
|
||||
query = query.where(Job.client_id == client_id)
|
||||
if status:
|
||||
query = query.where(Job.status == status)
|
||||
if locale:
|
||||
query = query.join(LocaleInstance).where(
|
||||
LocaleInstance.locale_code == locale
|
||||
)
|
||||
if date_from:
|
||||
query = query.where(Job.created_at >= date_from)
|
||||
if date_to:
|
||||
query = query.where(Job.created_at <= date_to)
|
||||
if search:
|
||||
query = query.where(
|
||||
Job.campaign_name.ilike(f"%{search}%")
|
||||
| Job.job_ref.ilike(f"%{search}%")
|
||||
)
|
||||
|
||||
# Count total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply pagination
|
||||
query = query.order_by(Job.created_at.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
jobs = list(result.scalars().unique().all())
|
||||
|
||||
return jobs, total
|
||||
|
||||
async def update_job(
|
||||
self, db: AsyncSession, job_id: UUID, data: JobUpdate
|
||||
) -> Job | None:
|
||||
"""Update a job."""
|
||||
job = await self.get_job(db, job_id)
|
||||
if job is None:
|
||||
return None
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(job, field, value)
|
||||
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
async def launch_job(self, db: AsyncSession, job_id: UUID) -> Job | None:
|
||||
"""Validate and queue a job for processing."""
|
||||
job = await self.get_job(db, job_id)
|
||||
if job is None:
|
||||
return None
|
||||
|
||||
if job.status != JobStatus.created:
|
||||
raise ValueError(
|
||||
f"Job cannot be launched from status '{job.status.value}'"
|
||||
)
|
||||
|
||||
# Check source lines exist
|
||||
result = await db.execute(
|
||||
select(func.count()).where(SourceLine.job_id == job_id)
|
||||
)
|
||||
source_count = result.scalar() or 0
|
||||
if source_count == 0:
|
||||
raise ValueError("Job has no source lines. Upload source file first.")
|
||||
|
||||
# Check locale instances exist
|
||||
if not job.locale_instances:
|
||||
raise ValueError("Job has no locale instances configured.")
|
||||
|
||||
# Update status
|
||||
job.status = JobStatus.queued
|
||||
await db.flush()
|
||||
|
||||
# Dispatch Celery tasks
|
||||
from app.tasks.job_tasks import process_job
|
||||
|
||||
process_job.delay(str(job_id))
|
||||
|
||||
return job
|
||||
|
||||
async def cancel_job(self, db: AsyncSession, job_id: UUID) -> Job | None:
|
||||
"""Cancel a running job."""
|
||||
job = await self.get_job(db, job_id)
|
||||
if job is None:
|
||||
return None
|
||||
|
||||
if job.status not in (JobStatus.queued, JobStatus.running):
|
||||
raise ValueError(
|
||||
f"Job cannot be cancelled from status '{job.status.value}'"
|
||||
)
|
||||
|
||||
job.status = JobStatus.error
|
||||
for instance in job.locale_instances:
|
||||
if instance.status in (LocaleStatus.queued, LocaleStatus.running):
|
||||
instance.status = LocaleStatus.error
|
||||
instance.error_log = "Cancelled by user"
|
||||
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
async def rerun_locale(
|
||||
self, db: AsyncSession, job_id: UUID, locale_code: str
|
||||
) -> LocaleInstance | None:
|
||||
"""Re-run a single locale instance."""
|
||||
result = await db.execute(
|
||||
select(LocaleInstance).where(
|
||||
LocaleInstance.job_id == job_id,
|
||||
LocaleInstance.locale_code == locale_code,
|
||||
)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
instance.status = LocaleStatus.queued
|
||||
instance.error_log = None
|
||||
instance.started_at = None
|
||||
instance.completed_at = None
|
||||
instance.token_usage = 0
|
||||
instance.estimated_cost = 0.0
|
||||
await db.flush()
|
||||
|
||||
from app.tasks.job_tasks import process_locale_instance
|
||||
|
||||
process_locale_instance.delay(str(job_id), locale_code)
|
||||
|
||||
return instance
|
||||
104
backend/app/services/output_service.py
Normal file
104
backend/app/services/output_service.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.job import LocaleInstance
|
||||
from app.models.output import OutputRow
|
||||
from app.models.source import SourceLine
|
||||
from app.schemas.output import (
|
||||
OutputPreviewResponse,
|
||||
OutputRowResponse,
|
||||
SourceLinePreview,
|
||||
)
|
||||
|
||||
|
||||
class OutputService:
|
||||
"""Service for assembling output preview data and triggering exports."""
|
||||
|
||||
async def get_preview(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
job_id: UUID,
|
||||
locale_code: str,
|
||||
) -> OutputPreviewResponse | None:
|
||||
"""Assemble output preview data for a specific locale instance."""
|
||||
# Get the locale instance
|
||||
result = await db.execute(
|
||||
select(LocaleInstance)
|
||||
.where(
|
||||
LocaleInstance.job_id == job_id,
|
||||
LocaleInstance.locale_code == locale_code,
|
||||
)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
# Get source lines
|
||||
source_result = await db.execute(
|
||||
select(SourceLine)
|
||||
.where(SourceLine.job_id == job_id)
|
||||
.order_by(SourceLine.row_order)
|
||||
)
|
||||
source_lines = [
|
||||
SourceLinePreview.model_validate(sl)
|
||||
for sl in source_result.scalars().all()
|
||||
]
|
||||
|
||||
# Get output rows
|
||||
output_result = await db.execute(
|
||||
select(OutputRow)
|
||||
.where(OutputRow.instance_id == instance.id)
|
||||
.order_by(OutputRow.row_order)
|
||||
)
|
||||
output_rows = [
|
||||
OutputRowResponse.model_validate(row)
|
||||
for row in output_result.scalars().all()
|
||||
]
|
||||
|
||||
return OutputPreviewResponse(
|
||||
locale_code=locale_code,
|
||||
instance_id=instance.id,
|
||||
source_lines=source_lines,
|
||||
output_rows=output_rows,
|
||||
total_rows=len(output_rows),
|
||||
)
|
||||
|
||||
async def get_output_rows(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
instance_id: UUID,
|
||||
) -> list[OutputRow]:
|
||||
"""Get all output rows for a locale instance."""
|
||||
result = await db.execute(
|
||||
select(OutputRow)
|
||||
.where(OutputRow.instance_id == instance_id)
|
||||
.order_by(OutputRow.row_order)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def trigger_export(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
job_id: UUID,
|
||||
locale_code: str,
|
||||
) -> str | None:
|
||||
"""Trigger export generation for a locale and return the file path."""
|
||||
result = await db.execute(
|
||||
select(LocaleInstance)
|
||||
.where(
|
||||
LocaleInstance.job_id == job_id,
|
||||
LocaleInstance.locale_code == locale_code,
|
||||
)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
if instance.output_file_path:
|
||||
return instance.output_file_path
|
||||
|
||||
# Export would be triggered here; for now return None indicating no export yet
|
||||
return None
|
||||
133
backend/app/services/report_service.py
Normal file
133
backend/app/services/report_service.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.audit import TokenUsageLog
|
||||
from app.models.feedback import Feedback, FlagType
|
||||
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
|
||||
from app.models.output import OutputRow, ConfidenceTier
|
||||
|
||||
|
||||
class ReportService:
|
||||
"""Service for aggregation queries powering reports."""
|
||||
|
||||
async def get_usage_stats(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID | None = None,
|
||||
date_from: datetime | None = None,
|
||||
date_to: datetime | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get overall usage statistics."""
|
||||
job_query = select(
|
||||
func.count(Job.id).label("total_jobs"),
|
||||
func.sum(Job.total_token_usage).label("total_tokens"),
|
||||
func.sum(Job.total_estimated_cost).label("total_cost"),
|
||||
)
|
||||
if client_id:
|
||||
job_query = job_query.where(Job.client_id == client_id)
|
||||
if date_from:
|
||||
job_query = job_query.where(Job.created_at >= date_from)
|
||||
if date_to:
|
||||
job_query = job_query.where(Job.created_at <= date_to)
|
||||
|
||||
result = await db.execute(job_query)
|
||||
row = result.one()
|
||||
|
||||
# Status breakdown
|
||||
status_query = select(
|
||||
Job.status, func.count(Job.id)
|
||||
).group_by(Job.status)
|
||||
if client_id:
|
||||
status_query = status_query.where(Job.client_id == client_id)
|
||||
status_result = await db.execute(status_query)
|
||||
status_breakdown = {
|
||||
status.value: count for status, count in status_result.all()
|
||||
}
|
||||
|
||||
return {
|
||||
"total_jobs": row.total_jobs or 0,
|
||||
"total_tokens": row.total_tokens or 0,
|
||||
"total_cost": float(row.total_cost or 0.0),
|
||||
"status_breakdown": status_breakdown,
|
||||
}
|
||||
|
||||
async def get_token_cost_data(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID | None = None,
|
||||
date_from: datetime | None = None,
|
||||
date_to: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get token usage and cost data grouped by date."""
|
||||
query = (
|
||||
select(
|
||||
func.date_trunc("day", TokenUsageLog.timestamp).label("date"),
|
||||
func.sum(TokenUsageLog.input_tokens).label("input_tokens"),
|
||||
func.sum(TokenUsageLog.output_tokens).label("output_tokens"),
|
||||
func.sum(TokenUsageLog.total_tokens).label("total_tokens"),
|
||||
func.sum(TokenUsageLog.estimated_cost_usd).label("total_cost"),
|
||||
)
|
||||
.group_by(func.date_trunc("day", TokenUsageLog.timestamp))
|
||||
.order_by(func.date_trunc("day", TokenUsageLog.timestamp))
|
||||
)
|
||||
|
||||
if client_id:
|
||||
query = query.join(LocaleInstance).join(Job).where(
|
||||
Job.client_id == client_id
|
||||
)
|
||||
if date_from:
|
||||
query = query.where(TokenUsageLog.timestamp >= date_from)
|
||||
if date_to:
|
||||
query = query.where(TokenUsageLog.timestamp <= date_to)
|
||||
|
||||
result = await db.execute(query)
|
||||
return [
|
||||
{
|
||||
"date": str(row.date),
|
||||
"input_tokens": row.input_tokens or 0,
|
||||
"output_tokens": row.output_tokens or 0,
|
||||
"total_tokens": row.total_tokens or 0,
|
||||
"total_cost": float(row.total_cost or 0.0),
|
||||
}
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
async def get_quality_metrics(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
client_id: UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get quality metrics from output confidence tiers and feedback."""
|
||||
# Confidence tier distribution
|
||||
tier_query = select(
|
||||
OutputRow.confidence_tier, func.count(OutputRow.id)
|
||||
).group_by(OutputRow.confidence_tier)
|
||||
|
||||
if client_id:
|
||||
tier_query = tier_query.join(LocaleInstance).join(Job).where(
|
||||
Job.client_id == client_id
|
||||
)
|
||||
|
||||
tier_result = await db.execute(tier_query)
|
||||
tier_breakdown = {
|
||||
tier.value: count for tier, count in tier_result.all()
|
||||
}
|
||||
|
||||
# Feedback distribution
|
||||
feedback_query = select(
|
||||
Feedback.flag_type, func.count(Feedback.id)
|
||||
).group_by(Feedback.flag_type)
|
||||
|
||||
feedback_result = await db.execute(feedback_query)
|
||||
feedback_breakdown = {
|
||||
ft.value: count for ft, count in feedback_result.all()
|
||||
}
|
||||
|
||||
return {
|
||||
"confidence_tiers": tier_breakdown,
|
||||
"feedback_distribution": feedback_breakdown,
|
||||
}
|
||||
0
backend/app/tasks/__init__.py
Normal file
0
backend/app/tasks/__init__.py
Normal file
26
backend/app/tasks/celery_app.py
Normal file
26
backend/app/tasks/celery_app.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""Celery application configuration."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from app.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"transcreation",
|
||||
broker=settings.REDIS_URL,
|
||||
backend=settings.REDIS_URL,
|
||||
include=["app.tasks.job_tasks"],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
task_track_started=True,
|
||||
task_acks_late=True,
|
||||
worker_prefetch_multiplier=1,
|
||||
result_expires=3600,
|
||||
task_soft_time_limit=1800, # 30 minutes
|
||||
task_time_limit=3600, # 1 hour
|
||||
)
|
||||
258
backend/app/tasks/job_tasks.py
Normal file
258
backend/app/tasks/job_tasks.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
"""Celery tasks for job processing."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
from app.models.job import Job, JobStatus, LocaleInstance, LocaleStatus
|
||||
from app.models.output import ConfidenceTier, OutputRow
|
||||
from app.models.source import SourceLine
|
||||
from app.pipeline.orchestrator import PipelineOrchestrator
|
||||
from app.tasks.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_async_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
"""Create a fresh async session factory for use in Celery tasks."""
|
||||
engine = create_async_engine(settings.DATABASE_URL, pool_pre_ping=True)
|
||||
return async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=2)
|
||||
def process_job(self, job_id: str) -> dict:
|
||||
"""Fan out locale instance processing for a job.
|
||||
|
||||
For each locale instance in the job, dispatches a process_locale_instance task.
|
||||
"""
|
||||
logger.info(f"Processing job {job_id}")
|
||||
|
||||
async def _process() -> dict:
|
||||
factory = _get_async_session_factory()
|
||||
async with factory() as db:
|
||||
# Get job
|
||||
result = await db.execute(select(Job).where(Job.id == job_id))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
return {"error": f"Job {job_id} not found"}
|
||||
|
||||
# Update status to running
|
||||
job.status = JobStatus.running
|
||||
await db.commit()
|
||||
|
||||
# Get locale instances
|
||||
result = await db.execute(
|
||||
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
|
||||
)
|
||||
instances = result.scalars().all()
|
||||
|
||||
# Dispatch per-locale tasks
|
||||
task_ids = []
|
||||
for instance in instances:
|
||||
task = process_locale_instance.delay(job_id, instance.locale_code)
|
||||
task_ids.append(task.id)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"dispatched_locales": len(task_ids),
|
||||
"task_ids": task_ids,
|
||||
}
|
||||
|
||||
return asyncio.run(_process())
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=1)
|
||||
def process_locale_instance(self, job_id: str, locale_code: str) -> dict:
|
||||
"""Process a single locale instance through the pipeline.
|
||||
|
||||
Runs the pipeline orchestrator for the given job + locale combination.
|
||||
"""
|
||||
logger.info(f"Processing locale {locale_code} for job {job_id}")
|
||||
|
||||
async def _process() -> dict:
|
||||
factory = _get_async_session_factory()
|
||||
async with factory() as db:
|
||||
# Get the locale instance
|
||||
result = await db.execute(
|
||||
select(LocaleInstance).where(
|
||||
LocaleInstance.job_id == job_id,
|
||||
LocaleInstance.locale_code == locale_code,
|
||||
)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
if instance is None:
|
||||
return {"error": f"Locale instance not found: {job_id}/{locale_code}"}
|
||||
|
||||
# Get job
|
||||
job_result = await db.execute(select(Job).where(Job.id == job_id))
|
||||
job = job_result.scalar_one_or_none()
|
||||
if job is None:
|
||||
return {"error": f"Job {job_id} not found"}
|
||||
|
||||
# Update instance status
|
||||
instance.status = LocaleStatus.running
|
||||
instance.started_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
# Get source lines
|
||||
source_result = await db.execute(
|
||||
select(SourceLine)
|
||||
.where(SourceLine.job_id == job_id)
|
||||
.order_by(SourceLine.row_order)
|
||||
)
|
||||
source_lines = [
|
||||
{
|
||||
"id": str(sl.id),
|
||||
"en_gb": sl.en_gb,
|
||||
"copy_type": sl.copy_type,
|
||||
"creative_guidance": sl.creative_guidance,
|
||||
"visual_ref": sl.visual_ref,
|
||||
"char_limit": sl.char_limit,
|
||||
"is_display_format": sl.is_display_format,
|
||||
}
|
||||
for sl in source_result.scalars().all()
|
||||
]
|
||||
|
||||
# Build job params
|
||||
job_params = {
|
||||
"job_id": str(job.id),
|
||||
"client_id": str(job.client_id),
|
||||
"locale_code": locale_code,
|
||||
"channel": job.channel,
|
||||
"sub_channel": job.sub_channel,
|
||||
"programme": job.programme.value,
|
||||
"campaign_name": job.campaign_name,
|
||||
"context_prompt": job.context_prompt,
|
||||
}
|
||||
|
||||
# Run pipeline
|
||||
orchestrator = PipelineOrchestrator(
|
||||
job_params=job_params,
|
||||
source_lines=source_lines,
|
||||
output_dir=settings.STORAGE_ROOT,
|
||||
)
|
||||
context = await orchestrator.run()
|
||||
|
||||
# Save output rows to database
|
||||
for i, draft in enumerate(context.draft_outputs):
|
||||
# Find matching source line
|
||||
source_line_id = None
|
||||
if i < len(source_lines):
|
||||
source_line_id = source_lines[i].get("id")
|
||||
|
||||
# Get confidence tier
|
||||
tier = "low"
|
||||
if i < len(context.ranking_declarations):
|
||||
tier = context.ranking_declarations[i].confidence_tier
|
||||
|
||||
# Get character counts
|
||||
char_counts = {}
|
||||
if i < len(context.compliance_results):
|
||||
char_counts = context.compliance_results[i].character_counts
|
||||
|
||||
output_row = OutputRow(
|
||||
instance_id=instance.id,
|
||||
line_id=source_line_id,
|
||||
row_order=i + 1,
|
||||
confidence_tier=ConfidenceTier(tier),
|
||||
option_1=draft.option_1.text if draft.option_1 else "",
|
||||
backtranslation_1=draft.option_1.backtranslation if draft.option_1 else "",
|
||||
rationale_1=draft.option_1.rationale if draft.option_1 else "",
|
||||
option_2=draft.option_2.text if draft.option_2 else None,
|
||||
backtranslation_2=draft.option_2.backtranslation if draft.option_2 else None,
|
||||
rationale_2=draft.option_2.rationale if draft.option_2 else None,
|
||||
option_3=draft.option_3.text if draft.option_3 else None,
|
||||
backtranslation_3=draft.option_3.backtranslation if draft.option_3 else None,
|
||||
rationale_3=draft.option_3.rationale if draft.option_3 else None,
|
||||
tm_entries_cited=draft.tm_entries_cited if draft.tm_entries_cited else None,
|
||||
character_count_option_1=char_counts.get("option_1"),
|
||||
character_count_option_2=char_counts.get("option_2"),
|
||||
character_count_option_3=char_counts.get("option_3"),
|
||||
)
|
||||
db.add(output_row)
|
||||
|
||||
# Update instance status
|
||||
instance.status = LocaleStatus.complete
|
||||
instance.completed_at = datetime.now(timezone.utc)
|
||||
output_path = (
|
||||
f"{settings.STORAGE_ROOT}/jobs/{job_id}/output/"
|
||||
f"{locale_code}_{job_id}_output.xlsx"
|
||||
)
|
||||
instance.output_file_path = output_path
|
||||
instance.agent_version = "1.0.0"
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Check if all instances are complete
|
||||
await _check_job_completion(db, job_id)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"locale_code": locale_code,
|
||||
"status": "complete",
|
||||
"output_rows": len(context.draft_outputs),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing {locale_code} for job {job_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
instance.status = LocaleStatus.error
|
||||
instance.error_log = str(e)
|
||||
instance.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
# Check if job should be marked as partial/error
|
||||
await _check_job_completion(db, job_id)
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"locale_code": locale_code,
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return asyncio.run(_process())
|
||||
|
||||
|
||||
async def _check_job_completion(db: AsyncSession, job_id: str) -> None:
|
||||
"""Check if all locale instances are done and update job status accordingly."""
|
||||
result = await db.execute(
|
||||
select(LocaleInstance).where(LocaleInstance.job_id == job_id)
|
||||
)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
if not instances:
|
||||
return
|
||||
|
||||
all_complete = all(i.status == LocaleStatus.complete for i in instances)
|
||||
all_done = all(
|
||||
i.status in (LocaleStatus.complete, LocaleStatus.error)
|
||||
for i in instances
|
||||
)
|
||||
any_error = any(i.status == LocaleStatus.error for i in instances)
|
||||
|
||||
job_result = await db.execute(select(Job).where(Job.id == job_id))
|
||||
job = job_result.scalar_one_or_none()
|
||||
if job is None:
|
||||
return
|
||||
|
||||
if all_complete:
|
||||
job.status = JobStatus.complete
|
||||
# Sum up token usage
|
||||
job.total_token_usage = sum(i.token_usage for i in instances)
|
||||
job.total_estimated_cost = sum(i.estimated_cost for i in instances)
|
||||
elif all_done and any_error:
|
||||
any_complete = any(i.status == LocaleStatus.complete for i in instances)
|
||||
if any_complete:
|
||||
job.status = JobStatus.partial_complete
|
||||
else:
|
||||
job.status = JobStatus.error
|
||||
|
||||
await db.commit()
|
||||
0
backend/app/ws/__init__.py
Normal file
0
backend/app/ws/__init__.py
Normal file
47
backend/app/ws/handler.py
Normal file
47
backend/app/ws/handler.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.auth.service import AuthService
|
||||
from app.ws.manager import manager
|
||||
|
||||
ws_router = APIRouter()
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
@ws_router.websocket("/ws/jobs/{job_id}")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
job_id: str,
|
||||
token: str = Query(...),
|
||||
) -> None:
|
||||
"""WebSocket endpoint for real-time job progress updates.
|
||||
|
||||
Authentication is performed via query parameter token.
|
||||
"""
|
||||
# Validate token
|
||||
claims = auth_service.validate_token(token)
|
||||
if claims is None:
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
await manager.connect(job_id, websocket)
|
||||
|
||||
try:
|
||||
# Send initial connection message
|
||||
await manager.send_personal(
|
||||
websocket,
|
||||
{
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "Connected to job progress stream",
|
||||
},
|
||||
)
|
||||
|
||||
# Keep connection alive and listen for client messages
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# Currently we don't process incoming messages,
|
||||
# but the loop keeps the connection alive
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(job_id, websocket)
|
||||
except Exception:
|
||||
manager.disconnect(job_id, websocket)
|
||||
59
backend/app/ws/manager.py
Normal file
59
backend/app/ws/manager.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections grouped by job_id."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, list[WebSocket]] = {}
|
||||
|
||||
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Accept a websocket connection and register it for a job."""
|
||||
await websocket.accept()
|
||||
if job_id not in self._connections:
|
||||
self._connections[job_id] = []
|
||||
self._connections[job_id].append(websocket)
|
||||
|
||||
def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Remove a websocket connection from a job's connection list."""
|
||||
if job_id in self._connections:
|
||||
self._connections[job_id] = [
|
||||
ws for ws in self._connections[job_id] if ws != websocket
|
||||
]
|
||||
if not self._connections[job_id]:
|
||||
del self._connections[job_id]
|
||||
|
||||
async def broadcast(self, job_id: str, message: dict[str, Any]) -> None:
|
||||
"""Broadcast a JSON message to all connections for a job."""
|
||||
if job_id not in self._connections:
|
||||
return
|
||||
|
||||
payload = json.dumps(message)
|
||||
disconnected: list[WebSocket] = []
|
||||
|
||||
for websocket in self._connections[job_id]:
|
||||
try:
|
||||
await websocket.send_text(payload)
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up dead connections
|
||||
for ws in disconnected:
|
||||
self.disconnect(job_id, ws)
|
||||
|
||||
async def send_personal(
|
||||
self, websocket: WebSocket, message: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a message to a specific websocket."""
|
||||
await websocket.send_text(json.dumps(message))
|
||||
|
||||
def get_connection_count(self, job_id: str) -> int:
|
||||
"""Get the number of active connections for a job."""
|
||||
return len(self._connections.get(job_id, []))
|
||||
|
||||
|
||||
# Singleton instance
|
||||
manager = ConnectionManager()
|
||||
17
backend/requirements.txt
Normal file
17
backend/requirements.txt
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
fastapi>=0.111.0
|
||||
uvicorn[standard]>=0.29.0
|
||||
sqlalchemy[asyncio]>=2.0.30
|
||||
asyncpg>=0.29.0
|
||||
alembic>=1.13.0
|
||||
celery[redis]>=5.4.0
|
||||
redis>=5.0.0
|
||||
pydantic[email]>=2.7.0
|
||||
pydantic-settings>=2.2.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
bcrypt>=4.0.0
|
||||
python-multipart>=0.0.9
|
||||
openpyxl>=3.1.0
|
||||
grapheme>=0.6.0
|
||||
anthropic>=0.25.0
|
||||
websockets>=12.0
|
||||
httpx>=0.27.0
|
||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
0
backend/tests/fixtures/__init__.py
vendored
Normal file
0
backend/tests/fixtures/__init__.py
vendored
Normal file
0
backend/tests/test_api/__init__.py
Normal file
0
backend/tests/test_api/__init__.py
Normal file
0
backend/tests/test_auth/__init__.py
Normal file
0
backend/tests/test_auth/__init__.py
Normal file
0
backend/tests/test_modules/__init__.py
Normal file
0
backend/tests/test_modules/__init__.py
Normal file
0
backend/tests/test_pipeline/__init__.py
Normal file
0
backend/tests/test_pipeline/__init__.py
Normal file
71
docker-compose.yml
Normal file
71
docker-compose.yml
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
version: "3.9"
|
||||
|
||||
services:
|
||||
db:
|
||||
image: postgres:16
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: transcreation
|
||||
POSTGRES_PASSWORD: transcreation
|
||||
POSTGRES_DB: transcreation
|
||||
ports:
|
||||
- "5492:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U transcreation"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "6389:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8040:8000"
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./storage:/storage
|
||||
- ./seed:/app/seed
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
celery_worker:
|
||||
build:
|
||||
context: ./backend
|
||||
dockerfile: Dockerfile
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./storage:/storage
|
||||
- ./seed:/app/seed
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
command: celery -A app.tasks.celery_app worker --loglevel=info --concurrency=4
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
2
frontend/.env.local.example
Normal file
2
frontend/.env.local.example
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_WS_URL=ws://localhost:8000
|
||||
35
frontend/.gitignore
vendored
Normal file
35
frontend/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.js
|
||||
.yarn/install-state.gz
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# next.js
|
||||
/.next/
|
||||
/out/
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
# debug
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
||||
# local env files
|
||||
.env*.local
|
||||
.env
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
15
frontend/Dockerfile
Normal file
15
frontend/Dockerfile
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
FROM node:20-alpine AS builder
|
||||
WORKDIR /app
|
||||
COPY package.json package-lock.json* ./
|
||||
RUN npm ci
|
||||
COPY . .
|
||||
RUN npm run build
|
||||
|
||||
FROM node:20-alpine AS runner
|
||||
WORKDIR /app
|
||||
ENV NODE_ENV production
|
||||
COPY --from=builder /app/public ./public
|
||||
COPY --from=builder /app/.next/standalone ./
|
||||
COPY --from=builder /app/.next/static ./.next/static
|
||||
EXPOSE 3000
|
||||
CMD ["node", "server.js"]
|
||||
16
frontend/components.json
Normal file
16
frontend/components.json
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "default",
|
||||
"rsc": true,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.ts",
|
||||
"css": "src/app/globals.css",
|
||||
"baseColor": "slate",
|
||||
"cssVariables": true
|
||||
},
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils"
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue