video-accessibility/backend/app/tasks/embed_glossary.py
Vadym Samoilenko 31199f8705 chore: push all session changes — backend hardening, tests, apache config, deploy scripts
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-30 15:52:14 +01:00

112 lines
3.6 KiB
Python

"""
Celery task: compute and store Gemini embeddings for all terms in a glossary version.
Runs as a background job after glossary ingestion so the API response is fast.
Processes terms in concurrent batches of 250 (5 batches in parallel).
"""
from __future__ import annotations
import asyncio
from typing import Any
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from ..core.config import settings
from ..core.logging import get_logger
from ..models.glossary import EmbeddingStatus
from . import celery_app
logger = get_logger(__name__)
_BATCH_SIZE = 250
_CONCURRENCY = 5
@celery_app.task(name="embed_glossary_version", bind=True, max_retries=3)
def embed_glossary_version_task(self, version_id: str) -> dict:
try:
result = asyncio.run(_async_embed_version(version_id))
return result
except Exception as exc:
logger.error(f"embed_glossary_version_task failed for {version_id}: {exc}")
raise self.retry(exc=exc, countdown=60) from None
async def _embed_batch(
db: AsyncIOMotorDatabase,
version_id: str,
batch: list[dict[str, Any]],
sem: asyncio.Semaphore,
counter: list[int],
total: int,
) -> None:
from pymongo import UpdateOne
from ..services.embedding_service import embedding_service
async with sem:
texts = [t["source_term"] for t in batch]
ids = [t["_id"] for t in batch]
embeddings = await embedding_service.embed_texts(texts)
ops = [
UpdateOne({"_id": tid}, {"$set": {"embedding": emb}})
for tid, emb in zip(ids, embeddings, strict=False)
]
if ops:
await db.glossary_terms.bulk_write(ops, ordered=False)
counter[0] += len(batch)
await db.glossary_versions.update_one(
{"_id": ObjectId(version_id)},
{"$set": {"embedded_count": counter[0]}},
)
logger.info(f"Version {version_id}: embedded {counter[0]}/{total}")
async def _async_embed_version(version_id: str) -> dict:
mongo_client = AsyncIOMotorClient(settings.mongodb_uri)
db = mongo_client[settings.mongodb_db]
try:
await db.glossary_versions.update_one(
{"_id": ObjectId(version_id)},
{"$set": {"embedding_status": EmbeddingStatus.IN_PROGRESS.value}},
)
cursor = db.glossary_terms.find(
{"version_id": version_id, "embedding": None},
{"_id": 1, "source_term": 1},
)
terms = await cursor.to_list(length=None)
total = len(terms)
logger.info(f"Embedding {total} terms for version {version_id} (batch={_BATCH_SIZE}, concurrency={_CONCURRENCY})")
batches = [terms[i: i + _BATCH_SIZE] for i in range(0, total, _BATCH_SIZE)]
sem = asyncio.Semaphore(_CONCURRENCY)
counter = [0]
await asyncio.gather(*[
_embed_batch(db, version_id, batch, sem, counter, total)
for batch in batches
])
await db.glossary_versions.update_one(
{"_id": ObjectId(version_id)},
{"$set": {
"embedding_status": EmbeddingStatus.DONE.value,
"embedded_count": total,
}},
)
logger.info(f"Embedding complete for version {version_id}: {total} terms")
return {"version_id": version_id, "total": total}
except Exception:
await db.glossary_versions.update_one(
{"_id": ObjectId(version_id)},
{"$set": {"embedding_status": EmbeddingStatus.FAILED.value}},
)
raise
finally:
mongo_client.close()