"""Azure AD JWT validation using JWKS (RS256).""" import logging import time from dataclasses import dataclass import httpx import jwt from app.config import settings logger = logging.getLogger("olivas.auth") # JWKS cache _jwks_cache: dict = {} _jwks_cache_time: float = 0 _JWKS_CACHE_TTL = 3600 # 1 hour @dataclass class CurrentUser: oid: str name: str email: str def _get_jwks_uri() -> str: tenant = settings.AZURE_TENANT_ID return f"https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys" def _get_issuer() -> str: tenant = settings.AZURE_TENANT_ID return f"https://login.microsoftonline.com/{tenant}/v2.0" def refresh_jwks_cache() -> None: """Fetch and cache JWKS keys from Azure AD.""" global _jwks_cache, _jwks_cache_time uri = _get_jwks_uri() logger.info(f"Fetching JWKS from {uri}") resp = httpx.get(uri, timeout=10) resp.raise_for_status() _jwks_cache = resp.json() _jwks_cache_time = time.time() logger.info(f"JWKS cache refreshed ({len(_jwks_cache.get('keys', []))} keys)") def _get_signing_key(token: str) -> jwt.algorithms.RSAAlgorithm: """Get the signing key for the given token from cached JWKS.""" global _jwks_cache, _jwks_cache_time if not _jwks_cache or (time.time() - _jwks_cache_time > _JWKS_CACHE_TTL): refresh_jwks_cache() jwks_client = jwt.PyJWKClient.__new__(jwt.PyJWKClient) jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache) try: header = jwt.get_unverified_header(token) kid = header.get("kid") for key in jwks_client.jwk_set.keys: if key.key_id == kid: return key.key except Exception: pass # Key not found — try refreshing once refresh_jwks_cache() jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache) header = jwt.get_unverified_header(token) kid = header.get("kid") for key in jwks_client.jwk_set.keys: if key.key_id == kid: return key.key raise jwt.InvalidTokenError(f"Unable to find signing key with kid={kid}") def validate_token(token: str) -> CurrentUser: """Decode and validate an Azure AD JWT token. Returns CurrentUser.""" signing_key = _get_signing_key(token) payload = jwt.decode( token, signing_key, algorithms=["RS256"], audience=settings.AZURE_CLIENT_ID, issuer=_get_issuer(), options={"require": ["exp", "iss", "aud", "oid"]}, ) return CurrentUser( oid=payload["oid"], name=payload.get("name", ""), email=payload.get("preferred_username", payload.get("email", "")), )