The background task runs in its own DB session but the job row hadn't been committed yet by the request session. The background task couldn't find the job, causing FK violations when trying to create spec_versions. Fix: explicitly commit the request session after creating the job and before adding the background task, ensuring the job row is visible. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
477 lines
16 KiB
Python
477 lines
16 KiB
Python
"""REST API routes for Knowledge Base management."""
|
|
import difflib
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.knowledge_base_schemas import (
|
|
KnowledgeBaseListItem,
|
|
KnowledgeBaseDetail,
|
|
SourceDocumentResponse,
|
|
ProcessingJobResponse,
|
|
SpecVersionListItem,
|
|
SpecVersionDetail,
|
|
DiffResponse,
|
|
DiffLine,
|
|
)
|
|
from app.dependencies.auth import get_current_user
|
|
from app.models.database import get_db
|
|
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
|
from app.services.storage_service import storage_service
|
|
|
|
kb_router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
|
|
|
# Allowed MIME types for source document upload
|
|
ALLOWED_MIME_TYPES = {
|
|
"application/pdf",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx
|
|
"text/html",
|
|
"text/plain",
|
|
"text/markdown",
|
|
"image/png",
|
|
"image/jpeg",
|
|
"image/webp",
|
|
}
|
|
|
|
|
|
async def _get_user_info(session: AsyncSession, user_claims: dict) -> tuple:
|
|
"""Extract user ID and name from claims."""
|
|
from app.repositories.user_repository import UserRepository
|
|
user_repo = UserRepository(session)
|
|
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
|
user_id = None
|
|
user_name = user_claims.get("name", "Unknown")
|
|
if azure_oid:
|
|
user = await user_repo.get_or_create_from_azure(
|
|
azure_ad_oid=azure_oid,
|
|
email=user_claims.get("email", user_claims.get("preferred_username", "")),
|
|
name=user_name,
|
|
)
|
|
user_id = user.id
|
|
return user_id, user_name
|
|
|
|
|
|
@kb_router.get("", response_model=list[KnowledgeBaseListItem])
|
|
async def list_knowledge_bases(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""List all knowledge bases with summary info."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
kbs = await repo.list_knowledge_bases()
|
|
|
|
results = []
|
|
for kb in kbs:
|
|
active_spec = next((sv for sv in kb.spec_versions if sv.is_active), None)
|
|
latest_job = kb.processing_jobs[0] if kb.processing_jobs else None
|
|
|
|
results.append(KnowledgeBaseListItem(
|
|
id=kb.id,
|
|
agent_key=kb.agent_key,
|
|
display_name=kb.display_name,
|
|
description=kb.description,
|
|
source_document_count=len(kb.source_documents),
|
|
active_spec_version=active_spec.version_number if active_spec else None,
|
|
active_spec_char_count=active_spec.char_count if active_spec else None,
|
|
latest_job_status=latest_job.status if latest_job else None,
|
|
latest_job_completed_at=latest_job.completed_at if latest_job else None,
|
|
created_at=kb.created_at,
|
|
))
|
|
|
|
return results
|
|
|
|
|
|
@kb_router.get("/{kb_id}", response_model=KnowledgeBaseDetail)
|
|
async def get_knowledge_base(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get full detail for a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
active_spec = next((sv for sv in kb.spec_versions if sv.is_active), None)
|
|
latest_job = kb.processing_jobs[0] if kb.processing_jobs else None
|
|
|
|
return KnowledgeBaseDetail(
|
|
id=kb.id,
|
|
agent_key=kb.agent_key,
|
|
display_name=kb.display_name,
|
|
description=kb.description,
|
|
source_documents=[
|
|
SourceDocumentResponse(
|
|
id=doc.id,
|
|
knowledge_base_id=doc.knowledge_base_id,
|
|
filename=doc.filename,
|
|
file_storage_key=doc.file_storage_key,
|
|
file_size_bytes=doc.file_size_bytes,
|
|
mime_type=doc.mime_type,
|
|
uploaded_by_name=doc.uploaded_by_name,
|
|
parse_status=doc.parse_status,
|
|
parse_error=doc.parse_error,
|
|
created_at=doc.created_at,
|
|
)
|
|
for doc in sorted(kb.source_documents, key=lambda d: d.created_at, reverse=True)
|
|
],
|
|
active_spec_version=active_spec.version_number if active_spec else None,
|
|
active_spec_char_count=active_spec.char_count if active_spec else None,
|
|
latest_job=ProcessingJobResponse(
|
|
id=latest_job.id,
|
|
knowledge_base_id=latest_job.knowledge_base_id,
|
|
status=latest_job.status,
|
|
triggered_by_name=latest_job.triggered_by_name,
|
|
total_documents=latest_job.total_documents,
|
|
parsed_documents=latest_job.parsed_documents,
|
|
spec_version_id=latest_job.spec_version_id,
|
|
error_message=latest_job.error_message,
|
|
started_at=latest_job.started_at,
|
|
completed_at=latest_job.completed_at,
|
|
created_at=latest_job.created_at,
|
|
) if latest_job else None,
|
|
created_at=kb.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.post("/{kb_id}/documents", response_model=SourceDocumentResponse, status_code=201)
|
|
async def upload_source_document(
|
|
kb_id: uuid.UUID,
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Upload a source document to a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
# Validate file type
|
|
if file.content_type and file.content_type not in ALLOWED_MIME_TYPES:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"File type '{file.content_type}' is not supported. Supported types: PDF, DOCX, PPTX, XLSX, HTML, TXT, MD, PNG, JPG, WebP."
|
|
)
|
|
|
|
user_id, user_name = await _get_user_info(db, user)
|
|
|
|
# Read file data
|
|
file_data = await file.read()
|
|
|
|
# Create DB record first to get the ID
|
|
doc = await repo.add_source_document(
|
|
knowledge_base_id=kb_id,
|
|
filename=file.filename or "unknown",
|
|
file_storage_key="pending", # Will update after storing
|
|
file_size_bytes=len(file_data),
|
|
mime_type=file.content_type or "application/octet-stream",
|
|
uploaded_by_id=user_id,
|
|
uploaded_by_name=user_name,
|
|
)
|
|
|
|
# Store file using the DB-generated doc ID
|
|
storage_key = await storage_service.store_kb_document(
|
|
file_data=file_data,
|
|
kb_id=kb_id,
|
|
doc_id=doc.id,
|
|
filename=file.filename or "unknown",
|
|
mime_type=file.content_type or "application/octet-stream",
|
|
)
|
|
|
|
# Update the storage key
|
|
doc.file_storage_key = storage_key
|
|
|
|
return SourceDocumentResponse(
|
|
id=doc.id,
|
|
knowledge_base_id=doc.knowledge_base_id,
|
|
filename=doc.filename,
|
|
file_storage_key=doc.file_storage_key,
|
|
file_size_bytes=doc.file_size_bytes,
|
|
mime_type=doc.mime_type,
|
|
uploaded_by_name=doc.uploaded_by_name,
|
|
parse_status=doc.parse_status,
|
|
parse_error=doc.parse_error,
|
|
created_at=doc.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.delete("/{kb_id}/documents/{doc_id}", status_code=204)
|
|
async def delete_source_document(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Remove a source document from a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
doc = await repo.remove_source_document(doc_id)
|
|
if not doc:
|
|
raise HTTPException(status_code=404, detail="Source document not found")
|
|
|
|
# Delete file from storage
|
|
await storage_service.delete_file(doc.file_storage_key)
|
|
|
|
|
|
@kb_router.post("/{kb_id}/process", response_model=ProcessingJobResponse, status_code=201)
|
|
async def trigger_processing(
|
|
kb_id: uuid.UUID,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Trigger the document processing pipeline for a knowledge base."""
|
|
from app.main import knowledge_base_service as kb_service
|
|
|
|
if kb_service is None:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Knowledge Base processing is not available. LLAMA_CLOUD_API_KEY is not configured."
|
|
)
|
|
|
|
repo = KnowledgeBaseRepository(db)
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
# Auto-fail stale pending jobs (older than 5 minutes) before checking
|
|
await repo.fail_stale_jobs(kb_id)
|
|
|
|
# Check for active jobs
|
|
has_active = await repo.has_active_job(kb_id)
|
|
if has_active:
|
|
raise HTTPException(status_code=409, detail="A processing job is already running for this knowledge base.")
|
|
|
|
# Check that there are source documents
|
|
docs = await repo.get_source_documents(kb_id)
|
|
if not docs:
|
|
raise HTTPException(status_code=400, detail="No source documents to process.")
|
|
|
|
user_id, user_name = await _get_user_info(db, user)
|
|
|
|
# Create the job and commit immediately so the background task can see it
|
|
job = await repo.create_processing_job(
|
|
knowledge_base_id=kb_id,
|
|
total_documents=len(docs),
|
|
triggered_by_id=user_id,
|
|
triggered_by_name=user_name,
|
|
)
|
|
await db.commit()
|
|
|
|
# Start background processing
|
|
background_tasks.add_task(
|
|
kb_service.process_documents,
|
|
kb_id=kb_id,
|
|
job_id=job.id,
|
|
agent_key=kb.agent_key,
|
|
user_id=user_id,
|
|
user_name=user_name,
|
|
)
|
|
|
|
return ProcessingJobResponse(
|
|
id=job.id,
|
|
knowledge_base_id=job.knowledge_base_id,
|
|
status=job.status,
|
|
triggered_by_name=job.triggered_by_name,
|
|
total_documents=job.total_documents,
|
|
parsed_documents=job.parsed_documents,
|
|
spec_version_id=job.spec_version_id,
|
|
error_message=job.error_message,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
created_at=job.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}/jobs/{job_id}", response_model=ProcessingJobResponse)
|
|
async def get_processing_job(
|
|
kb_id: uuid.UUID,
|
|
job_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get the status of a processing job."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
job = await repo.get_processing_job(job_id)
|
|
if not job or job.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Processing job not found")
|
|
|
|
return ProcessingJobResponse(
|
|
id=job.id,
|
|
knowledge_base_id=job.knowledge_base_id,
|
|
status=job.status,
|
|
triggered_by_name=job.triggered_by_name,
|
|
total_documents=job.total_documents,
|
|
parsed_documents=job.parsed_documents,
|
|
spec_version_id=job.spec_version_id,
|
|
error_message=job.error_message,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
created_at=job.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}/versions", response_model=list[SpecVersionListItem])
|
|
async def list_spec_versions(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""List all spec versions for a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
versions = await repo.list_spec_versions(kb_id)
|
|
|
|
return [
|
|
SpecVersionListItem(
|
|
id=v.id,
|
|
knowledge_base_id=v.knowledge_base_id,
|
|
version_number=v.version_number,
|
|
generated_by_name=v.generated_by_name,
|
|
source_document_ids=v.source_document_ids,
|
|
is_active=v.is_active,
|
|
char_count=v.char_count,
|
|
created_at=v.created_at,
|
|
)
|
|
for v in versions
|
|
]
|
|
|
|
|
|
@kb_router.get("/{kb_id}/versions/{version_id}", response_model=SpecVersionDetail)
|
|
async def get_spec_version(
|
|
kb_id: uuid.UUID,
|
|
version_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get full spec content for a version."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
version = await repo.get_spec_version(version_id)
|
|
if not version or version.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Spec version not found")
|
|
|
|
return SpecVersionDetail(
|
|
id=version.id,
|
|
knowledge_base_id=version.knowledge_base_id,
|
|
version_number=version.version_number,
|
|
content=version.content,
|
|
generated_by_name=version.generated_by_name,
|
|
source_document_ids=version.source_document_ids,
|
|
is_active=version.is_active,
|
|
char_count=version.char_count,
|
|
created_at=version.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}/versions/{v_a}/diff/{v_b}", response_model=DiffResponse)
|
|
async def get_spec_diff(
|
|
kb_id: uuid.UUID,
|
|
v_a: uuid.UUID,
|
|
v_b: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Compute diff between two spec versions."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
version_a = await repo.get_spec_version(v_a)
|
|
version_b = await repo.get_spec_version(v_b)
|
|
|
|
if not version_a or version_a.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Version A not found")
|
|
if not version_b or version_b.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Version B not found")
|
|
|
|
lines_a = version_a.content.splitlines(keepends=True)
|
|
lines_b = version_b.content.splitlines(keepends=True)
|
|
|
|
diff = list(difflib.unified_diff(
|
|
lines_a, lines_b,
|
|
fromfile=f"v{version_a.version_number}",
|
|
tofile=f"v{version_b.version_number}",
|
|
lineterm="",
|
|
))
|
|
|
|
additions = 0
|
|
deletions = 0
|
|
diff_lines = []
|
|
old_line = 0
|
|
new_line = 0
|
|
|
|
for line in diff:
|
|
if line.startswith("@@"):
|
|
# Parse hunk header for line numbers
|
|
import re
|
|
match = re.match(r"@@ -(\d+)", line)
|
|
if match:
|
|
old_line = int(match.group(1)) - 1
|
|
new_match = re.search(r"\+(\d+)", line)
|
|
new_line = int(new_match.group(1)) - 1 if new_match else 0
|
|
diff_lines.append(DiffLine(type="context", content=line.rstrip("\n")))
|
|
elif line.startswith("---") or line.startswith("+++"):
|
|
diff_lines.append(DiffLine(type="context", content=line.rstrip("\n")))
|
|
elif line.startswith("+"):
|
|
additions += 1
|
|
new_line += 1
|
|
diff_lines.append(DiffLine(
|
|
type="add", content=line[1:].rstrip("\n"),
|
|
line_number_new=new_line,
|
|
))
|
|
elif line.startswith("-"):
|
|
deletions += 1
|
|
old_line += 1
|
|
diff_lines.append(DiffLine(
|
|
type="remove", content=line[1:].rstrip("\n"),
|
|
line_number_old=old_line,
|
|
))
|
|
else:
|
|
old_line += 1
|
|
new_line += 1
|
|
diff_lines.append(DiffLine(
|
|
type="context", content=line.rstrip("\n"),
|
|
line_number_old=old_line,
|
|
line_number_new=new_line,
|
|
))
|
|
|
|
return DiffResponse(
|
|
version_a=version_a.version_number,
|
|
version_b=version_b.version_number,
|
|
additions=additions,
|
|
deletions=deletions,
|
|
lines=diff_lines,
|
|
)
|
|
|
|
|
|
@kb_router.post("/{kb_id}/versions/{version_id}/activate", response_model=SpecVersionDetail)
|
|
async def activate_spec_version(
|
|
kb_id: uuid.UUID,
|
|
version_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Activate (revert to) a specific spec version."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
version = await repo.activate_spec_version(version_id)
|
|
if not version or version.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Spec version not found")
|
|
|
|
# Invalidate reference docs cache
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if kb:
|
|
from app.main import analysis_service
|
|
if analysis_service:
|
|
analysis_service.reference_docs.invalidate_cache(kb.agent_key)
|
|
|
|
return SpecVersionDetail(
|
|
id=version.id,
|
|
knowledge_base_id=version.knowledge_base_id,
|
|
version_number=version.version_number,
|
|
content=version.content,
|
|
generated_by_name=version.generated_by_name,
|
|
source_document_ids=version.source_document_ids,
|
|
is_active=version.is_active,
|
|
char_count=version.char_count,
|
|
created_at=version.created_at,
|
|
)
|