diff --git a/.env.example b/.env.example index e872649..c11d0a6 100644 --- a/.env.example +++ b/.env.example @@ -16,3 +16,13 @@ BACKEND_PORT=8000 # AI Design Analysis (optional — leave empty to disable) ANTHROPIC_API_KEY= + +# Azure AD SSO (set AZURE_AUTH_ENABLED=false to disable) +AZURE_AUTH_ENABLED=true +AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385 +AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef + +# Frontend Azure AD (Vite env vars) +VITE_AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385 +VITE_AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef +VITE_AZURE_REDIRECT_URI=https://ai-sandbox.oliver.solutions/olivas diff --git a/backend/app/api/endpoints/analysis.py b/backend/app/api/endpoints/analysis.py index 2bf23d7..7f76fc0 100644 --- a/backend/app/api/endpoints/analysis.py +++ b/backend/app/api/endpoints/analysis.py @@ -1,6 +1,6 @@ import io -from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, UploadFile, Form +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, Form from fastapi.responses import StreamingResponse from PIL import Image from sqlalchemy import select @@ -27,10 +27,8 @@ async def create_analysis( name: str | None = Form(None), model: str = Form("deepgaze_iie"), db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) - # Verify project belongs to user stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id) result = await db.execute(stmt) @@ -474,9 +472,8 @@ async def check_ai_insights_available(): async def get_analysis( analysis_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() @@ -515,9 +512,8 @@ async def get_analysis( async def get_analysis_status( analysis_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() @@ -531,9 +527,8 @@ async def get_analysis_image( analysis_id: str, image_type: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() @@ -564,7 +559,7 @@ async def get_analysis_image( async def generate_ai_insights_endpoint( analysis_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): """Generate AI-powered insights for a completed analysis using Claude.""" from app.services.ai_insights import generate_ai_insights, is_available @@ -572,7 +567,6 @@ async def generate_ai_insights_endpoint( if not is_available(): raise HTTPException(status_code=503, detail="AI insights not configured (missing API key)") - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() @@ -632,9 +626,8 @@ async def generate_ai_insights_endpoint( async def delete_analysis( analysis_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() diff --git a/backend/app/api/endpoints/aoi.py b/backend/app/api/endpoints/aoi.py index a7b598a..b8307a6 100644 --- a/backend/app/api/endpoints/aoi.py +++ b/backend/app/api/endpoints/aoi.py @@ -1,5 +1,5 @@ import numpy as np -from fastapi import APIRouter, Depends, Header, HTTPException +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -19,9 +19,8 @@ async def create_aois( analysis_id: str, body: AOICreate, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() @@ -68,9 +67,8 @@ async def create_aois( async def list_aois( analysis_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) analysis = result.scalar_one_or_none() @@ -88,9 +86,8 @@ async def delete_aoi( analysis_id: str, aoi_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) # Verify analysis ownership stmt = select(Analysis).where(Analysis.id == analysis_id, Analysis.user_id == user_id) result = await db.execute(stmt) diff --git a/backend/app/api/endpoints/comparison.py b/backend/app/api/endpoints/comparison.py index 94be9ef..70836a8 100644 --- a/backend/app/api/endpoints/comparison.py +++ b/backend/app/api/endpoints/comparison.py @@ -1,4 +1,5 @@ -from fastapi import APIRouter, Depends, Header, HTTPException +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -11,13 +12,6 @@ from app.models.project import Project router = APIRouter(tags=["comparison"]) -class ComparisonCreate: - pass - - -from pydantic import BaseModel - - class ComparisonCreateBody(BaseModel): name: str analysis_ids: list[str] @@ -39,10 +33,8 @@ async def create_comparison( project_id: str, body: ComparisonCreateBody, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) - # Verify project stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id) result = await db.execute(stmt) @@ -111,9 +103,8 @@ async def create_comparison( async def get_comparison( comparison_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Comparison).where( Comparison.id == comparison_id, Comparison.user_id == user_id ) diff --git a/backend/app/api/endpoints/projects.py b/backend/app/api/endpoints/projects.py index a4fc915..f583db6 100644 --- a/backend/app/api/endpoints/projects.py +++ b/backend/app/api/endpoints/projects.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, Header, HTTPException +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -16,9 +16,8 @@ router = APIRouter(prefix="/projects", tags=["projects"]) async def create_project( body: ProjectCreate, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) project = Project(user_id=user_id, name=body.name, description=body.description) db.add(project) await db.flush() @@ -38,9 +37,8 @@ async def list_projects( page: int = 1, per_page: int = 20, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) offset = (page - 1) * per_page stmt = ( @@ -75,9 +73,8 @@ async def list_projects( async def get_project( project_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = ( select(Project) .options(selectinload(Project.analyses)) @@ -95,9 +92,8 @@ async def update_project( project_id: str, body: ProjectUpdate, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id) result = await db.execute(stmt) project = result.scalar_one_or_none() @@ -130,9 +126,8 @@ async def update_project( async def delete_project( project_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = select(Project).where(Project.id == project_id, Project.user_id == user_id) result = await db.execute(stmt) project = result.scalar_one_or_none() diff --git a/backend/app/api/endpoints/reports.py b/backend/app/api/endpoints/reports.py index ecf8b10..08f0368 100644 --- a/backend/app/api/endpoints/reports.py +++ b/backend/app/api/endpoints/reports.py @@ -1,7 +1,7 @@ import io import logging -from fastapi import APIRouter, Depends, Header, HTTPException +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -20,9 +20,8 @@ logger = logging.getLogger("olivas.reports") async def download_report( analysis_id: str, db: AsyncSession = Depends(get_db), - x_user_id: str | None = Header(None), + user_id: str = Depends(get_user_id), ): - user_id = get_user_id(x_user_id) stmt = ( select(Analysis) .options(selectinload(Analysis.aois)) diff --git a/backend/app/auth.py b/backend/app/auth.py new file mode 100644 index 0000000..7ffaa97 --- /dev/null +++ b/backend/app/auth.py @@ -0,0 +1,97 @@ +"""Azure AD JWT validation using JWKS (RS256).""" + +import logging +import time +from dataclasses import dataclass + +import httpx +import jwt + +from app.config import settings + +logger = logging.getLogger("olivas.auth") + +# JWKS cache +_jwks_cache: dict = {} +_jwks_cache_time: float = 0 +_JWKS_CACHE_TTL = 3600 # 1 hour + + +@dataclass +class CurrentUser: + oid: str + name: str + email: str + + +def _get_jwks_uri() -> str: + tenant = settings.AZURE_TENANT_ID + return f"https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys" + + +def _get_issuer() -> str: + tenant = settings.AZURE_TENANT_ID + return f"https://login.microsoftonline.com/{tenant}/v2.0" + + +def refresh_jwks_cache() -> None: + """Fetch and cache JWKS keys from Azure AD.""" + global _jwks_cache, _jwks_cache_time + uri = _get_jwks_uri() + logger.info(f"Fetching JWKS from {uri}") + resp = httpx.get(uri, timeout=10) + resp.raise_for_status() + _jwks_cache = resp.json() + _jwks_cache_time = time.time() + logger.info(f"JWKS cache refreshed ({len(_jwks_cache.get('keys', []))} keys)") + + +def _get_signing_key(token: str) -> jwt.algorithms.RSAAlgorithm: + """Get the signing key for the given token from cached JWKS.""" + global _jwks_cache, _jwks_cache_time + + if not _jwks_cache or (time.time() - _jwks_cache_time > _JWKS_CACHE_TTL): + refresh_jwks_cache() + + jwks_client = jwt.PyJWKClient.__new__(jwt.PyJWKClient) + jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache) + + try: + header = jwt.get_unverified_header(token) + kid = header.get("kid") + for key in jwks_client.jwk_set.keys: + if key.key_id == kid: + return key.key + except Exception: + pass + + # Key not found — try refreshing once + refresh_jwks_cache() + jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache) + header = jwt.get_unverified_header(token) + kid = header.get("kid") + for key in jwks_client.jwk_set.keys: + if key.key_id == kid: + return key.key + + raise jwt.InvalidTokenError(f"Unable to find signing key with kid={kid}") + + +def validate_token(token: str) -> CurrentUser: + """Decode and validate an Azure AD JWT token. Returns CurrentUser.""" + signing_key = _get_signing_key(token) + + payload = jwt.decode( + token, + signing_key, + algorithms=["RS256"], + audience=settings.AZURE_CLIENT_ID, + issuer=_get_issuer(), + options={"require": ["exp", "iss", "aud", "oid"]}, + ) + + return CurrentUser( + oid=payload["oid"], + name=payload.get("name", ""), + email=payload.get("preferred_username", payload.get("email", "")), + ) diff --git a/backend/app/config.py b/backend/app/config.py index 6b2e250..2c78cf3 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -17,6 +17,11 @@ class Settings(BaseSettings): GOOGLE_CLOUD_PROJECT: str = "optical-414516" + # Azure AD SSO + AZURE_TENANT_ID: str = "" + AZURE_CLIENT_ID: str = "" + AZURE_AUTH_ENABLED: bool = True + @property def use_cloud_run(self) -> bool: return bool(self.CLOUD_RUN_SALIENCY_URL) diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index b4b00ce..e4f8086 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -1,12 +1,36 @@ -from sqlalchemy.ext.asyncio import AsyncSession +import logging +from fastapi import Depends, Header, HTTPException + +from app.auth import CurrentUser, validate_token +from app.config import settings from app.db.session import get_db -# User ID header — placeholder for SSO integration. -# When SSO is added, this will extract user_id from the JWT/session token. -USER_ID_HEADER = "X-User-Id" -DEFAULT_USER_ID = "default" +logger = logging.getLogger("olivas.auth") + +_anonymous_user = CurrentUser(oid="default", name="Default User", email="") -def get_user_id(x_user_id: str | None = None) -> str: - return x_user_id or DEFAULT_USER_ID +async def get_current_user(authorization: str | None = Header(None)) -> CurrentUser: + """Extract and validate the Bearer token from Authorization header.""" + if not settings.AZURE_AUTH_ENABLED: + return _anonymous_user + + if not authorization: + raise HTTPException(status_code=401, detail="Missing Authorization header") + + parts = authorization.split(" ", 1) + if len(parts) != 2 or parts[0].lower() != "bearer": + raise HTTPException(status_code=401, detail="Invalid Authorization header format") + + token = parts[1] + try: + return validate_token(token) + except Exception as e: + logger.warning(f"Token validation failed: {e}") + raise HTTPException(status_code=401, detail="Invalid or expired token") + + +async def get_user_id(user: CurrentUser = Depends(get_current_user)) -> str: + """Return the user's OID from the validated token.""" + return user.oid diff --git a/backend/app/main.py b/backend/app/main.py index 3edb837..31f79f8 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -35,6 +35,15 @@ async def lifespan(app: FastAPI): except Exception as e: logger.warning(f"Failed to load ML models: {e}. Analysis will fail until models load.") + # Warm up JWKS cache for Azure AD token validation + if settings.AZURE_AUTH_ENABLED and settings.AZURE_TENANT_ID: + try: + from app.auth import refresh_jwks_cache + refresh_jwks_cache() + logger.info("Azure AD JWKS cache warmed up") + except Exception as e: + logger.warning(f"Failed to warm JWKS cache: {e}. Will fetch on first request.") + yield # Shutdown diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9c4e383..77e2449 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "aiofiles>=23.0", "anthropic>=0.40", "httpx>=0.27", + "PyJWT[crypto]>=2.8", ] [project.optional-dependencies] diff --git a/docker-compose.yml b/docker-compose.yml index 6f476dd..9128279 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,6 +30,10 @@ services: CLOUD_RUN_PROCESSING_URL: ${CLOUD_RUN_PROCESSING_URL:-} CLOUD_RUN_SECRET: ${CLOUD_RUN_SECRET:-} GOOGLE_CLOUD_PROJECT: ${GOOGLE_CLOUD_PROJECT:-optical-414516} + # Azure AD SSO + AZURE_AUTH_ENABLED: ${AZURE_AUTH_ENABLED:-true} + AZURE_TENANT_ID: ${AZURE_TENANT_ID:-} + AZURE_CLIENT_ID: ${AZURE_CLIENT_ID:-} volumes: - uploads:/app/data/uploads depends_on: diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 2502453..a0e0df7 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,6 +8,8 @@ "name": "frontend", "version": "0.0.0", "dependencies": { + "@azure/msal-browser": "^5.4.0", + "@azure/msal-react": "^5.0.6", "@tailwindcss/vite": "^4.2.1", "@tanstack/react-query": "^5.90.21", "axios": "^1.13.5", @@ -35,6 +37,40 @@ "vite": "^7.3.1" } }, + "node_modules/@azure/msal-browser": { + "version": "5.4.0", + "resolved": "https://registry.npmjs.org/@azure/msal-browser/-/msal-browser-5.4.0.tgz", + "integrity": "sha512-GvRbLNk26oPOPpnry4Ym8wtXrmdozGm2Sry5EKfui0siwnEuAKWEeMLLyosDo5nVEIIDO1C2t/+HpVzqqCWlfQ==", + "license": "MIT", + "dependencies": { + "@azure/msal-common": "16.2.0" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/@azure/msal-common": { + "version": "16.2.0", + "resolved": "https://registry.npmjs.org/@azure/msal-common/-/msal-common-16.2.0.tgz", + "integrity": "sha512-ge0nGzTLmEE5lg7tSCbTBrYqMGkpFQeQEtqfcKPuGJn/FPFf8Xz51uDfZsm5xpstNZGMYPhHvnYbL8OeNp/aLw==", + "license": "MIT", + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/@azure/msal-react": { + "version": "5.0.6", + "resolved": "https://registry.npmjs.org/@azure/msal-react/-/msal-react-5.0.6.tgz", + "integrity": "sha512-p0YTCdsx+jkZNR/3Awznn12LxlTcKdsH4/FBiIoD53axn6A4y8g+9sng716PtxMvXSGaRIZo9JKM/yfFdz/+oQ==", + "license": "MIT", + "engines": { + "node": ">=20" + }, + "peerDependencies": { + "@azure/msal-browser": "^5.4.0", + "react": "^19.2.1" + } + }, "node_modules/@babel/code-frame": { "version": "7.29.0", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index 1683c85..acb0283 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,6 +10,8 @@ "preview": "vite preview" }, "dependencies": { + "@azure/msal-browser": "^5.4.0", + "@azure/msal-react": "^5.0.6", "@tailwindcss/vite": "^4.2.1", "@tanstack/react-query": "^5.90.21", "axios": "^1.13.5", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e9e3108..1d730a9 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,6 @@ import { Routes, Route } from "react-router-dom"; import AppLayout from "./components/layout/AppLayout"; +import RequireAuth from "./auth/RequireAuth"; import Dashboard from "./pages/Dashboard"; import ProjectDetail from "./pages/ProjectDetail"; import NewAnalysis from "./pages/NewAnalysis"; @@ -10,17 +11,19 @@ import About from "./pages/About"; function App() { return ( - - }> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - - + + + }> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + + + ); } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 902346c..8902e86 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,4 +1,5 @@ import axios from "axios"; +import { msalInstance, loginRequest } from "../auth/msalConfig"; const client = axios.create({ baseURL: "/api", @@ -7,14 +8,36 @@ const client = axios.create({ }, }); -client.interceptors.request.use((config) => { - config.headers["X-User-Id"] = "default"; +client.interceptors.request.use(async (config) => { + const accounts = msalInstance.getAllAccounts(); + if (accounts.length > 0) { + try { + const response = await msalInstance.acquireTokenSilent({ + ...loginRequest, + account: accounts[0], + }); + config.headers.Authorization = `Bearer ${response.idToken}`; + } catch (e) { + console.warn("[API] Silent token acquisition failed, will redirect:", e); + await msalInstance.acquireTokenRedirect(loginRequest); + } + } return config; }); client.interceptors.response.use( (response) => response, - (error) => { + async (error) => { + if (error.response?.status === 401) { + const accounts = msalInstance.getAllAccounts(); + if (accounts.length > 0) { + try { + await msalInstance.acquireTokenRedirect(loginRequest); + } catch (e) { + console.error("[API] Re-auth redirect failed:", e); + } + } + } const message = error.response?.data?.detail || error.message || "An error occurred"; console.error("[API Error]", message); diff --git a/frontend/src/auth/RequireAuth.tsx b/frontend/src/auth/RequireAuth.tsx new file mode 100644 index 0000000..6fe1402 --- /dev/null +++ b/frontend/src/auth/RequireAuth.tsx @@ -0,0 +1,36 @@ +import { useEffect } from "react"; +import { useIsAuthenticated, useMsal } from "@azure/msal-react"; +import { InteractionStatus } from "@azure/msal-browser"; +import { loginRequest } from "./msalConfig"; +import LoadingSpinner from "../components/common/LoadingSpinner"; + +export default function RequireAuth({ children }: { children: React.ReactNode }) { + const isAuthenticated = useIsAuthenticated(); + const { inProgress, instance } = useMsal(); + + useEffect(() => { + if (!isAuthenticated && inProgress === InteractionStatus.None) { + instance.loginRedirect(loginRequest).catch((e) => { + console.error("[Auth] Login redirect failed:", e); + }); + } + }, [isAuthenticated, inProgress, instance]); + + if (inProgress !== InteractionStatus.None) { + return ( +
+ +
+ ); + } + + if (!isAuthenticated) { + return ( +
+ +
+ ); + } + + return <>{children}; +} diff --git a/frontend/src/auth/msalConfig.ts b/frontend/src/auth/msalConfig.ts new file mode 100644 index 0000000..af777c5 --- /dev/null +++ b/frontend/src/auth/msalConfig.ts @@ -0,0 +1,31 @@ +import { PublicClientApplication, type Configuration, LogLevel } from "@azure/msal-browser"; + +const clientId = import.meta.env.VITE_AZURE_CLIENT_ID ?? ""; +const tenantId = import.meta.env.VITE_AZURE_TENANT_ID ?? ""; +const redirectUri = import.meta.env.VITE_AZURE_REDIRECT_URI ?? window.location.origin; + +const msalConfig: Configuration = { + auth: { + clientId, + authority: `https://login.microsoftonline.com/${tenantId}`, + redirectUri, + postLogoutRedirectUri: redirectUri, + }, + cache: { + cacheLocation: "sessionStorage", + }, + system: { + loggerOptions: { + logLevel: LogLevel.Warning, + loggerCallback: (_level, message) => { + console.debug("[MSAL]", message); + }, + }, + }, +}; + +export const msalInstance = new PublicClientApplication(msalConfig); + +export const loginRequest = { + scopes: [`${clientId}/.default`], +}; diff --git a/frontend/src/components/layout/Header.tsx b/frontend/src/components/layout/Header.tsx index 813f3e2..6d69b50 100644 --- a/frontend/src/components/layout/Header.tsx +++ b/frontend/src/components/layout/Header.tsx @@ -1,8 +1,16 @@ import { Link } from "react-router-dom"; +import { useMsal } from "@azure/msal-react"; import { useUIStore } from "../../stores/uiStore"; export default function Header() { const toggleSidebar = useUIStore((s) => s.toggleSidebar); + const { instance, accounts } = useMsal(); + + const userName = accounts[0]?.name || accounts[0]?.username || ""; + + const handleLogout = () => { + instance.logoutRedirect({ postLogoutRedirectUri: window.location.origin }); + }; return (
New Analysis + {userName && ( + + {userName} + + )} + {accounts.length > 0 && ( + + )}
); diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index a601f65..41ab315 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -2,6 +2,8 @@ import { StrictMode } from "react"; import { createRoot } from "react-dom/client"; import { BrowserRouter } from "react-router-dom"; import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { MsalProvider } from "@azure/msal-react"; +import { msalInstance } from "./auth/msalConfig"; import App from "./App"; import "./globals.css"; @@ -15,12 +17,21 @@ const queryClient = new QueryClient({ }, }); -createRoot(document.getElementById("root")!).render( - - - - - - - , -); +// Initialize MSAL — handle redirect promise +msalInstance.initialize().then(() => { + msalInstance.handleRedirectPromise().catch((e) => { + console.error("[MSAL] Redirect error:", e); + }); + + createRoot(document.getElementById("root")!).render( + + + + + + + + + , + ); +});