agent_tracker/msal_auth.py
2025-09-05 13:41:33 -05:00

191 lines
No EOL
6.6 KiB
Python

import os
import secrets
import base64
import hashlib
from typing import Optional, Dict, Any
import config
# Only import msal if it's enabled
if config.is_msal_enabled():
import msal
else:
msal = None
# MSAL Scopes - Following specification
SCOPES = ["openid", "profile", "email"]
class MSALAuth:
def __init__(self):
if not config.is_msal_enabled():
self.enabled = False
self.client_id = None
self.authority = None
self.redirect_uri = None
return
if msal is None:
raise ImportError("MSAL library not available but MSAL is enabled")
msal_config = config.get_msal_config()
self.enabled = True
self.client_id = msal_config["client_id"]
self.authority = msal_config["authority"]
self.redirect_uri = msal_config["redirect_uri"]
if not all([self.client_id, self.authority, self.redirect_uri]):
raise ValueError("Missing Azure AD configuration. Check environment variables.")
def _build_msal_app(self, cache=None):
"""Create MSAL public client application for SPA"""
return msal.PublicClientApplication(
self.client_id,
authority=self.authority,
token_cache=cache
)
def _generate_pkce_challenge(self, code_verifier: str) -> str:
"""Generate PKCE code challenge from code verifier"""
code_challenge = base64.urlsafe_b64encode(
hashlib.sha256(code_verifier.encode('utf-8')).digest()
).decode('utf-8').rstrip('=')
return code_challenge
def get_auth_url(self, session_state: Optional[str] = None) -> Dict[str, Any]:
"""
Generate authorization URL with PKCE challenge for SPA
Returns: dict with auth_url, state, and code_verifier for session storage
"""
if not self.enabled:
raise RuntimeError("MSAL is disabled")
app = self._build_msal_app()
# Generate PKCE parameters (required for SPA)
code_verifier = secrets.token_urlsafe(96)
# Generate state parameter for CSRF protection
state = session_state or secrets.token_urlsafe(32)
code_challenge = self._generate_pkce_challenge(code_verifier)
# Manual URL construction to ensure PKCE parameters are included
# MSAL library sometimes doesn't include PKCE params correctly for SPA
import urllib.parse
auth_params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": " ".join(SCOPES),
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"response_mode": "query"
}
auth_url = f"{self.authority}/oauth2/v2.0/authorize?" + urllib.parse.urlencode(auth_params)
return {
"auth_url": auth_url,
"state": state,
"code_verifier": code_verifier
}
def acquire_token_by_auth_code(self,
auth_code: str,
code_verifier: str,
scopes: Optional[list] = None) -> Optional[Dict[str, Any]]:
"""
Exchange authorization code for tokens using PKCE for SPA
"""
if not self.enabled:
raise RuntimeError("MSAL is disabled")
if not code_verifier:
print("MSAL Error: code_verifier is required for SPA")
return None
app = self._build_msal_app()
if scopes is None:
scopes = SCOPES
# Manual token request for PKCE flow - MSAL library may not support code_verifier properly
import requests
token_url = f"{self.authority}/oauth2/v2.0/token"
token_data = {
"client_id": self.client_id,
"scope": " ".join(scopes),
"code": auth_code,
"redirect_uri": self.redirect_uri,
"grant_type": "authorization_code",
"code_verifier": code_verifier
}
try:
response = requests.post(token_url, data=token_data)
result = response.json()
if response.status_code != 200:
result["error"] = result.get("error", "token_request_failed")
except Exception as e:
print(f"MSAL Error: Manual token request failed: {e}")
return None
if "error" in result:
print(f"MSAL Error: {result.get('error_description', result.get('error'))}")
return None
return result
def get_user_profile(self, token_result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Extract user profile information from token result
"""
if not token_result or "access_token" not in token_result:
return None
# Get user info from ID token claims
id_token_claims = token_result.get("id_token_claims", {})
# Fallback to access token if available
if not id_token_claims and "access_token" in token_result:
# We could make a Graph API call here, but ID token should have the basic info
pass
if id_token_claims:
return {
"azure_ad_id": id_token_claims.get("oid"), # Object ID - unique identifier
"email": id_token_claims.get("email") or id_token_claims.get("preferred_username"),
"full_name": id_token_claims.get("name"),
"first_name": id_token_claims.get("given_name"),
"last_name": id_token_claims.get("family_name"),
"tenant_id": id_token_claims.get("tid")
}
return None
def validate_state(self, received_state: str, session_state: str) -> bool:
"""
Validate state parameter to prevent CSRF attacks
"""
return received_state == session_state
# Global instance - only create if MSAL is enabled
msal_auth = None
if config.is_msal_enabled():
try:
msal_auth = MSALAuth()
except (ValueError, ImportError) as e:
print(f"Warning: Could not initialize MSAL: {e}")
msal_auth = None
def get_msal_instance() -> Optional[MSALAuth]:
"""Get the global MSAL authentication instance"""
return msal_auth
def is_msal_available() -> bool:
"""Check if MSAL is available and properly configured"""
return msal_auth is not None and msal_auth.enabled