157 lines
5.1 KiB
Python
157 lines
5.1 KiB
Python
import logging
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from app.config import settings
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
from app.websocket.manager import ConnectionManager
|
|
from app.websocket.handlers import handle_analyze_message
|
|
from app.services.gemini_service import GeminiService
|
|
from app.services.reference_docs import ReferenceDocsService
|
|
from app.services.analysis_service import AnalysisService
|
|
|
|
|
|
# Global services - initialized at startup
|
|
manager = ConnectionManager()
|
|
analysis_service: AnalysisService | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
Initialize services on startup and cleanup on shutdown.
|
|
|
|
Loads reference documents and initializes the analysis service.
|
|
"""
|
|
global analysis_service
|
|
|
|
# Validate settings
|
|
settings.validate()
|
|
|
|
# Initialize services
|
|
print("Loading reference documents...")
|
|
reference_docs = ReferenceDocsService(settings.REFERENCE_DOCS_PATH)
|
|
|
|
# Log document info
|
|
doc_summary = reference_docs.get_context_summary()
|
|
print(f" Brand documents: {len(doc_summary['brand_files'])} files ({doc_summary['brand_context_length']} chars)")
|
|
print(f" Channel documents: {len(doc_summary['channel_files'])} files ({doc_summary['channel_context_length']} chars)")
|
|
|
|
print("Initializing Gemini service...")
|
|
gemini_service = GeminiService(settings.GEMINI_API_KEY)
|
|
|
|
print("Initializing analysis service...")
|
|
analysis_service = AnalysisService(gemini_service, reference_docs)
|
|
|
|
print("Backend ready!")
|
|
|
|
yield
|
|
|
|
# Cleanup on shutdown (if needed)
|
|
print("Shutting down...")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="ModComms Proof Review API",
|
|
description="AI-powered proof review backend for Barclays marketing materials",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware - allow frontend to connect
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.CORS_ORIGINS.split(","),
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy", "service": "modcomms-backend"}
|
|
|
|
|
|
@app.get("/info")
|
|
async def info():
|
|
"""Get backend information."""
|
|
if analysis_service:
|
|
ref_docs = analysis_service.reference_docs
|
|
doc_summary = ref_docs.get_context_summary()
|
|
return {
|
|
"status": "ready",
|
|
"agents": ["Legal Agent", "Brand Agent", "Tone Agent", "Channel Agent"],
|
|
"reference_docs": doc_summary,
|
|
}
|
|
return {"status": "initializing"}
|
|
|
|
|
|
@app.websocket("/ws/analyze")
|
|
async def websocket_analyze(websocket: WebSocket):
|
|
"""
|
|
WebSocket endpoint for proof analysis with real-time updates.
|
|
|
|
Protocol:
|
|
- Client sends: {"type": "analyze", "file_data": "<base64>", "file_type": "image/png", "is_wip": false}
|
|
- Server sends: {"type": "agent_started", "agent_name": "..."}
|
|
- Server sends: {"type": "agent_completed", "agent_name": "...", "review": {...}}
|
|
- Server sends: {"type": "complete", "result": {...}}
|
|
- On error: {"type": "error", "message": "..."}
|
|
"""
|
|
client_id = str(uuid.uuid4())
|
|
logger.info(f"[MAIN] WebSocket connection established - client_id: {client_id}")
|
|
await manager.connect(websocket, client_id)
|
|
|
|
try:
|
|
while True:
|
|
# Wait for a message from the client
|
|
data = await websocket.receive_json()
|
|
logger.info(f"[MAIN] Received message from client {client_id} - type: {data.get('type')}")
|
|
|
|
if data.get("type") == "analyze":
|
|
if analysis_service is None:
|
|
logger.error("[MAIN] Analysis service not ready")
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": "Backend not ready. Please wait for initialization."
|
|
})
|
|
continue
|
|
|
|
# Handle the analysis request
|
|
await handle_analyze_message(
|
|
websocket=websocket,
|
|
client_id=client_id,
|
|
data=data,
|
|
manager=manager,
|
|
analysis_service=analysis_service,
|
|
)
|
|
else:
|
|
logger.warning(f"[MAIN] Unknown message type: {data.get('type')}")
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": f"Unknown message type: {data.get('type')}"
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(f"[MAIN] Client {client_id} disconnected")
|
|
manager.disconnect(client_id)
|
|
except Exception as e:
|
|
logger.error(f"[MAIN] Error for client {client_id}: {str(e)}")
|
|
await manager.send_message(client_id, {
|
|
"type": "error",
|
|
"message": str(e)
|
|
})
|
|
manager.disconnect(client_id)
|