On server restart, stale active jobs are automatically resumed rather than failed. Docs already parsed in a prior run are skipped (resume from cache), docs stuck at 'parsing' are reset to 'pending' and re-parsed. - Repository: add get_all_stale_active_jobs() and reset_stuck_parsing_docs() - Service: skip already-parsed docs in _parse_doc(), reset stuck docs on start - Main: recover stale jobs via asyncio.create_task() in lifespan startup Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
342 lines
13 KiB
Python
342 lines
13 KiB
Python
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.models.models import KnowledgeBase, SourceDocument, SpecVersion, ProcessingJob
|
|
|
|
|
|
class KnowledgeBaseRepository:
|
|
"""Repository for knowledge base database operations."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
# ---- Knowledge Bases ----
|
|
|
|
async def list_knowledge_bases(self) -> list[KnowledgeBase]:
|
|
"""List all knowledge bases with eager-loaded relationships."""
|
|
query = (
|
|
select(KnowledgeBase)
|
|
.options(
|
|
selectinload(KnowledgeBase.source_documents),
|
|
selectinload(KnowledgeBase.spec_versions),
|
|
selectinload(KnowledgeBase.processing_jobs),
|
|
)
|
|
.order_by(KnowledgeBase.display_name)
|
|
)
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_knowledge_base(self, kb_id: uuid.UUID) -> Optional[KnowledgeBase]:
|
|
"""Get a knowledge base by ID with all relationships."""
|
|
query = (
|
|
select(KnowledgeBase)
|
|
.options(
|
|
selectinload(KnowledgeBase.source_documents),
|
|
selectinload(KnowledgeBase.spec_versions),
|
|
selectinload(KnowledgeBase.processing_jobs),
|
|
)
|
|
.where(KnowledgeBase.id == kb_id)
|
|
)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_knowledge_base_by_key(self, agent_key: str) -> Optional[KnowledgeBase]:
|
|
"""Get a knowledge base by agent_key."""
|
|
query = select(KnowledgeBase).where(KnowledgeBase.agent_key == agent_key)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
# ---- Source Documents ----
|
|
|
|
async def add_source_document(
|
|
self,
|
|
knowledge_base_id: uuid.UUID,
|
|
filename: str,
|
|
file_storage_key: str,
|
|
file_size_bytes: int,
|
|
mime_type: str,
|
|
uploaded_by_id: Optional[uuid.UUID] = None,
|
|
uploaded_by_name: Optional[str] = None,
|
|
) -> SourceDocument:
|
|
"""Add a source document to a knowledge base."""
|
|
doc = SourceDocument(
|
|
knowledge_base_id=knowledge_base_id,
|
|
filename=filename,
|
|
file_storage_key=file_storage_key,
|
|
file_size_bytes=file_size_bytes,
|
|
mime_type=mime_type,
|
|
uploaded_by_id=uploaded_by_id,
|
|
uploaded_by_name=uploaded_by_name,
|
|
)
|
|
self.session.add(doc)
|
|
await self.session.flush()
|
|
return doc
|
|
|
|
async def remove_source_document(self, doc_id: uuid.UUID) -> Optional[SourceDocument]:
|
|
"""Remove a source document by ID. Returns the deleted doc or None."""
|
|
query = select(SourceDocument).where(SourceDocument.id == doc_id)
|
|
result = await self.session.execute(query)
|
|
doc = result.scalar_one_or_none()
|
|
if doc:
|
|
await self.session.delete(doc)
|
|
await self.session.flush()
|
|
return doc
|
|
|
|
async def get_source_documents(self, kb_id: uuid.UUID) -> list[SourceDocument]:
|
|
"""Get all source documents for a knowledge base."""
|
|
query = (
|
|
select(SourceDocument)
|
|
.where(SourceDocument.knowledge_base_id == kb_id)
|
|
.order_by(SourceDocument.created_at.desc())
|
|
)
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_source_document(self, doc_id: uuid.UUID) -> Optional[SourceDocument]:
|
|
"""Get a single source document by ID."""
|
|
query = select(SourceDocument).where(SourceDocument.id == doc_id)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def update_source_document_parse_status(
|
|
self,
|
|
doc_id: uuid.UUID,
|
|
status: str,
|
|
parsed_markdown: Optional[str] = None,
|
|
parse_error: Optional[str] = None,
|
|
) -> None:
|
|
"""Update parse status of a source document."""
|
|
query = select(SourceDocument).where(SourceDocument.id == doc_id)
|
|
result = await self.session.execute(query)
|
|
doc = result.scalar_one_or_none()
|
|
if doc:
|
|
doc.parse_status = status
|
|
if parsed_markdown is not None:
|
|
doc.parsed_markdown = parsed_markdown
|
|
if parse_error is not None:
|
|
doc.parse_error = parse_error
|
|
await self.session.flush()
|
|
|
|
# ---- Spec Versions ----
|
|
|
|
async def get_active_spec_by_key(self, agent_key: str) -> Optional[SpecVersion]:
|
|
"""Get the active spec version for a given agent key."""
|
|
query = (
|
|
select(SpecVersion)
|
|
.join(KnowledgeBase)
|
|
.where(KnowledgeBase.agent_key == agent_key)
|
|
.where(SpecVersion.is_active == True)
|
|
)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def list_spec_versions(self, kb_id: uuid.UUID) -> list[SpecVersion]:
|
|
"""List all spec versions for a knowledge base."""
|
|
query = (
|
|
select(SpecVersion)
|
|
.where(SpecVersion.knowledge_base_id == kb_id)
|
|
.order_by(SpecVersion.version_number.desc())
|
|
)
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_spec_version(self, version_id: uuid.UUID) -> Optional[SpecVersion]:
|
|
"""Get a spec version by ID."""
|
|
query = select(SpecVersion).where(SpecVersion.id == version_id)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create_spec_version(
|
|
self,
|
|
knowledge_base_id: uuid.UUID,
|
|
content: str,
|
|
source_document_ids: Optional[list] = None,
|
|
generated_by_id: Optional[uuid.UUID] = None,
|
|
generated_by_name: Optional[str] = None,
|
|
processing_job_id: Optional[uuid.UUID] = None,
|
|
) -> SpecVersion:
|
|
"""Create a new spec version, auto-incrementing version number and deactivating prior."""
|
|
# Get next version number
|
|
max_query = (
|
|
select(func.coalesce(func.max(SpecVersion.version_number), 0))
|
|
.where(SpecVersion.knowledge_base_id == knowledge_base_id)
|
|
)
|
|
result = await self.session.execute(max_query)
|
|
next_version = result.scalar() + 1
|
|
|
|
# Deactivate all existing versions for this KB
|
|
deactivate_query = (
|
|
select(SpecVersion)
|
|
.where(SpecVersion.knowledge_base_id == knowledge_base_id)
|
|
.where(SpecVersion.is_active == True)
|
|
)
|
|
deactivate_result = await self.session.execute(deactivate_query)
|
|
for sv in deactivate_result.scalars().all():
|
|
sv.is_active = False
|
|
|
|
# Create the new version
|
|
spec = SpecVersion(
|
|
knowledge_base_id=knowledge_base_id,
|
|
version_number=next_version,
|
|
content=content,
|
|
source_document_ids=source_document_ids,
|
|
generated_by_id=generated_by_id,
|
|
generated_by_name=generated_by_name,
|
|
processing_job_id=processing_job_id,
|
|
is_active=True,
|
|
char_count=len(content),
|
|
)
|
|
self.session.add(spec)
|
|
await self.session.flush()
|
|
return spec
|
|
|
|
async def activate_spec_version(self, version_id: uuid.UUID) -> Optional[SpecVersion]:
|
|
"""Activate a specific version (revert), deactivating all others for same KB."""
|
|
query = select(SpecVersion).where(SpecVersion.id == version_id)
|
|
result = await self.session.execute(query)
|
|
target = result.scalar_one_or_none()
|
|
if not target:
|
|
return None
|
|
|
|
# Deactivate all versions for this KB
|
|
all_query = (
|
|
select(SpecVersion)
|
|
.where(SpecVersion.knowledge_base_id == target.knowledge_base_id)
|
|
.where(SpecVersion.is_active == True)
|
|
)
|
|
all_result = await self.session.execute(all_query)
|
|
for sv in all_result.scalars().all():
|
|
sv.is_active = False
|
|
|
|
# Activate the target
|
|
target.is_active = True
|
|
await self.session.flush()
|
|
return target
|
|
|
|
# ---- Processing Jobs ----
|
|
|
|
async def create_processing_job(
|
|
self,
|
|
knowledge_base_id: uuid.UUID,
|
|
total_documents: int,
|
|
triggered_by_id: Optional[uuid.UUID] = None,
|
|
triggered_by_name: Optional[str] = None,
|
|
) -> ProcessingJob:
|
|
"""Create a new processing job."""
|
|
job = ProcessingJob(
|
|
knowledge_base_id=knowledge_base_id,
|
|
total_documents=total_documents,
|
|
triggered_by_id=triggered_by_id,
|
|
triggered_by_name=triggered_by_name,
|
|
started_at=datetime.now(timezone.utc),
|
|
)
|
|
self.session.add(job)
|
|
await self.session.flush()
|
|
return job
|
|
|
|
async def update_processing_job(
|
|
self,
|
|
job_id: uuid.UUID,
|
|
status: Optional[str] = None,
|
|
parsed_documents: Optional[int] = None,
|
|
spec_version_id: Optional[uuid.UUID] = None,
|
|
error_message: Optional[str] = None,
|
|
completed_at: Optional[datetime] = None,
|
|
) -> Optional[ProcessingJob]:
|
|
"""Update a processing job's fields."""
|
|
query = select(ProcessingJob).where(ProcessingJob.id == job_id)
|
|
result = await self.session.execute(query)
|
|
job = result.scalar_one_or_none()
|
|
if not job:
|
|
return None
|
|
|
|
if status is not None:
|
|
job.status = status
|
|
if parsed_documents is not None:
|
|
job.parsed_documents = parsed_documents
|
|
if spec_version_id is not None:
|
|
job.spec_version_id = spec_version_id
|
|
if error_message is not None:
|
|
job.error_message = error_message
|
|
if completed_at is not None:
|
|
job.completed_at = completed_at
|
|
|
|
await self.session.flush()
|
|
return job
|
|
|
|
async def get_processing_job(self, job_id: uuid.UUID) -> Optional[ProcessingJob]:
|
|
"""Get a processing job by ID."""
|
|
query = select(ProcessingJob).where(ProcessingJob.id == job_id)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_latest_processing_job(self, kb_id: uuid.UUID) -> Optional[ProcessingJob]:
|
|
"""Get the most recent processing job for a knowledge base."""
|
|
query = (
|
|
select(ProcessingJob)
|
|
.where(ProcessingJob.knowledge_base_id == kb_id)
|
|
.order_by(ProcessingJob.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await self.session.execute(query)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def fail_stale_jobs(self, kb_id: uuid.UUID, stale_minutes: int = 5) -> None:
|
|
"""Mark stale pending/active jobs as failed (older than stale_minutes)."""
|
|
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_minutes)
|
|
active_statuses = ["pending", "parsing_documents", "distilling"]
|
|
query = (
|
|
select(ProcessingJob)
|
|
.where(ProcessingJob.knowledge_base_id == kb_id)
|
|
.where(ProcessingJob.status.in_(active_statuses))
|
|
.where(ProcessingJob.created_at < cutoff)
|
|
)
|
|
result = await self.session.execute(query)
|
|
for job in result.scalars().all():
|
|
job.status = "failed"
|
|
job.error_message = "Job timed out (stale)"
|
|
job.completed_at = datetime.now(timezone.utc)
|
|
await self.session.flush()
|
|
|
|
async def has_active_job(self, kb_id: uuid.UUID) -> bool:
|
|
"""Check if there's an active (non-terminal) processing job for this KB."""
|
|
active_statuses = ["pending", "parsing_documents", "distilling"]
|
|
query = (
|
|
select(func.count())
|
|
.select_from(ProcessingJob)
|
|
.where(ProcessingJob.knowledge_base_id == kb_id)
|
|
.where(ProcessingJob.status.in_(active_statuses))
|
|
)
|
|
result = await self.session.execute(query)
|
|
return result.scalar() > 0
|
|
|
|
async def get_all_stale_active_jobs(self, stale_minutes: int = 5) -> list[ProcessingJob]:
|
|
"""Get all active jobs across all KBs older than stale_minutes (for startup recovery)."""
|
|
cutoff = datetime.now(timezone.utc) - timedelta(minutes=stale_minutes)
|
|
active_statuses = ["pending", "parsing_documents", "distilling"]
|
|
query = (
|
|
select(ProcessingJob)
|
|
.where(ProcessingJob.status.in_(active_statuses))
|
|
.where(ProcessingJob.created_at < cutoff)
|
|
)
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def reset_stuck_parsing_docs(self, kb_id: uuid.UUID) -> int:
|
|
"""Reset docs stuck at 'parsing' back to 'pending' so they can be re-parsed."""
|
|
query = (
|
|
select(SourceDocument)
|
|
.where(SourceDocument.knowledge_base_id == kb_id)
|
|
.where(SourceDocument.parse_status == "parsing")
|
|
)
|
|
result = await self.session.execute(query)
|
|
docs = result.scalars().all()
|
|
for doc in docs:
|
|
doc.parse_status = "pending"
|
|
await self.session.flush()
|
|
return len(docs)
|