olivas/backend/app/auth.py
Vadym Samoilenko f217a5aea6 Add Azure AD SSO authentication for backend and frontend
Replace X-User-Id header auth with Azure AD JWT token validation.
Backend validates tokens via JWKS, frontend uses MSAL for login/token
acquisition. Adds logout button, 401 handling, and configurable
AZURE_AUTH_ENABLED toggle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:41:06 +00:00

97 lines
2.6 KiB
Python

"""Azure AD JWT validation using JWKS (RS256)."""
import logging
import time
from dataclasses import dataclass
import httpx
import jwt
from app.config import settings
logger = logging.getLogger("olivas.auth")
# JWKS cache
_jwks_cache: dict = {}
_jwks_cache_time: float = 0
_JWKS_CACHE_TTL = 3600 # 1 hour
@dataclass
class CurrentUser:
oid: str
name: str
email: str
def _get_jwks_uri() -> str:
tenant = settings.AZURE_TENANT_ID
return f"https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys"
def _get_issuer() -> str:
tenant = settings.AZURE_TENANT_ID
return f"https://login.microsoftonline.com/{tenant}/v2.0"
def refresh_jwks_cache() -> None:
"""Fetch and cache JWKS keys from Azure AD."""
global _jwks_cache, _jwks_cache_time
uri = _get_jwks_uri()
logger.info(f"Fetching JWKS from {uri}")
resp = httpx.get(uri, timeout=10)
resp.raise_for_status()
_jwks_cache = resp.json()
_jwks_cache_time = time.time()
logger.info(f"JWKS cache refreshed ({len(_jwks_cache.get('keys', []))} keys)")
def _get_signing_key(token: str) -> jwt.algorithms.RSAAlgorithm:
"""Get the signing key for the given token from cached JWKS."""
global _jwks_cache, _jwks_cache_time
if not _jwks_cache or (time.time() - _jwks_cache_time > _JWKS_CACHE_TTL):
refresh_jwks_cache()
jwks_client = jwt.PyJWKClient.__new__(jwt.PyJWKClient)
jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
try:
header = jwt.get_unverified_header(token)
kid = header.get("kid")
for key in jwks_client.jwk_set.keys:
if key.key_id == kid:
return key.key
except Exception:
pass
# Key not found — try refreshing once
refresh_jwks_cache()
jwks_client.jwk_set = jwt.PyJWKSet.from_dict(_jwks_cache)
header = jwt.get_unverified_header(token)
kid = header.get("kid")
for key in jwks_client.jwk_set.keys:
if key.key_id == kid:
return key.key
raise jwt.InvalidTokenError(f"Unable to find signing key with kid={kid}")
def validate_token(token: str) -> CurrentUser:
"""Decode and validate an Azure AD JWT token. Returns CurrentUser."""
signing_key = _get_signing_key(token)
payload = jwt.decode(
token,
signing_key,
algorithms=["RS256"],
audience=settings.AZURE_CLIENT_ID,
issuer=_get_issuer(),
options={"require": ["exp", "iss", "aud", "oid"]},
)
return CurrentUser(
oid=payload["oid"],
name=payload.get("name", ""),
email=payload.get("preferred_username", payload.get("email", "")),
)