Replace X-User-Id header auth with Azure AD JWT token validation. Backend validates tokens via JWKS, frontend uses MSAL for login/token acquisition. Adds logout button, 401 handling, and configurable AZURE_AUTH_ENABLED toggle. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
639 lines
23 KiB
Python
639 lines
23 KiB
Python
import io
|
|
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, Form
|
|
from fastapi.responses import StreamingResponse
|
|
from PIL import Image
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.session import get_db
|
|
from app.dependencies import get_user_id
|
|
from app.models.analysis import Analysis
|
|
from app.models.project import Project
|
|
from app.schemas.analysis import AnalysisDetail, AnalysisStatus, AnalysisSummary
|
|
from app.services.storage import storage
|
|
|
|
router = APIRouter(tags=["analysis"])
|
|
|
|
ALLOWED_FORMATS = {"JPEG", "PNG", "TIFF", "WEBP", "BMP"}
|
|
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
|
|
|
|
|
@router.post("/projects/{project_id}/analyses", response_model=AnalysisStatus, status_code=202)
|
|
async def create_analysis(
|
|
project_id: str,
|
|
file: UploadFile,
|
|
background_tasks: BackgroundTasks,
|
|
name: str | None = Form(None),
|
|
model: str = Form("deepgaze_iie"),
|
|
db: AsyncSession = Depends(get_db),
|
|
user_id: str = Depends(get_user_id),
|
|
):
|
|
# Verify project belongs to user
|
|
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
|
|
# Read and validate image
|
|
image_data = await file.read()
|
|
if len(image_data) > MAX_FILE_SIZE:
|
|
raise HTTPException(status_code=413, detail="File too large (max 50MB)")
|
|
|
|
try:
|
|
image = Image.open(io.BytesIO(image_data))
|
|
image.verify()
|
|
image = Image.open(io.BytesIO(image_data)) # re-open after verify
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid image file")
|
|
|
|
if image.format not in ALLOWED_FORMATS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported format: {image.format}. Allowed: {', '.join(ALLOWED_FORMATS)}",
|
|
)
|
|
|
|
# Create analysis record
|
|
analysis = Analysis(
|
|
project_id=project_id,
|
|
user_id=user_id,
|
|
name=name or file.filename or "Untitled",
|
|
model_used=model,
|
|
status="pending",
|
|
original_filename=file.filename or "upload",
|
|
image_width=image.width,
|
|
image_height=image.height,
|
|
file_format=image.format or "PNG",
|
|
storage_path=str(storage.base_dir),
|
|
)
|
|
db.add(analysis)
|
|
await db.flush()
|
|
await db.refresh(analysis)
|
|
|
|
# Save original image
|
|
await storage.save_bytes(image_data, analysis.id, "original.png")
|
|
|
|
# Save thumbnail
|
|
thumb = image.copy()
|
|
thumb.thumbnail((400, 400))
|
|
thumb_buffer = io.BytesIO()
|
|
thumb.save(thumb_buffer, format="PNG")
|
|
await storage.save_bytes(thumb_buffer.getvalue(), analysis.id, "thumbnail.png")
|
|
|
|
analysis_id = analysis.id
|
|
|
|
# Commit now so the background thread can see the record
|
|
await db.commit()
|
|
|
|
# Queue background processing (sync function runs in threadpool)
|
|
background_tasks.add_task(run_analysis_pipeline, analysis_id, image_data, model)
|
|
|
|
return AnalysisStatus(id=analysis_id, status="pending")
|
|
|
|
|
|
def _make_sync_engine():
|
|
from sqlalchemy import create_engine
|
|
from app.config import settings
|
|
try:
|
|
return create_engine(settings.DATABASE_URL.replace("+asyncpg", "+psycopg2"))
|
|
except Exception:
|
|
return create_engine(settings.DATABASE_URL.replace("+asyncpg", ""))
|
|
|
|
|
|
def _save_file(path, data: bytes) -> None:
|
|
with open(path, "wb") as f:
|
|
f.write(data)
|
|
|
|
|
|
def run_analysis_pipeline(analysis_id: str, image_data: bytes, model_name: str):
|
|
"""Background task: full saliency analysis pipeline. Runs sync in threadpool.
|
|
|
|
Routes to Cloud Run if CLOUD_RUN_SALIENCY_URL is configured, otherwise runs locally.
|
|
"""
|
|
from app.config import settings
|
|
|
|
if settings.use_cloud_run:
|
|
_run_pipeline_cloud_run(analysis_id, image_data, model_name)
|
|
else:
|
|
_run_pipeline_local(analysis_id, image_data, model_name)
|
|
|
|
|
|
def _run_pipeline_cloud_run(analysis_id: str, image_data: bytes, model_name: str):
|
|
"""Pipeline using Google Cloud Run for saliency + image processing."""
|
|
import base64
|
|
import logging
|
|
import numpy as np
|
|
from sqlalchemy.orm import Session
|
|
from app.services.cloud_run_client import call_saliency, call_processing
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger("olivas.pipeline")
|
|
sync_engine = _make_sync_engine()
|
|
|
|
try:
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
analysis.status = "processing"
|
|
db.commit()
|
|
|
|
logger.info(f"[cloud-run] Starting analysis {analysis_id}")
|
|
|
|
# 1. Saliency inference via Cloud Run
|
|
sal_result = call_saliency(image_data, model_name)
|
|
|
|
saliency_b64: str = sal_result["saliency_b64"]
|
|
shape: list[int] = sal_result["shape"]
|
|
gaze_seq: list[dict] = sal_result["gaze_sequence"]
|
|
hotspots: list[dict] = sal_result["hotspots"]
|
|
overall_score: float = sal_result["overall_score"]
|
|
entropy_score: float = sal_result["entropy_score"]
|
|
|
|
# 2. Save raw saliency locally
|
|
h, w = shape
|
|
saliency_full = np.frombuffer(
|
|
base64.b64decode(saliency_b64), dtype=np.float32
|
|
).reshape(h, w)
|
|
np.save(str(storage.get_path(analysis_id, "saliency_raw.npy")), saliency_full)
|
|
|
|
# 3. Save saliency grayscale PNG locally
|
|
saliency_uint8 = (saliency_full * 255).astype(np.uint8)
|
|
saliency_img = Image.fromarray(saliency_uint8, mode="L")
|
|
buf = io.BytesIO()
|
|
saliency_img.save(buf, format="PNG")
|
|
_save_file(storage.get_path(analysis_id, "saliency_gray.png"), buf.getvalue())
|
|
|
|
# 4. Image post-processing via Cloud Run (or local fallback)
|
|
if settings.CLOUD_RUN_PROCESSING_URL:
|
|
proc_result = call_processing(image_data, saliency_b64, shape, gaze_seq)
|
|
heatmap_overlay_data = base64.b64decode(proc_result["heatmap_overlay_b64"])
|
|
heatmap_standalone_data = base64.b64decode(proc_result["heatmap_standalone_b64"])
|
|
gaze_img_data = base64.b64decode(proc_result["gaze_sequence_img_b64"])
|
|
else:
|
|
# Local fallback for image processing
|
|
from app.services.heatmap import generate_heatmap_overlay, generate_standalone_heatmap
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
heatmap_overlay_data = _img_to_png_bytes(generate_heatmap_overlay(image, saliency_full))
|
|
heatmap_standalone_data = _img_to_png_bytes(generate_standalone_heatmap(saliency_full))
|
|
gaze_img_data = _img_to_png_bytes(_draw_gaze_sequence(image, gaze_seq))
|
|
|
|
_save_file(storage.get_path(analysis_id, "heatmap_overlay.png"), heatmap_overlay_data)
|
|
_save_file(storage.get_path(analysis_id, "heatmap_standalone.png"), heatmap_standalone_data)
|
|
_save_file(storage.get_path(analysis_id, "gaze_sequence.png"), gaze_img_data)
|
|
|
|
# 5. Update DB
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
analysis.status = "completed"
|
|
analysis.gaze_sequence = gaze_seq
|
|
analysis.hotspots = hotspots
|
|
analysis.overall_score = overall_score
|
|
analysis.entropy_score = entropy_score
|
|
db.commit()
|
|
logger.info(f"[cloud-run] Analysis {analysis_id} completed (score={overall_score})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[cloud-run] Analysis {analysis_id} failed: {e}", exc_info=True)
|
|
try:
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
if analysis:
|
|
analysis.status = "failed"
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _run_pipeline_local(analysis_id: str, image_data: bytes, model_name: str):
|
|
"""Pipeline running entirely locally (dev mode / no Cloud Run)."""
|
|
import logging
|
|
import numpy as np
|
|
from sqlalchemy.orm import Session
|
|
from app.services.saliency.model_manager import model_manager
|
|
from app.services.image_processing import prepare_for_inference, upscale_saliency
|
|
from app.services.heatmap import generate_heatmap_overlay, generate_standalone_heatmap
|
|
from app.services.gaze_sequence import extract_gaze_sequence
|
|
|
|
logger = logging.getLogger("olivas.pipeline")
|
|
sync_engine = _make_sync_engine()
|
|
|
|
try:
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
analysis.status = "processing"
|
|
db.commit()
|
|
|
|
logger.info(f"[local] Starting analysis {analysis_id}")
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
resized, scale = prepare_for_inference(image)
|
|
logger.info(f"Image resized: {image.size} -> {resized.size}")
|
|
|
|
logger.info(f"Running {model_name} inference...")
|
|
saliency = model_manager.predict(resized, model_name)
|
|
logger.info("Inference complete")
|
|
|
|
saliency_full = upscale_saliency(saliency, image.height, image.width)
|
|
|
|
np.save(str(storage.get_path(analysis_id, "saliency_raw.npy")), saliency_full)
|
|
|
|
saliency_uint8 = (saliency_full * 255).astype(np.uint8)
|
|
saliency_img = Image.fromarray(saliency_uint8, mode="L")
|
|
buf = io.BytesIO()
|
|
saliency_img.save(buf, format="PNG")
|
|
_save_file(storage.get_path(analysis_id, "saliency_gray.png"), buf.getvalue())
|
|
|
|
heatmap_overlay = generate_heatmap_overlay(image, saliency_full)
|
|
_save_file(storage.get_path(analysis_id, "heatmap_overlay.png"), _img_to_png_bytes(heatmap_overlay))
|
|
|
|
heatmap_standalone = generate_standalone_heatmap(saliency_full)
|
|
_save_file(storage.get_path(analysis_id, "heatmap_standalone.png"), _img_to_png_bytes(heatmap_standalone))
|
|
|
|
gaze_seq = extract_gaze_sequence(saliency_full, num_fixations=5)
|
|
hotspots = _extract_hotspots(saliency_full, num_hotspots=5)
|
|
overall_score, entropy_score = _compute_design_score(saliency_full, hotspots, gaze_seq)
|
|
|
|
gaze_img = _draw_gaze_sequence(image, gaze_seq)
|
|
_save_file(storage.get_path(analysis_id, "gaze_sequence.png"), _img_to_png_bytes(gaze_img))
|
|
|
|
analysis.status = "completed"
|
|
analysis.gaze_sequence = gaze_seq
|
|
analysis.hotspots = hotspots
|
|
analysis.overall_score = overall_score
|
|
analysis.entropy_score = entropy_score
|
|
db.commit()
|
|
logger.info(f"[local] Analysis {analysis_id} completed (score={overall_score}, entropy={entropy_score})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[local] Analysis {analysis_id} failed: {e}", exc_info=True)
|
|
try:
|
|
with Session(sync_engine) as db:
|
|
analysis = db.get(Analysis, analysis_id)
|
|
if analysis:
|
|
analysis.status = "failed"
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _img_to_png_bytes(img: Image.Image) -> bytes:
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="PNG")
|
|
return buf.getvalue()
|
|
|
|
|
|
def _compute_design_score(saliency_full, hotspots, gaze_seq):
|
|
"""Compute composite Design Effectiveness Score (0-100) and raw entropy score.
|
|
|
|
Four components:
|
|
A. Peak Dominance (0.30) — how much stronger the top hotspot is vs rest
|
|
B. Hierarchy Clarity (0.25) — whether intensities decrease monotonically
|
|
C. Gaze Coherence (0.25) — whether gaze follows a smooth spatial path
|
|
D. Entropy Concentration (0.20) — Shannon entropy, softened with sqrt
|
|
"""
|
|
import numpy as np
|
|
|
|
# --- D. Entropy Concentration (also gives us the raw entropy_score) ---
|
|
sal_sum = saliency_full.sum()
|
|
if sal_sum > 0:
|
|
prob_dist = saliency_full / sal_sum
|
|
prob_dist = prob_dist[prob_dist > 0]
|
|
entropy = -np.sum(prob_dist * np.log2(prob_dist))
|
|
max_entropy = np.log2(saliency_full.size)
|
|
raw_concentration = (1 - entropy / max_entropy) * 100
|
|
else:
|
|
raw_concentration = 0.0
|
|
|
|
entropy_score = round(float(np.clip(raw_concentration, 0, 100)), 1)
|
|
entropy_adjusted = float(np.sqrt(max(raw_concentration, 0) / 100)) * 100
|
|
|
|
# --- A. Peak Dominance ---
|
|
if len(hotspots) >= 2:
|
|
top_intensity = hotspots[0]["intensity"]
|
|
rest_intensities = [h["intensity"] for h in hotspots[1:]]
|
|
rest_mean = float(np.mean(rest_intensities)) if rest_intensities else 0.0
|
|
if rest_mean > 0:
|
|
dominance_ratio = top_intensity / rest_mean
|
|
else:
|
|
dominance_ratio = 10.0
|
|
peak_dominance = float(100 * (1 - np.exp(-0.5 * dominance_ratio)))
|
|
elif len(hotspots) == 1:
|
|
peak_dominance = 95.0 # single hotspot = very dominant
|
|
else:
|
|
peak_dominance = 50.0 # no hotspots, neutral
|
|
|
|
# --- B. Hierarchy Clarity ---
|
|
intensities = [h["intensity"] for h in hotspots]
|
|
n = len(intensities)
|
|
if n >= 2:
|
|
concordant = 0
|
|
total_pairs = 0
|
|
for i in range(n):
|
|
for j in range(i + 1, n):
|
|
total_pairs += 1
|
|
if intensities[i] > intensities[j]:
|
|
concordant += 1
|
|
monotonicity = concordant / total_pairs if total_pairs > 0 else 1.0
|
|
|
|
if intensities[0] > 0:
|
|
drop_ratio = 1 - (intensities[-1] / intensities[0])
|
|
else:
|
|
drop_ratio = 0.0
|
|
|
|
hierarchy_clarity = float((0.6 * monotonicity + 0.4 * drop_ratio) * 100)
|
|
else:
|
|
hierarchy_clarity = 70.0 # neutral default
|
|
|
|
# --- C. Gaze Coherence ---
|
|
gaze_points = [(g["x"], g["y"]) for g in gaze_seq] if gaze_seq else []
|
|
ng = len(gaze_points)
|
|
if ng >= 3:
|
|
# Angle smoothness
|
|
angles = []
|
|
for i in range(ng - 2):
|
|
ax = gaze_points[i + 1][0] - gaze_points[i][0]
|
|
ay = gaze_points[i + 1][1] - gaze_points[i][1]
|
|
bx = gaze_points[i + 2][0] - gaze_points[i + 1][0]
|
|
by = gaze_points[i + 2][1] - gaze_points[i + 1][1]
|
|
mag_a = np.sqrt(ax ** 2 + ay ** 2)
|
|
mag_b = np.sqrt(bx ** 2 + by ** 2)
|
|
if mag_a > 0 and mag_b > 0:
|
|
cos_angle = np.clip((ax * bx + ay * by) / (mag_a * mag_b), -1, 1)
|
|
angle = float(np.degrees(np.arccos(cos_angle)))
|
|
angles.append(angle)
|
|
|
|
if angles:
|
|
avg_angle = float(np.mean(angles))
|
|
angle_smoothness = max(0, 100 - (avg_angle / 180) * 100)
|
|
else:
|
|
angle_smoothness = 70.0
|
|
|
|
# Path efficiency
|
|
total_path = sum(
|
|
np.sqrt((gaze_points[i + 1][0] - gaze_points[i][0]) ** 2 +
|
|
(gaze_points[i + 1][1] - gaze_points[i][1]) ** 2)
|
|
for i in range(ng - 1)
|
|
)
|
|
direct_dist = np.sqrt((gaze_points[-1][0] - gaze_points[0][0]) ** 2 +
|
|
(gaze_points[-1][1] - gaze_points[0][1]) ** 2)
|
|
path_efficiency = float(direct_dist / total_path) if total_path > 0 else 1.0
|
|
|
|
gaze_coherence = 0.7 * angle_smoothness + 0.3 * (path_efficiency * 100)
|
|
else:
|
|
gaze_coherence = 70.0 # neutral default for too few points
|
|
|
|
# --- Composite ---
|
|
composite = (
|
|
0.30 * peak_dominance
|
|
+ 0.25 * hierarchy_clarity
|
|
+ 0.25 * gaze_coherence
|
|
+ 0.20 * entropy_adjusted
|
|
)
|
|
overall_score = round(float(np.clip(composite, 0, 100)), 1)
|
|
|
|
return overall_score, entropy_score
|
|
|
|
|
|
def _extract_hotspots(saliency, num_hotspots=5):
|
|
import numpy as np
|
|
from scipy.ndimage import gaussian_filter
|
|
|
|
sal = saliency.copy()
|
|
h, w = sal.shape
|
|
hotspots = []
|
|
radius = int(max(h, w) * 0.08)
|
|
|
|
for i in range(num_hotspots):
|
|
smoothed = gaussian_filter(sal, sigma=max(h, w) * 0.015)
|
|
peak_idx = np.unravel_index(np.argmax(smoothed), smoothed.shape)
|
|
y, x = int(peak_idx[0]), int(peak_idx[1])
|
|
intensity = float(saliency[y, x])
|
|
|
|
# Bounding box around hotspot
|
|
x1 = max(0, x - radius)
|
|
y1 = max(0, y - radius)
|
|
x2 = min(w, x + radius)
|
|
y2 = min(h, y + radius)
|
|
|
|
hotspots.append({
|
|
"rank": i + 1,
|
|
"center_x": x,
|
|
"center_y": y,
|
|
"x": x1,
|
|
"y": y1,
|
|
"width": x2 - x1,
|
|
"height": y2 - y1,
|
|
"intensity": round(intensity, 4),
|
|
})
|
|
|
|
# Inhibition of return
|
|
yy, xx = np.ogrid[:h, :w]
|
|
mask = (xx - x) ** 2 + (yy - y) ** 2 <= radius ** 2
|
|
sal[mask] = 0.0
|
|
|
|
return hotspots
|
|
|
|
|
|
def _draw_gaze_sequence(image, gaze_seq):
|
|
from PIL import ImageDraw, ImageFont
|
|
|
|
img = image.copy()
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
font = ImageFont.load_default(size=24)
|
|
|
|
colors = ["#FF4444", "#FF8800", "#FFCC00", "#44CC44", "#4488FF"]
|
|
|
|
for i, point in enumerate(gaze_seq):
|
|
x, y = point["x"], point["y"]
|
|
color = colors[i % len(colors)]
|
|
r = 25
|
|
|
|
# Draw circle
|
|
draw.ellipse([x - r, y - r, x + r, y + r], outline=color, width=3)
|
|
draw.text((x - 6, y - 12), str(point["rank"]), fill=color, font=font)
|
|
|
|
# Draw line to next point
|
|
if i < len(gaze_seq) - 1:
|
|
nx, ny = gaze_seq[i + 1]["x"], gaze_seq[i + 1]["y"]
|
|
draw.line([x, y, nx, ny], fill=color, width=2)
|
|
|
|
return img
|
|
|
|
|
|
@router.get("/analyses/ai-insights-available")
|
|
async def check_ai_insights_available():
|
|
"""Check if AI insights are available (API key configured)."""
|
|
from app.services.ai_insights import is_available
|
|
return {"available": is_available()}
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}", response_model=AnalysisDetail)
|
|
async def get_analysis(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_id: str = Depends(get_user_id),
|
|
):
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
# Generate insights for completed analyses
|
|
insights = None
|
|
if analysis.status == "completed":
|
|
from app.services.insights import generate_insights
|
|
insights = generate_insights(analysis)
|
|
|
|
return AnalysisDetail(
|
|
id=analysis.id,
|
|
name=analysis.name,
|
|
model_used=analysis.model_used,
|
|
status=analysis.status,
|
|
original_filename=analysis.original_filename,
|
|
image_width=analysis.image_width,
|
|
image_height=analysis.image_height,
|
|
file_format=analysis.file_format,
|
|
overall_score=analysis.overall_score,
|
|
created_at=analysis.created_at,
|
|
gaze_sequence=analysis.gaze_sequence,
|
|
hotspots=analysis.hotspots,
|
|
insights=insights,
|
|
entropy_score=analysis.entropy_score,
|
|
ai_insights=analysis.ai_insights,
|
|
ai_score=analysis.ai_score,
|
|
ai_score_reason=analysis.ai_score_reason,
|
|
ai_cost_usd=analysis.ai_cost_usd,
|
|
)
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}/status", response_model=AnalysisStatus)
|
|
async def get_analysis_status(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_id: str = Depends(get_user_id),
|
|
):
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
return AnalysisStatus(id=analysis.id, status=analysis.status)
|
|
|
|
|
|
@router.get("/analyses/{analysis_id}/images/{image_type}")
|
|
async def get_analysis_image(
|
|
analysis_id: str,
|
|
image_type: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_id: str = Depends(get_user_id),
|
|
):
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
file_map = {
|
|
"original": "original.png",
|
|
"thumbnail": "thumbnail.png",
|
|
"heatmap": "heatmap_overlay.png",
|
|
"heatmap-standalone": "heatmap_standalone.png",
|
|
"saliency-raw": "saliency_gray.png",
|
|
"gaze-sequence": "gaze_sequence.png",
|
|
}
|
|
|
|
filename = file_map.get(image_type)
|
|
if not filename:
|
|
raise HTTPException(status_code=400, detail=f"Unknown image type: {image_type}")
|
|
|
|
if not storage.exists(analysis_id, filename):
|
|
raise HTTPException(status_code=404, detail=f"Image not yet available")
|
|
|
|
data = await storage.load_bytes(analysis_id, filename)
|
|
return StreamingResponse(io.BytesIO(data), media_type="image/png")
|
|
|
|
|
|
@router.post("/analyses/{analysis_id}/ai-insights")
|
|
async def generate_ai_insights_endpoint(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_id: str = Depends(get_user_id),
|
|
):
|
|
"""Generate AI-powered insights for a completed analysis using Claude."""
|
|
from app.services.ai_insights import generate_ai_insights, is_available
|
|
|
|
if not is_available():
|
|
raise HTTPException(status_code=503, detail="AI insights not configured (missing API key)")
|
|
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
if analysis.status != "completed":
|
|
raise HTTPException(status_code=400, detail="Analysis is not yet completed")
|
|
|
|
# Load images
|
|
try:
|
|
original_bytes = await storage.load_bytes(analysis_id, "original.png")
|
|
heatmap_bytes = await storage.load_bytes(analysis_id, "heatmap_overlay.png")
|
|
except FileNotFoundError:
|
|
raise HTTPException(status_code=404, detail="Analysis images not found")
|
|
|
|
metadata = {
|
|
"overall_score": analysis.overall_score,
|
|
"entropy_score": analysis.entropy_score,
|
|
"hotspots": analysis.hotspots or [],
|
|
"gaze_sequence": analysis.gaze_sequence or [],
|
|
"image_width": analysis.image_width,
|
|
"image_height": analysis.image_height,
|
|
}
|
|
|
|
try:
|
|
result = generate_ai_insights(metadata, original_bytes, heatmap_bytes)
|
|
|
|
# Save to DB
|
|
analysis.ai_insights = result["insights"]
|
|
analysis.ai_score = result["ai_score"]
|
|
analysis.ai_score_reason = result["score_reason"]
|
|
analysis.ai_cost_usd = result["cost_usd"]
|
|
await db.flush()
|
|
|
|
# Invalidate cached PDF so next download includes AI insights
|
|
if storage.exists(analysis_id, "report.pdf"):
|
|
import os
|
|
try:
|
|
os.remove(storage.get_path(analysis_id, "report.pdf"))
|
|
except OSError:
|
|
pass
|
|
|
|
return {
|
|
"insights": result["insights"],
|
|
"ai_score": result["ai_score"],
|
|
"score_reason": result["score_reason"],
|
|
"cost_usd": result["cost_usd"],
|
|
"input_tokens": result["input_tokens"],
|
|
"output_tokens": result["output_tokens"],
|
|
}
|
|
except RuntimeError as e:
|
|
raise HTTPException(status_code=502, detail=str(e))
|
|
|
|
|
|
@router.delete("/analyses/{analysis_id}", status_code=204)
|
|
async def delete_analysis(
|
|
analysis_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_id: str = Depends(get_user_id),
|
|
):
|
|
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
analysis = result.scalar_one_or_none()
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="Analysis not found")
|
|
|
|
await storage.delete_analysis(analysis_id)
|
|
await db.delete(analysis)
|
|
await db.commit()
|