modcomms/backend/app/services/auth_service.py
michael 321a9ca820 Implement Microsoft MSAL SSO with PKCE flow
Frontend:
- Add @azure/msal-browser and @azure/msal-react packages
- Create authConfig.ts with MSAL configuration for PKCE flow
- Create authService.ts for token acquisition and user info
- Wrap App with MsalProvider in index.tsx
- Replace dummy login with real MSAL loginPopup() in Login.tsx
- Update App.tsx to use useIsAuthenticated/useMsal hooks
- Update Profile.tsx to display real user data from claims
- Update geminiService.ts to include access_token in WebSocket messages
- Update WIPReviewer.tsx to pass msalInstance for auth

Backend:
- Add python-jose and httpx dependencies for JWT verification
- Create auth_service.py with Azure AD JWKS fetching and token verification
- Create auth.py FastAPI dependency for protected REST endpoints
- Update main.py to verify tokens on WebSocket and protect /info endpoint
- Add AZURE_TENANT_ID, AZURE_CLIENT_ID, DISABLE_AUTH to config

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-16 08:43:30 -06:00

118 lines
3.6 KiB
Python

"""
Azure AD token verification service.
Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS).
"""
import logging
from typing import Optional
from datetime import datetime, timedelta
import httpx
from jose import jwt, JWTError
from app.config import settings
logger = logging.getLogger(__name__)
# Cache for JWKS (JSON Web Key Set)
_jwks_cache: dict = {}
_jwks_cache_expiry: datetime = datetime.min
async def get_azure_jwks() -> dict:
"""
Fetch and cache Azure AD's public keys for token verification.
Keys are cached for 24 hours to minimize network calls.
"""
global _jwks_cache, _jwks_cache_expiry
if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache:
return _jwks_cache
jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys"
try:
async with httpx.AsyncClient() as client:
response = await client.get(jwks_url, timeout=10.0)
response.raise_for_status()
_jwks_cache = response.json()
_jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24)
logger.info("Successfully fetched Azure AD JWKS")
return _jwks_cache
except Exception as e:
logger.error(f"Failed to fetch JWKS: {e}")
if _jwks_cache: # Return stale cache if available
return _jwks_cache
raise
async def verify_access_token(token: str) -> Optional[dict]:
"""
Verify an Azure AD access token and return the claims.
Args:
token: The JWT access token from the frontend
Returns:
The token claims dict if valid, None if invalid
"""
if settings.DISABLE_AUTH:
logger.warning("Auth disabled - skipping token verification")
return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
if not token:
logger.warning("No token provided")
return None
try:
# Get Azure AD public keys
jwks = await get_azure_jwks()
# Decode without verification first to get the key ID
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
if not kid:
logger.warning("No key ID in token header")
return None
# Find the matching key
rsa_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache")
# Try refreshing JWKS in case keys rotated
global _jwks_cache_expiry
_jwks_cache_expiry = datetime.min
jwks = await get_azure_jwks()
for key in jwks.get("keys", []):
if key.get("kid") == kid:
rsa_key = key
break
if not rsa_key:
logger.error("Could not find matching key after refresh")
return None
# Verify and decode the token
claims = jwt.decode(
token,
rsa_key,
algorithms=["RS256"],
audience=f"api://{settings.AZURE_CLIENT_ID}",
issuer=f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/",
)
logger.info(f"Token verified for user: {claims.get('name', 'unknown')}")
return claims
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")
return None
except Exception as e:
logger.error(f"Token verification error: {e}")
return None