- msal_auth.py: replace verify_signature=False with real JWKS verification using PyJWKClient; validates RS256 signature, aud=clientId, issuer v2.0 - App.tsx: split DEV bypass from empty-accounts case — in production, accounts.length === 0 now correctly triggers loginRedirect instead of calling fetchMe without a token Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
"""
|
|
MSAL / Azure AD token validator (SPA PKCE flow).
|
|
Backend only validates incoming Bearer JWTs — no server-side MSAL client needed.
|
|
Frontend sends the MSAL idToken (aud = clientId) for user identification.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
|
|
import jwt
|
|
from jwt import PyJWKClient
|
|
|
|
from ..config_runtime import server_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# JWKS client caches keys after first fetch
|
|
_jwks_client: Optional[PyJWKClient] = None
|
|
|
|
|
|
def _get_jwks_client() -> PyJWKClient:
|
|
global _jwks_client
|
|
if _jwks_client is None:
|
|
jwks_uri = (
|
|
f"https://login.microsoftonline.com/"
|
|
f"{server_config.MSAL_TENANT_ID}/discovery/v2.0/keys"
|
|
)
|
|
_jwks_client = PyJWKClient(jwks_uri, cache_keys=True)
|
|
return _jwks_client
|
|
|
|
|
|
class MSALAuthenticator:
|
|
def __init__(self):
|
|
if server_config.DEV_MODE:
|
|
logger.info("Running in DEV_MODE — MSAL authentication bypassed")
|
|
|
|
async def validate_token(self, access_token: str) -> Optional[Dict[str, Any]]:
|
|
if server_config.DEV_MODE:
|
|
return {
|
|
'oid': server_config.DEV_USER_ID,
|
|
'preferred_username': server_config.DEV_USER_EMAIL,
|
|
'name': server_config.DEV_USER_NAME,
|
|
}
|
|
|
|
if not access_token:
|
|
return None
|
|
|
|
try:
|
|
jwks_client = _get_jwks_client()
|
|
signing_key = jwks_client.get_signing_key_from_jwt(access_token)
|
|
claims = jwt.decode(
|
|
access_token,
|
|
signing_key.key,
|
|
algorithms=["RS256"],
|
|
audience=server_config.MSAL_CLIENT_ID,
|
|
issuer=f"https://login.microsoftonline.com/{server_config.MSAL_TENANT_ID}/v2.0",
|
|
)
|
|
|
|
user_id = claims.get('oid')
|
|
if not user_id:
|
|
logger.warning("Token missing 'oid' claim")
|
|
return None
|
|
|
|
return {
|
|
'oid': user_id,
|
|
'preferred_username': claims.get('preferred_username') or claims.get('upn', ''),
|
|
'name': claims.get('name', ''),
|
|
}
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
logger.warning("Token expired")
|
|
return None
|
|
except jwt.InvalidTokenError as e:
|
|
logger.warning(f"Invalid JWT: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Token validation error: {e}", exc_info=True)
|
|
return None
|
|
|
|
async def get_logout_url(self, post_logout_redirect_uri: Optional[str] = None) -> str:
|
|
if server_config.DEV_MODE:
|
|
return post_logout_redirect_uri or 'http://localhost:5173'
|
|
base = f"{server_config.MSAL_AUTHORITY}/oauth2/v2.0/logout"
|
|
if post_logout_redirect_uri:
|
|
return f"{base}?post_logout_redirect_uri={post_logout_redirect_uri}"
|
|
return base
|
|
|
|
def get_client_config(self) -> Dict[str, Any]:
|
|
if server_config.DEV_MODE:
|
|
return {
|
|
'clientId': server_config.MSAL_CLIENT_ID,
|
|
'authority': server_config.MSAL_AUTHORITY,
|
|
'redirectUri': server_config.MSAL_REDIRECT_URI,
|
|
'devMode': True,
|
|
}
|
|
return {
|
|
'clientId': server_config.MSAL_CLIENT_ID,
|
|
'authority': server_config.MSAL_AUTHORITY,
|
|
'redirectUri': server_config.MSAL_REDIRECT_URI,
|
|
'devMode': False,
|
|
}
|
|
|
|
def is_dev_mode(self) -> bool:
|
|
return server_config.DEV_MODE
|
|
|
|
|
|
msal_auth = MSALAuthenticator()
|