modcomms/backend/app/repositories/knowledge_base_repository.py
Vadym Samoilenko 1982d5d76e feat(knowledge-base): smart resume for interrupted processing jobs
On server restart, stale active jobs are automatically resumed rather
than failed. Docs already parsed in a prior run are skipped (resume from
cache), docs stuck at 'parsing' are reset to 'pending' and re-parsed.

- Repository: add get_all_stale_active_jobs() and reset_stuck_parsing_docs()
- Service: skip already-parsed docs in _parse_doc(), reset stuck docs on start
- Main: recover stale jobs via asyncio.create_task() in lifespan startup

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 10:20:35 +01:00

342 lines
13 KiB
Python

import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.models import KnowledgeBase, SourceDocument, SpecVersion, ProcessingJob
class KnowledgeBaseRepository:
"""Repository for knowledge base database operations."""
def __init__(self, session: AsyncSession):
self.session = session
# ---- Knowledge Bases ----
async def list_knowledge_bases(self) -> list[KnowledgeBase]:
"""List all knowledge bases with eager-loaded relationships."""
query = (
select(KnowledgeBase)
.options(
selectinload(KnowledgeBase.source_documents),
selectinload(KnowledgeBase.spec_versions),
selectinload(KnowledgeBase.processing_jobs),
)
.order_by(KnowledgeBase.display_name)
)
result = await self.session.execute(query)
return list(result.scalars().all())
async def get_knowledge_base(self, kb_id: uuid.UUID) -> Optional[KnowledgeBase]:
"""Get a knowledge base by ID with all relationships."""
query = (
select(KnowledgeBase)
.options(
selectinload(KnowledgeBase.source_documents),
selectinload(KnowledgeBase.spec_versions),
selectinload(KnowledgeBase.processing_jobs),
)
.where(KnowledgeBase.id == kb_id)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def get_knowledge_base_by_key(self, agent_key: str) -> Optional[KnowledgeBase]:
"""Get a knowledge base by agent_key."""
query = select(KnowledgeBase).where(KnowledgeBase.agent_key == agent_key)
result = await self.session.execute(query)
return result.scalar_one_or_none()
# ---- Source Documents ----
async def add_source_document(
self,
knowledge_base_id: uuid.UUID,
filename: str,
file_storage_key: str,
file_size_bytes: int,
mime_type: str,
uploaded_by_id: Optional[uuid.UUID] = None,
uploaded_by_name: Optional[str] = None,
) -> SourceDocument:
"""Add a source document to a knowledge base."""
doc = SourceDocument(
knowledge_base_id=knowledge_base_id,
filename=filename,
file_storage_key=file_storage_key,
file_size_bytes=file_size_bytes,
mime_type=mime_type,
uploaded_by_id=uploaded_by_id,
uploaded_by_name=uploaded_by_name,
)
self.session.add(doc)
await self.session.flush()
return doc
async def remove_source_document(self, doc_id: uuid.UUID) -> Optional[SourceDocument]:
"""Remove a source document by ID. Returns the deleted doc or None."""
query = select(SourceDocument).where(SourceDocument.id == doc_id)
result = await self.session.execute(query)
doc = result.scalar_one_or_none()
if doc:
await self.session.delete(doc)
await self.session.flush()
return doc
async def get_source_documents(self, kb_id: uuid.UUID) -> list[SourceDocument]:
"""Get all source documents for a knowledge base."""
query = (
select(SourceDocument)
.where(SourceDocument.knowledge_base_id == kb_id)
.order_by(SourceDocument.created_at.desc())
)
result = await self.session.execute(query)
return list(result.scalars().all())
async def get_source_document(self, doc_id: uuid.UUID) -> Optional[SourceDocument]:
"""Get a single source document by ID."""
query = select(SourceDocument).where(SourceDocument.id == doc_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def update_source_document_parse_status(
self,
doc_id: uuid.UUID,
status: str,
parsed_markdown: Optional[str] = None,
parse_error: Optional[str] = None,
) -> None:
"""Update parse status of a source document."""
query = select(SourceDocument).where(SourceDocument.id == doc_id)
result = await self.session.execute(query)
doc = result.scalar_one_or_none()
if doc:
doc.parse_status = status
if parsed_markdown is not None:
doc.parsed_markdown = parsed_markdown
if parse_error is not None:
doc.parse_error = parse_error
await self.session.flush()
# ---- Spec Versions ----
async def get_active_spec_by_key(self, agent_key: str) -> Optional[SpecVersion]:
"""Get the active spec version for a given agent key."""
query = (
select(SpecVersion)
.join(KnowledgeBase)
.where(KnowledgeBase.agent_key == agent_key)
.where(SpecVersion.is_active == True)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def list_spec_versions(self, kb_id: uuid.UUID) -> list[SpecVersion]:
"""List all spec versions for a knowledge base."""
query = (
select(SpecVersion)
.where(SpecVersion.knowledge_base_id == kb_id)
.order_by(SpecVersion.version_number.desc())
)
result = await self.session.execute(query)
return list(result.scalars().all())
async def get_spec_version(self, version_id: uuid.UUID) -> Optional[SpecVersion]:
"""Get a spec version by ID."""
query = select(SpecVersion).where(SpecVersion.id == version_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def create_spec_version(
self,
knowledge_base_id: uuid.UUID,
content: str,
source_document_ids: Optional[list] = None,
generated_by_id: Optional[uuid.UUID] = None,
generated_by_name: Optional[str] = None,
processing_job_id: Optional[uuid.UUID] = None,
) -> SpecVersion:
"""Create a new spec version, auto-incrementing version number and deactivating prior."""
# Get next version number
max_query = (
select(func.coalesce(func.max(SpecVersion.version_number), 0))
.where(SpecVersion.knowledge_base_id == knowledge_base_id)
)
result = await self.session.execute(max_query)
next_version = result.scalar() + 1
# Deactivate all existing versions for this KB
deactivate_query = (
select(SpecVersion)
.where(SpecVersion.knowledge_base_id == knowledge_base_id)
.where(SpecVersion.is_active == True)
)
deactivate_result = await self.session.execute(deactivate_query)
for sv in deactivate_result.scalars().all():
sv.is_active = False
# Create the new version
spec = SpecVersion(
knowledge_base_id=knowledge_base_id,
version_number=next_version,
content=content,
source_document_ids=source_document_ids,
generated_by_id=generated_by_id,
generated_by_name=generated_by_name,
processing_job_id=processing_job_id,
is_active=True,
char_count=len(content),
)
self.session.add(spec)
await self.session.flush()
return spec
async def activate_spec_version(self, version_id: uuid.UUID) -> Optional[SpecVersion]:
"""Activate a specific version (revert), deactivating all others for same KB."""
query = select(SpecVersion).where(SpecVersion.id == version_id)
result = await self.session.execute(query)
target = result.scalar_one_or_none()
if not target:
return None
# Deactivate all versions for this KB
all_query = (
select(SpecVersion)
.where(SpecVersion.knowledge_base_id == target.knowledge_base_id)
.where(SpecVersion.is_active == True)
)
all_result = await self.session.execute(all_query)
for sv in all_result.scalars().all():
sv.is_active = False
# Activate the target
target.is_active = True
await self.session.flush()
return target
# ---- Processing Jobs ----
async def create_processing_job(
self,
knowledge_base_id: uuid.UUID,
total_documents: int,
triggered_by_id: Optional[uuid.UUID] = None,
triggered_by_name: Optional[str] = None,
) -> ProcessingJob:
"""Create a new processing job."""
job = ProcessingJob(
knowledge_base_id=knowledge_base_id,
total_documents=total_documents,
triggered_by_id=triggered_by_id,
triggered_by_name=triggered_by_name,
started_at=datetime.now(timezone.utc),
)
self.session.add(job)
await self.session.flush()
return job
async def update_processing_job(
self,
job_id: uuid.UUID,
status: Optional[str] = None,
parsed_documents: Optional[int] = None,
spec_version_id: Optional[uuid.UUID] = None,
error_message: Optional[str] = None,
completed_at: Optional[datetime] = None,
) -> Optional[ProcessingJob]:
"""Update a processing job's fields."""
query = select(ProcessingJob).where(ProcessingJob.id == job_id)
result = await self.session.execute(query)
job = result.scalar_one_or_none()
if not job:
return None
if status is not None:
job.status = status
if parsed_documents is not None:
job.parsed_documents = parsed_documents
if spec_version_id is not None:
job.spec_version_id = spec_version_id
if error_message is not None:
job.error_message = error_message
if completed_at is not None:
job.completed_at = completed_at
await self.session.flush()
return job
async def get_processing_job(self, job_id: uuid.UUID) -> Optional[ProcessingJob]:
"""Get a processing job by ID."""
query = select(ProcessingJob).where(ProcessingJob.id == job_id)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def get_latest_processing_job(self, kb_id: uuid.UUID) -> Optional[ProcessingJob]:
"""Get the most recent processing job for a knowledge base."""
query = (
select(ProcessingJob)
.where(ProcessingJob.knowledge_base_id == kb_id)
.order_by(ProcessingJob.created_at.desc())
.limit(1)
)
result = await self.session.execute(query)
return result.scalar_one_or_none()
async def fail_stale_jobs(self, kb_id: uuid.UUID, stale_minutes: int = 5) -> None:
"""Mark stale pending/active jobs as failed (older than stale_minutes)."""
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_minutes)
active_statuses = ["pending", "parsing_documents", "distilling"]
query = (
select(ProcessingJob)
.where(ProcessingJob.knowledge_base_id == kb_id)
.where(ProcessingJob.status.in_(active_statuses))
.where(ProcessingJob.created_at < cutoff)
)
result = await self.session.execute(query)
for job in result.scalars().all():
job.status = "failed"
job.error_message = "Job timed out (stale)"
job.completed_at = datetime.now(timezone.utc)
await self.session.flush()
async def has_active_job(self, kb_id: uuid.UUID) -> bool:
"""Check if there's an active (non-terminal) processing job for this KB."""
active_statuses = ["pending", "parsing_documents", "distilling"]
query = (
select(func.count())
.select_from(ProcessingJob)
.where(ProcessingJob.knowledge_base_id == kb_id)
.where(ProcessingJob.status.in_(active_statuses))
)
result = await self.session.execute(query)
return result.scalar() > 0
async def get_all_stale_active_jobs(self, stale_minutes: int = 5) -> list[ProcessingJob]:
"""Get all active jobs across all KBs older than stale_minutes (for startup recovery)."""
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_minutes)
active_statuses = ["pending", "parsing_documents", "distilling"]
query = (
select(ProcessingJob)
.where(ProcessingJob.status.in_(active_statuses))
.where(ProcessingJob.created_at < cutoff)
)
result = await self.session.execute(query)
return list(result.scalars().all())
async def reset_stuck_parsing_docs(self, kb_id: uuid.UUID) -> int:
"""Reset docs stuck at 'parsing' back to 'pending' so they can be re-parsed."""
query = (
select(SourceDocument)
.where(SourceDocument.knowledge_base_id == kb_id)
.where(SourceDocument.parse_status == "parsing")
)
result = await self.session.execute(query)
docs = result.scalars().all()
for doc in docs:
doc.parse_status = "pending"
await self.session.flush()
return len(docs)