diff --git a/backend/server/auth/msal_auth.py b/backend/server/auth/msal_auth.py index 933cda5..44e8746 100644 --- a/backend/server/auth/msal_auth.py +++ b/backend/server/auth/msal_auth.py @@ -1,18 +1,33 @@ """ MSAL / Azure AD token validator (SPA PKCE flow). Backend only validates incoming Bearer JWTs — no server-side MSAL client needed. +Frontend sends the MSAL idToken (aud = clientId) for user identification. """ import logging -import time from typing import Optional, Dict, Any import jwt +from jwt import PyJWKClient from ..config_runtime import server_config logger = logging.getLogger(__name__) +# JWKS client caches keys after first fetch +_jwks_client: Optional[PyJWKClient] = None + + +def _get_jwks_client() -> PyJWKClient: + global _jwks_client + if _jwks_client is None: + jwks_uri = ( + f"https://login.microsoftonline.com/" + f"{server_config.MSAL_TENANT_ID}/discovery/v2.0/keys" + ) + _jwks_client = PyJWKClient(jwks_uri, cache_keys=True) + return _jwks_client + class MSALAuthenticator: def __init__(self): @@ -31,29 +46,30 @@ class MSALAuthenticator: return None try: - # Decode without signature verification (PKCE SPA tokens may use - # audience = client_id; full sig verification requires fetching JWKS). - unverified = jwt.decode( + jwks_client = _get_jwks_client() + signing_key = jwks_client.get_signing_key_from_jwt(access_token) + claims = jwt.decode( access_token, - options={"verify_signature": False, "verify_aud": False}, + signing_key.key, + algorithms=["RS256"], + audience=server_config.MSAL_CLIENT_ID, + issuer=f"https://login.microsoftonline.com/{server_config.MSAL_TENANT_ID}/v2.0", ) - user_id = unverified.get('oid') + user_id = claims.get('oid') if not user_id: logger.warning("Token missing 'oid' claim") return None - exp = unverified.get('exp', 0) - if exp < time.time(): - logger.warning("Token expired") - return None - return { 'oid': user_id, - 'preferred_username': unverified.get('preferred_username') or unverified.get('upn', ''), - 'name': unverified.get('name', ''), + 'preferred_username': claims.get('preferred_username') or claims.get('upn', ''), + 'name': claims.get('name', ''), } + except jwt.ExpiredSignatureError: + logger.warning("Token expired") + return None except jwt.InvalidTokenError as e: logger.warning(f"Invalid JWT: {e}") return None diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index eb34b10..d07cdc5 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -22,11 +22,16 @@ function AuthGate({ children }: { children: React.ReactNode }) { if (inProgress !== InteractionStatus.None) return const acquire = async () => { - // Dev mode: skip MSAL, just call /auth/me directly - if (import.meta.env.DEV || accounts.length === 0) { + // Dev mode: skip MSAL, just call /auth/me directly (backend uses DEV_MODE) + if (import.meta.env.DEV) { if (!user) fetchMe() return } + // Not yet authenticated — redirect to Azure AD login + if (accounts.length === 0) { + instance.loginRedirect({ scopes: ['openid', 'profile', 'email'] }) + return + } try { const result = await instance.acquireTokenSilent({ account: accounts[0],