""" MSAL / Azure AD token validator (SPA PKCE flow). Backend only validates incoming Bearer JWTs — no server-side MSAL client needed. Frontend sends the MSAL idToken (aud = clientId) for user identification. """ import logging from typing import Optional, Dict, Any import jwt from jwt import PyJWKClient from ..config_runtime import server_config logger = logging.getLogger(__name__) # JWKS client caches keys after first fetch _jwks_client: Optional[PyJWKClient] = None def _get_jwks_client() -> PyJWKClient: global _jwks_client if _jwks_client is None: jwks_uri = ( f"https://login.microsoftonline.com/" f"{server_config.MSAL_TENANT_ID}/discovery/v2.0/keys" ) _jwks_client = PyJWKClient(jwks_uri, cache_keys=True) return _jwks_client class MSALAuthenticator: def __init__(self): if server_config.DEV_MODE: logger.info("Running in DEV_MODE — MSAL authentication bypassed") async def validate_token(self, access_token: str) -> Optional[Dict[str, Any]]: if server_config.DEV_MODE: return { 'oid': server_config.DEV_USER_ID, 'preferred_username': server_config.DEV_USER_EMAIL, 'name': server_config.DEV_USER_NAME, } if not access_token: return None try: jwks_client = _get_jwks_client() signing_key = jwks_client.get_signing_key_from_jwt(access_token) claims = jwt.decode( access_token, signing_key.key, algorithms=["RS256"], audience=server_config.MSAL_CLIENT_ID, issuer=f"https://login.microsoftonline.com/{server_config.MSAL_TENANT_ID}/v2.0", ) user_id = claims.get('oid') if not user_id: logger.warning("Token missing 'oid' claim") return None return { 'oid': user_id, 'preferred_username': claims.get('preferred_username') or claims.get('upn', ''), 'name': claims.get('name', ''), } except jwt.ExpiredSignatureError: logger.warning("Token expired") return None except jwt.InvalidTokenError as e: logger.warning(f"Invalid JWT: {e}") return None except Exception as e: logger.error(f"Token validation error: {e}", exc_info=True) return None async def get_logout_url(self, post_logout_redirect_uri: Optional[str] = None) -> str: if server_config.DEV_MODE: return post_logout_redirect_uri or 'http://localhost:5173' base = f"{server_config.MSAL_AUTHORITY}/oauth2/v2.0/logout" if post_logout_redirect_uri: return f"{base}?post_logout_redirect_uri={post_logout_redirect_uri}" return base def get_client_config(self) -> Dict[str, Any]: if server_config.DEV_MODE: return { 'clientId': server_config.MSAL_CLIENT_ID, 'authority': server_config.MSAL_AUTHORITY, 'redirectUri': server_config.MSAL_REDIRECT_URI, 'devMode': True, } return { 'clientId': server_config.MSAL_CLIENT_ID, 'authority': server_config.MSAL_AUTHORITY, 'redirectUri': server_config.MSAL_REDIRECT_URI, 'devMode': False, } def is_dev_mode(self) -> bool: return server_config.DEV_MODE msal_auth = MSALAuthenticator()