knowledge_base_service and analysis_service were local variables inside the lifespan() function — not module-level exports. Importing them via 'from app.main import ...' always failed with ImportError → 500. Use request.app.state (same pattern as analysis_routes.py) instead. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
479 lines
17 KiB
Python
479 lines
17 KiB
Python
"""REST API routes for Knowledge Base management."""
|
|
import difflib
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, UploadFile, File
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.knowledge_base_schemas import (
|
|
KnowledgeBaseListItem,
|
|
KnowledgeBaseDetail,
|
|
SourceDocumentResponse,
|
|
ProcessingJobResponse,
|
|
SpecVersionListItem,
|
|
SpecVersionDetail,
|
|
DiffResponse,
|
|
DiffLine,
|
|
)
|
|
from app.dependencies.auth import get_current_user
|
|
from app.models.database import get_db
|
|
from app.repositories.knowledge_base_repository import KnowledgeBaseRepository
|
|
from app.services.storage_service import storage_service
|
|
|
|
kb_router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
|
|
|
# Allowed MIME types for source document upload
|
|
ALLOWED_MIME_TYPES = {
|
|
"application/pdf",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx
|
|
"text/html",
|
|
"text/plain",
|
|
"text/markdown",
|
|
"image/png",
|
|
"image/jpeg",
|
|
"image/webp",
|
|
}
|
|
|
|
|
|
async def _get_user_info(session: AsyncSession, user_claims: dict) -> tuple:
|
|
"""Extract user ID and name from claims."""
|
|
from app.repositories.user_repository import UserRepository
|
|
user_repo = UserRepository(session)
|
|
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
|
user_id = None
|
|
user_name = user_claims.get("name", "Unknown")
|
|
if azure_oid:
|
|
user = await user_repo.get_or_create_from_azure(
|
|
azure_ad_oid=azure_oid,
|
|
email=user_claims.get("email", user_claims.get("preferred_username", "")),
|
|
name=user_name,
|
|
)
|
|
user_id = user.id
|
|
return user_id, user_name
|
|
|
|
|
|
@kb_router.get("", response_model=list[KnowledgeBaseListItem])
|
|
async def list_knowledge_bases(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""List all knowledge bases with summary info."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
kbs = await repo.list_knowledge_bases()
|
|
|
|
results = []
|
|
for kb in kbs:
|
|
active_spec = next((sv for sv in kb.spec_versions if sv.is_active), None)
|
|
latest_job = kb.processing_jobs[0] if kb.processing_jobs else None
|
|
|
|
results.append(KnowledgeBaseListItem(
|
|
id=kb.id,
|
|
agent_key=kb.agent_key,
|
|
display_name=kb.display_name,
|
|
description=kb.description,
|
|
source_document_count=len(kb.source_documents),
|
|
active_spec_version=active_spec.version_number if active_spec else None,
|
|
active_spec_char_count=active_spec.char_count if active_spec else None,
|
|
latest_job_status=latest_job.status if latest_job else None,
|
|
latest_job_completed_at=latest_job.completed_at if latest_job else None,
|
|
created_at=kb.created_at,
|
|
))
|
|
|
|
return results
|
|
|
|
|
|
@kb_router.get("/{kb_id}", response_model=KnowledgeBaseDetail)
|
|
async def get_knowledge_base(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get full detail for a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
active_spec = next((sv for sv in kb.spec_versions if sv.is_active), None)
|
|
latest_job = kb.processing_jobs[0] if kb.processing_jobs else None
|
|
|
|
return KnowledgeBaseDetail(
|
|
id=kb.id,
|
|
agent_key=kb.agent_key,
|
|
display_name=kb.display_name,
|
|
description=kb.description,
|
|
source_documents=[
|
|
SourceDocumentResponse(
|
|
id=doc.id,
|
|
knowledge_base_id=doc.knowledge_base_id,
|
|
filename=doc.filename,
|
|
file_storage_key=doc.file_storage_key,
|
|
file_size_bytes=doc.file_size_bytes,
|
|
mime_type=doc.mime_type,
|
|
uploaded_by_name=doc.uploaded_by_name,
|
|
parse_status=doc.parse_status,
|
|
parse_error=doc.parse_error,
|
|
created_at=doc.created_at,
|
|
)
|
|
for doc in sorted(kb.source_documents, key=lambda d: d.created_at, reverse=True)
|
|
],
|
|
active_spec_version=active_spec.version_number if active_spec else None,
|
|
active_spec_char_count=active_spec.char_count if active_spec else None,
|
|
latest_job=ProcessingJobResponse(
|
|
id=latest_job.id,
|
|
knowledge_base_id=latest_job.knowledge_base_id,
|
|
status=latest_job.status,
|
|
triggered_by_name=latest_job.triggered_by_name,
|
|
total_documents=latest_job.total_documents,
|
|
parsed_documents=latest_job.parsed_documents,
|
|
spec_version_id=latest_job.spec_version_id,
|
|
error_message=latest_job.error_message,
|
|
started_at=latest_job.started_at,
|
|
completed_at=latest_job.completed_at,
|
|
created_at=latest_job.created_at,
|
|
) if latest_job else None,
|
|
created_at=kb.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.post("/{kb_id}/documents", response_model=SourceDocumentResponse, status_code=201)
|
|
async def upload_source_document(
|
|
kb_id: uuid.UUID,
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Upload a source document to a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
# Validate file type
|
|
if file.content_type and file.content_type not in ALLOWED_MIME_TYPES:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"File type '{file.content_type}' is not supported. Supported types: PDF, DOCX, PPTX, XLSX, HTML, TXT, MD, PNG, JPG, WebP."
|
|
)
|
|
|
|
user_id, user_name = await _get_user_info(db, user)
|
|
|
|
# Read file data
|
|
file_data = await file.read()
|
|
|
|
# Create DB record first to get the ID
|
|
doc = await repo.add_source_document(
|
|
knowledge_base_id=kb_id,
|
|
filename=file.filename or "unknown",
|
|
file_storage_key="pending", # Will update after storing
|
|
file_size_bytes=len(file_data),
|
|
mime_type=file.content_type or "application/octet-stream",
|
|
uploaded_by_id=user_id,
|
|
uploaded_by_name=user_name,
|
|
)
|
|
|
|
# Store file using the DB-generated doc ID
|
|
storage_key = await storage_service.store_kb_document(
|
|
file_data=file_data,
|
|
kb_id=kb_id,
|
|
doc_id=doc.id,
|
|
filename=file.filename or "unknown",
|
|
mime_type=file.content_type or "application/octet-stream",
|
|
)
|
|
|
|
# Update the storage key
|
|
doc.file_storage_key = storage_key
|
|
|
|
return SourceDocumentResponse(
|
|
id=doc.id,
|
|
knowledge_base_id=doc.knowledge_base_id,
|
|
filename=doc.filename,
|
|
file_storage_key=doc.file_storage_key,
|
|
file_size_bytes=doc.file_size_bytes,
|
|
mime_type=doc.mime_type,
|
|
uploaded_by_name=doc.uploaded_by_name,
|
|
parse_status=doc.parse_status,
|
|
parse_error=doc.parse_error,
|
|
created_at=doc.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.delete("/{kb_id}/documents/{doc_id}", status_code=204)
|
|
async def delete_source_document(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Remove a source document from a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
doc = await repo.remove_source_document(doc_id)
|
|
if not doc:
|
|
raise HTTPException(status_code=404, detail="Source document not found")
|
|
|
|
# Delete file from storage
|
|
await storage_service.delete_file(doc.file_storage_key)
|
|
|
|
|
|
@kb_router.post("/{kb_id}/process", response_model=ProcessingJobResponse, status_code=201)
|
|
async def trigger_processing(
|
|
kb_id: uuid.UUID,
|
|
request: Request,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Trigger the document processing pipeline for a knowledge base."""
|
|
kb_service = getattr(request.app.state, "knowledge_base_service", None)
|
|
|
|
if kb_service is None:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Knowledge Base processing is not available. LLAMA_CLOUD_API_KEY is not configured."
|
|
)
|
|
|
|
repo = KnowledgeBaseRepository(db)
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
# Auto-fail stale pending jobs (older than 5 minutes) before checking
|
|
await repo.fail_stale_jobs(kb_id)
|
|
|
|
# Check for active jobs
|
|
has_active = await repo.has_active_job(kb_id)
|
|
if has_active:
|
|
raise HTTPException(status_code=409, detail="A processing job is already running for this knowledge base.")
|
|
|
|
# Check that there are source documents
|
|
docs = await repo.get_source_documents(kb_id)
|
|
if not docs:
|
|
raise HTTPException(status_code=400, detail="No source documents to process.")
|
|
|
|
user_id, user_name = await _get_user_info(db, user)
|
|
|
|
# Create the job and commit immediately so the background task can see it
|
|
job = await repo.create_processing_job(
|
|
knowledge_base_id=kb_id,
|
|
total_documents=len(docs),
|
|
triggered_by_id=user_id,
|
|
triggered_by_name=user_name,
|
|
)
|
|
await db.commit()
|
|
|
|
# Start background processing
|
|
background_tasks.add_task(
|
|
kb_service.process_documents,
|
|
kb_id=kb_id,
|
|
job_id=job.id,
|
|
agent_key=kb.agent_key,
|
|
user_id=user_id,
|
|
user_name=user_name,
|
|
)
|
|
|
|
return ProcessingJobResponse(
|
|
id=job.id,
|
|
knowledge_base_id=job.knowledge_base_id,
|
|
status=job.status,
|
|
triggered_by_name=job.triggered_by_name,
|
|
total_documents=job.total_documents,
|
|
parsed_documents=job.parsed_documents,
|
|
spec_version_id=job.spec_version_id,
|
|
error_message=job.error_message,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
created_at=job.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}/jobs/{job_id}", response_model=ProcessingJobResponse)
|
|
async def get_processing_job(
|
|
kb_id: uuid.UUID,
|
|
job_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get the status of a processing job."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
job = await repo.get_processing_job(job_id)
|
|
if not job or job.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Processing job not found")
|
|
|
|
return ProcessingJobResponse(
|
|
id=job.id,
|
|
knowledge_base_id=job.knowledge_base_id,
|
|
status=job.status,
|
|
triggered_by_name=job.triggered_by_name,
|
|
total_documents=job.total_documents,
|
|
parsed_documents=job.parsed_documents,
|
|
spec_version_id=job.spec_version_id,
|
|
error_message=job.error_message,
|
|
started_at=job.started_at,
|
|
completed_at=job.completed_at,
|
|
created_at=job.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}/versions", response_model=list[SpecVersionListItem])
|
|
async def list_spec_versions(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""List all spec versions for a knowledge base."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
versions = await repo.list_spec_versions(kb_id)
|
|
|
|
return [
|
|
SpecVersionListItem(
|
|
id=v.id,
|
|
knowledge_base_id=v.knowledge_base_id,
|
|
version_number=v.version_number,
|
|
generated_by_name=v.generated_by_name,
|
|
source_document_ids=v.source_document_ids,
|
|
is_active=v.is_active,
|
|
char_count=v.char_count,
|
|
created_at=v.created_at,
|
|
)
|
|
for v in versions
|
|
]
|
|
|
|
|
|
@kb_router.get("/{kb_id}/versions/{version_id}", response_model=SpecVersionDetail)
|
|
async def get_spec_version(
|
|
kb_id: uuid.UUID,
|
|
version_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Get full spec content for a version."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
version = await repo.get_spec_version(version_id)
|
|
if not version or version.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Spec version not found")
|
|
|
|
return SpecVersionDetail(
|
|
id=version.id,
|
|
knowledge_base_id=version.knowledge_base_id,
|
|
version_number=version.version_number,
|
|
content=version.content,
|
|
generated_by_name=version.generated_by_name,
|
|
source_document_ids=version.source_document_ids,
|
|
is_active=version.is_active,
|
|
char_count=version.char_count,
|
|
created_at=version.created_at,
|
|
)
|
|
|
|
|
|
@kb_router.get("/{kb_id}/versions/{v_a}/diff/{v_b}", response_model=DiffResponse)
|
|
async def get_spec_diff(
|
|
kb_id: uuid.UUID,
|
|
v_a: uuid.UUID,
|
|
v_b: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Compute diff between two spec versions."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
version_a = await repo.get_spec_version(v_a)
|
|
version_b = await repo.get_spec_version(v_b)
|
|
|
|
if not version_a or version_a.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Version A not found")
|
|
if not version_b or version_b.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Version B not found")
|
|
|
|
lines_a = version_a.content.splitlines(keepends=True)
|
|
lines_b = version_b.content.splitlines(keepends=True)
|
|
|
|
diff = list(difflib.unified_diff(
|
|
lines_a, lines_b,
|
|
fromfile=f"v{version_a.version_number}",
|
|
tofile=f"v{version_b.version_number}",
|
|
lineterm="",
|
|
))
|
|
|
|
additions = 0
|
|
deletions = 0
|
|
diff_lines = []
|
|
old_line = 0
|
|
new_line = 0
|
|
|
|
for line in diff:
|
|
if line.startswith("@@"):
|
|
# Parse hunk header for line numbers
|
|
import re
|
|
match = re.match(r"@@ -(\d+)", line)
|
|
if match:
|
|
old_line = int(match.group(1)) - 1
|
|
new_match = re.search(r"\+(\d+)", line)
|
|
new_line = int(new_match.group(1)) - 1 if new_match else 0
|
|
diff_lines.append(DiffLine(type="context", content=line.rstrip("\n")))
|
|
elif line.startswith("---") or line.startswith("+++"):
|
|
diff_lines.append(DiffLine(type="context", content=line.rstrip("\n")))
|
|
elif line.startswith("+"):
|
|
additions += 1
|
|
new_line += 1
|
|
diff_lines.append(DiffLine(
|
|
type="add", content=line[1:].rstrip("\n"),
|
|
line_number_new=new_line,
|
|
))
|
|
elif line.startswith("-"):
|
|
deletions += 1
|
|
old_line += 1
|
|
diff_lines.append(DiffLine(
|
|
type="remove", content=line[1:].rstrip("\n"),
|
|
line_number_old=old_line,
|
|
))
|
|
else:
|
|
old_line += 1
|
|
new_line += 1
|
|
diff_lines.append(DiffLine(
|
|
type="context", content=line.rstrip("\n"),
|
|
line_number_old=old_line,
|
|
line_number_new=new_line,
|
|
))
|
|
|
|
return DiffResponse(
|
|
version_a=version_a.version_number,
|
|
version_b=version_b.version_number,
|
|
additions=additions,
|
|
deletions=deletions,
|
|
lines=diff_lines,
|
|
)
|
|
|
|
|
|
@kb_router.post("/{kb_id}/versions/{version_id}/activate", response_model=SpecVersionDetail)
|
|
async def activate_spec_version(
|
|
kb_id: uuid.UUID,
|
|
version_id: uuid.UUID,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: dict = Depends(get_current_user),
|
|
):
|
|
"""Activate (revert to) a specific spec version."""
|
|
repo = KnowledgeBaseRepository(db)
|
|
version = await repo.activate_spec_version(version_id)
|
|
if not version or version.knowledge_base_id != kb_id:
|
|
raise HTTPException(status_code=404, detail="Spec version not found")
|
|
|
|
# Invalidate reference docs cache
|
|
kb = await repo.get_knowledge_base(kb_id)
|
|
if kb:
|
|
analysis_service = getattr(request.app.state, "analysis_service", None)
|
|
if analysis_service:
|
|
analysis_service.reference_docs.invalidate_cache(kb.agent_key)
|
|
|
|
return SpecVersionDetail(
|
|
id=version.id,
|
|
knowledge_base_id=version.knowledge_base_id,
|
|
version_number=version.version_number,
|
|
content=version.content,
|
|
generated_by_name=version.generated_by_name,
|
|
source_document_ids=version.source_document_ids,
|
|
is_active=version.is_active,
|
|
char_count=version.char_count,
|
|
created_at=version.created_at,
|
|
)
|