191 lines
No EOL
6.6 KiB
Python
191 lines
No EOL
6.6 KiB
Python
import os
|
|
import secrets
|
|
import base64
|
|
import hashlib
|
|
from typing import Optional, Dict, Any
|
|
import config
|
|
|
|
# Only import msal if it's enabled
|
|
if config.is_msal_enabled():
|
|
import msal
|
|
else:
|
|
msal = None
|
|
|
|
# MSAL Scopes - Following specification
|
|
SCOPES = ["openid", "profile", "email"]
|
|
|
|
class MSALAuth:
|
|
def __init__(self):
|
|
if not config.is_msal_enabled():
|
|
self.enabled = False
|
|
self.client_id = None
|
|
self.authority = None
|
|
self.redirect_uri = None
|
|
return
|
|
|
|
if msal is None:
|
|
raise ImportError("MSAL library not available but MSAL is enabled")
|
|
|
|
msal_config = config.get_msal_config()
|
|
self.enabled = True
|
|
self.client_id = msal_config["client_id"]
|
|
self.authority = msal_config["authority"]
|
|
self.redirect_uri = msal_config["redirect_uri"]
|
|
|
|
if not all([self.client_id, self.authority, self.redirect_uri]):
|
|
raise ValueError("Missing Azure AD configuration. Check environment variables.")
|
|
|
|
def _build_msal_app(self, cache=None):
|
|
"""Create MSAL public client application for SPA"""
|
|
return msal.PublicClientApplication(
|
|
self.client_id,
|
|
authority=self.authority,
|
|
token_cache=cache
|
|
)
|
|
|
|
def _generate_pkce_challenge(self, code_verifier: str) -> str:
|
|
"""Generate PKCE code challenge from code verifier"""
|
|
code_challenge = base64.urlsafe_b64encode(
|
|
hashlib.sha256(code_verifier.encode('utf-8')).digest()
|
|
).decode('utf-8').rstrip('=')
|
|
return code_challenge
|
|
|
|
def get_auth_url(self, session_state: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Generate authorization URL with PKCE challenge for SPA
|
|
Returns: dict with auth_url, state, and code_verifier for session storage
|
|
"""
|
|
if not self.enabled:
|
|
raise RuntimeError("MSAL is disabled")
|
|
|
|
app = self._build_msal_app()
|
|
|
|
# Generate PKCE parameters (required for SPA)
|
|
code_verifier = secrets.token_urlsafe(96)
|
|
|
|
# Generate state parameter for CSRF protection
|
|
state = session_state or secrets.token_urlsafe(32)
|
|
|
|
code_challenge = self._generate_pkce_challenge(code_verifier)
|
|
|
|
# Manual URL construction to ensure PKCE parameters are included
|
|
# MSAL library sometimes doesn't include PKCE params correctly for SPA
|
|
import urllib.parse
|
|
|
|
auth_params = {
|
|
"client_id": self.client_id,
|
|
"response_type": "code",
|
|
"redirect_uri": self.redirect_uri,
|
|
"scope": " ".join(SCOPES),
|
|
"state": state,
|
|
"code_challenge": code_challenge,
|
|
"code_challenge_method": "S256",
|
|
"response_mode": "query"
|
|
}
|
|
|
|
auth_url = f"{self.authority}/oauth2/v2.0/authorize?" + urllib.parse.urlencode(auth_params)
|
|
|
|
return {
|
|
"auth_url": auth_url,
|
|
"state": state,
|
|
"code_verifier": code_verifier
|
|
}
|
|
|
|
def acquire_token_by_auth_code(self,
|
|
auth_code: str,
|
|
code_verifier: str,
|
|
scopes: Optional[list] = None) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Exchange authorization code for tokens using PKCE for SPA
|
|
"""
|
|
if not self.enabled:
|
|
raise RuntimeError("MSAL is disabled")
|
|
|
|
if not code_verifier:
|
|
print("MSAL Error: code_verifier is required for SPA")
|
|
return None
|
|
|
|
app = self._build_msal_app()
|
|
|
|
if scopes is None:
|
|
scopes = SCOPES
|
|
|
|
# Manual token request for PKCE flow - MSAL library may not support code_verifier properly
|
|
import requests
|
|
|
|
token_url = f"{self.authority}/oauth2/v2.0/token"
|
|
token_data = {
|
|
"client_id": self.client_id,
|
|
"scope": " ".join(scopes),
|
|
"code": auth_code,
|
|
"redirect_uri": self.redirect_uri,
|
|
"grant_type": "authorization_code",
|
|
"code_verifier": code_verifier
|
|
}
|
|
|
|
try:
|
|
response = requests.post(token_url, data=token_data)
|
|
result = response.json()
|
|
|
|
if response.status_code != 200:
|
|
result["error"] = result.get("error", "token_request_failed")
|
|
|
|
except Exception as e:
|
|
print(f"MSAL Error: Manual token request failed: {e}")
|
|
return None
|
|
|
|
if "error" in result:
|
|
print(f"MSAL Error: {result.get('error_description', result.get('error'))}")
|
|
return None
|
|
|
|
return result
|
|
|
|
def get_user_profile(self, token_result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Extract user profile information from token result
|
|
"""
|
|
if not token_result or "access_token" not in token_result:
|
|
return None
|
|
|
|
# Get user info from ID token claims
|
|
id_token_claims = token_result.get("id_token_claims", {})
|
|
|
|
# Fallback to access token if available
|
|
if not id_token_claims and "access_token" in token_result:
|
|
# We could make a Graph API call here, but ID token should have the basic info
|
|
pass
|
|
|
|
if id_token_claims:
|
|
return {
|
|
"azure_ad_id": id_token_claims.get("oid"), # Object ID - unique identifier
|
|
"email": id_token_claims.get("email") or id_token_claims.get("preferred_username"),
|
|
"full_name": id_token_claims.get("name"),
|
|
"first_name": id_token_claims.get("given_name"),
|
|
"last_name": id_token_claims.get("family_name"),
|
|
"tenant_id": id_token_claims.get("tid")
|
|
}
|
|
|
|
return None
|
|
|
|
def validate_state(self, received_state: str, session_state: str) -> bool:
|
|
"""
|
|
Validate state parameter to prevent CSRF attacks
|
|
"""
|
|
return received_state == session_state
|
|
|
|
# Global instance - only create if MSAL is enabled
|
|
msal_auth = None
|
|
if config.is_msal_enabled():
|
|
try:
|
|
msal_auth = MSALAuth()
|
|
except (ValueError, ImportError) as e:
|
|
print(f"Warning: Could not initialize MSAL: {e}")
|
|
msal_auth = None
|
|
|
|
def get_msal_instance() -> Optional[MSALAuth]:
|
|
"""Get the global MSAL authentication instance"""
|
|
return msal_auth
|
|
|
|
def is_msal_available() -> bool:
|
|
"""Check if MSAL is available and properly configured"""
|
|
return msal_auth is not None and msal_auth.enabled |