video-accessibility/backend/app/services/microsoft_auth.py
2025-10-10 09:19:39 -05:00

220 lines
7.5 KiB
Python

"""Microsoft Authentication Service
Validates Microsoft ID tokens and extracts user information.
"""
import time
from typing import Dict, Optional
import requests
from jose import jwt, JWTError
from jose.exceptions import JWKError
from pydantic import BaseModel, EmailStr
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger(__name__)
class MicrosoftUserInfo(BaseModel):
"""User information extracted from Microsoft ID token."""
email: EmailStr
name: str
sub: str # Microsoft user ID
tid: str # Tenant ID
email_verified: bool = True
class MicrosoftAuthError(Exception):
"""Base exception for Microsoft authentication errors."""
pass
class MicrosoftTokenValidationError(MicrosoftAuthError):
"""Raised when token validation fails."""
pass
class MicrosoftAuthService:
"""Service for Microsoft authentication operations."""
def __init__(self):
self.client_id = settings.azure_client_id
self.authority = settings.azure_authority
# Extract tenant ID from authority URL
# Format: https://login.microsoftonline.com/{tenant_id}
self.tenant_id = self.authority.rstrip('/').split('/')[-1]
# Microsoft's OpenID configuration endpoint
self.openid_config_url = f"{self.authority}/v2.0/.well-known/openid-configuration"
# Cache for JWKS (public keys)
self._jwks_cache: Optional[Dict] = None
self._jwks_cache_time: float = 0
self._jwks_cache_ttl: int = 3600 # Cache for 1 hour
def _get_openid_config(self) -> Dict:
"""Fetch OpenID Connect configuration from Microsoft."""
try:
response = requests.get(self.openid_config_url, timeout=10)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
logger.error(f"Failed to fetch OpenID configuration: {e}")
raise MicrosoftAuthError("Failed to fetch Microsoft authentication configuration")
def _get_jwks(self, force_refresh: bool = False) -> Dict:
"""Fetch JSON Web Key Set (JWKS) from Microsoft.
Args:
force_refresh: Force refresh even if cache is valid
Returns:
JWKS dictionary with public keys
"""
# Check cache
current_time = time.time()
if (not force_refresh and
self._jwks_cache and
(current_time - self._jwks_cache_time) < self._jwks_cache_ttl):
return self._jwks_cache
try:
# Get JWKS URI from OpenID configuration
config = self._get_openid_config()
jwks_uri = config.get('jwks_uri')
if not jwks_uri:
raise MicrosoftAuthError("JWKS URI not found in OpenID configuration")
# Fetch JWKS
response = requests.get(jwks_uri, timeout=10)
response.raise_for_status()
jwks = response.json()
# Update cache
self._jwks_cache = jwks
self._jwks_cache_time = current_time
return jwks
except requests.RequestException as e:
logger.error(f"Failed to fetch JWKS: {e}")
raise MicrosoftAuthError("Failed to fetch Microsoft public keys")
def validate_token(self, id_token: str) -> MicrosoftUserInfo:
"""Validate Microsoft ID token and extract user information.
Args:
id_token: Microsoft ID token string
Returns:
MicrosoftUserInfo with validated user data
Raises:
MicrosoftTokenValidationError: If token validation fails
"""
try:
# Get JWKS for signature verification
jwks = self._get_jwks()
# Decode token header to get key ID (kid)
unverified_header = jwt.get_unverified_header(id_token)
kid = unverified_header.get('kid')
if not kid:
raise MicrosoftTokenValidationError("Token header missing 'kid' claim")
# Find the matching key in JWKS
rsa_key = None
for key in jwks.get('keys', []):
if key.get('kid') == kid:
rsa_key = {
'kty': key['kty'],
'kid': key['kid'],
'use': key.get('use'),
'n': key['n'],
'e': key['e']
}
break
if not rsa_key:
logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache")
# Try refreshing JWKS cache (keys might have been rotated)
jwks = self._get_jwks(force_refresh=True)
for key in jwks.get('keys', []):
if key.get('kid') == kid:
rsa_key = {
'kty': key['kty'],
'kid': key['kid'],
'use': key.get('use'),
'n': key['n'],
'e': key['e']
}
break
if not rsa_key:
raise MicrosoftTokenValidationError(f"Unable to find key with ID: {kid}")
# Validate token signature and claims
try:
payload = jwt.decode(
id_token,
rsa_key,
algorithms=['RS256'],
audience=self.client_id,
issuer=f"https://login.microsoftonline.com/{self.tenant_id}/v2.0"
)
except JWTError as e:
raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}")
# Extract required claims
email = payload.get('email') or payload.get('preferred_username')
if not email:
raise MicrosoftTokenValidationError("Token missing email claim")
name = payload.get('name')
if not name:
# Fallback to email if name not provided
name = email.split('@')[0]
sub = payload.get('sub')
if not sub:
raise MicrosoftTokenValidationError("Token missing 'sub' claim")
tid = payload.get('tid')
if not tid:
raise MicrosoftTokenValidationError("Token missing 'tid' claim")
# Check if email is verified (Microsoft tokens are considered verified)
email_verified = payload.get('email_verified', True)
# Create user info object
user_info = MicrosoftUserInfo(
email=email,
name=name,
sub=sub,
tid=tid,
email_verified=email_verified
)
logger.info(f"Successfully validated Microsoft token for user: {email}")
return user_info
except JWKError as e:
logger.error(f"JWK error during token validation: {e}")
raise MicrosoftTokenValidationError(f"Key processing error: {str(e)}")
except Exception as e:
if isinstance(e, (MicrosoftAuthError, MicrosoftTokenValidationError)):
raise
logger.error(f"Unexpected error during token validation: {e}")
raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}")
# Singleton instance
microsoft_auth_service = MicrosoftAuthService()
def get_microsoft_auth_service() -> MicrosoftAuthService:
"""Get Microsoft authentication service instance."""
return microsoft_auth_service