- Move AI parsing and matching into BackgroundTasks so both endpoints return immediately instead of blocking until Claude finishes (~60s+) - Frontend now polls project status after upload/match POST returns, keeping the spinner/progress UI working as before - Replace <a href> export links with programmatic Axios downloads to fix missing /gsb base path and missing auth token (401 in production) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
357 lines
13 KiB
Python
357 lines
13 KiB
Python
"""Client document upload and AI matching endpoints."""
|
|
|
|
import logging
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.database import get_db, async_session
|
|
from app.models.gmal import GmalAsset
|
|
from app.models.project import Project, ClientAsset, Match, ProjectStatus, MatchConfidence
|
|
from app.schemas.project import ClientAssetOut, ClientAssetUpdate, MatchOut, MatchSelectRequest, ManualMatchRequest
|
|
from app.services.doc_parser import extract_text_from_file, parse_text_with_ai
|
|
from app.services.ai_matching import match_client_assets
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _background_parse(project_id: int, filename: str, text: str, metadata: dict):
|
|
"""Run AI parsing and save results in the background (own DB session)."""
|
|
async with async_session() as db:
|
|
try:
|
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
return
|
|
|
|
# Stage 3: AI parsing
|
|
try:
|
|
extracted, usage_info = parse_text_with_ai(text)
|
|
except Exception as e:
|
|
logger.error(f"AI parsing failed for project {project_id}: {e}")
|
|
project.status = ProjectStatus.DRAFT
|
|
project.parse_stage = f"AI parsing failed: {str(e)}"
|
|
await db.commit()
|
|
return
|
|
|
|
# Save AI costs
|
|
project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0)
|
|
project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0)
|
|
project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0)
|
|
project.ai_call_count = (project.ai_call_count or 0) + 1
|
|
|
|
# Stage 4: Saving results
|
|
project.parse_stage = f"AI found {len(extracted)} assets. Saving..."
|
|
await db.commit()
|
|
|
|
# Clear existing client assets
|
|
existing = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.project_id == project_id)
|
|
)
|
|
for ca in existing.scalars().all():
|
|
await db.delete(ca)
|
|
|
|
# Create client asset records
|
|
assets = []
|
|
for idx, item in enumerate(extracted):
|
|
ca = ClientAsset(
|
|
project_id=project_id,
|
|
raw_name=item.get("name", "Unknown"),
|
|
raw_description=item.get("description", ""),
|
|
volume=item.get("volume", 1),
|
|
sort_order=idx + 1,
|
|
)
|
|
db.add(ca)
|
|
assets.append(ca)
|
|
|
|
project.status = ProjectStatus.REVIEW
|
|
project.parse_stage = f"Done! {len(assets)} assets extracted."
|
|
await db.commit()
|
|
logger.info(f"Background parse complete for project {project_id}: {len(assets)} assets")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background parse error for project {project_id}: {e}")
|
|
|
|
|
|
@router.post("/{project_id}/upload")
|
|
async def upload_client_document(
|
|
project_id: int,
|
|
background_tasks: BackgroundTasks,
|
|
file: UploadFile = File(...),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Upload a client document and extract assets using AI."""
|
|
project = await _get_project(project_id, db)
|
|
|
|
# Stage 1: Read file
|
|
content = await file.read()
|
|
project.source_filename = file.filename
|
|
project.status = ProjectStatus.PARSING
|
|
project.parse_stage = f"Uploading {file.filename}..."
|
|
await db.commit()
|
|
|
|
# Stage 2: Extract text (fast, synchronous)
|
|
project.parse_stage = "Extracting text from document..."
|
|
await db.commit()
|
|
|
|
try:
|
|
text, metadata = extract_text_from_file(content, file.filename)
|
|
except Exception as e:
|
|
project.status = ProjectStatus.DRAFT
|
|
project.parse_stage = None
|
|
await db.commit()
|
|
raise HTTPException(status_code=400, detail=f"Failed to extract text: {str(e)}")
|
|
|
|
sheets_info = f" ({metadata['sheet_count']} sheets)" if metadata['sheet_count'] else ""
|
|
project.parse_stage = f"Extracted {metadata['char_count']:,} characters{sheets_info}. Sending to AI..."
|
|
await db.commit()
|
|
|
|
# Stage 3+4: AI parsing runs in background — return 202 immediately
|
|
background_tasks.add_task(_background_parse, project_id, file.filename, text, metadata)
|
|
|
|
return {
|
|
"message": f"Document received. AI parsing started for {file.filename}.",
|
|
"status": "parsing",
|
|
}
|
|
|
|
|
|
@router.get("/{project_id}/client-assets", response_model=list[ClientAssetOut])
|
|
async def list_client_assets(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(ClientAsset)
|
|
.where(ClientAsset.project_id == project_id)
|
|
.order_by(ClientAsset.sort_order)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.put("/{project_id}/client-assets/{asset_id}", response_model=ClientAssetOut)
|
|
async def update_client_asset(
|
|
project_id: int,
|
|
asset_id: int,
|
|
data: ClientAssetUpdate,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id == asset_id, ClientAsset.project_id == project_id)
|
|
)
|
|
ca = result.scalar_one_or_none()
|
|
if not ca:
|
|
raise HTTPException(status_code=404, detail="Client asset not found")
|
|
|
|
if data.raw_name is not None:
|
|
ca.raw_name = data.raw_name
|
|
if data.raw_description is not None:
|
|
ca.raw_description = data.raw_description
|
|
if data.volume is not None:
|
|
ca.volume = data.volume
|
|
|
|
await db.commit()
|
|
await db.refresh(ca)
|
|
return ca
|
|
|
|
|
|
async def _background_match(project_id: int, asset_snapshots: list):
|
|
"""Run AI matching in the background (own DB session)."""
|
|
async with async_session() as db:
|
|
try:
|
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
return
|
|
|
|
# Reconstruct ClientAsset-like objects from snapshots for match_client_assets
|
|
ca_result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id.in_([s["id"] for s in asset_snapshots]))
|
|
.order_by(ClientAsset.sort_order)
|
|
)
|
|
client_assets = ca_result.scalars().all()
|
|
|
|
matches = await match_client_assets(db, project_id, client_assets)
|
|
|
|
await db.refresh(project)
|
|
project.status = ProjectStatus.REVIEW
|
|
await db.commit()
|
|
logger.info(f"Background match complete for project {project_id}: {len(matches)} matches")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Background match error for project {project_id}: {e}")
|
|
try:
|
|
async with async_session() as db2:
|
|
result = await db2.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if project:
|
|
project.status = ProjectStatus.REVIEW
|
|
await db2.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@router.post("/{project_id}/match")
|
|
async def run_matching(
|
|
project_id: int,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Trigger AI matching for all client assets in this project."""
|
|
project = await _get_project(project_id, db)
|
|
|
|
# Get client assets
|
|
result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.project_id == project_id).order_by(ClientAsset.sort_order)
|
|
)
|
|
client_assets = result.scalars().all()
|
|
|
|
if not client_assets:
|
|
raise HTTPException(status_code=400, detail="No client assets to match. Upload a document first.")
|
|
|
|
# Snapshot IDs before clearing (ORM objects expire after commit)
|
|
asset_snapshots = [{"id": ca.id} for ca in client_assets]
|
|
|
|
# Clear existing matches
|
|
for ca in client_assets:
|
|
matches_result = await db.execute(select(Match).where(Match.client_asset_id == ca.id))
|
|
for m in matches_result.scalars().all():
|
|
await db.delete(m)
|
|
|
|
project.status = ProjectStatus.MATCHING
|
|
await db.commit()
|
|
|
|
# Run matching in background — return 202 immediately
|
|
background_tasks.add_task(_background_match, project_id, asset_snapshots)
|
|
|
|
return {
|
|
"message": f"Matching started for {len(client_assets)} client assets.",
|
|
"status": "matching",
|
|
}
|
|
|
|
|
|
@router.post("/{project_id}/match/cancel")
|
|
async def cancel_matching_endpoint(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Cancel an in-progress matching run."""
|
|
from app.services.ai_matching import cancel_matching
|
|
cancel_matching(project_id)
|
|
project = await _get_project(project_id, db)
|
|
project.status = ProjectStatus.REVIEW
|
|
await db.commit()
|
|
return {"detail": "Matching cancellation requested"}
|
|
|
|
|
|
@router.get("/{project_id}/matches", response_model=list[MatchOut])
|
|
async def list_matches(project_id: int, db: AsyncSession = Depends(get_db)):
|
|
"""Get all matches for a project, grouped by client asset."""
|
|
# Get client asset IDs for this project
|
|
ca_result = await db.execute(
|
|
select(ClientAsset.id).where(ClientAsset.project_id == project_id)
|
|
)
|
|
ca_ids = [r[0] for r in ca_result.all()]
|
|
|
|
if not ca_ids:
|
|
return []
|
|
|
|
result = await db.execute(
|
|
select(Match, GmalAsset)
|
|
.join(GmalAsset, Match.gmal_asset_id == GmalAsset.id)
|
|
.where(Match.client_asset_id.in_(ca_ids))
|
|
.order_by(Match.client_asset_id, Match.rank)
|
|
)
|
|
|
|
matches = []
|
|
for match, gmal in result.all():
|
|
matches.append(MatchOut(
|
|
id=match.id,
|
|
client_asset_id=match.client_asset_id,
|
|
gmal_asset_id=match.gmal_asset_id,
|
|
gmal_id=gmal.gmal_id,
|
|
gmal_name=gmal.asset_name,
|
|
gmal_unique_name=gmal.unique_name,
|
|
confidence=match.confidence.value,
|
|
confidence_score=float(match.confidence_score) if match.confidence_score else None,
|
|
ai_reasoning=match.ai_reasoning,
|
|
caveat_text=match.caveat_text,
|
|
is_selected=match.is_selected,
|
|
rank=match.rank,
|
|
))
|
|
|
|
return matches
|
|
|
|
|
|
@router.put("/{project_id}/matches/{match_id}/select")
|
|
async def select_match(
|
|
project_id: int,
|
|
match_id: int,
|
|
data: MatchSelectRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Select or deselect a match. Deselects other matches for the same client asset."""
|
|
result = await db.execute(select(Match).where(Match.id == match_id))
|
|
match = result.scalar_one_or_none()
|
|
if not match:
|
|
raise HTTPException(status_code=404, detail="Match not found")
|
|
|
|
if data.is_selected:
|
|
# Deselect all other matches for this client asset
|
|
siblings = await db.execute(
|
|
select(Match).where(Match.client_asset_id == match.client_asset_id)
|
|
)
|
|
for sibling in siblings.scalars().all():
|
|
sibling.is_selected = False
|
|
|
|
match.is_selected = data.is_selected
|
|
await db.commit()
|
|
|
|
return {"detail": "Match updated"}
|
|
|
|
|
|
@router.post("/{project_id}/matches/{client_asset_id}/manual")
|
|
async def manual_match(
|
|
project_id: int,
|
|
client_asset_id: int,
|
|
data: ManualMatchRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Manually assign a GMAL asset to a client asset."""
|
|
# Verify client asset belongs to project
|
|
ca_result = await db.execute(
|
|
select(ClientAsset).where(ClientAsset.id == client_asset_id, ClientAsset.project_id == project_id)
|
|
)
|
|
ca = ca_result.scalar_one_or_none()
|
|
if not ca:
|
|
raise HTTPException(status_code=404, detail="Client asset not found")
|
|
|
|
# Verify GMAL asset exists
|
|
gmal_result = await db.execute(select(GmalAsset).where(GmalAsset.id == data.gmal_asset_id))
|
|
gmal = gmal_result.scalar_one_or_none()
|
|
if not gmal:
|
|
raise HTTPException(status_code=404, detail="GMAL asset not found")
|
|
|
|
# Deselect existing matches
|
|
existing = await db.execute(select(Match).where(Match.client_asset_id == client_asset_id))
|
|
for m in existing.scalars().all():
|
|
m.is_selected = False
|
|
|
|
# Create manual match
|
|
match = Match(
|
|
client_asset_id=client_asset_id,
|
|
gmal_asset_id=data.gmal_asset_id,
|
|
confidence=MatchConfidence.EXACT,
|
|
confidence_score=1.0,
|
|
ai_reasoning="Manually assigned by user",
|
|
caveat_text="",
|
|
is_selected=True,
|
|
rank=0,
|
|
)
|
|
db.add(match)
|
|
await db.commit()
|
|
|
|
return {"detail": f"Manually matched to {gmal.gmal_id}"}
|
|
|
|
|
|
async def _get_project(project_id: int, db: AsyncSession) -> Project:
|
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
return project
|