video-accessibility/backend/app/services/websocket.py

432 lines
16 KiB
Python

"""
WebSocket Connection Manager for Real-time Job Status Updates
This module provides WebSocket support for broadcasting job status changes
in real-time to connected clients. It uses Redis pub/sub for scalable
message broadcasting across multiple worker processes.
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Any, Optional
import redis.asyncio as redis
from fastapi import WebSocket
from pydantic import BaseModel
from ..core.config import settings
logger = logging.getLogger(__name__)
class JobStatusUpdate(BaseModel):
"""Schema for job status update messages"""
job_id: str
status: str
updated_at: datetime
job_title: Optional[str] = None # Job title for better user experience
message: Optional[str] = None
progress: Optional[int] = None # 0-100 percentage
metadata: Optional[dict[str, Any]] = None
eligible_users: Optional[set[str]] = None # Pre-computed eligible users
class ConnectionManager:
"""Manages WebSocket connections and Redis pub/sub for job status updates"""
def __init__(self):
# WebSocket connections by user_id
self.user_ws: dict[str, set[WebSocket]] = {}
# WebSocket metadata: websocket -> {user_id, jobs, scopes}
self.ws_meta: dict[WebSocket, dict[str, Any]] = {}
# Job subscriptions: job_id -> set of websockets
self.job_ws: dict[str, set[WebSocket]] = {}
# Lock for thread safety
self.lock = asyncio.Lock()
# Redis client for pub/sub
self.redis_client: Optional[redis.Redis] = None
self.pubsub: Optional[redis.client.PubSub] = None
self.subscriber_task: Optional[asyncio.Task] = None
async def start(self):
"""Initialize Redis pub/sub subscriber"""
try:
self.redis_client = redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=True
)
self.pubsub = self.redis_client.pubsub()
# Subscribe to job status channels
await self.pubsub.subscribe("job_status_updates") # Global channel
await self.pubsub.psubscribe("job_status_updates:*") # Pattern for individual job channels
# Start background task to handle Redis messages
self.subscriber_task = asyncio.create_task(self._redis_subscriber())
logger.info("WebSocket connection manager started")
except Exception as e:
logger.error(f"Failed to start WebSocket connection manager: {e}")
raise
async def stop(self):
"""Cleanup Redis connections"""
if self.subscriber_task:
self.subscriber_task.cancel()
try:
await self.subscriber_task
except asyncio.CancelledError:
pass
if self.pubsub:
await self.pubsub.unsubscribe()
await self.pubsub.punsubscribe()
await self.pubsub.aclose()
if self.redis_client:
await self.redis_client.aclose()
logger.info("WebSocket connection manager stopped")
async def connect_job_status(self, websocket: WebSocket, user_id: str, job_id: str):
"""Connect a WebSocket for specific job status updates"""
await websocket.accept()
async with self.lock:
# Add to user connections
if user_id not in self.user_ws:
self.user_ws[user_id] = set()
self.user_ws[user_id].add(websocket)
# Initialize/update websocket metadata
if websocket not in self.ws_meta:
self.ws_meta[websocket] = {
"user_id": user_id,
"jobs": set(),
"scopes": set()
}
self.ws_meta[websocket]["jobs"].add(job_id)
# Add to job subscriptions
if job_id not in self.job_ws:
self.job_ws[job_id] = set()
self.job_ws[job_id].add(websocket)
logger.info(f"User {user_id} connected for job {job_id} status updates")
# Send initial connection confirmation
await self._send_to_websocket(websocket, {
"type": "connection_established",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat()
})
async def connect_job_list(self, websocket: WebSocket, user_id: str):
"""Connect a WebSocket for job list updates (all jobs for a user)"""
await websocket.accept()
async with self.lock:
# Add to user connections
if user_id not in self.user_ws:
self.user_ws[user_id] = set()
self.user_ws[user_id].add(websocket)
# Initialize/update websocket metadata
if websocket not in self.ws_meta:
self.ws_meta[websocket] = {
"user_id": user_id,
"jobs": set(),
"scopes": set()
}
self.ws_meta[websocket]["scopes"].add("job_list")
logger.info(f"User {user_id} connected for job list updates")
# Send initial connection confirmation
await self._send_to_websocket(websocket, {
"type": "connection_established",
"scope": "job_list",
"timestamp": datetime.utcnow().isoformat()
})
async def disconnect(self, websocket: WebSocket, user_id: str):
"""Disconnect a WebSocket and clean up subscriptions"""
async with self.lock:
# Get websocket metadata
meta = self.ws_meta.pop(websocket, None)
if not meta:
return
# Remove from job subscriptions
for job_id in meta.get("jobs", set()):
if job_id in self.job_ws:
self.job_ws[job_id].discard(websocket)
if not self.job_ws[job_id]:
del self.job_ws[job_id]
# Remove from user connections
if user_id in self.user_ws:
self.user_ws[user_id].discard(websocket)
if not self.user_ws[user_id]:
del self.user_ws[user_id]
logger.info(f"User {user_id} disconnected from WebSocket")
async def broadcast_job_status_update(
self,
job_id: str,
status: str,
job_title: Optional[str] = None,
message: Optional[str] = None,
progress: Optional[int] = None,
metadata: Optional[dict[str, Any]] = None
):
"""
Async wrapper for broadcasting job status updates from API routes
For Celery workers, use websocket_publisher.publish_job_update_with_eligibility() directly
"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from .websocket_publisher import publish_job_update_with_eligibility
# Run the sync publisher in a thread pool
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=1) as executor:
await loop.run_in_executor(
executor,
publish_job_update_with_eligibility,
job_id,
status,
job_title,
message,
progress,
metadata
)
async def _redis_subscriber(self):
"""Background task to handle Redis pub/sub messages with reconnection logic"""
delay = 1 # Start with 1 second delay
max_delay = 30 # Maximum delay of 30 seconds
while True:
try:
# (Re)create pubsub connection
if self.pubsub:
try:
await self.pubsub.aclose()
except Exception:
pass
self.pubsub = self.redis_client.pubsub()
# Subscribe to channels
await self.pubsub.subscribe("job_status_updates")
await self.pubsub.psubscribe("job_status_updates:*")
logger.info("Redis subscriber connected and subscribed")
delay = 1 # Reset delay on successful connection
# Listen for messages
async for message in self.pubsub.listen():
if message["type"] in ("message", "pmessage"):
await self._handle_redis_message(message)
except asyncio.CancelledError:
logger.info("Redis subscriber task cancelled")
break
except Exception as e:
logger.error(f"Redis subscriber error, retrying in {delay}s: {e}")
await asyncio.sleep(delay)
delay = min(delay * 2, max_delay) # Exponential backoff
async def _handle_redis_message(self, message: dict[str, Any]):
"""Handle incoming Redis pub/sub message"""
try:
# For pattern messages, the channel is in the "channel" field
# For regular messages, it's also in the "channel" field
channel = message["channel"]
data = json.loads(message["data"])
update = JobStatusUpdate(**data)
logger.debug(f"Received Redis message on channel '{channel}': {data}")
# Send to specific job subscribers
if channel.startswith("job_status_updates:"):
job_id = channel.split(":", 1)[1]
logger.debug(f"Sending job status update for job {job_id} to subscribers")
await self._send_job_status_to_subscribers(job_id, update)
# Send to global subscribers (job list updates)
elif channel == "job_status_updates":
logger.debug("Sending global job status update to subscribers")
await self._send_job_status_to_global_subscribers(update)
except Exception as e:
logger.error(f"Failed to handle Redis message: {e}")
async def _send_job_status_to_subscribers(self, job_id: str, update: JobStatusUpdate):
"""Send job status update to specific job subscribers"""
async with self.lock:
target_websockets = list(self.job_ws.get(job_id, set()))
if not target_websockets:
return
# Convert to JSON-serializable dict
message = {
"type": "job_status_update",
"data": update.model_dump(mode="json")
}
await self._send_to_websockets(target_websockets, message)
async def _send_job_status_to_global_subscribers(self, update: JobStatusUpdate):
"""Send job status update to global (job list) subscribers with user filtering"""
# Convert to JSON-serializable dict
message_data = update.model_dump(mode="json")
# Remove eligible_users from the client message
message_data.pop("eligible_users", None)
message = {
"type": "job_list_update",
"data": message_data
}
# Use pre-computed eligible users if available, otherwise compute them
eligible_users = getattr(update, 'eligible_users', None)
if eligible_users is None:
eligible_users = await self._get_job_related_users(update.job_id)
# Find websockets for eligible users that have job_list scope
target_websockets = []
async with self.lock:
for user_id in eligible_users:
for websocket in self.user_ws.get(user_id, set()):
meta = self.ws_meta.get(websocket, {})
if "job_list" in meta.get("scopes", set()):
target_websockets.append(websocket)
await self._send_to_websockets(target_websockets, message)
async def _get_job_related_users(self, job_id: str) -> set[str]:
"""
Get all users who should receive notifications for a specific job.
Returns set of user IDs for:
- Job creator (client_id)
- Reviewers who worked on the job
- Admin users (see all jobs)
"""
eligible_users = set()
try:
# Import database connection
from ..core.database import get_database
db = await get_database()
# Get the job
job = await db["jobs"].find_one({"_id": job_id})
if not job:
logger.warning(f"Job {job_id} not found for notification filtering")
return eligible_users
# Add job creator
if job.get("client_id"):
eligible_users.add(job["client_id"])
# Add reviewers from review history
review = job.get("review", {})
if review.get("reviewer_id"):
eligible_users.add(review["reviewer_id"])
# Add reviewers from history
for history_item in review.get("history", []):
if history_item.get("by"):
eligible_users.add(history_item["by"])
# Add all admin users (they can see all jobs)
admin_users = db["users"].find({"role": "admin"})
async for admin_user in admin_users:
user_id = str(admin_user["_id"])
eligible_users.add(user_id)
logger.debug(f"Job {job_id} notification eligible users: {len(eligible_users)}")
except Exception as e:
logger.error(f"Error getting job related users for {job_id}: {e}")
return eligible_users
async def _send_to_websockets(self, websockets: list[WebSocket], message: dict[str, Any]):
"""Send message to a list of WebSocket connections"""
disconnected_websockets = []
for websocket in websockets:
try:
await self._send_to_websocket(websocket, message)
except Exception as e:
logger.warning(f"Failed to send to websocket: {e}")
disconnected_websockets.append(websocket)
# Clean up disconnected connections
for websocket in disconnected_websockets:
# Get user_id from metadata before disconnecting
async with self.lock:
meta = self.ws_meta.get(websocket, {})
user_id = meta.get("user_id")
if user_id:
await self.disconnect(websocket, user_id)
async def _send_to_websocket(self, websocket: WebSocket, message: dict[str, Any]):
"""Send message to a specific WebSocket connection"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"WebSocket send failed: {e}")
raise
# Global connection manager instance
connection_manager = ConnectionManager()
async def authenticate_websocket(websocket: WebSocket, token: Optional[str]) -> Optional[str]:
"""
Authenticate a WebSocket connection using a JWT token
Returns user_id if valid, None if invalid
"""
try:
if not token:
logger.warning("WebSocket authentication failed: Missing token")
await websocket.close(code=4001, reason="Missing authentication token")
return None
# Import JWT decode function
from ..core.security import decode_token
# Decode JWT token - this may raise HTTPException
try:
payload = decode_token(token)
if not payload or "sub" not in payload:
logger.warning("WebSocket authentication failed: Invalid token payload")
await websocket.close(code=4001, reason="Invalid authentication token")
return None
user_id = payload["sub"]
logger.info(f"WebSocket authentication successful for user: {user_id}")
return user_id
except Exception as jwt_error:
logger.warning(f"WebSocket authentication failed: JWT decode error: {jwt_error}")
await websocket.close(code=4001, reason="Invalid authentication token")
return None
except Exception as e:
logger.error(f"WebSocket authentication failed with unexpected error: {e}")
await websocket.close(code=4001, reason="Authentication failed")
return None
async def get_connection_manager() -> ConnectionManager:
"""Dependency to get the connection manager"""
return connection_manager