Add Azure AD SSO authentication for backend and frontend
Replace X-User-Id header auth with Azure AD JWT token validation. Backend validates tokens via JWKS, frontend uses MSAL for login/token acquisition. Adds logout button, 401 handling, and configurable AZURE_AUTH_ENABLED toggle. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7925264307
commit
f217a5aea6
20 changed files with 366 additions and 78 deletions
10
.env.example
10
.env.example
|
|
@ -16,3 +16,13 @@ BACKEND_PORT=8000
|
|||
|
||||
# AI Design Analysis (optional — leave empty to disable)
|
||||
ANTHROPIC_API_KEY=
|
||||
|
||||
# Azure AD SSO (set AZURE_AUTH_ENABLED=false to disable)
|
||||
AZURE_AUTH_ENABLED=true
|
||||
AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385
|
||||
AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef
|
||||
|
||||
# Frontend Azure AD (Vite env vars)
|
||||
VITE_AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385
|
||||
VITE_AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef
|
||||
VITE_AZURE_REDIRECT_URI=https://ai-sandbox.oliver.solutions/olivas
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import io
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, UploadFile, Form
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, Form
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
from sqlalchemy import select
|
||||
|
|
@ -27,10 +27,8 @@ async def create_analysis(
|
|||
name: str | None = Form(None),
|
||||
model: str = Form("deepgaze_iie"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
|
||||
# Verify project belongs to user
|
||||
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
|
|
@ -474,9 +472,8 @@ async def check_ai_insights_available():
|
|||
async def get_analysis(
|
||||
analysis_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
@ -515,9 +512,8 @@ async def get_analysis(
|
|||
async def get_analysis_status(
|
||||
analysis_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
@ -531,9 +527,8 @@ async def get_analysis_image(
|
|||
analysis_id: str,
|
||||
image_type: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
@ -564,7 +559,7 @@ async def get_analysis_image(
|
|||
async def generate_ai_insights_endpoint(
|
||||
analysis_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Generate AI-powered insights for a completed analysis using Claude."""
|
||||
from app.services.ai_insights import generate_ai_insights, is_available
|
||||
|
|
@ -572,7 +567,6 @@ async def generate_ai_insights_endpoint(
|
|||
if not is_available():
|
||||
raise HTTPException(status_code=503, detail="AI insights not configured (missing API key)")
|
||||
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
@ -632,9 +626,8 @@ async def generate_ai_insights_endpoint(
|
|||
async def delete_analysis(
|
||||
analysis_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -19,9 +19,8 @@ async def create_aois(
|
|||
analysis_id: str,
|
||||
body: AOICreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
@ -68,9 +67,8 @@ async def create_aois(
|
|||
async def list_aois(
|
||||
analysis_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
|
@ -88,9 +86,8 @@ async def delete_aoi(
|
|||
analysis_id: str,
|
||||
aoi_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
# Verify analysis ownership
|
||||
stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -11,13 +12,6 @@ from app.models.project import Project
|
|||
router = APIRouter(tags=["comparison"])
|
||||
|
||||
|
||||
class ComparisonCreate:
|
||||
pass
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ComparisonCreateBody(BaseModel):
|
||||
name: str
|
||||
analysis_ids: list[str]
|
||||
|
|
@ -39,10 +33,8 @@ async def create_comparison(
|
|||
project_id: str,
|
||||
body: ComparisonCreateBody,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
|
||||
# Verify project
|
||||
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
|
|
@ -111,9 +103,8 @@ async def create_comparison(
|
|||
async def get_comparison(
|
||||
comparison_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Comparison).where(
|
||||
Comparison.id == comparison_id, Comparison.user_id == user_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
|
@ -16,9 +16,8 @@ router = APIRouter(prefix="/projects", tags=["projects"])
|
|||
async def create_project(
|
||||
body: ProjectCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
project = Project(user_id=user_id, name=body.name, description=body.description)
|
||||
db.add(project)
|
||||
await db.flush()
|
||||
|
|
@ -38,9 +37,8 @@ async def list_projects(
|
|||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
offset = (page - 1) * per_page
|
||||
|
||||
stmt = (
|
||||
|
|
@ -75,9 +73,8 @@ async def list_projects(
|
|||
async def get_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = (
|
||||
select(Project)
|
||||
.options(selectinload(Project.analyses))
|
||||
|
|
@ -95,9 +92,8 @@ async def update_project(
|
|||
project_id: str,
|
||||
body: ProjectUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
project = result.scalar_one_or_none()
|
||||
|
|
@ -130,9 +126,8 @@ async def update_project(
|
|||
async def delete_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
||||
result = await db.execute(stmt)
|
||||
project = result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import io
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -20,9 +20,8 @@ logger = logging.getLogger("olivas.reports")
|
|||
async def download_report(
|
||||
analysis_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
x_user_id: str | None = Header(None),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
user_id = get_user_id(x_user_id)
|
||||
stmt = (
|
||||
select(Analysis)
|
||||
.options(selectinload(Analysis.aois))
|
||||
|
|
|
|||
97
backend/app/auth.py
Normal file
97
backend/app/auth.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Azure AD JWT validation using JWKS (RS256)."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger("olivas.auth")
|
||||
|
||||
# JWKS cache
|
||||
_jwks_cache: dict = {}
|
||||
_jwks_cache_time: float = 0
|
||||
_JWKS_CACHE_TTL = 3600 # 1 hour
|
||||
|
||||
|
||||
@dataclass
|
||||
class CurrentUser:
|
||||
oid: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
def _get_jwks_uri() -> str:
|
||||
tenant = settings.AZURE_TENANT_ID
|
||||
return f"https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys"
|
||||
|
||||
|
||||
def _get_issuer() -> str:
|
||||
tenant = settings.AZURE_TENANT_ID
|
||||
return f"https://login.microsoftonline.com/{tenant}/v2.0"
|
||||
|
||||
|
||||
def refresh_jwks_cache() -> None:
|
||||
"""Fetch and cache JWKS keys from Azure AD."""
|
||||
global _jwks_cache, _jwks_cache_time
|
||||
uri = _get_jwks_uri()
|
||||
logger.info(f"Fetching JWKS from {uri}")
|
||||
resp = httpx.get(uri, timeout=10)
|
||||
resp.raise_for_status()
|
||||
_jwks_cache = resp.json()
|
||||
_jwks_cache_time = time.time()
|
||||
logger.info(f"JWKS cache refreshed ({len(_jwks_cache.get('keys', []))} keys)")
|
||||
|
||||
|
||||
def _get_signing_key(token: str) -> jwt.algorithms.RSAAlgorithm:
|
||||
"""Get the signing key for the given token from cached JWKS."""
|
||||
global _jwks_cache, _jwks_cache_time
|
||||
|
||||
if not _jwks_cache or (time.time() - _jwks_cache_time > _JWKS_CACHE_TTL):
|
||||
refresh_jwks_cache()
|
||||
|
||||
jwks_client = jwt.PyJWKClient.__new__(jwt.PyJWKClient)
|
||||
jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header.get("kid")
|
||||
for key in jwks_client.jwk_set.keys:
|
||||
if key.key_id == kid:
|
||||
return key.key
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Key not found — try refreshing once
|
||||
refresh_jwks_cache()
|
||||
jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header.get("kid")
|
||||
for key in jwks_client.jwk_set.keys:
|
||||
if key.key_id == kid:
|
||||
return key.key
|
||||
|
||||
raise jwt.InvalidTokenError(f"Unable to find signing key with kid={kid}")
|
||||
|
||||
|
||||
def validate_token(token: str) -> CurrentUser:
|
||||
"""Decode and validate an Azure AD JWT token. Returns CurrentUser."""
|
||||
signing_key = _get_signing_key(token)
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
signing_key,
|
||||
algorithms=["RS256"],
|
||||
audience=settings.AZURE_CLIENT_ID,
|
||||
issuer=_get_issuer(),
|
||||
options={"require": ["exp", "iss", "aud", "oid"]},
|
||||
)
|
||||
|
||||
return CurrentUser(
|
||||
oid=payload["oid"],
|
||||
name=payload.get("name", ""),
|
||||
email=payload.get("preferred_username", payload.get("email", "")),
|
||||
)
|
||||
|
|
@ -17,6 +17,11 @@ class Settings(BaseSettings):
|
|||
|
||||
GOOGLE_CLOUD_PROJECT: str = "optical-414516"
|
||||
|
||||
# Azure AD SSO
|
||||
AZURE_TENANT_ID: str = ""
|
||||
AZURE_CLIENT_ID: str = ""
|
||||
AZURE_AUTH_ENABLED: bool = True
|
||||
|
||||
@property
|
||||
def use_cloud_run(self) -> bool:
|
||||
return bool(self.CLOUD_RUN_SALIENCY_URL)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,36 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import logging
|
||||
|
||||
from fastapi import Depends, Header, HTTPException
|
||||
|
||||
from app.auth import CurrentUser, validate_token
|
||||
from app.config import settings
|
||||
from app.db.session import get_db
|
||||
|
||||
# User ID header — placeholder for SSO integration.
|
||||
# When SSO is added, this will extract user_id from the JWT/session token.
|
||||
USER_ID_HEADER = "X-User-Id"
|
||||
DEFAULT_USER_ID = "default"
|
||||
logger = logging.getLogger("olivas.auth")
|
||||
|
||||
_anonymous_user = CurrentUser(oid="default", name="Default User", email="")
|
||||
|
||||
|
||||
def get_user_id(x_user_id: str | None = None) -> str:
|
||||
return x_user_id or DEFAULT_USER_ID
|
||||
async def get_current_user(authorization: str | None = Header(None)) -> CurrentUser:
|
||||
"""Extract and validate the Bearer token from Authorization header."""
|
||||
if not settings.AZURE_AUTH_ENABLED:
|
||||
return _anonymous_user
|
||||
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
|
||||
|
||||
token = parts[1]
|
||||
try:
|
||||
return validate_token(token)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token validation failed: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
|
||||
async def get_user_id(user: CurrentUser = Depends(get_current_user)) -> str:
|
||||
"""Return the user's OID from the validated token."""
|
||||
return user.oid
|
||||
|
|
|
|||
|
|
@ -35,6 +35,15 @@ async def lifespan(app: FastAPI):
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to load ML models: {e}. Analysis will fail until models load.")
|
||||
|
||||
# Warm up JWKS cache for Azure AD token validation
|
||||
if settings.AZURE_AUTH_ENABLED and settings.AZURE_TENANT_ID:
|
||||
try:
|
||||
from app.auth import refresh_jwks_cache
|
||||
refresh_jwks_cache()
|
||||
logger.info("Azure AD JWKS cache warmed up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to warm JWKS cache: {e}. Will fetch on first request.")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ dependencies = [
|
|||
"aiofiles>=23.0",
|
||||
"anthropic>=0.40",
|
||||
"httpx>=0.27",
|
||||
"PyJWT[crypto]>=2.8",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ services:
|
|||
CLOUD_RUN_PROCESSING_URL: ${CLOUD_RUN_PROCESSING_URL:-}
|
||||
CLOUD_RUN_SECRET: ${CLOUD_RUN_SECRET:-}
|
||||
GOOGLE_CLOUD_PROJECT: ${GOOGLE_CLOUD_PROJECT:-optical-414516}
|
||||
# Azure AD SSO
|
||||
AZURE_AUTH_ENABLED: ${AZURE_AUTH_ENABLED:-true}
|
||||
AZURE_TENANT_ID: ${AZURE_TENANT_ID:-}
|
||||
AZURE_CLIENT_ID: ${AZURE_CLIENT_ID:-}
|
||||
volumes:
|
||||
- uploads:/app/data/uploads
|
||||
depends_on:
|
||||
|
|
|
|||
36
frontend/package-lock.json
generated
36
frontend/package-lock.json
generated
|
|
@ -8,6 +8,8 @@
|
|||
"name": "frontend",
|
||||
"version": "0.0.0",
|
||||
"dependencies": {
|
||||
"@azure/msal-browser": "^5.4.0",
|
||||
"@azure/msal-react": "^5.0.6",
|
||||
"@tailwindcss/vite": "^4.2.1",
|
||||
"@tanstack/react-query": "^5.90.21",
|
||||
"axios": "^1.13.5",
|
||||
|
|
@ -35,6 +37,40 @@
|
|||
"vite": "^7.3.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@azure/msal-browser": {
|
||||
"version": "5.4.0",
|
||||
"resolved": "https://registry.npmjs.org/@azure/msal-browser/-/msal-browser-5.4.0.tgz",
|
||||
"integrity": "sha512-GvRbLNk26oPOPpnry4Ym8wtXrmdozGm2Sry5EKfui0siwnEuAKWEeMLLyosDo5nVEIIDO1C2t/+HpVzqqCWlfQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@azure/msal-common": "16.2.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@azure/msal-common": {
|
||||
"version": "16.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@azure/msal-common/-/msal-common-16.2.0.tgz",
|
||||
"integrity": "sha512-ge0nGzTLmEE5lg7tSCbTBrYqMGkpFQeQEtqfcKPuGJn/FPFf8Xz51uDfZsm5xpstNZGMYPhHvnYbL8OeNp/aLw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=0.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@azure/msal-react": {
|
||||
"version": "5.0.6",
|
||||
"resolved": "https://registry.npmjs.org/@azure/msal-react/-/msal-react-5.0.6.tgz",
|
||||
"integrity": "sha512-p0YTCdsx+jkZNR/3Awznn12LxlTcKdsH4/FBiIoD53axn6A4y8g+9sng716PtxMvXSGaRIZo9JKM/yfFdz/+oQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=20"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@azure/msal-browser": "^5.4.0",
|
||||
"react": "^19.2.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/code-frame": {
|
||||
"version": "7.29.0",
|
||||
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@azure/msal-browser": "^5.4.0",
|
||||
"@azure/msal-react": "^5.0.6",
|
||||
"@tailwindcss/vite": "^4.2.1",
|
||||
"@tanstack/react-query": "^5.90.21",
|
||||
"axios": "^1.13.5",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import { Routes, Route } from "react-router-dom";
|
||||
import AppLayout from "./components/layout/AppLayout";
|
||||
import RequireAuth from "./auth/RequireAuth";
|
||||
import Dashboard from "./pages/Dashboard";
|
||||
import ProjectDetail from "./pages/ProjectDetail";
|
||||
import NewAnalysis from "./pages/NewAnalysis";
|
||||
|
|
@ -10,17 +11,19 @@ import About from "./pages/About";
|
|||
|
||||
function App() {
|
||||
return (
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/" element={<Dashboard />} />
|
||||
<Route path="/projects/:projectId" element={<ProjectDetail />} />
|
||||
<Route path="/analyze" element={<NewAnalysis />} />
|
||||
<Route path="/analyze/:analysisId" element={<AnalysisView />} />
|
||||
<Route path="/compare/:comparisonId" element={<ComparisonView />} />
|
||||
<Route path="/help" element={<Help />} />
|
||||
<Route path="/about" element={<About />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
<RequireAuth>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/" element={<Dashboard />} />
|
||||
<Route path="/projects/:projectId" element={<ProjectDetail />} />
|
||||
<Route path="/analyze" element={<NewAnalysis />} />
|
||||
<Route path="/analyze/:analysisId" element={<AnalysisView />} />
|
||||
<Route path="/compare/:comparisonId" element={<ComparisonView />} />
|
||||
<Route path="/help" element={<Help />} />
|
||||
<Route path="/about" element={<About />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</RequireAuth>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import axios from "axios";
|
||||
import { msalInstance, loginRequest } from "../auth/msalConfig";
|
||||
|
||||
const client = axios.create({
|
||||
baseURL: "/api",
|
||||
|
|
@ -7,14 +8,36 @@ const client = axios.create({
|
|||
},
|
||||
});
|
||||
|
||||
client.interceptors.request.use((config) => {
|
||||
config.headers["X-User-Id"] = "default";
|
||||
client.interceptors.request.use(async (config) => {
|
||||
const accounts = msalInstance.getAllAccounts();
|
||||
if (accounts.length > 0) {
|
||||
try {
|
||||
const response = await msalInstance.acquireTokenSilent({
|
||||
...loginRequest,
|
||||
account: accounts[0],
|
||||
});
|
||||
config.headers.Authorization = `Bearer ${response.idToken}`;
|
||||
} catch (e) {
|
||||
console.warn("[API] Silent token acquisition failed, will redirect:", e);
|
||||
await msalInstance.acquireTokenRedirect(loginRequest);
|
||||
}
|
||||
}
|
||||
return config;
|
||||
});
|
||||
|
||||
client.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error) => {
|
||||
async (error) => {
|
||||
if (error.response?.status === 401) {
|
||||
const accounts = msalInstance.getAllAccounts();
|
||||
if (accounts.length > 0) {
|
||||
try {
|
||||
await msalInstance.acquireTokenRedirect(loginRequest);
|
||||
} catch (e) {
|
||||
console.error("[API] Re-auth redirect failed:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
const message =
|
||||
error.response?.data?.detail || error.message || "An error occurred";
|
||||
console.error("[API Error]", message);
|
||||
|
|
|
|||
36
frontend/src/auth/RequireAuth.tsx
Normal file
36
frontend/src/auth/RequireAuth.tsx
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import { useEffect } from "react";
|
||||
import { useIsAuthenticated, useMsal } from "@azure/msal-react";
|
||||
import { InteractionStatus } from "@azure/msal-browser";
|
||||
import { loginRequest } from "./msalConfig";
|
||||
import LoadingSpinner from "../components/common/LoadingSpinner";
|
||||
|
||||
export default function RequireAuth({ children }: { children: React.ReactNode }) {
|
||||
const isAuthenticated = useIsAuthenticated();
|
||||
const { inProgress, instance } = useMsal();
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAuthenticated && inProgress === InteractionStatus.None) {
|
||||
instance.loginRedirect(loginRequest).catch((e) => {
|
||||
console.error("[Auth] Login redirect failed:", e);
|
||||
});
|
||||
}
|
||||
}, [isAuthenticated, inProgress, instance]);
|
||||
|
||||
if (inProgress !== InteractionStatus.None) {
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen">
|
||||
<LoadingSpinner size="lg" message="Authenticating..." />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen">
|
||||
<LoadingSpinner size="lg" message="Redirecting to login..." />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return <>{children}</>;
|
||||
}
|
||||
31
frontend/src/auth/msalConfig.ts
Normal file
31
frontend/src/auth/msalConfig.ts
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import { PublicClientApplication, type Configuration, LogLevel } from "@azure/msal-browser";
|
||||
|
||||
const clientId = import.meta.env.VITE_AZURE_CLIENT_ID ?? "";
|
||||
const tenantId = import.meta.env.VITE_AZURE_TENANT_ID ?? "";
|
||||
const redirectUri = import.meta.env.VITE_AZURE_REDIRECT_URI ?? window.location.origin;
|
||||
|
||||
const msalConfig: Configuration = {
|
||||
auth: {
|
||||
clientId,
|
||||
authority: `https://login.microsoftonline.com/${tenantId}`,
|
||||
redirectUri,
|
||||
postLogoutRedirectUri: redirectUri,
|
||||
},
|
||||
cache: {
|
||||
cacheLocation: "sessionStorage",
|
||||
},
|
||||
system: {
|
||||
loggerOptions: {
|
||||
logLevel: LogLevel.Warning,
|
||||
loggerCallback: (_level, message) => {
|
||||
console.debug("[MSAL]", message);
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const msalInstance = new PublicClientApplication(msalConfig);
|
||||
|
||||
export const loginRequest = {
|
||||
scopes: [`${clientId}/.default`],
|
||||
};
|
||||
|
|
@ -1,8 +1,16 @@
|
|||
import { Link } from "react-router-dom";
|
||||
import { useMsal } from "@azure/msal-react";
|
||||
import { useUIStore } from "../../stores/uiStore";
|
||||
|
||||
export default function Header() {
|
||||
const toggleSidebar = useUIStore((s) => s.toggleSidebar);
|
||||
const { instance, accounts } = useMsal();
|
||||
|
||||
const userName = accounts[0]?.name || accounts[0]?.username || "";
|
||||
|
||||
const handleLogout = () => {
|
||||
instance.logoutRedirect({ postLogoutRedirectUri: window.location.origin });
|
||||
};
|
||||
|
||||
return (
|
||||
<header
|
||||
|
|
@ -71,6 +79,19 @@ export default function Header() {
|
|||
>
|
||||
New Analysis
|
||||
</Link>
|
||||
{userName && (
|
||||
<span className="text-sm text-white/60 ml-2 hidden md:inline">
|
||||
{userName}
|
||||
</span>
|
||||
)}
|
||||
{accounts.length > 0 && (
|
||||
<button
|
||||
onClick={handleLogout}
|
||||
className="text-sm text-white/50 hover:text-white transition-colors px-2 py-1.5 rounded"
|
||||
>
|
||||
Logout
|
||||
</button>
|
||||
)}
|
||||
</nav>
|
||||
</header>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import { StrictMode } from "react";
|
|||
import { createRoot } from "react-dom/client";
|
||||
import { BrowserRouter } from "react-router-dom";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { MsalProvider } from "@azure/msal-react";
|
||||
import { msalInstance } from "./auth/msalConfig";
|
||||
import App from "./App";
|
||||
import "./globals.css";
|
||||
|
||||
|
|
@ -15,12 +17,21 @@ const queryClient = new QueryClient({
|
|||
},
|
||||
});
|
||||
|
||||
createRoot(document.getElementById("root")!).render(
|
||||
<StrictMode>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<BrowserRouter>
|
||||
<App />
|
||||
</BrowserRouter>
|
||||
</QueryClientProvider>
|
||||
</StrictMode>,
|
||||
);
|
||||
// Initialize MSAL — handle redirect promise
|
||||
msalInstance.initialize().then(() => {
|
||||
msalInstance.handleRedirectPromise().catch((e) => {
|
||||
console.error("[MSAL] Redirect error:", e);
|
||||
});
|
||||
|
||||
createRoot(document.getElementById("root")!).render(
|
||||
<StrictMode>
|
||||
<MsalProvider instance={msalInstance}>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<BrowserRouter>
|
||||
<App />
|
||||
</BrowserRouter>
|
||||
</QueryClientProvider>
|
||||
</MsalProvider>
|
||||
</StrictMode>,
|
||||
);
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue