"""REST API routes for Knowledge Base management.""" import difflib import uuid from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, UploadFile, File from sqlalchemy.ext.asyncio import AsyncSession from app.api.knowledge_base_schemas import ( KnowledgeBaseListItem, KnowledgeBaseDetail, SourceDocumentResponse, ProcessingJobResponse, SpecVersionListItem, SpecVersionDetail, DiffResponse, DiffLine, ) from app.dependencies.auth import get_current_user from app.models.database import get_db from app.repositories.knowledge_base_repository import KnowledgeBaseRepository from app.services.storage_service import storage_service kb_router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"]) # Allowed MIME types for source document upload ALLOWED_MIME_TYPES = { "application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx "application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx "text/html", "text/plain", "text/markdown", "image/png", "image/jpeg", "image/webp", } async def _get_user_info(session: AsyncSession, user_claims: dict) -> tuple: """Extract user ID and name from claims.""" from app.repositories.user_repository import UserRepository user_repo = UserRepository(session) azure_oid = user_claims.get("oid") or user_claims.get("sub") user_id = None user_name = user_claims.get("name", "Unknown") if azure_oid: user = await user_repo.get_or_create_from_azure( azure_ad_oid=azure_oid, email=user_claims.get("email", user_claims.get("preferred_username", "")), name=user_name, ) user_id = user.id return user_id, user_name @kb_router.get("", response_model=list[KnowledgeBaseListItem]) async def list_knowledge_bases( db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """List all knowledge bases with summary info.""" repo = KnowledgeBaseRepository(db) kbs = await repo.list_knowledge_bases() results = [] for kb in kbs: active_spec = next((sv for sv in kb.spec_versions if sv.is_active), None) latest_job = kb.processing_jobs[0] if kb.processing_jobs else None results.append(KnowledgeBaseListItem( id=kb.id, agent_key=kb.agent_key, display_name=kb.display_name, description=kb.description, source_document_count=len(kb.source_documents), active_spec_version=active_spec.version_number if active_spec else None, active_spec_char_count=active_spec.char_count if active_spec else None, latest_job_status=latest_job.status if latest_job else None, latest_job_completed_at=latest_job.completed_at if latest_job else None, created_at=kb.created_at, )) return results @kb_router.get("/{kb_id}", response_model=KnowledgeBaseDetail) async def get_knowledge_base( kb_id: uuid.UUID, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Get full detail for a knowledge base.""" repo = KnowledgeBaseRepository(db) kb = await repo.get_knowledge_base(kb_id) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") active_spec = next((sv for sv in kb.spec_versions if sv.is_active), None) latest_job = kb.processing_jobs[0] if kb.processing_jobs else None return KnowledgeBaseDetail( id=kb.id, agent_key=kb.agent_key, display_name=kb.display_name, description=kb.description, source_documents=[ SourceDocumentResponse( id=doc.id, knowledge_base_id=doc.knowledge_base_id, filename=doc.filename, file_storage_key=doc.file_storage_key, file_size_bytes=doc.file_size_bytes, mime_type=doc.mime_type, uploaded_by_name=doc.uploaded_by_name, parse_status=doc.parse_status, parse_error=doc.parse_error, created_at=doc.created_at, ) for doc in sorted(kb.source_documents, key=lambda d: d.created_at, reverse=True) ], active_spec_version=active_spec.version_number if active_spec else None, active_spec_char_count=active_spec.char_count if active_spec else None, latest_job=ProcessingJobResponse( id=latest_job.id, knowledge_base_id=latest_job.knowledge_base_id, status=latest_job.status, triggered_by_name=latest_job.triggered_by_name, total_documents=latest_job.total_documents, parsed_documents=latest_job.parsed_documents, spec_version_id=latest_job.spec_version_id, error_message=latest_job.error_message, started_at=latest_job.started_at, completed_at=latest_job.completed_at, created_at=latest_job.created_at, ) if latest_job else None, created_at=kb.created_at, ) @kb_router.post("/{kb_id}/documents", response_model=SourceDocumentResponse, status_code=201) async def upload_source_document( kb_id: uuid.UUID, file: UploadFile = File(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Upload a source document to a knowledge base.""" repo = KnowledgeBaseRepository(db) kb = await repo.get_knowledge_base(kb_id) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") # Validate file type if file.content_type and file.content_type not in ALLOWED_MIME_TYPES: raise HTTPException( status_code=400, detail=f"File type '{file.content_type}' is not supported. Supported types: PDF, DOCX, PPTX, XLSX, HTML, TXT, MD, PNG, JPG, WebP." ) user_id, user_name = await _get_user_info(db, user) # Read file data file_data = await file.read() # Create DB record first to get the ID doc = await repo.add_source_document( knowledge_base_id=kb_id, filename=file.filename or "unknown", file_storage_key="pending", # Will update after storing file_size_bytes=len(file_data), mime_type=file.content_type or "application/octet-stream", uploaded_by_id=user_id, uploaded_by_name=user_name, ) # Store file using the DB-generated doc ID storage_key = await storage_service.store_kb_document( file_data=file_data, kb_id=kb_id, doc_id=doc.id, filename=file.filename or "unknown", mime_type=file.content_type or "application/octet-stream", ) # Update the storage key doc.file_storage_key = storage_key return SourceDocumentResponse( id=doc.id, knowledge_base_id=doc.knowledge_base_id, filename=doc.filename, file_storage_key=doc.file_storage_key, file_size_bytes=doc.file_size_bytes, mime_type=doc.mime_type, uploaded_by_name=doc.uploaded_by_name, parse_status=doc.parse_status, parse_error=doc.parse_error, created_at=doc.created_at, ) @kb_router.delete("/{kb_id}/documents/{doc_id}", status_code=204) async def delete_source_document( kb_id: uuid.UUID, doc_id: uuid.UUID, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Remove a source document from a knowledge base.""" repo = KnowledgeBaseRepository(db) doc = await repo.remove_source_document(doc_id) if not doc: raise HTTPException(status_code=404, detail="Source document not found") # Delete file from storage await storage_service.delete_file(doc.file_storage_key) @kb_router.post("/{kb_id}/process", response_model=ProcessingJobResponse, status_code=201) async def trigger_processing( kb_id: uuid.UUID, request: Request, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Trigger the document processing pipeline for a knowledge base.""" kb_service = getattr(request.app.state, "knowledge_base_service", None) if kb_service is None: raise HTTPException( status_code=503, detail="Knowledge Base processing is not available. LLAMA_CLOUD_API_KEY is not configured." ) repo = KnowledgeBaseRepository(db) kb = await repo.get_knowledge_base(kb_id) if not kb: raise HTTPException(status_code=404, detail="Knowledge base not found") # Auto-fail stale pending jobs (older than 5 minutes) before checking await repo.fail_stale_jobs(kb_id) # Check for active jobs has_active = await repo.has_active_job(kb_id) if has_active: raise HTTPException(status_code=409, detail="A processing job is already running for this knowledge base.") # Check that there are source documents docs = await repo.get_source_documents(kb_id) if not docs: raise HTTPException(status_code=400, detail="No source documents to process.") user_id, user_name = await _get_user_info(db, user) # Create the job and commit immediately so the background task can see it job = await repo.create_processing_job( knowledge_base_id=kb_id, total_documents=len(docs), triggered_by_id=user_id, triggered_by_name=user_name, ) await db.commit() # Start background processing background_tasks.add_task( kb_service.process_documents, kb_id=kb_id, job_id=job.id, agent_key=kb.agent_key, user_id=user_id, user_name=user_name, ) return ProcessingJobResponse( id=job.id, knowledge_base_id=job.knowledge_base_id, status=job.status, triggered_by_name=job.triggered_by_name, total_documents=job.total_documents, parsed_documents=job.parsed_documents, spec_version_id=job.spec_version_id, error_message=job.error_message, started_at=job.started_at, completed_at=job.completed_at, created_at=job.created_at, ) @kb_router.get("/{kb_id}/jobs/{job_id}", response_model=ProcessingJobResponse) async def get_processing_job( kb_id: uuid.UUID, job_id: uuid.UUID, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Get the status of a processing job.""" repo = KnowledgeBaseRepository(db) job = await repo.get_processing_job(job_id) if not job or job.knowledge_base_id != kb_id: raise HTTPException(status_code=404, detail="Processing job not found") return ProcessingJobResponse( id=job.id, knowledge_base_id=job.knowledge_base_id, status=job.status, triggered_by_name=job.triggered_by_name, total_documents=job.total_documents, parsed_documents=job.parsed_documents, spec_version_id=job.spec_version_id, error_message=job.error_message, started_at=job.started_at, completed_at=job.completed_at, created_at=job.created_at, ) @kb_router.get("/{kb_id}/versions", response_model=list[SpecVersionListItem]) async def list_spec_versions( kb_id: uuid.UUID, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """List all spec versions for a knowledge base.""" repo = KnowledgeBaseRepository(db) versions = await repo.list_spec_versions(kb_id) return [ SpecVersionListItem( id=v.id, knowledge_base_id=v.knowledge_base_id, version_number=v.version_number, generated_by_name=v.generated_by_name, source_document_ids=v.source_document_ids, is_active=v.is_active, char_count=v.char_count, created_at=v.created_at, ) for v in versions ] @kb_router.get("/{kb_id}/versions/{version_id}", response_model=SpecVersionDetail) async def get_spec_version( kb_id: uuid.UUID, version_id: uuid.UUID, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Get full spec content for a version.""" repo = KnowledgeBaseRepository(db) version = await repo.get_spec_version(version_id) if not version or version.knowledge_base_id != kb_id: raise HTTPException(status_code=404, detail="Spec version not found") return SpecVersionDetail( id=version.id, knowledge_base_id=version.knowledge_base_id, version_number=version.version_number, content=version.content, generated_by_name=version.generated_by_name, source_document_ids=version.source_document_ids, is_active=version.is_active, char_count=version.char_count, created_at=version.created_at, ) @kb_router.get("/{kb_id}/versions/{v_a}/diff/{v_b}", response_model=DiffResponse) async def get_spec_diff( kb_id: uuid.UUID, v_a: uuid.UUID, v_b: uuid.UUID, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Compute diff between two spec versions.""" repo = KnowledgeBaseRepository(db) version_a = await repo.get_spec_version(v_a) version_b = await repo.get_spec_version(v_b) if not version_a or version_a.knowledge_base_id != kb_id: raise HTTPException(status_code=404, detail="Version A not found") if not version_b or version_b.knowledge_base_id != kb_id: raise HTTPException(status_code=404, detail="Version B not found") lines_a = version_a.content.splitlines(keepends=True) lines_b = version_b.content.splitlines(keepends=True) diff = list(difflib.unified_diff( lines_a, lines_b, fromfile=f"v{version_a.version_number}", tofile=f"v{version_b.version_number}", lineterm="", )) additions = 0 deletions = 0 diff_lines = [] old_line = 0 new_line = 0 for line in diff: if line.startswith("@@"): # Parse hunk header for line numbers import re match = re.match(r"@@ -(\d+)", line) if match: old_line = int(match.group(1)) - 1 new_match = re.search(r"\+(\d+)", line) new_line = int(new_match.group(1)) - 1 if new_match else 0 diff_lines.append(DiffLine(type="context", content=line.rstrip("\n"))) elif line.startswith("---") or line.startswith("+++"): diff_lines.append(DiffLine(type="context", content=line.rstrip("\n"))) elif line.startswith("+"): additions += 1 new_line += 1 diff_lines.append(DiffLine( type="add", content=line[1:].rstrip("\n"), line_number_new=new_line, )) elif line.startswith("-"): deletions += 1 old_line += 1 diff_lines.append(DiffLine( type="remove", content=line[1:].rstrip("\n"), line_number_old=old_line, )) else: old_line += 1 new_line += 1 diff_lines.append(DiffLine( type="context", content=line.rstrip("\n"), line_number_old=old_line, line_number_new=new_line, )) return DiffResponse( version_a=version_a.version_number, version_b=version_b.version_number, additions=additions, deletions=deletions, lines=diff_lines, ) @kb_router.post("/{kb_id}/versions/{version_id}/activate", response_model=SpecVersionDetail) async def activate_spec_version( kb_id: uuid.UUID, version_id: uuid.UUID, request: Request, db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user), ): """Activate (revert to) a specific spec version.""" repo = KnowledgeBaseRepository(db) version = await repo.activate_spec_version(version_id) if not version or version.knowledge_base_id != kb_id: raise HTTPException(status_code=404, detail="Spec version not found") # Invalidate reference docs cache kb = await repo.get_knowledge_base(kb_id) if kb: analysis_service = getattr(request.app.state, "analysis_service", None) if analysis_service: analysis_service.reference_docs.invalidate_cache(kb.agent_key) return SpecVersionDetail( id=version.id, knowledge_base_id=version.knowledge_base_id, version_number=version.version_number, content=version.content, generated_by_name=version.generated_by_name, source_document_ids=version.source_document_ids, is_active=version.is_active, char_count=version.char_count, created_at=version.created_at, )