gmal-scope-builder/backend/app/middleware/auth.py
Vadym Samoilenko b7db37828b Fix 401: send ID token instead of Graph access token
Access tokens for User.Read scope have audience=graph.microsoft.com,
but the backend validates audience=CLIENT_ID. ID tokens always have
audience=CLIENT_ID so they validate correctly.

Also add upn claim fallback for email extraction from ID token.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-30 11:16:44 +01:00

78 lines
2.6 KiB
Python

import os
import httpx
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import jwt, JWTError
TENANT_ID = os.environ.get("AZURE_TENANT_ID", "")
CLIENT_ID = os.environ.get("AZURE_CLIENT_ID", "")
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
ISSUER = f"https://login.microsoftonline.com/{TENANT_ID}/v2.0"
bearer_scheme = HTTPBearer(auto_error=False)
# Module-level cache — populated once per process, never blocks the event loop
_jwks_cache: dict | None = None
async def _get_jwks() -> dict:
"""Fetch JWKS from Azure using async HTTP. Cached in process memory."""
global _jwks_cache
if _jwks_cache is not None:
return _jwks_cache
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(JWKS_URL)
response.raise_for_status()
_jwks_cache = response.json()
return _jwks_cache
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
) -> dict:
if os.environ.get("DEV_AUTH_BYPASS", "").lower() in ("1", "true", "yes"):
return {"oid": "dev-user", "name": "Dev User", "email": "dev@localhost"}
if credentials is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
token = credentials.credentials
try:
jwks = await _get_jwks()
header = jwt.get_unverified_header(token)
key = next(
(k for k in jwks["keys"] if k.get("kid") == header.get("kid")),
None,
)
if key is None:
# Key not in cache — fetch fresh JWKS once (keys can rotate)
global _jwks_cache
_jwks_cache = None
jwks = await _get_jwks()
key = next(
(k for k in jwks["keys"] if k.get("kid") == header.get("kid")),
None,
)
if key is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown signing key")
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
audience=CLIENT_ID,
issuer=ISSUER,
options={"verify_at_hash": False},
)
return {
"oid": payload.get("oid"),
"name": payload.get("name"),
"email": (
payload.get("preferred_username")
or payload.get("upn")
or payload.get("email")
),
}
except JWTError as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {e}")