semblance-dev/backend/app/websocket_manager_async.py
Vadym Samoilenko b4978989a5 Fix AI autonomous mode: cross-loop WebSocket emit + polling fallback
The AI Runner runs on a dedicated background thread with its own asyncio
event loop. When it emitted WebSocket events via sio.emit(), the call
happened on the wrong loop (AI Runner's vs ASGI/Quart's), causing silent
failures — messages were saved to MongoDB but never reached the frontend.

Additionally, the frontend HTTP polling fallback was never enabled when
WebSocket appeared connected, leaving no way to discover missed messages.

- websocket_manager_async.py: store ASGI main loop reference; detect
  cross-loop calls in emit_to_focus_group and use run_coroutine_threadsafe
  to schedule emits on the correct loop
- __init__.py: register the ASGI event loop with the WebSocket manager
  in before_serving hook
- FocusGroupSession.tsx: always poll fetchMessages every 3s during AI mode
  as a reliability fallback regardless of WebSocket status

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-23 18:22:24 +00:00

510 lines
No EOL
22 KiB
Python
Executable file

"""
Async WebSocket Manager for Synthetic Society
Handles WebSocket connections, room management, and real-time event broadcasting.
Uses python-socketio AsyncServer for native Quart/ASGI compatibility.
"""
import logging
import os
import threading
import asyncio
from typing import Dict, Set, Any, Optional
from datetime import datetime, timezone
from .extensions import socketio_server as sio
from app.auth.quart_jwt import decode_token
# Set up logging
logger = logging.getLogger(__name__)
class AsyncWebSocketManager:
"""Manages WebSocket connections and rooms for focus group sessions using AsyncServer."""
def __init__(self):
# Use singleton SocketIO AsyncServer instance
self.sio = sio
self.focus_group_rooms: Dict[str, Set[str]] = {} # focus_group_id -> set of session_ids
self.user_sessions: Dict[str, Dict[str, Any]] = {} # session_id -> user info
# Main ASGI event loop reference (set via set_main_loop during app startup)
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
# Register SocketIO event handlers
self._register_handlers()
def set_main_loop(self, loop: asyncio.AbstractEventLoop) -> None:
"""Store the main ASGI event loop for cross-thread emission."""
self._main_loop = loop
logger.info("AsyncWebSocketManager: main event loop registered")
def _register_handlers(self):
"""Register all WebSocket event handlers."""
@self.sio.event
async def cancel_task(sid, data):
"""Handle task cancellation requests via WebSocket."""
try:
task_id = data.get('task_id')
if not task_id:
await self.sio.emit('error', {'message': 'Missing task_id'}, to=sid)
return
# Get user ID from session
session_info = self.user_sessions.get(sid)
if not session_info:
await self.sio.emit('error', {'message': 'Invalid session'}, to=sid)
return
user_id = session_info.get('user_id')
logger.info(f"WebSocket cancellation request for task {task_id} from user {user_id}")
print(f"🔥 WebSocket: Received cancellation request for task {task_id}")
# Cancel the task using task manager
from app.services.task_manager import get_task_manager
task_manager = get_task_manager()
success = await task_manager.cancel_task(task_id)
if success:
# Broadcast cancellation success to user
await self.emit_to_user(
user_id,
'task_cancelled',
{
'task_id': task_id,
'message': 'Task cancelled successfully'
}
)
print(f"✅ WebSocket: Successfully cancelled task {task_id}")
logger.info(f"Successfully cancelled task {task_id} via WebSocket")
else:
await self.sio.emit('error', {
'message': 'Task not found or already completed',
'task_id': task_id
}, to=sid)
print(f"❌ WebSocket: Task {task_id} not found or already completed")
except Exception as e:
logger.error(f"Error in WebSocket task cancellation: {e}")
print(f"❌ WebSocket: Cancellation error: {e}")
await self.sio.emit('error', {'message': 'Cancellation failed'}, to=sid)
@self.sio.event
async def connect(sid, environ, auth):
"""Handle WebSocket connection."""
import os
import threading
process_id = os.getpid()
thread_id = threading.get_ident()
print(f"🔌 ASYNC PROCESS DEBUG - WebSocket connection attempt from {sid}")
print(f"🔌 ASYNC PROCESS DEBUG - Connection handler PID: {process_id}, Thread: {thread_id}")
logger.info(f"WebSocket connection attempt from {sid}")
# Validate JWT token from auth data - require authentication
if not auth or 'token' not in auth:
logger.warning(f"WebSocket connection without auth token - rejecting")
await self.sio.emit('auth_error', {
'error': 'missing_token',
'message': 'Authentication token required'
}, to=sid)
await self.sio.disconnect(sid)
return False
try:
# Decode and validate JWT token with better error handling
token = auth['token']
import jwt
import os
# Validate token format first
if not token or not isinstance(token, str):
raise ValueError("Invalid token format")
# Check if token has the right number of segments (should have 3 parts separated by dots)
token_parts = token.split('.')
if len(token_parts) != 3:
raise ValueError(f"Invalid JWT format: expected 3 segments, got {len(token_parts)}")
# Get JWT secret from environment (same as our Quart JWT system)
jwt_secret = os.environ.get('SECRET_KEY', '')
if not jwt_secret:
raise ValueError("SECRET_KEY not configured")
try:
decoded_token = jwt.decode(token, jwt_secret, algorithms=['HS256'])
user_id = decoded_token.get('sub')
if not user_id:
raise ValueError("No user ID in token")
except Exception as jwt_error:
logger.warning(f"JWT decode failed: {jwt_error}")
logger.warning(f"Token format: {len(token_parts)} segments, first 20 chars: {token[:20]}...")
# Reject connection for expired or invalid tokens
await self.sio.emit('auth_error', {
'error': 'invalid_token',
'message': f'Token validation failed: {str(jwt_error)}',
'expired': 'expired' in str(jwt_error).lower()
}, to=sid)
await self.sio.disconnect(sid)
return False
# Store user session info
self.user_sessions[sid] = {
'user_id': user_id,
'connected_at': datetime.now(timezone.utc),
'focus_groups': set()
}
logger.info(f"WebSocket connected - Session: {sid}, User: {user_id}")
# Emit connection success
await self.sio.emit('connected', {'status': 'success', 'session_id': sid}, to=sid)
except Exception as e:
logger.error(f"Connection authentication failed: {e}")
await self.sio.emit('auth_error', {
'error': 'authentication_failed',
'message': 'Authentication error occurred'
}, to=sid)
await self.sio.disconnect(sid)
return False
@self.sio.event
async def disconnect(sid):
"""Handle WebSocket disconnection."""
session_id = sid
if session_id in self.user_sessions:
user_info = self.user_sessions[session_id]
user_id = user_info['user_id']
# Leave all focus group rooms
for focus_group_id in user_info['focus_groups'].copy():
await self._leave_focus_group_room(session_id, focus_group_id)
# Clean up session
del self.user_sessions[session_id]
logger.info(f"WebSocket disconnected - Session: {session_id}, User: {user_id}")
@self.sio.event
async def join_focus_group(sid, data):
"""Handle joining a focus group room."""
session_id = sid
if session_id not in self.user_sessions:
await self.sio.emit('error', {'message': 'Session not authenticated'}, to=sid)
return
focus_group_id = data.get('focus_group_id')
if not focus_group_id:
await self.sio.emit('error', {'message': 'Focus group ID required'}, to=sid)
return
# Join the room
success = await self._join_focus_group_room(session_id, focus_group_id)
if success:
await self.sio.emit('joined_focus_group', {
'focus_group_id': focus_group_id,
'status': 'success'
}, to=sid)
logger.info(f"User joined focus group room - Session: {session_id}, Group: {focus_group_id}")
else:
await self.sio.emit('error', {'message': 'Failed to join focus group'}, to=sid)
@self.sio.event
async def leave_focus_group(sid, data):
"""Handle leaving a focus group room."""
session_id = sid
if session_id not in self.user_sessions:
await self.sio.emit('error', {'message': 'Session not authenticated'}, to=sid)
return
focus_group_id = data.get('focus_group_id')
if not focus_group_id:
await self.sio.emit('error', {'message': 'Focus group ID required'}, to=sid)
return
# Leave the room
success = await self._leave_focus_group_room(session_id, focus_group_id)
if success:
await self.sio.emit('left_focus_group', {
'focus_group_id': focus_group_id,
'status': 'success'
}, to=sid)
logger.info(f"User left focus group room - Session: {session_id}, Group: {focus_group_id}")
async def _join_focus_group_room(self, session_id: str, focus_group_id: str) -> bool:
"""Join a user session to a focus group room."""
try:
# Add to SocketIO room
await self.sio.enter_room(session_id, focus_group_id)
# Track in our data structures
if focus_group_id not in self.focus_group_rooms:
self.focus_group_rooms[focus_group_id] = set()
self.focus_group_rooms[focus_group_id].add(session_id)
self.user_sessions[session_id]['focus_groups'].add(focus_group_id)
return True
except Exception as e:
logger.error(f"Failed to join focus group room: {e}")
return False
async def _leave_focus_group_room(self, session_id: str, focus_group_id: str) -> bool:
"""Remove a user session from a focus group room."""
try:
# Leave SocketIO room
await self.sio.leave_room(session_id, focus_group_id)
# Clean up tracking
if focus_group_id in self.focus_group_rooms:
self.focus_group_rooms[focus_group_id].discard(session_id)
# Remove room if empty
if not self.focus_group_rooms[focus_group_id]:
del self.focus_group_rooms[focus_group_id]
if session_id in self.user_sessions:
self.user_sessions[session_id]['focus_groups'].discard(focus_group_id)
return True
except Exception as e:
logger.error(f"Failed to leave focus group room: {e}")
return False
async def emit_to_user(self, user_id: str, event: str, data: Any):
"""Emit an event to a specific user across all their sessions."""
try:
user_sessions = []
# Find all sessions for this user
for session_id, session_info in self.user_sessions.items():
if session_info.get('user_id') == user_id:
user_sessions.append(session_id)
if not user_sessions:
logger.debug(f"No active sessions found for user {user_id}")
return
# Prepare the event data
event_data = {
'user_id': user_id,
'timestamp': datetime.now(timezone.utc).isoformat(),
**data
}
# Send to all user sessions
for session_id in user_sessions:
await self.sio.emit(event, event_data, to=session_id)
logger.debug(f"Emitted '{event}' to user {user_id} ({len(user_sessions)} sessions)")
except Exception as e:
logger.error(f"Failed to emit to user {user_id}: {e}")
async def emit_to_focus_group(self, focus_group_id: str, event: str, data: Any, include_sender: bool = True, sender_session_id: Optional[str] = None):
"""Emit an event to all users in a focus group room."""
process_id = os.getpid()
thread_id = threading.get_ident()
print(f"🔔 ASYNC PROCESS DEBUG - emit_to_focus_group called: {event} for focus group {focus_group_id}")
print(f"🔔 ASYNC PROCESS DEBUG - PID: {process_id}, Thread: {thread_id}")
print(f"🔔 Focus group rooms: {list(self.focus_group_rooms.keys())}")
try:
if focus_group_id not in self.focus_group_rooms:
print(f"🔔 ASYNC ERROR: No active sessions for focus group {focus_group_id}")
logger.debug(f"No active sessions for focus group {focus_group_id}")
return
room_name = focus_group_id
room_sessions = self.focus_group_rooms[focus_group_id].copy()
print(f"🔔 ASYNC Room {focus_group_id} has {len(room_sessions)} tracked sessions: {list(room_sessions)}")
# Clean up stale sessions
active_sessions = []
stale_sessions = []
for session_id in room_sessions:
if session_id in self.user_sessions:
active_sessions.append(session_id)
else:
stale_sessions.append(session_id)
self.focus_group_rooms[focus_group_id].discard(session_id)
if stale_sessions:
print(f"🔔 ASYNC Cleaned up {len(stale_sessions)} stale sessions: {stale_sessions}")
print(f"🔔 ASYNC Room {focus_group_id} has {len(active_sessions)} ACTIVE sessions: {active_sessions}")
if not active_sessions:
print(f"🔔 ASYNC ERROR: No active sessions remaining for focus group {focus_group_id} after cleanup")
return
# Prepare the event data
event_data = {
'focus_group_id': focus_group_id,
'timestamp': datetime.now(timezone.utc).isoformat(),
**data
}
# Detect if we're running on a different event loop than the ASGI server
# (happens when called from the AI Runner's background thread)
cross_loop = False
if self._main_loop is not None:
try:
current_loop = asyncio.get_running_loop()
cross_loop = (current_loop is not self._main_loop)
except RuntimeError:
cross_loop = True # No running loop on this thread
if include_sender or not sender_session_id:
# Send to all users in the room using AsyncServer
print(f"🔔 ASYNC Emitting '{event}' to room {room_name} (cross_loop={cross_loop})")
if cross_loop:
future = asyncio.run_coroutine_threadsafe(
self.sio.emit(event, event_data, room=room_name),
self._main_loop
)
future.result(timeout=5)
else:
await self.sio.emit(event, event_data, room=room_name)
print(f"🔔 ASYNC Successfully emitted '{event}' to focus group {focus_group_id} ({len(active_sessions)} active users)")
logger.debug(f"Emitted '{event}' to focus group {focus_group_id} ({len(active_sessions)} active users)")
else:
# Send to all users except the sender
for session_id in active_sessions:
if session_id != sender_session_id:
if cross_loop:
future = asyncio.run_coroutine_threadsafe(
self.sio.emit(event, event_data, to=session_id),
self._main_loop
)
future.result(timeout=5)
else:
await self.sio.emit(event, event_data, to=session_id)
logger.debug(f"Emitted '{event}' to focus group {focus_group_id} (excluding sender)")
except Exception as e:
logger.error(f"Failed to emit to focus group {focus_group_id}: {e}")
print(f"🔔 ASYNC ERROR: Failed to emit to focus group {focus_group_id}: {e}")
async def emit_message_update(self, focus_group_id: str, message_data: Dict[str, Any], sender_session_id: Optional[str] = None):
"""Emit a new message to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'message_update',
{'message': message_data},
include_sender=True,
sender_session_id=sender_session_id
)
async def emit_ai_status_update(self, focus_group_id: str, status_data: Dict[str, Any]):
"""Emit AI status change to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'ai_status_update',
{'status': status_data}
)
async def emit_moderator_status_update(self, focus_group_id: str, moderator_status: Dict[str, Any]):
"""Emit moderator status change to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'moderator_status_update',
{'moderator_status': moderator_status}
)
async def emit_theme_update(self, focus_group_id: str, theme_data: Dict[str, Any], action: str = 'added'):
"""Emit theme update to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'theme_update',
{'theme': theme_data, 'action': action}
)
async def emit_analytics_update(self, focus_group_id: str, analytics_data: Dict[str, Any]):
"""Emit analytics update to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'analytics_update',
{'analytics': analytics_data}
)
async def emit_conversation_state_update(self, focus_group_id: str, state_data: Dict[str, Any]):
"""Emit conversation state update to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'conversation_state_update',
{'state': state_data}
)
def get_room_info(self, focus_group_id: str) -> Dict[str, Any]:
"""Get information about a focus group room."""
if focus_group_id not in self.focus_group_rooms:
return {'active_sessions': 0, 'users': []}
sessions = self.focus_group_rooms[focus_group_id]
users = []
for session_id in sessions:
if session_id in self.user_sessions:
user_info = self.user_sessions[session_id]
users.append({
'session_id': session_id,
'user_id': user_info['user_id'],
'connected_at': user_info['connected_at'].isoformat()
})
return {
'active_sessions': len(sessions),
'users': users
}
def get_connection_stats(self) -> Dict[str, Any]:
"""Get overall connection statistics."""
return {
'total_sessions': len(self.user_sessions),
'total_focus_groups': len(self.focus_group_rooms),
'focus_group_details': {
fg_id: len(sessions) for fg_id, sessions in self.focus_group_rooms.items()
}
}
# Global WebSocket manager instance
websocket_manager: Optional[AsyncWebSocketManager] = None
def init_async_websocket_manager() -> AsyncWebSocketManager:
"""Initialize the global async WebSocket manager."""
global websocket_manager
websocket_manager = AsyncWebSocketManager()
return websocket_manager
def get_async_websocket_manager() -> Optional[AsyncWebSocketManager]:
"""Get the global async WebSocket manager instance."""
return websocket_manager
# Async emit function for use throughout the codebase
async def emit_websocket_event(event: str, data: dict, room: str | None = None) -> None:
"""
Async WebSocket event emission function.
Args:
event: Event name
data: Event data
room: Room to emit to (focus group ID)
"""
try:
if websocket_manager and room:
await websocket_manager.emit_to_focus_group(room, event, data)
print(f"🔔 ASYNC - Successfully emitted {event} to room {room}")
else:
print(f"🔔 ASYNC - WebSocket manager not available or no room specified")
except Exception as e:
print(f"🔔 ASYNC ERROR emitting WebSocket event {event}: {e}")
logger.exception(f"Error emitting WebSocket event {event}: {e}")