From 321a9ca8200d753c27c64965d1939e67f7953eda Mon Sep 17 00:00:00 2001 From: michael Date: Tue, 16 Dec 2025 08:43:30 -0600 Subject: [PATCH] Implement Microsoft MSAL SSO with PKCE flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Frontend: - Add @azure/msal-browser and @azure/msal-react packages - Create authConfig.ts with MSAL configuration for PKCE flow - Create authService.ts for token acquisition and user info - Wrap App with MsalProvider in index.tsx - Replace dummy login with real MSAL loginPopup() in Login.tsx - Update App.tsx to use useIsAuthenticated/useMsal hooks - Update Profile.tsx to display real user data from claims - Update geminiService.ts to include access_token in WebSocket messages - Update WIPReviewer.tsx to pass msalInstance for auth Backend: - Add python-jose and httpx dependencies for JWT verification - Create auth_service.py with Azure AD JWKS fetching and token verification - Create auth.py FastAPI dependency for protected REST endpoints - Update main.py to verify tokens on WebSocket and protect /info endpoint - Add AZURE_TENANT_ID, AZURE_CLIENT_ID, DISABLE_AUTH to config 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/.env.example | 8 ++ backend/app/config.py | 13 +++ backend/app/dependencies/__init__.py | 0 backend/app/dependencies/auth.py | 56 +++++++++++++ backend/app/main.py | 28 +++++-- backend/app/services/auth_service.py | 118 +++++++++++++++++++++++++++ backend/requirements.txt | 2 + frontend/App.tsx | 48 ++++++++--- frontend/components/Login.tsx | 40 ++++++--- frontend/components/Profile.tsx | 29 +++++-- frontend/components/WIPReviewer.tsx | 10 ++- frontend/index.tsx | 45 +++++++--- frontend/package-lock.json | 42 +++++++++- frontend/package.json | 8 +- frontend/services/authConfig.ts | 52 ++++++++++++ frontend/services/authService.ts | 68 +++++++++++++++ frontend/services/geminiService.ts | 31 +++++-- 17 files changed, 538 insertions(+), 60 deletions(-) create mode 100644 backend/app/dependencies/__init__.py create mode 100644 backend/app/dependencies/auth.py create mode 100644 backend/app/services/auth_service.py create mode 100644 frontend/services/authConfig.ts create mode 100644 frontend/services/authService.ts diff --git a/backend/.env.example b/backend/.env.example index 79342ad..4ae8591 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -13,3 +13,11 @@ CORS_ORIGINS=http://localhost:3000 # Server Configuration HOST=0.0.0.0 PORT=8000 + +# Azure AD Configuration (for Microsoft SSO token verification) +# Get these from your Azure AD app registration +AZURE_TENANT_ID=your_azure_tenant_id_here +AZURE_CLIENT_ID=your_azure_client_id_here + +# Development only - set to "true" to disable authentication (NOT for production) +DISABLE_AUTH=false diff --git a/backend/app/config.py b/backend/app/config.py index f7775bb..43ced75 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -18,10 +18,23 @@ class Settings: _default_ref_docs = Path(__file__).parent.parent.parent / "reference_docs" REFERENCE_DOCS_PATH: str = os.getenv("REFERENCE_DOCS_PATH", str(_default_ref_docs)) + # Azure AD Configuration for token verification + AZURE_TENANT_ID: str = os.getenv("AZURE_TENANT_ID", "") + AZURE_CLIENT_ID: str = os.getenv("AZURE_CLIENT_ID", "") + + # Auth bypass for development (set to "true" to skip auth) + DISABLE_AUTH: bool = os.getenv("DISABLE_AUTH", "false").lower() == "true" + def validate(self) -> None: """Validate required settings are present.""" if not self.GEMINI_API_KEY: raise ValueError("GEMINI_API_KEY environment variable is required") + if not self.DISABLE_AUTH: + if not self.AZURE_TENANT_ID: + raise ValueError("AZURE_TENANT_ID environment variable is required (or set DISABLE_AUTH=true)") + if not self.AZURE_CLIENT_ID: + raise ValueError("AZURE_CLIENT_ID environment variable is required (or set DISABLE_AUTH=true)") + settings = Settings() diff --git a/backend/app/dependencies/__init__.py b/backend/app/dependencies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/dependencies/auth.py b/backend/app/dependencies/auth.py new file mode 100644 index 0000000..b88ec38 --- /dev/null +++ b/backend/app/dependencies/auth.py @@ -0,0 +1,56 @@ +""" +FastAPI authentication dependencies. + +Provides dependency functions for securing REST endpoints with Azure AD token verification. +""" +from typing import Optional +from fastapi import Header, HTTPException, status + +from app.services.auth_service import verify_access_token + + +async def get_current_user(authorization: Optional[str] = Header(None)) -> dict: + """ + FastAPI dependency to verify the access token and return user claims. + + Use as a dependency on protected endpoints: + @app.get("/protected") + async def protected_route(user: dict = Depends(get_current_user)): + return {"message": f"Hello {user.get('name')}"} + + Args: + authorization: The Authorization header value (Bearer ) + + Returns: + The token claims dict containing user information + + Raises: + HTTPException: 401 if token is missing or invalid + """ + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authorization header", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Extract token from "Bearer " format + parts = authorization.split() + if len(parts) != 2 or parts[0].lower() != "bearer": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authorization header format. Expected: Bearer ", + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = parts[1] + claims = await verify_access_token(token) + + if not claims: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return claims diff --git a/backend/app/main.py b/backend/app/main.py index 9e70d6f..d70ba76 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,10 +2,12 @@ import logging import uuid from contextlib import asynccontextmanager -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends from fastapi.middleware.cors import CORSMiddleware from app.config import settings +from app.services.auth_service import verify_access_token +from app.dependencies.auth import get_current_user # Configure logging logging.basicConfig( @@ -86,17 +88,18 @@ async def health_check(): @app.get("/info") -async def info(): - """Get backend information.""" +async def info(user: dict = Depends(get_current_user)): + """Get backend information. Requires authentication.""" if analysis_service: ref_docs = analysis_service.reference_docs doc_summary = ref_docs.get_context_summary() return { "status": "ready", + "user": user.get("name", "Unknown"), "agents": ["Legal Agent", "Brand Agent", "Tone Agent", "Channel Agent"], "reference_docs": doc_summary, } - return {"status": "initializing"} + return {"status": "initializing", "user": user.get("name", "Unknown")} @app.websocket("/ws/analyze") @@ -105,7 +108,8 @@ async def websocket_analyze(websocket: WebSocket): WebSocket endpoint for proof analysis with real-time updates. Protocol: - - Client sends: {"type": "analyze", "file_data": "", "file_type": "image/png", "is_wip": false} + - Client sends: {"type": "analyze", "file_data": "", "file_type": "image/png", "is_wip": false, "access_token": ""} + - Server verifies token before processing - Server sends: {"type": "agent_started", "agent_name": "..."} - Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}} - Server sends: {"type": "complete", "result": {...}} @@ -122,6 +126,20 @@ async def websocket_analyze(websocket: WebSocket): logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}") if data.get("type") == "analyze": + # Verify access token from message + access_token = data.get("access_token") + user_claims = await verify_access_token(access_token) + + if not user_claims: + logger.warning(f"[MAIN] Authentication failed for client {client_id}") + await manager.send_message(client_id, { + "type": "error", + "message": "Authentication failed. Please sign in again." + }) + continue + + logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}") + if analysis_service is None: logger.error("[MAIN] Analysis service not ready") await manager.send_message(client_id, { diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..b2e3d62 --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,118 @@ +""" +Azure AD token verification service. + +Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS). +""" +import logging +from typing import Optional +from datetime import datetime, timedelta + +import httpx +from jose import jwt, JWTError + +from app.config import settings + +logger = logging.getLogger(__name__) + +# Cache for JWKS (JSON Web Key Set) +_jwks_cache: dict = {} +_jwks_cache_expiry: datetime = datetime.min + + +async def get_azure_jwks() -> dict: + """ + Fetch and cache Azure AD's public keys for token verification. + Keys are cached for 24 hours to minimize network calls. + """ + global _jwks_cache, _jwks_cache_expiry + + if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache: + return _jwks_cache + + jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys" + + try: + async with httpx.AsyncClient() as client: + response = await client.get(jwks_url, timeout=10.0) + response.raise_for_status() + _jwks_cache = response.json() + _jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24) + logger.info("Successfully fetched Azure AD JWKS") + return _jwks_cache + except Exception as e: + logger.error(f"Failed to fetch JWKS: {e}") + if _jwks_cache: # Return stale cache if available + return _jwks_cache + raise + + +async def verify_access_token(token: str) -> Optional[dict]: + """ + Verify an Azure AD access token and return the claims. + + Args: + token: The JWT access token from the frontend + + Returns: + The token claims dict if valid, None if invalid + """ + if settings.DISABLE_AUTH: + logger.warning("Auth disabled - skipping token verification") + return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"} + + if not token: + logger.warning("No token provided") + return None + + try: + # Get Azure AD public keys + jwks = await get_azure_jwks() + + # Decode without verification first to get the key ID + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + + if not kid: + logger.warning("No key ID in token header") + return None + + # Find the matching key + rsa_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + rsa_key = key + break + + if not rsa_key: + logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache") + # Try refreshing JWKS in case keys rotated + global _jwks_cache_expiry + _jwks_cache_expiry = datetime.min + jwks = await get_azure_jwks() + for key in jwks.get("keys", []): + if key.get("kid") == kid: + rsa_key = key + break + + if not rsa_key: + logger.error("Could not find matching key after refresh") + return None + + # Verify and decode the token + claims = jwt.decode( + token, + rsa_key, + algorithms=["RS256"], + audience=f"api://{settings.AZURE_CLIENT_ID}", + issuer=f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/", + ) + + logger.info(f"Token verified for user: {claims.get('name', 'unknown')}") + return claims + + except JWTError as e: + logger.warning(f"JWT verification failed: {e}") + return None + except Exception as e: + logger.error(f"Token verification error: {e}") + return None diff --git a/backend/requirements.txt b/backend/requirements.txt index 8e716f3..d15d752 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,3 +6,5 @@ pydantic>=2.5.0 python-multipart>=0.0.9 aiofiles>=23.2.1 websockets>=12.0 +python-jose[cryptography]>=3.3.0 +httpx>=0.26.0 diff --git a/frontend/App.tsx b/frontend/App.tsx index 185170d..1fcccc1 100644 --- a/frontend/App.tsx +++ b/frontend/App.tsx @@ -1,5 +1,7 @@ import React, { useState, useEffect } from 'react'; +import { useIsAuthenticated, useMsal } from '@azure/msal-react'; +import { InteractionStatus } from '@azure/msal-browser'; import { Hero } from './components/Hero'; import { analyzeProof } from './services/geminiService'; import type { AgentReview, AgentName, FlaggedItem, ResolvedItem, ErrorItem } from './types'; @@ -24,7 +26,10 @@ export interface DropdownOptions { } const App: React.FC = () => { - const [isLoggedIn, setIsLoggedIn] = useState(false); + // MSAL authentication state + const isAuthenticated = useIsAuthenticated(); + const { instance: msalInstance, inProgress } = useMsal(); + const [currentView, setCurrentView] = useState('Home'); const [selectedCampaign, setSelectedCampaign] = useState(null); const [selectedProof, setSelectedProof] = useState(null); @@ -287,7 +292,7 @@ const App: React.FC = () => { }; try { - const feedback = await analyzeProof(file, handleAgentUpdate); + const feedback = await analyzeProof(file, handleAgentUpdate, msalInstance); const previewUrl = await fileToDataUrl(file); if (feedback.overallStatus === 'Analysis Error') { @@ -435,7 +440,7 @@ const App: React.FC = () => { }; try { - const feedback = await analyzeProof(file, handleAgentUpdateForRetry); + const feedback = await analyzeProof(file, handleAgentUpdateForRetry, msalInstance); const previewUrl = await fileToDataUrl(file); const newWorkfrontId = `#WF_${Math.floor(10000 + Math.random() * 90000)}`; @@ -723,12 +728,14 @@ const App: React.FC = () => { } }; - const handleLogin = () => { - setIsLoggedIn(true); - }; - - const handleLogout = () => { - setIsLoggedIn(false); + const handleLogout = async () => { + try { + await msalInstance.logoutPopup({ + postLogoutRedirectUri: window.location.origin, + }); + } catch (error) { + console.error('Logout failed:', error); + } }; const renderContent = () => { @@ -736,7 +743,7 @@ const App: React.FC = () => { case 'Analytics': return ; case 'Profile': - return ; + return ; case 'CopyGenAI': return ; case 'Campaigns': @@ -759,7 +766,7 @@ const App: React.FC = () => { onResolveSubmit={handleResolveSubmit} />; case 'WIP Reviewer': - return ; + return ; case 'Auditing': return { } }; - if (!isLoggedIn) { - return ; + // Show loading spinner during MSAL authentication interactions + if (inProgress !== InteractionStatus.None) { + return ( +
+
+ + + + +

Authenticating...

+
+
+ ); + } + + if (!isAuthenticated) { + return ; } // Determine background color based on view to avoid grey bar on Home view diff --git a/frontend/components/Login.tsx b/frontend/components/Login.tsx index 330a4eb..a1d88f0 100644 --- a/frontend/components/Login.tsx +++ b/frontend/components/Login.tsx @@ -1,12 +1,10 @@ import React, { useState } from 'react'; +import { useMsal } from '@azure/msal-react'; import { BarclaysLogo } from './icons/BarclaysLogo'; import { XIcon } from './icons/XIcon'; import { MicrosoftLogo } from './icons/MicrosoftLogo'; - -interface LoginProps { - onLogin: () => void; -} +import { loginRequest } from '../services/authConfig'; const SupportModal: React.FC<{ isOpen: boolean; @@ -72,9 +70,11 @@ const SupportModal: React.FC<{ ); }; -export const Login: React.FC = ({ onLogin }) => { +export const Login: React.FC = () => { + const { instance } = useMsal(); const [isSupportModalOpen, setIsSupportModalOpen] = useState(false); const [isLoggingIn, setIsLoggingIn] = useState(false); + const [loginError, setLoginError] = useState(null); const handleSupportSubmit = (query: string) => { console.log("Support query submitted:", query); @@ -82,12 +82,26 @@ export const Login: React.FC = ({ onLogin }) => { setIsSupportModalOpen(false); }; - const handleMicrosoftLogin = () => { + const handleMicrosoftLogin = async () => { setIsLoggingIn(true); - // Simulate redirect/auth delay - setTimeout(() => { - onLogin(); - }, 1500); + setLoginError(null); + + try { + await instance.loginPopup(loginRequest); + // Success - MSAL Provider will detect the login and re-render App + } catch (error: unknown) { + console.error('Login failed:', error); + if (error instanceof Error) { + // Handle user cancellation differently from errors + if (error.message.includes('user_cancelled')) { + setLoginError(null); // Don't show error for cancellation + } else { + setLoginError('Login failed. Please try again or contact support.'); + } + } + } finally { + setIsLoggingIn(false); + } }; return ( @@ -130,6 +144,12 @@ export const Login: React.FC = ({ onLogin }) => {

+ {loginError && ( +
+ {loginError} +
+ )} +