diff --git a/backend/.env.example b/backend/.env.example
index 79342ad..4ae8591 100644
--- a/backend/.env.example
+++ b/backend/.env.example
@@ -13,3 +13,11 @@ CORS_ORIGINS=http://localhost:3000
# Server Configuration
HOST=0.0.0.0
PORT=8000
+
+# Azure AD Configuration (for Microsoft SSO token verification)
+# Get these from your Azure AD app registration
+AZURE_TENANT_ID=your_azure_tenant_id_here
+AZURE_CLIENT_ID=your_azure_client_id_here
+
+# Development only - set to "true" to disable authentication (NOT for production)
+DISABLE_AUTH=false
diff --git a/backend/app/config.py b/backend/app/config.py
index f7775bb..43ced75 100644
--- a/backend/app/config.py
+++ b/backend/app/config.py
@@ -18,10 +18,23 @@ class Settings:
_default_ref_docs = Path(__file__).parent.parent.parent / "reference_docs"
REFERENCE_DOCS_PATH: str = os.getenv("REFERENCE_DOCS_PATH", str(_default_ref_docs))
+ # Azure AD Configuration for token verification
+ AZURE_TENANT_ID: str = os.getenv("AZURE_TENANT_ID", "")
+ AZURE_CLIENT_ID: str = os.getenv("AZURE_CLIENT_ID", "")
+
+ # Auth bypass for development (set to "true" to skip auth)
+ DISABLE_AUTH: bool = os.getenv("DISABLE_AUTH", "false").lower() == "true"
+
def validate(self) -> None:
"""Validate required settings are present."""
if not self.GEMINI_API_KEY:
raise ValueError("GEMINI_API_KEY environment variable is required")
+ if not self.DISABLE_AUTH:
+ if not self.AZURE_TENANT_ID:
+ raise ValueError("AZURE_TENANT_ID environment variable is required (or set DISABLE_AUTH=true)")
+ if not self.AZURE_CLIENT_ID:
+ raise ValueError("AZURE_CLIENT_ID environment variable is required (or set DISABLE_AUTH=true)")
+
settings = Settings()
diff --git a/backend/app/dependencies/__init__.py b/backend/app/dependencies/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/backend/app/dependencies/auth.py b/backend/app/dependencies/auth.py
new file mode 100644
index 0000000..b88ec38
--- /dev/null
+++ b/backend/app/dependencies/auth.py
@@ -0,0 +1,56 @@
+"""
+FastAPI authentication dependencies.
+
+Provides dependency functions for securing REST endpoints with Azure AD token verification.
+"""
+from typing import Optional
+from fastapi import Header, HTTPException, status
+
+from app.services.auth_service import verify_access_token
+
+
+async def get_current_user(authorization: Optional[str] = Header(None)) -> dict:
+ """
+ FastAPI dependency to verify the access token and return user claims.
+
+ Use as a dependency on protected endpoints:
+ @app.get("/protected")
+ async def protected_route(user: dict = Depends(get_current_user)):
+ return {"message": f"Hello {user.get('name')}"}
+
+ Args:
+ authorization: The Authorization header value (Bearer )
+
+ Returns:
+ The token claims dict containing user information
+
+ Raises:
+ HTTPException: 401 if token is missing or invalid
+ """
+ if not authorization:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Missing authorization header",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ # Extract token from "Bearer " format
+ parts = authorization.split()
+ if len(parts) != 2 or parts[0].lower() != "bearer":
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid authorization header format. Expected: Bearer ",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ token = parts[1]
+ claims = await verify_access_token(token)
+
+ if not claims:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or expired token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ return claims
diff --git a/backend/app/main.py b/backend/app/main.py
index 9e70d6f..d70ba76 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -2,10 +2,12 @@ import logging
import uuid
from contextlib import asynccontextmanager
-from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
+from app.services.auth_service import verify_access_token
+from app.dependencies.auth import get_current_user
# Configure logging
logging.basicConfig(
@@ -86,17 +88,18 @@ async def health_check():
@app.get("/info")
-async def info():
- """Get backend information."""
+async def info(user: dict = Depends(get_current_user)):
+ """Get backend information. Requires authentication."""
if analysis_service:
ref_docs = analysis_service.reference_docs
doc_summary = ref_docs.get_context_summary()
return {
"status": "ready",
+ "user": user.get("name", "Unknown"),
"agents": ["Legal Agent", "Brand Agent", "Tone Agent", "Channel Agent"],
"reference_docs": doc_summary,
}
- return {"status": "initializing"}
+ return {"status": "initializing", "user": user.get("name", "Unknown")}
@app.websocket("/ws/analyze")
@@ -105,7 +108,8 @@ async def websocket_analyze(websocket: WebSocket):
WebSocket endpoint for proof analysis with real-time updates.
Protocol:
- - Client sends: {"type": "analyze", "file_data": "", "file_type": "image/png", "is_wip": false}
+ - Client sends: {"type": "analyze", "file_data": "", "file_type": "image/png", "is_wip": false, "access_token": ""}
+ - Server verifies token before processing
- Server sends: {"type": "agent_started", "agent_name": "..."}
- Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}}
- Server sends: {"type": "complete", "result": {...}}
@@ -122,6 +126,20 @@ async def websocket_analyze(websocket: WebSocket):
logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}")
if data.get("type") == "analyze":
+ # Verify access token from message
+ access_token = data.get("access_token")
+ user_claims = await verify_access_token(access_token)
+
+ if not user_claims:
+ logger.warning(f"[MAIN] Authentication failed for client {client_id}")
+ await manager.send_message(client_id, {
+ "type": "error",
+ "message": "Authentication failed. Please sign in again."
+ })
+ continue
+
+ logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}")
+
if analysis_service is None:
logger.error("[MAIN] Analysis service not ready")
await manager.send_message(client_id, {
diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py
new file mode 100644
index 0000000..b2e3d62
--- /dev/null
+++ b/backend/app/services/auth_service.py
@@ -0,0 +1,118 @@
+"""
+Azure AD token verification service.
+
+Validates JWT access tokens from the frontend using Azure AD's public keys (JWKS).
+"""
+import logging
+from typing import Optional
+from datetime import datetime, timedelta
+
+import httpx
+from jose import jwt, JWTError
+
+from app.config import settings
+
+logger = logging.getLogger(__name__)
+
+# Cache for JWKS (JSON Web Key Set)
+_jwks_cache: dict = {}
+_jwks_cache_expiry: datetime = datetime.min
+
+
+async def get_azure_jwks() -> dict:
+ """
+ Fetch and cache Azure AD's public keys for token verification.
+ Keys are cached for 24 hours to minimize network calls.
+ """
+ global _jwks_cache, _jwks_cache_expiry
+
+ if datetime.utcnow() < _jwks_cache_expiry and _jwks_cache:
+ return _jwks_cache
+
+ jwks_url = f"https://login.microsoftonline.com/{settings.AZURE_TENANT_ID}/discovery/v2.0/keys"
+
+ try:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(jwks_url, timeout=10.0)
+ response.raise_for_status()
+ _jwks_cache = response.json()
+ _jwks_cache_expiry = datetime.utcnow() + timedelta(hours=24)
+ logger.info("Successfully fetched Azure AD JWKS")
+ return _jwks_cache
+ except Exception as e:
+ logger.error(f"Failed to fetch JWKS: {e}")
+ if _jwks_cache: # Return stale cache if available
+ return _jwks_cache
+ raise
+
+
+async def verify_access_token(token: str) -> Optional[dict]:
+ """
+ Verify an Azure AD access token and return the claims.
+
+ Args:
+ token: The JWT access token from the frontend
+
+ Returns:
+ The token claims dict if valid, None if invalid
+ """
+ if settings.DISABLE_AUTH:
+ logger.warning("Auth disabled - skipping token verification")
+ return {"sub": "dev-user", "name": "Development User", "preferred_username": "dev@localhost"}
+
+ if not token:
+ logger.warning("No token provided")
+ return None
+
+ try:
+ # Get Azure AD public keys
+ jwks = await get_azure_jwks()
+
+ # Decode without verification first to get the key ID
+ unverified_header = jwt.get_unverified_header(token)
+ kid = unverified_header.get("kid")
+
+ if not kid:
+ logger.warning("No key ID in token header")
+ return None
+
+ # Find the matching key
+ rsa_key = None
+ for key in jwks.get("keys", []):
+ if key.get("kid") == kid:
+ rsa_key = key
+ break
+
+ if not rsa_key:
+ logger.warning(f"Key ID {kid} not found in JWKS, refreshing cache")
+ # Try refreshing JWKS in case keys rotated
+ global _jwks_cache_expiry
+ _jwks_cache_expiry = datetime.min
+ jwks = await get_azure_jwks()
+ for key in jwks.get("keys", []):
+ if key.get("kid") == kid:
+ rsa_key = key
+ break
+
+ if not rsa_key:
+ logger.error("Could not find matching key after refresh")
+ return None
+
+ # Verify and decode the token
+ claims = jwt.decode(
+ token,
+ rsa_key,
+ algorithms=["RS256"],
+ audience=f"api://{settings.AZURE_CLIENT_ID}",
+ issuer=f"https://sts.windows.net/{settings.AZURE_TENANT_ID}/",
+ )
+
+ logger.info(f"Token verified for user: {claims.get('name', 'unknown')}")
+ return claims
+
+ except JWTError as e:
+ logger.warning(f"JWT verification failed: {e}")
+ return None
+ except Exception as e:
+ logger.error(f"Token verification error: {e}")
+ return None
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 8e716f3..d15d752 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -6,3 +6,5 @@ pydantic>=2.5.0
python-multipart>=0.0.9
aiofiles>=23.2.1
websockets>=12.0
+python-jose[cryptography]>=3.3.0
+httpx>=0.26.0
diff --git a/frontend/App.tsx b/frontend/App.tsx
index 185170d..1fcccc1 100644
--- a/frontend/App.tsx
+++ b/frontend/App.tsx
@@ -1,5 +1,7 @@
import React, { useState, useEffect } from 'react';
+import { useIsAuthenticated, useMsal } from '@azure/msal-react';
+import { InteractionStatus } from '@azure/msal-browser';
import { Hero } from './components/Hero';
import { analyzeProof } from './services/geminiService';
import type { AgentReview, AgentName, FlaggedItem, ResolvedItem, ErrorItem } from './types';
@@ -24,7 +26,10 @@ export interface DropdownOptions {
}
const App: React.FC = () => {
- const [isLoggedIn, setIsLoggedIn] = useState(false);
+ // MSAL authentication state
+ const isAuthenticated = useIsAuthenticated();
+ const { instance: msalInstance, inProgress } = useMsal();
+
const [currentView, setCurrentView] = useState('Home');
const [selectedCampaign, setSelectedCampaign] = useState(null);
const [selectedProof, setSelectedProof] = useState(null);
@@ -287,7 +292,7 @@ const App: React.FC = () => {
};
try {
- const feedback = await analyzeProof(file, handleAgentUpdate);
+ const feedback = await analyzeProof(file, handleAgentUpdate, msalInstance);
const previewUrl = await fileToDataUrl(file);
if (feedback.overallStatus === 'Analysis Error') {
@@ -435,7 +440,7 @@ const App: React.FC = () => {
};
try {
- const feedback = await analyzeProof(file, handleAgentUpdateForRetry);
+ const feedback = await analyzeProof(file, handleAgentUpdateForRetry, msalInstance);
const previewUrl = await fileToDataUrl(file);
const newWorkfrontId = `#WF_${Math.floor(10000 + Math.random() * 90000)}`;
@@ -723,12 +728,14 @@ const App: React.FC = () => {
}
};
- const handleLogin = () => {
- setIsLoggedIn(true);
- };
-
- const handleLogout = () => {
- setIsLoggedIn(false);
+ const handleLogout = async () => {
+ try {
+ await msalInstance.logoutPopup({
+ postLogoutRedirectUri: window.location.origin,
+ });
+ } catch (error) {
+ console.error('Logout failed:', error);
+ }
};
const renderContent = () => {
@@ -736,7 +743,7 @@ const App: React.FC = () => {
case 'Analytics':
return ;
case 'Profile':
- return ;
+ return ;
case 'CopyGenAI':
return ;
case 'Campaigns':
@@ -759,7 +766,7 @@ const App: React.FC = () => {
onResolveSubmit={handleResolveSubmit}
/>;
case 'WIP Reviewer':
- return ;
+ return ;
case 'Auditing':
return {
}
};
- if (!isLoggedIn) {
- return ;
+ // Show loading spinner during MSAL authentication interactions
+ if (inProgress !== InteractionStatus.None) {
+ return (
+
+
+
+
Authenticating...
+
+
+ );
+ }
+
+ if (!isAuthenticated) {
+ return ;
}
// Determine background color based on view to avoid grey bar on Home view
diff --git a/frontend/components/Login.tsx b/frontend/components/Login.tsx
index 330a4eb..a1d88f0 100644
--- a/frontend/components/Login.tsx
+++ b/frontend/components/Login.tsx
@@ -1,12 +1,10 @@
import React, { useState } from 'react';
+import { useMsal } from '@azure/msal-react';
import { BarclaysLogo } from './icons/BarclaysLogo';
import { XIcon } from './icons/XIcon';
import { MicrosoftLogo } from './icons/MicrosoftLogo';
-
-interface LoginProps {
- onLogin: () => void;
-}
+import { loginRequest } from '../services/authConfig';
const SupportModal: React.FC<{
isOpen: boolean;
@@ -72,9 +70,11 @@ const SupportModal: React.FC<{
);
};
-export const Login: React.FC = ({ onLogin }) => {
+export const Login: React.FC = () => {
+ const { instance } = useMsal();
const [isSupportModalOpen, setIsSupportModalOpen] = useState(false);
const [isLoggingIn, setIsLoggingIn] = useState(false);
+ const [loginError, setLoginError] = useState(null);
const handleSupportSubmit = (query: string) => {
console.log("Support query submitted:", query);
@@ -82,12 +82,26 @@ export const Login: React.FC = ({ onLogin }) => {
setIsSupportModalOpen(false);
};
- const handleMicrosoftLogin = () => {
+ const handleMicrosoftLogin = async () => {
setIsLoggingIn(true);
- // Simulate redirect/auth delay
- setTimeout(() => {
- onLogin();
- }, 1500);
+ setLoginError(null);
+
+ try {
+ await instance.loginPopup(loginRequest);
+ // Success - MSAL Provider will detect the login and re-render App
+ } catch (error: unknown) {
+ console.error('Login failed:', error);
+ if (error instanceof Error) {
+ // Handle user cancellation differently from errors
+ if (error.message.includes('user_cancelled')) {
+ setLoginError(null); // Don't show error for cancellation
+ } else {
+ setLoginError('Login failed. Please try again or contact support.');
+ }
+ }
+ } finally {
+ setIsLoggingIn(false);
+ }
};
return (
@@ -130,6 +144,12 @@ export const Login: React.FC = ({ onLogin }) => {
+ {loginError && (
+
+ {loginError}
+
+ )}
+