modcomms/backend/app/services/auth_service.py
michael 5b9e824da9 Use OpenID scopes instead of custom API scopes
- Change frontend scopes from api://{client_id}/.default to
  openid, profile, email for simpler authentication
- Update backend token validation to expect ID token format:
  - Audience: client_id (not api://{client_id})
  - Issuer: v2.0 endpoint

This avoids requiring Application ID URI setup in Azure AD.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-18 10:50:02 -06:00

120 lines
3.7 KiB
Python

"""
Azure AD token verification service.
Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS).
"""
import logging
from typing import Optional
from datetime import datetime, timedelta
import httpx
from jose import jwt, JWTError
from app.config import settings
logger = logging.getLogger(__name__)
# Cache for JWKS (JSON Web Key Set)
_jwks_cache: dict = {}
_jwks_cache_expiry: datetime = datetime.min
async def get_azure_jwks() -> dict:
"""
Fetch and cache Azure AD's public keys for token verification.
Keys are cached for 24 hours to minimize network calls.
"""
global _jwks_cache, _jwks_cache_expiry
if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache:
return _jwks_cache
jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys"
try:
async with httpx.AsyncClient() as client:
response = await client.get(jwks_url, timeout=10.0)
response.raise_for_status()
_jwks_cache = response.json()
_jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24)
logger.info("Successfully fetched Azure AD JWKS")
return _jwks_cache
except Exception as e:
logger.error(f"Failed to fetch JWKS: {e}")
if _jwks_cache: # Return stale cache if available
return _jwks_cache
raise
async def verify_access_token(token: str) -> Optional[dict]:
"""
Verify an Azure AD access token and return the claims.
Args:
token: The JWT access token from the frontend
Returns:
The token claims dict if valid, None if invalid
"""
if settings.DISABLE_AUTH:
logger.warning("Auth disabled - skipping token verification")
return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
if not token:
logger.warning("No token provided")
return None
try:
# Get Azure AD public keys
jwks = await get_azure_jwks()
# Decode without verification first to get the key ID
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
if not kid:
logger.warning("No key ID in token header")
return None
# Find the matching key
rsa_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache")
# Try refreshing JWKS in case keys rotated
global _jwks_cache_expiry
_jwks_cache_expiry = datetime.min
jwks = await get_azure_jwks()
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.error("Could not find matching key after refresh")
return None
# Verify and decode the token
# For ID tokens with OpenID scopes, audience is the client ID
# and issuer uses the v2.0 endpoint
claims = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=settings.AZURE_CLIENT_ID,
issuer=f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/v2.0",
)
logger.info(f"Token verified for user: {claims.get('name', 'unknown')}")
return claims
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")
return None
except Exception as e:
logger.error(f"Token verification error: {e}")
return None