112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
"""
|
|
Celery task: compute and store Gemini embeddings for all terms in a glossary version.
|
|
|
|
Runs as a background job after glossary ingestion so the API response is fast.
|
|
Processes terms in concurrent batches of 250 (5 batches in parallel).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from bson import ObjectId
|
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
|
|
|
from ..core.config import settings
|
|
from ..core.logging import get_logger
|
|
from ..models.glossary import EmbeddingStatus
|
|
from . import celery_app
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_BATCH_SIZE = 250
|
|
_CONCURRENCY = 5
|
|
|
|
|
|
@celery_app.task(name="embed_glossary_version", bind=True, max_retries=3)
|
|
def embed_glossary_version_task(self, version_id: str) -> dict:
|
|
try:
|
|
result = asyncio.run(_async_embed_version(version_id))
|
|
return result
|
|
except Exception as exc:
|
|
logger.error(f"embed_glossary_version_task failed for {version_id}: {exc}")
|
|
raise self.retry(exc=exc, countdown=60) from None
|
|
|
|
|
|
async def _embed_batch(
|
|
db: AsyncIOMotorDatabase,
|
|
version_id: str,
|
|
batch: list[dict[str, Any]],
|
|
sem: asyncio.Semaphore,
|
|
counter: list[int],
|
|
total: int,
|
|
) -> None:
|
|
from pymongo import UpdateOne
|
|
|
|
from ..services.embedding_service import embedding_service
|
|
|
|
async with sem:
|
|
texts = [t["source_term"] for t in batch]
|
|
ids = [t["_id"] for t in batch]
|
|
embeddings = await embedding_service.embed_texts(texts)
|
|
|
|
ops = [
|
|
UpdateOne({"_id": tid}, {"$set": {"embedding": emb}})
|
|
for tid, emb in zip(ids, embeddings, strict=False)
|
|
]
|
|
if ops:
|
|
await db.glossary_terms.bulk_write(ops, ordered=False)
|
|
|
|
counter[0] += len(batch)
|
|
await db.glossary_versions.update_one(
|
|
{"_id": ObjectId(version_id)},
|
|
{"$set": {"embedded_count": counter[0]}},
|
|
)
|
|
logger.info(f"Version {version_id}: embedded {counter[0]}/{total}")
|
|
|
|
|
|
async def _async_embed_version(version_id: str) -> dict:
|
|
mongo_client = AsyncIOMotorClient(settings.mongodb_uri)
|
|
db = mongo_client[settings.mongodb_db]
|
|
|
|
try:
|
|
await db.glossary_versions.update_one(
|
|
{"_id": ObjectId(version_id)},
|
|
{"$set": {"embedding_status": EmbeddingStatus.IN_PROGRESS.value}},
|
|
)
|
|
|
|
cursor = db.glossary_terms.find(
|
|
{"version_id": version_id, "embedding": None},
|
|
{"_id": 1, "source_term": 1},
|
|
)
|
|
terms = await cursor.to_list(length=None)
|
|
total = len(terms)
|
|
logger.info(f"Embedding {total} terms for version {version_id} (batch={_BATCH_SIZE}, concurrency={_CONCURRENCY})")
|
|
|
|
batches = [terms[i: i + _BATCH_SIZE] for i in range(0, total, _BATCH_SIZE)]
|
|
sem = asyncio.Semaphore(_CONCURRENCY)
|
|
counter = [0]
|
|
|
|
await asyncio.gather(*[
|
|
_embed_batch(db, version_id, batch, sem, counter, total)
|
|
for batch in batches
|
|
])
|
|
|
|
await db.glossary_versions.update_one(
|
|
{"_id": ObjectId(version_id)},
|
|
{"$set": {
|
|
"embedding_status": EmbeddingStatus.DONE.value,
|
|
"embedded_count": total,
|
|
}},
|
|
)
|
|
logger.info(f"Embedding complete for version {version_id}: {total} terms")
|
|
return {"version_id": version_id, "total": total}
|
|
|
|
except Exception:
|
|
await db.glossary_versions.update_one(
|
|
{"_id": ObjectId(version_id)},
|
|
{"$set": {"embedding_status": EmbeddingStatus.FAILED.value}},
|
|
)
|
|
raise
|
|
finally:
|
|
mongo_client.close()
|