video-accessibility/backend/app/services/websocket.py

417 lines
No EOL
16 KiB
Python

"""
WebSocket Connection Manager for Real-time Job Status Updates
This module provides WebSocket support for broadcasting job status changes
in real-time to connected clients. It uses Redis pub/sub for scalable
message broadcasting across multiple worker processes.
"""
import asyncio
import json
import logging
from typing import Dict, List, Set, Optional, Any
from datetime import datetime
from fastapi import WebSocket, WebSocketDisconnect
import redis.asyncio as redis
import redis as sync_redis
from pydantic import BaseModel
from ..core.redis import get_redis_client
from ..core.security import decode_token
from ..core.config import settings
logger = logging.getLogger(__name__)
class JobStatusUpdate(BaseModel):
"""Schema for job status update messages"""
job_id: str
status: str
updated_at: datetime
job_title: Optional[str] = None # Job title for better user experience
message: Optional[str] = None
progress: Optional[int] = None # 0-100 percentage
metadata: Optional[Dict[str, Any]] = None
class ConnectionManager:
"""Manages WebSocket connections and Redis pub/sub for job status updates"""
def __init__(self):
# Active WebSocket connections by user_id
self.active_connections: Dict[str, Set[WebSocket]] = {}
# Job subscriptions: job_id -> set of user_ids
self.job_subscriptions: Dict[str, Set[str]] = {}
# Global job list subscriptions by user_id
self.global_subscriptions: Set[str] = set()
# Redis client for pub/sub
self.redis_client: Optional[redis.Redis] = None
self.pubsub: Optional[redis.client.PubSub] = None
self.subscriber_task: Optional[asyncio.Task] = None
async def start(self):
"""Initialize Redis pub/sub subscriber"""
try:
self.redis_client = await redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=True
)
self.pubsub = self.redis_client.pubsub()
# Subscribe to job status channels
await self.pubsub.subscribe("job_status_updates") # Global channel
await self.pubsub.psubscribe("job_status_updates:*") # Pattern for individual job channels
# Start background task to handle Redis messages
self.subscriber_task = asyncio.create_task(self._redis_subscriber())
logger.info("WebSocket connection manager started")
except Exception as e:
logger.error(f"Failed to start WebSocket connection manager: {e}")
raise
async def stop(self):
"""Cleanup Redis connections"""
if self.subscriber_task:
self.subscriber_task.cancel()
try:
await self.subscriber_task
except asyncio.CancelledError:
pass
if self.pubsub:
await self.pubsub.unsubscribe()
await self.pubsub.punsubscribe()
await self.pubsub.aclose()
if self.redis_client:
await self.redis_client.aclose()
logger.info("WebSocket connection manager stopped")
async def connect_job_status(self, websocket: WebSocket, user_id: str, job_id: str):
"""Connect a WebSocket for specific job status updates"""
await websocket.accept()
# Add connection to active connections
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
# Add job subscription
if job_id not in self.job_subscriptions:
self.job_subscriptions[job_id] = set()
self.job_subscriptions[job_id].add(user_id)
logger.info(f"User {user_id} connected for job {job_id} status updates")
# Send initial connection confirmation
await self._send_to_websocket(websocket, {
"type": "connection_established",
"job_id": job_id,
"timestamp": datetime.utcnow().isoformat()
})
async def connect_job_list(self, websocket: WebSocket, user_id: str):
"""Connect a WebSocket for job list updates (all jobs for a user)"""
await websocket.accept()
# Add connection to active connections
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
# Add to global subscriptions
self.global_subscriptions.add(user_id)
logger.info(f"User {user_id} connected for job list updates")
# Send initial connection confirmation
await self._send_to_websocket(websocket, {
"type": "connection_established",
"scope": "job_list",
"timestamp": datetime.utcnow().isoformat()
})
def disconnect(self, websocket: WebSocket, user_id: str):
"""Disconnect a WebSocket and clean up subscriptions"""
# Remove from active connections
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
# Remove from global subscriptions if no connections left
if user_id not in self.active_connections:
self.global_subscriptions.discard(user_id)
# Remove from job subscriptions
for job_id in list(self.job_subscriptions.keys()):
self.job_subscriptions[job_id].discard(user_id)
if not self.job_subscriptions[job_id]:
del self.job_subscriptions[job_id]
logger.info(f"User {user_id} disconnected from WebSocket")
async def broadcast_job_status_update(
self,
job_id: str,
status: str,
job_title: Optional[str] = None,
user_id: Optional[str] = None,
message: Optional[str] = None,
progress: Optional[int] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""
Broadcast job status update to Redis pub/sub
This will be called from Celery workers
"""
update = JobStatusUpdate(
job_id=job_id,
status=status,
updated_at=datetime.utcnow(),
job_title=job_title,
message=message,
progress=progress,
metadata=metadata
)
try:
# Create a synchronous Redis client for Celery workers
redis_client = sync_redis.Redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=True
)
# Publish to global channel
redis_client.publish(
"job_status_updates",
update.model_dump_json()
)
# Publish to specific job channel
redis_client.publish(
f"job_status_updates:{job_id}",
update.model_dump_json()
)
# Close the connection
redis_client.close()
logger.debug(f"Broadcasted status update for job {job_id}: {status}")
except Exception as e:
logger.error(f"Failed to broadcast job status update: {e}")
async def _redis_subscriber(self):
"""Background task to handle Redis pub/sub messages"""
try:
async for message in self.pubsub.listen():
# Handle both regular messages and pattern messages
if message["type"] in ("message", "pmessage"):
await self._handle_redis_message(message)
except asyncio.CancelledError:
logger.info("Redis subscriber task cancelled")
except Exception as e:
logger.error(f"Redis subscriber error: {e}")
async def _handle_redis_message(self, message: Dict[str, Any]):
"""Handle incoming Redis pub/sub message"""
try:
# For pattern messages, the channel is in the "channel" field
# For regular messages, it's also in the "channel" field
channel = message["channel"]
data = json.loads(message["data"])
update = JobStatusUpdate(**data)
logger.debug(f"Received Redis message on channel '{channel}': {data}")
# Send to specific job subscribers
if channel.startswith("job_status_updates:"):
job_id = channel.split(":", 1)[1]
logger.debug(f"Sending job status update for job {job_id} to subscribers")
await self._send_job_status_to_subscribers(job_id, update)
# Send to global subscribers (job list updates)
elif channel == "job_status_updates":
logger.debug(f"Sending global job status update to subscribers")
await self._send_job_status_to_global_subscribers(update)
except Exception as e:
logger.error(f"Failed to handle Redis message: {e}")
async def _send_job_status_to_subscribers(self, job_id: str, update: JobStatusUpdate):
"""Send job status update to specific job subscribers"""
if job_id not in self.job_subscriptions:
return
# Convert to JSON-serializable dict
message = {
"type": "job_status_update",
"data": json.loads(update.model_dump_json())
}
for user_id in list(self.job_subscriptions[job_id]):
await self._send_to_user(user_id, message)
async def _send_job_status_to_global_subscribers(self, update: JobStatusUpdate):
"""Send job status update to global (job list) subscribers with user filtering"""
# Convert to JSON-serializable dict
message = {
"type": "job_list_update",
"data": json.loads(update.model_dump_json())
}
# Get users who should receive this notification
eligible_users = await self._get_job_related_users(update.job_id)
# Only send to users who are both subscribed and have access to this job
for user_id in list(self.global_subscriptions):
if user_id in eligible_users:
await self._send_to_user(user_id, message)
async def _get_job_related_users(self, job_id: str) -> Set[str]:
"""
Get all users who should receive notifications for a specific job.
Returns set of user IDs for:
- Job creator (client_id)
- Reviewers who worked on the job
- Admin users (see all jobs)
"""
eligible_users = set()
try:
# Import database connection
from ..core.database import get_database
db = await get_database()
# Get the job
job = await db["jobs"].find_one({"_id": job_id})
if not job:
logger.warning(f"Job {job_id} not found for notification filtering")
return eligible_users
# Add job creator
if job.get("client_id"):
eligible_users.add(job["client_id"])
# Add reviewers from review history
review = job.get("review", {})
if review.get("reviewer_id"):
eligible_users.add(review["reviewer_id"])
# Add reviewers from history
for history_item in review.get("history", []):
if history_item.get("by"):
eligible_users.add(history_item["by"])
# Add all admin users (they can see all jobs)
admin_users = db["users"].find({"role": "admin"})
async for admin_user in admin_users:
user_id = str(admin_user["_id"])
eligible_users.add(user_id)
logger.debug(f"Job {job_id} notification eligible users: {len(eligible_users)}")
except Exception as e:
logger.error(f"Error getting job related users for {job_id}: {e}")
return eligible_users
async def _send_to_user(self, user_id: str, message: Dict[str, Any]):
"""Send message to all WebSocket connections for a user"""
if user_id not in self.active_connections:
return
# Send to all connections for this user
disconnected_connections = set()
for websocket in list(self.active_connections[user_id]):
try:
await self._send_to_websocket(websocket, message)
except Exception as e:
logger.warning(f"Failed to send to websocket for user {user_id}: {e}")
disconnected_connections.add(websocket)
# Clean up disconnected connections
for websocket in disconnected_connections:
self.disconnect(websocket, user_id)
async def _send_to_websocket(self, websocket: WebSocket, message: Dict[str, Any]):
"""Send message to a specific WebSocket connection"""
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"WebSocket send failed: {e}")
raise
# Global connection manager instance
connection_manager = ConnectionManager()
async def authenticate_websocket(websocket: WebSocket, token: str) -> Optional[str]:
"""
Authenticate WebSocket connection using JWT token
Returns user_id if valid, None if invalid
"""
try:
if not token:
await websocket.close(code=4001, reason="Missing authentication token")
return None
# Decode JWT token
payload = decode_token(token)
if not payload or "sub" not in payload:
await websocket.close(code=4001, reason="Invalid authentication token")
return None
return payload["sub"] # user_id
except Exception as e:
logger.warning(f"WebSocket authentication failed: {e}")
await websocket.close(code=4001, reason="Authentication failed")
return None
async def authenticate_websocket(websocket: WebSocket, token: Optional[str]) -> Optional[str]:
"""
Authenticate a WebSocket connection using a JWT token
Returns user_id if valid, None if invalid
"""
try:
if not token:
logger.warning("WebSocket authentication failed: Missing token")
await websocket.close(code=4001, reason="Missing authentication token")
return None
# Import JWT decode function
from ..core.security import decode_token
# Decode JWT token - this may raise HTTPException
try:
payload = decode_token(token)
if not payload or "sub" not in payload:
logger.warning("WebSocket authentication failed: Invalid token payload")
await websocket.close(code=4001, reason="Invalid authentication token")
return None
user_id = payload["sub"]
logger.info(f"WebSocket authentication successful for user: {user_id}")
return user_id
except Exception as jwt_error:
logger.warning(f"WebSocket authentication failed: JWT decode error: {jwt_error}")
await websocket.close(code=4001, reason="Invalid authentication token")
return None
except Exception as e:
logger.error(f"WebSocket authentication failed with unexpected error: {e}")
await websocket.close(code=4001, reason="Authentication failed")
return None
async def get_connection_manager() -> ConnectionManager:
"""Dependency to get the connection manager"""
return connection_manager