modcomms/backend/app/main.py
Vadym Samoilenko 5c338c31fb Fix WebSocket connection dropped during long proof analysis
- Add 25s heartbeat ping from backend to prevent Apache/proxy idle-timeout
  killing the connection during 1-3 min analysis runs
- Handle heartbeat silently in both analyzeProof and analyzeWIPProof frontend handlers
- Run PDF rasterization via asyncio.to_thread so heartbeats aren't blocked
- Wrap analyze_proof with asyncio.wait_for(timeout=300) for a hard 5-min cap
- Log dropped send_message calls in ConnectionManager instead of swallowing silently
- cloudrun.yaml: add sessionAffinity, startup probe, raise containerConcurrency 4→10,
  document DISABLE_AUTH option

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-18 11:23:59 +00:00

269 lines
10 KiB
Python
Executable file

import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.services.auth_service import verify_access_token
from app.dependencies.auth import get_current_user
from app.models.database import init_db, close_db, async_session_factory as _session_factory
from app.repositories.user_repository import UserRepository
from app.api import router as api_router, kb_router
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
class HealthCheckFilter(logging.Filter):
"""Filter out health check endpoint logs from uvicorn access log."""
def filter(self, record: logging.LogRecord) -> bool:
message = record.getMessage()
if "GET /health" in message:
return False
return True
# Filter out health check logs from uvicorn access log
logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter())
from app.websocket.manager import ConnectionManager
from app.websocket.handlers import handle_analyze_message
from app.services.gemini_service import GeminiService
from app.services.reference_docs import ReferenceDocsService
from app.services.analysis_service import AnalysisService
from app.services.knowledge_base_service import KnowledgeBaseService
# Global services - initialized at startup
manager = ConnectionManager()
analysis_service: AnalysisService | None = None
knowledge_base_service: KnowledgeBaseService | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Initialize services on startup and cleanup on shutdown.
Loads reference documents and initializes the analysis service.
"""
global analysis_service, knowledge_base_service
# Validate settings
settings.validate()
# Initialize database
print("Initializing database connection...")
db_available = False
try:
await init_db()
print("Database initialized successfully")
db_available = True
except Exception as e:
logger.warning(f"Database initialization failed (may not be available): {e}")
print(f"Warning: Database not available - running in stateless mode")
# Initialize services
print("Loading reference documents...")
reference_docs = ReferenceDocsService(settings.REFERENCE_DOCS_PATH)
# Load specs from DB if database is available
if db_available:
try:
from app.models.database import async_session_factory
async with async_session_factory() as session:
print("Loading specs from database...")
await reference_docs.load_specs_from_db(session)
except Exception as e:
logger.warning(f"Failed to load specs from DB (falling back to files): {e}")
# Log document info
doc_summary = reference_docs.get_context_summary()
print(f" Brand documents: {len(doc_summary['brand_files'])} files ({doc_summary['brand_context_length']} chars)")
print(f" Channel documents: {len(doc_summary['channel_files'])} files ({doc_summary['channel_context_length']} chars)")
print("Initializing Gemini service...")
gemini_service = GeminiService(settings.GEMINI_API_KEY)
print("Initializing analysis service...")
analysis_service = AnalysisService(gemini_service, reference_docs)
# Initialize Knowledge Base service (requires LlamaParse API key)
if settings.LLAMA_CLOUD_API_KEY:
from app.services.llamaparse_service import LlamaParseService
print("Initializing LlamaParse service...")
llamaparse_service = LlamaParseService(settings.LLAMA_CLOUD_API_KEY, settings.LLAMA_CLOUD_BASE_URL)
knowledge_base_service = KnowledgeBaseService(llamaparse_service, gemini_service, reference_docs)
print("Knowledge Base pipeline ready!")
else:
print("LLAMA_CLOUD_API_KEY not set - Knowledge Base processing pipeline disabled")
print("Backend ready!")
yield
# Cleanup on shutdown
print("Shutting down...")
await close_db()
# Create FastAPI app
app = FastAPI(
title="ModComms Proof Review API",
description="AI-powered proof review backend for Barclays marketing materials",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware - allow frontend to connect
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS.split(","),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include API routes
app.include_router(api_router, prefix="/api")
app.include_router(kb_router, prefix="/api")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "service": "modcomms-backend"}
@app.get("/info")
async def info(user: dict = Depends(get_current_user)):
"""Get backend information. Requires authentication."""
if analysis_service:
ref_docs = analysis_service.reference_docs
doc_summary = ref_docs.get_context_summary()
return {
"status": "ready",
"user": user.get("name", "Unknown"),
"agents": ["Legal Agent", "Brand Agent", "Channel Best Practices Agent", "Channel Tech Specs Agent"],
"reference_docs": doc_summary,
}
return {"status": "initializing", "user": user.get("name", "Unknown")}
@app.websocket("/ws/analyze")
async def websocket_analyze(websocket: WebSocket):
"""
WebSocket endpoint for proof analysis with real-time updates.
Protocol:
- Client sends: {"type": "analyze", "file_data": "<base64>", "file_type": "image/png", "is_wip": false, "access_token": "<jwt>"}
- Server verifies token before processing
- Server sends: {"type": "agent_started", "agent_name": "..."}
- Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}}
- Server sends: {"type": "complete", "result": {...}}
- On error: {"type": "error", "message": "..."}
"""
client_id = str(uuid.uuid4())
logger.info(f"[MAIN] WebSocket connection established - client_id: {client_id}")
await manager.connect(websocket, client_id)
try:
while True:
# Wait for a message from the client
data = await websocket.receive_json()
logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}")
if data.get("type") == "analyze":
# Verify access token from message
access_token = data.get("access_token")
user_claims = await verify_access_token(access_token)
if not user_claims:
logger.warning(f"[MAIN] Authentication failed for client {client_id}")
await manager.send_message(client_id, {
"type": "error",
"message": "Authentication failed. Please sign in again."
})
continue
logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}")
# Check role: oversight_admin cannot upload/analyze proofs
current_user_id: Optional[uuid.UUID] = None
try:
async with _session_factory() as ws_session:
ws_user_repo = UserRepository(ws_session)
azure_oid = user_claims.get("oid") or user_claims.get("sub")
ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None
current_user_id = ws_user.id if ws_user else None
except Exception as role_err:
logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}")
if analysis_service is None:
logger.error("[MAIN] Analysis service not ready")
await manager.send_message(client_id, {
"type": "error",
"message": "Backend not ready. Please wait for initialization."
})
continue
# Start keepalive heartbeat to prevent proxy idle-timeout
# (Apache/nginx/Cloud Run default is 60s; we ping every 25s)
async def _heartbeat(ws: WebSocket) -> None:
try:
while True:
await asyncio.sleep(25)
await ws.send_json({"type": "heartbeat"})
logger.debug(f"[MAIN] Heartbeat sent to client {client_id}")
except asyncio.CancelledError:
pass
except Exception as hb_err:
logger.debug(f"[MAIN] Heartbeat stopped for client {client_id}: {hb_err}")
heartbeat_task = asyncio.create_task(_heartbeat(websocket))
try:
# Handle the analysis request
await handle_analyze_message(
websocket=websocket,
client_id=client_id,
data=data,
manager=manager,
analysis_service=analysis_service,
current_user_id=current_user_id,
)
finally:
heartbeat_task.cancel()
else:
logger.warning(f"[MAIN] Unknown message type: {data.get('type')}")
await manager.send_message(client_id, {
"type": "error",
"message": f"Unknown message type: {data.get('type')}"
})
except WebSocketDisconnect:
logger.info(f"[MAIN] Client {client_id} disconnected")
manager.disconnect(client_id)
except RuntimeError as e:
# Client disconnected mid-analysis (e.g. navigated away before result arrived)
if "not connected" in str(e).lower() or "websocket" in str(e).lower():
logger.info(f"[MAIN] Client {client_id} disconnected before result was sent")
else:
logger.error(f"[MAIN] RuntimeError for client {client_id}: {str(e)}")
manager.disconnect(client_id)
except Exception as e:
logger.error(f"[MAIN] Error for client {client_id}: {str(e)}")
try:
await manager.send_message(client_id, {
"type": "error",
"message": str(e)
})
except Exception:
pass
manager.disconnect(client_id)