168 lines
No EOL
7.4 KiB
Python
168 lines
No EOL
7.4 KiB
Python
"""
|
|
Thread-Safe WebSocket Manager
|
|
Allows WebSocket events to be emitted from background threads (AI mode) to frontend clients.
|
|
Solves the cross-thread issue where AI processing runs in daemon threads but WebSocket
|
|
connections exist in the main Flask thread.
|
|
"""
|
|
|
|
import threading
|
|
import queue
|
|
import time
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime
|
|
|
|
class ThreadSafeWebSocketManager:
|
|
"""
|
|
Manages WebSocket events across thread boundaries.
|
|
|
|
Uses a thread-safe queue to pass WebSocket events from background threads
|
|
(like AI mode processing) to the main Flask thread where WebSocket connections exist.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Thread-safe queue for WebSocket events
|
|
self.event_queue = queue.Queue()
|
|
# Main thread WebSocket manager reference
|
|
self.main_websocket_manager = None
|
|
# Background processing thread
|
|
self.processing_thread = None
|
|
self.should_stop = threading.Event()
|
|
self.is_running = False
|
|
|
|
def set_main_websocket_manager(self, websocket_manager):
|
|
"""Set the main thread WebSocket manager reference."""
|
|
self.main_websocket_manager = websocket_manager
|
|
|
|
# Start background processing if not already running
|
|
if not self.is_running:
|
|
self.start_background_processing()
|
|
|
|
def start_background_processing(self):
|
|
"""Start the background thread that processes WebSocket events."""
|
|
if self.is_running:
|
|
return
|
|
|
|
self.should_stop.clear()
|
|
self.is_running = True
|
|
|
|
def process_events():
|
|
"""Background thread function that processes queued WebSocket events."""
|
|
print(f"🔄 ThreadSafeWebSocketManager: Background processing started in thread {threading.get_ident()}")
|
|
|
|
while not self.should_stop.is_set():
|
|
try:
|
|
# Get event from queue (blocking with timeout)
|
|
try:
|
|
event_data = self.event_queue.get(timeout=1.0)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
# Process the event
|
|
if self.main_websocket_manager and event_data:
|
|
self._process_websocket_event(event_data)
|
|
|
|
# Mark task as done
|
|
self.event_queue.task_done()
|
|
|
|
except Exception as e:
|
|
print(f"🔄 ThreadSafeWebSocketManager: Error processing event: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print(f"🔄 ThreadSafeWebSocketManager: Background processing stopped")
|
|
self.is_running = False
|
|
|
|
self.processing_thread = threading.Thread(target=process_events, daemon=True)
|
|
self.processing_thread.start()
|
|
|
|
def stop_background_processing(self):
|
|
"""Stop the background processing thread."""
|
|
if self.is_running:
|
|
self.should_stop.set()
|
|
if self.processing_thread:
|
|
self.processing_thread.join(timeout=5.0)
|
|
self.is_running = False
|
|
|
|
def _process_websocket_event(self, event_data: Dict[str, Any]):
|
|
"""Process a WebSocket event in the main thread context."""
|
|
try:
|
|
event_type = event_data.get('event_type')
|
|
focus_group_id = event_data.get('focus_group_id')
|
|
data = event_data.get('data', {})
|
|
|
|
current_thread = threading.get_ident()
|
|
print(f"🔄 ThreadSafeWebSocketManager: Processing {event_type} in thread {current_thread}")
|
|
|
|
# CRITICAL: Check if we're in the same thread as Flask-SocketIO
|
|
main_thread = threading.main_thread()
|
|
is_main_thread = threading.current_thread() is main_thread
|
|
print(f"🔄 ThreadSafeWebSocketManager: Current thread is main thread: {is_main_thread}")
|
|
print(f"🔄 ThreadSafeWebSocketManager: Main thread ID: {main_thread.ident}, Current: {current_thread}")
|
|
|
|
# Route to appropriate emission method
|
|
if event_type == 'message_update':
|
|
self.main_websocket_manager.emit_message_update(focus_group_id, data)
|
|
elif event_type == 'ai_status_update':
|
|
self.main_websocket_manager.emit_ai_status_update(focus_group_id, data)
|
|
elif event_type == 'theme_update':
|
|
theme_data = data.get('theme', {})
|
|
action = data.get('action', 'added')
|
|
self.main_websocket_manager.emit_theme_update(focus_group_id, theme_data, action)
|
|
elif event_type == 'moderator_status_update':
|
|
self.main_websocket_manager.emit_moderator_status_update(focus_group_id, data)
|
|
elif event_type == 'analytics_update':
|
|
self.main_websocket_manager.emit_analytics_update(focus_group_id, data)
|
|
elif event_type == 'conversation_state_update':
|
|
self.main_websocket_manager.emit_conversation_state_update(focus_group_id, data)
|
|
else:
|
|
# Generic emission
|
|
self.main_websocket_manager.emit_to_focus_group(focus_group_id, event_type, data)
|
|
|
|
print(f"✅ ThreadSafeWebSocketManager: Successfully processed {event_type} for focus group {focus_group_id}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ ThreadSafeWebSocketManager: Error processing event: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def emit_from_background_thread(self, event_type: str, focus_group_id: str, data: Dict[str, Any]):
|
|
"""
|
|
Emit a WebSocket event from a background thread.
|
|
|
|
This method can be called from any thread (including AI processing daemon threads).
|
|
The event will be queued and processed by the main thread.
|
|
"""
|
|
current_thread = threading.get_ident()
|
|
print(f"🔄 ThreadSafeWebSocketManager: Queueing {event_type} from thread {current_thread}")
|
|
|
|
event_data = {
|
|
'event_type': event_type,
|
|
'focus_group_id': focus_group_id,
|
|
'data': data,
|
|
'timestamp': datetime.utcnow().isoformat(),
|
|
'source_thread': current_thread
|
|
}
|
|
|
|
try:
|
|
self.event_queue.put(event_data, timeout=5.0) # 5 second timeout
|
|
print(f"✅ ThreadSafeWebSocketManager: Queued {event_type} for focus group {focus_group_id}")
|
|
except queue.Full:
|
|
print(f"❌ ThreadSafeWebSocketManager: Event queue is full, dropping {event_type}")
|
|
except Exception as e:
|
|
print(f"❌ ThreadSafeWebSocketManager: Error queueing event: {e}")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about the thread-safe WebSocket manager."""
|
|
return {
|
|
'is_running': self.is_running,
|
|
'queue_size': self.event_queue.qsize(),
|
|
'has_main_manager': self.main_websocket_manager is not None,
|
|
'processing_thread_alive': self.processing_thread.is_alive() if self.processing_thread else False
|
|
}
|
|
|
|
# Global instance
|
|
_thread_safe_manager = ThreadSafeWebSocketManager()
|
|
|
|
def get_thread_safe_websocket_manager() -> ThreadSafeWebSocketManager:
|
|
"""Get the global thread-safe WebSocket manager instance."""
|
|
return _thread_safe_manager |