"""Client document upload and AI matching endpoints.""" import logging from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db, async_session from app.models.gmal import GmalAsset from app.models.project import Project, ClientAsset, Match, ProjectStatus, MatchConfidence from app.schemas.project import ClientAssetOut, ClientAssetUpdate, MatchOut, MatchSelectRequest, ManualMatchRequest from app.services.doc_parser import extract_text_from_file, parse_text_with_ai, deep_pass1_structure_analysis, deep_pass2_guided_extraction, SYSTEM_PROMPT, EXTRACT_TOOLS from app.services.ai_matching import match_client_assets router = APIRouter() logger = logging.getLogger(__name__) async def _background_parse(project_id: int, filename: str, text: str, metadata: dict, mode: str = "normal", user_context: str = ""): """Run AI parsing and save results in the background (own DB session).""" async with async_session() as db: try: result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: return # Stage 3: AI parsing (normal or deep) try: if mode == "deep": # Pass 1: Structure analysis import time start = time.time() project.parse_stage = "Deep extraction Pass 1/2: Analyzing spreadsheet structure... (this takes 20-40 seconds)" await db.commit() structure_analysis, usage1 = deep_pass1_structure_analysis(text, user_context) elapsed1 = int(time.time() - start) project.ai_input_tokens = (project.ai_input_tokens or 0) + usage1.get("input_tokens", 0) project.ai_output_tokens = (project.ai_output_tokens or 0) + usage1.get("output_tokens", 0) project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage1.get("cost_usd", 0) project.ai_call_count = (project.ai_call_count or 0) + 1 project.parse_stage = f"Pass 1 complete ({elapsed1}s). Pass 2/2: Extracting assets using structure analysis..." await db.commit() # Pass 2: Guided extraction extracted, usage2 = deep_pass2_guided_extraction(text, structure_analysis, user_context) elapsed2 = int(time.time() - start) usage_info = { "input_tokens": usage1.get("input_tokens", 0) + usage2.get("input_tokens", 0), "output_tokens": usage1.get("output_tokens", 0) + usage2.get("output_tokens", 0), "cost_usd": usage1.get("cost_usd", 0) + usage2.get("cost_usd", 0), } project.parse_stage = f"Deep extraction complete ({elapsed2}s total). Found {len(extracted)} assets." await db.commit() else: extracted, usage_info = parse_text_with_ai(text, user_context) except Exception as e: logger.error(f"AI parsing failed for project {project_id}: {e}") project.status = ProjectStatus.DRAFT project.parse_stage = f"AI parsing failed: {str(e)}" await db.commit() return # Save AI costs project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0) project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0) project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0) project.ai_call_count = (project.ai_call_count or 0) + 1 # Stage 4: Saving results project.parse_stage = f"AI found {len(extracted)} assets. Saving..." await db.commit() # Clear existing client assets existing = await db.execute( select(ClientAsset).where(ClientAsset.project_id == project_id) ) for ca in existing.scalars().all(): await db.delete(ca) # Create client asset records (skip zero-quantity assets) assets = [] for idx, item in enumerate(extracted): volume = item.get("volume", 1) if volume <= 0: continue ca = ClientAsset( project_id=project_id, raw_name=item.get("name", "Unknown"), raw_description=item.get("description", ""), client_tier=item.get("tier", "") or None, volume=volume, sort_order=idx + 1, ) db.add(ca) assets.append(ca) project.status = ProjectStatus.REVIEW project.parse_stage = f"Done! {len(assets)} assets extracted." await db.commit() logger.info(f"Background parse complete for project {project_id}: {len(assets)} assets") except Exception as e: logger.error(f"Background parse error for project {project_id}: {e}") @router.post("/{project_id}/upload") async def upload_client_document( project_id: int, background_tasks: BackgroundTasks, files: list[UploadFile] = File(...), mode: str = "normal", user_context: str = "", db: AsyncSession = Depends(get_db), ): """Upload one or more client documents and extract assets using AI.""" project = await _get_project(project_id, db) import os from app.config import settings filenames = [] all_text_parts = [] total_chars = 0 total_sheets = 0 project.status = ProjectStatus.PARSING project.parse_stage = f"Uploading {len(files)} file(s)..." await db.commit() # Stage 1+2: Read and extract text from each file for file in files: content = await file.read() save_path = os.path.join(settings.data_dir, file.filename) with open(save_path, "wb") as f: f.write(content) filenames.append(file.filename) project.parse_stage = f"Extracting text from {file.filename}..." await db.commit() try: text, metadata = extract_text_from_file(content, file.filename) all_text_parts.append(f"\n{'='*60}\nFILE: {file.filename}\n{'='*60}\n{text}") total_chars += metadata["char_count"] total_sheets += metadata.get("sheet_count", 0) except Exception as e: logger.warning(f"Failed to extract text from {file.filename}: {e}") continue if not all_text_parts: project.status = ProjectStatus.DRAFT project.parse_stage = None await db.commit() raise HTTPException(status_code=400, detail="Failed to extract text from any uploaded file.") combined_text = "\n".join(all_text_parts) project.source_filename = ", ".join(filenames) sheets_info = f" ({total_sheets} sheets)" if total_sheets else "" project.parse_stage = f"Extracted {total_chars:,} characters from {len(filenames)} file(s){sheets_info}. Sending to AI..." await db.commit() # Stage 3+4: AI parsing runs in background — return 202 immediately background_tasks.add_task(_background_parse, project_id, ", ".join(filenames), combined_text, {"char_count": total_chars, "sheet_count": total_sheets}, mode, user_context) return { "message": f"{len(filenames)} file(s) received. AI parsing started.", "status": "parsing", } @router.get("/{project_id}/tier-mapping") async def get_tier_mapping(project_id: int, db: AsyncSession = Depends(get_db)): """Get the tier mapping for a project.""" project = await _get_project(project_id, db) import json if project.tier_mapping: return json.loads(project.tier_mapping) return {"tiers": []} @router.put("/{project_id}/tier-mapping") async def set_tier_mapping(project_id: int, data: dict, db: AsyncSession = Depends(get_db)): """Set the tier mapping for a project.""" import json project = await _get_project(project_id, db) project.tier_mapping = json.dumps(data) await db.commit() return data @router.post("/{project_id}/expand-tiers") async def expand_tiers_endpoint(project_id: int, db: AsyncSession = Depends(get_db)): """Expand matched assets into complexity tier variants.""" from app.services.tier_expander import expand_to_tiers project = await _get_project(project_id, db) result = await expand_to_tiers(db, project) return result @router.get("/{project_id}/brief-analysis") async def get_brief_analysis(project_id: int, db: AsyncSession = Depends(get_db)): """Get the structured brief analysis for a project.""" project = await _get_project(project_id, db) if not project.brief_analysis: return {"status": "not_analyzed"} import json try: return {"status": "analyzed", "analysis": json.loads(project.brief_analysis)} except json.JSONDecodeError: return {"status": "error", "analysis": None} @router.post("/{project_id}/analyze-brief") async def analyze_brief_endpoint(project_id: int, data: dict | None = None, db: AsyncSession = Depends(get_db)): """Run AI analysis on uploaded document or pasted text.""" from app.services.rfp_analysis import analyze_brief from app.services.doc_parser import extract_text_from_file project = await _get_project(project_id, db) body = data or {} # Option 1: Pasted text if body.get("text"): text = body["text"] # Option 2: Read from uploaded file elif project.source_filename: import os from app.config import settings filepath = os.path.join(settings.data_dir, project.source_filename) if not os.path.exists(filepath): raise HTTPException(status_code=400, detail="Source file not found on disk. Re-upload or paste the brief text.") with open(filepath, "rb") as f: content = f.read() text, _ = extract_text_from_file(content, project.source_filename) else: raise HTTPException(status_code=400, detail="No document uploaded and no text provided. Upload a file or paste the brief.") if len(text.strip()) < 20: raise HTTPException(status_code=400, detail="Brief text is too short to analyze.") analysis = await analyze_brief(db, project, text) return {"status": "analyzed", "analysis": analysis} @router.get("/{project_id}/client-assets", response_model=list[ClientAssetOut]) async def list_client_assets(project_id: int, db: AsyncSession = Depends(get_db)): result = await db.execute( select(ClientAsset) .where(ClientAsset.project_id == project_id) .order_by(ClientAsset.sort_order) ) return result.scalars().all() @router.put("/{project_id}/client-assets/{asset_id}", response_model=ClientAssetOut) async def update_client_asset( project_id: int, asset_id: int, data: ClientAssetUpdate, db: AsyncSession = Depends(get_db), ): result = await db.execute( select(ClientAsset).where(ClientAsset.id == asset_id, ClientAsset.project_id == project_id) ) ca = result.scalar_one_or_none() if not ca: raise HTTPException(status_code=404, detail="Client asset not found") if data.raw_name is not None: ca.raw_name = data.raw_name if data.raw_description is not None: ca.raw_description = data.raw_description if data.volume is not None: ca.volume = data.volume await db.commit() await db.refresh(ca) return ca async def _background_match(project_id: int, asset_snapshots: list): """Run AI matching in the background (own DB session).""" async with async_session() as db: try: result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: return # Reconstruct ClientAsset-like objects from snapshots for match_client_assets ca_result = await db.execute( select(ClientAsset).where(ClientAsset.id.in_([s["id"] for s in asset_snapshots])) .order_by(ClientAsset.sort_order) ) client_assets = ca_result.scalars().all() matches = await match_client_assets(db, project_id, client_assets) await db.refresh(project) project.status = ProjectStatus.REVIEW await db.commit() logger.info(f"Background match complete for project {project_id}: {len(matches)} matches") except Exception as e: logger.error(f"Background match error for project {project_id}: {e}") try: async with async_session() as db2: result = await db2.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if project: project.status = ProjectStatus.REVIEW await db2.commit() except Exception: pass @router.post("/{project_id}/match") async def run_matching( project_id: int, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db), ): """Trigger AI matching for all client assets in this project.""" project = await _get_project(project_id, db) # Get client assets result = await db.execute( select(ClientAsset).where(ClientAsset.project_id == project_id).order_by(ClientAsset.sort_order) ) client_assets = result.scalars().all() if not client_assets: raise HTTPException(status_code=400, detail="No client assets to match. Upload a document first.") # Snapshot IDs before clearing (ORM objects expire after commit) asset_snapshots = [{"id": ca.id} for ca in client_assets] # Clear existing matches for ca in client_assets: matches_result = await db.execute(select(Match).where(Match.client_asset_id == ca.id)) for m in matches_result.scalars().all(): await db.delete(m) project.status = ProjectStatus.MATCHING await db.commit() # Run matching in background — return 202 immediately background_tasks.add_task(_background_match, project_id, asset_snapshots) return { "message": f"Matching started for {len(client_assets)} client assets.", "status": "matching", } @router.post("/{project_id}/match/cancel") async def cancel_matching_endpoint(project_id: int, db: AsyncSession = Depends(get_db)): """Cancel an in-progress matching run.""" from app.services.ai_matching import cancel_matching cancel_matching(project_id) project = await _get_project(project_id, db) project.status = ProjectStatus.REVIEW await db.commit() return {"detail": "Matching cancellation requested"} @router.post("/{project_id}/refine") async def refine_matches_endpoint( project_id: int, data: dict, db: AsyncSession = Depends(get_db), ): """Interpret a natural language instruction to refine matches.""" from app.services.match_refiner import refine_matches instruction = data.get("instruction", "") if not instruction: raise HTTPException(status_code=400, detail="No instruction provided") result = await refine_matches(db, project_id, instruction) # If there are assets to re-match, trigger matching for just those if result.get("rematch_count", 0) > 0: rematch_ids = result["rematch_asset_ids"] ca_result = await db.execute( select(ClientAsset).where(ClientAsset.id.in_(rematch_ids)).order_by(ClientAsset.sort_order) ) client_assets = ca_result.scalars().all() if client_assets: matches = await match_client_assets(db, project_id, client_assets) result["new_matches"] = len(matches) return result @router.post("/{project_id}/matches/{match_id}/feedback") async def submit_match_feedback( project_id: int, match_id: int, data: dict, db: AsyncSession = Depends(get_db), ): """Store feedback on a match (confirm or reject) for the learning system.""" from app.models.feedback import MatchFeedback result = await db.execute(select(Match).where(Match.id == match_id)) match = result.scalar_one_or_none() if not match: raise HTTPException(status_code=404, detail="Match not found") # Get the client asset name for the feedback record ca_result = await db.execute(select(ClientAsset).where(ClientAsset.id == match.client_asset_id)) ca = ca_result.scalar_one_or_none() confirmed = data.get("confirmed", True) comment = data.get("comment", "") feedback = MatchFeedback( client_term=(ca.raw_name or "").strip().lower() if ca else "", client_description=ca.raw_description if ca else None, gmal_asset_id=match.gmal_asset_id, confirmed=confirmed, user_comment=comment, ) db.add(feedback) await db.commit() return {"detail": f"Feedback {'confirmed' if confirmed else 'rejected'} stored"} @router.get("/{project_id}/matches", response_model=list[MatchOut]) async def list_matches(project_id: int, db: AsyncSession = Depends(get_db)): """Get all matches for a project, grouped by client asset.""" # Get client asset IDs for this project ca_result = await db.execute( select(ClientAsset.id).where(ClientAsset.project_id == project_id) ) ca_ids = [r[0] for r in ca_result.all()] if not ca_ids: return [] result = await db.execute( select(Match, GmalAsset) .join(GmalAsset, Match.gmal_asset_id == GmalAsset.id) .where(Match.client_asset_id.in_(ca_ids)) .order_by(Match.client_asset_id, Match.rank) ) matches = [] for match, gmal in result.all(): matches.append(MatchOut( id=match.id, client_asset_id=match.client_asset_id, gmal_asset_id=match.gmal_asset_id, gmal_id=gmal.gmal_id, gmal_name=gmal.asset_name, gmal_unique_name=gmal.unique_name, confidence=match.confidence.value, confidence_score=float(match.confidence_score) if match.confidence_score else None, ai_reasoning=match.ai_reasoning, caveat_text=match.caveat_text, is_selected=match.is_selected, rank=match.rank, )) return matches @router.put("/{project_id}/matches/{match_id}/select") async def select_match( project_id: int, match_id: int, data: MatchSelectRequest, db: AsyncSession = Depends(get_db), ): """Select or deselect a match. Deselects other matches for the same client asset.""" result = await db.execute(select(Match).where(Match.id == match_id)) match = result.scalar_one_or_none() if not match: raise HTTPException(status_code=404, detail="Match not found") if data.is_selected: # Deselect all other matches for this client asset siblings = await db.execute( select(Match).where(Match.client_asset_id == match.client_asset_id) ) for sibling in siblings.scalars().all(): sibling.is_selected = False match.is_selected = data.is_selected await db.commit() return {"detail": "Match updated"} @router.post("/{project_id}/matches/{client_asset_id}/manual") async def manual_match( project_id: int, client_asset_id: int, data: ManualMatchRequest, db: AsyncSession = Depends(get_db), ): """Manually assign a GMAL asset to a client asset.""" # Verify client asset belongs to project ca_result = await db.execute( select(ClientAsset).where(ClientAsset.id == client_asset_id, ClientAsset.project_id == project_id) ) ca = ca_result.scalar_one_or_none() if not ca: raise HTTPException(status_code=404, detail="Client asset not found") # Verify GMAL asset exists gmal_result = await db.execute(select(GmalAsset).where(GmalAsset.id == data.gmal_asset_id)) gmal = gmal_result.scalar_one_or_none() if not gmal: raise HTTPException(status_code=404, detail="GMAL asset not found") # Deselect existing matches existing = await db.execute(select(Match).where(Match.client_asset_id == client_asset_id)) for m in existing.scalars().all(): m.is_selected = False # Create manual match match = Match( client_asset_id=client_asset_id, gmal_asset_id=data.gmal_asset_id, confidence=MatchConfidence.EXACT, confidence_score=1.0, ai_reasoning="Manually assigned by user", caveat_text="", is_selected=True, rank=0, ) db.add(match) await db.commit() return {"detail": f"Manually matched to {gmal.gmal_id}"} async def _get_project(project_id: int, db: AsyncSession) -> Project: result = await db.execute(select(Project).where(Project.id == project_id)) project = result.scalar_one_or_none() if not project: raise HTTPException(status_code=404, detail="Project not found") return project