diff --git a/.env.example b/.env.example index 0971b0e..816913d 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ ANTHROPIC_API_KEY=sk-ant-xxxxx AZURE_TENANT_ID=e519c2e6-bc6d-4fdf-8d9c-923c2f002385 AZURE_CLIENT_ID=9079054c-9620-4757-a256-23413042f1ef -AZURE_REDIRECT_URI=https://ai-sandbox.oliver.solutions/Pimco-charts/auth/callback +AZURE_REDIRECT_URI=https://ai-sandbox.oliver.solutions/Pimco-charts SESSION_SECRET_KEY= diff --git a/app/auth/middleware.py b/app/auth/middleware.py index eedc0eb..c6a11aa 100644 --- a/app/auth/middleware.py +++ b/app/auth/middleware.py @@ -2,7 +2,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import RedirectResponse, Response -EXEMPT_PATHS = {"/auth/login", "/auth/callback", "/auth/logout"} +EXEMPT_PATHS = {"/auth/login", "/auth/logout"} class AuthMiddleware(BaseHTTPMiddleware): @@ -17,6 +17,10 @@ class AuthMiddleware(BaseHTTPMiddleware): if path in EXEMPT_PATHS: return await call_next(request) + # OAuth callback arrives at "/" with ?code= query param + if path in ("/", "") and request.query_params.get("code"): + return await call_next(request) + if not request.session.get("user"): if request.headers.get("HX-Request"): return Response( diff --git a/app/auth/routes.py b/app/auth/routes.py index 5402566..f087e63 100644 --- a/app/auth/routes.py +++ b/app/auth/routes.py @@ -3,14 +3,14 @@ import secrets from fastapi import APIRouter, Request from fastapi.responses import RedirectResponse -from app.auth.msal_client import build_auth_url, exchange_code, generate_pkce_pair +from app.auth.msal_client import build_auth_url, generate_pkce_pair from app.config import AZURE_TENANT_ID, AZURE_REDIRECT_URI router = APIRouter(prefix="/auth") LOGOUT_URL = ( f"https://login.microsoftonline.com/{AZURE_TENANT_ID}/oauth2/v2.0/logout" - f"?post_logout_redirect_uri=https://ai-sandbox.oliver.solutions/Pimco-charts" + f"?post_logout_redirect_uri={AZURE_REDIRECT_URI}" ) @@ -24,30 +24,6 @@ async def login(request: Request): return RedirectResponse(url=auth_url) -@router.get("/callback") -async def callback(request: Request, code: str = "", state: str = "", error: str = ""): - if error: - return RedirectResponse(url="/auth/login") - - stored_state = request.session.pop("oauth_state", None) - verifier = request.session.pop("pkce_verifier", None) - - if not stored_state or state != stored_state or not verifier: - return RedirectResponse(url="/auth/login") - - try: - result = exchange_code(code=code, verifier=verifier) - except ValueError: - return RedirectResponse(url="/auth/login") - - claims = result.get("id_token_claims", {}) - request.session["user"] = { - "name": claims.get("name", claims.get("preferred_username", "User")), - "email": claims.get("email", claims.get("preferred_username", "")), - "oid": claims.get("oid", ""), - } - return RedirectResponse(url="/") - @router.get("/logout") async def logout(request: Request): diff --git a/app/main.py b/app/main.py index 3103783..34d08ad 100644 --- a/app/main.py +++ b/app/main.py @@ -40,7 +40,26 @@ _sessions: dict[str, dict] = {} @app.get("/", response_class=HTMLResponse) -async def index(request: Request): +async def index(request: Request, code: str = "", state: str = "", error: str = ""): + # OAuth callback: Azure redirects back to root with ?code=&state= + if code: + stored_state = request.session.pop("oauth_state", None) + verifier = request.session.pop("pkce_verifier", None) + if stored_state and state == stored_state and verifier: + try: + from app.auth.msal_client import exchange_code + result = exchange_code(code=code, verifier=verifier) + claims = result.get("id_token_claims", {}) + request.session["user"] = { + "name": claims.get("name", claims.get("preferred_username", "User")), + "email": claims.get("email", claims.get("preferred_username", "")), + "oid": claims.get("oid", ""), + } + except ValueError: + pass + from fastapi.responses import RedirectResponse as RR + return RR(url="/") + return templates.TemplateResponse("upload.html", { "request": request, "user": request.session.get("user"),