Access tokens for User.Read scope have audience=graph.microsoft.com, but the backend validates audience=CLIENT_ID. ID tokens always have audience=CLIENT_ID so they validate correctly. Also add upn claim fallback for email extraction from ID token. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
import os
|
|
import httpx
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from jose import jwt, JWTError
|
|
|
|
TENANT_ID = os.environ.get("AZURE_TENANT_ID", "")
|
|
CLIENT_ID = os.environ.get("AZURE_CLIENT_ID", "")
|
|
|
|
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
|
|
ISSUER = f"https://login.microsoftonline.com/{TENANT_ID}/v2.0"
|
|
|
|
bearer_scheme = HTTPBearer(auto_error=False)
|
|
|
|
# Module-level cache — populated once per process, never blocks the event loop
|
|
_jwks_cache: dict | None = None
|
|
|
|
|
|
async def _get_jwks() -> dict:
|
|
"""Fetch JWKS from Azure using async HTTP. Cached in process memory."""
|
|
global _jwks_cache
|
|
if _jwks_cache is not None:
|
|
return _jwks_cache
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.get(JWKS_URL)
|
|
response.raise_for_status()
|
|
_jwks_cache = response.json()
|
|
return _jwks_cache
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
|
) -> dict:
|
|
if os.environ.get("DEV_AUTH_BYPASS", "").lower() in ("1", "true", "yes"):
|
|
return {"oid": "dev-user", "name": "Dev User", "email": "dev@localhost"}
|
|
|
|
if credentials is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
|
|
token = credentials.credentials
|
|
try:
|
|
jwks = await _get_jwks()
|
|
header = jwt.get_unverified_header(token)
|
|
key = next(
|
|
(k for k in jwks["keys"] if k.get("kid") == header.get("kid")),
|
|
None,
|
|
)
|
|
if key is None:
|
|
# Key not in cache — fetch fresh JWKS once (keys can rotate)
|
|
global _jwks_cache
|
|
_jwks_cache = None
|
|
jwks = await _get_jwks()
|
|
key = next(
|
|
(k for k in jwks["keys"] if k.get("kid") == header.get("kid")),
|
|
None,
|
|
)
|
|
if key is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown signing key")
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
key,
|
|
algorithms=["RS256"],
|
|
audience=CLIENT_ID,
|
|
issuer=ISSUER,
|
|
options={"verify_at_hash": False},
|
|
)
|
|
return {
|
|
"oid": payload.get("oid"),
|
|
"name": payload.get("name"),
|
|
"email": (
|
|
payload.get("preferred_username")
|
|
or payload.get("upn")
|
|
or payload.get("email")
|
|
),
|
|
}
|
|
except JWTError as e:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {e}")
|