Tier fix (reverses previous "extract once" mistake):
- SEPARATE entry for EACH tier where volume > 0
- "KV 360" Tier A=No/0, Tier B=Yes/1, Tier C=Yes/1 → TWO entries
- Tier field matches column header exactly ("Tier B", "Tier C")
- Tiers with volume=0 or status=No are skipped
- Applied to both normal and deep extraction prompts
User context box (new Step 3 on Upload tab):
- Textarea where users give hints before extraction runs
- Examples: "Focus on Toolbox sheet", "Tier columns are D/F/H"
- Context prepended to Claude prompt in both normal and deep modes
- Passed through upload endpoint → background parse → AI calls
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
551 lines
21 KiB
Python
551 lines
21 KiB
Python
"""Client document upload and AI matching endpoints."""
|
|
|
|
import logging
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.database import get_db, async_session
|
|
from app.models.gmal import GmalAsset
|
|
from app.models.project import Project, ClientAsset, Match, ProjectStatus, MatchConfidence
|
|
from app.schemas.project import ClientAssetOut, ClientAssetUpdate, MatchOut, MatchSelectRequest, ManualMatchRequest
|
|
from app.services.doc_parser import extract_text_from_file, parse_text_with_ai, deep_pass1_structure_analysis, deep_pass2_guided_extraction, SYSTEM_PROMPT, EXTRACT_TOOLS
|
|
from app.services.ai_matching import match_client_assets
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _background_parse(project_id: int, filename: str, text: str, metadata: dict, mode: str = "normal", user_context: str = ""):
|
|
"""Run AI parsing and save results in the background (own DB session)."""
|
|
async with async_session() as db:
|
|
try:
|
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
return
|
|
|
|
# Stage 3: AI parsing (normal or deep)
|
|
try:
|
|
if mode == "deep":
|
|
# Pass 1: Structure analysis
|
|
import time
|
|
start = time.time()
|
|
project.parse_stage = "Deep extraction Pass 1/2: Analyzing spreadsheet structure... (this takes 20-40 seconds)"
|
|
await db.commit()
|
|
|
|
structure_analysis, usage1 = deep_pass1_structure_analysis(text, user_context)
|
|
|
|
elapsed1 = int(time.time() - start)
|
|
project.ai_input_tokens = (project.ai_input_tokens or 0) + usage1.get("input_tokens", 0)
|
|
project.ai_output_tokens = (project.ai_output_tokens or 0) + usage1.get("output_tokens", 0)
|
|
project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage1.get("cost_usd", 0)
|
|
project.ai_call_count = (project.ai_call_count or 0) + 1
|
|
project.parse_stage = f"Pass 1 complete ({elapsed1}s). Pass 2/2: Extracting assets using structure analysis..."
|
|
await db.commit()
|
|
|
|
# Pass 2: Guided extraction
|
|
extracted, usage2 = deep_pass2_guided_extraction(text, structure_analysis, user_context)
|
|
|
|
elapsed2 = int(time.time() - start)
|
|
usage_info = {
|
|
"input_tokens": usage1.get("input_tokens", 0) + usage2.get("input_tokens", 0),
|
|
"output_tokens": usage1.get("output_tokens", 0) + usage2.get("output_tokens", 0),
|
|
"cost_usd": usage1.get("cost_usd", 0) + usage2.get("cost_usd", 0),
|
|
}
|
|
project.parse_stage = f"Deep extraction complete ({elapsed2}s total). Found {len(extracted)} assets."
|
|
await db.commit()
|
|
else:
|
|
extracted, usage_info = parse_text_with_ai(text, user_context)
|
|
except Exception as e:
|
|
logger.error(f"AI parsing failed for project {project_id}: {e}")
|
|
project.status = ProjectStatus.DRAFT
|
|
project.parse_stage = f"AI parsing failed: {str(e)}"
|
|
await db.commit()
|
|
return
|
|
|
|
# Save AI costs
|
|
project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0)
|
|
project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0)
|
|
project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0)
|
|
project.ai_call_count = (project.ai_call_count or 0) + 1
|
|
|
|
# Stage 4: Saving results
|
|
project.parse_stage = f"AI found {len(extracted)} assets. Saving..."
|
|
await db.commit()
|
|
|
|
# Clear existing client assets
|
|
existing = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.project_id == project_id)
|
|
)
|
|
for ca in existing.scalars().all():
|
|
await db.delete(ca)
|
|
|
|
# Create client asset records (skip zero-quantity assets)
|
|
assets = []
|
|
for idx, item in enumerate(extracted):
|
|
volume = item.get("volume", 1)
|
|
if volume <= 0:
|
|
continue
|
|
ca = ClientAsset(
|
|
project_id=project_id,
|
|
raw_name=item.get("name", "Unknown"),
|
|
raw_description=item.get("description", ""),
|
|
client_tier=item.get("tier", "") or None,
|
|
volume=volume,
|
|
sort_order=idx + 1,
|
|
)
|
|
db.add(ca)
|
|
assets.append(ca)
|
|
|
|
project.status = ProjectStatus.REVIEW
|
|
project.parse_stage = f"Done! {len(assets)} assets extracted."
|
|
await db.commit()
|
|
logger.info(f"Background parse complete for project {project_id}: {len(assets)} assets")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background parse error for project {project_id}: {e}")
|
|
|
|
|
|
@router.post("/{project_id}/upload")
|
|
async def upload_client_document(
|
|
project_id: int,
|
|
background_tasks: BackgroundTasks,
|
|
files: list[UploadFile] = File(...),
|
|
mode: str = "normal",
|
|
user_context: str = "",
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Upload one or more client documents and extract assets using AI."""
|
|
project = await _get_project(project_id, db)
|
|
|
|
import os
|
|
from app.config import settings
|
|
|
|
filenames = []
|
|
all_text_parts = []
|
|
total_chars = 0
|
|
total_sheets = 0
|
|
|
|
project.status = ProjectStatus.PARSING
|
|
project.parse_stage = f"Uploading {len(files)} file(s)..."
|
|
await db.commit()
|
|
|
|
# Stage 1+2: Read and extract text from each file
|
|
for file in files:
|
|
content = await file.read()
|
|
save_path = os.path.join(settings.data_dir, file.filename)
|
|
with open(save_path, "wb") as f:
|
|
f.write(content)
|
|
filenames.append(file.filename)
|
|
|
|
project.parse_stage = f"Extracting text from {file.filename}..."
|
|
await db.commit()
|
|
|
|
try:
|
|
text, metadata = extract_text_from_file(content, file.filename)
|
|
all_text_parts.append(f"\n{'='*60}\nFILE: {file.filename}\n{'='*60}\n{text}")
|
|
total_chars += metadata["char_count"]
|
|
total_sheets += metadata.get("sheet_count", 0)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract text from {file.filename}: {e}")
|
|
continue
|
|
|
|
if not all_text_parts:
|
|
project.status = ProjectStatus.DRAFT
|
|
project.parse_stage = None
|
|
await db.commit()
|
|
raise HTTPException(status_code=400, detail="Failed to extract text from any uploaded file.")
|
|
|
|
combined_text = "\n".join(all_text_parts)
|
|
project.source_filename = ", ".join(filenames)
|
|
|
|
sheets_info = f" ({total_sheets} sheets)" if total_sheets else ""
|
|
project.parse_stage = f"Extracted {total_chars:,} characters from {len(filenames)} file(s){sheets_info}. Sending to AI..."
|
|
await db.commit()
|
|
|
|
# Stage 3+4: AI parsing runs in background — return 202 immediately
|
|
background_tasks.add_task(_background_parse, project_id, ", ".join(filenames), combined_text, {"char_count": total_chars, "sheet_count": total_sheets}, mode, user_context)
|
|
|
|
return {
|
|
"message": f"{len(filenames)} file(s) received. AI parsing started.",
|
|
"status": "parsing",
|
|
}
|
|
|
|
|
|
@router.get("/{project_id}/tier-mapping")
|
|
async def get_tier_mapping(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Get the tier mapping for a project."""
|
|
project = await _get_project(project_id, db)
|
|
import json
|
|
if project.tier_mapping:
|
|
return json.loads(project.tier_mapping)
|
|
return {"tiers": []}
|
|
|
|
|
|
@router.put("/{project_id}/tier-mapping")
|
|
async def set_tier_mapping(project_id: int, data: dict, db: AsyncSession = Depends(get_db)):
|
|
"""Set the tier mapping for a project."""
|
|
import json
|
|
project = await _get_project(project_id, db)
|
|
project.tier_mapping = json.dumps(data)
|
|
await db.commit()
|
|
return data
|
|
|
|
|
|
@router.post("/{project_id}/expand-tiers")
|
|
async def expand_tiers_endpoint(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Expand matched assets into complexity tier variants."""
|
|
from app.services.tier_expander import expand_to_tiers
|
|
project = await _get_project(project_id, db)
|
|
result = await expand_to_tiers(db, project)
|
|
return result
|
|
|
|
|
|
@router.get("/{project_id}/brief-analysis")
|
|
async def get_brief_analysis(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Get the structured brief analysis for a project."""
|
|
project = await _get_project(project_id, db)
|
|
if not project.brief_analysis:
|
|
return {"status": "not_analyzed"}
|
|
import json
|
|
try:
|
|
return {"status": "analyzed", "analysis": json.loads(project.brief_analysis)}
|
|
except json.JSONDecodeError:
|
|
return {"status": "error", "analysis": None}
|
|
|
|
|
|
@router.post("/{project_id}/analyze-brief")
|
|
async def analyze_brief_endpoint(project_id: int, data: dict | None = None, db: AsyncSession = Depends(get_db)):
|
|
"""Run AI analysis on uploaded document or pasted text."""
|
|
from app.services.rfp_analysis import analyze_brief
|
|
from app.services.doc_parser import extract_text_from_file
|
|
|
|
project = await _get_project(project_id, db)
|
|
body = data or {}
|
|
|
|
# Option 1: Pasted text
|
|
if body.get("text"):
|
|
text = body["text"]
|
|
# Option 2: Read from uploaded file
|
|
elif project.source_filename:
|
|
import os
|
|
from app.config import settings
|
|
filepath = os.path.join(settings.data_dir, project.source_filename)
|
|
if not os.path.exists(filepath):
|
|
raise HTTPException(status_code=400, detail="Source file not found on disk. Re-upload or paste the brief text.")
|
|
with open(filepath, "rb") as f:
|
|
content = f.read()
|
|
text, _ = extract_text_from_file(content, project.source_filename)
|
|
else:
|
|
raise HTTPException(status_code=400, detail="No document uploaded and no text provided. Upload a file or paste the brief.")
|
|
|
|
if len(text.strip()) < 20:
|
|
raise HTTPException(status_code=400, detail="Brief text is too short to analyze.")
|
|
|
|
analysis = await analyze_brief(db, project, text)
|
|
return {"status": "analyzed", "analysis": analysis}
|
|
|
|
|
|
@router.get("/{project_id}/client-assets", response_model=list[ClientAssetOut])
|
|
async def list_client_assets(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(ClientAsset)
|
|
.where(ClientAsset.project_id == project_id)
|
|
.order_by(ClientAsset.sort_order)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.put("/{project_id}/client-assets/{asset_id}", response_model=ClientAssetOut)
|
|
async def update_client_asset(
|
|
project_id: int,
|
|
asset_id: int,
|
|
data: ClientAssetUpdate,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id == asset_id, ClientAsset.project_id == project_id)
|
|
)
|
|
ca = result.scalar_one_or_none()
|
|
if not ca:
|
|
raise HTTPException(status_code=404, detail="Client asset not found")
|
|
|
|
if data.raw_name is not None:
|
|
ca.raw_name = data.raw_name
|
|
if data.raw_description is not None:
|
|
ca.raw_description = data.raw_description
|
|
if data.volume is not None:
|
|
ca.volume = data.volume
|
|
|
|
await db.commit()
|
|
await db.refresh(ca)
|
|
return ca
|
|
|
|
|
|
async def _background_match(project_id: int, asset_snapshots: list):
|
|
"""Run AI matching in the background (own DB session)."""
|
|
async with async_session() as db:
|
|
try:
|
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
return
|
|
|
|
# Reconstruct ClientAsset-like objects from snapshots for match_client_assets
|
|
ca_result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id.in_([s["id"] for s in asset_snapshots]))
|
|
.order_by(ClientAsset.sort_order)
|
|
)
|
|
client_assets = ca_result.scalars().all()
|
|
|
|
matches = await match_client_assets(db, project_id, client_assets)
|
|
|
|
await db.refresh(project)
|
|
project.status = ProjectStatus.REVIEW
|
|
await db.commit()
|
|
logger.info(f"Background match complete for project {project_id}: {len(matches)} matches")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background match error for project {project_id}: {e}")
|
|
try:
|
|
async with async_session() as db2:
|
|
result = await db2.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if project:
|
|
project.status = ProjectStatus.REVIEW
|
|
await db2.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@router.post("/{project_id}/match")
|
|
async def run_matching(
|
|
project_id: int,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Trigger AI matching for all client assets in this project."""
|
|
project = await _get_project(project_id, db)
|
|
|
|
# Get client assets
|
|
result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.project_id == project_id).order_by(ClientAsset.sort_order)
|
|
)
|
|
client_assets = result.scalars().all()
|
|
|
|
if not client_assets:
|
|
raise HTTPException(status_code=400, detail="No client assets to match. Upload a document first.")
|
|
|
|
# Snapshot IDs before clearing (ORM objects expire after commit)
|
|
asset_snapshots = [{"id": ca.id} for ca in client_assets]
|
|
|
|
# Clear existing matches
|
|
for ca in client_assets:
|
|
matches_result = await db.execute(select(Match).where(Match.client_asset_id == ca.id))
|
|
for m in matches_result.scalars().all():
|
|
await db.delete(m)
|
|
|
|
project.status = ProjectStatus.MATCHING
|
|
await db.commit()
|
|
|
|
# Run matching in background — return 202 immediately
|
|
background_tasks.add_task(_background_match, project_id, asset_snapshots)
|
|
|
|
return {
|
|
"message": f"Matching started for {len(client_assets)} client assets.",
|
|
"status": "matching",
|
|
}
|
|
|
|
|
|
@router.post("/{project_id}/match/cancel")
|
|
async def cancel_matching_endpoint(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Cancel an in-progress matching run."""
|
|
from app.services.ai_matching import cancel_matching
|
|
cancel_matching(project_id)
|
|
project = await _get_project(project_id, db)
|
|
project.status = ProjectStatus.REVIEW
|
|
await db.commit()
|
|
return {"detail": "Matching cancellation requested"}
|
|
|
|
|
|
@router.post("/{project_id}/refine")
|
|
async def refine_matches_endpoint(
|
|
project_id: int,
|
|
data: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Interpret a natural language instruction to refine matches."""
|
|
from app.services.match_refiner import refine_matches
|
|
instruction = data.get("instruction", "")
|
|
if not instruction:
|
|
raise HTTPException(status_code=400, detail="No instruction provided")
|
|
|
|
result = await refine_matches(db, project_id, instruction)
|
|
|
|
# If there are assets to re-match, trigger matching for just those
|
|
if result.get("rematch_count", 0) > 0:
|
|
rematch_ids = result["rematch_asset_ids"]
|
|
ca_result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id.in_(rematch_ids)).order_by(ClientAsset.sort_order)
|
|
)
|
|
client_assets = ca_result.scalars().all()
|
|
if client_assets:
|
|
matches = await match_client_assets(db, project_id, client_assets)
|
|
result["new_matches"] = len(matches)
|
|
|
|
return result
|
|
|
|
|
|
@router.post("/{project_id}/matches/{match_id}/feedback")
|
|
async def submit_match_feedback(
|
|
project_id: int,
|
|
match_id: int,
|
|
data: dict,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Store feedback on a match (confirm or reject) for the learning system."""
|
|
from app.models.feedback import MatchFeedback
|
|
|
|
result = await db.execute(select(Match).where(Match.id == match_id))
|
|
match = result.scalar_one_or_none()
|
|
if not match:
|
|
raise HTTPException(status_code=404, detail="Match not found")
|
|
|
|
# Get the client asset name for the feedback record
|
|
ca_result = await db.execute(select(ClientAsset).where(ClientAsset.id == match.client_asset_id))
|
|
ca = ca_result.scalar_one_or_none()
|
|
|
|
confirmed = data.get("confirmed", True)
|
|
comment = data.get("comment", "")
|
|
|
|
feedback = MatchFeedback(
|
|
client_term=(ca.raw_name or "").strip().lower() if ca else "",
|
|
client_description=ca.raw_description if ca else None,
|
|
gmal_asset_id=match.gmal_asset_id,
|
|
confirmed=confirmed,
|
|
user_comment=comment,
|
|
)
|
|
db.add(feedback)
|
|
await db.commit()
|
|
|
|
return {"detail": f"Feedback {'confirmed' if confirmed else 'rejected'} stored"}
|
|
|
|
|
|
@router.get("/{project_id}/matches", response_model=list[MatchOut])
|
|
async def list_matches(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Get all matches for a project, grouped by client asset."""
|
|
# Get client asset IDs for this project
|
|
ca_result = await db.execute(
|
|
select(ClientAsset.id).where(ClientAsset.project_id == project_id)
|
|
)
|
|
ca_ids = [r[0] for r in ca_result.all()]
|
|
|
|
if not ca_ids:
|
|
return []
|
|
|
|
result = await db.execute(
|
|
select(Match, GmalAsset)
|
|
.join(GmalAsset, Match.gmal_asset_id == GmalAsset.id)
|
|
.where(Match.client_asset_id.in_(ca_ids))
|
|
.order_by(Match.client_asset_id, Match.rank)
|
|
)
|
|
|
|
matches = []
|
|
for match, gmal in result.all():
|
|
matches.append(MatchOut(
|
|
id=match.id,
|
|
client_asset_id=match.client_asset_id,
|
|
gmal_asset_id=match.gmal_asset_id,
|
|
gmal_id=gmal.gmal_id,
|
|
gmal_name=gmal.asset_name,
|
|
gmal_unique_name=gmal.unique_name,
|
|
confidence=match.confidence.value,
|
|
confidence_score=float(match.confidence_score) if match.confidence_score else None,
|
|
ai_reasoning=match.ai_reasoning,
|
|
caveat_text=match.caveat_text,
|
|
is_selected=match.is_selected,
|
|
rank=match.rank,
|
|
))
|
|
|
|
return matches
|
|
|
|
|
|
@router.put("/{project_id}/matches/{match_id}/select")
|
|
async def select_match(
|
|
project_id: int,
|
|
match_id: int,
|
|
data: MatchSelectRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Select or deselect a match. Deselects other matches for the same client asset."""
|
|
result = await db.execute(select(Match).where(Match.id == match_id))
|
|
match = result.scalar_one_or_none()
|
|
if not match:
|
|
raise HTTPException(status_code=404, detail="Match not found")
|
|
|
|
if data.is_selected:
|
|
# Deselect all other matches for this client asset
|
|
siblings = await db.execute(
|
|
select(Match).where(Match.client_asset_id == match.client_asset_id)
|
|
)
|
|
for sibling in siblings.scalars().all():
|
|
sibling.is_selected = False
|
|
|
|
match.is_selected = data.is_selected
|
|
await db.commit()
|
|
|
|
return {"detail": "Match updated"}
|
|
|
|
|
|
@router.post("/{project_id}/matches/{client_asset_id}/manual")
|
|
async def manual_match(
|
|
project_id: int,
|
|
client_asset_id: int,
|
|
data: ManualMatchRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Manually assign a GMAL asset to a client asset."""
|
|
# Verify client asset belongs to project
|
|
ca_result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id == client_asset_id, ClientAsset.project_id == project_id)
|
|
)
|
|
ca = ca_result.scalar_one_or_none()
|
|
if not ca:
|
|
raise HTTPException(status_code=404, detail="Client asset not found")
|
|
|
|
# Verify GMAL asset exists
|
|
gmal_result = await db.execute(select(GmalAsset).where(GmalAsset.id == data.gmal_asset_id))
|
|
gmal = gmal_result.scalar_one_or_none()
|
|
if not gmal:
|
|
raise HTTPException(status_code=404, detail="GMAL asset not found")
|
|
|
|
# Deselect existing matches
|
|
existing = await db.execute(select(Match).where(Match.client_asset_id == client_asset_id))
|
|
for m in existing.scalars().all():
|
|
m.is_selected = False
|
|
|
|
# Create manual match
|
|
match = Match(
|
|
client_asset_id=client_asset_id,
|
|
gmal_asset_id=data.gmal_asset_id,
|
|
confidence=MatchConfidence.EXACT,
|
|
confidence_score=1.0,
|
|
ai_reasoning="Manually assigned by user",
|
|
caveat_text="",
|
|
is_selected=True,
|
|
rank=0,
|
|
)
|
|
db.add(match)
|
|
await db.commit()
|
|
|
|
return {"detail": f"Manually matched to {gmal.gmal_id}"}
|
|
|
|
|
|
async def _get_project(project_id: int, db: AsyncSession) -> Project:
|
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
return project
|