432 lines
16 KiB
Python
432 lines
16 KiB
Python
"""
|
|
WebSocket Connection Manager for Real-time Job Status Updates
|
|
|
|
This module provides WebSocket support for broadcasting job status changes
|
|
in real-time to connected clients. It uses Redis pub/sub for scalable
|
|
message broadcasting across multiple worker processes.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Optional
|
|
|
|
import redis.asyncio as redis
|
|
from fastapi import WebSocket
|
|
from pydantic import BaseModel
|
|
|
|
from ..core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JobStatusUpdate(BaseModel):
|
|
"""Schema for job status update messages"""
|
|
job_id: str
|
|
status: str
|
|
updated_at: datetime
|
|
job_title: Optional[str] = None # Job title for better user experience
|
|
message: Optional[str] = None
|
|
progress: Optional[int] = None # 0-100 percentage
|
|
metadata: Optional[dict[str, Any]] = None
|
|
eligible_users: Optional[set[str]] = None # Pre-computed eligible users
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections and Redis pub/sub for job status updates"""
|
|
|
|
def __init__(self):
|
|
# WebSocket connections by user_id
|
|
self.user_ws: dict[str, set[WebSocket]] = {}
|
|
# WebSocket metadata: websocket -> {user_id, jobs, scopes}
|
|
self.ws_meta: dict[WebSocket, dict[str, Any]] = {}
|
|
# Job subscriptions: job_id -> set of websockets
|
|
self.job_ws: dict[str, set[WebSocket]] = {}
|
|
# Lock for thread safety
|
|
self.lock = asyncio.Lock()
|
|
# Redis client for pub/sub
|
|
self.redis_client: Optional[redis.Redis] = None
|
|
self.pubsub: Optional[redis.client.PubSub] = None
|
|
self.subscriber_task: Optional[asyncio.Task] = None
|
|
|
|
async def start(self):
|
|
"""Initialize Redis pub/sub subscriber"""
|
|
try:
|
|
self.redis_client = redis.from_url(
|
|
settings.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True
|
|
)
|
|
self.pubsub = self.redis_client.pubsub()
|
|
|
|
# Subscribe to job status channels
|
|
await self.pubsub.subscribe("job_status_updates") # Global channel
|
|
await self.pubsub.psubscribe("job_status_updates:*") # Pattern for individual job channels
|
|
|
|
# Start background task to handle Redis messages
|
|
self.subscriber_task = asyncio.create_task(self._redis_subscriber())
|
|
logger.info("WebSocket connection manager started")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to start WebSocket connection manager: {e}")
|
|
raise
|
|
|
|
async def stop(self):
|
|
"""Cleanup Redis connections"""
|
|
if self.subscriber_task:
|
|
self.subscriber_task.cancel()
|
|
try:
|
|
await self.subscriber_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self.pubsub:
|
|
await self.pubsub.unsubscribe()
|
|
await self.pubsub.punsubscribe()
|
|
await self.pubsub.aclose()
|
|
|
|
if self.redis_client:
|
|
await self.redis_client.aclose()
|
|
|
|
logger.info("WebSocket connection manager stopped")
|
|
|
|
async def connect_job_status(self, websocket: WebSocket, user_id: str, job_id: str):
|
|
"""Connect a WebSocket for specific job status updates"""
|
|
await websocket.accept()
|
|
|
|
async with self.lock:
|
|
# Add to user connections
|
|
if user_id not in self.user_ws:
|
|
self.user_ws[user_id] = set()
|
|
self.user_ws[user_id].add(websocket)
|
|
|
|
# Initialize/update websocket metadata
|
|
if websocket not in self.ws_meta:
|
|
self.ws_meta[websocket] = {
|
|
"user_id": user_id,
|
|
"jobs": set(),
|
|
"scopes": set()
|
|
}
|
|
self.ws_meta[websocket]["jobs"].add(job_id)
|
|
|
|
# Add to job subscriptions
|
|
if job_id not in self.job_ws:
|
|
self.job_ws[job_id] = set()
|
|
self.job_ws[job_id].add(websocket)
|
|
|
|
logger.info(f"User {user_id} connected for job {job_id} status updates")
|
|
|
|
# Send initial connection confirmation
|
|
await self._send_to_websocket(websocket, {
|
|
"type": "connection_established",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
async def connect_job_list(self, websocket: WebSocket, user_id: str):
|
|
"""Connect a WebSocket for job list updates (all jobs for a user)"""
|
|
await websocket.accept()
|
|
|
|
async with self.lock:
|
|
# Add to user connections
|
|
if user_id not in self.user_ws:
|
|
self.user_ws[user_id] = set()
|
|
self.user_ws[user_id].add(websocket)
|
|
|
|
# Initialize/update websocket metadata
|
|
if websocket not in self.ws_meta:
|
|
self.ws_meta[websocket] = {
|
|
"user_id": user_id,
|
|
"jobs": set(),
|
|
"scopes": set()
|
|
}
|
|
self.ws_meta[websocket]["scopes"].add("job_list")
|
|
|
|
logger.info(f"User {user_id} connected for job list updates")
|
|
|
|
# Send initial connection confirmation
|
|
await self._send_to_websocket(websocket, {
|
|
"type": "connection_established",
|
|
"scope": "job_list",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
async def disconnect(self, websocket: WebSocket, user_id: str):
|
|
"""Disconnect a WebSocket and clean up subscriptions"""
|
|
async with self.lock:
|
|
# Get websocket metadata
|
|
meta = self.ws_meta.pop(websocket, None)
|
|
if not meta:
|
|
return
|
|
|
|
# Remove from job subscriptions
|
|
for job_id in meta.get("jobs", set()):
|
|
if job_id in self.job_ws:
|
|
self.job_ws[job_id].discard(websocket)
|
|
if not self.job_ws[job_id]:
|
|
del self.job_ws[job_id]
|
|
|
|
# Remove from user connections
|
|
if user_id in self.user_ws:
|
|
self.user_ws[user_id].discard(websocket)
|
|
if not self.user_ws[user_id]:
|
|
del self.user_ws[user_id]
|
|
|
|
logger.info(f"User {user_id} disconnected from WebSocket")
|
|
|
|
async def broadcast_job_status_update(
|
|
self,
|
|
job_id: str,
|
|
status: str,
|
|
job_title: Optional[str] = None,
|
|
message: Optional[str] = None,
|
|
progress: Optional[int] = None,
|
|
metadata: Optional[dict[str, Any]] = None
|
|
):
|
|
"""
|
|
Async wrapper for broadcasting job status updates from API routes
|
|
For Celery workers, use websocket_publisher.publish_job_update_with_eligibility() directly
|
|
"""
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from .websocket_publisher import publish_job_update_with_eligibility
|
|
|
|
# Run the sync publisher in a thread pool
|
|
loop = asyncio.get_event_loop()
|
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
await loop.run_in_executor(
|
|
executor,
|
|
publish_job_update_with_eligibility,
|
|
job_id,
|
|
status,
|
|
job_title,
|
|
message,
|
|
progress,
|
|
metadata
|
|
)
|
|
|
|
async def _redis_subscriber(self):
|
|
"""Background task to handle Redis pub/sub messages with reconnection logic"""
|
|
delay = 1 # Start with 1 second delay
|
|
max_delay = 30 # Maximum delay of 30 seconds
|
|
|
|
while True:
|
|
try:
|
|
# (Re)create pubsub connection
|
|
if self.pubsub:
|
|
try:
|
|
await self.pubsub.aclose()
|
|
except Exception:
|
|
pass
|
|
|
|
self.pubsub = self.redis_client.pubsub()
|
|
|
|
# Subscribe to channels
|
|
await self.pubsub.subscribe("job_status_updates")
|
|
await self.pubsub.psubscribe("job_status_updates:*")
|
|
|
|
logger.info("Redis subscriber connected and subscribed")
|
|
delay = 1 # Reset delay on successful connection
|
|
|
|
# Listen for messages
|
|
async for message in self.pubsub.listen():
|
|
if message["type"] in ("message", "pmessage"):
|
|
await self._handle_redis_message(message)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Redis subscriber task cancelled")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Redis subscriber error, retrying in {delay}s: {e}")
|
|
await asyncio.sleep(delay)
|
|
delay = min(delay * 2, max_delay) # Exponential backoff
|
|
|
|
async def _handle_redis_message(self, message: dict[str, Any]):
|
|
"""Handle incoming Redis pub/sub message"""
|
|
try:
|
|
# For pattern messages, the channel is in the "channel" field
|
|
# For regular messages, it's also in the "channel" field
|
|
channel = message["channel"]
|
|
data = json.loads(message["data"])
|
|
update = JobStatusUpdate(**data)
|
|
|
|
logger.debug(f"Received Redis message on channel '{channel}': {data}")
|
|
|
|
# Send to specific job subscribers
|
|
if channel.startswith("job_status_updates:"):
|
|
job_id = channel.split(":", 1)[1]
|
|
logger.debug(f"Sending job status update for job {job_id} to subscribers")
|
|
await self._send_job_status_to_subscribers(job_id, update)
|
|
|
|
# Send to global subscribers (job list updates)
|
|
elif channel == "job_status_updates":
|
|
logger.debug("Sending global job status update to subscribers")
|
|
await self._send_job_status_to_global_subscribers(update)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to handle Redis message: {e}")
|
|
|
|
async def _send_job_status_to_subscribers(self, job_id: str, update: JobStatusUpdate):
|
|
"""Send job status update to specific job subscribers"""
|
|
async with self.lock:
|
|
target_websockets = list(self.job_ws.get(job_id, set()))
|
|
|
|
if not target_websockets:
|
|
return
|
|
|
|
# Convert to JSON-serializable dict
|
|
message = {
|
|
"type": "job_status_update",
|
|
"data": update.model_dump(mode="json")
|
|
}
|
|
|
|
await self._send_to_websockets(target_websockets, message)
|
|
|
|
async def _send_job_status_to_global_subscribers(self, update: JobStatusUpdate):
|
|
"""Send job status update to global (job list) subscribers with user filtering"""
|
|
# Convert to JSON-serializable dict
|
|
message_data = update.model_dump(mode="json")
|
|
# Remove eligible_users from the client message
|
|
message_data.pop("eligible_users", None)
|
|
message = {
|
|
"type": "job_list_update",
|
|
"data": message_data
|
|
}
|
|
|
|
# Use pre-computed eligible users if available, otherwise compute them
|
|
eligible_users = getattr(update, 'eligible_users', None)
|
|
if eligible_users is None:
|
|
eligible_users = await self._get_job_related_users(update.job_id)
|
|
|
|
# Find websockets for eligible users that have job_list scope
|
|
target_websockets = []
|
|
async with self.lock:
|
|
for user_id in eligible_users:
|
|
for websocket in self.user_ws.get(user_id, set()):
|
|
meta = self.ws_meta.get(websocket, {})
|
|
if "job_list" in meta.get("scopes", set()):
|
|
target_websockets.append(websocket)
|
|
|
|
await self._send_to_websockets(target_websockets, message)
|
|
|
|
async def _get_job_related_users(self, job_id: str) -> set[str]:
|
|
"""
|
|
Get all users who should receive notifications for a specific job.
|
|
Returns set of user IDs for:
|
|
- Job creator (client_id)
|
|
- Reviewers who worked on the job
|
|
- Admin users (see all jobs)
|
|
"""
|
|
eligible_users = set()
|
|
|
|
try:
|
|
# Import database connection
|
|
from ..core.database import get_database
|
|
db = await get_database()
|
|
|
|
# Get the job
|
|
job = await db["jobs"].find_one({"_id": job_id})
|
|
if not job:
|
|
logger.warning(f"Job {job_id} not found for notification filtering")
|
|
return eligible_users
|
|
|
|
# Add job creator
|
|
if job.get("client_id"):
|
|
eligible_users.add(job["client_id"])
|
|
|
|
# Add reviewers from review history
|
|
review = job.get("review", {})
|
|
if review.get("reviewer_id"):
|
|
eligible_users.add(review["reviewer_id"])
|
|
|
|
# Add reviewers from history
|
|
for history_item in review.get("history", []):
|
|
if history_item.get("by"):
|
|
eligible_users.add(history_item["by"])
|
|
|
|
# Add all admin users (they can see all jobs)
|
|
admin_users = db["users"].find({"role": "admin"})
|
|
async for admin_user in admin_users:
|
|
user_id = str(admin_user["_id"])
|
|
eligible_users.add(user_id)
|
|
|
|
logger.debug(f"Job {job_id} notification eligible users: {len(eligible_users)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting job related users for {job_id}: {e}")
|
|
|
|
return eligible_users
|
|
|
|
async def _send_to_websockets(self, websockets: list[WebSocket], message: dict[str, Any]):
|
|
"""Send message to a list of WebSocket connections"""
|
|
disconnected_websockets = []
|
|
|
|
for websocket in websockets:
|
|
try:
|
|
await self._send_to_websocket(websocket, message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send to websocket: {e}")
|
|
disconnected_websockets.append(websocket)
|
|
|
|
# Clean up disconnected connections
|
|
for websocket in disconnected_websockets:
|
|
# Get user_id from metadata before disconnecting
|
|
async with self.lock:
|
|
meta = self.ws_meta.get(websocket, {})
|
|
user_id = meta.get("user_id")
|
|
if user_id:
|
|
await self.disconnect(websocket, user_id)
|
|
|
|
async def _send_to_websocket(self, websocket: WebSocket, message: dict[str, Any]):
|
|
"""Send message to a specific WebSocket connection"""
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"WebSocket send failed: {e}")
|
|
raise
|
|
|
|
|
|
# Global connection manager instance
|
|
connection_manager = ConnectionManager()
|
|
|
|
|
|
async def authenticate_websocket(websocket: WebSocket, token: Optional[str]) -> Optional[str]:
|
|
"""
|
|
Authenticate a WebSocket connection using a JWT token
|
|
Returns user_id if valid, None if invalid
|
|
"""
|
|
try:
|
|
if not token:
|
|
logger.warning("WebSocket authentication failed: Missing token")
|
|
await websocket.close(code=4001, reason="Missing authentication token")
|
|
return None
|
|
|
|
# Import JWT decode function
|
|
from ..core.security import decode_token
|
|
|
|
# Decode JWT token - this may raise HTTPException
|
|
try:
|
|
payload = decode_token(token)
|
|
if not payload or "sub" not in payload:
|
|
logger.warning("WebSocket authentication failed: Invalid token payload")
|
|
await websocket.close(code=4001, reason="Invalid authentication token")
|
|
return None
|
|
|
|
user_id = payload["sub"]
|
|
logger.info(f"WebSocket authentication successful for user: {user_id}")
|
|
return user_id
|
|
except Exception as jwt_error:
|
|
logger.warning(f"WebSocket authentication failed: JWT decode error: {jwt_error}")
|
|
await websocket.close(code=4001, reason="Invalid authentication token")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket authentication failed with unexpected error: {e}")
|
|
await websocket.close(code=4001, reason="Authentication failed")
|
|
return None
|
|
|
|
|
|
async def get_connection_manager() -> ConnectionManager:
|
|
"""Dependency to get the connection manager"""
|
|
return connection_manager
|