417 lines
No EOL
16 KiB
Python
417 lines
No EOL
16 KiB
Python
"""
|
|
WebSocket Connection Manager for Real-time Job Status Updates
|
|
|
|
This module provides WebSocket support for broadcasting job status changes
|
|
in real-time to connected clients. It uses Redis pub/sub for scalable
|
|
message broadcasting across multiple worker processes.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Set, Optional, Any
|
|
from datetime import datetime
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
import redis.asyncio as redis
|
|
import redis as sync_redis
|
|
from pydantic import BaseModel
|
|
|
|
from ..core.redis import get_redis_client
|
|
from ..core.security import decode_token
|
|
from ..core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JobStatusUpdate(BaseModel):
|
|
"""Schema for job status update messages"""
|
|
job_id: str
|
|
status: str
|
|
updated_at: datetime
|
|
job_title: Optional[str] = None # Job title for better user experience
|
|
message: Optional[str] = None
|
|
progress: Optional[int] = None # 0-100 percentage
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections and Redis pub/sub for job status updates"""
|
|
|
|
def __init__(self):
|
|
# Active WebSocket connections by user_id
|
|
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
|
# Job subscriptions: job_id -> set of user_ids
|
|
self.job_subscriptions: Dict[str, Set[str]] = {}
|
|
# Global job list subscriptions by user_id
|
|
self.global_subscriptions: Set[str] = set()
|
|
# Redis client for pub/sub
|
|
self.redis_client: Optional[redis.Redis] = None
|
|
self.pubsub: Optional[redis.client.PubSub] = None
|
|
self.subscriber_task: Optional[asyncio.Task] = None
|
|
|
|
async def start(self):
|
|
"""Initialize Redis pub/sub subscriber"""
|
|
try:
|
|
self.redis_client = await redis.from_url(
|
|
settings.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True
|
|
)
|
|
self.pubsub = self.redis_client.pubsub()
|
|
|
|
# Subscribe to job status channels
|
|
await self.pubsub.subscribe("job_status_updates") # Global channel
|
|
await self.pubsub.psubscribe("job_status_updates:*") # Pattern for individual job channels
|
|
|
|
# Start background task to handle Redis messages
|
|
self.subscriber_task = asyncio.create_task(self._redis_subscriber())
|
|
logger.info("WebSocket connection manager started")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to start WebSocket connection manager: {e}")
|
|
raise
|
|
|
|
async def stop(self):
|
|
"""Cleanup Redis connections"""
|
|
if self.subscriber_task:
|
|
self.subscriber_task.cancel()
|
|
try:
|
|
await self.subscriber_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self.pubsub:
|
|
await self.pubsub.unsubscribe()
|
|
await self.pubsub.punsubscribe()
|
|
await self.pubsub.aclose()
|
|
|
|
if self.redis_client:
|
|
await self.redis_client.aclose()
|
|
|
|
logger.info("WebSocket connection manager stopped")
|
|
|
|
async def connect_job_status(self, websocket: WebSocket, user_id: str, job_id: str):
|
|
"""Connect a WebSocket for specific job status updates"""
|
|
await websocket.accept()
|
|
|
|
# Add connection to active connections
|
|
if user_id not in self.active_connections:
|
|
self.active_connections[user_id] = set()
|
|
self.active_connections[user_id].add(websocket)
|
|
|
|
# Add job subscription
|
|
if job_id not in self.job_subscriptions:
|
|
self.job_subscriptions[job_id] = set()
|
|
self.job_subscriptions[job_id].add(user_id)
|
|
|
|
logger.info(f"User {user_id} connected for job {job_id} status updates")
|
|
|
|
# Send initial connection confirmation
|
|
await self._send_to_websocket(websocket, {
|
|
"type": "connection_established",
|
|
"job_id": job_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
async def connect_job_list(self, websocket: WebSocket, user_id: str):
|
|
"""Connect a WebSocket for job list updates (all jobs for a user)"""
|
|
await websocket.accept()
|
|
|
|
# Add connection to active connections
|
|
if user_id not in self.active_connections:
|
|
self.active_connections[user_id] = set()
|
|
self.active_connections[user_id].add(websocket)
|
|
|
|
# Add to global subscriptions
|
|
self.global_subscriptions.add(user_id)
|
|
|
|
logger.info(f"User {user_id} connected for job list updates")
|
|
|
|
# Send initial connection confirmation
|
|
await self._send_to_websocket(websocket, {
|
|
"type": "connection_established",
|
|
"scope": "job_list",
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
})
|
|
|
|
def disconnect(self, websocket: WebSocket, user_id: str):
|
|
"""Disconnect a WebSocket and clean up subscriptions"""
|
|
# Remove from active connections
|
|
if user_id in self.active_connections:
|
|
self.active_connections[user_id].discard(websocket)
|
|
if not self.active_connections[user_id]:
|
|
del self.active_connections[user_id]
|
|
|
|
# Remove from global subscriptions if no connections left
|
|
if user_id not in self.active_connections:
|
|
self.global_subscriptions.discard(user_id)
|
|
|
|
# Remove from job subscriptions
|
|
for job_id in list(self.job_subscriptions.keys()):
|
|
self.job_subscriptions[job_id].discard(user_id)
|
|
if not self.job_subscriptions[job_id]:
|
|
del self.job_subscriptions[job_id]
|
|
|
|
logger.info(f"User {user_id} disconnected from WebSocket")
|
|
|
|
async def broadcast_job_status_update(
|
|
self,
|
|
job_id: str,
|
|
status: str,
|
|
job_title: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
message: Optional[str] = None,
|
|
progress: Optional[int] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""
|
|
Broadcast job status update to Redis pub/sub
|
|
This will be called from Celery workers
|
|
"""
|
|
update = JobStatusUpdate(
|
|
job_id=job_id,
|
|
status=status,
|
|
updated_at=datetime.utcnow(),
|
|
job_title=job_title,
|
|
message=message,
|
|
progress=progress,
|
|
metadata=metadata
|
|
)
|
|
|
|
try:
|
|
# Create a synchronous Redis client for Celery workers
|
|
redis_client = sync_redis.Redis.from_url(
|
|
settings.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True
|
|
)
|
|
|
|
# Publish to global channel
|
|
redis_client.publish(
|
|
"job_status_updates",
|
|
update.model_dump_json()
|
|
)
|
|
|
|
# Publish to specific job channel
|
|
redis_client.publish(
|
|
f"job_status_updates:{job_id}",
|
|
update.model_dump_json()
|
|
)
|
|
|
|
# Close the connection
|
|
redis_client.close()
|
|
|
|
logger.debug(f"Broadcasted status update for job {job_id}: {status}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to broadcast job status update: {e}")
|
|
|
|
async def _redis_subscriber(self):
|
|
"""Background task to handle Redis pub/sub messages"""
|
|
try:
|
|
async for message in self.pubsub.listen():
|
|
# Handle both regular messages and pattern messages
|
|
if message["type"] in ("message", "pmessage"):
|
|
await self._handle_redis_message(message)
|
|
except asyncio.CancelledError:
|
|
logger.info("Redis subscriber task cancelled")
|
|
except Exception as e:
|
|
logger.error(f"Redis subscriber error: {e}")
|
|
|
|
async def _handle_redis_message(self, message: Dict[str, Any]):
|
|
"""Handle incoming Redis pub/sub message"""
|
|
try:
|
|
# For pattern messages, the channel is in the "channel" field
|
|
# For regular messages, it's also in the "channel" field
|
|
channel = message["channel"]
|
|
data = json.loads(message["data"])
|
|
update = JobStatusUpdate(**data)
|
|
|
|
logger.debug(f"Received Redis message on channel '{channel}': {data}")
|
|
|
|
# Send to specific job subscribers
|
|
if channel.startswith("job_status_updates:"):
|
|
job_id = channel.split(":", 1)[1]
|
|
logger.debug(f"Sending job status update for job {job_id} to subscribers")
|
|
await self._send_job_status_to_subscribers(job_id, update)
|
|
|
|
# Send to global subscribers (job list updates)
|
|
elif channel == "job_status_updates":
|
|
logger.debug(f"Sending global job status update to subscribers")
|
|
await self._send_job_status_to_global_subscribers(update)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to handle Redis message: {e}")
|
|
|
|
async def _send_job_status_to_subscribers(self, job_id: str, update: JobStatusUpdate):
|
|
"""Send job status update to specific job subscribers"""
|
|
if job_id not in self.job_subscriptions:
|
|
return
|
|
|
|
# Convert to JSON-serializable dict
|
|
message = {
|
|
"type": "job_status_update",
|
|
"data": json.loads(update.model_dump_json())
|
|
}
|
|
|
|
for user_id in list(self.job_subscriptions[job_id]):
|
|
await self._send_to_user(user_id, message)
|
|
|
|
async def _send_job_status_to_global_subscribers(self, update: JobStatusUpdate):
|
|
"""Send job status update to global (job list) subscribers with user filtering"""
|
|
# Convert to JSON-serializable dict
|
|
message = {
|
|
"type": "job_list_update",
|
|
"data": json.loads(update.model_dump_json())
|
|
}
|
|
|
|
# Get users who should receive this notification
|
|
eligible_users = await self._get_job_related_users(update.job_id)
|
|
|
|
# Only send to users who are both subscribed and have access to this job
|
|
for user_id in list(self.global_subscriptions):
|
|
if user_id in eligible_users:
|
|
await self._send_to_user(user_id, message)
|
|
|
|
async def _get_job_related_users(self, job_id: str) -> Set[str]:
|
|
"""
|
|
Get all users who should receive notifications for a specific job.
|
|
Returns set of user IDs for:
|
|
- Job creator (client_id)
|
|
- Reviewers who worked on the job
|
|
- Admin users (see all jobs)
|
|
"""
|
|
eligible_users = set()
|
|
|
|
try:
|
|
# Import database connection
|
|
from ..core.database import get_database
|
|
db = await get_database()
|
|
|
|
# Get the job
|
|
job = await db["jobs"].find_one({"_id": job_id})
|
|
if not job:
|
|
logger.warning(f"Job {job_id} not found for notification filtering")
|
|
return eligible_users
|
|
|
|
# Add job creator
|
|
if job.get("client_id"):
|
|
eligible_users.add(job["client_id"])
|
|
|
|
# Add reviewers from review history
|
|
review = job.get("review", {})
|
|
if review.get("reviewer_id"):
|
|
eligible_users.add(review["reviewer_id"])
|
|
|
|
# Add reviewers from history
|
|
for history_item in review.get("history", []):
|
|
if history_item.get("by"):
|
|
eligible_users.add(history_item["by"])
|
|
|
|
# Add all admin users (they can see all jobs)
|
|
admin_users = db["users"].find({"role": "admin"})
|
|
async for admin_user in admin_users:
|
|
user_id = str(admin_user["_id"])
|
|
eligible_users.add(user_id)
|
|
|
|
logger.debug(f"Job {job_id} notification eligible users: {len(eligible_users)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting job related users for {job_id}: {e}")
|
|
|
|
return eligible_users
|
|
|
|
async def _send_to_user(self, user_id: str, message: Dict[str, Any]):
|
|
"""Send message to all WebSocket connections for a user"""
|
|
if user_id not in self.active_connections:
|
|
return
|
|
|
|
# Send to all connections for this user
|
|
disconnected_connections = set()
|
|
for websocket in list(self.active_connections[user_id]):
|
|
try:
|
|
await self._send_to_websocket(websocket, message)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to send to websocket for user {user_id}: {e}")
|
|
disconnected_connections.add(websocket)
|
|
|
|
# Clean up disconnected connections
|
|
for websocket in disconnected_connections:
|
|
self.disconnect(websocket, user_id)
|
|
|
|
async def _send_to_websocket(self, websocket: WebSocket, message: Dict[str, Any]):
|
|
"""Send message to a specific WebSocket connection"""
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception as e:
|
|
logger.warning(f"WebSocket send failed: {e}")
|
|
raise
|
|
|
|
|
|
# Global connection manager instance
|
|
connection_manager = ConnectionManager()
|
|
|
|
|
|
async def authenticate_websocket(websocket: WebSocket, token: str) -> Optional[str]:
|
|
"""
|
|
Authenticate WebSocket connection using JWT token
|
|
Returns user_id if valid, None if invalid
|
|
"""
|
|
try:
|
|
if not token:
|
|
await websocket.close(code=4001, reason="Missing authentication token")
|
|
return None
|
|
|
|
# Decode JWT token
|
|
payload = decode_token(token)
|
|
if not payload or "sub" not in payload:
|
|
await websocket.close(code=4001, reason="Invalid authentication token")
|
|
return None
|
|
|
|
return payload["sub"] # user_id
|
|
|
|
except Exception as e:
|
|
logger.warning(f"WebSocket authentication failed: {e}")
|
|
await websocket.close(code=4001, reason="Authentication failed")
|
|
return None
|
|
|
|
|
|
async def authenticate_websocket(websocket: WebSocket, token: Optional[str]) -> Optional[str]:
|
|
"""
|
|
Authenticate a WebSocket connection using a JWT token
|
|
Returns user_id if valid, None if invalid
|
|
"""
|
|
try:
|
|
if not token:
|
|
logger.warning("WebSocket authentication failed: Missing token")
|
|
await websocket.close(code=4001, reason="Missing authentication token")
|
|
return None
|
|
|
|
# Import JWT decode function
|
|
from ..core.security import decode_token
|
|
|
|
# Decode JWT token - this may raise HTTPException
|
|
try:
|
|
payload = decode_token(token)
|
|
if not payload or "sub" not in payload:
|
|
logger.warning("WebSocket authentication failed: Invalid token payload")
|
|
await websocket.close(code=4001, reason="Invalid authentication token")
|
|
return None
|
|
|
|
user_id = payload["sub"]
|
|
logger.info(f"WebSocket authentication successful for user: {user_id}")
|
|
return user_id
|
|
except Exception as jwt_error:
|
|
logger.warning(f"WebSocket authentication failed: JWT decode error: {jwt_error}")
|
|
await websocket.close(code=4001, reason="Invalid authentication token")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket authentication failed with unexpected error: {e}")
|
|
await websocket.close(code=4001, reason="Authentication failed")
|
|
return None
|
|
|
|
|
|
async def get_connection_manager() -> ConnectionManager:
|
|
"""Dependency to get the connection manager"""
|
|
return connection_manager |