solventum-image-metadata/backend/app/core/auth.py
SamoilenkoVadym 5f5c04471c feat(sso): migrate to client-side MSAL flow without client secret
Backend changes:
- Add PyJWT for Azure AD id_token validation
- Add validate_azure_id_token() function in core/auth.py
- Replace /microsoft/login and /microsoft/callback with /microsoft/login POST
- New endpoint validates id_token from frontend (no Graph API calls)
- Support PublicClientApplication (no client secret needed)

Frontend changes:
- Add @azure/msal-browser and @azure/msal-react dependencies
- Create msalConfig.ts with MSAL configuration
- Wrap App with MsalProvider
- Update LoginPage to use useMsal hook and loginPopup
- Remove OAuthCallback handler (MSAL handles redirect)
- Frontend gets id_token from Microsoft, sends to backend

Benefits:
-  Works without AZURE_CLIENT_SECRET (matches apac-ops-bot)
-  More secure (no secret in backend)
-  Simpler backend (just JWT validation)
-  Better UX (MSAL handles popups, silent refresh)

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
2026-02-09 17:25:34 +00:00

311 lines
7.9 KiB
Python

"""
JWT Authentication
Replaces Flask session-based auth with JWT tokens + Redis refresh tokens.
"""
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import os
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT Configuration
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Security scheme
security = HTTPBearer()
# ===== Password Hashing =====
def hash_password(password: str) -> str:
"""
Hash a password using bcrypt.
Args:
password: Plain text password
Returns:
Hashed password
"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against its hash.
Args:
plain_password: Plain text password
hashed_password: Hashed password from database
Returns:
True if password matches, False otherwise
"""
return pwd_context.verify(plain_password, hashed_password)
# ===== JWT Token Creation =====
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create JWT access token (short-lived, 30 minutes).
Args:
data: Payload data (typically {"sub": user_id})
expires_delta: Optional custom expiration time
Returns:
JWT token string
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({
"exp": expire,
"type": "access"
})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def create_refresh_token(user_id: int) -> str:
"""
Create JWT refresh token (long-lived, 7 days).
Stored in Redis for validation.
Args:
user_id: User ID from database
Returns:
JWT refresh token string
"""
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"sub": str(user_id),
"exp": expire,
"type": "refresh"
}
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ===== JWT Token Validation =====
def decode_token(token: str) -> dict:
"""
Decode and validate JWT token.
Args:
token: JWT token string
Returns:
Decoded payload
Raises:
HTTPException: If token is invalid or expired
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)
def verify_access_token(token: str) -> int:
"""
Verify access token and extract user ID.
Args:
token: JWT access token
Returns:
user_id: User ID from token
Raises:
HTTPException: If token is invalid or not an access token
"""
payload = decode_token(token)
# Check token type
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
headers={"WWW-Authenticate": "Bearer"},
)
# Extract user ID
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
headers={"WWW-Authenticate": "Bearer"},
)
return int(user_id)
def verify_refresh_token(token: str) -> int:
"""
Verify refresh token and extract user ID.
Args:
token: JWT refresh token
Returns:
user_id: User ID from token
Raises:
HTTPException: If token is invalid or not a refresh token
"""
payload = decode_token(token)
# Check token type
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
headers={"WWW-Authenticate": "Bearer"},
)
# Extract user ID
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
headers={"WWW-Authenticate": "Bearer"},
)
return int(user_id)
# ===== FastAPI Dependencies =====
async def get_current_user_id(
credentials: HTTPAuthorizationCredentials = Depends(security)
) -> int:
"""
FastAPI dependency to get current user ID from JWT token.
Use this to protect endpoints: @router.get("/protected", dependencies=[Depends(get_current_user_id)])
Args:
credentials: HTTP Bearer credentials from Authorization header
Returns:
user_id: Current user's ID
Raises:
HTTPException: If token is invalid
"""
token = credentials.credentials
user_id = verify_access_token(token)
return user_id
# ===== Helper Functions =====
def create_tokens_response(user_id: int) -> dict:
"""
Create both access and refresh tokens for login response.
Args:
user_id: User ID from database
Returns:
Dict with access_token, refresh_token, token_type
"""
access_token = create_access_token({"sub": str(user_id)})
refresh_token = create_refresh_token(user_id)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60 # seconds
}
# ===== Azure AD ID Token Validation =====
def validate_azure_id_token(id_token: str, client_id: str, tenant_id: str) -> dict:
"""
Validate Azure AD id_token (JWT from Microsoft).
This validates the JWT signature using Microsoft's public keys,
verifies the issuer and audience, and extracts user claims.
Args:
id_token: ID token JWT string from Azure AD
client_id: Azure application client ID (audience)
tenant_id: Azure tenant ID
Returns:
Decoded token payload with user claims (email, name, etc.)
Raises:
HTTPException: If token is invalid, expired, or signature verification fails
"""
import jwt
from jwt import PyJWKClient
try:
# Get Microsoft's public signing keys
jwks_url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
jwks_client = PyJWKClient(jwks_url)
# Get the signing key from the JWT header
signing_key = jwks_client.get_signing_key_from_jwt(id_token)
# Decode and validate the token
decoded = jwt.decode(
id_token,
signing_key.key,
algorithms=["RS256"],
audience=client_id,
issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0"
)
return decoded
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="ID token has expired"
)
except jwt.InvalidAudienceError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token audience (client ID mismatch)"
)
except jwt.InvalidIssuerError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer (tenant ID mismatch)"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"ID token validation failed: {str(e)}"
)