All routine MSAL token verification logs now use DEBUG level so they don't flood the console on every polling request. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
163 lines
6.1 KiB
Python
Executable file
163 lines
6.1 KiB
Python
Executable file
"""
|
|
Azure AD token verification service.
|
|
|
|
Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS).
|
|
"""
|
|
import logging
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta
|
|
|
|
import httpx
|
|
from jose import jwt, JWTError
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cache for JWKS (JSON Web Key Set)
|
|
_jwks_cache: dict = {}
|
|
_jwks_cache_expiry: datetime = datetime.min
|
|
|
|
|
|
async def get_azure_jwks() -> dict:
|
|
"""
|
|
Fetch and cache Azure AD's public keys for token verification.
|
|
Keys are cached for 24 hours to minimize network calls.
|
|
"""
|
|
global _jwks_cache, _jwks_cache_expiry
|
|
|
|
if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache:
|
|
return _jwks_cache
|
|
|
|
jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys"
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(jwks_url, timeout=10.0)
|
|
response.raise_for_status()
|
|
_jwks_cache = response.json()
|
|
_jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24)
|
|
logger.debug("Successfully fetched Azure AD JWKS")
|
|
return _jwks_cache
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch JWKS: {e}")
|
|
if _jwks_cache: # Return stale cache if available
|
|
return _jwks_cache
|
|
raise
|
|
|
|
|
|
async def verify_access_token(token: str) -> Optional[dict]:
|
|
"""
|
|
Verify an Azure AD access token and return the claims.
|
|
|
|
Args:
|
|
token: The JWT access token from the frontend
|
|
|
|
Returns:
|
|
The token claims dict if valid, None if invalid
|
|
"""
|
|
logger.debug("[MSAL Backend] verify_access_token called")
|
|
|
|
if settings.DISABLE_AUTH:
|
|
logger.warning("[MSAL Backend] Auth disabled - skipping token verification")
|
|
return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
|
|
|
|
if not token:
|
|
logger.warning("[MSAL Backend] No token provided")
|
|
return None
|
|
|
|
# Log token preview (first/last chars only for security)
|
|
token_preview = f"{token[:20]}...{token[-10:]}" if len(token) > 30 else "[short token]"
|
|
logger.debug(f"[MSAL Backend] Verifying token: {token_preview}")
|
|
|
|
try:
|
|
# Get Azure AD public keys
|
|
logger.debug("[MSAL Backend] Fetching Azure AD JWKS...")
|
|
jwks = await get_azure_jwks()
|
|
logger.debug(f"[MSAL Backend] JWKS contains {len(jwks.get('keys', []))} keys")
|
|
|
|
# Decode without verification first to get the key ID
|
|
unverified_header = jwt.get_unverified_header(token)
|
|
kid = unverified_header.get("kid")
|
|
alg = unverified_header.get("alg")
|
|
logger.debug(f"[MSAL Backend] Token header - kid: {kid}, alg: {alg}")
|
|
|
|
# Log unverified claims to see what we're receiving
|
|
try:
|
|
unverified_claims = jwt.get_unverified_claims(token)
|
|
logger.debug(f"[MSAL Backend] Token aud (unverified): {unverified_claims.get('aud')}")
|
|
logger.debug(f"[MSAL Backend] Token iss (unverified): {unverified_claims.get('iss')}")
|
|
logger.debug(f"[MSAL Backend] Token azp (unverified): {unverified_claims.get('azp')}")
|
|
except Exception as e:
|
|
logger.warning(f"[MSAL Backend] Could not decode unverified claims: {e}")
|
|
|
|
if not kid:
|
|
logger.warning("[MSAL Backend] No key ID in token header")
|
|
return None
|
|
|
|
# Find the matching key
|
|
rsa_key = None
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.warning(f"[MSAL Backend] Key ID {kid} not found in JWKS, refreshing cache")
|
|
# Try refreshing JWKS in case keys rotated
|
|
global _jwks_cache_expiry
|
|
_jwks_cache_expiry = datetime.min
|
|
jwks = await get_azure_jwks()
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.error("[MSAL Backend] Could not find matching key after refresh")
|
|
return None
|
|
|
|
logger.debug(f"[MSAL Backend] Found matching RSA key for kid: {kid}")
|
|
|
|
# Verify and decode the token
|
|
# Azure AD can issue tokens with either v1.0 or v2.0 issuer format
|
|
# depending on the app registration's accessTokenAcceptedVersion setting
|
|
v1_issuer = f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/"
|
|
v2_issuer = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/v2.0"
|
|
|
|
logger.debug(f"[MSAL Backend] Verifying with audience: {settings.AZURE_CLIENT_ID}")
|
|
logger.debug(f"[MSAL Backend] Accepting issuers: {v1_issuer} OR {v2_issuer}")
|
|
|
|
# Try v1 issuer first (most common for /.default scope tokens)
|
|
claims = None
|
|
for issuer in [v1_issuer, v2_issuer]:
|
|
try:
|
|
claims = jwt.decode(
|
|
token,
|
|
rsa_key,
|
|
algorithms=["RS256"],
|
|
audience=settings.AZURE_CLIENT_ID,
|
|
issuer=issuer,
|
|
)
|
|
logger.debug(f"[MSAL Backend] Token verified with issuer: {issuer}")
|
|
break
|
|
except JWTError as e:
|
|
if "issuer" in str(e).lower():
|
|
continue # Try next issuer
|
|
raise # Re-raise if it's a different error
|
|
|
|
if not claims:
|
|
logger.warning("[MSAL Backend] Token issuer doesn't match any expected format")
|
|
return None
|
|
|
|
logger.debug(f"[MSAL Backend] Token verified successfully!")
|
|
logger.debug(f"[MSAL Backend] User: {claims.get('name', 'unknown')} ({claims.get('preferred_username', 'unknown')})")
|
|
logger.debug(f"[MSAL Backend] Token exp: {claims.get('exp')}, iat: {claims.get('iat')}")
|
|
return claims
|
|
|
|
except JWTError as e:
|
|
logger.warning(f"[MSAL Backend] JWT verification failed: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"[MSAL Backend] Token verification error: {e}")
|
|
return None
|