"""Microsoft Authentication Service Validates Microsoft ID tokens and extracts user information. """ import time import httpx from jose import JWTError, jwt from jose.exceptions import JWKError from pydantic import BaseModel, EmailStr from ..core.config import settings from ..core.logging import get_logger logger = get_logger(__name__) class MicrosoftUserInfo(BaseModel): """User information extracted from Microsoft ID token.""" email: EmailStr name: str sub: str # Microsoft user ID tid: str # Tenant ID email_verified: bool = True class MicrosoftAuthError(Exception): """Base exception for Microsoft authentication errors.""" pass class MicrosoftTokenValidationError(MicrosoftAuthError): """Raised when token validation fails.""" pass class MicrosoftAuthService: """Service for Microsoft authentication operations.""" def __init__(self): self.client_id = settings.azure_client_id self.authority = settings.azure_authority # Extract tenant ID from authority URL # Format: https://login.microsoftonline.com/{tenant_id} self.tenant_id = self.authority.rstrip('/').split('/')[-1] # Microsoft's OpenID configuration endpoint self.openid_config_url = f"{self.authority}/v2.0/.well-known/openid-configuration" # Cache for JWKS (public keys) self._jwks_cache: dict | None = None self._jwks_cache_time: float = 0 self._jwks_cache_ttl: int = 3600 # Cache for 1 hour async def _get_openid_config(self) -> dict: """Fetch OpenID Connect configuration from Microsoft.""" try: async with httpx.AsyncClient(timeout=10) as client: response = await client.get(self.openid_config_url) response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"Failed to fetch OpenID configuration: {e}") raise MicrosoftAuthError("Failed to fetch Microsoft authentication configuration") from e async def _get_jwks(self, force_refresh: bool = False) -> dict: """Fetch JSON Web Key Set (JWKS) from Microsoft. Args: force_refresh: Force refresh even if cache is valid Returns: JWKS dictionary with public keys """ current_time = time.time() if (not force_refresh and self._jwks_cache and (current_time - self._jwks_cache_time) < self._jwks_cache_ttl): return self._jwks_cache try: config = await self._get_openid_config() jwks_uri = config.get('jwks_uri') if not jwks_uri: raise MicrosoftAuthError("JWKS URI not found in OpenID configuration") async with httpx.AsyncClient(timeout=10) as client: response = await client.get(jwks_uri) response.raise_for_status() jwks = response.json() self._jwks_cache = jwks self._jwks_cache_time = current_time return jwks except httpx.HTTPError as e: logger.error(f"Failed to fetch JWKS: {e}") raise MicrosoftAuthError("Failed to fetch Microsoft public keys") from e async def validate_token(self, id_token: str) -> MicrosoftUserInfo: """Validate Microsoft ID token and extract user information. Args: id_token: Microsoft ID token string Returns: MicrosoftUserInfo with validated user data Raises: MicrosoftTokenValidationError: If token validation fails """ try: jwks = await self._get_jwks() unverified_header = jwt.get_unverified_header(id_token) kid = unverified_header.get('kid') if not kid: raise MicrosoftTokenValidationError("Token header missing 'kid' claim") def _find_key(keys: list) -> dict | None: for key in keys: if key.get('kid') == kid: return {'kty': key['kty'], 'kid': key['kid'], 'use': key.get('use'), 'n': key['n'], 'e': key['e']} return None rsa_key = _find_key(jwks.get('keys', [])) if not rsa_key: logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache") jwks = await self._get_jwks(force_refresh=True) rsa_key = _find_key(jwks.get('keys', [])) if not rsa_key: raise MicrosoftTokenValidationError(f"Unable to find key with ID: {kid}") try: payload = jwt.decode( id_token, rsa_key, algorithms=['RS256'], audience=self.client_id, issuer=f"https://login.microsoftonline.com/{self.tenant_id}/v2.0" ) except JWTError as e: raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}") from e email = payload.get('email') or payload.get('preferred_username') if not email: raise MicrosoftTokenValidationError("Token missing email claim") name = payload.get('name') or email.split('@')[0] sub = payload.get('sub') if not sub: raise MicrosoftTokenValidationError("Token missing 'sub' claim") tid = payload.get('tid') if not tid: raise MicrosoftTokenValidationError("Token missing 'tid' claim") email_verified = payload.get('email_verified', True) user_info = MicrosoftUserInfo( email=email, name=name, sub=sub, tid=tid, email_verified=email_verified, ) logger.info(f"Successfully validated Microsoft token for user: {email}") return user_info except JWKError as e: logger.error(f"JWK error during token validation: {e}") raise MicrosoftTokenValidationError(f"Key processing error: {str(e)}") from e except Exception as e: if isinstance(e, (MicrosoftAuthError, MicrosoftTokenValidationError)): raise logger.error(f"Unexpected error during token validation: {e}") raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}") from e # Singleton instance microsoft_auth_service = MicrosoftAuthService() def get_microsoft_auth_service() -> MicrosoftAuthService: """Get Microsoft authentication service instance.""" return microsoft_auth_service