""" WebSocket Connection Manager for Real-time Job Status Updates This module provides WebSocket support for broadcasting job status changes in real-time to connected clients. It uses Redis pub/sub for scalable message broadcasting across multiple worker processes. """ import asyncio import json import logging from datetime import datetime from typing import Any, Optional import redis.asyncio as redis from fastapi import WebSocket from pydantic import BaseModel from ..core.config import settings logger = logging.getLogger(__name__) class JobStatusUpdate(BaseModel): """Schema for job status update messages""" job_id: str status: str updated_at: datetime job_title: Optional[str] = None # Job title for better user experience message: Optional[str] = None progress: Optional[int] = None # 0-100 percentage metadata: Optional[dict[str, Any]] = None eligible_users: Optional[set[str]] = None # Pre-computed eligible users class ConnectionManager: """Manages WebSocket connections and Redis pub/sub for job status updates""" def __init__(self): # WebSocket connections by user_id self.user_ws: dict[str, set[WebSocket]] = {} # WebSocket metadata: websocket -> {user_id, jobs, scopes} self.ws_meta: dict[WebSocket, dict[str, Any]] = {} # Job subscriptions: job_id -> set of websockets self.job_ws: dict[str, set[WebSocket]] = {} # Lock for thread safety self.lock = asyncio.Lock() # Redis client for pub/sub self.redis_client: Optional[redis.Redis] = None self.pubsub: Optional[redis.client.PubSub] = None self.subscriber_task: Optional[asyncio.Task] = None async def start(self): """Initialize Redis pub/sub subscriber""" try: self.redis_client = redis.from_url( settings.redis_url, encoding="utf-8", decode_responses=True ) self.pubsub = self.redis_client.pubsub() # Subscribe to job status channels await self.pubsub.subscribe("job_status_updates") # Global channel await self.pubsub.psubscribe("job_status_updates:*") # Pattern for individual job channels # Start background task to handle Redis messages self.subscriber_task = asyncio.create_task(self._redis_subscriber()) logger.info("WebSocket connection manager started") except Exception as e: logger.error(f"Failed to start WebSocket connection manager: {e}") raise async def stop(self): """Cleanup Redis connections""" if self.subscriber_task: self.subscriber_task.cancel() try: await self.subscriber_task except asyncio.CancelledError: pass if self.pubsub: await self.pubsub.unsubscribe() await self.pubsub.punsubscribe() await self.pubsub.aclose() if self.redis_client: await self.redis_client.aclose() logger.info("WebSocket connection manager stopped") async def connect_job_status(self, websocket: WebSocket, user_id: str, job_id: str): """Connect a WebSocket for specific job status updates""" await websocket.accept() async with self.lock: # Add to user connections if user_id not in self.user_ws: self.user_ws[user_id] = set() self.user_ws[user_id].add(websocket) # Initialize/update websocket metadata if websocket not in self.ws_meta: self.ws_meta[websocket] = { "user_id": user_id, "jobs": set(), "scopes": set() } self.ws_meta[websocket]["jobs"].add(job_id) # Add to job subscriptions if job_id not in self.job_ws: self.job_ws[job_id] = set() self.job_ws[job_id].add(websocket) logger.info(f"User {user_id} connected for job {job_id} status updates") # Send initial connection confirmation await self._send_to_websocket(websocket, { "type": "connection_established", "job_id": job_id, "timestamp": datetime.utcnow().isoformat() }) async def connect_job_list(self, websocket: WebSocket, user_id: str): """Connect a WebSocket for job list updates (all jobs for a user)""" await websocket.accept() async with self.lock: # Add to user connections if user_id not in self.user_ws: self.user_ws[user_id] = set() self.user_ws[user_id].add(websocket) # Initialize/update websocket metadata if websocket not in self.ws_meta: self.ws_meta[websocket] = { "user_id": user_id, "jobs": set(), "scopes": set() } self.ws_meta[websocket]["scopes"].add("job_list") logger.info(f"User {user_id} connected for job list updates") # Send initial connection confirmation await self._send_to_websocket(websocket, { "type": "connection_established", "scope": "job_list", "timestamp": datetime.utcnow().isoformat() }) async def disconnect(self, websocket: WebSocket, user_id: str): """Disconnect a WebSocket and clean up subscriptions""" async with self.lock: # Get websocket metadata meta = self.ws_meta.pop(websocket, None) if not meta: return # Remove from job subscriptions for job_id in meta.get("jobs", set()): if job_id in self.job_ws: self.job_ws[job_id].discard(websocket) if not self.job_ws[job_id]: del self.job_ws[job_id] # Remove from user connections if user_id in self.user_ws: self.user_ws[user_id].discard(websocket) if not self.user_ws[user_id]: del self.user_ws[user_id] logger.info(f"User {user_id} disconnected from WebSocket") async def broadcast_job_status_update( self, job_id: str, status: str, job_title: Optional[str] = None, message: Optional[str] = None, progress: Optional[int] = None, metadata: Optional[dict[str, Any]] = None ): """ Async wrapper for broadcasting job status updates from API routes For Celery workers, use websocket_publisher.publish_job_update_with_eligibility() directly """ import asyncio from concurrent.futures import ThreadPoolExecutor from .websocket_publisher import publish_job_update_with_eligibility # Run the sync publisher in a thread pool loop = asyncio.get_event_loop() with ThreadPoolExecutor(max_workers=1) as executor: await loop.run_in_executor( executor, publish_job_update_with_eligibility, job_id, status, job_title, message, progress, metadata ) async def _redis_subscriber(self): """Background task to handle Redis pub/sub messages with reconnection logic""" delay = 1 # Start with 1 second delay max_delay = 30 # Maximum delay of 30 seconds while True: try: # (Re)create pubsub connection if self.pubsub: try: await self.pubsub.aclose() except Exception: pass self.pubsub = self.redis_client.pubsub() # Subscribe to channels await self.pubsub.subscribe("job_status_updates") await self.pubsub.psubscribe("job_status_updates:*") logger.info("Redis subscriber connected and subscribed") delay = 1 # Reset delay on successful connection # Listen for messages async for message in self.pubsub.listen(): if message["type"] in ("message", "pmessage"): await self._handle_redis_message(message) except asyncio.CancelledError: logger.info("Redis subscriber task cancelled") break except Exception as e: logger.error(f"Redis subscriber error, retrying in {delay}s: {e}") await asyncio.sleep(delay) delay = min(delay * 2, max_delay) # Exponential backoff async def _handle_redis_message(self, message: dict[str, Any]): """Handle incoming Redis pub/sub message""" try: # For pattern messages, the channel is in the "channel" field # For regular messages, it's also in the "channel" field channel = message["channel"] data = json.loads(message["data"]) update = JobStatusUpdate(**data) logger.debug(f"Received Redis message on channel '{channel}': {data}") # Send to specific job subscribers if channel.startswith("job_status_updates:"): job_id = channel.split(":", 1)[1] logger.debug(f"Sending job status update for job {job_id} to subscribers") await self._send_job_status_to_subscribers(job_id, update) # Send to global subscribers (job list updates) elif channel == "job_status_updates": logger.debug("Sending global job status update to subscribers") await self._send_job_status_to_global_subscribers(update) except Exception as e: logger.error(f"Failed to handle Redis message: {e}") async def _send_job_status_to_subscribers(self, job_id: str, update: JobStatusUpdate): """Send job status update to specific job subscribers""" async with self.lock: target_websockets = list(self.job_ws.get(job_id, set())) if not target_websockets: return # Convert to JSON-serializable dict message = { "type": "job_status_update", "data": update.model_dump(mode="json") } await self._send_to_websockets(target_websockets, message) async def _send_job_status_to_global_subscribers(self, update: JobStatusUpdate): """Send job status update to global (job list) subscribers with user filtering""" # Convert to JSON-serializable dict message_data = update.model_dump(mode="json") # Remove eligible_users from the client message message_data.pop("eligible_users", None) message = { "type": "job_list_update", "data": message_data } # Use pre-computed eligible users if available, otherwise compute them eligible_users = getattr(update, 'eligible_users', None) if eligible_users is None: eligible_users = await self._get_job_related_users(update.job_id) # Find websockets for eligible users that have job_list scope target_websockets = [] async with self.lock: for user_id in eligible_users: for websocket in self.user_ws.get(user_id, set()): meta = self.ws_meta.get(websocket, {}) if "job_list" in meta.get("scopes", set()): target_websockets.append(websocket) await self._send_to_websockets(target_websockets, message) async def _get_job_related_users(self, job_id: str) -> set[str]: """ Get all users who should receive notifications for a specific job. Returns set of user IDs for: - Job creator (client_id) - Reviewers who worked on the job - Admin users (see all jobs) """ eligible_users = set() try: # Import database connection from ..core.database import get_database db = await get_database() # Get the job job = await db["jobs"].find_one({"_id": job_id}) if not job: logger.warning(f"Job {job_id} not found for notification filtering") return eligible_users # Add job creator if job.get("client_id"): eligible_users.add(job["client_id"]) # Add reviewers from review history review = job.get("review", {}) if review.get("reviewer_id"): eligible_users.add(review["reviewer_id"]) # Add reviewers from history for history_item in review.get("history", []): if history_item.get("by"): eligible_users.add(history_item["by"]) # Add all admin users (they can see all jobs) admin_users = db["users"].find({"role": "admin"}) async for admin_user in admin_users: user_id = str(admin_user["_id"]) eligible_users.add(user_id) logger.debug(f"Job {job_id} notification eligible users: {len(eligible_users)}") except Exception as e: logger.error(f"Error getting job related users for {job_id}: {e}") return eligible_users async def _send_to_websockets(self, websockets: list[WebSocket], message: dict[str, Any]): """Send message to a list of WebSocket connections""" disconnected_websockets = [] for websocket in websockets: try: await self._send_to_websocket(websocket, message) except Exception as e: logger.warning(f"Failed to send to websocket: {e}") disconnected_websockets.append(websocket) # Clean up disconnected connections for websocket in disconnected_websockets: # Get user_id from metadata before disconnecting async with self.lock: meta = self.ws_meta.get(websocket, {}) user_id = meta.get("user_id") if user_id: await self.disconnect(websocket, user_id) async def _send_to_websocket(self, websocket: WebSocket, message: dict[str, Any]): """Send message to a specific WebSocket connection""" try: await websocket.send_json(message) except Exception as e: logger.warning(f"WebSocket send failed: {e}") raise # Global connection manager instance connection_manager = ConnectionManager() async def authenticate_websocket(websocket: WebSocket, token: Optional[str]) -> Optional[str]: """ Authenticate a WebSocket connection using a JWT token Returns user_id if valid, None if invalid """ try: if not token: logger.warning("WebSocket authentication failed: Missing token") await websocket.close(code=4001, reason="Missing authentication token") return None # Import JWT decode function from ..core.security import decode_token # Decode JWT token - this may raise HTTPException try: payload = decode_token(token) if not payload or "sub" not in payload: logger.warning("WebSocket authentication failed: Invalid token payload") await websocket.close(code=4001, reason="Invalid authentication token") return None user_id = payload["sub"] logger.info(f"WebSocket authentication successful for user: {user_id}") return user_id except Exception as jwt_error: logger.warning(f"WebSocket authentication failed: JWT decode error: {jwt_error}") await websocket.close(code=4001, reason="Invalid authentication token") return None except Exception as e: logger.error(f"WebSocket authentication failed with unexpected error: {e}") await websocket.close(code=4001, reason="Authentication failed") return None async def get_connection_manager() -> ConnectionManager: """Dependency to get the connection manager""" return connection_manager