video-accessibility/backend/app/services/glossary_service.py
Vadym Samoilenko 4645e67611 fix(glossary-list): show real embedding progress in glossary list view
- Add current_version_embedding_status/embedded_count/term_count to GlossaryResponse
- Batch-fetch current versions in list endpoint (single extra query, not N queries)
- Add get_versions_by_ids() helper to glossary_service
- Fix GlossaryList.tsx: embeddingBadge('') → embeddingBadge(g) with real status + pct

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-13 19:00:56 +01:00

779 lines
27 KiB
Python

"""
Glossary service — per-client terminology management.
Responsibilities:
• parse_xlsx(bytes, source_col) → list of (source_term, {locale: translation})
• ingest_glossary(...) → create Glossary + GlossaryVersion + GlossaryTerms in Mongo
• activate_version(...) → atomic swap of current_version_id
• match_terms_for_text(...) → hybrid exact + vector retrieval for prompt injection
• build_glossary_prompt_block(...) → formats matched terms for the Gemini prompt
"""
from __future__ import annotations
import io
import re
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime
from bson import ObjectId
from fastapi import UploadFile
from ..core.database import get_database
from ..core.logging import get_logger
from ..lib import locales as locale_lib
from ..models.glossary import (
EmbeddingStatus,
Glossary,
GlossaryStatus,
GlossaryVersion,
MatchedTerm,
glossary_from_doc,
glossary_version_from_doc,
)
logger = get_logger(__name__)
_COLL_GLOSSARIES = "glossaries"
_COLL_VERSIONS = "glossary_versions"
_COLL_TERMS = "glossary_terms"
# Maximum number of terms injected into a single Gemini prompt
_MAX_TERMS_IN_PROMPT = 50
# Atlas Vector Search index name (must exist on the collection)
_VECTOR_INDEX = "glossary_embedding_index"
_VECTOR_DIMS = 768
_VECTOR_SIMILARITY_THRESHOLD = 0.75
_VECTOR_TOP_K = 20
# ── xlsx parsing ─────────────────────────────────────────────────────────────
@dataclass
class _ParsedTerm:
cid: str | None
tid: str | None
source_term: str
translations: dict[str, str] # {normalized_locale: text}
def _cell(row: tuple, idx: int | None) -> str | None:
if idx is None or idx >= len(row):
return None
v = row[idx]
return str(v).strip() if v is not None else None
def parse_xlsx(file_bytes: bytes, source_locale_col: str) -> list[_ParsedTerm]:
"""
Parse an xlsx glossary file.
Args:
file_bytes: Raw xlsx bytes.
source_locale_col: The column header that contains the source text,
e.g. "en_gb" or "en-GB". Case-insensitive.
Returns:
List of parsed terms. Rows where the source column is empty are skipped.
"""
import openpyxl # local import — only used during ingest
wb = openpyxl.load_workbook(io.BytesIO(file_bytes), read_only=True, data_only=True)
ws = wb.active
rows = ws.iter_rows(values_only=True)
try:
header_row = next(rows)
except StopIteration:
return []
# Normalise header names to canonical locale codes
headers: list[str | None] = []
for h in header_row:
if h is None:
headers.append(None)
continue
s = str(h).strip()
headers.append(s)
# Find column indices
src_col_name = source_locale_col.strip()
# Try exact match first, then case-insensitive
src_idx: int | None = None
for i, h in enumerate(headers):
if h and h.lower() == src_col_name.lower():
src_idx = i
break
if src_idx is None:
raise ValueError(f"Source column '{source_locale_col}' not found in xlsx. Available: {[h for h in headers if h]}")
cid_idx = next((i for i, h in enumerate(headers) if h and h.upper() == "CID"), None)
tid_idx = next((i for i, h in enumerate(headers) if h and h.upper() == "TID"), None)
# All other columns with valid locale-like names become translation columns
locale_cols: list[tuple[int, str]] = [] # [(col_index, normalized_locale_code)]
for i, h in enumerate(headers):
if h is None or i == src_idx or i == cid_idx or i == tid_idx:
continue
norm = locale_lib.normalize_code(h)
if norm:
locale_cols.append((i, norm))
terms: list[_ParsedTerm] = []
for row in rows:
if not row or all(v is None for v in row):
continue
source = _cell(row, src_idx)
if not source:
continue
translations: dict[str, str] = {}
for col_idx, locale_code in locale_cols:
val = _cell(row, col_idx)
if val:
translations[locale_code] = val
terms.append(_ParsedTerm(
cid=_cell(row, cid_idx),
tid=_cell(row, tid_idx),
source_term=source,
translations=translations,
))
wb.close()
return terms
# ── Ingest ────────────────────────────────────────────────────────────────────
async def ingest_glossary(
client_id: str,
name: str,
source_locale: str,
source_locale_col: str,
file: UploadFile,
user_id: str,
description: str | None = None,
change_note: str | None = None,
) -> tuple[Glossary, GlossaryVersion]:
"""
Full glossary ingestion pipeline:
1. Upload xlsx to GCS
2. Parse terms
3. Create Glossary + GlossaryVersion + GlossaryTerm documents in Mongo
4. Kick off background embedding task
Returns (Glossary, GlossaryVersion) on success.
"""
db = await get_database()
# ── Upload original xlsx to GCS ──
file_bytes = await file.read()
glossary_id = str(ObjectId())
version_id = str(ObjectId())
gcs_path = f"glossaries/{client_id}/{glossary_id}/{version_id}/source.xlsx"
await _upload_bytes_to_gcs(file_bytes, gcs_path,
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
# ── Parse ──
logger.info(f"Parsing xlsx for glossary {glossary_id}, source_col={source_locale_col}")
parsed_terms = parse_xlsx(file_bytes, source_locale_col)
logger.info(f"Parsed {len(parsed_terms)} terms")
# ── Create Glossary doc ──
now = datetime.utcnow()
glossary_doc = {
"_id": ObjectId(glossary_id),
"client_id": client_id,
"name": name,
"description": description,
"source_locale": locale_lib.normalize_code(source_locale),
"source": "xlsx_upload",
"status": GlossaryStatus.ACTIVE.value,
"current_version_id": version_id,
"created_at": now,
"created_by": user_id,
}
await db[_COLL_GLOSSARIES].insert_one(glossary_doc)
# ── Create GlossaryVersion doc ──
version_doc = {
"_id": ObjectId(version_id),
"glossary_id": glossary_id,
"version_number": 1,
"source_xlsx_gcs_path": gcs_path,
"term_count": len(parsed_terms),
"embedded_count": 0,
"embedding_status": EmbeddingStatus.PENDING.value,
"created_at": now,
"created_by": user_id,
"change_note": change_note,
}
await db[_COLL_VERSIONS].insert_one(version_doc)
# ── Bulk insert GlossaryTerms ──
if parsed_terms:
term_docs = [
{
"_id": ObjectId(),
"glossary_id": glossary_id,
"version_id": version_id,
"cid": t.cid,
"tid": t.tid,
"source_term": t.source_term,
"source_term_lower": t.source_term.lower(),
"translations": t.translations,
"embedding": None,
}
for t in parsed_terms
]
await db[_COLL_TERMS].insert_many(term_docs, ordered=False)
# ── Create collection indexes (idempotent) ──
await _ensure_indexes(db)
# ── Kick off embedding Celery task ──
try:
from ..tasks.embed_glossary import embed_glossary_version_task
embed_glossary_version_task.delay(version_id)
logger.info(f"Queued embedding task for version {version_id}")
except Exception as e:
logger.warning(f"Could not queue embedding task: {e}")
glossary = glossary_from_doc(glossary_doc)
version = glossary_version_from_doc(version_doc)
return glossary, version
async def ingest_new_version(
glossary_id: str,
source_locale_col: str,
file: UploadFile,
user_id: str,
change_note: str | None = None,
) -> GlossaryVersion:
"""Add a new version to an existing glossary without replacing it as active."""
db = await get_database()
glossary_doc = await db[_COLL_GLOSSARIES].find_one({"_id": ObjectId(glossary_id)})
if not glossary_doc:
raise ValueError(f"Glossary {glossary_id} not found")
client_id = glossary_doc["client_id"]
# Find next version number
last_version = await db[_COLL_VERSIONS].find_one(
{"glossary_id": glossary_id},
sort=[("version_number", -1)],
)
next_version_num = (last_version["version_number"] + 1) if last_version else 1
file_bytes = await file.read()
version_id = str(ObjectId())
gcs_path = f"glossaries/{client_id}/{glossary_id}/{version_id}/source.xlsx"
await _upload_bytes_to_gcs(file_bytes, gcs_path,
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
parsed_terms = parse_xlsx(file_bytes, source_locale_col)
now = datetime.utcnow()
version_doc = {
"_id": ObjectId(version_id),
"glossary_id": glossary_id,
"version_number": next_version_num,
"source_xlsx_gcs_path": gcs_path,
"term_count": len(parsed_terms),
"embedded_count": 0,
"embedding_status": EmbeddingStatus.PENDING.value,
"created_at": now,
"created_by": user_id,
"change_note": change_note,
}
await db[_COLL_VERSIONS].insert_one(version_doc)
if parsed_terms:
term_docs = [
{
"_id": ObjectId(),
"glossary_id": glossary_id,
"version_id": version_id,
"cid": t.cid,
"tid": t.tid,
"source_term": t.source_term,
"source_term_lower": t.source_term.lower(),
"translations": t.translations,
"embedding": None,
}
for t in parsed_terms
]
await db[_COLL_TERMS].insert_many(term_docs, ordered=False)
try:
from ..tasks.embed_glossary import embed_glossary_version_task
embed_glossary_version_task.delay(version_id)
except Exception as e:
logger.warning(f"Could not queue embedding task: {e}")
return glossary_version_from_doc(version_doc)
async def activate_version(glossary_id: str, version_id: str) -> None:
"""Atomically set the active version of a glossary."""
db = await get_database()
result = await db[_COLL_GLOSSARIES].update_one(
{"_id": ObjectId(glossary_id)},
{"$set": {"current_version_id": version_id}},
)
if result.matched_count == 0:
raise ValueError(f"Glossary {glossary_id} not found")
# Invalidate Redis cache
await _invalidate_cache(glossary_id)
async def archive_glossary(glossary_id: str) -> None:
"""Hard-delete the glossary and all its versions and terms."""
db = await get_database()
versions = await db[_COLL_VERSIONS].find(
{"glossary_id": glossary_id}, {"_id": 1}
).to_list(length=None)
version_ids = [str(v["_id"]) for v in versions]
if version_ids:
terms_result = await db[_COLL_TERMS].delete_many({"version_id": {"$in": version_ids}})
logger.info(f"Deleted {terms_result.deleted_count} terms for glossary {glossary_id}")
await db[_COLL_VERSIONS].delete_many({"glossary_id": glossary_id})
logger.info(f"Deleted {len(version_ids)} versions for glossary {glossary_id}")
await db[_COLL_GLOSSARIES].delete_one({"_id": ObjectId(glossary_id)})
await _invalidate_cache(glossary_id)
logger.info(f"Deleted glossary {glossary_id}")
# ── Retrieval ─────────────────────────────────────────────────────────────────
async def match_terms_for_text(
client_id: str,
text: str,
target_locale: str,
top_k: int = _MAX_TERMS_IN_PROMPT,
) -> list[MatchedTerm]:
"""
Hybrid retrieval: exact-match (Aho-Corasick) + semantic (Atlas Vector Search).
Returns a ranked, deduplicated list of up to `top_k` MatchedTerm objects,
each with the source term and its translation in `target_locale`.
Exact matches rank before vector matches.
"""
db = await get_database()
norm_target = locale_lib.normalize_code(target_locale)
active_version_id = await _get_active_version_id(client_id)
if not active_version_id:
return []
# ── Exact pass ──
exact_matches = await _exact_match(db, active_version_id, text, norm_target)
# ── Vector pass (if we haven't hit the limit yet) ──
remaining = top_k - len(exact_matches)
already_found = {m.source_term.lower() for m in exact_matches}
vector_matches: list[MatchedTerm] = []
if remaining > 0:
try:
vector_matches = await _vector_match(
db, active_version_id, text, norm_target,
top_k=_VECTOR_TOP_K, exclude_terms=already_found,
)
except Exception as e:
logger.warning(f"Vector search failed (non-fatal): {e}")
combined = exact_matches + vector_matches
if len(combined) > top_k:
logger.info(f"glossary_terms_truncated: had {len(combined)}, capped at {top_k}")
combined = combined[:top_k]
return combined
async def _get_active_version_id(client_id: str) -> str | None:
"""Return the active version_id for the active glossary of a client, or None."""
try:
from ..core.redis import redis_client # lazy import
cache_key = f"glossary:active_version:{client_id}"
cached = await redis_client.get(cache_key)
if cached:
return cached.decode() if isinstance(cached, bytes) else cached
except Exception:
pass
db = await get_database()
glossary_doc = await db[_COLL_GLOSSARIES].find_one(
{"client_id": client_id, "status": GlossaryStatus.ACTIVE.value},
sort=[("created_at", -1)],
)
if not glossary_doc or not glossary_doc.get("current_version_id"):
return None
version_id = glossary_doc["current_version_id"]
try:
from ..core.redis import redis_client
cache_key = f"glossary:active_version:{client_id}"
await redis_client.setex(cache_key, 3600, version_id)
except Exception:
pass
return version_id
async def _invalidate_cache(glossary_id: str) -> None:
"""Clear Redis cache for a glossary's client."""
try:
db = await get_database()
doc = await db[_COLL_GLOSSARIES].find_one({"_id": ObjectId(glossary_id)})
if doc:
from ..core.redis import redis_client
await redis_client.delete(f"glossary:active_version:{doc['client_id']}")
except Exception as e:
logger.debug(f"Cache invalidation skipped: {e}")
async def _exact_match(
db,
version_id: str,
text: str,
target_locale: str,
) -> list[MatchedTerm]:
"""Find terms present in `text` using Aho-Corasick over the glossary terms."""
import ahocorasick # pyahocorasick
# Load all terms for this version (source_term_lower + translations)
cursor = db[_COLL_TERMS].find(
{"version_id": version_id},
{"source_term": 1, "source_term_lower": 1, "translations": 1},
)
terms = await cursor.to_list(length=None)
if not terms:
return []
# Build automaton
automaton = ahocorasick.Automaton()
for doc in terms:
stl = doc.get("source_term_lower") or doc.get("source_term", "")
if stl:
automaton.add_word(stl.lower(), (doc["source_term"], doc.get("translations", {})))
if not automaton:
return []
automaton.make_automaton()
text_lower = text.lower()
matched: list[MatchedTerm] = []
seen: set[str] = set()
for _end_idx, (source_term, translations) in automaton.iter(text_lower):
if source_term in seen:
continue
# Require word/phrase boundaries around the match
start_idx = _end_idx - len(source_term.lower()) + 1
if start_idx > 0 and text_lower[start_idx - 1].isalnum():
continue
end_after = _end_idx + 1
if end_after < len(text_lower) and text_lower[end_after].isalnum():
continue
target_text = _get_translation(translations, target_locale)
if not target_text:
continue
seen.add(source_term)
matched.append(MatchedTerm(
source_term=source_term,
target_translation=target_text,
match_kind="exact",
score=1.0,
))
return matched
async def _vector_match(
db,
version_id: str,
text: str,
target_locale: str,
top_k: int = 20,
exclude_terms: set[str] | None = None,
) -> list[MatchedTerm]:
"""Semantic search via Atlas Vector Search ($vectorSearch)."""
from ..services.embedding_service import embedding_service
query_embedding = await embedding_service.embed_text(text[:2000]) # cap input length
pipeline = [
{
"$vectorSearch": {
"index": _VECTOR_INDEX,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": top_k * 4,
"limit": top_k,
"filter": {"version_id": version_id},
}
},
{
"$project": {
"source_term": 1,
"translations": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
]
cursor = db[_COLL_TERMS].aggregate(pipeline)
results = await cursor.to_list(length=top_k)
matched: list[MatchedTerm] = []
for doc in results:
score = doc.get("score", 0.0)
if score < _VECTOR_SIMILARITY_THRESHOLD:
continue
source_term = doc["source_term"]
if exclude_terms and source_term.lower() in exclude_terms:
continue
target_text = _get_translation(doc["translations"], target_locale)
if not target_text:
continue
matched.append(MatchedTerm(
source_term=source_term,
target_translation=target_text,
match_kind="vector",
score=score,
))
return matched
def _get_translation(translations: dict[str, str], target_locale: str) -> str | None:
"""Look up a translation with locale-fallback.
Specific → bare: fr-CA → fr-FR siblings → fr
Bare → specific: fr → fr-FR, fr-CA (first match)
"""
if not translations or not target_locale:
return None
if target_locale in translations:
return translations[target_locale]
if "-" in target_locale:
# Specific locale: try sibling regions and bare parent (fr-CA → fr-FR → fr)
parent = target_locale.split("-")[0]
for code, text in translations.items():
if code.startswith(parent + "-") or code == parent:
return text
else:
# Bare code (fr): try any fr-* region variant stored in the glossary
for code, text in translations.items():
if code == target_locale or code.startswith(target_locale + "-"):
return text
return None
# ── Prompt block ──────────────────────────────────────────────────────────────
def build_glossary_prompt_block(
matched_terms: Sequence[MatchedTerm],
target_locale: str,
) -> str:
"""
Format matched terms for injection into a Gemini prompt.
Returns an empty string if no terms were matched.
"""
if not matched_terms:
return ""
target_label = locale_lib.get_gemini_label(target_locale)
lines = [
f"## Approved {target_label} terminology",
"Use these exact translations when the source terms appear — do not deviate:",
]
for term in matched_terms:
lines.append(f'- "{term.source_term}""{term.target_translation}"')
return "\n".join(lines)
# ── Helpers ───────────────────────────────────────────────────────────────────
async def _upload_bytes_to_gcs(data: bytes, gcs_path: str, content_type: str) -> None:
import asyncio
loop = asyncio.get_event_loop()
def _upload() -> None:
from google.cloud import storage as gcs_storage
from ..core.config import settings
client = gcs_storage.Client(project=settings.gcp_project_id)
bucket = client.bucket(settings.gcs_bucket)
blob = bucket.blob(gcs_path)
blob.content_type = content_type
blob.upload_from_string(data, content_type=content_type)
await loop.run_in_executor(None, _upload)
async def _ensure_indexes(db) -> None:
try:
await db[_COLL_GLOSSARIES].create_index([("client_id", 1), ("status", 1)])
await db[_COLL_VERSIONS].create_index([("glossary_id", 1), ("version_number", -1)])
await db[_COLL_TERMS].create_index([("version_id", 1), ("source_term_lower", 1)])
await db[_COLL_TERMS].create_index([("glossary_id", 1)])
except Exception as e:
logger.debug(f"Index creation skipped (likely already exist): {e}")
# ── Task helpers ─────────────────────────────────────────────────────────────
async def get_glossary_block_for_job(
job_doc: dict,
target_locale: str,
db,
) -> str:
"""
Convenience function for Celery tasks: given a job document and a target locale,
return the formatted glossary block for Gemini prompt injection (or empty string).
Looks up:
job_doc.project_id → db.projects → client_id → active glossary version
Non-fatal: any failure returns "" so the pipeline continues without a glossary.
"""
try:
job_id_for_log = job_doc.get("_id", "unknown")
project_id = job_doc.get("project_id")
if not project_id:
logger.debug(f"Glossary skip job={job_id_for_log}: no project_id")
return ""
project = await db.projects.find_one({"_id": project_id})
if not project:
logger.warning(f"Glossary skip job={job_id_for_log}: project {project_id!r} not found")
return ""
client_id = project.get("client_id")
if not client_id:
logger.debug(f"Glossary skip job={job_id_for_log}: project has no client_id")
return ""
# Get active version id via our cache-backed helper (reuses Redis if available)
active_version_id = await _get_active_version_id(client_id)
if not active_version_id:
logger.debug(f"Glossary skip job={job_id_for_log}: no active glossary for client {client_id!r}")
return ""
# Combine source VTT texts for matching
source_text = job_doc.get("_glossary_source_text", "")
if not source_text:
logger.debug(f"Glossary skip job={job_id_for_log}: no source text provided for matching")
return ""
logger.info(f"Glossary lookup job={job_id_for_log} client={client_id!r} version={active_version_id!r} locale={target_locale!r}")
norm_target = locale_lib.normalize_code(target_locale)
exact_matches = await _exact_match(db, active_version_id, source_text, norm_target)
remaining = _MAX_TERMS_IN_PROMPT - len(exact_matches)
already_found = {m.source_term.lower() for m in exact_matches}
vector_matches: list[MatchedTerm] = []
if remaining > 0:
try:
vector_matches = await _vector_match(
db, active_version_id, source_text, norm_target,
top_k=_VECTOR_TOP_K, exclude_terms=already_found,
)
except Exception as ve:
logger.debug(f"Vector search skipped in task context: {ve}")
combined = exact_matches + vector_matches
if len(combined) > _MAX_TERMS_IN_PROMPT:
logger.info(f"glossary_terms_truncated: capped at {_MAX_TERMS_IN_PROMPT}")
combined = combined[:_MAX_TERMS_IN_PROMPT]
return build_glossary_prompt_block(combined, target_locale)
except Exception as e:
import traceback
logger.warning(f"Glossary lookup failed for job {job_doc.get('_id')} (non-fatal): {e}\n{traceback.format_exc()}")
return ""
# ── Listing helpers ───────────────────────────────────────────────────────────
async def get_glossaries_for_client(client_id: str) -> list[Glossary]:
db = await get_database()
cursor = db[_COLL_GLOSSARIES].find(
{"client_id": client_id, "status": {"$ne": GlossaryStatus.ARCHIVED.value}},
sort=[("created_at", -1)],
)
docs = await cursor.to_list(length=100)
return [glossary_from_doc(d) for d in docs]
async def get_glossary(glossary_id: str) -> Glossary | None:
db = await get_database()
doc = await db[_COLL_GLOSSARIES].find_one({"_id": ObjectId(glossary_id)})
return glossary_from_doc(doc) if doc else None
async def get_versions_by_ids(version_ids: list[str]) -> dict[str, GlossaryVersion]:
"""Batch-fetch versions by ID, returns {version_id: GlossaryVersion}."""
if not version_ids:
return {}
db = await get_database()
docs = await db[_COLL_VERSIONS].find(
{"_id": {"$in": [ObjectId(vid) for vid in version_ids]}}
).to_list(length=len(version_ids))
return {str(d["_id"]): glossary_version_from_doc(d) for d in docs}
async def get_versions(glossary_id: str) -> list[GlossaryVersion]:
db = await get_database()
cursor = db[_COLL_VERSIONS].find(
{"glossary_id": glossary_id},
sort=[("version_number", -1)],
)
docs = await cursor.to_list(length=50)
return [glossary_version_from_doc(d) for d in docs]
async def get_terms_page(
version_id: str,
search: str | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[dict], int]:
"""Returns (terms, total_count) for paginated UI preview."""
db = await get_database()
query: dict = {"version_id": version_id}
if search:
query["source_term_lower"] = {"$regex": re.escape(search.lower())}
total = await db[_COLL_TERMS].count_documents(query)
cursor = db[_COLL_TERMS].find(
query,
{"_id": 1, "source_term": 1, "translations": 1},
skip=(page - 1) * page_size,
limit=page_size,
sort=[("source_term_lower", 1)],
)
docs = await cursor.to_list(length=page_size)
terms = []
for d in docs:
d["_id"] = str(d["_id"])
# Only source_term + translations are projected — build a minimal dict
# rather than validating against GlossaryTerm (which requires more fields)
terms.append({"_id": d["_id"], "source_term": d.get("source_term", ""), "translations": d.get("translations", {})})
return terms, total