- Add current_version_embedding_status/embedded_count/term_count to GlossaryResponse
- Batch-fetch current versions in list endpoint (single extra query, not N queries)
- Add get_versions_by_ids() helper to glossary_service
- Fix GlossaryList.tsx: embeddingBadge('') → embeddingBadge(g) with real status + pct
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
779 lines
27 KiB
Python
779 lines
27 KiB
Python
"""
|
|
Glossary service — per-client terminology management.
|
|
|
|
Responsibilities:
|
|
• parse_xlsx(bytes, source_col) → list of (source_term, {locale: translation})
|
|
• ingest_glossary(...) → create Glossary + GlossaryVersion + GlossaryTerms in Mongo
|
|
• activate_version(...) → atomic swap of current_version_id
|
|
• match_terms_for_text(...) → hybrid exact + vector retrieval for prompt injection
|
|
• build_glossary_prompt_block(...) → formats matched terms for the Gemini prompt
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import re
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
|
|
from bson import ObjectId
|
|
from fastapi import UploadFile
|
|
|
|
from ..core.database import get_database
|
|
from ..core.logging import get_logger
|
|
from ..lib import locales as locale_lib
|
|
from ..models.glossary import (
|
|
EmbeddingStatus,
|
|
Glossary,
|
|
GlossaryStatus,
|
|
GlossaryVersion,
|
|
MatchedTerm,
|
|
glossary_from_doc,
|
|
glossary_version_from_doc,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_COLL_GLOSSARIES = "glossaries"
|
|
_COLL_VERSIONS = "glossary_versions"
|
|
_COLL_TERMS = "glossary_terms"
|
|
|
|
# Maximum number of terms injected into a single Gemini prompt
|
|
_MAX_TERMS_IN_PROMPT = 50
|
|
|
|
# Atlas Vector Search index name (must exist on the collection)
|
|
_VECTOR_INDEX = "glossary_embedding_index"
|
|
_VECTOR_DIMS = 768
|
|
_VECTOR_SIMILARITY_THRESHOLD = 0.75
|
|
_VECTOR_TOP_K = 20
|
|
|
|
|
|
# ── xlsx parsing ─────────────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class _ParsedTerm:
|
|
cid: str | None
|
|
tid: str | None
|
|
source_term: str
|
|
translations: dict[str, str] # {normalized_locale: text}
|
|
|
|
|
|
def _cell(row: tuple, idx: int | None) -> str | None:
|
|
if idx is None or idx >= len(row):
|
|
return None
|
|
v = row[idx]
|
|
return str(v).strip() if v is not None else None
|
|
|
|
|
|
def parse_xlsx(file_bytes: bytes, source_locale_col: str) -> list[_ParsedTerm]:
|
|
"""
|
|
Parse an xlsx glossary file.
|
|
|
|
Args:
|
|
file_bytes: Raw xlsx bytes.
|
|
source_locale_col: The column header that contains the source text,
|
|
e.g. "en_gb" or "en-GB". Case-insensitive.
|
|
|
|
Returns:
|
|
List of parsed terms. Rows where the source column is empty are skipped.
|
|
"""
|
|
import openpyxl # local import — only used during ingest
|
|
|
|
wb = openpyxl.load_workbook(io.BytesIO(file_bytes), read_only=True, data_only=True)
|
|
ws = wb.active
|
|
|
|
rows = ws.iter_rows(values_only=True)
|
|
try:
|
|
header_row = next(rows)
|
|
except StopIteration:
|
|
return []
|
|
|
|
# Normalise header names to canonical locale codes
|
|
headers: list[str | None] = []
|
|
for h in header_row:
|
|
if h is None:
|
|
headers.append(None)
|
|
continue
|
|
s = str(h).strip()
|
|
headers.append(s)
|
|
|
|
# Find column indices
|
|
src_col_name = source_locale_col.strip()
|
|
# Try exact match first, then case-insensitive
|
|
src_idx: int | None = None
|
|
for i, h in enumerate(headers):
|
|
if h and h.lower() == src_col_name.lower():
|
|
src_idx = i
|
|
break
|
|
|
|
if src_idx is None:
|
|
raise ValueError(f"Source column '{source_locale_col}' not found in xlsx. Available: {[h for h in headers if h]}")
|
|
|
|
cid_idx = next((i for i, h in enumerate(headers) if h and h.upper() == "CID"), None)
|
|
tid_idx = next((i for i, h in enumerate(headers) if h and h.upper() == "TID"), None)
|
|
|
|
# All other columns with valid locale-like names become translation columns
|
|
locale_cols: list[tuple[int, str]] = [] # [(col_index, normalized_locale_code)]
|
|
for i, h in enumerate(headers):
|
|
if h is None or i == src_idx or i == cid_idx or i == tid_idx:
|
|
continue
|
|
norm = locale_lib.normalize_code(h)
|
|
if norm:
|
|
locale_cols.append((i, norm))
|
|
|
|
terms: list[_ParsedTerm] = []
|
|
for row in rows:
|
|
if not row or all(v is None for v in row):
|
|
continue
|
|
|
|
source = _cell(row, src_idx)
|
|
if not source:
|
|
continue
|
|
|
|
translations: dict[str, str] = {}
|
|
for col_idx, locale_code in locale_cols:
|
|
val = _cell(row, col_idx)
|
|
if val:
|
|
translations[locale_code] = val
|
|
|
|
terms.append(_ParsedTerm(
|
|
cid=_cell(row, cid_idx),
|
|
tid=_cell(row, tid_idx),
|
|
source_term=source,
|
|
translations=translations,
|
|
))
|
|
|
|
wb.close()
|
|
return terms
|
|
|
|
|
|
# ── Ingest ────────────────────────────────────────────────────────────────────
|
|
|
|
async def ingest_glossary(
|
|
client_id: str,
|
|
name: str,
|
|
source_locale: str,
|
|
source_locale_col: str,
|
|
file: UploadFile,
|
|
user_id: str,
|
|
description: str | None = None,
|
|
change_note: str | None = None,
|
|
) -> tuple[Glossary, GlossaryVersion]:
|
|
"""
|
|
Full glossary ingestion pipeline:
|
|
1. Upload xlsx to GCS
|
|
2. Parse terms
|
|
3. Create Glossary + GlossaryVersion + GlossaryTerm documents in Mongo
|
|
4. Kick off background embedding task
|
|
|
|
Returns (Glossary, GlossaryVersion) on success.
|
|
"""
|
|
db = await get_database()
|
|
|
|
# ── Upload original xlsx to GCS ──
|
|
file_bytes = await file.read()
|
|
glossary_id = str(ObjectId())
|
|
version_id = str(ObjectId())
|
|
gcs_path = f"glossaries/{client_id}/{glossary_id}/{version_id}/source.xlsx"
|
|
await _upload_bytes_to_gcs(file_bytes, gcs_path,
|
|
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
|
|
|
|
# ── Parse ──
|
|
logger.info(f"Parsing xlsx for glossary {glossary_id}, source_col={source_locale_col}")
|
|
parsed_terms = parse_xlsx(file_bytes, source_locale_col)
|
|
logger.info(f"Parsed {len(parsed_terms)} terms")
|
|
|
|
# ── Create Glossary doc ──
|
|
now = datetime.utcnow()
|
|
glossary_doc = {
|
|
"_id": ObjectId(glossary_id),
|
|
"client_id": client_id,
|
|
"name": name,
|
|
"description": description,
|
|
"source_locale": locale_lib.normalize_code(source_locale),
|
|
"source": "xlsx_upload",
|
|
"status": GlossaryStatus.ACTIVE.value,
|
|
"current_version_id": version_id,
|
|
"created_at": now,
|
|
"created_by": user_id,
|
|
}
|
|
await db[_COLL_GLOSSARIES].insert_one(glossary_doc)
|
|
|
|
# ── Create GlossaryVersion doc ──
|
|
version_doc = {
|
|
"_id": ObjectId(version_id),
|
|
"glossary_id": glossary_id,
|
|
"version_number": 1,
|
|
"source_xlsx_gcs_path": gcs_path,
|
|
"term_count": len(parsed_terms),
|
|
"embedded_count": 0,
|
|
"embedding_status": EmbeddingStatus.PENDING.value,
|
|
"created_at": now,
|
|
"created_by": user_id,
|
|
"change_note": change_note,
|
|
}
|
|
await db[_COLL_VERSIONS].insert_one(version_doc)
|
|
|
|
# ── Bulk insert GlossaryTerms ──
|
|
if parsed_terms:
|
|
term_docs = [
|
|
{
|
|
"_id": ObjectId(),
|
|
"glossary_id": glossary_id,
|
|
"version_id": version_id,
|
|
"cid": t.cid,
|
|
"tid": t.tid,
|
|
"source_term": t.source_term,
|
|
"source_term_lower": t.source_term.lower(),
|
|
"translations": t.translations,
|
|
"embedding": None,
|
|
}
|
|
for t in parsed_terms
|
|
]
|
|
await db[_COLL_TERMS].insert_many(term_docs, ordered=False)
|
|
|
|
# ── Create collection indexes (idempotent) ──
|
|
await _ensure_indexes(db)
|
|
|
|
# ── Kick off embedding Celery task ──
|
|
try:
|
|
from ..tasks.embed_glossary import embed_glossary_version_task
|
|
embed_glossary_version_task.delay(version_id)
|
|
logger.info(f"Queued embedding task for version {version_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not queue embedding task: {e}")
|
|
|
|
glossary = glossary_from_doc(glossary_doc)
|
|
version = glossary_version_from_doc(version_doc)
|
|
return glossary, version
|
|
|
|
|
|
async def ingest_new_version(
|
|
glossary_id: str,
|
|
source_locale_col: str,
|
|
file: UploadFile,
|
|
user_id: str,
|
|
change_note: str | None = None,
|
|
) -> GlossaryVersion:
|
|
"""Add a new version to an existing glossary without replacing it as active."""
|
|
db = await get_database()
|
|
|
|
glossary_doc = await db[_COLL_GLOSSARIES].find_one({"_id": ObjectId(glossary_id)})
|
|
if not glossary_doc:
|
|
raise ValueError(f"Glossary {glossary_id} not found")
|
|
|
|
client_id = glossary_doc["client_id"]
|
|
|
|
# Find next version number
|
|
last_version = await db[_COLL_VERSIONS].find_one(
|
|
{"glossary_id": glossary_id},
|
|
sort=[("version_number", -1)],
|
|
)
|
|
next_version_num = (last_version["version_number"] + 1) if last_version else 1
|
|
|
|
file_bytes = await file.read()
|
|
version_id = str(ObjectId())
|
|
gcs_path = f"glossaries/{client_id}/{glossary_id}/{version_id}/source.xlsx"
|
|
await _upload_bytes_to_gcs(file_bytes, gcs_path,
|
|
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
|
|
|
|
parsed_terms = parse_xlsx(file_bytes, source_locale_col)
|
|
|
|
now = datetime.utcnow()
|
|
version_doc = {
|
|
"_id": ObjectId(version_id),
|
|
"glossary_id": glossary_id,
|
|
"version_number": next_version_num,
|
|
"source_xlsx_gcs_path": gcs_path,
|
|
"term_count": len(parsed_terms),
|
|
"embedded_count": 0,
|
|
"embedding_status": EmbeddingStatus.PENDING.value,
|
|
"created_at": now,
|
|
"created_by": user_id,
|
|
"change_note": change_note,
|
|
}
|
|
await db[_COLL_VERSIONS].insert_one(version_doc)
|
|
|
|
if parsed_terms:
|
|
term_docs = [
|
|
{
|
|
"_id": ObjectId(),
|
|
"glossary_id": glossary_id,
|
|
"version_id": version_id,
|
|
"cid": t.cid,
|
|
"tid": t.tid,
|
|
"source_term": t.source_term,
|
|
"source_term_lower": t.source_term.lower(),
|
|
"translations": t.translations,
|
|
"embedding": None,
|
|
}
|
|
for t in parsed_terms
|
|
]
|
|
await db[_COLL_TERMS].insert_many(term_docs, ordered=False)
|
|
|
|
try:
|
|
from ..tasks.embed_glossary import embed_glossary_version_task
|
|
embed_glossary_version_task.delay(version_id)
|
|
except Exception as e:
|
|
logger.warning(f"Could not queue embedding task: {e}")
|
|
|
|
return glossary_version_from_doc(version_doc)
|
|
|
|
|
|
async def activate_version(glossary_id: str, version_id: str) -> None:
|
|
"""Atomically set the active version of a glossary."""
|
|
db = await get_database()
|
|
result = await db[_COLL_GLOSSARIES].update_one(
|
|
{"_id": ObjectId(glossary_id)},
|
|
{"$set": {"current_version_id": version_id}},
|
|
)
|
|
if result.matched_count == 0:
|
|
raise ValueError(f"Glossary {glossary_id} not found")
|
|
# Invalidate Redis cache
|
|
await _invalidate_cache(glossary_id)
|
|
|
|
|
|
async def archive_glossary(glossary_id: str) -> None:
|
|
"""Hard-delete the glossary and all its versions and terms."""
|
|
db = await get_database()
|
|
|
|
versions = await db[_COLL_VERSIONS].find(
|
|
{"glossary_id": glossary_id}, {"_id": 1}
|
|
).to_list(length=None)
|
|
version_ids = [str(v["_id"]) for v in versions]
|
|
|
|
if version_ids:
|
|
terms_result = await db[_COLL_TERMS].delete_many({"version_id": {"$in": version_ids}})
|
|
logger.info(f"Deleted {terms_result.deleted_count} terms for glossary {glossary_id}")
|
|
|
|
await db[_COLL_VERSIONS].delete_many({"glossary_id": glossary_id})
|
|
logger.info(f"Deleted {len(version_ids)} versions for glossary {glossary_id}")
|
|
|
|
await db[_COLL_GLOSSARIES].delete_one({"_id": ObjectId(glossary_id)})
|
|
await _invalidate_cache(glossary_id)
|
|
logger.info(f"Deleted glossary {glossary_id}")
|
|
|
|
|
|
# ── Retrieval ─────────────────────────────────────────────────────────────────
|
|
|
|
async def match_terms_for_text(
|
|
client_id: str,
|
|
text: str,
|
|
target_locale: str,
|
|
top_k: int = _MAX_TERMS_IN_PROMPT,
|
|
) -> list[MatchedTerm]:
|
|
"""
|
|
Hybrid retrieval: exact-match (Aho-Corasick) + semantic (Atlas Vector Search).
|
|
|
|
Returns a ranked, deduplicated list of up to `top_k` MatchedTerm objects,
|
|
each with the source term and its translation in `target_locale`.
|
|
Exact matches rank before vector matches.
|
|
"""
|
|
db = await get_database()
|
|
norm_target = locale_lib.normalize_code(target_locale)
|
|
|
|
active_version_id = await _get_active_version_id(client_id)
|
|
if not active_version_id:
|
|
return []
|
|
|
|
# ── Exact pass ──
|
|
exact_matches = await _exact_match(db, active_version_id, text, norm_target)
|
|
|
|
# ── Vector pass (if we haven't hit the limit yet) ──
|
|
remaining = top_k - len(exact_matches)
|
|
already_found = {m.source_term.lower() for m in exact_matches}
|
|
vector_matches: list[MatchedTerm] = []
|
|
|
|
if remaining > 0:
|
|
try:
|
|
vector_matches = await _vector_match(
|
|
db, active_version_id, text, norm_target,
|
|
top_k=_VECTOR_TOP_K, exclude_terms=already_found,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Vector search failed (non-fatal): {e}")
|
|
|
|
combined = exact_matches + vector_matches
|
|
if len(combined) > top_k:
|
|
logger.info(f"glossary_terms_truncated: had {len(combined)}, capped at {top_k}")
|
|
combined = combined[:top_k]
|
|
|
|
return combined
|
|
|
|
|
|
async def _get_active_version_id(client_id: str) -> str | None:
|
|
"""Return the active version_id for the active glossary of a client, or None."""
|
|
try:
|
|
from ..core.redis import redis_client # lazy import
|
|
cache_key = f"glossary:active_version:{client_id}"
|
|
cached = await redis_client.get(cache_key)
|
|
if cached:
|
|
return cached.decode() if isinstance(cached, bytes) else cached
|
|
except Exception:
|
|
pass
|
|
|
|
db = await get_database()
|
|
glossary_doc = await db[_COLL_GLOSSARIES].find_one(
|
|
{"client_id": client_id, "status": GlossaryStatus.ACTIVE.value},
|
|
sort=[("created_at", -1)],
|
|
)
|
|
if not glossary_doc or not glossary_doc.get("current_version_id"):
|
|
return None
|
|
|
|
version_id = glossary_doc["current_version_id"]
|
|
|
|
try:
|
|
from ..core.redis import redis_client
|
|
cache_key = f"glossary:active_version:{client_id}"
|
|
await redis_client.setex(cache_key, 3600, version_id)
|
|
except Exception:
|
|
pass
|
|
|
|
return version_id
|
|
|
|
|
|
async def _invalidate_cache(glossary_id: str) -> None:
|
|
"""Clear Redis cache for a glossary's client."""
|
|
try:
|
|
db = await get_database()
|
|
doc = await db[_COLL_GLOSSARIES].find_one({"_id": ObjectId(glossary_id)})
|
|
if doc:
|
|
from ..core.redis import redis_client
|
|
await redis_client.delete(f"glossary:active_version:{doc['client_id']}")
|
|
except Exception as e:
|
|
logger.debug(f"Cache invalidation skipped: {e}")
|
|
|
|
|
|
async def _exact_match(
|
|
db,
|
|
version_id: str,
|
|
text: str,
|
|
target_locale: str,
|
|
) -> list[MatchedTerm]:
|
|
"""Find terms present in `text` using Aho-Corasick over the glossary terms."""
|
|
import ahocorasick # pyahocorasick
|
|
|
|
# Load all terms for this version (source_term_lower + translations)
|
|
cursor = db[_COLL_TERMS].find(
|
|
{"version_id": version_id},
|
|
{"source_term": 1, "source_term_lower": 1, "translations": 1},
|
|
)
|
|
terms = await cursor.to_list(length=None)
|
|
if not terms:
|
|
return []
|
|
|
|
# Build automaton
|
|
automaton = ahocorasick.Automaton()
|
|
for doc in terms:
|
|
stl = doc.get("source_term_lower") or doc.get("source_term", "")
|
|
if stl:
|
|
automaton.add_word(stl.lower(), (doc["source_term"], doc.get("translations", {})))
|
|
if not automaton:
|
|
return []
|
|
automaton.make_automaton()
|
|
|
|
text_lower = text.lower()
|
|
matched: list[MatchedTerm] = []
|
|
seen: set[str] = set()
|
|
|
|
for _end_idx, (source_term, translations) in automaton.iter(text_lower):
|
|
if source_term in seen:
|
|
continue
|
|
# Require word/phrase boundaries around the match
|
|
start_idx = _end_idx - len(source_term.lower()) + 1
|
|
if start_idx > 0 and text_lower[start_idx - 1].isalnum():
|
|
continue
|
|
end_after = _end_idx + 1
|
|
if end_after < len(text_lower) and text_lower[end_after].isalnum():
|
|
continue
|
|
|
|
target_text = _get_translation(translations, target_locale)
|
|
if not target_text:
|
|
continue
|
|
seen.add(source_term)
|
|
matched.append(MatchedTerm(
|
|
source_term=source_term,
|
|
target_translation=target_text,
|
|
match_kind="exact",
|
|
score=1.0,
|
|
))
|
|
|
|
return matched
|
|
|
|
|
|
async def _vector_match(
|
|
db,
|
|
version_id: str,
|
|
text: str,
|
|
target_locale: str,
|
|
top_k: int = 20,
|
|
exclude_terms: set[str] | None = None,
|
|
) -> list[MatchedTerm]:
|
|
"""Semantic search via Atlas Vector Search ($vectorSearch)."""
|
|
from ..services.embedding_service import embedding_service
|
|
|
|
query_embedding = await embedding_service.embed_text(text[:2000]) # cap input length
|
|
|
|
pipeline = [
|
|
{
|
|
"$vectorSearch": {
|
|
"index": _VECTOR_INDEX,
|
|
"path": "embedding",
|
|
"queryVector": query_embedding,
|
|
"numCandidates": top_k * 4,
|
|
"limit": top_k,
|
|
"filter": {"version_id": version_id},
|
|
}
|
|
},
|
|
{
|
|
"$project": {
|
|
"source_term": 1,
|
|
"translations": 1,
|
|
"score": {"$meta": "vectorSearchScore"},
|
|
}
|
|
},
|
|
]
|
|
|
|
cursor = db[_COLL_TERMS].aggregate(pipeline)
|
|
results = await cursor.to_list(length=top_k)
|
|
|
|
matched: list[MatchedTerm] = []
|
|
for doc in results:
|
|
score = doc.get("score", 0.0)
|
|
if score < _VECTOR_SIMILARITY_THRESHOLD:
|
|
continue
|
|
source_term = doc["source_term"]
|
|
if exclude_terms and source_term.lower() in exclude_terms:
|
|
continue
|
|
target_text = _get_translation(doc["translations"], target_locale)
|
|
if not target_text:
|
|
continue
|
|
matched.append(MatchedTerm(
|
|
source_term=source_term,
|
|
target_translation=target_text,
|
|
match_kind="vector",
|
|
score=score,
|
|
))
|
|
|
|
return matched
|
|
|
|
|
|
def _get_translation(translations: dict[str, str], target_locale: str) -> str | None:
|
|
"""Look up a translation with locale-fallback.
|
|
|
|
Specific → bare: fr-CA → fr-FR siblings → fr
|
|
Bare → specific: fr → fr-FR, fr-CA (first match)
|
|
"""
|
|
if not translations or not target_locale:
|
|
return None
|
|
if target_locale in translations:
|
|
return translations[target_locale]
|
|
if "-" in target_locale:
|
|
# Specific locale: try sibling regions and bare parent (fr-CA → fr-FR → fr)
|
|
parent = target_locale.split("-")[0]
|
|
for code, text in translations.items():
|
|
if code.startswith(parent + "-") or code == parent:
|
|
return text
|
|
else:
|
|
# Bare code (fr): try any fr-* region variant stored in the glossary
|
|
for code, text in translations.items():
|
|
if code == target_locale or code.startswith(target_locale + "-"):
|
|
return text
|
|
return None
|
|
|
|
|
|
# ── Prompt block ──────────────────────────────────────────────────────────────
|
|
|
|
def build_glossary_prompt_block(
|
|
matched_terms: Sequence[MatchedTerm],
|
|
target_locale: str,
|
|
) -> str:
|
|
"""
|
|
Format matched terms for injection into a Gemini prompt.
|
|
Returns an empty string if no terms were matched.
|
|
"""
|
|
if not matched_terms:
|
|
return ""
|
|
|
|
target_label = locale_lib.get_gemini_label(target_locale)
|
|
lines = [
|
|
f"## Approved {target_label} terminology",
|
|
"Use these exact translations when the source terms appear — do not deviate:",
|
|
]
|
|
for term in matched_terms:
|
|
lines.append(f'- "{term.source_term}" → "{term.target_translation}"')
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
async def _upload_bytes_to_gcs(data: bytes, gcs_path: str, content_type: str) -> None:
|
|
import asyncio
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def _upload() -> None:
|
|
from google.cloud import storage as gcs_storage
|
|
|
|
from ..core.config import settings
|
|
client = gcs_storage.Client(project=settings.gcp_project_id)
|
|
bucket = client.bucket(settings.gcs_bucket)
|
|
blob = bucket.blob(gcs_path)
|
|
blob.content_type = content_type
|
|
blob.upload_from_string(data, content_type=content_type)
|
|
|
|
await loop.run_in_executor(None, _upload)
|
|
|
|
|
|
async def _ensure_indexes(db) -> None:
|
|
try:
|
|
await db[_COLL_GLOSSARIES].create_index([("client_id", 1), ("status", 1)])
|
|
await db[_COLL_VERSIONS].create_index([("glossary_id", 1), ("version_number", -1)])
|
|
await db[_COLL_TERMS].create_index([("version_id", 1), ("source_term_lower", 1)])
|
|
await db[_COLL_TERMS].create_index([("glossary_id", 1)])
|
|
except Exception as e:
|
|
logger.debug(f"Index creation skipped (likely already exist): {e}")
|
|
|
|
|
|
# ── Task helpers ─────────────────────────────────────────────────────────────
|
|
|
|
async def get_glossary_block_for_job(
|
|
job_doc: dict,
|
|
target_locale: str,
|
|
db,
|
|
) -> str:
|
|
"""
|
|
Convenience function for Celery tasks: given a job document and a target locale,
|
|
return the formatted glossary block for Gemini prompt injection (or empty string).
|
|
|
|
Looks up:
|
|
job_doc.project_id → db.projects → client_id → active glossary version
|
|
|
|
Non-fatal: any failure returns "" so the pipeline continues without a glossary.
|
|
"""
|
|
try:
|
|
job_id_for_log = job_doc.get("_id", "unknown")
|
|
project_id = job_doc.get("project_id")
|
|
if not project_id:
|
|
logger.debug(f"Glossary skip job={job_id_for_log}: no project_id")
|
|
return ""
|
|
|
|
project = await db.projects.find_one({"_id": project_id})
|
|
if not project:
|
|
logger.warning(f"Glossary skip job={job_id_for_log}: project {project_id!r} not found")
|
|
return ""
|
|
|
|
client_id = project.get("client_id")
|
|
if not client_id:
|
|
logger.debug(f"Glossary skip job={job_id_for_log}: project has no client_id")
|
|
return ""
|
|
|
|
# Get active version id via our cache-backed helper (reuses Redis if available)
|
|
active_version_id = await _get_active_version_id(client_id)
|
|
if not active_version_id:
|
|
logger.debug(f"Glossary skip job={job_id_for_log}: no active glossary for client {client_id!r}")
|
|
return ""
|
|
|
|
# Combine source VTT texts for matching
|
|
source_text = job_doc.get("_glossary_source_text", "")
|
|
if not source_text:
|
|
logger.debug(f"Glossary skip job={job_id_for_log}: no source text provided for matching")
|
|
return ""
|
|
|
|
logger.info(f"Glossary lookup job={job_id_for_log} client={client_id!r} version={active_version_id!r} locale={target_locale!r}")
|
|
norm_target = locale_lib.normalize_code(target_locale)
|
|
exact_matches = await _exact_match(db, active_version_id, source_text, norm_target)
|
|
|
|
remaining = _MAX_TERMS_IN_PROMPT - len(exact_matches)
|
|
already_found = {m.source_term.lower() for m in exact_matches}
|
|
vector_matches: list[MatchedTerm] = []
|
|
|
|
if remaining > 0:
|
|
try:
|
|
vector_matches = await _vector_match(
|
|
db, active_version_id, source_text, norm_target,
|
|
top_k=_VECTOR_TOP_K, exclude_terms=already_found,
|
|
)
|
|
except Exception as ve:
|
|
logger.debug(f"Vector search skipped in task context: {ve}")
|
|
|
|
combined = exact_matches + vector_matches
|
|
if len(combined) > _MAX_TERMS_IN_PROMPT:
|
|
logger.info(f"glossary_terms_truncated: capped at {_MAX_TERMS_IN_PROMPT}")
|
|
combined = combined[:_MAX_TERMS_IN_PROMPT]
|
|
|
|
return build_glossary_prompt_block(combined, target_locale)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
logger.warning(f"Glossary lookup failed for job {job_doc.get('_id')} (non-fatal): {e}\n{traceback.format_exc()}")
|
|
return ""
|
|
|
|
|
|
# ── Listing helpers ───────────────────────────────────────────────────────────
|
|
|
|
async def get_glossaries_for_client(client_id: str) -> list[Glossary]:
|
|
db = await get_database()
|
|
cursor = db[_COLL_GLOSSARIES].find(
|
|
{"client_id": client_id, "status": {"$ne": GlossaryStatus.ARCHIVED.value}},
|
|
sort=[("created_at", -1)],
|
|
)
|
|
docs = await cursor.to_list(length=100)
|
|
return [glossary_from_doc(d) for d in docs]
|
|
|
|
|
|
async def get_glossary(glossary_id: str) -> Glossary | None:
|
|
db = await get_database()
|
|
doc = await db[_COLL_GLOSSARIES].find_one({"_id": ObjectId(glossary_id)})
|
|
return glossary_from_doc(doc) if doc else None
|
|
|
|
|
|
async def get_versions_by_ids(version_ids: list[str]) -> dict[str, GlossaryVersion]:
|
|
"""Batch-fetch versions by ID, returns {version_id: GlossaryVersion}."""
|
|
if not version_ids:
|
|
return {}
|
|
db = await get_database()
|
|
docs = await db[_COLL_VERSIONS].find(
|
|
{"_id": {"$in": [ObjectId(vid) for vid in version_ids]}}
|
|
).to_list(length=len(version_ids))
|
|
return {str(d["_id"]): glossary_version_from_doc(d) for d in docs}
|
|
|
|
|
|
async def get_versions(glossary_id: str) -> list[GlossaryVersion]:
|
|
db = await get_database()
|
|
cursor = db[_COLL_VERSIONS].find(
|
|
{"glossary_id": glossary_id},
|
|
sort=[("version_number", -1)],
|
|
)
|
|
docs = await cursor.to_list(length=50)
|
|
return [glossary_version_from_doc(d) for d in docs]
|
|
|
|
|
|
async def get_terms_page(
|
|
version_id: str,
|
|
search: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 50,
|
|
) -> tuple[list[dict], int]:
|
|
"""Returns (terms, total_count) for paginated UI preview."""
|
|
db = await get_database()
|
|
query: dict = {"version_id": version_id}
|
|
if search:
|
|
query["source_term_lower"] = {"$regex": re.escape(search.lower())}
|
|
|
|
total = await db[_COLL_TERMS].count_documents(query)
|
|
cursor = db[_COLL_TERMS].find(
|
|
query,
|
|
{"_id": 1, "source_term": 1, "translations": 1},
|
|
skip=(page - 1) * page_size,
|
|
limit=page_size,
|
|
sort=[("source_term_lower", 1)],
|
|
)
|
|
docs = await cursor.to_list(length=page_size)
|
|
terms = []
|
|
for d in docs:
|
|
d["_id"] = str(d["_id"])
|
|
# Only source_term + translations are projected — build a minimal dict
|
|
# rather than validating against GlossaryTerm (which requires more fields)
|
|
terms.append({"_id": d["_id"], "source_term": d.get("source_term", ""), "translations": d.get("translations", {})})
|
|
return terms, total
|