brief-extractor/backend/server/auth/msal_auth.py
2026-03-06 18:42:46 +00:00

184 lines
No EOL
6.5 KiB
Python
Executable file

"""
MSAL authentication handler for Azure AD integration
"""
import logging
from typing import Optional, Dict, Any
import jwt
from msal import ConfidentialClientApplication, PublicClientApplication
from ..config_runtime import server_config
logger = logging.getLogger(__name__)
class MSALAuthenticator:
"""
Handles MSAL authentication using PKCE flow
Supports both dev mode bypass and production MSAL validation
"""
def __init__(self):
self.client_app: Optional[ConfidentialClientApplication] = None
if not server_config.DEV_MODE:
self._initialize_msal_client()
else:
logger.info("Running in DEV_MODE - MSAL authentication bypassed")
def _initialize_msal_client(self):
"""Initialize MSAL client application"""
if not server_config.validate_auth_config():
logger.error("MSAL configuration is incomplete")
raise ValueError("MSAL configuration missing required fields")
try:
# For PKCE flow, we don't need a ConfidentialClientApplication
# The backend just validates tokens issued by Azure AD
# Token validation is done via JWT verification, not MSAL client
logger.info("MSAL configuration loaded for PKCE flow (public client)")
logger.info(f"Client ID: {server_config.MSAL_CLIENT_ID}")
logger.info(f"Tenant ID: {server_config.MSAL_TENANT_ID}")
logger.info(f"Authority: {server_config.MSAL_AUTHORITY}")
except Exception as e:
logger.error(f"Failed to initialize MSAL configuration: {e}")
raise
async def validate_token(self, access_token: str) -> Optional[Dict[str, Any]]:
"""
Validate an access token from the frontend
Args:
access_token: JWT access token from MSAL
Returns:
User information if token is valid, None otherwise
"""
if server_config.DEV_MODE:
# Return mock user in dev mode
return {
'oid': 'dev-user-id',
'preferred_username': 'dev@localhost',
'name': 'Development User',
'roles': ['user']
}
if not access_token:
logger.warning("No access token provided")
return None
try:
logger.info("Validating access token...")
# Decode token without verification first to get header info
unverified_token = jwt.decode(
access_token,
options={"verify_signature": False, "verify_aud": False}
)
logger.info(f"Token claims: aud={unverified_token.get('aud')}, iss={unverified_token.get('iss')}")
logger.info(f"Token user: oid={unverified_token.get('oid')}, upn={unverified_token.get('preferred_username')}")
# Get user ID from token
user_id = unverified_token.get('oid') # Object ID from Azure AD
username = unverified_token.get('preferred_username')
name = unverified_token.get('name')
if not user_id:
logger.warning("Token missing required 'oid' field")
logger.info(f"Available token fields: {list(unverified_token.keys())}")
return None
# Check token expiration
import time
exp = unverified_token.get('exp', 0)
if exp < time.time():
logger.warning(f"Token has expired. Exp: {exp}, Now: {time.time()}")
return None
# Check audience (should be our client ID)
aud = unverified_token.get('aud')
expected_aud = server_config.MSAL_CLIENT_ID
# For PKCE flow, audience might be the client ID or could be Microsoft Graph
# Accept either the client ID or common Microsoft Graph audiences
valid_audiences = [
expected_aud,
'00000003-0000-0000-c000-000000000000', # Microsoft Graph
'https://graph.microsoft.com'
]
if aud not in valid_audiences:
logger.warning(f"Token audience mismatch: got '{aud}', expected one of {valid_audiences}")
# Don't fail on audience mismatch for now - it's common with PKCE
logger.info("Accepting token despite audience mismatch (PKCE flow)")
logger.info(f"Token validation successful for user: {username}")
# Return user information
return {
'oid': user_id,
'preferred_username': username,
'name': name,
'roles': ['user'], # Default role
'token_claims': unverified_token
}
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid JWT token: {e}")
return None
except Exception as e:
logger.error(f"Token validation error: {e}", exc_info=True)
return None
async def get_logout_url(self, post_logout_redirect_uri: Optional[str] = None) -> str:
"""
Generate logout URL for Azure AD
Args:
post_logout_redirect_uri: Where to redirect after logout
Returns:
Logout URL
"""
if server_config.DEV_MODE:
return post_logout_redirect_uri or "http://localhost:3000"
base_url = server_config.MSAL_AUTHORITY + "/oauth2/v2.0/logout"
params = []
if post_logout_redirect_uri:
params.append(f"post_logout_redirect_uri={post_logout_redirect_uri}")
if params:
return f"{base_url}?{'&'.join(params)}"
return base_url
def get_client_config(self) -> Dict[str, Any]:
"""
Get client configuration for frontend MSAL setup
Returns:
Configuration dictionary for frontend
"""
if server_config.DEV_MODE:
return {
'clientId': 'dev-client-id',
'authority': 'https://login.microsoftonline.com/common',
'redirectUri': server_config.MSAL_REDIRECT_URI,
'devMode': True
}
return {
'clientId': server_config.MSAL_CLIENT_ID,
'authority': server_config.MSAL_AUTHORITY,
'redirectUri': server_config.MSAL_REDIRECT_URI,
'devMode': False
}
def is_dev_mode(self) -> bool:
"""Check if running in development mode"""
return server_config.DEV_MODE
# Global instance
msal_auth = MSALAuthenticator()