gmal-scope-builder/backend/app/api/matching.py
DJP 68d342575e Fix tier extraction: separate entry per tier + user context box
Tier fix (reverses previous "extract once" mistake):
- SEPARATE entry for EACH tier where volume > 0
- "KV 360" Tier A=No/0, Tier B=Yes/1, Tier C=Yes/1 → TWO entries
- Tier field matches column header exactly ("Tier B", "Tier C")
- Tiers with volume=0 or status=No are skipped
- Applied to both normal and deep extraction prompts

User context box (new Step 3 on Upload tab):
- Textarea where users give hints before extraction runs
- Examples: "Focus on Toolbox sheet", "Tier columns are D/F/H"
- Context prepended to Claude prompt in both normal and deep modes
- Passed through upload endpoint → background parse → AI calls

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 09:24:07 -04:00

551 lines
21 KiB
Python

"""Client document upload and AI matching endpoints."""
import logging
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db, async_session
from app.models.gmal import GmalAsset
from app.models.project import Project, ClientAsset, Match, ProjectStatus, MatchConfidence
from app.schemas.project import ClientAssetOut, ClientAssetUpdate, MatchOut, MatchSelectRequest, ManualMatchRequest
from app.services.doc_parser import extract_text_from_file, parse_text_with_ai, deep_pass1_structure_analysis, deep_pass2_guided_extraction, SYSTEM_PROMPT, EXTRACT_TOOLS
from app.services.ai_matching import match_client_assets
router = APIRouter()
logger = logging.getLogger(__name__)
async def _background_parse(project_id: int, filename: str, text: str, metadata: dict, mode: str = "normal", user_context: str = ""):
"""Run AI parsing and save results in the background (own DB session)."""
async with async_session() as db:
try:
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
return
# Stage 3: AI parsing (normal or deep)
try:
if mode == "deep":
# Pass 1: Structure analysis
import time
start = time.time()
project.parse_stage = "Deep extraction Pass 1/2: Analyzing spreadsheet structure... (this takes 20-40 seconds)"
await db.commit()
structure_analysis, usage1 = deep_pass1_structure_analysis(text, user_context)
elapsed1 = int(time.time() - start)
project.ai_input_tokens = (project.ai_input_tokens or 0) + usage1.get("input_tokens", 0)
project.ai_output_tokens = (project.ai_output_tokens or 0) + usage1.get("output_tokens", 0)
project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage1.get("cost_usd", 0)
project.ai_call_count = (project.ai_call_count or 0) + 1
project.parse_stage = f"Pass 1 complete ({elapsed1}s). Pass 2/2: Extracting assets using structure analysis..."
await db.commit()
# Pass 2: Guided extraction
extracted, usage2 = deep_pass2_guided_extraction(text, structure_analysis, user_context)
elapsed2 = int(time.time() - start)
usage_info = {
"input_tokens": usage1.get("input_tokens", 0) + usage2.get("input_tokens", 0),
"output_tokens": usage1.get("output_tokens", 0) + usage2.get("output_tokens", 0),
"cost_usd": usage1.get("cost_usd", 0) + usage2.get("cost_usd", 0),
}
project.parse_stage = f"Deep extraction complete ({elapsed2}s total). Found {len(extracted)} assets."
await db.commit()
else:
extracted, usage_info = parse_text_with_ai(text, user_context)
except Exception as e:
logger.error(f"AI parsing failed for project {project_id}: {e}")
project.status = ProjectStatus.DRAFT
project.parse_stage = f"AI parsing failed: {str(e)}"
await db.commit()
return
# Save AI costs
project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0)
project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0)
project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0)
project.ai_call_count = (project.ai_call_count or 0) + 1
# Stage 4: Saving results
project.parse_stage = f"AI found {len(extracted)} assets. Saving..."
await db.commit()
# Clear existing client assets
existing = await db.execute(
select(ClientAsset).where(ClientAsset.project_id == project_id)
)
for ca in existing.scalars().all():
await db.delete(ca)
# Create client asset records (skip zero-quantity assets)
assets = []
for idx, item in enumerate(extracted):
volume = item.get("volume", 1)
if volume <= 0:
continue
ca = ClientAsset(
project_id=project_id,
raw_name=item.get("name", "Unknown"),
raw_description=item.get("description", ""),
client_tier=item.get("tier", "") or None,
volume=volume,
sort_order=idx + 1,
)
db.add(ca)
assets.append(ca)
project.status = ProjectStatus.REVIEW
project.parse_stage = f"Done! {len(assets)} assets extracted."
await db.commit()
logger.info(f"Background parse complete for project {project_id}: {len(assets)} assets")
except Exception as e:
logger.error(f"Background parse error for project {project_id}: {e}")
@router.post("/{project_id}/upload")
async def upload_client_document(
project_id: int,
background_tasks: BackgroundTasks,
files: list[UploadFile] = File(...),
mode: str = "normal",
user_context: str = "",
db: AsyncSession = Depends(get_db),
):
"""Upload one or more client documents and extract assets using AI."""
project = await _get_project(project_id, db)
import os
from app.config import settings
filenames = []
all_text_parts = []
total_chars = 0
total_sheets = 0
project.status = ProjectStatus.PARSING
project.parse_stage = f"Uploading {len(files)} file(s)..."
await db.commit()
# Stage 1+2: Read and extract text from each file
for file in files:
content = await file.read()
save_path = os.path.join(settings.data_dir, file.filename)
with open(save_path, "wb") as f:
f.write(content)
filenames.append(file.filename)
project.parse_stage = f"Extracting text from {file.filename}..."
await db.commit()
try:
text, metadata = extract_text_from_file(content, file.filename)
all_text_parts.append(f"\n{'='*60}\nFILE: {file.filename}\n{'='*60}\n{text}")
total_chars += metadata["char_count"]
total_sheets += metadata.get("sheet_count", 0)
except Exception as e:
logger.warning(f"Failed to extract text from {file.filename}: {e}")
continue
if not all_text_parts:
project.status = ProjectStatus.DRAFT
project.parse_stage = None
await db.commit()
raise HTTPException(status_code=400, detail="Failed to extract text from any uploaded file.")
combined_text = "\n".join(all_text_parts)
project.source_filename = ", ".join(filenames)
sheets_info = f" ({total_sheets} sheets)" if total_sheets else ""
project.parse_stage = f"Extracted {total_chars:,} characters from {len(filenames)} file(s){sheets_info}. Sending to AI..."
await db.commit()
# Stage 3+4: AI parsing runs in background — return 202 immediately
background_tasks.add_task(_background_parse, project_id, ", ".join(filenames), combined_text, {"char_count": total_chars, "sheet_count": total_sheets}, mode, user_context)
return {
"message": f"{len(filenames)} file(s) received. AI parsing started.",
"status": "parsing",
}
@router.get("/{project_id}/tier-mapping")
async def get_tier_mapping(project_id: int, db: AsyncSession = Depends(get_db)):
"""Get the tier mapping for a project."""
project = await _get_project(project_id, db)
import json
if project.tier_mapping:
return json.loads(project.tier_mapping)
return {"tiers": []}
@router.put("/{project_id}/tier-mapping")
async def set_tier_mapping(project_id: int, data: dict, db: AsyncSession = Depends(get_db)):
"""Set the tier mapping for a project."""
import json
project = await _get_project(project_id, db)
project.tier_mapping = json.dumps(data)
await db.commit()
return data
@router.post("/{project_id}/expand-tiers")
async def expand_tiers_endpoint(project_id: int, db: AsyncSession = Depends(get_db)):
"""Expand matched assets into complexity tier variants."""
from app.services.tier_expander import expand_to_tiers
project = await _get_project(project_id, db)
result = await expand_to_tiers(db, project)
return result
@router.get("/{project_id}/brief-analysis")
async def get_brief_analysis(project_id: int, db: AsyncSession = Depends(get_db)):
"""Get the structured brief analysis for a project."""
project = await _get_project(project_id, db)
if not project.brief_analysis:
return {"status": "not_analyzed"}
import json
try:
return {"status": "analyzed", "analysis": json.loads(project.brief_analysis)}
except json.JSONDecodeError:
return {"status": "error", "analysis": None}
@router.post("/{project_id}/analyze-brief")
async def analyze_brief_endpoint(project_id: int, data: dict | None = None, db: AsyncSession = Depends(get_db)):
"""Run AI analysis on uploaded document or pasted text."""
from app.services.rfp_analysis import analyze_brief
from app.services.doc_parser import extract_text_from_file
project = await _get_project(project_id, db)
body = data or {}
# Option 1: Pasted text
if body.get("text"):
text = body["text"]
# Option 2: Read from uploaded file
elif project.source_filename:
import os
from app.config import settings
filepath = os.path.join(settings.data_dir, project.source_filename)
if not os.path.exists(filepath):
raise HTTPException(status_code=400, detail="Source file not found on disk. Re-upload or paste the brief text.")
with open(filepath, "rb") as f:
content = f.read()
text, _ = extract_text_from_file(content, project.source_filename)
else:
raise HTTPException(status_code=400, detail="No document uploaded and no text provided. Upload a file or paste the brief.")
if len(text.strip()) < 20:
raise HTTPException(status_code=400, detail="Brief text is too short to analyze.")
analysis = await analyze_brief(db, project, text)
return {"status": "analyzed", "analysis": analysis}
@router.get("/{project_id}/client-assets", response_model=list[ClientAssetOut])
async def list_client_assets(project_id: int, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(ClientAsset)
.where(ClientAsset.project_id == project_id)
.order_by(ClientAsset.sort_order)
)
return result.scalars().all()
@router.put("/{project_id}/client-assets/{asset_id}", response_model=ClientAssetOut)
async def update_client_asset(
project_id: int,
asset_id: int,
data: ClientAssetUpdate,
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(ClientAsset).where(ClientAsset.id == asset_id, ClientAsset.project_id == project_id)
)
ca = result.scalar_one_or_none()
if not ca:
raise HTTPException(status_code=404, detail="Client asset not found")
if data.raw_name is not None:
ca.raw_name = data.raw_name
if data.raw_description is not None:
ca.raw_description = data.raw_description
if data.volume is not None:
ca.volume = data.volume
await db.commit()
await db.refresh(ca)
return ca
async def _background_match(project_id: int, asset_snapshots: list):
"""Run AI matching in the background (own DB session)."""
async with async_session() as db:
try:
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
return
# Reconstruct ClientAsset-like objects from snapshots for match_client_assets
ca_result = await db.execute(
select(ClientAsset).where(ClientAsset.id.in_([s["id"] for s in asset_snapshots]))
.order_by(ClientAsset.sort_order)
)
client_assets = ca_result.scalars().all()
matches = await match_client_assets(db, project_id, client_assets)
await db.refresh(project)
project.status = ProjectStatus.REVIEW
await db.commit()
logger.info(f"Background match complete for project {project_id}: {len(matches)} matches")
except Exception as e:
logger.error(f"Background match error for project {project_id}: {e}")
try:
async with async_session() as db2:
result = await db2.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if project:
project.status = ProjectStatus.REVIEW
await db2.commit()
except Exception:
pass
@router.post("/{project_id}/match")
async def run_matching(
project_id: int,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
):
"""Trigger AI matching for all client assets in this project."""
project = await _get_project(project_id, db)
# Get client assets
result = await db.execute(
select(ClientAsset).where(ClientAsset.project_id == project_id).order_by(ClientAsset.sort_order)
)
client_assets = result.scalars().all()
if not client_assets:
raise HTTPException(status_code=400, detail="No client assets to match. Upload a document first.")
# Snapshot IDs before clearing (ORM objects expire after commit)
asset_snapshots = [{"id": ca.id} for ca in client_assets]
# Clear existing matches
for ca in client_assets:
matches_result = await db.execute(select(Match).where(Match.client_asset_id == ca.id))
for m in matches_result.scalars().all():
await db.delete(m)
project.status = ProjectStatus.MATCHING
await db.commit()
# Run matching in background — return 202 immediately
background_tasks.add_task(_background_match, project_id, asset_snapshots)
return {
"message": f"Matching started for {len(client_assets)} client assets.",
"status": "matching",
}
@router.post("/{project_id}/match/cancel")
async def cancel_matching_endpoint(project_id: int, db: AsyncSession = Depends(get_db)):
"""Cancel an in-progress matching run."""
from app.services.ai_matching import cancel_matching
cancel_matching(project_id)
project = await _get_project(project_id, db)
project.status = ProjectStatus.REVIEW
await db.commit()
return {"detail": "Matching cancellation requested"}
@router.post("/{project_id}/refine")
async def refine_matches_endpoint(
project_id: int,
data: dict,
db: AsyncSession = Depends(get_db),
):
"""Interpret a natural language instruction to refine matches."""
from app.services.match_refiner import refine_matches
instruction = data.get("instruction", "")
if not instruction:
raise HTTPException(status_code=400, detail="No instruction provided")
result = await refine_matches(db, project_id, instruction)
# If there are assets to re-match, trigger matching for just those
if result.get("rematch_count", 0) > 0:
rematch_ids = result["rematch_asset_ids"]
ca_result = await db.execute(
select(ClientAsset).where(ClientAsset.id.in_(rematch_ids)).order_by(ClientAsset.sort_order)
)
client_assets = ca_result.scalars().all()
if client_assets:
matches = await match_client_assets(db, project_id, client_assets)
result["new_matches"] = len(matches)
return result
@router.post("/{project_id}/matches/{match_id}/feedback")
async def submit_match_feedback(
project_id: int,
match_id: int,
data: dict,
db: AsyncSession = Depends(get_db),
):
"""Store feedback on a match (confirm or reject) for the learning system."""
from app.models.feedback import MatchFeedback
result = await db.execute(select(Match).where(Match.id == match_id))
match = result.scalar_one_or_none()
if not match:
raise HTTPException(status_code=404, detail="Match not found")
# Get the client asset name for the feedback record
ca_result = await db.execute(select(ClientAsset).where(ClientAsset.id == match.client_asset_id))
ca = ca_result.scalar_one_or_none()
confirmed = data.get("confirmed", True)
comment = data.get("comment", "")
feedback = MatchFeedback(
client_term=(ca.raw_name or "").strip().lower() if ca else "",
client_description=ca.raw_description if ca else None,
gmal_asset_id=match.gmal_asset_id,
confirmed=confirmed,
user_comment=comment,
)
db.add(feedback)
await db.commit()
return {"detail": f"Feedback {'confirmed' if confirmed else 'rejected'} stored"}
@router.get("/{project_id}/matches", response_model=list[MatchOut])
async def list_matches(project_id: int, db: AsyncSession = Depends(get_db)):
"""Get all matches for a project, grouped by client asset."""
# Get client asset IDs for this project
ca_result = await db.execute(
select(ClientAsset.id).where(ClientAsset.project_id == project_id)
)
ca_ids = [r[0] for r in ca_result.all()]
if not ca_ids:
return []
result = await db.execute(
select(Match, GmalAsset)
.join(GmalAsset, Match.gmal_asset_id == GmalAsset.id)
.where(Match.client_asset_id.in_(ca_ids))
.order_by(Match.client_asset_id, Match.rank)
)
matches = []
for match, gmal in result.all():
matches.append(MatchOut(
id=match.id,
client_asset_id=match.client_asset_id,
gmal_asset_id=match.gmal_asset_id,
gmal_id=gmal.gmal_id,
gmal_name=gmal.asset_name,
gmal_unique_name=gmal.unique_name,
confidence=match.confidence.value,
confidence_score=float(match.confidence_score) if match.confidence_score else None,
ai_reasoning=match.ai_reasoning,
caveat_text=match.caveat_text,
is_selected=match.is_selected,
rank=match.rank,
))
return matches
@router.put("/{project_id}/matches/{match_id}/select")
async def select_match(
project_id: int,
match_id: int,
data: MatchSelectRequest,
db: AsyncSession = Depends(get_db),
):
"""Select or deselect a match. Deselects other matches for the same client asset."""
result = await db.execute(select(Match).where(Match.id == match_id))
match = result.scalar_one_or_none()
if not match:
raise HTTPException(status_code=404, detail="Match not found")
if data.is_selected:
# Deselect all other matches for this client asset
siblings = await db.execute(
select(Match).where(Match.client_asset_id == match.client_asset_id)
)
for sibling in siblings.scalars().all():
sibling.is_selected = False
match.is_selected = data.is_selected
await db.commit()
return {"detail": "Match updated"}
@router.post("/{project_id}/matches/{client_asset_id}/manual")
async def manual_match(
project_id: int,
client_asset_id: int,
data: ManualMatchRequest,
db: AsyncSession = Depends(get_db),
):
"""Manually assign a GMAL asset to a client asset."""
# Verify client asset belongs to project
ca_result = await db.execute(
select(ClientAsset).where(ClientAsset.id == client_asset_id, ClientAsset.project_id == project_id)
)
ca = ca_result.scalar_one_or_none()
if not ca:
raise HTTPException(status_code=404, detail="Client asset not found")
# Verify GMAL asset exists
gmal_result = await db.execute(select(GmalAsset).where(GmalAsset.id == data.gmal_asset_id))
gmal = gmal_result.scalar_one_or_none()
if not gmal:
raise HTTPException(status_code=404, detail="GMAL asset not found")
# Deselect existing matches
existing = await db.execute(select(Match).where(Match.client_asset_id == client_asset_id))
for m in existing.scalars().all():
m.is_selected = False
# Create manual match
match = Match(
client_asset_id=client_asset_id,
gmal_asset_id=data.gmal_asset_id,
confidence=MatchConfidence.EXACT,
confidence_score=1.0,
ai_reasoning="Manually assigned by user",
caveat_text="",
is_selected=True,
rank=0,
)
db.add(match)
await db.commit()
return {"detail": f"Manually matched to {gmal.gmal_id}"}
async def _get_project(project_id: int, db: AsyncSession) -> Project:
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project