agent_tracker/msal_auth.py
2025-08-17 07:23:53 -05:00

159 lines
No EOL
5.3 KiB
Python

import os
import secrets
import base64
import hashlib
from typing import Optional, Dict, Any
import config
# Only import msal if it's enabled
if config.is_msal_enabled():
import msal
else:
msal = None
# MSAL Scopes
SCOPES = ["User.Read"]
class MSALAuth:
def __init__(self):
if not config.is_msal_enabled():
self.enabled = False
self.client_id = None
self.authority = None
self.redirect_uri = None
return
if msal is None:
raise ImportError("MSAL library not available but MSAL is enabled")
msal_config = config.get_msal_config()
self.enabled = True
self.client_id = msal_config["client_id"]
self.authority = msal_config["authority"]
self.redirect_uri = msal_config["redirect_uri"]
if not all([self.client_id, self.authority, self.redirect_uri]):
raise ValueError("Missing Azure AD configuration. Check environment variables.")
def _build_msal_app(self, cache=None):
"""Create MSAL confidential client application"""
return msal.PublicClientApplication(
self.client_id,
authority=self.authority,
token_cache=cache
)
def _generate_pkce_challenge(self, code_verifier: str) -> str:
"""Generate PKCE code challenge from code verifier"""
code_challenge = base64.urlsafe_b64encode(
hashlib.sha256(code_verifier.encode('utf-8')).digest()
).decode('utf-8').rstrip('=')
return code_challenge
def get_auth_url(self, session_state: Optional[str] = None) -> Dict[str, Any]:
"""
Generate authorization URL with PKCE challenge
Returns: dict with auth_url, state, and code_verifier for session storage
"""
if not self.enabled:
raise RuntimeError("MSAL is disabled")
app = self._build_msal_app()
# Generate PKCE parameters
code_verifier = secrets.token_urlsafe(96)
# Generate state parameter for CSRF protection
state = session_state or secrets.token_urlsafe(32)
auth_url = app.get_authorization_request_url(
scopes=SCOPES,
redirect_uri=self.redirect_uri,
state=state,
code_challenge=self._generate_pkce_challenge(code_verifier),
code_challenge_method="S256"
)
return {
"auth_url": auth_url,
"state": state,
"code_verifier": code_verifier
}
def acquire_token_by_auth_code(self,
auth_code: str,
code_verifier: str,
scopes: Optional[list] = None) -> Optional[Dict[str, Any]]:
"""
Exchange authorization code for tokens using PKCE
"""
if not self.enabled:
raise RuntimeError("MSAL is disabled")
app = self._build_msal_app()
if scopes is None:
scopes = SCOPES
result = app.acquire_token_by_authorization_code(
auth_code,
scopes=scopes,
redirect_uri=self.redirect_uri,
code_verifier=code_verifier
)
if "error" in result:
print(f"MSAL Error: {result.get('error_description', result.get('error'))}")
return None
return result
def get_user_profile(self, token_result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Extract user profile information from token result
"""
if not token_result or "access_token" not in token_result:
return None
# Get user info from ID token claims
id_token_claims = token_result.get("id_token_claims", {})
# Fallback to access token if available
if not id_token_claims and "access_token" in token_result:
# We could make a Graph API call here, but ID token should have the basic info
pass
if id_token_claims:
return {
"azure_ad_id": id_token_claims.get("oid"), # Object ID - unique identifier
"email": id_token_claims.get("email") or id_token_claims.get("preferred_username"),
"full_name": id_token_claims.get("name"),
"first_name": id_token_claims.get("given_name"),
"last_name": id_token_claims.get("family_name"),
"tenant_id": id_token_claims.get("tid")
}
return None
def validate_state(self, received_state: str, session_state: str) -> bool:
"""
Validate state parameter to prevent CSRF attacks
"""
return received_state == session_state
# Global instance - only create if MSAL is enabled
msal_auth = None
if config.is_msal_enabled():
try:
msal_auth = MSALAuth()
except (ValueError, ImportError) as e:
print(f"Warning: Could not initialize MSAL: {e}")
msal_auth = None
def get_msal_instance() -> Optional[MSALAuth]:
"""Get the global MSAL authentication instance"""
return msal_auth
def is_msal_available() -> bool:
"""Check if MSAL is available and properly configured"""
return msal_auth is not None and msal_auth.enabled