When a client disconnects (navigates away, closes tab) while analysis is still running, the result send raises RuntimeError "WebSocket is not connected". Catch this specifically as INFO rather than ERROR, and guard the fallback send_message in the general Exception handler so it doesn't raise a second uncaught error. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
257 lines
9.9 KiB
Python
Executable file
257 lines
9.9 KiB
Python
Executable file
import logging
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from app.config import settings
|
|
from app.services.auth_service import verify_access_token
|
|
from app.dependencies.auth import get_current_user
|
|
from app.models.database import init_db, close_db, async_session_factory as _session_factory
|
|
from app.repositories.user_repository import UserRepository
|
|
from app.api import router as api_router, kb_router
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HealthCheckFilter(logging.Filter):
|
|
"""Filter out health check endpoint logs from uvicorn access log."""
|
|
def filter(self, record: logging.LogRecord) -> bool:
|
|
message = record.getMessage()
|
|
if "GET /health" in message:
|
|
return False
|
|
return True
|
|
|
|
|
|
# Filter out health check logs from uvicorn access log
|
|
logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter())
|
|
from app.websocket.manager import ConnectionManager
|
|
from app.websocket.handlers import handle_analyze_message
|
|
from app.services.gemini_service import GeminiService
|
|
from app.services.reference_docs import ReferenceDocsService
|
|
from app.services.analysis_service import AnalysisService
|
|
from app.services.knowledge_base_service import KnowledgeBaseService
|
|
|
|
|
|
# Global services - initialized at startup
|
|
manager = ConnectionManager()
|
|
analysis_service: AnalysisService | None = None
|
|
knowledge_base_service: KnowledgeBaseService | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
Initialize services on startup and cleanup on shutdown.
|
|
|
|
Loads reference documents and initializes the analysis service.
|
|
"""
|
|
global analysis_service, knowledge_base_service
|
|
|
|
# Validate settings
|
|
settings.validate()
|
|
|
|
# Initialize database
|
|
print("Initializing database connection...")
|
|
db_available = False
|
|
try:
|
|
await init_db()
|
|
print("Database initialized successfully")
|
|
db_available = True
|
|
except Exception as e:
|
|
logger.warning(f"Database initialization failed (may not be available): {e}")
|
|
print(f"Warning: Database not available - running in stateless mode")
|
|
|
|
# Initialize services
|
|
print("Loading reference documents...")
|
|
reference_docs = ReferenceDocsService(settings.REFERENCE_DOCS_PATH)
|
|
|
|
# Load specs from DB if database is available
|
|
if db_available:
|
|
try:
|
|
from app.models.database import async_session_factory
|
|
async with async_session_factory() as session:
|
|
print("Loading specs from database...")
|
|
await reference_docs.load_specs_from_db(session)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load specs from DB (falling back to files): {e}")
|
|
|
|
# Log document info
|
|
doc_summary = reference_docs.get_context_summary()
|
|
print(f" Brand documents: {len(doc_summary['brand_files'])} files ({doc_summary['brand_context_length']} chars)")
|
|
print(f" Channel documents: {len(doc_summary['channel_files'])} files ({doc_summary['channel_context_length']} chars)")
|
|
|
|
print("Initializing Gemini service...")
|
|
gemini_service = GeminiService(settings.GEMINI_API_KEY)
|
|
|
|
print("Initializing analysis service...")
|
|
analysis_service = AnalysisService(gemini_service, reference_docs)
|
|
|
|
# Initialize Knowledge Base service (requires LlamaParse API key)
|
|
if settings.LLAMA_CLOUD_API_KEY:
|
|
from app.services.llamaparse_service import LlamaParseService
|
|
print("Initializing LlamaParse service...")
|
|
llamaparse_service = LlamaParseService(settings.LLAMA_CLOUD_API_KEY)
|
|
knowledge_base_service = KnowledgeBaseService(llamaparse_service, gemini_service, reference_docs)
|
|
print("Knowledge Base pipeline ready!")
|
|
else:
|
|
print("LLAMA_CLOUD_API_KEY not set - Knowledge Base processing pipeline disabled")
|
|
|
|
print("Backend ready!")
|
|
|
|
yield
|
|
|
|
# Cleanup on shutdown
|
|
print("Shutting down...")
|
|
await close_db()
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="ModComms Proof Review API",
|
|
description="AI-powered proof review backend for Barclays marketing materials",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware - allow frontend to connect
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.CORS_ORIGINS.split(","),
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Include API routes
|
|
app.include_router(api_router, prefix="/api")
|
|
app.include_router(kb_router, prefix="/api")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy", "service": "modcomms-backend"}
|
|
|
|
|
|
@app.get("/info")
|
|
async def info(user: dict = Depends(get_current_user)):
|
|
"""Get backend information. Requires authentication."""
|
|
if analysis_service:
|
|
ref_docs = analysis_service.reference_docs
|
|
doc_summary = ref_docs.get_context_summary()
|
|
return {
|
|
"status": "ready",
|
|
"user": user.get("name", "Unknown"),
|
|
"agents": ["Legal Agent", "Brand Agent", "Channel Best Practices Agent", "Channel Tech Specs Agent"],
|
|
"reference_docs": doc_summary,
|
|
}
|
|
return {"status": "initializing", "user": user.get("name", "Unknown")}
|
|
|
|
|
|
@app.websocket("/ws/analyze")
|
|
async def websocket_analyze(websocket: WebSocket):
|
|
"""
|
|
WebSocket endpoint for proof analysis with real-time updates.
|
|
|
|
Protocol:
|
|
- Client sends: {"type": "analyze", "file_data": "<base64>", "file_type": "image/png", "is_wip": false, "access_token": "<jwt>"}
|
|
- Server verifies token before processing
|
|
- Server sends: {"type": "agent_started", "agent_name": "..."}
|
|
- Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}}
|
|
- Server sends: {"type": "complete", "result": {...}}
|
|
- On error: {"type": "error", "message": "..."}
|
|
"""
|
|
client_id = str(uuid.uuid4())
|
|
logger.info(f"[MAIN] WebSocket connection established - client_id: {client_id}")
|
|
await manager.connect(websocket, client_id)
|
|
|
|
try:
|
|
while True:
|
|
# Wait for a message from the client
|
|
data = await websocket.receive_json()
|
|
logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}")
|
|
|
|
if data.get("type") == "analyze":
|
|
# Verify access token from message
|
|
access_token = data.get("access_token")
|
|
user_claims = await verify_access_token(access_token)
|
|
|
|
if not user_claims:
|
|
logger.warning(f"[MAIN] Authentication failed for client {client_id}")
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": "Authentication failed. Please sign in again."
|
|
})
|
|
continue
|
|
|
|
logger.info(f"[MAIN] Authenticated user: {user_claims.get('name', 'unknown')}")
|
|
|
|
# Check role: oversight_admin cannot upload/analyze proofs
|
|
current_user_id: Optional[uuid.UUID] = None
|
|
try:
|
|
async with _session_factory() as ws_session:
|
|
ws_user_repo = UserRepository(ws_session)
|
|
azure_oid = user_claims.get("oid") or user_claims.get("sub")
|
|
ws_user = await ws_user_repo.get_by_azure_oid(azure_oid) if azure_oid else None
|
|
if ws_user and ws_user.role == "oversight_admin":
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": "Oversight Admin has read-only access and cannot analyze proofs."
|
|
})
|
|
continue
|
|
current_user_id = ws_user.id if ws_user else None
|
|
except Exception as role_err:
|
|
logger.warning(f"[MAIN] Role check failed for client {client_id}: {role_err}")
|
|
|
|
if analysis_service is None:
|
|
logger.error("[MAIN] Analysis service not ready")
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": "Backend not ready. Please wait for initialization."
|
|
})
|
|
continue
|
|
|
|
# Handle the analysis request
|
|
await handle_analyze_message(
|
|
websocket=websocket,
|
|
client_id=client_id,
|
|
data=data,
|
|
manager=manager,
|
|
analysis_service=analysis_service,
|
|
current_user_id=current_user_id,
|
|
)
|
|
else:
|
|
logger.warning(f"[MAIN] Unknown message type: {data.get('type')}")
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": f"Unknown message type: {data.get('type')}"
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"[MAIN] Client {client_id} disconnected")
|
|
manager.disconnect(client_id)
|
|
except RuntimeError as e:
|
|
# Client disconnected mid-analysis (e.g. navigated away before result arrived)
|
|
if "not connected" in str(e).lower() or "websocket" in str(e).lower():
|
|
logger.info(f"[MAIN] Client {client_id} disconnected before result was sent")
|
|
else:
|
|
logger.error(f"[MAIN] RuntimeError for client {client_id}: {str(e)}")
|
|
manager.disconnect(client_id)
|
|
except Exception as e:
|
|
logger.error(f"[MAIN] Error for client {client_id}: {str(e)}")
|
|
try:
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": str(e)
|
|
})
|
|
except Exception:
|
|
pass
|
|
manager.disconnect(client_id)
|