"""Microsoft Authentication Service Validates Microsoft ID tokens and extracts user information. """ import time from typing import Dict, Optional import requests from jose import jwt, JWTError from jose.exceptions import JWKError from pydantic import BaseModel, EmailStr from ..core.config import settings from ..core.logging import get_logger logger = get_logger(__name__) class MicrosoftUserInfo(BaseModel): """User information extracted from Microsoft ID token.""" email: EmailStr name: str sub: str # Microsoft user ID tid: str # Tenant ID email_verified: bool = True class MicrosoftAuthError(Exception): """Base exception for Microsoft authentication errors.""" pass class MicrosoftTokenValidationError(MicrosoftAuthError): """Raised when token validation fails.""" pass class MicrosoftAuthService: """Service for Microsoft authentication operations.""" def __init__(self): self.client_id = settings.azure_client_id self.authority = settings.azure_authority # Extract tenant ID from authority URL # Format: https://login.microsoftonline.com/{tenant_id} self.tenant_id = self.authority.rstrip('/').split('/')[-1] # Microsoft's OpenID configuration endpoint self.openid_config_url = f"{self.authority}/v2.0/.well-known/openid-configuration" # Cache for JWKS (public keys) self._jwks_cache: Optional[Dict] = None self._jwks_cache_time: float = 0 self._jwks_cache_ttl: int = 3600 # Cache for 1 hour def _get_openid_config(self) -> Dict: """Fetch OpenID Connect configuration from Microsoft.""" try: response = requests.get(self.openid_config_url, timeout=10) response.raise_for_status() return response.json() except requests.RequestException as e: logger.error(f"Failed to fetch OpenID configuration: {e}") raise MicrosoftAuthError("Failed to fetch Microsoft authentication configuration") def _get_jwks(self, force_refresh: bool = False) -> Dict: """Fetch JSON Web Key Set (JWKS) from Microsoft. Args: force_refresh: Force refresh even if cache is valid Returns: JWKS dictionary with public keys """ # Check cache current_time = time.time() if (not force_refresh and self._jwks_cache and (current_time - self._jwks_cache_time) < self._jwks_cache_ttl): return self._jwks_cache try: # Get JWKS URI from OpenID configuration config = self._get_openid_config() jwks_uri = config.get('jwks_uri') if not jwks_uri: raise MicrosoftAuthError("JWKS URI not found in OpenID configuration") # Fetch JWKS response = requests.get(jwks_uri, timeout=10) response.raise_for_status() jwks = response.json() # Update cache self._jwks_cache = jwks self._jwks_cache_time = current_time return jwks except requests.RequestException as e: logger.error(f"Failed to fetch JWKS: {e}") raise MicrosoftAuthError("Failed to fetch Microsoft public keys") def validate_token(self, id_token: str) -> MicrosoftUserInfo: """Validate Microsoft ID token and extract user information. Args: id_token: Microsoft ID token string Returns: MicrosoftUserInfo with validated user data Raises: MicrosoftTokenValidationError: If token validation fails """ try: # Get JWKS for signature verification jwks = self._get_jwks() # Decode token header to get key ID (kid) unverified_header = jwt.get_unverified_header(id_token) kid = unverified_header.get('kid') if not kid: raise MicrosoftTokenValidationError("Token header missing 'kid' claim") # Find the matching key in JWKS rsa_key = None for key in jwks.get('keys', []): if key.get('kid') == kid: rsa_key = { 'kty': key['kty'], 'kid': key['kid'], 'use': key.get('use'), 'n': key['n'], 'e': key['e'] } break if not rsa_key: logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache") # Try refreshing JWKS cache (keys might have been rotated) jwks = self._get_jwks(force_refresh=True) for key in jwks.get('keys', []): if key.get('kid') == kid: rsa_key = { 'kty': key['kty'], 'kid': key['kid'], 'use': key.get('use'), 'n': key['n'], 'e': key['e'] } break if not rsa_key: raise MicrosoftTokenValidationError(f"Unable to find key with ID: {kid}") # Validate token signature and claims try: payload = jwt.decode( id_token, rsa_key, algorithms=['RS256'], audience=self.client_id, issuer=f"https://login.microsoftonline.com/{self.tenant_id}/v2.0" ) except JWTError as e: raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}") # Extract required claims email = payload.get('email') or payload.get('preferred_username') if not email: raise MicrosoftTokenValidationError("Token missing email claim") name = payload.get('name') if not name: # Fallback to email if name not provided name = email.split('@')[0] sub = payload.get('sub') if not sub: raise MicrosoftTokenValidationError("Token missing 'sub' claim") tid = payload.get('tid') if not tid: raise MicrosoftTokenValidationError("Token missing 'tid' claim") # Check if email is verified (Microsoft tokens are considered verified) email_verified = payload.get('email_verified', True) # Create user info object user_info = MicrosoftUserInfo( email=email, name=name, sub=sub, tid=tid, email_verified=email_verified ) logger.info(f"Successfully validated Microsoft token for user: {email}") return user_info except JWKError as e: logger.error(f"JWK error during token validation: {e}") raise MicrosoftTokenValidationError(f"Key processing error: {str(e)}") except Exception as e: if isinstance(e, (MicrosoftAuthError, MicrosoftTokenValidationError)): raise logger.error(f"Unexpected error during token validation: {e}") raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}") # Singleton instance microsoft_auth_service = MicrosoftAuthService() def get_microsoft_auth_service() -> MicrosoftAuthService: """Get Microsoft authentication service instance.""" return microsoft_auth_service