olivas/backend/app/api/endpoints/analysis.py
DJP 92062b254d Add score clarity, AI design score, image format fix, cost tracking
- Replace bare score badge with rich ScoreCard component showing
  color-coded score (green/amber/red), label, and hover tooltip
  explaining what the 0-100 Attention Focus score means
- Add AI Design Effectiveness Score (1-10) from Claude alongside
  qualitative insights, with score_reason explanation
- Fix image/png media type error by converting all images to PNG
  before sending to Claude API
- Save ai_score and ai_score_reason to DB
- Display AI score badge in InsightsPanel with color coding

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:37:03 -05:00

458 lines
16 KiB
Python

import io
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, UploadFile, Form
from fastapi.responses import StreamingResponse
from PIL import Image
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.dependencies import get_user_id
from app.models.analysis import Analysis
from app.models.project import Project
from app.schemas.analysis import AnalysisDetail, AnalysisStatus, AnalysisSummary
from app.services.storage import storage
router = APIRouter(tags=["analysis"])
ALLOWED_FORMATS = {"JPEG", "PNG", "TIFF", "WEBP", "BMP"}
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
@router.post("/projects/{project_id}/analyses", response_model=AnalysisStatus, status_code=202)
async def create_analysis(
project_id: str,
file: UploadFile,
background_tasks: BackgroundTasks,
name: str | None = Form(None),
model: str = Form("deepgaze_iie"),
db: AsyncSession = Depends(get_db),
x_user_id: str | None = Header(None),
):
user_id = get_user_id(x_user_id)
# Verify project belongs to user
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
result = await db.execute(stmt)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Read and validate image
image_data = await file.read()
if len(image_data) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large (max 50MB)")
try:
image = Image.open(io.BytesIO(image_data))
image.verify()
image = Image.open(io.BytesIO(image_data)) # re-open after verify
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
if image.format not in ALLOWED_FORMATS:
raise HTTPException(
status_code=400,
detail=f"Unsupported format: {image.format}. Allowed: {', '.join(ALLOWED_FORMATS)}",
)
# Create analysis record
analysis = Analysis(
project_id=project_id,
user_id=user_id,
name=name or file.filename or "Untitled",
model_used=model,
status="pending",
original_filename=file.filename or "upload",
image_width=image.width,
image_height=image.height,
file_format=image.format or "PNG",
storage_path=str(storage.base_dir),
)
db.add(analysis)
await db.flush()
await db.refresh(analysis)
# Save original image
await storage.save_bytes(image_data, analysis.id, "original.png")
# Save thumbnail
thumb = image.copy()
thumb.thumbnail((400, 400))
thumb_buffer = io.BytesIO()
thumb.save(thumb_buffer, format="PNG")
await storage.save_bytes(thumb_buffer.getvalue(), analysis.id, "thumbnail.png")
analysis_id = analysis.id
# Commit now so the background thread can see the record
await db.commit()
# Queue background processing (sync function runs in threadpool)
background_tasks.add_task(run_analysis_pipeline, analysis_id, image_data, model)
return AnalysisStatus(id=analysis_id, status="pending")
def run_analysis_pipeline(analysis_id: str, image_data: bytes, model_name: str):
"""Background task: full saliency analysis pipeline. Runs sync in threadpool."""
import asyncio
import logging
import numpy as np
from app.services.saliency.model_manager import model_manager
from app.services.image_processing import prepare_for_inference, upscale_saliency
from app.services.heatmap import generate_heatmap_overlay, generate_standalone_heatmap
from app.services.gaze_sequence import extract_gaze_sequence
logger = logging.getLogger("olivas.pipeline")
# Use sync DB connection for background thread
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.config import settings
sync_url = settings.DATABASE_URL.replace("+asyncpg", "").replace("postgresql://", "postgresql+psycopg2://")
# Use psycopg2 if available, otherwise fallback
try:
sync_engine = create_engine(settings.DATABASE_URL.replace("+asyncpg", "+psycopg2"))
except Exception:
sync_engine = create_engine(settings.DATABASE_URL.replace("+asyncpg", ""))
try:
with Session(sync_engine) as db:
analysis = db.get(Analysis, analysis_id)
analysis.status = "processing"
db.commit()
logger.info(f"Starting analysis {analysis_id}")
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# 1. Resize for inference
resized, scale = prepare_for_inference(image)
logger.info(f"Image resized: {image.size} -> {resized.size}")
# 2. Run saliency model
logger.info(f"Running {model_name} inference...")
saliency = model_manager.predict(resized, model_name)
logger.info("Inference complete")
# 3. Upscale to original dimensions
saliency_full = upscale_saliency(saliency, image.height, image.width)
# 4. Save raw saliency as .npy
np.save(str(storage.get_path(analysis_id, "saliency_raw.npy")), saliency_full)
# 5. Save saliency as grayscale PNG
saliency_uint8 = (saliency_full * 255).astype(np.uint8)
saliency_img = Image.fromarray(saliency_uint8, mode="L")
buf = io.BytesIO()
saliency_img.save(buf, format="PNG")
with open(storage.get_path(analysis_id, "saliency_gray.png"), "wb") as f:
f.write(buf.getvalue())
# 6. Generate heatmap overlay
heatmap_overlay = generate_heatmap_overlay(image, saliency_full)
buf = io.BytesIO()
heatmap_overlay.save(buf, format="PNG")
with open(storage.get_path(analysis_id, "heatmap_overlay.png"), "wb") as f:
f.write(buf.getvalue())
# 7. Generate standalone heatmap
heatmap_standalone = generate_standalone_heatmap(saliency_full)
buf = io.BytesIO()
heatmap_standalone.save(buf, format="PNG")
with open(storage.get_path(analysis_id, "heatmap_standalone.png"), "wb") as f:
f.write(buf.getvalue())
# 8. Extract gaze sequence
gaze_seq = extract_gaze_sequence(saliency_full, num_fixations=5)
# 9. Compute overall attention score
# Normalize saliency to a proper probability distribution
sal_sum = saliency_full.sum()
if sal_sum > 0:
prob_dist = saliency_full / sal_sum
prob_dist = prob_dist[prob_dist > 0] # remove zeros for log
entropy = -np.sum(prob_dist * np.log2(prob_dist))
max_entropy = np.log2(saliency_full.size)
concentration = (1 - entropy / max_entropy) * 100
overall_score = round(float(np.clip(concentration, 0, 100)), 1)
else:
overall_score = 0.0
# 10. Extract hotspots
hotspots = _extract_hotspots(saliency_full, num_hotspots=5)
# 11. Generate gaze sequence image
gaze_img = _draw_gaze_sequence(image, gaze_seq)
buf = io.BytesIO()
gaze_img.save(buf, format="PNG")
with open(storage.get_path(analysis_id, "gaze_sequence.png"), "wb") as f:
f.write(buf.getvalue())
# Update DB
analysis.status = "completed"
analysis.gaze_sequence = gaze_seq
analysis.hotspots = hotspots
analysis.overall_score = overall_score
db.commit()
logger.info(f"Analysis {analysis_id} completed (score={overall_score})")
except Exception as e:
logger.error(f"Analysis {analysis_id} failed: {e}", exc_info=True)
try:
with Session(sync_engine) as db:
analysis = db.get(Analysis, analysis_id)
if analysis:
analysis.status = "failed"
db.commit()
except Exception:
pass
def _extract_hotspots(saliency, num_hotspots=5):
import numpy as np
from scipy.ndimage import gaussian_filter
sal = saliency.copy()
h, w = sal.shape
hotspots = []
radius = int(max(h, w) * 0.08)
for i in range(num_hotspots):
smoothed = gaussian_filter(sal, sigma=max(h, w) * 0.015)
peak_idx = np.unravel_index(np.argmax(smoothed), smoothed.shape)
y, x = int(peak_idx[0]), int(peak_idx[1])
intensity = float(saliency[y, x])
# Bounding box around hotspot
x1 = max(0, x - radius)
y1 = max(0, y - radius)
x2 = min(w, x + radius)
y2 = min(h, y + radius)
hotspots.append({
"rank": i + 1,
"center_x": x,
"center_y": y,
"x": x1,
"y": y1,
"width": x2 - x1,
"height": y2 - y1,
"intensity": round(intensity, 4),
})
# Inhibition of return
yy, xx = np.ogrid[:h, :w]
mask = (xx - x) ** 2 + (yy - y) ** 2 <= radius ** 2
sal[mask] = 0.0
return hotspots
def _draw_gaze_sequence(image, gaze_seq):
from PIL import ImageDraw, ImageFont
img = image.copy()
draw = ImageDraw.Draw(img)
font = ImageFont.load_default(size=24)
colors = ["#FF4444", "#FF8800", "#FFCC00", "#44CC44", "#4488FF"]
for i, point in enumerate(gaze_seq):
x, y = point["x"], point["y"]
color = colors[i % len(colors)]
r = 25
# Draw circle
draw.ellipse([x - r, y - r, x + r, y + r], outline=color, width=3)
draw.text((x - 6, y - 12), str(point["rank"]), fill=color, font=font)
# Draw line to next point
if i < len(gaze_seq) - 1:
nx, ny = gaze_seq[i + 1]["x"], gaze_seq[i + 1]["y"]
draw.line([x, y, nx, ny], fill=color, width=2)
return img
@router.get("/analyses/ai-insights-available")
async def check_ai_insights_available():
"""Check if AI insights are available (API key configured)."""
from app.services.ai_insights import is_available
return {"available": is_available()}
@router.get("/analyses/{analysis_id}", response_model=AnalysisDetail)
async def get_analysis(
analysis_id: str,
db: AsyncSession = Depends(get_db),
x_user_id: str | None = Header(None),
):
user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
# Generate insights for completed analyses
insights = None
if analysis.status == "completed":
from app.services.insights import generate_insights
insights = generate_insights(analysis)
return AnalysisDetail(
id=analysis.id,
name=analysis.name,
model_used=analysis.model_used,
status=analysis.status,
original_filename=analysis.original_filename,
image_width=analysis.image_width,
image_height=analysis.image_height,
file_format=analysis.file_format,
overall_score=analysis.overall_score,
created_at=analysis.created_at,
gaze_sequence=analysis.gaze_sequence,
hotspots=analysis.hotspots,
insights=insights,
ai_insights=analysis.ai_insights,
ai_score=analysis.ai_score,
ai_score_reason=analysis.ai_score_reason,
ai_cost_usd=analysis.ai_cost_usd,
)
@router.get("/analyses/{analysis_id}/status", response_model=AnalysisStatus)
async def get_analysis_status(
analysis_id: str,
db: AsyncSession = Depends(get_db),
x_user_id: str | None = Header(None),
):
user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
return AnalysisStatus(id=analysis.id, status=analysis.status)
@router.get("/analyses/{analysis_id}/images/{image_type}")
async def get_analysis_image(
analysis_id: str,
image_type: str,
db: AsyncSession = Depends(get_db),
x_user_id: str | None = Header(None),
):
user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
file_map = {
"original": "original.png",
"thumbnail": "thumbnail.png",
"heatmap": "heatmap_overlay.png",
"heatmap-standalone": "heatmap_standalone.png",
"saliency-raw": "saliency_gray.png",
"gaze-sequence": "gaze_sequence.png",
}
filename = file_map.get(image_type)
if not filename:
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
if not storage.exists(analysis_id, filename):
raise HTTPException(status_code=404, detail=f"Image not yet available")
data = await storage.load_bytes(analysis_id, filename)
return StreamingResponse(io.BytesIO(data), media_type="image/png")
@router.post("/analyses/{analysis_id}/ai-insights")
async def generate_ai_insights_endpoint(
analysis_id: str,
db: AsyncSession = Depends(get_db),
x_user_id: str | None = Header(None),
):
"""Generate AI-powered insights for a completed analysis using Claude."""
from app.services.ai_insights import generate_ai_insights, is_available
if not is_available():
raise HTTPException(status_code=503, detail="AI insights not configured (missing API key)")
user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
if analysis.status != "completed":
raise HTTPException(status_code=400, detail="Analysis is not yet completed")
# Load images
try:
original_bytes = await storage.load_bytes(analysis_id, "original.png")
heatmap_bytes = await storage.load_bytes(analysis_id, "heatmap_overlay.png")
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Analysis images not found")
metadata = {
"overall_score": analysis.overall_score,
"hotspots": analysis.hotspots or [],
"gaze_sequence": analysis.gaze_sequence or [],
"image_width": analysis.image_width,
"image_height": analysis.image_height,
}
try:
result = generate_ai_insights(metadata, original_bytes, heatmap_bytes)
# Save to DB
analysis.ai_insights = result["insights"]
analysis.ai_score = result["ai_score"]
analysis.ai_score_reason = result["score_reason"]
analysis.ai_cost_usd = result["cost_usd"]
await db.flush()
# Invalidate cached PDF so next download includes AI insights
if storage.exists(analysis_id, "report.pdf"):
import os
try:
os.remove(storage.get_path(analysis_id, "report.pdf"))
except OSError:
pass
return {
"insights": result["insights"],
"ai_score": result["ai_score"],
"score_reason": result["score_reason"],
"cost_usd": result["cost_usd"],
"input_tokens": result["input_tokens"],
"output_tokens": result["output_tokens"],
}
except RuntimeError as e:
raise HTTPException(status_code=502, detail=str(e))
@router.delete("/analyses/{analysis_id}", status_code=204)
async def delete_analysis(
analysis_id: str,
db: AsyncSession = Depends(get_db),
x_user_id: str | None = Header(None),
):
user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="Analysis not found")
await storage.delete_analysis(analysis_id)
await db.delete(analysis)
await db.commit()