""" JWT Token Validator for Azure AD authentication. Python equivalent of JWTValidator.php from MSAL specification. """ import jwt import requests import json import time from datetime import datetime, timezone from typing import Dict, Optional, Any from functools import lru_cache class JWTValidator: """Validates Azure AD JWT tokens server-side with real-time JWKS validation.""" def __init__(self, tenant_id: str, client_id: str): self.tenant_id = tenant_id self.client_id = client_id self.authority = f"https://login.microsoftonline.com/{tenant_id}" self.jwks_uri = f"{self.authority}/discovery/v2.0/keys" self.issuer = f"https://login.microsoftonline.com/{tenant_id}/v2.0" self._jwks_cache = {} self._jwks_cache_time = 0 self.jwks_cache_duration = 3600 # Cache JWKS for 1 hour @lru_cache(maxsize=1) def _get_openid_config(self) -> Dict[str, Any]: """Get OpenID Connect configuration from Azure AD.""" try: config_url = f"{self.authority}/v2.0/.well-known/openid_configuration" response = requests.get(config_url, timeout=10) response.raise_for_status() return response.json() except Exception as e: raise Exception(f"Failed to retrieve OpenID configuration: {str(e)}") def _get_jwks(self) -> Dict[str, Any]: """Retrieve JWKS (JSON Web Key Set) from Azure AD with caching.""" current_time = time.time() # Use cached JWKS if still valid if (self._jwks_cache and current_time - self._jwks_cache_time < self.jwks_cache_duration): return self._jwks_cache try: response = requests.get(self.jwks_uri, timeout=10) response.raise_for_status() jwks = response.json() # Update cache self._jwks_cache = jwks self._jwks_cache_time = current_time return jwks except Exception as e: # If we have cached JWKS and request fails, use cache if self._jwks_cache: return self._jwks_cache raise Exception(f"Failed to retrieve JWKS: {str(e)}") def _get_signing_key(self, kid: str) -> str: """Get the signing key for a given key ID from JWKS.""" jwks = self._get_jwks() for key in jwks.get('keys', []): if key.get('kid') == kid: # Convert JWK to PEM format for PyJWT return jwt.algorithms.RSAAlgorithm.from_jwk(key) raise Exception(f"Unable to find signing key with kid: {kid}") def validate_token(self, token: str) -> Dict[str, Any]: """ Validate Azure AD JWT token with comprehensive checks. Args: token: The JWT token to validate Returns: Dict containing validated token claims Raises: Exception: If token validation fails """ try: # Decode header to get key ID without verification unverified_header = jwt.get_unverified_header(token) kid = unverified_header.get('kid') if not kid: raise Exception("Token header missing 'kid' field") # Get signing key signing_key = self._get_signing_key(kid) # Define expected audiences (ID token and access token) expected_audiences = [ self.client_id, # ID token audience f"api://{self.client_id}", # Access token audience (if applicable) "https://graph.microsoft.com" # Microsoft Graph access token ] # Validate token with multiple audience options last_exception = None for audience in expected_audiences: try: payload = jwt.decode( token, signing_key, algorithms=['RS256'], audience=audience, issuer=self.issuer, options={ 'verify_exp': True, 'verify_nbf': True, 'verify_aud': True, 'verify_iss': True, 'require': ['exp', 'nbf', 'iat', 'aud', 'iss'] } ) # Additional custom validations self._validate_custom_claims(payload) return payload except jwt.InvalidAudienceError as e: last_exception = e continue # Try next audience except Exception as e: raise e # Other errors are not recoverable # If we get here, all audiences failed raise Exception(f"Token validation failed for all expected audiences. Last error: {str(last_exception)}") except jwt.ExpiredSignatureError: raise Exception("Token has expired") except jwt.InvalidTokenError as e: raise Exception(f"Invalid token: {str(e)}") except Exception as e: raise Exception(f"Token validation failed: {str(e)}") def _validate_custom_claims(self, payload: Dict[str, Any]) -> None: """Perform additional custom claim validations.""" current_time = datetime.now(timezone.utc).timestamp() # Check token timing exp = payload.get('exp') nbf = payload.get('nbf', 0) iat = payload.get('iat') if exp and current_time >= exp: raise Exception("Token has expired") if nbf and current_time < nbf: raise Exception("Token is not yet valid (nbf)") if iat and current_time < iat - 300: # Allow 5 minutes clock skew raise Exception("Token issued in the future") # Validate tenant tid = payload.get('tid') if tid and tid != self.tenant_id: raise Exception(f"Token from wrong tenant: {tid}") # Validate version (v2.0 tokens) ver = payload.get('ver') if ver != '2.0': raise Exception(f"Unsupported token version: {ver}") def get_user_info(self, payload: Dict[str, Any]) -> Dict[str, Any]: """Extract user information from validated token payload.""" return { 'user_id': payload.get('oid') or payload.get('sub'), 'email': payload.get('email') or payload.get('preferred_username'), 'name': payload.get('name'), 'given_name': payload.get('given_name'), 'family_name': payload.get('family_name'), 'tenant_id': payload.get('tid'), 'app_id': payload.get('appid') or payload.get('aud'), 'expires_at': payload.get('exp'), 'issued_at': payload.get('iat'), 'roles': payload.get('roles', []), 'groups': payload.get('groups', []) } def is_token_expired(self, payload: Dict[str, Any]) -> bool: """Check if token is expired based on payload.""" exp = payload.get('exp') if not exp: return True current_time = datetime.now(timezone.utc).timestamp() return current_time >= exp