Replace X-User-Id header auth with Azure AD JWT token validation. Backend validates tokens via JWKS, frontend uses MSAL for login/token acquisition. Adds logout button, 401 handling, and configurable AZURE_AUTH_ENABLED toggle. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
97 lines
2.6 KiB
Python
97 lines
2.6 KiB
Python
"""Azure AD JWT validation using JWKS (RS256)."""
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import httpx
|
|
import jwt
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger("olivas.auth")
|
|
|
|
# JWKS cache
|
|
_jwks_cache: dict = {}
|
|
_jwks_cache_time: float = 0
|
|
_JWKS_CACHE_TTL = 3600 # 1 hour
|
|
|
|
|
|
@dataclass
|
|
class CurrentUser:
|
|
oid: str
|
|
name: str
|
|
email: str
|
|
|
|
|
|
def _get_jwks_uri() -> str:
|
|
tenant = settings.AZURE_TENANT_ID
|
|
return f"https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys"
|
|
|
|
|
|
def _get_issuer() -> str:
|
|
tenant = settings.AZURE_TENANT_ID
|
|
return f"https://login.microsoftonline.com/{tenant}/v2.0"
|
|
|
|
|
|
def refresh_jwks_cache() -> None:
|
|
"""Fetch and cache JWKS keys from Azure AD."""
|
|
global _jwks_cache, _jwks_cache_time
|
|
uri = _get_jwks_uri()
|
|
logger.info(f"Fetching JWKS from {uri}")
|
|
resp = httpx.get(uri, timeout=10)
|
|
resp.raise_for_status()
|
|
_jwks_cache = resp.json()
|
|
_jwks_cache_time = time.time()
|
|
logger.info(f"JWKS cache refreshed ({len(_jwks_cache.get('keys', []))} keys)")
|
|
|
|
|
|
def _get_signing_key(token: str) -> jwt.algorithms.RSAAlgorithm:
|
|
"""Get the signing key for the given token from cached JWKS."""
|
|
global _jwks_cache, _jwks_cache_time
|
|
|
|
if not _jwks_cache or (time.time() - _jwks_cache_time > _JWKS_CACHE_TTL):
|
|
refresh_jwks_cache()
|
|
|
|
jwks_client = jwt.PyJWKClient.__new__(jwt.PyJWKClient)
|
|
jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
|
|
|
|
try:
|
|
header = jwt.get_unverified_header(token)
|
|
kid = header.get("kid")
|
|
for key in jwks_client.jwk_set.keys:
|
|
if key.key_id == kid:
|
|
return key.key
|
|
except Exception:
|
|
pass
|
|
|
|
# Key not found — try refreshing once
|
|
refresh_jwks_cache()
|
|
jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
|
|
header = jwt.get_unverified_header(token)
|
|
kid = header.get("kid")
|
|
for key in jwks_client.jwk_set.keys:
|
|
if key.key_id == kid:
|
|
return key.key
|
|
|
|
raise jwt.InvalidTokenError(f"Unable to find signing key with kid={kid}")
|
|
|
|
|
|
def validate_token(token: str) -> CurrentUser:
|
|
"""Decode and validate an Azure AD JWT token. Returns CurrentUser."""
|
|
signing_key = _get_signing_key(token)
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
signing_key,
|
|
algorithms=["RS256"],
|
|
audience=settings.AZURE_CLIENT_ID,
|
|
issuer=_get_issuer(),
|
|
options={"require": ["exp", "iss", "aud", "oid"]},
|
|
)
|
|
|
|
return CurrentUser(
|
|
oid=payload["oid"],
|
|
name=payload.get("name", ""),
|
|
email=payload.get("preferred_username", payload.get("email", "")),
|
|
)
|