ac-tool/backend/server/auth/msal_auth.py
Vadym Samoilenko 08710e1a16 fix: verify JWT signature via JWKS and fix auth dev bypass condition
- msal_auth.py: replace verify_signature=False with real JWKS verification
  using PyJWKClient; validates RS256 signature, aud=clientId, issuer v2.0
- App.tsx: split DEV bypass from empty-accounts case — in production,
  accounts.length === 0 now correctly triggers loginRedirect instead of
  calling fetchMe without a token

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-23 14:44:22 +00:00

107 lines
3.5 KiB
Python

"""
MSAL / Azure AD token validator (SPA PKCE flow).
Backend only validates incoming Bearer JWTs — no server-side MSAL client needed.
Frontend sends the MSAL idToken (aud = clientId) for user identification.
"""
import logging
from typing import Optional, Dict, Any
import jwt
from jwt import PyJWKClient
from ..config_runtime import server_config
logger = logging.getLogger(__name__)
# JWKS client caches keys after first fetch
_jwks_client: Optional[PyJWKClient] = None
def _get_jwks_client() -> PyJWKClient:
global _jwks_client
if _jwks_client is None:
jwks_uri = (
f"https://login.microsoftonline.com/"
f"{server_config.MSAL_TENANT_ID}/discovery/v2.0/keys"
)
_jwks_client = PyJWKClient(jwks_uri, cache_keys=True)
return _jwks_client
class MSALAuthenticator:
def __init__(self):
if server_config.DEV_MODE:
logger.info("Running in DEV_MODE — MSAL authentication bypassed")
async def validate_token(self, access_token: str) -> Optional[Dict[str, Any]]:
if server_config.DEV_MODE:
return {
'oid': server_config.DEV_USER_ID,
'preferred_username': server_config.DEV_USER_EMAIL,
'name': server_config.DEV_USER_NAME,
}
if not access_token:
return None
try:
jwks_client = _get_jwks_client()
signing_key = jwks_client.get_signing_key_from_jwt(access_token)
claims = jwt.decode(
access_token,
signing_key.key,
algorithms=["RS256"],
audience=server_config.MSAL_CLIENT_ID,
issuer=f"https://login.microsoftonline.com/{server_config.MSAL_TENANT_ID}/v2.0",
)
user_id = claims.get('oid')
if not user_id:
logger.warning("Token missing 'oid' claim")
return None
return {
'oid': user_id,
'preferred_username': claims.get('preferred_username') or claims.get('upn', ''),
'name': claims.get('name', ''),
}
except jwt.ExpiredSignatureError:
logger.warning("Token expired")
return None
except jwt.InvalidTokenError as e:
logger.warning(f"Invalid JWT: {e}")
return None
except Exception as e:
logger.error(f"Token validation error: {e}", exc_info=True)
return None
async def get_logout_url(self, post_logout_redirect_uri: Optional[str] = None) -> str:
if server_config.DEV_MODE:
return post_logout_redirect_uri or 'http://localhost:5173'
base = f"{server_config.MSAL_AUTHORITY}/oauth2/v2.0/logout"
if post_logout_redirect_uri:
return f"{base}?post_logout_redirect_uri={post_logout_redirect_uri}"
return base
def get_client_config(self) -> Dict[str, Any]:
if server_config.DEV_MODE:
return {
'clientId': server_config.MSAL_CLIENT_ID,
'authority': server_config.MSAL_AUTHORITY,
'redirectUri': server_config.MSAL_REDIRECT_URI,
'devMode': True,
}
return {
'clientId': server_config.MSAL_CLIENT_ID,
'authority': server_config.MSAL_AUTHORITY,
'redirectUri': server_config.MSAL_REDIRECT_URI,
'devMode': False,
}
def is_dev_mode(self) -> bool:
return server_config.DEV_MODE
msal_auth = MSALAuthenticator()