Frontend: - Add @azure/msal-browser and @azure/msal-react packages - Create authConfig.ts with MSAL configuration for PKCE flow - Create authService.ts for token acquisition and user info - Wrap App with MsalProvider in index.tsx - Replace dummy login with real MSAL loginPopup() in Login.tsx - Update App.tsx to use useIsAuthenticated/useMsal hooks - Update Profile.tsx to display real user data from claims - Update geminiService.ts to include access_token in WebSocket messages - Update WIPReviewer.tsx to pass msalInstance for auth Backend: - Add python-jose and httpx dependencies for JWT verification - Create auth_service.py with Azure AD JWKS fetching and token verification - Create auth.py FastAPI dependency for protected REST endpoints - Update main.py to verify tokens on WebSocket and protect /info endpoint - Add AZURE_TENANT_ID, AZURE_CLIENT_ID, DISABLE_AUTH to config 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
"""
|
|
Azure AD token verification service.
|
|
|
|
Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS).
|
|
"""
|
|
import logging
|
|
from typing import Optional
|
|
from datetime import datetime, timedelta
|
|
|
|
import httpx
|
|
from jose import jwt, JWTError
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Cache for JWKS (JSON Web Key Set)
|
|
_jwks_cache: dict = {}
|
|
_jwks_cache_expiry: datetime = datetime.min
|
|
|
|
|
|
async def get_azure_jwks() -> dict:
|
|
"""
|
|
Fetch and cache Azure AD's public keys for token verification.
|
|
Keys are cached for 24 hours to minimize network calls.
|
|
"""
|
|
global _jwks_cache, _jwks_cache_expiry
|
|
|
|
if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache:
|
|
return _jwks_cache
|
|
|
|
jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys"
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(jwks_url, timeout=10.0)
|
|
response.raise_for_status()
|
|
_jwks_cache = response.json()
|
|
_jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24)
|
|
logger.info("Successfully fetched Azure AD JWKS")
|
|
return _jwks_cache
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch JWKS: {e}")
|
|
if _jwks_cache: # Return stale cache if available
|
|
return _jwks_cache
|
|
raise
|
|
|
|
|
|
async def verify_access_token(token: str) -> Optional[dict]:
|
|
"""
|
|
Verify an Azure AD access token and return the claims.
|
|
|
|
Args:
|
|
token: The JWT access token from the frontend
|
|
|
|
Returns:
|
|
The token claims dict if valid, None if invalid
|
|
"""
|
|
if settings.DISABLE_AUTH:
|
|
logger.warning("Auth disabled - skipping token verification")
|
|
return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
|
|
|
|
if not token:
|
|
logger.warning("No token provided")
|
|
return None
|
|
|
|
try:
|
|
# Get Azure AD public keys
|
|
jwks = await get_azure_jwks()
|
|
|
|
# Decode without verification first to get the key ID
|
|
unverified_header = jwt.get_unverified_header(token)
|
|
kid = unverified_header.get("kid")
|
|
|
|
if not kid:
|
|
logger.warning("No key ID in token header")
|
|
return None
|
|
|
|
# Find the matching key
|
|
rsa_key = None
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache")
|
|
# Try refreshing JWKS in case keys rotated
|
|
global _jwks_cache_expiry
|
|
_jwks_cache_expiry = datetime.min
|
|
jwks = await get_azure_jwks()
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
rsa_key = key
|
|
break
|
|
|
|
if not rsa_key:
|
|
logger.error("Could not find matching key after refresh")
|
|
return None
|
|
|
|
# Verify and decode the token
|
|
claims = jwt.decode(
|
|
token,
|
|
rsa_key,
|
|
algorithms=["RS256"],
|
|
audience=f"api://{settings.AZURE_CLIENT_ID}",
|
|
issuer=f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/",
|
|
)
|
|
|
|
logger.info(f"Token verified for user: {claims.get('name', 'unknown')}")
|
|
return claims
|
|
|
|
except JWTError as e:
|
|
logger.warning(f"JWT verification failed: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Token verification error: {e}")
|
|
return None
|