diff --git a/backend/app/api/matching.py b/backend/app/api/matching.py index be2c5cb..0b26ce8 100644 --- a/backend/app/api/matching.py +++ b/backend/app/api/matching.py @@ -2,11 +2,11 @@ import logging -from fastapi import APIRouter, Depends, HTTPException, UploadFile, File +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.database import get_db +from app.database import get_db, async_session from app.models.gmal import GmalAsset from app.models.project import Project, ClientAsset, Match, ProjectStatus, MatchConfidence from app.schemas.project import ClientAssetOut, ClientAssetUpdate, MatchOut, MatchSelectRequest, ManualMatchRequest @@ -17,23 +17,85 @@ router = APIRouter() logger = logging.getLogger(__name__) +async def _background_parse(project_id: int, filename: str, text: str, metadata: dict): + """Run AI parsing and save results in the background (own DB session).""" + async with async_session() as db: + try: + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + if not project: + return + + # Stage 3: AI parsing + try: + extracted, usage_info = parse_text_with_ai(text) + except Exception as e: + logger.error(f"AI parsing failed for project {project_id}: {e}") + project.status = ProjectStatus.DRAFT + project.parse_stage = f"AI parsing failed: {str(e)}" + await db.commit() + return + + # Save AI costs + project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0) + project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0) + project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0) + project.ai_call_count = (project.ai_call_count or 0) + 1 + + # Stage 4: Saving results + project.parse_stage = f"AI found {len(extracted)} assets. Saving..." + await db.commit() + + # Clear existing client assets + existing = await db.execute( + select(ClientAsset).where(ClientAsset.project_id == project_id) + ) + for ca in existing.scalars().all(): + await db.delete(ca) + + # Create client asset records (skip zero-quantity assets) + assets = [] + for idx, item in enumerate(extracted): + volume = item.get("volume", 1) + if volume <= 0: + continue + ca = ClientAsset( + project_id=project_id, + raw_name=item.get("name", "Unknown"), + raw_description=item.get("description", ""), + volume=volume, + sort_order=idx + 1, + ) + db.add(ca) + assets.append(ca) + + project.status = ProjectStatus.REVIEW + project.parse_stage = f"Done! {len(assets)} assets extracted." + await db.commit() + logger.info(f"Background parse complete for project {project_id}: {len(assets)} assets") + + except Exception as e: + logger.error(f"Background parse error for project {project_id}: {e}") + + @router.post("/{project_id}/upload") async def upload_client_document( project_id: int, + background_tasks: BackgroundTasks, file: UploadFile = File(...), db: AsyncSession = Depends(get_db), ): """Upload a client document and extract assets using AI.""" project = await _get_project(project_id, db) - # Stage 1: Uploading + # Stage 1: Read file content = await file.read() project.source_filename = file.filename project.status = ProjectStatus.PARSING project.parse_stage = f"Uploading {file.filename}..." await db.commit() - # Stage 2: Extracting text + # Stage 2: Extract text (fast, synchronous) project.parse_stage = "Extracting text from document..." await db.commit() @@ -49,57 +111,12 @@ async def upload_client_document( project.parse_stage = f"Extracted {metadata['char_count']:,} characters{sheets_info}. Sending to AI..." await db.commit() - # Stage 3: AI parsing - try: - extracted, usage_info = parse_text_with_ai(text) - except Exception as e: - logger.error(f"AI parsing failed: {e}") - project.status = ProjectStatus.DRAFT - project.parse_stage = None - await db.commit() - raise HTTPException(status_code=400, detail=f"Failed to parse document: {str(e)}") - - # Save AI costs - project.ai_input_tokens = (project.ai_input_tokens or 0) + usage_info.get("input_tokens", 0) - project.ai_output_tokens = (project.ai_output_tokens or 0) + usage_info.get("output_tokens", 0) - project.ai_cost_usd = float(project.ai_cost_usd or 0) + usage_info.get("cost_usd", 0) - project.ai_call_count = (project.ai_call_count or 0) + 1 - - # Stage 4: Saving results - project.parse_stage = f"AI found {len(extracted)} assets. Saving..." - await db.commit() - - # Clear existing client assets - existing = await db.execute( - select(ClientAsset).where(ClientAsset.project_id == project_id) - ) - for ca in existing.scalars().all(): - await db.delete(ca) - - # Create client asset records - assets = [] - for idx, item in enumerate(extracted): - ca = ClientAsset( - project_id=project_id, - raw_name=item.get("name", "Unknown"), - raw_description=item.get("description", ""), - volume=item.get("volume", 1), - sort_order=idx + 1, - ) - db.add(ca) - assets.append(ca) - - project.status = ProjectStatus.REVIEW - project.parse_stage = f"Done! {len(assets)} assets extracted." - await db.commit() + # Stage 3+4: AI parsing runs in background — return 202 immediately + background_tasks.add_task(_background_parse, project_id, file.filename, text, metadata) return { - "message": f"Extracted {len(assets)} assets from {file.filename}", - "asset_count": len(assets), - "assets": [ - {"name": a.raw_name, "description": a.raw_description, "volume": a.volume} - for a in assets - ], + "message": f"Document received. AI parsing started for {file.filename}.", + "status": "parsing", } @@ -139,8 +156,48 @@ async def update_client_asset( return ca +async def _background_match(project_id: int, asset_snapshots: list): + """Run AI matching in the background (own DB session).""" + async with async_session() as db: + try: + result = await db.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + if not project: + return + + # Reconstruct ClientAsset-like objects from snapshots for match_client_assets + ca_result = await db.execute( + select(ClientAsset).where(ClientAsset.id.in_([s["id"] for s in asset_snapshots])) + .order_by(ClientAsset.sort_order) + ) + client_assets = ca_result.scalars().all() + + matches = await match_client_assets(db, project_id, client_assets) + + await db.refresh(project) + project.status = ProjectStatus.REVIEW + await db.commit() + logger.info(f"Background match complete for project {project_id}: {len(matches)} matches") + + except Exception as e: + logger.error(f"Background match error for project {project_id}: {e}") + try: + async with async_session() as db2: + result = await db2.execute(select(Project).where(Project.id == project_id)) + project = result.scalar_one_or_none() + if project: + project.status = ProjectStatus.REVIEW + await db2.commit() + except Exception: + pass + + @router.post("/{project_id}/match") -async def run_matching(project_id: int, db: AsyncSession = Depends(get_db)): +async def run_matching( + project_id: int, + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), +): """Trigger AI matching for all client assets in this project.""" project = await _get_project(project_id, db) @@ -153,6 +210,9 @@ async def run_matching(project_id: int, db: AsyncSession = Depends(get_db)): if not client_assets: raise HTTPException(status_code=400, detail="No client assets to match. Upload a document first.") + # Snapshot IDs before clearing (ORM objects expire after commit) + asset_snapshots = [{"id": ca.id} for ca in client_assets] + # Clear existing matches for ca in client_assets: matches_result = await db.execute(select(Match).where(Match.client_asset_id == ca.id)) @@ -162,17 +222,12 @@ async def run_matching(project_id: int, db: AsyncSession = Depends(get_db)): project.status = ProjectStatus.MATCHING await db.commit() - # Run matching (batched, parallel, commits per batch) - matches = await match_client_assets(db, project_id, client_assets) - - # Refresh project and set final status - await db.refresh(project) - project.status = ProjectStatus.REVIEW - await db.commit() + # Run matching in background — return 202 immediately + background_tasks.add_task(_background_match, project_id, asset_snapshots) return { - "message": f"Matched {len(client_assets)} client assets", - "total_matches": len(matches), + "message": f"Matching started for {len(client_assets)} client assets.", + "status": "matching", } diff --git a/backend/app/middleware/auth.py b/backend/app/middleware/auth.py index 77a0e1e..1eb29e0 100644 --- a/backend/app/middleware/auth.py +++ b/backend/app/middleware/auth.py @@ -1,6 +1,5 @@ import os import httpx -from functools import lru_cache from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jose import jwt, JWTError @@ -13,22 +12,20 @@ ISSUER = f"https://login.microsoftonline.com/{TENANT_ID}/v2.0" bearer_scheme = HTTPBearer(auto_error=False) - -@lru_cache(maxsize=1) -def _fetch_jwks() -> dict: - """Fetch JWKS from Azure. Cached in process memory; restart to refresh.""" - response = httpx.get(JWKS_URL, timeout=10) - response.raise_for_status() - return response.json() +# Module-level cache — populated once per process, never blocks the event loop +_jwks_cache: dict | None = None -def _get_jwks() -> dict: - try: - return _fetch_jwks() - except Exception: - # Clear cache and retry once on failure - _fetch_jwks.cache_clear() - return _fetch_jwks() +async def _get_jwks() -> dict: + """Fetch JWKS from Azure using async HTTP. Cached in process memory.""" + global _jwks_cache + if _jwks_cache is not None: + return _jwks_cache + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(JWKS_URL) + response.raise_for_status() + _jwks_cache = response.json() + return _jwks_cache async def get_current_user( @@ -42,12 +39,21 @@ async def get_current_user( token = credentials.credentials try: - jwks = _get_jwks() + jwks = await _get_jwks() header = jwt.get_unverified_header(token) key = next( (k for k in jwks["keys"] if k.get("kid") == header.get("kid")), None, ) + if key is None: + # Key not in cache — fetch fresh JWKS once (keys can rotate) + global _jwks_cache + _jwks_cache = None + jwks = await _get_jwks() + key = next( + (k for k in jwks["keys"] if k.get("kid") == header.get("kid")), + None, + ) if key is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown signing key") @@ -62,7 +68,11 @@ async def get_current_user( return { "oid": payload.get("oid"), "name": payload.get("name"), - "email": payload.get("preferred_username") or payload.get("email"), + "email": ( + payload.get("preferred_username") + or payload.get("upn") + or payload.get("email") + ), } except JWTError as e: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {e}") diff --git a/backend/app/services/export_excel.py b/backend/app/services/export_excel.py index 387e719..6d7db80 100644 --- a/backend/app/services/export_excel.py +++ b/backend/app/services/export_excel.py @@ -64,10 +64,29 @@ async def export_ratecard_excel(db: AsyncSession, project: Project, efficiency_l gmals_result = await db.execute(select(GmalAsset).where(GmalAsset.id.in_(gmal_ids))) gmals = {g.id: g for g in gmals_result.scalars().all()} + # Load selected matches for caveat lookup + matches_result = await db.execute( + select(Match).where( + Match.client_asset_id.in_(asset_ids), + Match.is_selected == True, + ) + ) + selected_matches = matches_result.scalars().all() + + caveat_by_asset = {} + for m in selected_matches: + parts = [] + if m.caveat_text: + parts.append(m.caveat_text) + gmal = gmals.get(m.gmal_asset_id) + if gmal and gmal.caveats: + parts.append(f"GMAL Standard Caveats: {gmal.caveats}") + caveat_by_asset[m.client_asset_id] = "\n\n".join(parts) + # Sheet 1: Ratecard Summary (roles x assets matrix) ws1 = wb.active ws1.title = "Ratecard Summary" - _build_ratecard_sheet(ws1, lines, roles, client_assets, gmals) + _build_ratecard_sheet(ws1, lines, roles, client_assets, gmals, caveat_by_asset) # Sheet 2: Asset Detail ws2 = wb.create_sheet("Asset Detail") @@ -86,8 +105,10 @@ async def export_ratecard_excel(db: AsyncSession, project: Project, efficiency_l return _workbook_to_bytes(wb) -def _build_ratecard_sheet(ws, lines, roles, client_assets, gmals): +def _build_ratecard_sheet(ws, lines, roles, client_assets, gmals, caveats: dict | None = None): """Build the main ratecard matrix: rows=roles, cols=client assets.""" + if caveats is None: + caveats = {} # Get unique sorted client assets and roles asset_ids_ordered = sorted(client_assets.keys()) role_ids_ordered = sorted(roles.keys(), key=lambda rid: (roles[rid].discipline, roles[rid].sort_order or 0)) @@ -127,9 +148,24 @@ def _build_ratecard_sheet(ws, lines, roles, client_assets, gmals): ws.cell(row=1, column=total_col, value="Total Hours").font = HEADER_FONT ws.cell(row=1, column=total_col).fill = HEADER_FILL + # Caveats row (row 2) + CAVEAT_FONT = Font(italic=True, size=9, color="555555") + CAVEAT_FILL = PatternFill(start_color="FFFBF0", end_color="FFFBF0", fill_type="solid") + ws.cell(row=2, column=1, value="").fill = CAVEAT_FILL + ws.cell(row=2, column=2, value="Assumptions / Caveats").font = Font(italic=True, bold=True, size=9, color="92400E") + ws.cell(row=2, column=2).fill = CAVEAT_FILL + for col_idx, asset_id in enumerate(asset_ids_ordered, 3): + caveat = caveats.get(asset_id, "") + cell = ws.cell(row=2, column=col_idx, value=caveat) + cell.font = CAVEAT_FONT + cell.fill = CAVEAT_FILL + cell.alignment = Alignment(wrap_text=True, vertical="top") + ws.cell(row=2, column=total_col).fill = CAVEAT_FILL + ws.row_dimensions[2].height = 60 + # Data rows current_discipline = None - row_idx = 2 + row_idx = 3 for role_id in role_ids_ordered: role = roles[role_id] @@ -181,7 +217,7 @@ def _build_ratecard_sheet(ws, lines, roles, client_assets, gmals): async def _build_asset_detail_sheet(ws, db, project, client_assets, gmals): """Build the asset detail sheet showing matches and caveats.""" - headers = ["Client Asset", "Volume", "Matched GMAL", "GMAL Name", "Confidence", "Score", "Caveats"] + headers = ["Client Asset", "Volume", "Matched GMAL", "GMAL Name", "Confidence", "Score", "Match Caveats", "GMAL Standard Caveats"] for col_idx, header in enumerate(headers, 1): cell = ws.cell(row=1, column=col_idx, value=header) cell.font = HEADER_FONT @@ -213,13 +249,17 @@ async def _build_asset_detail_sheet(ws, db, project, client_assets, gmals): ws.cell(row=row_idx, column=5, value=match.confidence.value) ws.cell(row=row_idx, column=6, value=float(match.confidence_score) if match.confidence_score else 0) ws.cell(row=row_idx, column=7, value=match.caveat_text or "") + ws.cell(row=row_idx, column=7).alignment = Alignment(wrap_text=True, vertical="top") + gmal_caveats = (gmal.caveats or "") if gmal else "" + ws.cell(row=row_idx, column=8, value=gmal_caveats) + ws.cell(row=row_idx, column=8).alignment = Alignment(wrap_text=True, vertical="top") else: ws.cell(row=row_idx, column=3, value="No match") row_idx += 1 # Column widths - widths = [30, 10, 15, 40, 12, 10, 60] + widths = [30, 10, 15, 40, 12, 10, 60, 60] for i, w in enumerate(widths, 1): ws.column_dimensions[get_column_letter(i)].width = w diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 56eb8cf..bff2533 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -182,8 +182,10 @@ function NavBar() { const { instance, accounts } = useMsal(); const user = accounts[0]; - function handleLogout() { - instance.logoutRedirect({ postLogoutRedirectUri: '/gsb' }); + async function handleLogout() { + // Clear local MSAL cache only — does not sign out of the Microsoft account + await instance.clearCache(); + window.location.href = '/gsb'; } return ( diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 90f8eda..5e8f02f 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -13,10 +13,11 @@ api.interceptors.request.use(async (config) => { try { const result = await msalInstance.acquireTokenSilent({ - ...loginRequest, + scopes: ['openid', 'profile', 'email'], account: accounts[0], }); - config.headers.Authorization = `Bearer ${result.accessToken}`; + // ID token has audience=CLIENT_ID so the backend can validate it + config.headers.Authorization = `Bearer ${result.idToken}`; } catch { // Token expired or failed — trigger interactive login await msalInstance.loginRedirect(loginRequest); diff --git a/frontend/src/auth/msalConfig.ts b/frontend/src/auth/msalConfig.ts index fe93485..651dd72 100644 --- a/frontend/src/auth/msalConfig.ts +++ b/frontend/src/auth/msalConfig.ts @@ -4,8 +4,8 @@ export const msalConfig: Configuration = { auth: { clientId: '9079054c-9620-4757-a256-23413042f1ef', authority: 'https://login.microsoftonline.com/e519c2e6-bc6d-4fdf-8d9c-923c2f002385', - redirectUri: 'https://optical-dev.oliver.solutions/gsb', - postLogoutRedirectUri: 'https://optical-dev.oliver.solutions/gsb', + redirectUri: 'https://optical-dev.oliver.solutions/gsb/', + postLogoutRedirectUri: 'https://optical-dev.oliver.solutions/gsb/', }, cache: { cacheLocation: 'localStorage', diff --git a/frontend/src/pages/ProjectView.tsx b/frontend/src/pages/ProjectView.tsx index b688deb..325015c 100644 --- a/frontend/src/pages/ProjectView.tsx +++ b/frontend/src/pages/ProjectView.tsx @@ -91,53 +91,64 @@ export default function ProjectView() { setUploading(true); setUploadStage(`Uploading ${file.name}...`); - // Poll project status for stage updates + try { + const form = new FormData(); + form.append('file', file); + await api.post(`/projects/${id}/upload`, form); + } catch (err: any) { + alert(`Upload failed: ${err.response?.data?.detail || err.message}`); + setUploading(false); + setUploadStage(''); + return; + } + + // Poll until background parsing completes (status leaves 'parsing') const pollInterval = setInterval(async () => { try { const res = await api.get(`/projects/${id}`); if (res.data.parse_stage) { setUploadStage(res.data.parse_stage); } + if (res.data.status !== 'parsing') { + clearInterval(pollInterval); + setUploading(false); + setUploadStage(''); + await loadProject(); + setTab('matches'); + } } catch {} }, 1500); - - try { - const form = new FormData(); - form.append('file', file); - await api.post(`/projects/${id}/upload`, form); - await loadProject(); - setTab('matches'); - } catch (err: any) { - alert(`Upload failed: ${err.response?.data?.detail || err.message}`); - } finally { - clearInterval(pollInterval); - setUploading(false); - setUploadStage(''); - } } async function handleMatch() { setMatching(true); - // Start polling for matches while the request runs - const pollInterval = setInterval(async () => { - try { - const matchRes = await api.get(`/projects/${id}/matches`); - setMatches(matchRes.data); - } catch {} - }, 3000); try { await api.post(`/projects/${id}/match`); - await loadProject(); } catch (err: any) { if (!err.message?.includes('cancel')) { alert(`Matching failed: ${err.response?.data?.detail || err.message}`); } - await loadProject(); - } finally { - clearInterval(pollInterval); setMatching(false); + await loadProject(); + return; } + + // Poll until background matching completes (status leaves 'matching') + const pollInterval = setInterval(async () => { + try { + const [matchRes, projRes] = await Promise.all([ + api.get(`/projects/${id}/matches`), + api.get(`/projects/${id}`), + ]); + setMatches(matchRes.data); + if (projRes.data.status !== 'matching') { + clearInterval(pollInterval); + setMatching(false); + await loadProject(); + } + } catch {} + }, 3000); } async function handleCancelMatch() { @@ -184,10 +195,32 @@ export default function ProjectView() { }); } - function getExcelExportUrl() { + async function downloadFile(url: string, filename: string) { + try { + const response = await api.get(url, { responseType: 'blob' }); + const blob = new Blob([response.data], { type: response.headers['content-type'] }); + const objectUrl = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = objectUrl; + a.download = filename; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(objectUrl); + } catch (err: any) { + alert(`Export failed: ${err.response?.data?.detail || err.message}`); + } + } + + function handleExcelExport() { const levels = Array.from(selectedEfficiencyLevels).sort().join(','); - const base = `/api/projects/${id}/ratecard/export/excel`; - return levels ? `${base}?efficiency_levels=${levels}` : base; + const base = `/projects/${id}/ratecard/export/excel`; + const url = levels ? `${base}?efficiency_levels=${levels}` : base; + downloadFile(url, `${project?.name || 'ratecard'}.xlsx`); + } + + function handlePdfExport() { + downloadFile(`/projects/${id}/ratecard/export/pdf`, `${project?.name || 'caveats'}_caveats.pdf`); } async function handleDelete() { @@ -455,12 +488,12 @@ export default function ProjectView() { {ratecard.total_assets} assets
@@ -575,12 +608,12 @@ export default function ProjectView() {