modcomms/backend/app/main.py
michael d21036a0de Add 4-tier RBAC backend: auth dependencies, role enforcement, agency filtering
- Add CHECK constraint migration for users.role (super_admin, oversight_admin, agency_admin, basic_user)
- Add get_current_db_user dependency resolving Azure claims to User ORM with agency
- Add require_role() factory and require_write_access() dependency
- Auto-promote dev user to super_admin when DISABLE_AUTH=true
- Add /api/me, PUT /api/users/{id}, POST /api/agencies endpoints
- Apply agency-based data filtering on campaigns, analytics, audit routes
- Block oversight_admin from all mutation routes (campaigns, proofs, flags, resolves)
- Restrict dropdown option mutations to super_admin only
- Add role check in WebSocket handler to block oversight_admin from analysis
- Add CurrentUserResponse, UserUpdate, AgencyCreate schemas

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 08:28:23 -06:00

244 lines
9.3 KiB
Python
Executable file

import logging
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.services.auth_service import verify_access_token
from app.dependencies.auth import get_current_user
from app.models.database import init_db, close_db, async_session_factory as _session_factory
from app.repositories.user_repository import UserRepository
from app.api import router as api_router, kb_router
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
class HealthCheckFilter(logging.Filter):
"""Filter out health check endpoint logs from uvicorn access log."""
def filter(self, record: logging.LogRecord) -> bool:
message = record.getMessage()
if "GET /health" in message:
return False
return True
# Filter out health check logs from uvicorn access log
logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter())
from app.websocket.manager import ConnectionManager
from app.websocket.handlers import handle_analyze_message
from app.services.gemini_service import GeminiService
from app.services.reference_docs import ReferenceDocsService
from app.services.analysis_service import AnalysisService
from app.services.knowledge_base_service import KnowledgeBaseService
# Global services - initialized at startup
manager = ConnectionManager()
analysis_service: AnalysisService | None = None
knowledge_base_service: KnowledgeBaseService | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Initialize services on startup and cleanup on shutdown.
Loads reference documents and initializes the analysis service.
"""
global analysis_service, knowledge_base_service
# Validate settings
settings.validate()
# Initialize database
print("Initializing database connection...")
db_available = False
try:
await init_db()
print("Database initialized successfully")
db_available = True
except Exception as e:
logger.warning(f"Database initialization failed (may not be available): {e}")
print(f"Warning: Database not available - running in stateless mode")
# Initialize services
print("Loading reference documents...")
reference_docs = ReferenceDocsService(settings.REFERENCE_DOCS_PATH)
# Load specs from DB if database is available
if db_available:
try:
from app.models.database import async_session_factory
async with async_session_factory() as session:
print("Loading specs from database...")
await reference_docs.load_specs_from_db(session)
except Exception as e:
logger.warning(f"Failed to load specs from DB (falling back to files): {e}")
# Log document info
doc_summary = reference_docs.get_context_summary()
print(f" Brand documents: {len(doc_summary['brand_files'])} files ({doc_summary['brand_context_length']} chars)")
print(f" Channel documents: {len(doc_summary['channel_files'])} files ({doc_summary['channel_context_length']} chars)")
print("Initializing Gemini service...")
gemini_service = GeminiService(settings.GEMINI_API_KEY)
print("Initializing analysis service...")
analysis_service = AnalysisService(gemini_service, reference_docs)
# Initialize Knowledge Base service (requires LlamaParse API key)
if settings.LLAMA_CLOUD_API_KEY:
from app.services.llamaparse_service import LlamaParseService
print("Initializing LlamaParse service...")
llamaparse_service = LlamaParseService(settings.LLAMA_CLOUD_API_KEY)
knowledge_base_service = KnowledgeBaseService(llamaparse_service, gemini_service, reference_docs)
print("Knowledge Base pipeline ready!")
else:
print("LLAMA_CLOUD_API_KEY not set - Knowledge Base processing pipeline disabled")
print("Backend ready!")
yield
# Cleanup on shutdown
print("Shutting down...")
await close_db()
# Create FastAPI app
app = FastAPI(
title="ModComms Proof Review API",
description="AI-powered proof review backend for Barclays marketing materials",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware - allow frontend to connect
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS.split(","),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include API routes
app.include_router(api_router, prefix="/api")
app.include_router(kb_router, prefix="/api")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "service": "modcomms-backend"}
@app.get("/info")
async def info(user: dict = Depends(get_current_user)):
"""Get backend information. Requires authentication."""
if analysis_service:
ref_docs = analysis_service.reference_docs
doc_summary = ref_docs.get_context_summary()
return {
"status": "ready",
"user": user.get("name", "Unknown"),
"agents": ["Legal Agent", "Brand Agent", "Channel Best Practices Agent", "Channel Tech Specs Agent"],
"reference_docs": doc_summary,
}
return {"status": "initializing", "user": user.get("name", "Unknown")}
@app.websocket("/ws/analyze")
async def websocket_analyze(websocket: WebSocket):
"""
WebSocket endpoint for proof analysis with real-time updates.
Protocol:
- Client sends: {"type": "analyze", "file_data": "<base64>", "file_type": "image/png", "is_wip": false, "access_token": "<jwt>"}
- Server verifies token before processing
- Server sends: {"type": "agent_started", "agent_name": "..."}
- Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}}
- Server sends: {"type": "complete", "result": {...}}
- On error: {"type": "error", "message": "..."}
"""
client_id = str(uuid.uuid4())
logger.info(f"[MAIN] WebSocket connection established - client_id: {client_id}")
await manager.connect(websocket, client_id)
try:
while True:
# Wait for a message from the client
data = await websocket.receive_json()
logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}")
if data.get("type") == "analyze":
# Verify access token from message
access_token = data.get("access_token")
user_claims = await verify_access_token(access_token)
if not user_claims:
logger.warning(f"[MAIN] Authentication failed for client {client_id}")
await manager.send_message(client_id, {
"type": "error",
"message": "Authentication failed. Please sign in again."
})
continue
logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}")
# Check role: oversight_admin cannot upload/analyze proofs
try:
async with _session_factory() as ws_session:
ws_user_repo = UserRepository(ws_session)
azure_oid = user_claims.get("oid") or user_claims.get("sub")
ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None
if ws_user and ws_user.role == "oversight_admin":
await manager.send_message(client_id, {
"type": "error",
"message": "Oversight Admin has read-only access and cannot analyze proofs."
})
continue
except Exception as role_err:
logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}")
if analysis_service is None:
logger.error("[MAIN] Analysis service not ready")
await manager.send_message(client_id, {
"type": "error",
"message": "Backend not ready. Please wait for initialization."
})
continue
# Handle the analysis request
await handle_analyze_message(
websocket=websocket,
client_id=client_id,
data=data,
manager=manager,
analysis_service=analysis_service,
)
else:
logger.warning(f"[MAIN] Unknown message type: {data.get('type')}")
await manager.send_message(client_id, {
"type": "error",
"message": f"Unknown message type: {data.get('type')}"
})
except WebSocketDisconnect:
logger.info(f"[MAIN] Client {client_id} disconnected")
manager.disconnect(client_id)
except Exception as e:
logger.error(f"[MAIN] Error for client {client_id}: {str(e)}")
await manager.send_message(client_id, {
"type": "error",
"message": str(e)
})
manager.disconnect(client_id)