193 lines
6.6 KiB
Python
193 lines
6.6 KiB
Python
"""Microsoft Authentication Service
|
|
|
|
Validates Microsoft ID tokens and extracts user information.
|
|
"""
|
|
import time
|
|
|
|
import httpx
|
|
from jose import JWTError, jwt
|
|
from jose.exceptions import JWKError
|
|
from pydantic import BaseModel, EmailStr
|
|
|
|
from ..core.config import settings
|
|
from ..core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class MicrosoftUserInfo(BaseModel):
|
|
"""User information extracted from Microsoft ID token."""
|
|
email: EmailStr
|
|
name: str
|
|
sub: str # Microsoft user ID
|
|
tid: str # Tenant ID
|
|
email_verified: bool = True
|
|
|
|
|
|
class MicrosoftAuthError(Exception):
|
|
"""Base exception for Microsoft authentication errors."""
|
|
pass
|
|
|
|
|
|
class MicrosoftTokenValidationError(MicrosoftAuthError):
|
|
"""Raised when token validation fails."""
|
|
pass
|
|
|
|
|
|
class MicrosoftAuthService:
|
|
"""Service for Microsoft authentication operations."""
|
|
|
|
def __init__(self):
|
|
self.client_id = settings.azure_client_id
|
|
self.authority = settings.azure_authority
|
|
|
|
# Extract tenant ID from authority URL
|
|
# Format: https://login.microsoftonline.com/{tenant_id}
|
|
self.tenant_id = self.authority.rstrip('/').split('/')[-1]
|
|
|
|
# Microsoft's OpenID configuration endpoint
|
|
self.openid_config_url = f"{self.authority}/v2.0/.well-known/openid-configuration"
|
|
|
|
# Cache for JWKS (public keys)
|
|
self._jwks_cache: dict | None = None
|
|
self._jwks_cache_time: float = 0
|
|
self._jwks_cache_ttl: int = 3600 # Cache for 1 hour
|
|
|
|
async def _get_openid_config(self) -> dict:
|
|
"""Fetch OpenID Connect configuration from Microsoft."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.get(self.openid_config_url)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"Failed to fetch OpenID configuration: {e}")
|
|
raise MicrosoftAuthError("Failed to fetch Microsoft authentication configuration")
|
|
|
|
async def _get_jwks(self, force_refresh: bool = False) -> dict:
|
|
"""Fetch JSON Web Key Set (JWKS) from Microsoft.
|
|
|
|
Args:
|
|
force_refresh: Force refresh even if cache is valid
|
|
|
|
Returns:
|
|
JWKS dictionary with public keys
|
|
"""
|
|
current_time = time.time()
|
|
if (not force_refresh and
|
|
self._jwks_cache and
|
|
(current_time - self._jwks_cache_time) < self._jwks_cache_ttl):
|
|
return self._jwks_cache
|
|
|
|
try:
|
|
config = await self._get_openid_config()
|
|
jwks_uri = config.get('jwks_uri')
|
|
|
|
if not jwks_uri:
|
|
raise MicrosoftAuthError("JWKS URI not found in OpenID configuration")
|
|
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.get(jwks_uri)
|
|
response.raise_for_status()
|
|
jwks = response.json()
|
|
|
|
self._jwks_cache = jwks
|
|
self._jwks_cache_time = current_time
|
|
return jwks
|
|
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"Failed to fetch JWKS: {e}")
|
|
raise MicrosoftAuthError("Failed to fetch Microsoft public keys")
|
|
|
|
async def validate_token(self, id_token: str) -> MicrosoftUserInfo:
|
|
"""Validate Microsoft ID token and extract user information.
|
|
|
|
Args:
|
|
id_token: Microsoft ID token string
|
|
|
|
Returns:
|
|
MicrosoftUserInfo with validated user data
|
|
|
|
Raises:
|
|
MicrosoftTokenValidationError: If token validation fails
|
|
"""
|
|
try:
|
|
jwks = await self._get_jwks()
|
|
|
|
unverified_header = jwt.get_unverified_header(id_token)
|
|
kid = unverified_header.get('kid')
|
|
|
|
if not kid:
|
|
raise MicrosoftTokenValidationError("Token header missing 'kid' claim")
|
|
|
|
def _find_key(keys: list) -> dict | None:
|
|
for key in keys:
|
|
if key.get('kid') == kid:
|
|
return {'kty': key['kty'], 'kid': key['kid'], 'use': key.get('use'),
|
|
'n': key['n'], 'e': key['e']}
|
|
return None
|
|
|
|
rsa_key = _find_key(jwks.get('keys', []))
|
|
if not rsa_key:
|
|
logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache")
|
|
jwks = await self._get_jwks(force_refresh=True)
|
|
rsa_key = _find_key(jwks.get('keys', []))
|
|
|
|
if not rsa_key:
|
|
raise MicrosoftTokenValidationError(f"Unable to find key with ID: {kid}")
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
id_token,
|
|
rsa_key,
|
|
algorithms=['RS256'],
|
|
audience=self.client_id,
|
|
issuer=f"https://login.microsoftonline.com/{self.tenant_id}/v2.0"
|
|
)
|
|
except JWTError as e:
|
|
raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}")
|
|
|
|
email = payload.get('email') or payload.get('preferred_username')
|
|
if not email:
|
|
raise MicrosoftTokenValidationError("Token missing email claim")
|
|
|
|
name = payload.get('name') or email.split('@')[0]
|
|
|
|
sub = payload.get('sub')
|
|
if not sub:
|
|
raise MicrosoftTokenValidationError("Token missing 'sub' claim")
|
|
|
|
tid = payload.get('tid')
|
|
if not tid:
|
|
raise MicrosoftTokenValidationError("Token missing 'tid' claim")
|
|
|
|
email_verified = payload.get('email_verified', True)
|
|
|
|
user_info = MicrosoftUserInfo(
|
|
email=email,
|
|
name=name,
|
|
sub=sub,
|
|
tid=tid,
|
|
email_verified=email_verified,
|
|
)
|
|
|
|
logger.info(f"Successfully validated Microsoft token for user: {email}")
|
|
return user_info
|
|
|
|
except JWKError as e:
|
|
logger.error(f"JWK error during token validation: {e}")
|
|
raise MicrosoftTokenValidationError(f"Key processing error: {str(e)}")
|
|
except Exception as e:
|
|
if isinstance(e, (MicrosoftAuthError, MicrosoftTokenValidationError)):
|
|
raise
|
|
logger.error(f"Unexpected error during token validation: {e}")
|
|
raise MicrosoftTokenValidationError(f"Token validation failed: {str(e)}")
|
|
|
|
|
|
# Singleton instance
|
|
microsoft_auth_service = MicrosoftAuthService()
|
|
|
|
|
|
def get_microsoft_auth_service() -> MicrosoftAuthService:
|
|
"""Get Microsoft authentication service instance."""
|
|
return microsoft_auth_service
|