from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db.session import get_db from app.dependencies import get_user_id from app.models.analysis import Analysis from app.models.comparison import Comparison from app.models.project import Project router = APIRouter(tags=["comparison"]) class ComparisonCreateBody(BaseModel): name: str analysis_ids: list[str] class ComparisonResponse(BaseModel): id: str name: str analyses: list[dict] comparison_data: dict | None = None model_config = {"from_attributes": True} @router.post( "/projects/{project_id}/comparisons", response_model=ComparisonResponse, status_code=201 ) async def create_comparison( project_id: str, body: ComparisonCreateBody, db: AsyncSession = Depends(get_db), user_id: str = Depends(get_user_id), ): # Verify project stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id) result = await db.execute(stmt) if not result.scalar_one_or_none(): raise HTTPException(status_code=404, detail="Project not found") if len(body.analysis_ids) < 2: raise HTTPException(status_code=400, detail="Need at least 2 analyses to compare") # Fetch analyses analyses_stmt = select(Analysis).where( Analysis.id.in_(body.analysis_ids), Analysis.user_id == user_id, Analysis.status == "completed", ) analyses_result = await db.execute(analyses_stmt) analyses = analyses_result.scalars().all() if len(analyses) != len(body.analysis_ids): raise HTTPException( status_code=400, detail="Some analyses not found or not completed" ) # Build comparison data analyses_data = [] winner_id = None max_score = -1 for a in analyses: data = { "analysis_id": a.id, "name": a.name, "overall_score": a.overall_score, "top_fixation": a.gaze_sequence[0] if a.gaze_sequence else None, } analyses_data.append(data) if a.overall_score and a.overall_score > max_score: max_score = a.overall_score winner_id = a.id comparison_data = { "winner": winner_id, "score_delta": round(max_score - min(a.overall_score or 0 for a in analyses), 1), } comparison = Comparison( project_id=project_id, user_id=user_id, name=body.name, analysis_ids=body.analysis_ids, comparison_data=comparison_data, ) db.add(comparison) await db.flush() await db.refresh(comparison) return ComparisonResponse( id=comparison.id, name=comparison.name, analyses=analyses_data, comparison_data=comparison_data, ) @router.get("/comparisons/{comparison_id}", response_model=ComparisonResponse) async def get_comparison( comparison_id: str, db: AsyncSession = Depends(get_db), user_id: str = Depends(get_user_id), ): stmt = select(Comparison).where( Comparison.id == comparison_id, Comparison.user_id == user_id ) result = await db.execute(stmt) comparison = result.scalar_one_or_none() if not comparison: raise HTTPException(status_code=404, detail="Comparison not found") # Fetch the analyses for display analyses_stmt = select(Analysis).where(Analysis.id.in_(comparison.analysis_ids)) analyses_result = await db.execute(analyses_stmt) analyses = analyses_result.scalars().all() analyses_data = [ { "analysis_id": a.id, "name": a.name, "overall_score": a.overall_score, "top_fixation": a.gaze_sequence[0] if a.gaze_sequence else None, } for a in analyses ] return ComparisonResponse( id=comparison.id, name=comparison.name, analyses=analyses_data, comparison_data=comparison.comparison_data, )