""" Azure AD token verification service. Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS). """ import logging from typing import Optional from datetime import datetime, timedelta import httpx from jose import jwt, JWTError from app.config import settings logger = logging.getLogger(__name__) # Cache for JWKS (JSON Web Key Set) _jwks_cache: dict = {} _jwks_cache_expiry: datetime = datetime.min async def get_azure_jwks() -> dict: """ Fetch and cache Azure AD's public keys for token verification. Keys are cached for 24 hours to minimize network calls. """ global _jwks_cache, _jwks_cache_expiry if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache: return _jwks_cache jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys" try: async with httpx.AsyncClient() as client: response = await client.get(jwks_url, timeout=10.0) response.raise_for_status() _jwks_cache = response.json() _jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24) logger.debug("Successfully fetched Azure AD JWKS") return _jwks_cache except Exception as e: logger.error(f"Failed to fetch JWKS: {e}") if _jwks_cache: # Return stale cache if available return _jwks_cache raise async def verify_access_token(token: str) -> Optional[dict]: """ Verify an Azure AD access token and return the claims. Args: token: The JWT access token from the frontend Returns: The token claims dict if valid, None if invalid """ logger.debug("[MSAL Backend] verify_access_token called") if settings.DISABLE_AUTH: logger.warning("[MSAL Backend] Auth disabled - skipping token verification") return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"} if not token: logger.warning("[MSAL Backend] No token provided") return None # Log token preview (first/last chars only for security) token_preview = f"{token[:20]}...{token[-10:]}" if len(token) > 30 else "[short token]" logger.debug(f"[MSAL Backend] Verifying token: {token_preview}") try: # Get Azure AD public keys logger.debug("[MSAL Backend] Fetching Azure AD JWKS...") jwks = await get_azure_jwks() logger.debug(f"[MSAL Backend] JWKS contains {len(jwks.get('keys', []))} keys") # Decode without verification first to get the key ID unverified_header = jwt.get_unverified_header(token) kid = unverified_header.get("kid") alg = unverified_header.get("alg") logger.debug(f"[MSAL Backend] Token header - kid: {kid}, alg: {alg}") # Log unverified claims to see what we're receiving try: unverified_claims = jwt.get_unverified_claims(token) logger.debug(f"[MSAL Backend] Token aud (unverified): {unverified_claims.get('aud')}") logger.debug(f"[MSAL Backend] Token iss (unverified): {unverified_claims.get('iss')}") logger.debug(f"[MSAL Backend] Token azp (unverified): {unverified_claims.get('azp')}") except Exception as e: logger.warning(f"[MSAL Backend] Could not decode unverified claims: {e}") if not kid: logger.warning("[MSAL Backend] No key ID in token header") return None # Find the matching key rsa_key = None for key in jwks.get("keys", []): if key.get("kid") == kid: rsa_key = key break if not rsa_key: logger.warning(f"[MSAL Backend] Key ID {kid} not found in JWKS, refreshing cache") # Try refreshing JWKS in case keys rotated global _jwks_cache_expiry _jwks_cache_expiry = datetime.min jwks = await get_azure_jwks() for key in jwks.get("keys", []): if key.get("kid") == kid: rsa_key = key break if not rsa_key: logger.error("[MSAL Backend] Could not find matching key after refresh") return None logger.debug(f"[MSAL Backend] Found matching RSA key for kid: {kid}") # Verify and decode the token # Azure AD can issue tokens with either v1.0 or v2.0 issuer format # depending on the app registration's accessTokenAcceptedVersion setting v1_issuer = f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/" v2_issuer = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/v2.0" logger.debug(f"[MSAL Backend] Verifying with audience: {settings.AZURE_CLIENT_ID}") logger.debug(f"[MSAL Backend] Accepting issuers: {v1_issuer} OR {v2_issuer}") # Try v1 issuer first (most common for /.default scope tokens) claims = None for issuer in [v1_issuer, v2_issuer]: try: claims = jwt.decode( token, rsa_key, algorithms=["RS256"], audience=settings.AZURE_CLIENT_ID, issuer=issuer, ) logger.debug(f"[MSAL Backend] Token verified with issuer: {issuer}") break except JWTError as e: if "issuer" in str(e).lower(): continue # Try next issuer raise # Re-raise if it's a different error if not claims: logger.warning("[MSAL Backend] Token issuer doesn't match any expected format") return None logger.debug(f"[MSAL Backend] Token verified successfully!") logger.debug(f"[MSAL Backend] User: {claims.get('name', 'unknown')} ({claims.get('preferred_username', 'unknown')})") logger.debug(f"[MSAL Backend] Token exp: {claims.get('exp')}, iat: {claims.get('iat')}") return claims except JWTError as e: logger.warning(f"[MSAL Backend] JWT verification failed: {e}") return None except Exception as e: logger.error(f"[MSAL Backend] Token verification error: {e}") return None