Switch to browser-side auth: MSAL.js + JWT validation
- MSAL.js handles full OAuth flow in browser (SPA-compatible) - Server validates ID token signature via Azure AD JWKS endpoint (PyJWT) - Root / serves MSAL shell for unauthenticated users, handles redirect callback - Remove Python MSAL/PKCE server-side logic - Replace msal dependency with PyJWT[crypto] Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
5c8f849151
commit
eff02c145e
7 changed files with 122 additions and 77 deletions
7
.claude/settings.local.json
Normal file
7
.claude/settings.local.json
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(chmod +x /Volumes/SSD/Projects/Oliver/pimco-charts/deploy.sh)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -2,14 +2,12 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
EXEMPT_PATHS = {"/auth/login", "/auth/logout"}
|
||||
EXEMPT_PATHS = {"/", "", "/auth/token", "/auth/logout"}
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
path = request.url.path
|
||||
|
||||
# Strip root_path prefix for matching
|
||||
root_path = request.scope.get("root_path", "")
|
||||
if root_path and path.startswith(root_path):
|
||||
path = path[len(root_path):]
|
||||
|
|
@ -17,18 +15,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||
if path in EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# OAuth callback arrives at "/" with ?code= query param
|
||||
if path in ("/", "") and request.query_params.get("code"):
|
||||
return await call_next(request)
|
||||
|
||||
if not request.session.get("user"):
|
||||
root = request.scope.get("root_path", "")
|
||||
login_url = f"{root}/auth/login"
|
||||
root = root_path
|
||||
if request.headers.get("HX-Request"):
|
||||
return Response(
|
||||
status_code=401,
|
||||
headers={"HX-Redirect": login_url},
|
||||
headers={"HX-Redirect": f"{root}/"},
|
||||
)
|
||||
return RedirectResponse(url=login_url)
|
||||
return RedirectResponse(url=f"{root}/")
|
||||
|
||||
return await call_next(request)
|
||||
|
|
|
|||
|
|
@ -1,37 +1,28 @@
|
|||
import msal
|
||||
import secrets
|
||||
import hashlib
|
||||
import base64
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
from app.config import AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_REDIRECT_URI
|
||||
from app.config import AZURE_TENANT_ID, AZURE_CLIENT_ID
|
||||
|
||||
AUTHORITY = f"https://login.microsoftonline.com/{AZURE_TENANT_ID}"
|
||||
SCOPES = ["User.Read"]
|
||||
_jwks_client: PyJWKClient | None = None
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
verifier = secrets.token_urlsafe(48)
|
||||
challenge = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(verifier.encode()).digest()
|
||||
).rstrip(b"=").decode()
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def get_msal_app():
|
||||
return msal.PublicClientApplication(AZURE_CLIENT_ID, authority=AUTHORITY)
|
||||
|
||||
|
||||
def build_auth_url(state: str, challenge: str) -> str:
|
||||
return get_msal_app().get_authorization_request_url(
|
||||
SCOPES, state=state, redirect_uri=AZURE_REDIRECT_URI,
|
||||
code_challenge=challenge, code_challenge_method="S256"
|
||||
def _get_jwks_client() -> PyJWKClient:
|
||||
global _jwks_client
|
||||
if _jwks_client is None:
|
||||
_jwks_client = PyJWKClient(
|
||||
f"https://login.microsoftonline.com/{AZURE_TENANT_ID}/discovery/v2.0/keys"
|
||||
)
|
||||
return _jwks_client
|
||||
|
||||
|
||||
def exchange_code(code: str, verifier: str) -> dict:
|
||||
result = get_msal_app().acquire_token_by_authorization_code(
|
||||
code, SCOPES, redirect_uri=AZURE_REDIRECT_URI, code_verifier=verifier
|
||||
def validate_id_token(id_token: str) -> dict:
|
||||
client = _get_jwks_client()
|
||||
signing_key = client.get_signing_key_from_jwt(id_token)
|
||||
claims = jwt.decode(
|
||||
id_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
audience=AZURE_CLIENT_ID,
|
||||
issuer=f"https://login.microsoftonline.com/{AZURE_TENANT_ID}/v2.0",
|
||||
)
|
||||
if "error" in result:
|
||||
raise ValueError(result.get("error_description", "Auth failed"))
|
||||
return result
|
||||
return claims
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import secrets
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import RedirectResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth.msal_client import build_auth_url, generate_pkce_pair
|
||||
from app.auth.msal_client import validate_id_token
|
||||
from app.config import AZURE_TENANT_ID, AZURE_REDIRECT_URI
|
||||
|
||||
router = APIRouter(prefix="/auth")
|
||||
|
|
@ -14,16 +13,25 @@ LOGOUT_URL = (
|
|||
)
|
||||
|
||||
|
||||
@router.get("/login")
|
||||
async def login(request: Request):
|
||||
state = secrets.token_urlsafe(16)
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
request.session["oauth_state"] = state
|
||||
request.session["pkce_verifier"] = verifier
|
||||
auth_url = build_auth_url(state=state, challenge=challenge)
|
||||
return RedirectResponse(url=auth_url, status_code=302)
|
||||
class TokenRequest(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/token")
|
||||
async def token(request: Request, body: TokenRequest):
|
||||
try:
|
||||
claims = validate_id_token(body.token)
|
||||
except Exception:
|
||||
return JSONResponse({"error": "invalid token"}, status_code=401)
|
||||
|
||||
request.session["user"] = {
|
||||
"name": claims.get("name", claims.get("preferred_username", "User")),
|
||||
"email": claims.get("email", claims.get("preferred_username", "")),
|
||||
"oid": claims.get("oid", ""),
|
||||
}
|
||||
root = request.scope.get("root_path", "")
|
||||
return JSONResponse({"redirect": f"{root}/"})
|
||||
|
||||
|
||||
@router.get("/logout")
|
||||
async def logout(request: Request):
|
||||
|
|
|
|||
38
app/main.py
38
app/main.py
|
|
@ -13,7 +13,10 @@ from fastapi.templating import Jinja2Templates
|
|||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
||||
|
||||
from app.config import OUTPUT_DIR, STATIC_DIR, TEMPLATES_DIR, SESSION_SECRET_KEY
|
||||
from app.config import (
|
||||
OUTPUT_DIR, STATIC_DIR, TEMPLATES_DIR, SESSION_SECRET_KEY,
|
||||
AZURE_CLIENT_ID, AZURE_TENANT_ID, AZURE_REDIRECT_URI,
|
||||
)
|
||||
from app.auth.middleware import AuthMiddleware
|
||||
from app.auth.routes import router as auth_router
|
||||
from app.data.loader import load_file
|
||||
|
|
@ -40,30 +43,19 @@ _sessions: dict[str, dict] = {}
|
|||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request, code: str = "", state: str = "", error: str = ""):
|
||||
# OAuth callback: Azure redirects back to root with ?code=&state=
|
||||
if code:
|
||||
stored_state = request.session.pop("oauth_state", None)
|
||||
verifier = request.session.pop("pkce_verifier", None)
|
||||
if stored_state and state == stored_state and verifier:
|
||||
try:
|
||||
from app.auth.msal_client import exchange_code
|
||||
result = exchange_code(code=code, verifier=verifier)
|
||||
claims = result.get("id_token_claims", {})
|
||||
request.session["user"] = {
|
||||
"name": claims.get("name", claims.get("preferred_username", "User")),
|
||||
"email": claims.get("email", claims.get("preferred_username", "")),
|
||||
"oid": claims.get("oid", ""),
|
||||
}
|
||||
except ValueError:
|
||||
pass
|
||||
from fastapi.responses import RedirectResponse as RR
|
||||
root = request.scope.get("root_path", "")
|
||||
return RR(url=f"{root}/")
|
||||
|
||||
async def index(request: Request):
|
||||
user = request.session.get("user")
|
||||
if not user:
|
||||
return templates.TemplateResponse("msal_shell.html", {
|
||||
"request": request,
|
||||
"client_id": AZURE_CLIENT_ID,
|
||||
"tenant_id": AZURE_TENANT_ID,
|
||||
"redirect_uri": AZURE_REDIRECT_URI,
|
||||
"root_path": request.scope.get("root_path", ""),
|
||||
})
|
||||
return templates.TemplateResponse("upload.html", {
|
||||
"request": request,
|
||||
"user": request.session.get("user"),
|
||||
"user": user,
|
||||
})
|
||||
|
||||
|
||||
|
|
|
|||
54
app/templates/msal_shell.html
Normal file
54
app/templates/msal_shell.html
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>PIMCO Chart Generator</title>
|
||||
<script src="https://alcdn.msauth.net/browser/2.38.3/js/msal-browser.min.js"></script>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500&family=Roboto+Condensed:wght@400;700&display=swap" rel="stylesheet">
|
||||
<link rel="stylesheet" href="{{ root_path }}/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1>PIMCO Chart Generator</h1>
|
||||
<p class="subtitle" id="status">Redirecting to Microsoft login...</p>
|
||||
</header>
|
||||
</div>
|
||||
<script>
|
||||
const msalInstance = new msal.PublicClientApplication({
|
||||
auth: {
|
||||
clientId: "{{ client_id }}",
|
||||
authority: "https://login.microsoftonline.com/{{ tenant_id }}",
|
||||
redirectUri: "{{ redirect_uri }}"
|
||||
},
|
||||
cache: { cacheLocation: "sessionStorage" }
|
||||
});
|
||||
|
||||
msalInstance.handleRedirectPromise()
|
||||
.then(async (response) => {
|
||||
if (response && response.idToken) {
|
||||
document.getElementById("status").textContent = "Completing sign-in...";
|
||||
const res = await fetch("{{ root_path }}/auth/token", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ token: response.idToken })
|
||||
});
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
window.location.href = data.redirect;
|
||||
} else {
|
||||
document.getElementById("status").textContent = "Sign-in failed. Please refresh.";
|
||||
}
|
||||
} else {
|
||||
msalInstance.loginRedirect({ scopes: ["User.Read"] });
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(err);
|
||||
document.getElementById("status").textContent = "Sign-in error. Please refresh the page.";
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
annotated-doc==0.0.4
|
||||
msal==1.31.0
|
||||
itsdangerous==2.2.0
|
||||
PyJWT[crypto]==2.8.0
|
||||
annotated-types==0.7.0
|
||||
anthropic==0.84.0
|
||||
anyio==4.12.1
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue