import os import httpx from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jose import jwt, JWTError TENANT_ID = os.environ.get("AZURE_TENANT_ID", "") CLIENT_ID = os.environ.get("AZURE_CLIENT_ID", "") JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys" ISSUER = f"https://login.microsoftonline.com/{TENANT_ID}/v2.0" bearer_scheme = HTTPBearer(auto_error=False) # Module-level cache — populated once per process, never blocks the event loop _jwks_cache: dict | None = None async def _get_jwks() -> dict: """Fetch JWKS from Azure using async HTTP. Cached in process memory.""" global _jwks_cache if _jwks_cache is not None: return _jwks_cache async with httpx.AsyncClient(timeout=10) as client: response = await client.get(JWKS_URL) response.raise_for_status() _jwks_cache = response.json() return _jwks_cache async def get_current_user( credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> dict: if os.environ.get("DEV_AUTH_BYPASS", "").lower() in ("1", "true", "yes"): return {"oid": "dev-user", "name": "Dev User", "email": "dev@localhost"} if credentials is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") token = credentials.credentials try: jwks = await _get_jwks() header = jwt.get_unverified_header(token) key = next( (k for k in jwks["keys"] if k.get("kid") == header.get("kid")), None, ) if key is None: # Key not in cache — fetch fresh JWKS once (keys can rotate) global _jwks_cache _jwks_cache = None jwks = await _get_jwks() key = next( (k for k in jwks["keys"] if k.get("kid") == header.get("kid")), None, ) if key is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown signing key") payload = jwt.decode( token, key, algorithms=["RS256"], audience=CLIENT_ID, issuer=ISSUER, options={"verify_at_hash": False}, ) return { "oid": payload.get("oid"), "name": payload.get("name"), "email": ( payload.get("preferred_username") or payload.get("upn") or payload.get("email") ), } except JWTError as e: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {e}")