import json from typing import Any from fastapi import WebSocket class ConnectionManager: """Manage WebSocket connections grouped by job_id.""" def __init__(self) -> None: self._connections: dict[str, list[WebSocket]] = {} async def connect(self, job_id: str, websocket: WebSocket) -> None: """Accept a websocket connection and register it for a job.""" await websocket.accept() if job_id not in self._connections: self._connections[job_id] = [] self._connections[job_id].append(websocket) def disconnect(self, job_id: str, websocket: WebSocket) -> None: """Remove a websocket connection from a job's connection list.""" if job_id in self._connections: self._connections[job_id] = [ ws for ws in self._connections[job_id] if ws != websocket ] if not self._connections[job_id]: del self._connections[job_id] async def broadcast(self, job_id: str, message: dict[str, Any]) -> None: """Broadcast a JSON message to all connections for a job.""" if job_id not in self._connections: return payload = json.dumps(message) disconnected: list[WebSocket] = [] for websocket in self._connections[job_id]: try: await websocket.send_text(payload) except Exception: disconnected.append(websocket) # Clean up dead connections for ws in disconnected: self.disconnect(job_id, ws) async def send_personal( self, websocket: WebSocket, message: dict[str, Any] ) -> None: """Send a message to a specific websocket.""" await websocket.send_text(json.dumps(message)) def get_connection_count(self, job_id: str) -> int: """Get the number of active connections for a job.""" return len(self._connections.get(job_id, [])) # Singleton instance manager = ConnectionManager()