import logging import uuid from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends from fastapi.middleware.cors import CORSMiddleware from app.config import settings from app.services.auth_service import verify_access_token from app.dependencies.auth import get_current_user from app.models.database import init_db, close_db, async_session_factory as _session_factory from app.repositories.user_repository import UserRepository from app.api import router as api_router, kb_router # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logger = logging.getLogger(__name__) class HealthCheckFilter(logging.Filter): """Filter out health check endpoint logs from uvicorn access log.""" def filter(self, record: logging.LogRecord) -> bool: message = record.getMessage() if "GET /health" in message: return False return True # Filter out health check logs from uvicorn access log logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) from app.websocket.manager import ConnectionManager from app.websocket.handlers import handle_analyze_message from app.services.gemini_service import GeminiService from app.services.reference_docs import ReferenceDocsService from app.services.analysis_service import AnalysisService from app.services.knowledge_base_service import KnowledgeBaseService # Global services - initialized at startup manager = ConnectionManager() analysis_service: AnalysisService | None = None knowledge_base_service: KnowledgeBaseService | None = None @asynccontextmanager async def lifespan(app: FastAPI): """ Initialize services on startup and cleanup on shutdown. Loads reference documents and initializes the analysis service. """ global analysis_service, knowledge_base_service # Validate settings settings.validate() # Initialize database print("Initializing database connection...") db_available = False try: await init_db() print("Database initialized successfully") db_available = True except Exception as e: logger.warning(f"Database initialization failed (may not be available): {e}") print(f"Warning: Database not available - running in stateless mode") # Initialize services print("Loading reference documents...") reference_docs = ReferenceDocsService(settings.REFERENCE_DOCS_PATH) # Load specs from DB if database is available if db_available: try: from app.models.database import async_session_factory async with async_session_factory() as session: print("Loading specs from database...") await reference_docs.load_specs_from_db(session) except Exception as e: logger.warning(f"Failed to load specs from DB (falling back to files): {e}") # Log document info doc_summary = reference_docs.get_context_summary() print(f" Brand documents: {len(doc_summary['brand_files'])} files ({doc_summary['brand_context_length']} chars)") print(f" Channel documents: {len(doc_summary['channel_files'])} files ({doc_summary['channel_context_length']} chars)") print("Initializing Gemini service...") gemini_service = GeminiService(settings.GEMINI_API_KEY) print("Initializing analysis service...") analysis_service = AnalysisService(gemini_service, reference_docs) # Initialize Knowledge Base service (requires LlamaParse API key) if settings.LLAMA_CLOUD_API_KEY: from app.services.llamaparse_service import LlamaParseService print("Initializing LlamaParse service...") llamaparse_service = LlamaParseService(settings.LLAMA_CLOUD_API_KEY) knowledge_base_service = KnowledgeBaseService(llamaparse_service, gemini_service, reference_docs) print("Knowledge Base pipeline ready!") else: print("LLAMA_CLOUD_API_KEY not set - Knowledge Base processing pipeline disabled") print("Backend ready!") yield # Cleanup on shutdown print("Shutting down...") await close_db() # Create FastAPI app app = FastAPI( title="ModComms Proof Review API", description="AI-powered proof review backend for Barclays marketing materials", version="1.0.0", lifespan=lifespan, ) # CORS middleware - allow frontend to connect app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS.split(","), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Include API routes app.include_router(api_router, prefix="/api") app.include_router(kb_router, prefix="/api") @app.get("/health") async def health_check(): """Health check endpoint.""" return {"status": "healthy", "service": "modcomms-backend"} @app.get("/info") async def info(user: dict = Depends(get_current_user)): """Get backend information. Requires authentication.""" if analysis_service: ref_docs = analysis_service.reference_docs doc_summary = ref_docs.get_context_summary() return { "status": "ready", "user": user.get("name", "Unknown"), "agents": ["Legal Agent", "Brand Agent", "Channel Best Practices Agent", "Channel Tech Specs Agent"], "reference_docs": doc_summary, } return {"status": "initializing", "user": user.get("name", "Unknown")} @app.websocket("/ws/analyze") async def websocket_analyze(websocket: WebSocket): """ WebSocket endpoint for proof analysis with real-time updates. Protocol: - Client sends: {"type": "analyze", "file_data": "", "file_type": "image/png", "is_wip": false, "access_token": ""} - Server verifies token before processing - Server sends: {"type": "agent_started", "agent_name": "..."} - Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}} - Server sends: {"type": "complete", "result": {...}} - On error: {"type": "error", "message": "..."} """ client_id = str(uuid.uuid4()) logger.info(f"[MAIN] WebSocket connection established - client_id: {client_id}") await manager.connect(websocket, client_id) try: while True: # Wait for a message from the client data = await websocket.receive_json() logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}") if data.get("type") == "analyze": # Verify access token from message access_token = data.get("access_token") user_claims = await verify_access_token(access_token) if not user_claims: logger.warning(f"[MAIN] Authentication failed for client {client_id}") await manager.send_message(client_id, { "type": "error", "message": "Authentication failed. Please sign in again." }) continue logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}") # Check role: oversight_admin cannot upload/analyze proofs current_user_id: Optional[uuid.UUID] = None try: async with _session_factory() as ws_session: ws_user_repo = UserRepository(ws_session) azure_oid = user_claims.get("oid") or user_claims.get("sub") ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None current_user_id = ws_user.id if ws_user else None except Exception as role_err: logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}") if analysis_service is None: logger.error("[MAIN] Analysis service not ready") await manager.send_message(client_id, { "type": "error", "message": "Backend not ready. Please wait for initialization." }) continue # Handle the analysis request await handle_analyze_message( websocket=websocket, client_id=client_id, data=data, manager=manager, analysis_service=analysis_service, current_user_id=current_user_id, ) else: logger.warning(f"[MAIN] Unknown message type: {data.get('type')}") await manager.send_message(client_id, { "type": "error", "message": f"Unknown message type: {data.get('type')}" }) except WebSocketDisconnect: logger.info(f"[MAIN] Client {client_id} disconnected") manager.disconnect(client_id) except RuntimeError as e: # Client disconnected mid-analysis (e.g. navigated away before result arrived) if "not connected" in str(e).lower() or "websocket" in str(e).lower(): logger.info(f"[MAIN] Client {client_id} disconnected before result was sent") else: logger.error(f"[MAIN] RuntimeError for client {client_id}: {str(e)}") manager.disconnect(client_id) except Exception as e: logger.error(f"[MAIN] Error for client {client_id}: {str(e)}") try: await manager.send_message(client_id, { "type": "error", "message": str(e) }) except Exception: pass manager.disconnect(client_id)