68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
import os
|
|
import httpx
|
|
from functools import lru_cache
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from jose import jwt, JWTError
|
|
|
|
TENANT_ID = os.environ.get("AZURE_TENANT_ID", "")
|
|
CLIENT_ID = os.environ.get("AZURE_CLIENT_ID", "")
|
|
|
|
JWKS_URL = f"https://login.microsoftonline.com/{TENANT_ID}/discovery/v2.0/keys"
|
|
ISSUER = f"https://login.microsoftonline.com/{TENANT_ID}/v2.0"
|
|
|
|
bearer_scheme = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _fetch_jwks() -> dict:
|
|
"""Fetch JWKS from Azure. Cached in process memory; restart to refresh."""
|
|
response = httpx.get(JWKS_URL, timeout=10)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
def _get_jwks() -> dict:
|
|
try:
|
|
return _fetch_jwks()
|
|
except Exception:
|
|
# Clear cache and retry once on failure
|
|
_fetch_jwks.cache_clear()
|
|
return _fetch_jwks()
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
|
) -> dict:
|
|
if os.environ.get("DEV_AUTH_BYPASS", "").lower() in ("1", "true", "yes"):
|
|
return {"oid": "dev-user", "name": "Dev User", "email": "dev@localhost"}
|
|
|
|
if credentials is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
|
|
token = credentials.credentials
|
|
try:
|
|
jwks = _get_jwks()
|
|
header = jwt.get_unverified_header(token)
|
|
key = next(
|
|
(k for k in jwks["keys"] if k.get("kid") == header.get("kid")),
|
|
None,
|
|
)
|
|
if key is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unknown signing key")
|
|
|
|
payload = jwt.decode(
|
|
token,
|
|
key,
|
|
algorithms=["RS256"],
|
|
audience=CLIENT_ID,
|
|
issuer=ISSUER,
|
|
options={"verify_at_hash": False},
|
|
)
|
|
return {
|
|
"oid": payload.get("oid"),
|
|
"name": payload.get("name"),
|
|
"email": payload.get("preferred_username") or payload.get("email"),
|
|
}
|
|
except JWTError as e:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {e}")
|