""" Azure AD token verification service. Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS). """ import logging from typing import Optional from datetime import datetime, timedelta import httpx from jose import jwt, JWTError from app.config import settings logger = logging.getLogger(__name__) # Cache for JWKS (JSON Web Key Set) _jwks_cache: dict = {} _jwks_cache_expiry: datetime = datetime.min async def get_azure_jwks() -> dict: """ Fetch and cache Azure AD's public keys for token verification. Keys are cached for 24 hours to minimize network calls. """ global _jwks_cache, _jwks_cache_expiry if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache: return _jwks_cache jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys" try: async with httpx.AsyncClient() as client: response = await client.get(jwks_url, timeout=10.0) response.raise_for_status() _jwks_cache = response.json() _jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24) logger.info("Successfully fetched Azure AD JWKS") return _jwks_cache except Exception as e: logger.error(f"Failed to fetch JWKS: {e}") if _jwks_cache: # Return stale cache if available return _jwks_cache raise async def verify_access_token(token: str) -> Optional[dict]: """ Verify an Azure AD access token and return the claims. Args: token: The JWT access token from the frontend Returns: The token claims dict if valid, None if invalid """ if settings.DISABLE_AUTH: logger.warning("Auth disabled - skipping token verification") return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"} if not token: logger.warning("No token provided") return None try: # Get Azure AD public keys jwks = await get_azure_jwks() # Decode without verification first to get the key ID unverified_header = jwt.get_unverified_header(token) kid = unverified_header.get("kid") if not kid: logger.warning("No key ID in token header") return None # Find the matching key rsa_key = None for key in jwks.get("keys", []): if key.get("kid") == kid: rsa_key = key break if not rsa_key: logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache") # Try refreshing JWKS in case keys rotated global _jwks_cache_expiry _jwks_cache_expiry = datetime.min jwks = await get_azure_jwks() for key in jwks.get("keys", []): if key.get("kid") == kid: rsa_key = key break if not rsa_key: logger.error("Could not find matching key after refresh") return None # Verify and decode the token claims = jwt.decode( token, rsa_key, algorithms=["RS256"], audience=f"api://{settings.AZURE_CLIENT_ID}", issuer=f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/", ) logger.info(f"Token verified for user: {claims.get('name', 'unknown')}") return claims except JWTError as e: logger.warning(f"JWT verification failed: {e}") return None except Exception as e: logger.error(f"Token verification error: {e}") return None