modcomms/backend/app/services/auth_service.py
michael ba9c0ebde3 Reduce auth logging verbosity: INFO → DEBUG
All routine MSAL token verification logs now use DEBUG level so they
don't flood the console on every polling request.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 16:13:25 -06:00

163 lines
6.1 KiB
Python
Executable file

"""
Azure AD token verification service.
Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS).
"""
import logging
from typing import Optional
from datetime import datetime, timedelta
import httpx
from jose import jwt, JWTError
from app.config import settings
logger = logging.getLogger(__name__)
# Cache for JWKS (JSON Web Key Set)
_jwks_cache: dict = {}
_jwks_cache_expiry: datetime = datetime.min
async def get_azure_jwks() -> dict:
"""
Fetch and cache Azure AD's public keys for token verification.
Keys are cached for 24 hours to minimize network calls.
"""
global _jwks_cache, _jwks_cache_expiry
if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache:
return _jwks_cache
jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys"
try:
async with httpx.AsyncClient() as client:
response = await client.get(jwks_url, timeout=10.0)
response.raise_for_status()
_jwks_cache = response.json()
_jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24)
logger.debug("Successfully fetched Azure AD JWKS")
return _jwks_cache
except Exception as e:
logger.error(f"Failed to fetch JWKS: {e}")
if _jwks_cache: # Return stale cache if available
return _jwks_cache
raise
async def verify_access_token(token: str) -> Optional[dict]:
"""
Verify an Azure AD access token and return the claims.
Args:
token: The JWT access token from the frontend
Returns:
The token claims dict if valid, None if invalid
"""
logger.debug("[MSAL Backend] verify_access_token called")
if settings.DISABLE_AUTH:
logger.warning("[MSAL Backend] Auth disabled - skipping token verification")
return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
if not token:
logger.warning("[MSAL Backend] No token provided")
return None
# Log token preview (first/last chars only for security)
token_preview = f"{token[:20]}...{token[-10:]}" if len(token) > 30 else "[short token]"
logger.debug(f"[MSAL Backend] Verifying token: {token_preview}")
try:
# Get Azure AD public keys
logger.debug("[MSAL Backend] Fetching Azure AD JWKS...")
jwks = await get_azure_jwks()
logger.debug(f"[MSAL Backend] JWKS contains {len(jwks.get('keys', []))} keys")
# Decode without verification first to get the key ID
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
alg = unverified_header.get("alg")
logger.debug(f"[MSAL Backend] Token header - kid: {kid}, alg: {alg}")
# Log unverified claims to see what we're receiving
try:
unverified_claims = jwt.get_unverified_claims(token)
logger.debug(f"[MSAL Backend] Token aud (unverified): {unverified_claims.get('aud')}")
logger.debug(f"[MSAL Backend] Token iss (unverified): {unverified_claims.get('iss')}")
logger.debug(f"[MSAL Backend] Token azp (unverified): {unverified_claims.get('azp')}")
except Exception as e:
logger.warning(f"[MSAL Backend] Could not decode unverified claims: {e}")
if not kid:
logger.warning("[MSAL Backend] No key ID in token header")
return None
# Find the matching key
rsa_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.warning(f"[MSAL Backend] Key ID {kid} not found in JWKS, refreshing cache")
# Try refreshing JWKS in case keys rotated
global _jwks_cache_expiry
_jwks_cache_expiry = datetime.min
jwks = await get_azure_jwks()
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.error("[MSAL Backend] Could not find matching key after refresh")
return None
logger.debug(f"[MSAL Backend] Found matching RSA key for kid: {kid}")
# Verify and decode the token
# Azure AD can issue tokens with either v1.0 or v2.0 issuer format
# depending on the app registration's accessTokenAcceptedVersion setting
v1_issuer = f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/"
v2_issuer = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/v2.0"
logger.debug(f"[MSAL Backend] Verifying with audience: {settings.AZURE_CLIENT_ID}")
logger.debug(f"[MSAL Backend] Accepting issuers: {v1_issuer} OR {v2_issuer}")
# Try v1 issuer first (most common for /.default scope tokens)
claims = None
for issuer in [v1_issuer, v2_issuer]:
try:
claims = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=settings.AZURE_CLIENT_ID,
issuer=issuer,
)
logger.debug(f"[MSAL Backend] Token verified with issuer: {issuer}")
break
except JWTError as e:
if "issuer" in str(e).lower():
continue # Try next issuer
raise # Re-raise if it's a different error
if not claims:
logger.warning("[MSAL Backend] Token issuer doesn't match any expected format")
return None
logger.debug(f"[MSAL Backend] Token verified successfully!")
logger.debug(f"[MSAL Backend] User: {claims.get('name', 'unknown')} ({claims.get('preferred_username', 'unknown')})")
logger.debug(f"[MSAL Backend] Token exp: {claims.get('exp')}, iat: {claims.get('iat')}")
return claims
except JWTError as e:
logger.warning(f"[MSAL Backend] JWT verification failed: {e}")
return None
except Exception as e:
logger.error(f"[MSAL Backend] Token verification error: {e}")
return None