""" WebSocket connection and message management """ import asyncio import json import logging from datetime import datetime from typing import Dict, Set, Any, Optional import uuid from weakref import WeakSet from quart import websocket from ..config_runtime import server_config logger = logging.getLogger(__name__) class WebSocketClient: """Represents a connected WebSocket client""" def __init__(self, client_id: str, user_id: Optional[str] = None): self.client_id = client_id self.user_id = user_id or 'anonymous' self.connected_at = datetime.utcnow() self.last_ping = datetime.utcnow() self.websocket = websocket._get_current_object() async def send(self, message: Dict[str, Any]): """Send a message to this client""" try: await self.websocket.send(json.dumps(message)) except Exception as e: logger.warning(f"Failed to send message to client {self.client_id}: {e}") raise async def ping(self): """Send ping to client""" try: await self.send({'type': 'ping', 'timestamp': datetime.utcnow().isoformat()}) self.last_ping = datetime.utcnow() except Exception as e: logger.warning(f"Failed to ping client {self.client_id}: {e}") raise class WebSocketManager: """ Manages WebSocket connections and broadcasts Singleton for coordinating real-time updates """ _instance: Optional['WebSocketManager'] = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if hasattr(self, '_initialized'): return self._initialized = True self.clients: Dict[str, WebSocketClient] = {} self._lock = asyncio.Lock() # Start background tasks self.ping_task = None self.cleanup_task = None logger.info("WebSocketManager initialized") async def start_background_tasks(self): """Start background maintenance tasks""" if not self.ping_task: self.ping_task = asyncio.create_task(self._ping_clients_loop()) if not self.cleanup_task: self.cleanup_task = asyncio.create_task(self._cleanup_disconnected_loop()) async def stop_background_tasks(self): """Stop background maintenance tasks""" if self.ping_task: self.ping_task.cancel() try: await self.ping_task except asyncio.CancelledError: pass if self.cleanup_task: self.cleanup_task.cancel() try: await self.cleanup_task except asyncio.CancelledError: pass async def register_client(self, user_id: Optional[str] = None) -> WebSocketClient: """ Register a new WebSocket client Args: user_id: User identifier (optional for dev mode) Returns: WebSocketClient instance """ client_id = str(uuid.uuid4()) client = WebSocketClient(client_id, user_id) async with self._lock: self.clients[client_id] = client logger.info(f"Registered WebSocket client {client_id} for user {user_id}") # Send initial connection acknowledgment await client.send({ 'type': 'connection.established', 'clientId': client_id, 'userId': user_id, 'connectedAt': client.connected_at.isoformat() }) return client async def unregister_client(self, client_id: str): """ Unregister a WebSocket client Args: client_id: Client identifier """ async with self._lock: if client_id in self.clients: client = self.clients.pop(client_id) logger.info(f"Unregistered WebSocket client {client_id} for user {client.user_id}") async def broadcast_to_all(self, message: Dict[str, Any]): """ Broadcast message to all connected clients Args: message: Message to broadcast """ if not self.clients: return # Add timestamp to message message['timestamp'] = datetime.utcnow().isoformat() async with self._lock: clients_to_remove = [] for client_id, client in self.clients.items(): try: await client.send(message) except Exception as e: logger.warning(f"Failed to send to client {client_id}: {e}") clients_to_remove.append(client_id) # Remove failed clients for client_id in clients_to_remove: self.clients.pop(client_id, None) async def broadcast_to_user(self, user_id: str, message: Dict[str, Any]): """ Broadcast message to all connections for a specific user Args: user_id: User identifier message: Message to broadcast """ if not self.clients: return # Add timestamp to message message['timestamp'] = datetime.utcnow().isoformat() async with self._lock: clients_to_remove = [] sent_count = 0 for client_id, client in self.clients.items(): if client.user_id == user_id: try: await client.send(message) sent_count += 1 except Exception as e: logger.warning(f"Failed to send to client {client_id}: {e}") clients_to_remove.append(client_id) # Remove failed clients for client_id in clients_to_remove: self.clients.pop(client_id, None) if sent_count > 0: logger.debug(f"Broadcast message to {sent_count} clients for user {user_id}") async def broadcast_job_update(self, job_id: str, message: Dict[str, Any]): """ Broadcast job-specific update Args: job_id: Job identifier message: Message to broadcast """ # For now, broadcast to all clients # In the future, we could implement job-specific subscriptions message['jobId'] = job_id await self.broadcast_to_all(message) async def send_queue_snapshot(self, client: WebSocketClient, jobs_data: list): """ Send initial queue snapshot to a client Args: client: WebSocket client jobs_data: Serialized jobs data """ try: await client.send({ 'type': 'queue.snapshot', 'jobs': jobs_data }) logger.debug(f"Sent queue snapshot to client {client.client_id}") except Exception as e: logger.error(f"Failed to send queue snapshot to {client.client_id}: {e}") raise async def get_connection_stats(self) -> Dict[str, Any]: """ Get WebSocket connection statistics Returns: Statistics dictionary """ async with self._lock: user_counts = {} for client in self.clients.values(): user_counts[client.user_id] = user_counts.get(client.user_id, 0) + 1 return { 'total_connections': len(self.clients), 'unique_users': len(user_counts), 'connections_per_user': user_counts, 'uptime_seconds': (datetime.utcnow() - min((c.connected_at for c in self.clients.values()), default=datetime.utcnow())).total_seconds() } async def _ping_clients_loop(self): """Background task to ping clients periodically""" while True: try: await asyncio.sleep(server_config.WS_PING_INTERVAL_SECONDS) async with self._lock: clients_to_remove = [] for client_id, client in self.clients.items(): try: await client.ping() except Exception: clients_to_remove.append(client_id) # Remove failed clients for client_id in clients_to_remove: self.clients.pop(client_id, None) logger.debug(f"Removed unresponsive client {client_id}") except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in ping loop: {e}") async def _cleanup_disconnected_loop(self): """Background task to clean up disconnected clients""" while True: try: await asyncio.sleep(60) # Check every minute async with self._lock: # Clean up clients that haven't been pinged recently cutoff = datetime.utcnow().timestamp() - (server_config.WS_PING_INTERVAL_SECONDS * 3) clients_to_remove = [] for client_id, client in self.clients.items(): if client.last_ping.timestamp() < cutoff: clients_to_remove.append(client_id) for client_id in clients_to_remove: self.clients.pop(client_id, None) logger.debug(f"Cleaned up stale client {client_id}") except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in cleanup loop: {e}") # Global instance ws_manager = WebSocketManager()