"""Client document upload and AI matching endpoints.""" import logging from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db, async_session from app.models.gmal import GmalAsset from app.models.project import Project, ClientAsset, Match, ProjectStatus, MatchConfidence from app.schemas.project import ClientAssetOut, ClientAssetUpdate, MatchOut, MatchSelectRequest, ManualMatchRequest from app.services.doc_parser import extract_text_from_file, parse_text_with_ai from app.services.ai_matching import match_client_assets router = APIRouter() logger = logging.getLogger(__name__) async def _background_parse(project_id: int, filename: str, text: str, metadata: dict): """Run AI parsing and save results in the background (own DB session).""" async with async_session() as db: try: result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: return # Stage 3: AI parsing try: extracted, usage_info = parse_text_with_ai(text) except Exception as e: logger.error(f"AI parsing failed for project {project_id}: {e}") project.status = ProjectStatus.DRAFT project.parse_stage = f"AI parsing failed: {str(e)}" await db.commit() return # Save AI costs project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0) project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0) project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0) project.ai_call_count = (project.ai_call_count or 0) + 1 # Stage 4: Saving results project.parse_stage = f"AI found {len(extracted)} assets. Saving..." await db.commit() # Clear existing client assets existing = await db.execute( select(ClientAsset).where(ClientAsset.project_id == project_id) ) for ca in existing.scalars().all(): await db.delete(ca) # Create client asset records (skip zero-quantity assets) assets = [] for idx, item in enumerate(extracted): volume = item.get("volume", 1) if volume <= 0: continue ca = ClientAsset( project_id=project_id, raw_name=item.get("name", "Unknown"), raw_description=item.get("description", ""), volume=volume, sort_order=idx + 1, ) db.add(ca) assets.append(ca) project.status = ProjectStatus.REVIEW project.parse_stage = f"Done! {len(assets)} assets extracted." await db.commit() logger.info(f"Background parse complete for project {project_id}: {len(assets)} assets") except Exception as e: logger.error(f"Background parse error for project {project_id}: {e}") @router.post("/{project_id}/upload") async def upload_client_document( project_id: int, background_tasks: BackgroundTasks, file: UploadFile = File(...), db: AsyncSession = Depends(get_db), ): """Upload a client document and extract assets using AI.""" project = await _get_project(project_id, db) # Stage 1: Read file and save to data dir import os from app.config import settings content = await file.read() save_path = os.path.join(settings.data_dir, file.filename) with open(save_path, "wb") as f: f.write(content) project.source_filename = file.filename project.status = ProjectStatus.PARSING project.parse_stage = f"Uploading {file.filename}..." await db.commit() # Stage 2: Extract text (fast, synchronous) project.parse_stage = "Extracting text from document..." await db.commit() try: text, metadata = extract_text_from_file(content, file.filename) except Exception as e: project.status = ProjectStatus.DRAFT project.parse_stage = None await db.commit() raise HTTPException(status_code=400, detail=f"Failed to extract text: {str(e)}") sheets_info = f" ({metadata['sheet_count']} sheets)" if metadata['sheet_count'] else "" project.parse_stage = f"Extracted {metadata['char_count']:,} characters{sheets_info}. Sending to AI..." await db.commit() # Stage 3+4: AI parsing runs in background — return 202 immediately background_tasks.add_task(_background_parse, project_id, file.filename, text, metadata) return { "message": f"Document received. AI parsing started for {file.filename}.", "status": "parsing", } @router.get("/{project_id}/brief-analysis") async def get_brief_analysis(project_id: int, db: AsyncSession = Depends(get_db)): """Get the structured brief analysis for a project.""" project = await _get_project(project_id, db) if not project.brief_analysis: return {"status": "not_analyzed"} import json try: return {"status": "analyzed", "analysis": json.loads(project.brief_analysis)} except json.JSONDecodeError: return {"status": "error", "analysis": None} @router.post("/{project_id}/analyze-brief") async def analyze_brief_endpoint(project_id: int, db: AsyncSession = Depends(get_db)): """Run AI analysis on the uploaded document.""" from app.services.rfp_analysis import analyze_brief from app.services.doc_parser import extract_text_from_file project = await _get_project(project_id, db) if not project.source_filename: raise HTTPException(status_code=400, detail="No document uploaded yet") # Read the file from data dir import os from app.config import settings filepath = os.path.join(settings.data_dir, project.source_filename) # Try to read from the stored file - if not available, re-extract from any recent upload # For now, we store the extracted text in parse_stage temporarily during upload # We need the raw text - let's store it if not os.path.exists(filepath): return {"status": "error", "detail": "Source file not found on disk. Re-upload the document."} with open(filepath, "rb") as f: content = f.read() text, metadata = extract_text_from_file(content, project.source_filename) analysis = await analyze_brief(db, project, text) return {"status": "analyzed", "analysis": analysis} @router.get("/{project_id}/client-assets", response_model=list[ClientAssetOut]) async def list_client_assets(project_id: int, db: AsyncSession = Depends(get_db)): result = await db.execute( select(ClientAsset) .where(ClientAsset.project_id == project_id) .order_by(ClientAsset.sort_order) ) return result.scalars().all() @router.put("/{project_id}/client-assets/{asset_id}", response_model=ClientAssetOut) async def update_client_asset( project_id: int, asset_id: int, data: ClientAssetUpdate, db: AsyncSession = Depends(get_db), ): result = await db.execute( select(ClientAsset).where(ClientAsset.id == asset_id, ClientAsset.project_id == project_id) ) ca = result.scalar_one_or_none() if not ca: raise HTTPException(status_code=404, detail="Client asset not found") if data.raw_name is not None: ca.raw_name = data.raw_name if data.raw_description is not None: ca.raw_description = data.raw_description if data.volume is not None: ca.volume = data.volume await db.commit() await db.refresh(ca) return ca async def _background_match(project_id: int, asset_snapshots: list): """Run AI matching in the background (own DB session).""" async with async_session() as db: try: result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: return # Reconstruct ClientAsset-like objects from snapshots for match_client_assets ca_result = await db.execute( select(ClientAsset).where(ClientAsset.id.in_([s["id"] for s in asset_snapshots])) .order_by(ClientAsset.sort_order) ) client_assets = ca_result.scalars().all() matches = await match_client_assets(db, project_id, client_assets) await db.refresh(project) project.status = ProjectStatus.REVIEW await db.commit() logger.info(f"Background match complete for project {project_id}: {len(matches)} matches") except Exception as e: logger.error(f"Background match error for project {project_id}: {e}") try: async with async_session() as db2: result = await db2.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if project: project.status = ProjectStatus.REVIEW await db2.commit() except Exception: pass @router.post("/{project_id}/match") async def run_matching( project_id: int, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db), ): """Trigger AI matching for all client assets in this project.""" project = await _get_project(project_id, db) # Get client assets result = await db.execute( select(ClientAsset).where(ClientAsset.project_id == project_id).order_by(ClientAsset.sort_order) ) client_assets = result.scalars().all() if not client_assets: raise HTTPException(status_code=400, detail="No client assets to match. Upload a document first.") # Snapshot IDs before clearing (ORM objects expire after commit) asset_snapshots = [{"id": ca.id} for ca in client_assets] # Clear existing matches for ca in client_assets: matches_result = await db.execute(select(Match).where(Match.client_asset_id == ca.id)) for m in matches_result.scalars().all(): await db.delete(m) project.status = ProjectStatus.MATCHING await db.commit() # Run matching in background — return 202 immediately background_tasks.add_task(_background_match, project_id, asset_snapshots) return { "message": f"Matching started for {len(client_assets)} client assets.", "status": "matching", } @router.post("/{project_id}/match/cancel") async def cancel_matching_endpoint(project_id: int, db: AsyncSession = Depends(get_db)): """Cancel an in-progress matching run.""" from app.services.ai_matching import cancel_matching cancel_matching(project_id) project = await _get_project(project_id, db) project.status = ProjectStatus.REVIEW await db.commit() return {"detail": "Matching cancellation requested"} @router.post("/{project_id}/refine") async def refine_matches_endpoint( project_id: int, data: dict, db: AsyncSession = Depends(get_db), ): """Interpret a natural language instruction to refine matches.""" from app.services.match_refiner import refine_matches instruction = data.get("instruction", "") if not instruction: raise HTTPException(status_code=400, detail="No instruction provided") result = await refine_matches(db, project_id, instruction) # If there are assets to re-match, trigger matching for just those if result.get("rematch_count", 0) > 0: rematch_ids = result["rematch_asset_ids"] ca_result = await db.execute( select(ClientAsset).where(ClientAsset.id.in_(rematch_ids)).order_by(ClientAsset.sort_order) ) client_assets = ca_result.scalars().all() if client_assets: matches = await match_client_assets(db, project_id, client_assets) result["new_matches"] = len(matches) return result @router.post("/{project_id}/matches/{match_id}/feedback") async def submit_match_feedback( project_id: int, match_id: int, data: dict, db: AsyncSession = Depends(get_db), ): """Store feedback on a match (confirm or reject) for the learning system.""" from app.models.feedback import MatchFeedback result = await db.execute(select(Match).where(Match.id == match_id)) match = result.scalar_one_or_none() if not match: raise HTTPException(status_code=404, detail="Match not found") # Get the client asset name for the feedback record ca_result = await db.execute(select(ClientAsset).where(ClientAsset.id == match.client_asset_id)) ca = ca_result.scalar_one_or_none() confirmed = data.get("confirmed", True) comment = data.get("comment", "") feedback = MatchFeedback( client_term=(ca.raw_name or "").strip().lower() if ca else "", client_description=ca.raw_description if ca else None, gmal_asset_id=match.gmal_asset_id, confirmed=confirmed, user_comment=comment, ) db.add(feedback) await db.commit() return {"detail": f"Feedback {'confirmed' if confirmed else 'rejected'} stored"} @router.get("/{project_id}/matches", response_model=list[MatchOut]) async def list_matches(project_id: int, db: AsyncSession = Depends(get_db)): """Get all matches for a project, grouped by client asset.""" # Get client asset IDs for this project ca_result = await db.execute( select(ClientAsset.id).where(ClientAsset.project_id == project_id) ) ca_ids = [r[0] for r in ca_result.all()] if not ca_ids: return [] result = await db.execute( select(Match, GmalAsset) .join(GmalAsset, Match.gmal_asset_id == GmalAsset.id) .where(Match.client_asset_id.in_(ca_ids)) .order_by(Match.client_asset_id, Match.rank) ) matches = [] for match, gmal in result.all(): matches.append(MatchOut( id=match.id, client_asset_id=match.client_asset_id, gmal_asset_id=match.gmal_asset_id, gmal_id=gmal.gmal_id, gmal_name=gmal.asset_name, gmal_unique_name=gmal.unique_name, confidence=match.confidence.value, confidence_score=float(match.confidence_score) if match.confidence_score else None, ai_reasoning=match.ai_reasoning, caveat_text=match.caveat_text, is_selected=match.is_selected, rank=match.rank, )) return matches @router.put("/{project_id}/matches/{match_id}/select") async def select_match( project_id: int, match_id: int, data: MatchSelectRequest, db: AsyncSession = Depends(get_db), ): """Select or deselect a match. Deselects other matches for the same client asset.""" result = await db.execute(select(Match).where(Match.id == match_id)) match = result.scalar_one_or_none() if not match: raise HTTPException(status_code=404, detail="Match not found") if data.is_selected: # Deselect all other matches for this client asset siblings = await db.execute( select(Match).where(Match.client_asset_id == match.client_asset_id) ) for sibling in siblings.scalars().all(): sibling.is_selected = False match.is_selected = data.is_selected await db.commit() return {"detail": "Match updated"} @router.post("/{project_id}/matches/{client_asset_id}/manual") async def manual_match( project_id: int, client_asset_id: int, data: ManualMatchRequest, db: AsyncSession = Depends(get_db), ): """Manually assign a GMAL asset to a client asset.""" # Verify client asset belongs to project ca_result = await db.execute( select(ClientAsset).where(ClientAsset.id == client_asset_id, ClientAsset.project_id == project_id) ) ca = ca_result.scalar_one_or_none() if not ca: raise HTTPException(status_code=404, detail="Client asset not found") # Verify GMAL asset exists gmal_result = await db.execute(select(GmalAsset).where(GmalAsset.id == data.gmal_asset_id)) gmal = gmal_result.scalar_one_or_none() if not gmal: raise HTTPException(status_code=404, detail="GMAL asset not found") # Deselect existing matches existing = await db.execute(select(Match).where(Match.client_asset_id == client_asset_id)) for m in existing.scalars().all(): m.is_selected = False # Create manual match match = Match( client_asset_id=client_asset_id, gmal_asset_id=data.gmal_asset_id, confidence=MatchConfidence.EXACT, confidence_score=1.0, ai_reasoning="Manually assigned by user", caveat_text="", is_selected=True, rank=0, ) db.add(match) await db.commit() return {"detail": f"Manually matched to {gmal.gmal_id}"} async def _get_project(project_id: int, db: AsyncSession) -> Project: result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: raise HTTPException(status_code=404, detail="Project not found") return project