diff --git a/.env.example b/.env.example
index e872649..c11d0a6 100644
--- a/.env.example
+++ b/.env.example
@@ -16,3 +16,13 @@ BACKEND_PORT=8000
# AI Design Analysis (optional — leave empty to disable)
ANTHROPIC_API_KEY=
+
+# Azure AD SSO (set AZURE_AUTH_ENABLED=false to disable)
+AZURE_AUTH_ENABLED=true
+AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385
+AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef
+
+# Frontend Azure AD (Vite env vars)
+VITE_AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385
+VITE_AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef
+VITE_AZURE_REDIRECT_URI=https://ai-sandbox.oliver.solutions/olivas
diff --git a/backend/app/api/endpoints/analysis.py b/backend/app/api/endpoints/analysis.py
index 2bf23d7..7f76fc0 100644
--- a/backend/app/api/endpoints/analysis.py
+++ b/backend/app/api/endpoints/analysis.py
@@ -1,6 +1,6 @@
import io
-from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, UploadFile, Form
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, Form
from fastapi.responses import StreamingResponse
from PIL import Image
from sqlalchemy import select
@@ -27,10 +27,8 @@ async def create_analysis(
name: str | None = Form(None),
model: str = Form("deepgaze_iie"),
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
-
# Verify project belongs to user
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
result = await db.execute(stmt)
@@ -474,9 +472,8 @@ async def check_ai_insights_available():
async def get_analysis(
analysis_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
@@ -515,9 +512,8 @@ async def get_analysis(
async def get_analysis_status(
analysis_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
@@ -531,9 +527,8 @@ async def get_analysis_image(
analysis_id: str,
image_type: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
@@ -564,7 +559,7 @@ async def get_analysis_image(
async def generate_ai_insights_endpoint(
analysis_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
"""Generate AI-powered insights for a completed analysis using Claude."""
from app.services.ai_insights import generate_ai_insights, is_available
@@ -572,7 +567,6 @@ async def generate_ai_insights_endpoint(
if not is_available():
raise HTTPException(status_code=503, detail="AI insights not configured (missing API key)")
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
@@ -632,9 +626,8 @@ async def generate_ai_insights_endpoint(
async def delete_analysis(
analysis_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
diff --git a/backend/app/api/endpoints/aoi.py b/backend/app/api/endpoints/aoi.py
index a7b598a..b8307a6 100644
--- a/backend/app/api/endpoints/aoi.py
+++ b/backend/app/api/endpoints/aoi.py
@@ -1,5 +1,5 @@
import numpy as np
-from fastapi import APIRouter, Depends, Header, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -19,9 +19,8 @@ async def create_aois(
analysis_id: str,
body: AOICreate,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
@@ -68,9 +67,8 @@ async def create_aois(
async def list_aois(
analysis_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
analysis = result.scalar_one_or_none()
@@ -88,9 +86,8 @@ async def delete_aoi(
analysis_id: str,
aoi_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
# Verify analysis ownership
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
result = await db.execute(stmt)
diff --git a/backend/app/api/endpoints/comparison.py b/backend/app/api/endpoints/comparison.py
index 94be9ef..70836a8 100644
--- a/backend/app/api/endpoints/comparison.py
+++ b/backend/app/api/endpoints/comparison.py
@@ -1,4 +1,5 @@
-from fastapi import APIRouter, Depends, Header, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -11,13 +12,6 @@ from app.models.project import Project
router = APIRouter(tags=["comparison"])
-class ComparisonCreate:
- pass
-
-
-from pydantic import BaseModel
-
-
class ComparisonCreateBody(BaseModel):
name: str
analysis_ids: list[str]
@@ -39,10 +33,8 @@ async def create_comparison(
project_id: str,
body: ComparisonCreateBody,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
-
# Verify project
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
result = await db.execute(stmt)
@@ -111,9 +103,8 @@ async def create_comparison(
async def get_comparison(
comparison_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Comparison).where(
Comparison.id == comparison_id, Comparison.user_id == user_id
)
diff --git a/backend/app/api/endpoints/projects.py b/backend/app/api/endpoints/projects.py
index a4fc915..f583db6 100644
--- a/backend/app/api/endpoints/projects.py
+++ b/backend/app/api/endpoints/projects.py
@@ -1,4 +1,4 @@
-from fastapi import APIRouter, Depends, Header, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -16,9 +16,8 @@ router = APIRouter(prefix="/projects", tags=["projects"])
async def create_project(
body: ProjectCreate,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
project = Project(user_id=user_id, name=body.name, description=body.description)
db.add(project)
await db.flush()
@@ -38,9 +37,8 @@ async def list_projects(
page: int = 1,
per_page: int = 20,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
offset = (page - 1) * per_page
stmt = (
@@ -75,9 +73,8 @@ async def list_projects(
async def get_project(
project_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = (
select(Project)
.options(selectinload(Project.analyses))
@@ -95,9 +92,8 @@ async def update_project(
project_id: str,
body: ProjectUpdate,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
result = await db.execute(stmt)
project = result.scalar_one_or_none()
@@ -130,9 +126,8 @@ async def update_project(
async def delete_project(
project_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
result = await db.execute(stmt)
project = result.scalar_one_or_none()
diff --git a/backend/app/api/endpoints/reports.py b/backend/app/api/endpoints/reports.py
index ecf8b10..08f0368 100644
--- a/backend/app/api/endpoints/reports.py
+++ b/backend/app/api/endpoints/reports.py
@@ -1,7 +1,7 @@
import io
import logging
-from fastapi import APIRouter, Depends, Header, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -20,9 +20,8 @@ logger = logging.getLogger("olivas.reports")
async def download_report(
analysis_id: str,
db: AsyncSession = Depends(get_db),
- x_user_id: str | None = Header(None),
+ user_id: str = Depends(get_user_id),
):
- user_id = get_user_id(x_user_id)
stmt = (
select(Analysis)
.options(selectinload(Analysis.aois))
diff --git a/backend/app/auth.py b/backend/app/auth.py
new file mode 100644
index 0000000..7ffaa97
--- /dev/null
+++ b/backend/app/auth.py
@@ -0,0 +1,97 @@
+"""Azure AD JWT validation using JWKS (RS256)."""
+
+import logging
+import time
+from dataclasses import dataclass
+
+import httpx
+import jwt
+
+from app.config import settings
+
+logger = logging.getLogger("olivas.auth")
+
+# JWKS cache
+_jwks_cache: dict = {}
+_jwks_cache_time: float = 0
+_JWKS_CACHE_TTL = 3600 # 1 hour
+
+
+@dataclass
+class CurrentUser:
+ oid: str
+ name: str
+ email: str
+
+
+def _get_jwks_uri() -> str:
+ tenant = settings.AZURE_TENANT_ID
+ return f"https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys"
+
+
+def _get_issuer() -> str:
+ tenant = settings.AZURE_TENANT_ID
+ return f"https://login.microsoftonline.com/{tenant}/v2.0"
+
+
+def refresh_jwks_cache() -> None:
+ """Fetch and cache JWKS keys from Azure AD."""
+ global _jwks_cache, _jwks_cache_time
+ uri = _get_jwks_uri()
+ logger.info(f"Fetching JWKS from {uri}")
+ resp = httpx.get(uri, timeout=10)
+ resp.raise_for_status()
+ _jwks_cache = resp.json()
+ _jwks_cache_time = time.time()
+ logger.info(f"JWKS cache refreshed ({len(_jwks_cache.get('keys', []))} keys)")
+
+
+def _get_signing_key(token: str) -> jwt.algorithms.RSAAlgorithm:
+ """Get the signing key for the given token from cached JWKS."""
+ global _jwks_cache, _jwks_cache_time
+
+ if not _jwks_cache or (time.time() - _jwks_cache_time > _JWKS_CACHE_TTL):
+ refresh_jwks_cache()
+
+ jwks_client = jwt.PyJWKClient.__new__(jwt.PyJWKClient)
+ jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
+
+ try:
+ header = jwt.get_unverified_header(token)
+ kid = header.get("kid")
+ for key in jwks_client.jwk_set.keys:
+ if key.key_id == kid:
+ return key.key
+ except Exception:
+ pass
+
+ # Key not found — try refreshing once
+ refresh_jwks_cache()
+ jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
+ header = jwt.get_unverified_header(token)
+ kid = header.get("kid")
+ for key in jwks_client.jwk_set.keys:
+ if key.key_id == kid:
+ return key.key
+
+ raise jwt.InvalidTokenError(f"Unable to find signing key with kid={kid}")
+
+
+def validate_token(token: str) -> CurrentUser:
+ """Decode and validate an Azure AD JWT token. Returns CurrentUser."""
+ signing_key = _get_signing_key(token)
+
+ payload = jwt.decode(
+ token,
+ signing_key,
+ algorithms=["RS256"],
+ audience=settings.AZURE_CLIENT_ID,
+ issuer=_get_issuer(),
+ options={"require": ["exp", "iss", "aud", "oid"]},
+ )
+
+ return CurrentUser(
+ oid=payload["oid"],
+ name=payload.get("name", ""),
+ email=payload.get("preferred_username", payload.get("email", "")),
+ )
diff --git a/backend/app/config.py b/backend/app/config.py
index 6b2e250..2c78cf3 100644
--- a/backend/app/config.py
+++ b/backend/app/config.py
@@ -17,6 +17,11 @@ class Settings(BaseSettings):
GOOGLE_CLOUD_PROJECT: str = "optical-414516"
+ # Azure AD SSO
+ AZURE_TENANT_ID: str = ""
+ AZURE_CLIENT_ID: str = ""
+ AZURE_AUTH_ENABLED: bool = True
+
@property
def use_cloud_run(self) -> bool:
return bool(self.CLOUD_RUN_SALIENCY_URL)
diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py
index b4b00ce..e4f8086 100644
--- a/backend/app/dependencies.py
+++ b/backend/app/dependencies.py
@@ -1,12 +1,36 @@
-from sqlalchemy.ext.asyncio import AsyncSession
+import logging
+from fastapi import Depends, Header, HTTPException
+
+from app.auth import CurrentUser, validate_token
+from app.config import settings
from app.db.session import get_db
-# User ID header — placeholder for SSO integration.
-# When SSO is added, this will extract user_id from the JWT/session token.
-USER_ID_HEADER = "X-User-Id"
-DEFAULT_USER_ID = "default"
+logger = logging.getLogger("olivas.auth")
+
+_anonymous_user = CurrentUser(oid="default", name="Default User", email="")
-def get_user_id(x_user_id: str | None = None) -> str:
- return x_user_id or DEFAULT_USER_ID
+async def get_current_user(authorization: str | None = Header(None)) -> CurrentUser:
+ """Extract and validate the Bearer token from Authorization header."""
+ if not settings.AZURE_AUTH_ENABLED:
+ return _anonymous_user
+
+ if not authorization:
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
+
+ parts = authorization.split(" ", 1)
+ if len(parts) != 2 or parts[0].lower() != "bearer":
+ raise HTTPException(status_code=401, detail="Invalid Authorization header format")
+
+ token = parts[1]
+ try:
+ return validate_token(token)
+ except Exception as e:
+ logger.warning(f"Token validation failed: {e}")
+ raise HTTPException(status_code=401, detail="Invalid or expired token")
+
+
+async def get_user_id(user: CurrentUser = Depends(get_current_user)) -> str:
+ """Return the user's OID from the validated token."""
+ return user.oid
diff --git a/backend/app/main.py b/backend/app/main.py
index 3edb837..31f79f8 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -35,6 +35,15 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.warning(f"Failed to load ML models: {e}. Analysis will fail until models load.")
+ # Warm up JWKS cache for Azure AD token validation
+ if settings.AZURE_AUTH_ENABLED and settings.AZURE_TENANT_ID:
+ try:
+ from app.auth import refresh_jwks_cache
+ refresh_jwks_cache()
+ logger.info("Azure AD JWKS cache warmed up")
+ except Exception as e:
+ logger.warning(f"Failed to warm JWKS cache: {e}. Will fetch on first request.")
+
yield
# Shutdown
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 9c4e383..77e2449 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -21,6 +21,7 @@ dependencies = [
"aiofiles>=23.0",
"anthropic>=0.40",
"httpx>=0.27",
+ "PyJWT[crypto]>=2.8",
]
[project.optional-dependencies]
diff --git a/docker-compose.yml b/docker-compose.yml
index 6f476dd..9128279 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -30,6 +30,10 @@ services:
CLOUD_RUN_PROCESSING_URL: ${CLOUD_RUN_PROCESSING_URL:-}
CLOUD_RUN_SECRET: ${CLOUD_RUN_SECRET:-}
GOOGLE_CLOUD_PROJECT: ${GOOGLE_CLOUD_PROJECT:-optical-414516}
+ # Azure AD SSO
+ AZURE_AUTH_ENABLED: ${AZURE_AUTH_ENABLED:-true}
+ AZURE_TENANT_ID: ${AZURE_TENANT_ID:-}
+ AZURE_CLIENT_ID: ${AZURE_CLIENT_ID:-}
volumes:
- uploads:/app/data/uploads
depends_on:
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index 2502453..a0e0df7 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -8,6 +8,8 @@
"name": "frontend",
"version": "0.0.0",
"dependencies": {
+ "@azure/msal-browser": "^5.4.0",
+ "@azure/msal-react": "^5.0.6",
"@tailwindcss/vite": "^4.2.1",
"@tanstack/react-query": "^5.90.21",
"axios": "^1.13.5",
@@ -35,6 +37,40 @@
"vite": "^7.3.1"
}
},
+ "node_modules/@azure/msal-browser": {
+ "version": "5.4.0",
+ "resolved": "https://registry.npmjs.org/@azure/msal-browser/-/msal-browser-5.4.0.tgz",
+ "integrity": "sha512-GvRbLNk26oPOPpnry4Ym8wtXrmdozGm2Sry5EKfui0siwnEuAKWEeMLLyosDo5nVEIIDO1C2t/+HpVzqqCWlfQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@azure/msal-common": "16.2.0"
+ },
+ "engines": {
+ "node": ">=0.8.0"
+ }
+ },
+ "node_modules/@azure/msal-common": {
+ "version": "16.2.0",
+ "resolved": "https://registry.npmjs.org/@azure/msal-common/-/msal-common-16.2.0.tgz",
+ "integrity": "sha512-ge0nGzTLmEE5lg7tSCbTBrYqMGkpFQeQEtqfcKPuGJn/FPFf8Xz51uDfZsm5xpstNZGMYPhHvnYbL8OeNp/aLw==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=0.8.0"
+ }
+ },
+ "node_modules/@azure/msal-react": {
+ "version": "5.0.6",
+ "resolved": "https://registry.npmjs.org/@azure/msal-react/-/msal-react-5.0.6.tgz",
+ "integrity": "sha512-p0YTCdsx+jkZNR/3Awznn12LxlTcKdsH4/FBiIoD53axn6A4y8g+9sng716PtxMvXSGaRIZo9JKM/yfFdz/+oQ==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=20"
+ },
+ "peerDependencies": {
+ "@azure/msal-browser": "^5.4.0",
+ "react": "^19.2.1"
+ }
+ },
"node_modules/@babel/code-frame": {
"version": "7.29.0",
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz",
diff --git a/frontend/package.json b/frontend/package.json
index 1683c85..acb0283 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -10,6 +10,8 @@
"preview": "vite preview"
},
"dependencies": {
+ "@azure/msal-browser": "^5.4.0",
+ "@azure/msal-react": "^5.0.6",
"@tailwindcss/vite": "^4.2.1",
"@tanstack/react-query": "^5.90.21",
"axios": "^1.13.5",
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index e9e3108..1d730a9 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -1,5 +1,6 @@
import { Routes, Route } from "react-router-dom";
import AppLayout from "./components/layout/AppLayout";
+import RequireAuth from "./auth/RequireAuth";
import Dashboard from "./pages/Dashboard";
import ProjectDetail from "./pages/ProjectDetail";
import NewAnalysis from "./pages/NewAnalysis";
@@ -10,17 +11,19 @@ import About from "./pages/About";
function App() {
return (
-