Backend changes: - Add PyJWT for Azure AD id_token validation - Add validate_azure_id_token() function in core/auth.py - Replace /microsoft/login and /microsoft/callback with /microsoft/login POST - New endpoint validates id_token from frontend (no Graph API calls) - Support PublicClientApplication (no client secret needed) Frontend changes: - Add @azure/msal-browser and @azure/msal-react dependencies - Create msalConfig.ts with MSAL configuration - Wrap App with MsalProvider - Update LoginPage to use useMsal hook and loginPopup - Remove OAuthCallback handler (MSAL handles redirect) - Frontend gets id_token from Microsoft, sends to backend Benefits: - ✅ Works without AZURE_CLIENT_SECRET (matches apac-ops-bot) - ✅ More secure (no secret in backend) - ✅ Simpler backend (just JWT validation) - ✅ Better UX (MSAL handles popups, silent refresh) Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
311 lines
7.9 KiB
Python
311 lines
7.9 KiB
Python
"""
|
|
JWT Authentication
|
|
Replaces Flask session-based auth with JWT tokens + Redis refresh tokens.
|
|
"""
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
import os
|
|
|
|
# Password hashing
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# JWT Configuration
|
|
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
|
# Security scheme
|
|
security = HTTPBearer()
|
|
|
|
|
|
# ===== Password Hashing =====
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""
|
|
Hash a password using bcrypt.
|
|
|
|
Args:
|
|
password: Plain text password
|
|
|
|
Returns:
|
|
Hashed password
|
|
"""
|
|
return pwd_context.hash(password)
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""
|
|
Verify a password against its hash.
|
|
|
|
Args:
|
|
plain_password: Plain text password
|
|
hashed_password: Hashed password from database
|
|
|
|
Returns:
|
|
True if password matches, False otherwise
|
|
"""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
|
|
# ===== JWT Token Creation =====
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create JWT access token (short-lived, 30 minutes).
|
|
|
|
Args:
|
|
data: Payload data (typically {"sub": user_id})
|
|
expires_delta: Optional custom expiration time
|
|
|
|
Returns:
|
|
JWT token string
|
|
"""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"type": "access"
|
|
})
|
|
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
def create_refresh_token(user_id: int) -> str:
|
|
"""
|
|
Create JWT refresh token (long-lived, 7 days).
|
|
Stored in Redis for validation.
|
|
|
|
Args:
|
|
user_id: User ID from database
|
|
|
|
Returns:
|
|
JWT refresh token string
|
|
"""
|
|
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
to_encode = {
|
|
"sub": str(user_id),
|
|
"exp": expire,
|
|
"type": "refresh"
|
|
}
|
|
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
|
|
# ===== JWT Token Validation =====
|
|
|
|
def decode_token(token: str) -> dict:
|
|
"""
|
|
Decode and validate JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
|
|
Returns:
|
|
Decoded payload
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or expired
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
return payload
|
|
except JWTError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=f"Invalid token: {str(e)}",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
def verify_access_token(token: str) -> int:
|
|
"""
|
|
Verify access token and extract user ID.
|
|
|
|
Args:
|
|
token: JWT access token
|
|
|
|
Returns:
|
|
user_id: User ID from token
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or not an access token
|
|
"""
|
|
payload = decode_token(token)
|
|
|
|
# Check token type
|
|
if payload.get("type") != "access":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token type",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Extract user ID
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return int(user_id)
|
|
|
|
|
|
def verify_refresh_token(token: str) -> int:
|
|
"""
|
|
Verify refresh token and extract user ID.
|
|
|
|
Args:
|
|
token: JWT refresh token
|
|
|
|
Returns:
|
|
user_id: User ID from token
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or not a refresh token
|
|
"""
|
|
payload = decode_token(token)
|
|
|
|
# Check token type
|
|
if payload.get("type") != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token type",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Extract user ID
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return int(user_id)
|
|
|
|
|
|
# ===== FastAPI Dependencies =====
|
|
|
|
async def get_current_user_id(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
) -> int:
|
|
"""
|
|
FastAPI dependency to get current user ID from JWT token.
|
|
Use this to protect endpoints: @router.get("/protected", dependencies=[Depends(get_current_user_id)])
|
|
|
|
Args:
|
|
credentials: HTTP Bearer credentials from Authorization header
|
|
|
|
Returns:
|
|
user_id: Current user's ID
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid
|
|
"""
|
|
token = credentials.credentials
|
|
user_id = verify_access_token(token)
|
|
return user_id
|
|
|
|
|
|
# ===== Helper Functions =====
|
|
|
|
def create_tokens_response(user_id: int) -> dict:
|
|
"""
|
|
Create both access and refresh tokens for login response.
|
|
|
|
Args:
|
|
user_id: User ID from database
|
|
|
|
Returns:
|
|
Dict with access_token, refresh_token, token_type
|
|
"""
|
|
access_token = create_access_token({"sub": str(user_id)})
|
|
refresh_token = create_refresh_token(user_id)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"refresh_token": refresh_token,
|
|
"token_type": "bearer",
|
|
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60 # seconds
|
|
}
|
|
|
|
|
|
# ===== Azure AD ID Token Validation =====
|
|
|
|
def validate_azure_id_token(id_token: str, client_id: str, tenant_id: str) -> dict:
|
|
"""
|
|
Validate Azure AD id_token (JWT from Microsoft).
|
|
|
|
This validates the JWT signature using Microsoft's public keys,
|
|
verifies the issuer and audience, and extracts user claims.
|
|
|
|
Args:
|
|
id_token: ID token JWT string from Azure AD
|
|
client_id: Azure application client ID (audience)
|
|
tenant_id: Azure tenant ID
|
|
|
|
Returns:
|
|
Decoded token payload with user claims (email, name, etc.)
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid, expired, or signature verification fails
|
|
"""
|
|
import jwt
|
|
from jwt import PyJWKClient
|
|
|
|
try:
|
|
# Get Microsoft's public signing keys
|
|
jwks_url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
|
|
jwks_client = PyJWKClient(jwks_url)
|
|
|
|
# Get the signing key from the JWT header
|
|
signing_key = jwks_client.get_signing_key_from_jwt(id_token)
|
|
|
|
# Decode and validate the token
|
|
decoded = jwt.decode(
|
|
id_token,
|
|
signing_key.key,
|
|
algorithms=["RS256"],
|
|
audience=client_id,
|
|
issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0"
|
|
)
|
|
|
|
return decoded
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="ID token has expired"
|
|
)
|
|
except jwt.InvalidAudienceError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token audience (client ID mismatch)"
|
|
)
|
|
except jwt.InvalidIssuerError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token issuer (tenant ID mismatch)"
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=f"ID token validation failed: {str(e)}"
|
|
)
|