""" Async WebSocket Manager for Synthetic Society Handles WebSocket connections, room management, and real-time event broadcasting. Uses python-socketio AsyncServer for native Quart/ASGI compatibility. """ import logging import os import threading import asyncio from typing import Dict, Set, Any, Optional from datetime import datetime, timezone from .extensions import socketio_server as sio from app.auth.quart_jwt import decode_token # Set up logging logger = logging.getLogger(__name__) class AsyncWebSocketManager: """Manages WebSocket connections and rooms for focus group sessions using AsyncServer.""" def __init__(self): # Use singleton SocketIO AsyncServer instance self.sio = sio self.focus_group_rooms: Dict[str, Set[str]] = {} # focus_group_id -> set of session_ids self.user_sessions: Dict[str, Dict[str, Any]] = {} # session_id -> user info # Main ASGI event loop reference (set via set_main_loop during app startup) self._main_loop: Optional[asyncio.AbstractEventLoop] = None # Register SocketIO event handlers self._register_handlers() def set_main_loop(self, loop: asyncio.AbstractEventLoop) -> None: """Store the main ASGI event loop for cross-thread emission.""" self._main_loop = loop logger.info("AsyncWebSocketManager: main event loop registered") def _register_handlers(self): """Register all WebSocket event handlers.""" @self.sio.event async def cancel_task(sid, data): """Handle task cancellation requests via WebSocket.""" try: task_id = data.get('task_id') if not task_id: await self.sio.emit('error', {'message': 'Missing task_id'}, to=sid) return # Get user ID from session session_info = self.user_sessions.get(sid) if not session_info: await self.sio.emit('error', {'message': 'Invalid session'}, to=sid) return user_id = session_info.get('user_id') logger.info(f"WebSocket cancellation request for task {task_id} from user {user_id}") print(f"🔥 WebSocket: Received cancellation request for task {task_id}") # Cancel the task using task manager from app.services.task_manager import get_task_manager task_manager = get_task_manager() success = await task_manager.cancel_task(task_id) if success: # Broadcast cancellation success to user await self.emit_to_user( user_id, 'task_cancelled', { 'task_id': task_id, 'message': 'Task cancelled successfully' } ) print(f"✅ WebSocket: Successfully cancelled task {task_id}") logger.info(f"Successfully cancelled task {task_id} via WebSocket") else: await self.sio.emit('error', { 'message': 'Task not found or already completed', 'task_id': task_id }, to=sid) print(f"❌ WebSocket: Task {task_id} not found or already completed") except Exception as e: logger.error(f"Error in WebSocket task cancellation: {e}") print(f"❌ WebSocket: Cancellation error: {e}") await self.sio.emit('error', {'message': 'Cancellation failed'}, to=sid) @self.sio.event async def connect(sid, environ, auth): """Handle WebSocket connection.""" import os import threading process_id = os.getpid() thread_id = threading.get_ident() print(f"🔌 ASYNC PROCESS DEBUG - WebSocket connection attempt from {sid}") print(f"🔌 ASYNC PROCESS DEBUG - Connection handler PID: {process_id}, Thread: {thread_id}") logger.info(f"WebSocket connection attempt from {sid}") # Validate JWT token from auth data - require authentication if not auth or 'token' not in auth: logger.warning(f"WebSocket connection without auth token - rejecting") await self.sio.emit('auth_error', { 'error': 'missing_token', 'message': 'Authentication token required' }, to=sid) await self.sio.disconnect(sid) return False try: # Decode and validate JWT token with better error handling token = auth['token'] import jwt import os # Validate token format first if not token or not isinstance(token, str): raise ValueError("Invalid token format") # Check if token has the right number of segments (should have 3 parts separated by dots) token_parts = token.split('.') if len(token_parts) != 3: raise ValueError(f"Invalid JWT format: expected 3 segments, got {len(token_parts)}") # Get JWT secret from environment (same as our Quart JWT system) jwt_secret = os.environ.get('SECRET_KEY', '') if not jwt_secret: raise ValueError("SECRET_KEY not configured") try: decoded_token = jwt.decode(token, jwt_secret, algorithms=['HS256']) user_id = decoded_token.get('sub') if not user_id: raise ValueError("No user ID in token") except Exception as jwt_error: logger.warning(f"JWT decode failed: {jwt_error}") logger.warning(f"Token format: {len(token_parts)} segments, first 20 chars: {token[:20]}...") # Reject connection for expired or invalid tokens await self.sio.emit('auth_error', { 'error': 'invalid_token', 'message': f'Token validation failed: {str(jwt_error)}', 'expired': 'expired' in str(jwt_error).lower() }, to=sid) await self.sio.disconnect(sid) return False # Store user session info self.user_sessions[sid] = { 'user_id': user_id, 'connected_at': datetime.now(timezone.utc), 'focus_groups': set() } logger.info(f"WebSocket connected - Session: {sid}, User: {user_id}") # Emit connection success await self.sio.emit('connected', {'status': 'success', 'session_id': sid}, to=sid) except Exception as e: logger.error(f"Connection authentication failed: {e}") await self.sio.emit('auth_error', { 'error': 'authentication_failed', 'message': 'Authentication error occurred' }, to=sid) await self.sio.disconnect(sid) return False @self.sio.event async def disconnect(sid): """Handle WebSocket disconnection.""" session_id = sid if session_id in self.user_sessions: user_info = self.user_sessions[session_id] user_id = user_info['user_id'] # Leave all focus group rooms for focus_group_id in user_info['focus_groups'].copy(): await self._leave_focus_group_room(session_id, focus_group_id) # Clean up session del self.user_sessions[session_id] logger.info(f"WebSocket disconnected - Session: {session_id}, User: {user_id}") @self.sio.event async def join_focus_group(sid, data): """Handle joining a focus group room.""" session_id = sid if session_id not in self.user_sessions: await self.sio.emit('error', {'message': 'Session not authenticated'}, to=sid) return focus_group_id = data.get('focus_group_id') if not focus_group_id: await self.sio.emit('error', {'message': 'Focus group ID required'}, to=sid) return # Join the room success = await self._join_focus_group_room(session_id, focus_group_id) if success: await self.sio.emit('joined_focus_group', { 'focus_group_id': focus_group_id, 'status': 'success' }, to=sid) logger.info(f"User joined focus group room - Session: {session_id}, Group: {focus_group_id}") else: await self.sio.emit('error', {'message': 'Failed to join focus group'}, to=sid) @self.sio.event async def leave_focus_group(sid, data): """Handle leaving a focus group room.""" session_id = sid if session_id not in self.user_sessions: await self.sio.emit('error', {'message': 'Session not authenticated'}, to=sid) return focus_group_id = data.get('focus_group_id') if not focus_group_id: await self.sio.emit('error', {'message': 'Focus group ID required'}, to=sid) return # Leave the room success = await self._leave_focus_group_room(session_id, focus_group_id) if success: await self.sio.emit('left_focus_group', { 'focus_group_id': focus_group_id, 'status': 'success' }, to=sid) logger.info(f"User left focus group room - Session: {session_id}, Group: {focus_group_id}") async def _join_focus_group_room(self, session_id: str, focus_group_id: str) -> bool: """Join a user session to a focus group room.""" try: # Add to SocketIO room await self.sio.enter_room(session_id, focus_group_id) # Track in our data structures if focus_group_id not in self.focus_group_rooms: self.focus_group_rooms[focus_group_id] = set() self.focus_group_rooms[focus_group_id].add(session_id) self.user_sessions[session_id]['focus_groups'].add(focus_group_id) return True except Exception as e: logger.error(f"Failed to join focus group room: {e}") return False async def _leave_focus_group_room(self, session_id: str, focus_group_id: str) -> bool: """Remove a user session from a focus group room.""" try: # Leave SocketIO room await self.sio.leave_room(session_id, focus_group_id) # Clean up tracking if focus_group_id in self.focus_group_rooms: self.focus_group_rooms[focus_group_id].discard(session_id) # Remove room if empty if not self.focus_group_rooms[focus_group_id]: del self.focus_group_rooms[focus_group_id] if session_id in self.user_sessions: self.user_sessions[session_id]['focus_groups'].discard(focus_group_id) return True except Exception as e: logger.error(f"Failed to leave focus group room: {e}") return False async def emit_to_user(self, user_id: str, event: str, data: Any): """Emit an event to a specific user across all their sessions.""" try: user_sessions = [] # Find all sessions for this user for session_id, session_info in self.user_sessions.items(): if session_info.get('user_id') == user_id: user_sessions.append(session_id) if not user_sessions: logger.debug(f"No active sessions found for user {user_id}") return # Prepare the event data event_data = { 'user_id': user_id, 'timestamp': datetime.now(timezone.utc).isoformat(), **data } # Send to all user sessions for session_id in user_sessions: await self.sio.emit(event, event_data, to=session_id) logger.debug(f"Emitted '{event}' to user {user_id} ({len(user_sessions)} sessions)") except Exception as e: logger.error(f"Failed to emit to user {user_id}: {e}") async def emit_to_focus_group(self, focus_group_id: str, event: str, data: Any, include_sender: bool = True, sender_session_id: Optional[str] = None): """Emit an event to all users in a focus group room.""" process_id = os.getpid() thread_id = threading.get_ident() print(f"🔔 ASYNC PROCESS DEBUG - emit_to_focus_group called: {event} for focus group {focus_group_id}") print(f"🔔 ASYNC PROCESS DEBUG - PID: {process_id}, Thread: {thread_id}") print(f"🔔 Focus group rooms: {list(self.focus_group_rooms.keys())}") try: if focus_group_id not in self.focus_group_rooms: print(f"🔔 ASYNC ERROR: No active sessions for focus group {focus_group_id}") logger.debug(f"No active sessions for focus group {focus_group_id}") return room_name = focus_group_id room_sessions = self.focus_group_rooms[focus_group_id].copy() print(f"🔔 ASYNC Room {focus_group_id} has {len(room_sessions)} tracked sessions: {list(room_sessions)}") # Clean up stale sessions active_sessions = [] stale_sessions = [] for session_id in room_sessions: if session_id in self.user_sessions: active_sessions.append(session_id) else: stale_sessions.append(session_id) self.focus_group_rooms[focus_group_id].discard(session_id) if stale_sessions: print(f"🔔 ASYNC Cleaned up {len(stale_sessions)} stale sessions: {stale_sessions}") print(f"🔔 ASYNC Room {focus_group_id} has {len(active_sessions)} ACTIVE sessions: {active_sessions}") if not active_sessions: print(f"🔔 ASYNC ERROR: No active sessions remaining for focus group {focus_group_id} after cleanup") return # Prepare the event data event_data = { 'focus_group_id': focus_group_id, 'timestamp': datetime.now(timezone.utc).isoformat(), **data } # Detect if we're running on a different event loop than the ASGI server # (happens when called from the AI Runner's background thread) cross_loop = False if self._main_loop is not None: try: current_loop = asyncio.get_running_loop() cross_loop = (current_loop is not self._main_loop) except RuntimeError: cross_loop = True # No running loop on this thread if include_sender or not sender_session_id: # Send to all users in the room using AsyncServer print(f"🔔 ASYNC Emitting '{event}' to room {room_name} (cross_loop={cross_loop})") if cross_loop: future = asyncio.run_coroutine_threadsafe( self.sio.emit(event, event_data, room=room_name), self._main_loop ) future.result(timeout=5) else: await self.sio.emit(event, event_data, room=room_name) print(f"🔔 ASYNC Successfully emitted '{event}' to focus group {focus_group_id} ({len(active_sessions)} active users)") logger.debug(f"Emitted '{event}' to focus group {focus_group_id} ({len(active_sessions)} active users)") else: # Send to all users except the sender for session_id in active_sessions: if session_id != sender_session_id: if cross_loop: future = asyncio.run_coroutine_threadsafe( self.sio.emit(event, event_data, to=session_id), self._main_loop ) future.result(timeout=5) else: await self.sio.emit(event, event_data, to=session_id) logger.debug(f"Emitted '{event}' to focus group {focus_group_id} (excluding sender)") except Exception as e: logger.error(f"Failed to emit to focus group {focus_group_id}: {e}") print(f"🔔 ASYNC ERROR: Failed to emit to focus group {focus_group_id}: {e}") async def emit_message_update(self, focus_group_id: str, message_data: Dict[str, Any], sender_session_id: Optional[str] = None): """Emit a new message to focus group participants.""" await self.emit_to_focus_group( focus_group_id, 'message_update', {'message': message_data}, include_sender=True, sender_session_id=sender_session_id ) async def emit_ai_status_update(self, focus_group_id: str, status_data: Dict[str, Any]): """Emit AI status change to focus group participants.""" await self.emit_to_focus_group( focus_group_id, 'ai_status_update', {'status': status_data} ) async def emit_moderator_status_update(self, focus_group_id: str, moderator_status: Dict[str, Any]): """Emit moderator status change to focus group participants.""" await self.emit_to_focus_group( focus_group_id, 'moderator_status_update', {'moderator_status': moderator_status} ) async def emit_theme_update(self, focus_group_id: str, theme_data: Dict[str, Any], action: str = 'added'): """Emit theme update to focus group participants.""" await self.emit_to_focus_group( focus_group_id, 'theme_update', {'theme': theme_data, 'action': action} ) async def emit_analytics_update(self, focus_group_id: str, analytics_data: Dict[str, Any]): """Emit analytics update to focus group participants.""" await self.emit_to_focus_group( focus_group_id, 'analytics_update', {'analytics': analytics_data} ) async def emit_conversation_state_update(self, focus_group_id: str, state_data: Dict[str, Any]): """Emit conversation state update to focus group participants.""" await self.emit_to_focus_group( focus_group_id, 'conversation_state_update', {'state': state_data} ) def get_room_info(self, focus_group_id: str) -> Dict[str, Any]: """Get information about a focus group room.""" if focus_group_id not in self.focus_group_rooms: return {'active_sessions': 0, 'users': []} sessions = self.focus_group_rooms[focus_group_id] users = [] for session_id in sessions: if session_id in self.user_sessions: user_info = self.user_sessions[session_id] users.append({ 'session_id': session_id, 'user_id': user_info['user_id'], 'connected_at': user_info['connected_at'].isoformat() }) return { 'active_sessions': len(sessions), 'users': users } def get_connection_stats(self) -> Dict[str, Any]: """Get overall connection statistics.""" return { 'total_sessions': len(self.user_sessions), 'total_focus_groups': len(self.focus_group_rooms), 'focus_group_details': { fg_id: len(sessions) for fg_id, sessions in self.focus_group_rooms.items() } } # Global WebSocket manager instance websocket_manager: Optional[AsyncWebSocketManager] = None def init_async_websocket_manager() -> AsyncWebSocketManager: """Initialize the global async WebSocket manager.""" global websocket_manager websocket_manager = AsyncWebSocketManager() return websocket_manager def get_async_websocket_manager() -> Optional[AsyncWebSocketManager]: """Get the global async WebSocket manager instance.""" return websocket_manager # Async emit function for use throughout the codebase async def emit_websocket_event(event: str, data: dict, room: str | None = None) -> None: """ Async WebSocket event emission function. Args: event: Event name data: Event data room: Room to emit to (focus group ID) """ try: if websocket_manager and room: await websocket_manager.emit_to_focus_group(room, event, data) print(f"🔔 ASYNC - Successfully emitted {event} to room {room}") else: print(f"🔔 ASYNC - WebSocket manager not available or no room specified") except Exception as e: print(f"🔔 ASYNC ERROR emitting WebSocket event {event}: {e}") logger.exception(f"Error emitting WebSocket event {event}: {e}")