- Replace bare score badge with rich ScoreCard component showing color-coded score (green/amber/red), label, and hover tooltip explaining what the 0-100 Attention Focus score means - Add AI Design Effectiveness Score (1-10) from Claude alongside qualitative insights, with score_reason explanation - Fix image/png media type error by converting all images to PNG before sending to Claude API - Save ai_score and ai_score_reason to DB - Display AI score badge in InsightsPanel with color coding Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
458 lines
16 KiB
Python
458 lines
16 KiB
Python
import io
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, UploadFile, Form
|
|
from fastapi.responses import StreamingResponse
|
|
from PIL import Image
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.session import get_db
|
|
from app.dependencies import get_user_id
|
|
from app.models.analysis import Analysis
|
|
from app.models.project import Project
|
|
from app.schemas.analysis import AnalysisDetail, AnalysisStatus, AnalysisSummary
|
|
from app.services.storage import storage
|
|
|
|
router = APIRouter(tags=["analysis"])
|
|
|
|
ALLOWED_FORMATS = {"JPEG", "PNG", "TIFF", "WEBP", "BMP"}
|
|
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
|
|
|
|
|
@router.post("/projects/{project_id}/analyses", response_model=AnalysisStatus, status_code=202)
|
|
async def create_analysis(
|
|
project_id: str,
|
|
file: UploadFile,
|
|
background_tasks: BackgroundTasks,
|
|
name: str | None = Form(None),
|
|
model: str = Form("deepgaze_iie"),
|
|
db: AsyncSession = Depends(get_db),
|
|
x_user_id: str | None = Header(None),
|
|
):
|
|
user_id = get_user_id(x_user_id)
|
|
|
|
# Verify project belongs to user
|
|
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
|
|
# Read and validate image
|
|
image_data = await file.read()
|
|
if len(image_data) > MAX_FILE_SIZE:
|
|
raise HTTPException(status_code=413, detail="File too large (max 50MB)")
|
|
|
|
try:
|
|
image = Image.open(io.BytesIO(image_data))
|
|
image.verify()
|
|
image = Image.open(io.BytesIO(image_data)) # re-open after verify
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid image file")
|
|
|
|
if image.format not in ALLOWED_FORMATS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported format: {image.format}. Allowed: {', '.join(ALLOWED_FORMATS)}",
|
|
)
|
|
|
|
# Create analysis record
|
|
analysis = Analysis(
|
|
project_id=project_id,
|
|
user_id=user_id,
|
|
name=name or file.filename or "Untitled",
|
|
model_used=model,
|
|
status="pending",
|
|
original_filename=file.filename or "upload",
|
|
image_width=image.width,
|
|
image_height=image.height,
|
|
file_format=image.format or "PNG",
|
|
storage_path=str(storage.base_dir),
|
|
)
|
|
db.add(analysis)
|
|
await db.flush()
|
|
await db.refresh(analysis)
|
|
|
|
# Save original image
|
|
await storage.save_bytes(image_data, analysis.id, "original.png")
|
|
|
|
# Save thumbnail
|
|
thumb = image.copy()
|
|
thumb.thumbnail((400, 400))
|
|
thumb_buffer = io.BytesIO()
|
|
thumb.save(thumb_buffer, format="PNG")
|
|
await storage.save_bytes(thumb_buffer.getvalue(), analysis.id, "thumbnail.png")
|
|
|
|
analysis_id = analysis.id
|
|
|
|
# Commit now so the background thread can see the record
|
|
await db.commit()
|
|
|
|
# Queue background processing (sync function runs in threadpool)
|
|
background_tasks.add_task(run_analysis_pipeline, analysis_id, image_data, model)
|
|
|
|
return AnalysisStatus(id=analysis_id, status="pending")
|
|
|
|
|
|
def run_analysis_pipeline(analysis_id: str, image_data: bytes, model_name: str):
|
|
"""Background task: full saliency analysis pipeline. Runs sync in threadpool."""
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
from app.services.saliency.model_manager import model_manager
|
|
from app.services.image_processing import prepare_for_inference, upscale_saliency
|
|
from app.services.heatmap import generate_heatmap_overlay, generate_standalone_heatmap
|
|
from app.services.gaze_sequence import extract_gaze_sequence
|
|
|
|
logger = logging.getLogger("olivas.pipeline")
|
|
|
|
# Use sync DB connection for background thread
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session
|
|
from app.config import settings
|
|
|
|
sync_url = settings.DATABASE_URL.replace("+asyncpg", "").replace("postgresql://", "postgresql+psycopg2://")
|
|
# Use psycopg2 if available, otherwise fallback
|
|
try:
|
|
sync_engine = create_engine(settings.DATABASE_URL.replace("+asyncpg", "+psycopg2"))
|
|
except Exception:
|
|
sync_engine = create_engine(settings.DATABASE_URL.replace("+asyncpg", ""))
|
|
|
|
try:
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
analysis.status = "processing"
|
|
db.commit()
|
|
|
|
logger.info(f"Starting analysis {analysis_id}")
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
# 1. Resize for inference
|
|
resized, scale = prepare_for_inference(image)
|
|
logger.info(f"Image resized: {image.size} -> {resized.size}")
|
|
|
|
# 2. Run saliency model
|
|
logger.info(f"Running {model_name} inference...")
|
|
saliency = model_manager.predict(resized, model_name)
|
|
logger.info("Inference complete")
|
|
|
|
# 3. Upscale to original dimensions
|
|
saliency_full = upscale_saliency(saliency, image.height, image.width)
|
|
|
|
# 4. Save raw saliency as .npy
|
|
np.save(str(storage.get_path(analysis_id, "saliency_raw.npy")), saliency_full)
|
|
|
|
# 5. Save saliency as grayscale PNG
|
|
saliency_uint8 = (saliency_full * 255).astype(np.uint8)
|
|
saliency_img = Image.fromarray(saliency_uint8, mode="L")
|
|
buf = io.BytesIO()
|
|
saliency_img.save(buf, format="PNG")
|
|
with open(storage.get_path(analysis_id, "saliency_gray.png"), "wb") as f:
|
|
f.write(buf.getvalue())
|
|
|
|
# 6. Generate heatmap overlay
|
|
heatmap_overlay = generate_heatmap_overlay(image, saliency_full)
|
|
buf = io.BytesIO()
|
|
heatmap_overlay.save(buf, format="PNG")
|
|
with open(storage.get_path(analysis_id, "heatmap_overlay.png"), "wb") as f:
|
|
f.write(buf.getvalue())
|
|
|
|
# 7. Generate standalone heatmap
|
|
heatmap_standalone = generate_standalone_heatmap(saliency_full)
|
|
buf = io.BytesIO()
|
|
heatmap_standalone.save(buf, format="PNG")
|
|
with open(storage.get_path(analysis_id, "heatmap_standalone.png"), "wb") as f:
|
|
f.write(buf.getvalue())
|
|
|
|
# 8. Extract gaze sequence
|
|
gaze_seq = extract_gaze_sequence(saliency_full, num_fixations=5)
|
|
|
|
# 9. Compute overall attention score
|
|
# Normalize saliency to a proper probability distribution
|
|
sal_sum = saliency_full.sum()
|
|
if sal_sum > 0:
|
|
prob_dist = saliency_full / sal_sum
|
|
prob_dist = prob_dist[prob_dist > 0] # remove zeros for log
|
|
entropy = -np.sum(prob_dist * np.log2(prob_dist))
|
|
max_entropy = np.log2(saliency_full.size)
|
|
concentration = (1 - entropy / max_entropy) * 100
|
|
overall_score = round(float(np.clip(concentration, 0, 100)), 1)
|
|
else:
|
|
overall_score = 0.0
|
|
|
|
# 10. Extract hotspots
|
|
hotspots = _extract_hotspots(saliency_full, num_hotspots=5)
|
|
|
|
# 11. Generate gaze sequence image
|
|
gaze_img = _draw_gaze_sequence(image, gaze_seq)
|
|
buf = io.BytesIO()
|
|
gaze_img.save(buf, format="PNG")
|
|
with open(storage.get_path(analysis_id, "gaze_sequence.png"), "wb") as f:
|
|
f.write(buf.getvalue())
|
|
|
|
# Update DB
|
|
analysis.status = "completed"
|
|
analysis.gaze_sequence = gaze_seq
|
|
analysis.hotspots = hotspots
|
|
analysis.overall_score = overall_score
|
|
db.commit()
|
|
logger.info(f"Analysis {analysis_id} completed (score={overall_score})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Analysis {analysis_id} failed: {e}", exc_info=True)
|
|
try:
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
if analysis:
|
|
analysis.status = "failed"
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _extract_hotspots(saliency, num_hotspots=5):
|
|
import numpy as np
|
|
from scipy.ndimage import gaussian_filter
|
|
|
|
sal = saliency.copy()
|
|
h, w = sal.shape
|
|
hotspots = []
|
|
radius = int(max(h, w) * 0.08)
|
|
|
|
for i in range(num_hotspots):
|
|
smoothed = gaussian_filter(sal, sigma=max(h, w) * 0.015)
|
|
peak_idx = np.unravel_index(np.argmax(smoothed), smoothed.shape)
|
|
y, x = int(peak_idx[0]), int(peak_idx[1])
|
|
intensity = float(saliency[y, x])
|
|
|
|
# Bounding box around hotspot
|
|
x1 = max(0, x - radius)
|
|
y1 = max(0, y - radius)
|
|
x2 = min(w, x + radius)
|
|
y2 = min(h, y + radius)
|
|
|
|
hotspots.append({
|
|
"rank": i + 1,
|
|
"center_x": x,
|
|
"center_y": y,
|
|
"x": x1,
|
|
"y": y1,
|
|
"width": x2 - x1,
|
|
"height": y2 - y1,
|
|
"intensity": round(intensity, 4),
|
|
})
|
|
|
|
# Inhibition of return
|
|
yy, xx = np.ogrid[:h, :w]
|
|
mask = (xx - x) ** 2 + (yy - y) ** 2 <= radius ** 2
|
|
sal[mask] = 0.0
|
|
|
|
return hotspots
|
|
|
|
|
|
def _draw_gaze_sequence(image, gaze_seq):
|
|
from PIL import ImageDraw, ImageFont
|
|
|
|
img = image.copy()
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
font = ImageFont.load_default(size=24)
|
|
|
|
colors = ["#FF4444", "#FF8800", "#FFCC00", "#44CC44", "#4488FF"]
|
|
|
|
for i, point in enumerate(gaze_seq):
|
|
x, y = point["x"], point["y"]
|
|
color = colors[i % len(colors)]
|
|
r = 25
|
|
|
|
# Draw circle
|
|
draw.ellipse([x - r, y - r, x + r, y + r], outline=color, width=3)
|
|
draw.text((x - 6, y - 12), str(point["rank"]), fill=color, font=font)
|
|
|
|
# Draw line to next point
|
|
if i < len(gaze_seq) - 1:
|
|
nx, ny = gaze_seq[i + 1]["x"], gaze_seq[i + 1]["y"]
|
|
draw.line([x, y, nx, ny], fill=color, width=2)
|
|
|
|
return img
|
|
|
|
|
|
@router.get("/analyses/ai-insights-available")
|
|
async def check_ai_insights_available():
|
|
"""Check if AI insights are available (API key configured)."""
|
|
from app.services.ai_insights import is_available
|
|
return {"available": is_available()}
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}", response_model=AnalysisDetail)
|
|
async def get_analysis(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
x_user_id: str | None = Header(None),
|
|
):
|
|
user_id = get_user_id(x_user_id)
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
# Generate insights for completed analyses
|
|
insights = None
|
|
if analysis.status == "completed":
|
|
from app.services.insights import generate_insights
|
|
insights = generate_insights(analysis)
|
|
|
|
return AnalysisDetail(
|
|
id=analysis.id,
|
|
name=analysis.name,
|
|
model_used=analysis.model_used,
|
|
status=analysis.status,
|
|
original_filename=analysis.original_filename,
|
|
image_width=analysis.image_width,
|
|
image_height=analysis.image_height,
|
|
file_format=analysis.file_format,
|
|
overall_score=analysis.overall_score,
|
|
created_at=analysis.created_at,
|
|
gaze_sequence=analysis.gaze_sequence,
|
|
hotspots=analysis.hotspots,
|
|
insights=insights,
|
|
ai_insights=analysis.ai_insights,
|
|
ai_score=analysis.ai_score,
|
|
ai_score_reason=analysis.ai_score_reason,
|
|
ai_cost_usd=analysis.ai_cost_usd,
|
|
)
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}/status", response_model=AnalysisStatus)
|
|
async def get_analysis_status(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
x_user_id: str | None = Header(None),
|
|
):
|
|
user_id = get_user_id(x_user_id)
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
return AnalysisStatus(id=analysis.id, status=analysis.status)
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}/images/{image_type}")
|
|
async def get_analysis_image(
|
|
analysis_id: str,
|
|
image_type: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
x_user_id: str | None = Header(None),
|
|
):
|
|
user_id = get_user_id(x_user_id)
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
file_map = {
|
|
"original": "original.png",
|
|
"thumbnail": "thumbnail.png",
|
|
"heatmap": "heatmap_overlay.png",
|
|
"heatmap-standalone": "heatmap_standalone.png",
|
|
"saliency-raw": "saliency_gray.png",
|
|
"gaze-sequence": "gaze_sequence.png",
|
|
}
|
|
|
|
filename = file_map.get(image_type)
|
|
if not filename:
|
|
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
|
|
|
if not storage.exists(analysis_id, filename):
|
|
raise HTTPException(status_code=404, detail=f"Image not yet available")
|
|
|
|
data = await storage.load_bytes(analysis_id, filename)
|
|
return StreamingResponse(io.BytesIO(data), media_type="image/png")
|
|
|
|
|
|
@router.post("/analyses/{analysis_id}/ai-insights")
|
|
async def generate_ai_insights_endpoint(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
x_user_id: str | None = Header(None),
|
|
):
|
|
"""Generate AI-powered insights for a completed analysis using Claude."""
|
|
from app.services.ai_insights import generate_ai_insights, is_available
|
|
|
|
if not is_available():
|
|
raise HTTPException(status_code=503, detail="AI insights not configured (missing API key)")
|
|
|
|
user_id = get_user_id(x_user_id)
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
if analysis.status != "completed":
|
|
raise HTTPException(status_code=400, detail="Analysis is not yet completed")
|
|
|
|
# Load images
|
|
try:
|
|
original_bytes = await storage.load_bytes(analysis_id, "original.png")
|
|
heatmap_bytes = await storage.load_bytes(analysis_id, "heatmap_overlay.png")
|
|
except FileNotFoundError:
|
|
raise HTTPException(status_code=404, detail="Analysis images not found")
|
|
|
|
metadata = {
|
|
"overall_score": analysis.overall_score,
|
|
"hotspots": analysis.hotspots or [],
|
|
"gaze_sequence": analysis.gaze_sequence or [],
|
|
"image_width": analysis.image_width,
|
|
"image_height": analysis.image_height,
|
|
}
|
|
|
|
try:
|
|
result = generate_ai_insights(metadata, original_bytes, heatmap_bytes)
|
|
|
|
# Save to DB
|
|
analysis.ai_insights = result["insights"]
|
|
analysis.ai_score = result["ai_score"]
|
|
analysis.ai_score_reason = result["score_reason"]
|
|
analysis.ai_cost_usd = result["cost_usd"]
|
|
await db.flush()
|
|
|
|
# Invalidate cached PDF so next download includes AI insights
|
|
if storage.exists(analysis_id, "report.pdf"):
|
|
import os
|
|
try:
|
|
os.remove(storage.get_path(analysis_id, "report.pdf"))
|
|
except OSError:
|
|
pass
|
|
|
|
return {
|
|
"insights": result["insights"],
|
|
"ai_score": result["ai_score"],
|
|
"score_reason": result["score_reason"],
|
|
"cost_usd": result["cost_usd"],
|
|
"input_tokens": result["input_tokens"],
|
|
"output_tokens": result["output_tokens"],
|
|
}
|
|
except RuntimeError as e:
|
|
raise HTTPException(status_code=502, detail=str(e))
|
|
|
|
|
|
@router.delete("/analyses/{analysis_id}", status_code=204)
|
|
async def delete_analysis(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
x_user_id: str | None = Header(None),
|
|
):
|
|
user_id = get_user_id(x_user_id)
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
await storage.delete_analysis(analysis_id)
|
|
await db.delete(analysis)
|
|
await db.commit()
|