major refactor of entire application - migrate sync -> async including pymongo -> motor, flask -> quart, google-generativeai -> google-genai

This commit is contained in:
michael 2025-08-27 15:20:56 -05:00
parent 6fa8d5ec55
commit fe9b146375
69 changed files with 2379 additions and 1078 deletions

View file

@ -18,7 +18,11 @@
"Bash(find:*)",
"Bash(npx tsc:*)",
"WebFetch(domain:platform.openai.com)",
"WebFetch(domain:cookbook.openai.com)"
"WebFetch(domain:cookbook.openai.com)",
"Bash(pip uninstall:*)",
"Bash(pip install:*)",
"mcp__gpt5-bridge__call_gpt5",
"WebSearch"
],
"deny": []
},

1
.gitignore vendored
View file

@ -4,3 +4,4 @@ dist/
# Ignore Python cache files
__pycache__/
*.py[cod]
*pycache*

View file

@ -1,10 +1,10 @@
from flask import Flask
from flask_cors import CORS
from flask_jwt_extended import JWTManager
from flask_socketio import SocketIO
from quart import Quart
from quart_cors import cors
# No longer using Flask-JWT-Extended - replaced with Quart-compatible JWT
from dotenv import load_dotenv
import os
import tempfile
import asyncio
load_dotenv()
@ -43,10 +43,10 @@ def setup_temp_directories():
return temp_dir, upload_dir
def create_app():
# Set up temp directories BEFORE creating Flask app
# Set up temp directories BEFORE creating Quart app
temp_dir, upload_dir = setup_temp_directories()
app = Flask(__name__)
app = Quart(__name__)
# Setup custom logging configuration
try:
@ -67,14 +67,14 @@ def create_app():
app.config['TIMEOUT'] = 300 # 5 minutes
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16 MB max upload
# Configure Flask/Werkzeug file upload settings
# Configure Quart/Werkzeug file upload settings
app.config['UPLOAD_FOLDER'] = upload_dir
app.config['UPLOAD_EXTENSIONS'] = ['.jpg', '.jpeg', '.png']
# Configure temp directory for Flask/Werkzeug
# Configure temp directory for Quart/Werkzeug
if temp_dir and os.path.isdir(temp_dir):
app.config['TEMP_FOLDER'] = temp_dir
print(f"Flask configured with temp directory: {temp_dir}")
print(f"Quart configured with temp directory: {temp_dir}")
# Additional Werkzeug configuration for multipart form handling
app.config['MAX_CONTENT_PATH'] = None # Don't limit content path
@ -83,19 +83,20 @@ def create_app():
app.config['MAX_FORM_MEMORY_SIZE'] = 16 * 1024 * 1024 # Keep small uploads in memory
# Initialize extensions
CORS(app, resources={r"/api/*": {"origins": "*"}})
jwt = JWTManager(app)
app = cors(app, allow_origin="*", allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
# Initialize SocketIO using singleton pattern (GPT-5 fix for AI mode WebSocket issues)
from .extensions import socketio
socketio.init_app(app) # Bind the singleton SocketIO instance to this Flask app
# JWT is now handled by custom Quart-compatible auth system
# No longer using JWTManager(app) due to Flask/Quart incompatibility
# Initialize AsyncServer WebSocket functionality
from .extensions import socketio_server
# Store socketio reference on app for backward compatibility
app.socketio = socketio
app.socketio = socketio_server
# Initialize WebSocket manager with singleton SocketIO
from app.websocket_manager import init_websocket_manager
websocket_manager = init_websocket_manager() # No parameter needed - uses singleton
# Initialize async WebSocket manager
from app.websocket_manager_async import init_async_websocket_manager
websocket_manager = init_async_websocket_manager()
# Debug tap removed - using simpler GPT-5 diagnostic logging instead
@ -103,8 +104,12 @@ def create_app():
import threading
main_process_id = os.getpid()
main_thread_id = threading.get_ident()
print(f"🔌 PROCESS DEBUG - Flask app initialized with WebSocket manager")
print(f"🔌 PROCESS DEBUG - Main Flask PID: {main_process_id}, Thread: {main_thread_id}")
print(f"🔌 PROCESS DEBUG - Quart app initialized with WebSocket manager")
print(f"🔌 PROCESS DEBUG - Main Quart PID: {main_process_id}, Thread: {main_thread_id}")
# Initialize AI Runner service for autonomous conversations
from app.services.ai_runner_service import init_ai_runner
init_ai_runner()
# Register blueprints
from app.routes.auth import auth_bp
@ -126,7 +131,11 @@ def create_app():
def health_check():
return {'status': 'ok', 'message': 'Backend is running'}, 200
# Store socketio reference on app for access in routes
app.socketio = socketio
# Create ASGI app with SocketIO integration
import socketio as socketio_pkg
asgi_app = socketio_pkg.ASGIApp(socketio_server, app)
return app
# Store reference to the original Quart app for access in routes
asgi_app.quart_app = app
return asgi_app

View file

@ -0,0 +1,9 @@
"""
Quart Authentication Module
Provides JWT authentication functionality compatible with Quart ASGI applications.
"""
from .quart_jwt import jwt_required, get_jwt_identity, create_access_token, decode_token
__all__ = ['jwt_required', 'get_jwt_identity', 'create_access_token', 'decode_token']

View file

@ -0,0 +1,204 @@
"""
Quart-compatible JWT Authentication
Replacement for Flask-JWT-Extended to work with Quart ASGI applications.
Provides jwt_required decorator and token management functions.
"""
import os
import jwt
import functools
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from quart import request, g, current_app, jsonify
# JWT Configuration - ensure compatibility with Flask-JWT-Extended
JWT_SECRET_KEY = os.environ.get('SECRET_KEY', 'your-secret-key-for-sessions-and-tokens')
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=24)
JWT_ALGORITHM = 'HS256'
class QuartJWTError(Exception):
"""Base exception for JWT errors in Quart."""
pass
def create_access_token(identity: str, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token.
Args:
identity: User identifier (usually user ID)
expires_delta: Optional expiration time override
Returns:
JWT token string
"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + JWT_ACCESS_TOKEN_EXPIRES
payload = {
'sub': identity, # Subject (user ID)
'exp': expire,
'iat': datetime.utcnow(),
'type': 'access'
}
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
def decode_token(token: str) -> Dict[str, Any]:
"""
Decode and validate a JWT token.
Args:
token: JWT token string
Returns:
Decoded token payload
Raises:
QuartJWTError: If token is invalid
"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
raise QuartJWTError("Token has expired")
except jwt.InvalidTokenError as e:
raise QuartJWTError(f"Invalid token: {str(e)}")
def get_jwt_identity() -> Optional[str]:
"""
Get the identity (user ID) from the current JWT token.
Returns:
User ID from token, or None if no valid token
"""
try:
return getattr(g, 'current_user_id', None)
except Exception:
return None
def jwt_required(optional: bool = False):
"""
Decorator to require valid JWT token for route access.
Args:
optional: If True, allow access without token (but still decode if present)
"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
# Get token from Authorization header
auth_header = request.headers.get('Authorization')
token = None
if auth_header:
# Expected format: "Bearer <token>"
parts = auth_header.split()
if len(parts) == 2 and parts[0].lower() == 'bearer':
token = parts[1]
if not token:
if optional:
# No token provided but optional - allow access
g.current_user_id = None
result = await func(*args, **kwargs)
# Handle tuple returns
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
from quart import make_response
return make_response(response, status_code)
else:
return result
else:
# Token required but not provided
from quart import make_response
return make_response(jsonify({'error': 'Missing authorization token'}), 401)
# Validate token
try:
payload = decode_token(token)
user_id = payload.get('sub')
if not user_id:
raise QuartJWTError("No user ID in token")
# Store user ID in request context
g.current_user_id = user_id
# Call the actual route function and handle tuple returns
result = await func(*args, **kwargs)
# Handle tuple returns (response, status_code)
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
# Use make_response for non-response objects
from quart import make_response
return make_response(response, status_code)
else:
return result
except QuartJWTError as e:
if optional:
# Invalid token but optional - allow access without user ID
g.current_user_id = None
result = await func(*args, **kwargs)
# Handle tuple returns here too
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
from quart import make_response
return make_response(response, status_code)
else:
return result
else:
# Invalid token and required
from quart import make_response
return make_response(jsonify({'error': f'Invalid token: {str(e)}'}), 401)
except Exception as e:
current_app.logger.error(f"JWT validation error: {e}")
if optional:
g.current_user_id = None
result = await func(*args, **kwargs)
# Handle tuple returns here too
if isinstance(result, tuple) and len(result) == 2:
response, status_code = result
if hasattr(response, 'status_code'):
response.status_code = status_code
return response
else:
from quart import make_response
return make_response(response, status_code)
else:
return result
else:
from quart import make_response
return make_response(jsonify({'error': 'Authentication error'}), 500)
return wrapper
return decorator
# For backward compatibility, provide the same function names as Flask-JWT-Extended
def get_current_user():
"""Get current user ID (alias for get_jwt_identity)."""
return get_jwt_identity()

View file

@ -1,9 +1,27 @@
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
import os
import logging
# MongoDB connection
def get_db():
# Global Motor client singleton - per event loop
_motor_clients = {} # event_loop_id -> (client, database)
async def get_db():
"""Get database connection using singleton Motor client per event loop."""
import asyncio
# Get current event loop to ensure Motor client affinity
try:
current_loop = asyncio.get_running_loop()
loop_id = id(current_loop)
except RuntimeError:
raise RuntimeError("get_db() must be called from within an async context")
# Return cached database for this event loop if available
if loop_id in _motor_clients:
client, database = _motor_clients[loop_id]
return database
# Try to read environment variables for MongoDB credentials
mongo_user = os.environ.get('MONGO_USER')
mongo_pass = os.environ.get('MONGO_PASS')
@ -22,28 +40,33 @@ def get_db():
for creds in standard_credentials:
try:
uri = f"mongodb://{creds['user']}:{creds['pass']}@{mongo_host}:{mongo_port}/semblance_db?authSource={creds['db']}"
client = MongoClient(uri, serverSelectionTimeoutMS=2000)
db = client.semblance_db
motor_client = AsyncIOMotorClient(uri, serverSelectionTimeoutMS=2000)
database = motor_client.semblance_db
# Test the connection with a simple command
db.command('ping')
await database.command('ping')
logging.debug(f"Successfully connected to MongoDB with standard credentials ({creds['user']})")
return db
# Cache for this event loop
_motor_clients[loop_id] = (motor_client, database)
return database
except Exception as e:
# Continue trying other credentials
pass
# Try to connect without authentication if standard credentials don't work
try:
client = MongoClient(f'mongodb://{mongo_host}:{mongo_port}', serverSelectionTimeoutMS=5000)
# Simply use the database as is - MongoDB will allow this if auth is not required
db = client.semblance_db
motor_client = AsyncIOMotorClient(f'mongodb://{mongo_host}:{mongo_port}', serverSelectionTimeoutMS=5000)
database = motor_client.semblance_db
# Test the connection with a simple command
db.command('ping')
await database.command('ping')
# Try a write operation to verify we have proper access
test_result = db.test_collection.insert_one({"test": "auth_test"})
db.test_collection.delete_one({"_id": test_result.inserted_id})
test_result = await database.test_collection.insert_one({"test": "auth_test"})
await database.test_collection.delete_one({"_id": test_result.inserted_id})
logging.debug("Successfully connected to MongoDB without authentication")
return db
# Cache for this event loop
_motor_clients[loop_id] = (motor_client, database)
return database
except Exception as e:
logging.debug(f"Could not connect without auth: {e}")
@ -51,11 +74,14 @@ def get_db():
if mongo_user and mongo_pass:
try:
uri = f"mongodb://{mongo_user}:{mongo_pass}@{mongo_host}:{mongo_port}/semblance_db?authSource=admin"
client = MongoClient(uri, serverSelectionTimeoutMS=5000)
db = client.semblance_db
db.command('ping') # Test the connection
motor_client = AsyncIOMotorClient(uri, serverSelectionTimeoutMS=5000)
database = motor_client.semblance_db
await database.command('ping') # Test the connection
logging.debug(f"Successfully connected to MongoDB with credentials for user: {mongo_user}")
return db
# Cache for this event loop
_motor_clients[loop_id] = (motor_client, database)
return database
except Exception as e:
logging.warning(f"Failed to connect with environment credentials: {e}")
@ -63,5 +89,27 @@ def get_db():
logging.warning("Could not authenticate with MongoDB. If authentication is required, operations will fail.")
logging.warning("To fix this: Set MONGO_USER and MONGO_PASS environment variables.")
# Return a client that will likely fail when operations are performed, but the app will start
client = MongoClient(f'mongodb://{mongo_host}:{mongo_port}', serverSelectionTimeoutMS=5000)
return client.semblance_db
motor_client = AsyncIOMotorClient(f'mongodb://{mongo_host}:{mongo_port}', serverSelectionTimeoutMS=5000)
database = motor_client.semblance_db
# Cache for this event loop
_motor_clients[loop_id] = (motor_client, database)
return database
def close_db_connections():
"""Close all Motor clients and their PyMongo background threads."""
global _motor_clients
closed_count = 0
for loop_id, (client, database) in _motor_clients.items():
try:
client.close()
closed_count += 1
except Exception as e:
logging.warning(f"Error closing Motor client for loop {loop_id}: {e}")
if closed_count > 0:
logging.info(f"🗄️ Closed {closed_count} Motor clients - PyMongo threads should stop")
_motor_clients.clear()

View file

@ -1,20 +1,24 @@
"""
Flask Extensions Module
Provides singleton instances of Flask extensions to ensure consistency across the application.
This fixes the WebSocket AI mode issue by ensuring all parts of the app use the same SocketIO instance.
Quart Extensions Module
Provides singleton instances of Quart extensions to ensure consistency across the application.
Uses python-socketio AsyncServer for native Quart/ASGI compatibility.
"""
from flask_socketio import SocketIO
import socketio
import logging
# Create the SINGLE SocketIO instance that will be used throughout the application
# This is the singleton pattern recommended by GPT-5 to fix AI mode WebSocket issues
socketio = SocketIO(
# Set up logging for socketio
socketio_logger = logging.getLogger('socketio')
socketio_logger.setLevel(logging.WARNING) # Reduce socketio log noise
# Create the AsyncServer instance for Quart/ASGI compatibility
socketio_server = socketio.AsyncServer(
async_mode='asgi',
cors_allowed_origins="*",
async_mode="eventlet",
ping_timeout=120, # 2 minutes timeout for ping response
ping_interval=45, # Send ping every 45 seconds
logger=False, # Disable verbose socketio logging (reduces log noise)
engineio_logger=False # Disable verbose engineio logging (reduces PING/PONG spam)
ping_timeout=120, # 2 minutes timeout for ping response
ping_interval=45, # Send ping every 45 seconds
logger=False, # Disable verbose socketio logging
engineio_logger=False # Disable verbose engineio logging
)
# Note: The app will be bound to this instance using socketio.init_app(app) in create_app()
# Note: This will be wrapped with socketio.ASGIApp in create_app() to integrate with Quart

View file

@ -7,9 +7,9 @@ import os
import threading
import eventlet
def emit_websocket_event(event_name: str, focus_group_id: str, data: dict):
"""Helper function to emit WebSocket events using queue-based emitter (GPT-5 fix)."""
from app.websocket_manager import emit_websocket_event as queue_emit
async def emit_websocket_event(event_name: str, focus_group_id: str, data: dict):
"""Helper function to emit WebSocket events using async WebSocket manager."""
from app.websocket_manager_async import emit_websocket_event as async_emit
process_id = os.getpid()
thread_id = threading.get_ident()
@ -38,9 +38,9 @@ def emit_websocket_event(event_name: str, focus_group_id: str, data: dict):
**data
}
# Emit to the specific focus group room using the queue-based system
queue_emit(event_name, event_data, focus_group_id)
print(f"🔔 Successfully queued {event_name} for focus group {focus_group_id}")
# Emit to the specific focus group room using the async system
await async_emit(event_name, event_data, focus_group_id)
print(f"🔔 Successfully emitted {event_name} for focus group {focus_group_id}")
except Exception as e:
print(f"🔔 ERROR emitting WebSocket event {event_name}: {e}")
@ -65,8 +65,8 @@ def emit_with_ack(event_name: str, focus_group_id: str, payload: dict):
class FocusGroup:
@staticmethod
def create(focus_group_data, user_id):
db = get_db()
async def create(focus_group_data, user_id):
db = await get_db()
# Add metadata
focus_group_data["created_at"] = datetime.utcnow()
@ -87,14 +87,14 @@ class FocusGroup:
if "verbosity" not in focus_group_data:
focus_group_data["verbosity"] = "medium"
result = db.focus_groups.insert_one(focus_group_data)
result = await db.focus_groups.insert_one(focus_group_data)
return str(result.inserted_id)
@staticmethod
def find_by_id(focus_group_id):
db = get_db()
async def find_by_id(focus_group_id):
db = await get_db()
try:
focus_group = db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
focus_group = await db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
if focus_group:
focus_group["_id"] = str(focus_group["_id"])
return focus_group
@ -102,9 +102,10 @@ class FocusGroup:
return None
@staticmethod
def find_by_user(user_id, limit=50):
db = get_db()
focus_groups = db.focus_groups.find({"created_by": user_id}).sort("created_at", -1).limit(limit)
async def find_by_user(user_id, limit=50):
db = await get_db()
cursor = db.focus_groups.find({"created_by": user_id}).sort("created_at", -1).limit(limit)
focus_groups = await cursor.to_list(length=limit)
result = []
for group in focus_groups:
@ -114,21 +115,22 @@ class FocusGroup:
return result
@staticmethod
def get_all(limit=50):
async def get_all(limit=50):
import logging
logger = logging.getLogger('app.focus_group_model')
try:
logger.debug(f"=== FocusGroup.get_all() called with limit={limit} ===")
db = get_db()
db = await get_db()
logger.debug(f"Database connection obtained: {db}")
# Check if collection exists and has data
collection = db.focus_groups
total_count = collection.count_documents({})
total_count = await collection.count_documents({})
logger.debug(f"Total focus groups in database: {total_count}")
focus_groups = list(db.focus_groups.find().sort("created_at", -1).limit(limit))
cursor = db.focus_groups.find().sort("created_at", -1).limit(limit)
focus_groups = await cursor.to_list(length=limit)
logger.debug(f"Query returned {len(focus_groups)} focus groups")
result = []
@ -147,8 +149,8 @@ class FocusGroup:
return []
@staticmethod
def update(focus_group_id, data):
db = get_db()
async def update(focus_group_id, data):
db = await get_db()
# Create a copy of the data to avoid modifying the original
filtered_data = data.copy()
@ -177,7 +179,7 @@ class FocusGroup:
except:
pass
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{"$set": filtered_data}
)
@ -186,7 +188,7 @@ class FocusGroup:
if result.modified_count > 0:
# Emit status change event if status was updated
if 'status' in filtered_data:
emit_websocket_event('ai_status_update', focus_group_id, {
await emit_websocket_event('ai_status_update', focus_group_id, {
'status': {
'status': filtered_data['status'], # Frontend expects nested structure
'updated_at': filtered_data["updated_at"].isoformat()
@ -195,7 +197,7 @@ class FocusGroup:
# Emit model change event if LLM model was updated
if 'llm_model' in filtered_data:
emit_websocket_event('focus_group_update', focus_group_id, {
await emit_websocket_event('focus_group_update', focus_group_id, {
'llm_model': filtered_data['llm_model'],
'reasoning_effort': filtered_data.get('reasoning_effort'),
'verbosity': filtered_data.get('verbosity'),
@ -206,7 +208,7 @@ class FocusGroup:
if 'llm_model' in filtered_data and result.modified_count > 0:
try:
# Re-read the document to verify the update
updated_doc = db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
updated_doc = await db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
actual_model = updated_doc.get('llm_model') if updated_doc else None
log_msg = f"🔍 [{datetime.utcnow()}] POST-UPDATE VERIFICATION: Expected '{filtered_data['llm_model']}', got '{actual_model}' for {focus_group_id}\n"
with open('/tmp/focus_group_debug.log', 'a') as f:
@ -283,9 +285,9 @@ class FocusGroup:
return cleaned_files, failed_files
@staticmethod
def _cleanup_focus_group_collections(focus_group_id):
async def _cleanup_focus_group_collections(focus_group_id):
"""Clean up all related collection documents for a focus group."""
db = get_db()
db = await get_db()
cleaned_collections = []
failed_collections = []
@ -301,7 +303,7 @@ class FocusGroup:
for collection_name, field_name in collections_to_clean:
try:
collection = getattr(db, collection_name)
result = collection.delete_many({field_name: focus_group_id})
result = await collection.delete_many({field_name: focus_group_id})
if result.deleted_count > 0:
cleaned_collections.append(f"{collection_name}: {result.deleted_count} documents")
print(f"Cleaned up {result.deleted_count} documents from {collection_name}")
@ -314,13 +316,13 @@ class FocusGroup:
return cleaned_collections, failed_collections
@staticmethod
def delete(focus_group_id):
async def delete(focus_group_id):
"""Delete a focus group and all its associated data including creative assets."""
db = get_db()
db = await get_db()
try:
# First, get the focus group data to access uploaded assets
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
print(f"Focus group {focus_group_id} not found")
return False
@ -331,10 +333,10 @@ class FocusGroup:
cleaned_files, failed_files = FocusGroup._cleanup_focus_group_assets(focus_group_id, uploaded_assets)
# Clean up related collections
cleaned_collections, failed_collections = FocusGroup._cleanup_focus_group_collections(focus_group_id)
cleaned_collections, failed_collections = await FocusGroup._cleanup_focus_group_collections(focus_group_id)
# Finally, delete the main focus group document
result = db.focus_groups.delete_one({"_id": ObjectId(focus_group_id)})
result = await db.focus_groups.delete_one({"_id": ObjectId(focus_group_id)})
if result.deleted_count > 0:
print(f"Successfully deleted focus group {focus_group_id}")
@ -357,32 +359,33 @@ class FocusGroup:
return False
@staticmethod
def add_participant(focus_group_id, persona_id):
db = get_db()
result = db.focus_groups.update_one(
async def add_participant(focus_group_id, persona_id):
db = await get_db()
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{"$addToSet": {"participants": persona_id}}
)
return result.modified_count > 0
@staticmethod
def remove_participant(focus_group_id, persona_id):
db = get_db()
result = db.focus_groups.update_one(
async def remove_participant(focus_group_id, persona_id):
db = await get_db()
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{"$pull": {"participants": persona_id}}
)
return result.modified_count > 0
@staticmethod
def get_messages(focus_group_id, limit=100):
async def get_messages(focus_group_id, limit=100):
"""Get all messages for a focus group."""
db = get_db()
db = await get_db()
try:
# Get all messages and sort chronologically
messages = list(db.focus_group_messages.find(
cursor = db.focus_group_messages.find(
{"focus_group_id": focus_group_id}
).sort("created_at", 1))
).sort("created_at", 1)
messages = await cursor.to_list(length=None)
# Convert ObjectId to strings
for message in messages:
@ -396,12 +399,12 @@ class FocusGroup:
return []
@staticmethod
def add_message(focus_group_id, message_data):
async def add_message(focus_group_id, message_data):
"""Add a new message to a focus group."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return None
@ -419,7 +422,7 @@ class FocusGroup:
}
# Insert the message
result = db.focus_group_messages.insert_one(message)
result = await db.focus_group_messages.insert_one(message)
if result.inserted_id:
message_id = str(result.inserted_id)
@ -427,7 +430,7 @@ class FocusGroup:
# If this message activates visual context, update the focus group's active visual context
if message.get("activates_visual_context") and message.get("attached_assets"):
FocusGroup._activate_visual_assets(focus_group_id, message.get("attached_assets"), message_id)
await FocusGroup._activate_visual_assets(focus_group_id, message.get("attached_assets"), message_id)
# Emit WebSocket event for new message
message_for_websocket = {
@ -443,7 +446,7 @@ class FocusGroup:
}
print(f"🔔 EMITTING WEBSOCKET EVENT: message_update for focus group {focus_group_id}")
print(f"🔔 Message data: sender={message_for_websocket['senderId']}, type={message_for_websocket['type']}")
emit_websocket_event('message_update', focus_group_id, message_for_websocket)
await emit_websocket_event('message_update', focus_group_id, message_for_websocket)
return message_id
else:
@ -455,17 +458,17 @@ class FocusGroup:
return None
@staticmethod
def update_message_highlight(focus_group_id, message_id, highlighted):
async def update_message_highlight(focus_group_id, message_id, highlighted):
"""Update the highlighted status of a message."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return False
# Update the message
result = db.focus_group_messages.update_one(
result = await db.focus_group_messages.update_one(
{"_id": ObjectId(message_id), "focus_group_id": focus_group_id},
{"$set": {"highlighted": highlighted, "updated_at": datetime.utcnow()}}
)
@ -477,14 +480,15 @@ class FocusGroup:
return False
@staticmethod
def get_generated_themes(focus_group_id):
async def get_generated_themes(focus_group_id):
"""Get all generated themes for a focus group."""
db = get_db()
db = await get_db()
try:
# Get themes associated with this focus group
themes = list(db.focus_group_themes.find(
cursor = db.focus_group_themes.find(
{"focus_group_id": focus_group_id}
).sort("created_at", -1))
).sort("created_at", -1)
themes = await cursor.to_list(length=None)
# Convert ObjectId to strings
for theme in themes:
@ -498,12 +502,12 @@ class FocusGroup:
return []
@staticmethod
def add_generated_theme(focus_group_id, theme_data):
async def add_generated_theme(focus_group_id, theme_data):
"""Add a new generated theme to a focus group."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return None
@ -519,13 +523,13 @@ class FocusGroup:
}
# Insert the theme
result = db.focus_group_themes.insert_one(theme)
result = await db.focus_group_themes.insert_one(theme)
if result.inserted_id:
theme["_id"] = str(result.inserted_id)
# Emit WebSocket event for new theme
emit_websocket_event('theme_update', focus_group_id, {
await emit_websocket_event('theme_update', focus_group_id, {
'theme': {
'id': theme["id"],
'title': theme["title"],
@ -545,12 +549,12 @@ class FocusGroup:
return None
@staticmethod
def add_generated_themes(focus_group_id, themes_data):
async def add_generated_themes(focus_group_id, themes_data):
"""Add multiple generated themes to a focus group."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return None
@ -574,11 +578,11 @@ class FocusGroup:
# Insert the themes
if themes:
result = db.focus_group_themes.insert_many(themes)
result = await db.focus_group_themes.insert_many(themes)
# Emit WebSocket events for all new themes
for theme in themes:
emit_websocket_event('theme_update', focus_group_id, {
await emit_websocket_event('theme_update', focus_group_id, {
'theme': {
'id': theme["id"],
'title': theme["title"],
@ -600,12 +604,12 @@ class FocusGroup:
return []
@staticmethod
def delete_generated_theme(focus_group_id, theme_id):
async def delete_generated_theme(focus_group_id, theme_id):
"""Delete a generated theme from a focus group."""
db = get_db()
db = await get_db()
try:
# Delete the theme
result = db.focus_group_themes.delete_one(
result = await db.focus_group_themes.delete_one(
{"focus_group_id": focus_group_id, "id": theme_id}
)
@ -616,14 +620,15 @@ class FocusGroup:
return False
@staticmethod
def get_reasoning_history(focus_group_id, limit=50):
async def get_reasoning_history(focus_group_id, limit=50):
"""Get reasoning history for a focus group."""
db = get_db()
db = await get_db()
try:
# Get reasoning entries associated with this focus group
reasoning_entries = list(db.focus_group_reasoning.find(
cursor = db.focus_group_reasoning.find(
{"focus_group_id": focus_group_id}
).sort("timestamp", -1).limit(limit))
).sort("timestamp", -1).limit(limit)
reasoning_entries = await cursor.to_list(length=limit)
# Convert ObjectId to strings and format timestamps
for entry in reasoning_entries:
@ -640,12 +645,12 @@ class FocusGroup:
return []
@staticmethod
def add_reasoning_entry(focus_group_id, reasoning_data):
async def add_reasoning_entry(focus_group_id, reasoning_data):
"""Add a reasoning entry to a focus group."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return None
@ -669,7 +674,7 @@ class FocusGroup:
reasoning_entry["timestamp"] = datetime.utcnow()
# Insert the reasoning entry
result = db.focus_group_reasoning.insert_one(reasoning_entry)
result = await db.focus_group_reasoning.insert_one(reasoning_entry)
# Return the id of the new entry
return str(result.inserted_id)
@ -679,12 +684,12 @@ class FocusGroup:
return None
@staticmethod
def update_reasoning_execution(focus_group_id, reasoning_id, execution_result):
async def update_reasoning_execution(focus_group_id, reasoning_id, execution_result):
"""Update the execution result of a reasoning entry."""
db = get_db()
db = await get_db()
try:
# Update the reasoning entry
result = db.focus_group_reasoning.update_one(
result = await db.focus_group_reasoning.update_one(
{"_id": ObjectId(reasoning_id), "focus_group_id": focus_group_id},
{"$set": {
"execution_status": "success" if not execution_result.get("error") else "error",
@ -700,14 +705,15 @@ class FocusGroup:
return False
@staticmethod
def get_notes(focus_group_id, limit=100):
async def get_notes(focus_group_id, limit=100):
"""Get all notes for a focus group."""
db = get_db()
db = await get_db()
try:
# Look for a notes collection associated with this focus group
notes = list(db.focus_group_notes.find(
cursor = db.focus_group_notes.find(
{"focus_group_id": focus_group_id}
).sort("created_at", -1).limit(limit))
).sort("created_at", -1).limit(limit)
notes = await cursor.to_list(length=limit)
# Convert ObjectId to strings
for note in notes:
@ -723,12 +729,12 @@ class FocusGroup:
return []
@staticmethod
def add_note(focus_group_id, note_data):
async def add_note(focus_group_id, note_data):
"""Add a new note to a focus group."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return None
@ -752,7 +758,7 @@ class FocusGroup:
note["timestamp"] = datetime.utcnow()
# Insert the note
result = db.focus_group_notes.insert_one(note)
result = await db.focus_group_notes.insert_one(note)
# Return the id of the new note
return str(result.inserted_id)
@ -762,12 +768,12 @@ class FocusGroup:
return None
@staticmethod
def delete_note(focus_group_id, note_id):
async def delete_note(focus_group_id, note_id):
"""Delete a note from a focus group."""
db = get_db()
db = await get_db()
try:
# Delete the note
result = db.focus_group_notes.delete_one(
result = await db.focus_group_notes.delete_one(
{"_id": ObjectId(note_id), "focus_group_id": focus_group_id}
)
@ -778,12 +784,12 @@ class FocusGroup:
return False
@staticmethod
def add_mode_event(focus_group_id, event_type, user_id=None):
async def add_mode_event(focus_group_id, event_type, user_id=None):
"""Add a mode switch event to a focus group."""
db = get_db()
db = await get_db()
try:
# Ensure the focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return None
@ -797,7 +803,7 @@ class FocusGroup:
}
# Insert the mode event
result = db.focus_group_mode_events.insert_one(mode_event)
result = await db.focus_group_mode_events.insert_one(mode_event)
if result.inserted_id:
mode_event_id = str(result.inserted_id)
@ -814,7 +820,7 @@ class FocusGroup:
}
print(f"🔔 EMITTING WEBSOCKET EVENT: mode_event_update for focus group {focus_group_id}")
print(f"🔔 Mode event data: event_type={event_type}, timestamp={mode_event['timestamp'].isoformat()}")
emit_websocket_event('mode_event_update', focus_group_id, mode_event_for_websocket)
await emit_websocket_event('mode_event_update', focus_group_id, mode_event_for_websocket)
return mode_event_id
@ -826,14 +832,15 @@ class FocusGroup:
return None
@staticmethod
def get_mode_events(focus_group_id, limit=100):
async def get_mode_events(focus_group_id, limit=100):
"""Get all mode events for a focus group."""
db = get_db()
db = await get_db()
try:
# Look for mode events associated with this focus group
mode_events = list(db.focus_group_mode_events.find(
cursor = db.focus_group_mode_events.find(
{"focus_group_id": focus_group_id}
).sort("timestamp", 1).limit(limit))
).sort("timestamp", 1).limit(limit)
mode_events = await cursor.to_list(length=limit)
# Convert ObjectId to strings
for event in mode_events:
@ -849,9 +856,9 @@ class FocusGroup:
return []
@staticmethod
def add_uploaded_assets(focus_group_id, assets_metadata):
async def add_uploaded_assets(focus_group_id, assets_metadata):
"""Add uploaded asset metadata to a focus group."""
db = get_db()
db = await get_db()
try:
# Clean the metadata to remove file_path before storing in DB
cleaned_assets = []
@ -867,7 +874,7 @@ class FocusGroup:
cleaned_assets.append(cleaned_asset)
# Add assets to the focus group
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{
"$push": {"uploaded_assets": {"$each": cleaned_assets}},
@ -882,12 +889,12 @@ class FocusGroup:
return False
@staticmethod
def remove_uploaded_asset(focus_group_id, filename):
async def remove_uploaded_asset(focus_group_id, filename):
"""Remove an uploaded asset metadata from a focus group."""
db = get_db()
db = await get_db()
try:
# Remove asset from the focus group
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{
"$pull": {"uploaded_assets": {"filename": filename}},
@ -902,11 +909,11 @@ class FocusGroup:
return False
@staticmethod
def get_uploaded_assets(focus_group_id):
async def get_uploaded_assets(focus_group_id):
"""Get uploaded assets for a focus group."""
db = get_db()
db = await get_db()
try:
focus_group = db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
focus_group = await db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
if focus_group:
return focus_group.get('uploaded_assets', [])
return []
@ -916,11 +923,11 @@ class FocusGroup:
return []
@staticmethod
def update_asset_name(focus_group_id, filename, user_assigned_name):
async def update_asset_name(focus_group_id, filename, user_assigned_name):
"""Update the user assigned name for an uploaded asset."""
db = get_db()
db = await get_db()
try:
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id), "uploaded_assets.filename": filename},
{
"$set": {
@ -937,11 +944,11 @@ class FocusGroup:
return False
@staticmethod
def clear_uploaded_assets(focus_group_id):
async def clear_uploaded_assets(focus_group_id):
"""Clear all uploaded assets for a focus group from database."""
db = get_db()
db = await get_db()
try:
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{
"$unset": {"uploaded_assets": ""},
@ -956,15 +963,15 @@ class FocusGroup:
return False
@staticmethod
def _activate_visual_assets(focus_group_id, asset_filenames, message_id):
async def _activate_visual_assets(focus_group_id, asset_filenames, message_id):
"""Internal method to activate visual assets in conversation context."""
db = get_db()
db = await get_db()
try:
# Get current message count to determine sequence number
message_count = db.focus_group_messages.count_documents({"focus_group_id": focus_group_id})
message_count = await db.focus_group_messages.count_documents({"focus_group_id": focus_group_id})
# Get existing visual context to check for duplicate assets
focus_group = db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
focus_group = await db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
existing_context = focus_group.get("active_visual_context", []) if focus_group else []
# Track which assets are new vs existing
@ -1009,7 +1016,7 @@ class FocusGroup:
# First, update existing assets to current sequence
for filename in updated_filenames:
db.focus_groups.update_one(
await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id), "active_visual_context.filename": filename},
{
"$set": {
@ -1024,7 +1031,7 @@ class FocusGroup:
# Then, add any new assets
result = None
if new_records:
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{
"$push": {"active_visual_context": {"$each": new_records}},
@ -1034,7 +1041,7 @@ class FocusGroup:
)
else:
# If we only updated existing assets, just set the updated_at timestamp
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{"$set": {"updated_at": datetime.utcnow()}}
)
@ -1048,11 +1055,11 @@ class FocusGroup:
return False
@staticmethod
def get_active_visual_context(focus_group_id):
async def get_active_visual_context(focus_group_id):
"""Get all images that are active in conversation context for this focus group."""
db = get_db()
db = await get_db()
try:
focus_group = db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
focus_group = await db.focus_groups.find_one({"_id": ObjectId(focus_group_id)})
if focus_group:
return focus_group.get('active_visual_context', [])
return []
@ -1062,14 +1069,15 @@ class FocusGroup:
return []
@staticmethod
def get_messages_with_visual_context(focus_group_id, limit=100):
async def get_messages_with_visual_context(focus_group_id, limit=100):
"""Get messages with enhanced visual context information."""
db = get_db()
db = await get_db()
try:
# Get all messages
messages = list(db.focus_group_messages.find(
cursor = db.focus_group_messages.find(
{"focus_group_id": focus_group_id}
).sort("created_at", 1))
).sort("created_at", 1)
messages = await cursor.to_list(length=None)
# Convert ObjectId to strings and add sequence numbers
for i, message in enumerate(messages):
@ -1078,7 +1086,7 @@ class FocusGroup:
message["sequence"] = i + 1
# Add flag indicating if this message has visual context available
active_context = FocusGroup.get_active_visual_context(focus_group_id)
active_context = await FocusGroup.get_active_visual_context(focus_group_id)
message["has_visual_context"] = any(
asset["activated_at_sequence"] <= message["sequence"]
for asset in active_context
@ -1091,11 +1099,11 @@ class FocusGroup:
return []
@staticmethod
def clear_visual_context(focus_group_id):
async def clear_visual_context(focus_group_id):
"""Clear all active visual context for a focus group (useful for testing)."""
db = get_db()
db = await get_db()
try:
result = db.focus_groups.update_one(
result = await db.focus_groups.update_one(
{"_id": ObjectId(focus_group_id)},
{
"$unset": {"active_visual_context": ""},

View file

@ -5,9 +5,9 @@ from datetime import datetime
class Folder:
@staticmethod
def create(folder_data, user_id):
async def create(folder_data, user_id):
"""Create a new folder."""
db = get_db()
db = await get_db()
# Add metadata
folder_data["created_at"] = datetime.utcnow()
@ -15,15 +15,15 @@ class Folder:
# Note: No longer storing persona_ids in folders - using persona-centric storage
result = db.folders.insert_one(folder_data)
result = await db.folders.insert_one(folder_data)
return str(result.inserted_id)
@staticmethod
def find_by_id(folder_id):
async def find_by_id(folder_id):
"""Find a folder by its ID."""
db = get_db()
db = await get_db()
try:
folder = db.folders.find_one({"_id": ObjectId(folder_id)})
folder = await db.folders.find_one({"_id": ObjectId(folder_id)})
if folder:
folder["_id"] = str(folder["_id"])
return folder
@ -32,10 +32,11 @@ class Folder:
return None
@staticmethod
def find_by_user(user_id, limit=100):
async def find_by_user(user_id, limit=100):
"""Find all folders created by a specific user."""
db = get_db()
folders = db.folders.find({"created_by": user_id}).sort("created_at", -1).limit(limit)
db = await get_db()
cursor = db.folders.find({"created_by": user_id}).sort("created_at", -1).limit(limit)
folders = await cursor.to_list(length=limit)
result = []
for folder in folders:
@ -45,11 +46,12 @@ class Folder:
return result
@staticmethod
def get_all(limit=100):
async def get_all(limit=100):
"""Get all folders (for debugging/admin purposes)."""
try:
db = get_db()
folders = list(db.folders.find().sort("created_at", -1).limit(limit))
db = await get_db()
cursor = db.folders.find().sort("created_at", -1).limit(limit)
folders = await cursor.to_list(length=limit)
result = []
for folder in folders:
@ -62,9 +64,9 @@ class Folder:
return []
@staticmethod
def update(folder_id, data):
async def update(folder_id, data):
"""Update a folder."""
db = get_db()
db = await get_db()
# Create a copy of the data to avoid modifying the original
filtered_data = data.copy()
@ -82,7 +84,7 @@ class Folder:
# Set the updated timestamp
filtered_data["updated_at"] = datetime.utcnow()
result = db.folders.update_one(
result = await db.folders.update_one(
{"_id": ObjectId(folder_id)},
{"$set": filtered_data}
)
@ -90,26 +92,26 @@ class Folder:
return result.modified_count > 0
@staticmethod
def delete(folder_id):
async def delete(folder_id):
"""Delete a folder."""
db = get_db()
db = await get_db()
try:
result = db.folders.delete_one({"_id": ObjectId(folder_id)})
result = await db.folders.delete_one({"_id": ObjectId(folder_id)})
return result.deleted_count > 0
except Exception as e:
print(f"Error in delete: {e}")
return False
@staticmethod
def add_persona(folder_id, persona_id):
async def add_persona(folder_id, persona_id):
"""Add a persona to a folder (persona-centric storage)."""
db = get_db()
db = await get_db()
try:
print(f"🔧 FOLDER ADD_PERSONA: folder_id={folder_id}, persona_id={persona_id}")
# Check if persona exists
persona = db.personas.find_one({"_id": ObjectId(persona_id)})
persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
if not persona:
print(f"❌ FOLDER ADD_PERSONA: Persona {persona_id} not found")
return False
@ -118,7 +120,7 @@ class Folder:
print(f"📋 FOLDER ADD_PERSONA: Current folder_ids: {persona.get('folder_ids', 'None')}")
# Only update the persona's folder_ids - single source of truth
persona_result = db.personas.update_one(
persona_result = await db.personas.update_one(
{"_id": ObjectId(persona_id)},
{"$addToSet": {"folder_ids": folder_id}, "$set": {"updated_at": datetime.utcnow()}}
)
@ -126,11 +128,11 @@ class Folder:
print(f"📝 FOLDER ADD_PERSONA: Update result - modified_count: {persona_result.modified_count}, matched_count: {persona_result.matched_count}")
# Verify the update
updated_persona = db.personas.find_one({"_id": ObjectId(persona_id)})
updated_persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
print(f"✅ FOLDER ADD_PERSONA: Updated folder_ids: {updated_persona.get('folder_ids', 'None')}")
# Update folder's updated_at timestamp
db.folders.update_one(
await db.folders.update_one(
{"_id": ObjectId(folder_id)},
{"$set": {"updated_at": datetime.utcnow()}}
)
@ -143,15 +145,15 @@ class Folder:
return False
@staticmethod
def remove_persona(folder_id, persona_id):
async def remove_persona(folder_id, persona_id):
"""Remove a persona from a folder (persona-centric storage)."""
db = get_db()
db = await get_db()
try:
print(f"🔧 FOLDER REMOVE_PERSONA: folder_id={folder_id}, persona_id={persona_id}")
# Check if persona exists
persona = db.personas.find_one({"_id": ObjectId(persona_id)})
persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
if not persona:
print(f"❌ FOLDER REMOVE_PERSONA: Persona {persona_id} not found")
return False
@ -160,7 +162,7 @@ class Folder:
print(f"📋 FOLDER REMOVE_PERSONA: Current folder_ids: {persona.get('folder_ids', 'None')}")
# Only update the persona's folder_ids - single source of truth
persona_result = db.personas.update_one(
persona_result = await db.personas.update_one(
{"_id": ObjectId(persona_id)},
{"$pull": {"folder_ids": folder_id}, "$set": {"updated_at": datetime.utcnow()}}
)
@ -168,11 +170,11 @@ class Folder:
print(f"📝 FOLDER REMOVE_PERSONA: Update result - modified_count: {persona_result.modified_count}, matched_count: {persona_result.matched_count}")
# Verify the update
updated_persona = db.personas.find_one({"_id": ObjectId(persona_id)})
updated_persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
print(f"✅ FOLDER REMOVE_PERSONA: Updated folder_ids: {updated_persona.get('folder_ids', 'None')}")
# Update folder's updated_at timestamp
db.folders.update_one(
await db.folders.update_one(
{"_id": ObjectId(folder_id)},
{"$set": {"updated_at": datetime.utcnow()}}
)
@ -185,9 +187,9 @@ class Folder:
return False
@staticmethod
def add_personas_batch(folder_id, persona_ids):
async def add_personas_batch(folder_id, persona_ids):
"""Add multiple personas to a folder (persona-centric storage)."""
db = get_db()
db = await get_db()
try:
print(f"🔧 FOLDER ADD_PERSONAS_BATCH: folder_id={folder_id}, persona_ids={persona_ids}")
@ -199,7 +201,7 @@ class Folder:
print(f"🔧 FOLDER BATCH: Processing persona {persona_id}")
# Check if persona exists
persona = db.personas.find_one({"_id": ObjectId(persona_id)})
persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
if not persona:
print(f"❌ FOLDER BATCH: Persona {persona_id} not found")
persona_results.append(False)
@ -208,7 +210,7 @@ class Folder:
print(f"✅ FOLDER BATCH: Found persona {persona.get('name', 'Unknown')}")
print(f"📋 FOLDER BATCH: Current folder_ids: {persona.get('folder_ids', 'None')}")
result = db.personas.update_one(
result = await db.personas.update_one(
{"_id": ObjectId(persona_id)},
{"$addToSet": {"folder_ids": folder_id}, "$set": {"updated_at": datetime.utcnow()}}
)
@ -221,7 +223,7 @@ class Folder:
persona_results.append(False)
# Update folder's updated_at timestamp
db.folders.update_one(
await db.folders.update_one(
{"_id": ObjectId(folder_id)},
{"$set": {"updated_at": datetime.utcnow()}}
)
@ -237,9 +239,9 @@ class Folder:
return False
@staticmethod
def remove_personas_batch(folder_id, persona_ids):
async def remove_personas_batch(folder_id, persona_ids):
"""Remove multiple personas from a folder (persona-centric storage)."""
db = get_db()
db = await get_db()
try:
print(f"🔧 FOLDER REMOVE_PERSONAS_BATCH: folder_id={folder_id}, persona_ids={persona_ids}")
@ -251,7 +253,7 @@ class Folder:
print(f"🔧 FOLDER REMOVE_BATCH: Processing persona {persona_id}")
# Check if persona exists
persona = db.personas.find_one({"_id": ObjectId(persona_id)})
persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
if not persona:
print(f"❌ FOLDER REMOVE_BATCH: Persona {persona_id} not found")
persona_results.append(False)
@ -260,7 +262,7 @@ class Folder:
print(f"✅ FOLDER REMOVE_BATCH: Found persona {persona.get('name', 'Unknown')}")
print(f"📋 FOLDER REMOVE_BATCH: Current folder_ids: {persona.get('folder_ids', 'None')}")
result = db.personas.update_one(
result = await db.personas.update_one(
{"_id": ObjectId(persona_id)},
{"$pull": {"folder_ids": folder_id}, "$set": {"updated_at": datetime.utcnow()}}
)
@ -273,7 +275,7 @@ class Folder:
persona_results.append(False)
# Update folder's updated_at timestamp
db.folders.update_one(
await db.folders.update_one(
{"_id": ObjectId(folder_id)},
{"$set": {"updated_at": datetime.utcnow()}}
)
@ -289,13 +291,13 @@ class Folder:
return False
@staticmethod
def get_folders_containing_persona(persona_id, user_id=None):
async def get_folders_containing_persona(persona_id, user_id=None):
"""Find all folders that contain a specific persona (persona-centric storage)."""
db = get_db()
db = await get_db()
try:
# Get the persona to see which folders it belongs to
persona = db.personas.find_one({"_id": ObjectId(persona_id)})
persona = await db.personas.find_one({"_id": ObjectId(persona_id)})
if not persona or not persona.get("folder_ids"):
return []
@ -307,7 +309,8 @@ class Folder:
if user_id:
query["created_by"] = user_id
folders = list(db.folders.find(query))
cursor = db.folders.find(query)
folders = await cursor.to_list(length=None)
result = []
for folder in folders:

View file

@ -4,8 +4,8 @@ from datetime import datetime
class Persona:
@staticmethod
def create(persona_data, user_id=None):
db = get_db()
async def create(persona_data, user_id=None):
db = await get_db()
# Add metadata
persona_data["created_at"] = datetime.utcnow()
@ -15,13 +15,13 @@ class Persona:
if "folder_ids" not in persona_data:
persona_data["folder_ids"] = []
result = db.personas.insert_one(persona_data)
result = await db.personas.insert_one(persona_data)
print(f"✅ PERSONA CREATED: {persona_data.get('name', 'Unknown')} with folder_ids: {persona_data['folder_ids']}")
return str(result.inserted_id)
@staticmethod
def find_by_id(persona_id):
db = get_db()
async def find_by_id(persona_id):
db = await get_db()
try:
# If persona_id is already an ObjectId, use it directly
if isinstance(persona_id, ObjectId):
@ -33,14 +33,14 @@ class Persona:
except Exception as e:
print(f"Invalid ObjectId format: {persona_id}, error: {e}")
# Try lookup by string ID as fallback
persona = db.personas.find_one({"id": persona_id})
persona = await db.personas.find_one({"id": persona_id})
if persona:
persona["_id"] = str(persona["_id"])
return persona
return None
# Lookup by ObjectId
persona = db.personas.find_one({"_id": object_id})
persona = await db.personas.find_one({"_id": object_id})
if persona:
persona["_id"] = str(persona["_id"])
return persona
@ -49,25 +49,25 @@ class Persona:
return None
@staticmethod
def find_by_user(user_id, limit=100):
db = get_db()
async def find_by_user(user_id, limit=100):
db = await get_db()
personas = db.personas.find({"created_by": user_id}).sort("created_at", -1).limit(limit)
result = []
for persona in personas:
async for persona in personas:
persona["_id"] = str(persona["_id"])
result.append(persona)
return result
@staticmethod
def get_all(limit=100):
async def get_all(limit=100):
try:
db = get_db()
personas = list(db.personas.find().sort("created_at", -1).limit(limit))
db = await get_db()
personas = db.personas.find().sort("created_at", -1).limit(limit)
result = []
for persona in personas:
async for persona in personas:
persona["_id"] = str(persona["_id"])
result.append(persona)
@ -77,8 +77,8 @@ class Persona:
return []
@staticmethod
def update(persona_id, data):
db = get_db()
async def update(persona_id, data):
db = await get_db()
# Create a copy of the data to avoid modifying the original
filtered_data = data.copy()
@ -96,7 +96,7 @@ class Persona:
# Set the updated timestamp
filtered_data["updated_at"] = datetime.utcnow()
result = db.personas.update_one(
result = await db.personas.update_one(
{"_id": ObjectId(persona_id)},
{"$set": filtered_data}
)
@ -104,8 +104,8 @@ class Persona:
return result.modified_count > 0
@staticmethod
def delete(persona_id):
db = get_db()
async def delete(persona_id):
db = await get_db()
try:
# Convert to ObjectId if needed
if isinstance(persona_id, ObjectId):
@ -119,7 +119,7 @@ class Persona:
except Exception as e:
print(f"Invalid ObjectId format for delete: {persona_id}, error: {e}")
# Try delete by string ID as fallback
result = db.personas.delete_one({"id": persona_id})
result = await db.personas.delete_one({"id": persona_id})
# Note: No folder cleanup needed - using persona-centric storage
return result.deleted_count > 0
@ -127,7 +127,7 @@ class Persona:
# Folder membership is only stored in persona.folder_ids, which gets deleted with the persona
# Delete by ObjectId
result = db.personas.delete_one({"_id": object_id})
result = await db.personas.delete_one({"_id": object_id})
return result.deleted_count > 0
except Exception as e:
print(f"Error in delete: {e}, persona_id: {persona_id}")

View file

@ -22,33 +22,33 @@ class User:
return bcrypt.checkpw(password.encode('utf-8'), password_hash.encode('utf-8'))
@staticmethod
def find_by_username(username):
db = get_db()
user_data = db.users.find_one({"username": username})
async def find_by_username(username):
db = await get_db()
user_data = await db.users.find_one({"username": username})
return user_data
@staticmethod
def find_by_email(email):
db = get_db()
user_data = db.users.find_one({"email": email})
async def find_by_email(email):
db = await get_db()
user_data = await db.users.find_one({"email": email})
return user_data
@staticmethod
def find_by_id(user_id):
db = get_db()
user_data = db.users.find_one({"_id": ObjectId(user_id)})
async def find_by_id(user_id):
db = await get_db()
user_data = await db.users.find_one({"_id": ObjectId(user_id)})
return user_data
@staticmethod
def find_by_microsoft_id(microsoft_id):
db = get_db()
user_data = db.users.find_one({"microsoft_id": microsoft_id})
async def find_by_microsoft_id(microsoft_id):
db = await get_db()
user_data = await db.users.find_one({"microsoft_id": microsoft_id})
return user_data
@staticmethod
def update_microsoft_id(user_id, microsoft_id):
db = get_db()
result = db.users.update_one(
async def update_microsoft_id(user_id, microsoft_id):
db = await get_db()
result = await db.users.update_one(
{"_id": ObjectId(user_id)},
{"$set": {"microsoft_id": microsoft_id, "auth_type": "microsoft"}}
)
@ -63,8 +63,8 @@ class User:
"microsoft_id": self.microsoft_id
}
def save(self):
db = get_db()
async def save(self):
db = await get_db()
user_data = {
"username": self.username,
"email": self.email,
@ -73,23 +73,23 @@ class User:
"auth_type": self.auth_type,
"microsoft_id": self.microsoft_id
}
result = db.users.insert_one(user_data)
result = await db.users.insert_one(user_data)
return result.inserted_id
@staticmethod
def create_default_user():
async def create_default_user():
try:
db = get_db()
db = await get_db()
# First check if users collection exists
collections = db.list_collection_names()
collections = await db.list_collection_names()
if "users" not in collections:
print("Creating users collection")
db.create_collection("users")
await db.create_collection("users")
# Safely check if user exists, handling potential auth errors
try:
user_exists = db.users.count_documents({"username": "user"}) > 0
user_exists = await db.users.count_documents({"username": "user"}) > 0
except Exception as e:
print(f"Error checking for default user: {e}")
# If we can't query, assume we need to create the user
@ -102,7 +102,7 @@ class User:
password_hash=User.hash_password("pass"),
role="admin"
)
default_user.save()
await default_user.save()
print("Default user created successfully")
else:
print("Default user already exists")

View file

@ -3,10 +3,10 @@ AI Persona Generation Routes.
These endpoints handle the generation of synthetic personas using AI models.
"""
from flask import Blueprint, request, jsonify, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from quart import Blueprint, request, jsonify, current_app, make_response
from app.auth.quart_jwt import jwt_required, get_jwt_identity
import time
import concurrent.futures
import asyncio
from werkzeug.serving import is_running_from_reloader
from app.services.ai_persona_service import (
@ -29,7 +29,7 @@ ai_personas_bp = Blueprint('ai_personas', __name__)
@ai_personas_bp.route('/generate-basic-profiles', methods=['POST'])
@jwt_required()
def generate_basic_profiles():
async def generate_basic_profiles():
"""
First stage of the two-stage persona generation process.
@ -48,7 +48,7 @@ def generate_basic_profiles():
A JSON object containing an array of basic persona profiles
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
# Extract parameters
audience_brief = data.get('audience_brief')
@ -74,7 +74,7 @@ def generate_basic_profiles():
current_app.logger.info(f"Generating {count} basic profiles using model: {llm_model}")
# Generate basic profiles
basic_profiles = generate_basic_personas(
basic_profiles = await generate_basic_personas(
audience_brief=audience_brief,
research_objective=research_objective,
count=count,
@ -101,7 +101,7 @@ def generate_basic_profiles():
@ai_personas_bp.route('/complete-persona', methods=['POST'])
@jwt_required()
def complete_persona():
async def complete_persona():
"""
Second stage of the two-stage persona generation process.
@ -122,7 +122,7 @@ def complete_persona():
A JSON object containing the complete persona
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
# Extract parameters
basic_profile = data.get('basic_profile')
@ -135,7 +135,7 @@ def complete_persona():
try:
# Complete the persona
complete_persona_data = generate_persona(
complete_persona_data = await generate_persona(
basic_persona=basic_profile,
temperature=temperature
)
@ -155,7 +155,7 @@ def complete_persona():
@ai_personas_bp.route('/complete-and-save-persona', methods=['POST'])
@jwt_required()
def complete_and_save_persona():
async def complete_and_save_persona():
"""
Second stage of the two-stage persona generation process that also saves the
persona to the database.
@ -178,7 +178,7 @@ def complete_and_save_persona():
A JSON object containing the complete persona with its database ID
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
# Extract parameters
basic_profile = data.get('basic_profile')
@ -213,7 +213,7 @@ def complete_and_save_persona():
current_app.logger.info(f"Completing persona '{persona_name}' using model: {llm_model}")
# Complete the persona
complete_persona_data = generate_persona(
complete_persona_data = await generate_persona(
basic_persona=basic_profile,
temperature=temperature,
customer_data_session_id=customer_data_session_id,
@ -222,7 +222,7 @@ def complete_and_save_persona():
# Generate AI summary for the persona
try:
summary_data = generate_persona_summary(
summary_data = await generate_persona_summary(
persona_data=complete_persona_data,
temperature=temperature,
llm_model=llm_model
@ -253,7 +253,7 @@ def complete_and_save_persona():
print(f"📝 Backend: Added research_objective to persona data (~{len(research_objective)} chars)")
# Save to database
persona_id = Persona.create(complete_persona_data, user_id)
persona_id = await Persona.create(complete_persona_data, user_id)
# Add the database ID to the response
complete_persona_data['_id'] = str(persona_id)
@ -277,7 +277,7 @@ def complete_and_save_persona():
@ai_personas_bp.route('/generate', methods=['POST'])
@jwt_required()
def generate_ai_persona():
async def generate_ai_persona():
"""
Generate a synthetic persona using AI and return it without saving.
@ -297,7 +297,7 @@ def generate_ai_persona():
A JSON object containing the generated persona
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
try:
# Extract customization parameters
@ -318,7 +318,7 @@ def generate_ai_persona():
temperature = 0.7
# Generate the persona
persona_data = generate_persona(
persona_data = await generate_persona(
prompt_customization=customization,
temperature=temperature
)
@ -335,7 +335,7 @@ def generate_ai_persona():
@ai_personas_bp.route('/generate-and-save', methods=['POST'])
@jwt_required()
def generate_and_save_persona():
async def generate_and_save_persona():
"""
Generate a synthetic persona using AI and save it to the database.
@ -345,7 +345,7 @@ def generate_and_save_persona():
A JSON object containing the generated and saved persona, including its database ID
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
try:
# Extract customization parameters
@ -366,14 +366,14 @@ def generate_and_save_persona():
temperature = 0.7
# Generate the persona
persona_data = generate_persona(
persona_data = await generate_persona(
prompt_customization=customization,
temperature=temperature
)
# Generate AI summary for the persona
try:
summary_data = generate_persona_summary(
summary_data = await generate_persona_summary(
persona_data=persona_data,
temperature=temperature
)
@ -395,7 +395,7 @@ def generate_and_save_persona():
del persona_data['id']
# Save to database
persona_id = Persona.create(persona_data, user_id)
persona_id = await Persona.create(persona_data, user_id)
# Add the database ID to the response
persona_data['_id'] = str(persona_id)
@ -416,7 +416,7 @@ def generate_and_save_persona():
@ai_personas_bp.route('/batch-generate', methods=['POST'])
@jwt_required()
def batch_generate_personas():
async def batch_generate_personas():
"""
Generate multiple synthetic personas using AI.
@ -438,7 +438,7 @@ def batch_generate_personas():
A JSON object containing an array of generated personas
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
count = data.get('count', 1)
if count < 1 or count > 10: # Limit the number for performance reasons
@ -472,27 +472,22 @@ def batch_generate_personas():
'temperature': temperature
})
# Generate personas
personas = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(count, 4)) as executor:
# Start the generation tasks
future_to_task = {
executor.submit(
generate_persona,
task['prompt_customization'],
# Generate personas using asyncio.gather for concurrent async execution
try:
generation_coroutines = [
generate_persona(
task['prompt_customization'],
None, # No basic_persona for this endpoint
task['temperature']
): i for i, task in enumerate(generation_tasks)
}
) for task in generation_tasks
]
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_task):
try:
persona_data = future.result()
personas.append(persona_data)
except Exception as exc:
current_app.logger.error(f"Persona generation task failed with error: {exc}")
raise PersonaGenerationError(f"Failed to generate one of the personas: {str(exc)}")
# Execute all persona generations concurrently
personas = await asyncio.gather(*generation_coroutines)
except Exception as exc:
current_app.logger.error(f"Persona generation task failed with error: {exc}")
raise PersonaGenerationError(f"Failed to generate one of the personas: {str(exc)}")
return jsonify({
"message": f"Successfully generated {len(personas)} personas",
@ -509,7 +504,7 @@ def batch_generate_personas():
@ai_personas_bp.route('/batch-generate-and-save', methods=['POST'])
@jwt_required()
def batch_generate_and_save_personas():
async def batch_generate_and_save_personas():
"""
Generate multiple synthetic personas using AI and save them to the database.
@ -519,7 +514,7 @@ def batch_generate_and_save_personas():
A JSON object containing the array of generated and saved personas with their IDs
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
count = data.get('count', 1)
if count < 1 or count > 10: # Limit for performance
@ -553,27 +548,22 @@ def batch_generate_and_save_personas():
'temperature': temperature
})
# Generate personas
generated_personas = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(count, 4)) as executor:
# Start the generation tasks
future_to_task = {
executor.submit(
generate_persona,
task['prompt_customization'],
# Generate personas using asyncio.gather for concurrent async execution
try:
generation_coroutines = [
generate_persona(
task['prompt_customization'],
None, # No basic_persona for this endpoint
task['temperature']
): i for i, task in enumerate(generation_tasks)
}
) for task in generation_tasks
]
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_task):
try:
persona_data = future.result()
generated_personas.append(persona_data)
except Exception as exc:
current_app.logger.error(f"Persona generation task failed with error: {exc}")
raise PersonaGenerationError(f"Failed to generate one of the personas: {str(exc)}")
# Execute all persona generations concurrently
generated_personas = await asyncio.gather(*generation_coroutines)
except Exception as exc:
current_app.logger.error(f"Persona generation task failed with error: {exc}")
raise PersonaGenerationError(f"Failed to generate one of the personas: {str(exc)}")
# Save all generated personas to the database
personas = []
@ -581,7 +571,7 @@ def batch_generate_and_save_personas():
for persona_data in generated_personas:
# Generate AI summary for each persona
try:
summary_data = generate_persona_summary(
summary_data = await generate_persona_summary(
persona_data=persona_data,
temperature=temperature
)
@ -603,7 +593,7 @@ def batch_generate_and_save_personas():
del persona_data['id']
# Save to database
persona_id = Persona.create(persona_data, user_id)
persona_id = await Persona.create(persona_data, user_id)
# Add database ID to the response
persona_data['_id'] = str(persona_id)
@ -626,7 +616,7 @@ def batch_generate_and_save_personas():
@ai_personas_bp.route('/generate-persona-summary', methods=['POST'])
@jwt_required()
def generate_summary_for_persona():
async def generate_summary_for_persona():
"""
Generate an AI-synthesized summary for an existing persona.
@ -649,7 +639,7 @@ def generate_summary_for_persona():
A JSON object containing the generated summary data
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
# Extract parameters
persona_data = data.get('persona_data')
@ -671,7 +661,7 @@ def generate_summary_for_persona():
try:
# Generate the summary
summary_data = generate_persona_summary(
summary_data = await generate_persona_summary(
persona_data=persona_data,
temperature=temperature
)
@ -691,7 +681,7 @@ def generate_summary_for_persona():
@ai_personas_bp.route('/enhance-audience-brief', methods=['POST'])
@jwt_required()
def enhance_audience_brief_endpoint():
async def enhance_audience_brief_endpoint():
"""
Generate suggestions to improve an audience brief for better persona generation.
@ -710,7 +700,7 @@ def enhance_audience_brief_endpoint():
A JSON object containing separate suggestion arrays for audience_brief and research_objective
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
# Extract parameters
audience_brief = data.get('audience_brief')
@ -733,7 +723,7 @@ def enhance_audience_brief_endpoint():
try:
# Generate enhancement suggestions
suggestions = enhance_audience_brief(
suggestions = await enhance_audience_brief(
audience_brief=audience_brief.strip(),
research_objective=research_objective.strip(),
temperature=temperature
@ -755,7 +745,7 @@ def enhance_audience_brief_endpoint():
@ai_personas_bp.route('/batch-generate-summaries', methods=['POST'])
@jwt_required()
def batch_generate_summaries():
async def batch_generate_summaries():
"""
Generate comprehensive markdown summaries for multiple personas for download/client review.
@ -773,7 +763,7 @@ def batch_generate_summaries():
A JSON object containing the generated summaries and any errors encountered
"""
user_id = get_jwt_identity()
data = request.get_json() or {}
data = await request.get_json() or {}
# Extract parameters
persona_ids = data.get('persona_ids', [])
@ -802,7 +792,7 @@ def batch_generate_summaries():
for persona_id in persona_ids:
try:
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if persona:
personas_data.append(persona)
else:
@ -822,13 +812,13 @@ def batch_generate_summaries():
successful_summaries = []
failed_summaries = []
def process_persona_summary(persona_data):
async def process_persona_summary(persona_data):
"""Helper function to process a single persona summary"""
try:
persona_name = persona_data.get('name', 'Unknown')
print(f"✅ Backend: Successfully generated summary for '{persona_name}' using model: {llm_model}")
summary = generate_persona_download_summary(
summary = await generate_persona_download_summary(
persona_data=persona_data,
temperature=temperature,
llm_model=llm_model
@ -852,22 +842,24 @@ def batch_generate_summaries():
batch = personas_data[i:i + batch_size]
current_app.logger.info(f"Processing batch {i//batch_size + 1}: {len(batch)} personas")
# Process this batch
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
# Submit all tasks for this batch
future_to_persona = {
executor.submit(process_persona_summary, persona): persona
for persona in batch
}
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_persona):
result = future.result()
if result['success']:
successful_summaries.append(result)
else:
failed_summaries.append(result)
current_app.logger.error(f"Failed to generate summary for persona {result['persona_name']}: {result['error']}")
# Process this batch using asyncio
batch_tasks = [process_persona_summary(persona) for persona in batch]
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# Collect results
for result in batch_results:
if isinstance(result, Exception):
failed_summaries.append({
'success': False,
'error': str(result),
'persona_name': 'Unknown'
})
current_app.logger.error(f"Failed to generate summary: {result}")
elif result['success']:
successful_summaries.append(result)
else:
failed_summaries.append(result)
current_app.logger.error(f"Failed to generate summary for persona {result['persona_name']}: {result['error']}")
# Prepare response
total_requested = len(persona_ids)
@ -916,7 +908,7 @@ def batch_generate_summaries():
@ai_personas_bp.route('/upload-customer-data', methods=['POST'])
@jwt_required()
def upload_customer_data():
async def upload_customer_data():
"""
Upload customer data files and parse them using LlamaParse.
@ -930,12 +922,13 @@ def upload_customer_data():
try:
current_app.logger.debug(f"=== UPLOAD CUSTOMER DATA API called for user {user_id} ===")
# Check if files were provided
if 'files' not in request.files:
# Check if files were provided (Quart async files access)
files_dict = await request.files
if 'files' not in files_dict:
current_app.logger.warning("No 'files' key in request.files")
return jsonify({"error": "No files provided"}), 400
files = request.files.getlist('files')
files = files_dict.getlist('files')
if not files or all(f.filename == '' for f in files):
current_app.logger.warning("No files selected")
return jsonify({"error": "No files selected"}), 400
@ -943,7 +936,7 @@ def upload_customer_data():
current_app.logger.info(f"Processing {len(files)} customer data files")
# Upload and parse files using customer data service
session_id = customer_data_service.upload_and_parse_files(files)
session_id = await customer_data_service.upload_and_parse_files(files)
current_app.logger.info(f"Successfully processed customer data files with session_id: {session_id}")
@ -963,7 +956,7 @@ def upload_customer_data():
@ai_personas_bp.route('/cleanup-customer-data/<session_id>', methods=['DELETE'])
@jwt_required()
def cleanup_customer_data(session_id):
async def cleanup_customer_data(session_id):
"""
Clean up customer data files for a specific session.

View file

@ -1,13 +1,13 @@
from flask import Blueprint, request, jsonify
from flask_jwt_extended import create_access_token, jwt_required, get_jwt_identity
from quart import Blueprint, request, jsonify
from app.auth.quart_jwt import create_access_token, jwt_required, get_jwt_identity
from app.models.user import User
from app.services.msal_service import MSALService
auth_bp = Blueprint('auth', __name__)
@auth_bp.route('/register', methods=['POST'])
def register():
data = request.get_json()
async def register():
data = await request.get_json()
if not data or not data.get('username') or not data.get('email') or not data.get('password'):
return jsonify({"message": "Missing required fields"}), 400
@ -17,15 +17,15 @@ def register():
password = data.get('password')
# Check if user already exists
if User.find_by_username(username):
if await User.find_by_username(username):
return jsonify({"message": "Username already taken"}), 409
if User.find_by_email(email):
if await User.find_by_email(email):
return jsonify({"message": "Email already registered"}), 409
# Create new user
hashed_password = User.hash_password(password)
new_user = User(username=username, email=email, password_hash=hashed_password)
user_id = new_user.save()
user_id = await new_user.save()
# Generate access token
access_token = create_access_token(identity=str(user_id))
@ -37,9 +37,9 @@ def register():
}), 201
@auth_bp.route('/login', methods=['POST'])
def login():
async def login():
try:
data = request.get_json()
data = await request.get_json()
if not data or not data.get('username') or not data.get('password'):
return jsonify({"message": "Missing username or password"}), 400
@ -76,7 +76,7 @@ def login():
# Try to find user in database
try:
# Find user by username
user_data = User.find_by_username(username)
user_data = await User.find_by_username(username)
if not user_data:
return jsonify({"message": "Invalid username or password"}), 401
@ -111,7 +111,7 @@ def login():
@auth_bp.route('/me', methods=['GET'])
@jwt_required()
def get_profile():
async def get_profile():
user_id = get_jwt_identity()
# Handle the default_id case specially
@ -124,7 +124,7 @@ def get_profile():
}), 200
try:
user_data = User.find_by_id(user_id)
user_data = await User.find_by_id(user_id)
if not user_data:
return jsonify({"message": "User not found"}), 404
@ -144,10 +144,10 @@ def get_profile():
}), 200
@auth_bp.route('/microsoft', methods=['POST'])
def microsoft_login():
async def microsoft_login():
"""Handle Microsoft OAuth authentication."""
try:
data = request.get_json()
data = await request.get_json()
if not data or not data.get('id_token'):
return jsonify({"message": "Missing Microsoft ID token"}), 400
@ -171,15 +171,15 @@ def microsoft_login():
existing_user = None
try:
# First try to find by Microsoft ID
existing_user = User.find_by_microsoft_id(microsoft_id)
existing_user = await User.find_by_microsoft_id(microsoft_id)
# If not found by Microsoft ID, try by email
if not existing_user:
existing_user = User.find_by_email(email)
existing_user = await User.find_by_email(email)
# If found by email but no Microsoft ID, update the user to link Microsoft account
if existing_user and not existing_user.get('microsoft_id'):
User.update_microsoft_id(existing_user['_id'], microsoft_id)
await User.update_microsoft_id(existing_user['_id'], microsoft_id)
existing_user['microsoft_id'] = microsoft_id
existing_user['auth_type'] = 'microsoft'
@ -192,7 +192,7 @@ def microsoft_login():
try:
user_data = msal_service.create_user_data(microsoft_user_info)
new_user = User(**user_data)
user_id = new_user.save()
user_id = await new_user.save()
existing_user = {
"_id": user_id,
@ -226,4 +226,25 @@ def microsoft_login():
except Exception as e:
print(f"Unexpected error in Microsoft login route: {e}")
return jsonify({"message": "Internal server error"}), 500
return jsonify({"message": "Internal server error"}), 500
@auth_bp.route('/refresh-token', methods=['POST'])
async def refresh_token():
"""Generate a new token for testing during JWT system migration."""
try:
data = (await request.get_json()) or {}
user_id = data.get('user_id', 'default_user')
# Create a new token with our Quart-JWT system
access_token = create_access_token(user_id)
return jsonify({
"message": "Token refreshed successfully",
"access_token": access_token,
"user_id": user_id,
"system": "quart-jwt"
}), 200
except Exception as e:
return jsonify({"message": f"Token refresh failed: {str(e)}"}), 500

View file

@ -4,10 +4,11 @@ These endpoints handle AI-assisted focus group functionality, including persona
and key theme generation.
"""
from flask import Blueprint, request, jsonify, current_app
from flask_jwt_extended import jwt_required, get_jwt_identity
from quart import Blueprint, request, jsonify, current_app
from app.auth.quart_jwt import jwt_required, get_jwt_identity
from typing import Dict, List, Any
import time
import concurrent.futures
from app.services.focus_group_response_service import (
generate_persona_response,
@ -23,6 +24,7 @@ from app.services.ai_moderator_service import AIModeratorService
from app.services.autonomous_conversation_controller import AutonomousConversationController
from app.services.conversation_decision_service import ConversationDecisionService, ConversationDecisionError
from app.services.conversation_state_manager import ConversationStateManager
from app.services.ai_runner_service import get_ai_runner
from app.services.image_description_service import ImageDescriptionService, ImageDescriptionError
from app.models.focus_group import FocusGroup
from app.models.persona import Persona
@ -32,7 +34,7 @@ focus_group_ai_bp = Blueprint('focus_group_ai', __name__)
@focus_group_ai_bp.route('/generate-response', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def generate_ai_response():
async def generate_ai_response():
"""
Generate a response from a persona in a focus group discussion.
@ -49,7 +51,7 @@ def generate_ai_response():
A JSON object containing the generated response
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
# Validate required fields
required_fields = ['focus_group_id', 'persona_id', 'current_topic']
@ -92,7 +94,7 @@ def generate_ai_response():
current_app.logger.info(f"🤖 Generating AI response using model: {llm_model or 'default (gemini-2.5-pro)'} for focus group {focus_group_id}")
# Validate persona exists
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if not persona:
return jsonify({"error": "Persona not found"}), 404
@ -114,13 +116,13 @@ def generate_ai_response():
# This is the new approach - use persistent conversation context instead of detection
print(f"🎨 Checking for active visual context in focus group {focus_group_id}")
from app.services.conversation_context_service import ConversationContextService
has_visual_context = ConversationContextService.has_visual_context(focus_group_id)
has_visual_context = await ConversationContextService.has_visual_context(focus_group_id)
print(f"🎨 Focus group has active visual context: {has_visual_context}")
# Build multimodal conversation context
try:
multimodal_context = ConversationContextService.build_multimodal_context(focus_group_id, recent_messages)
multimodal_context = await ConversationContextService.build_multimodal_context(focus_group_id, recent_messages)
print(f"✅ Built multimodal context with {multimodal_context['total_visual_assets']} visual assets")
except Exception as e:
print(f"❌ Error building multimodal context: {e}")
@ -180,7 +182,7 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
})
# Generate response using contextual conversation method
response_text = LLMService.generate_contextual_response(
response_text = await LLMService.generate_contextual_response(
prompt=prompt,
conversation_context=multimodal_context['conversation_context'],
temperature=temperature,
@ -192,7 +194,7 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
print(f"💬 Using standard response generation (no visual context)")
current_app.logger.info(f"Generating standard response")
response_text = generate_persona_response(
response_text = await generate_persona_response(
persona=persona,
current_topic=current_topic,
previous_messages=recent_messages,
@ -267,7 +269,7 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
@focus_group_ai_bp.route('/generate-key-themes', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def generate_key_themes():
async def generate_key_themes():
"""
Generate key themes from a focus group discussion.
@ -281,7 +283,7 @@ def generate_key_themes():
A JSON object containing the generated key themes
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
# Validate required fields
if 'focus_group_id' not in data:
@ -293,7 +295,7 @@ def generate_key_themes():
temperature = data.get('temperature', 0.7)
# Validate focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
@ -302,7 +304,7 @@ def generate_key_themes():
# Generate key themes
try:
themes = KeyThemeService.generate_key_themes(
themes = await KeyThemeService.generate_key_themes(
focus_group_id=focus_group_id,
temperature=temperature,
llm_model=llm_model
@ -312,7 +314,7 @@ def generate_key_themes():
current_app.logger.info(f"Generated {len(themes)} key themes for focus group {focus_group_id}")
# Save themes to database
theme_ids = FocusGroup.add_generated_themes(focus_group_id, themes)
theme_ids = await FocusGroup.add_generated_themes(focus_group_id, themes)
if not theme_ids:
current_app.logger.error("Failed to save themes to database")
@ -355,7 +357,7 @@ def generate_key_themes():
@focus_group_ai_bp.route('/key-themes/<focus_group_id>', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_key_themes(focus_group_id):
async def get_key_themes(focus_group_id):
"""
Get all generated key themes for a focus group.
@ -364,12 +366,12 @@ def get_key_themes(focus_group_id):
"""
try:
# Validate focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
# Get themes
themes = FocusGroup.get_generated_themes(focus_group_id)
themes = await FocusGroup.get_generated_themes(focus_group_id)
# Format themes for response
formatted_themes = []
@ -397,7 +399,7 @@ def get_key_themes(focus_group_id):
@focus_group_ai_bp.route('/key-themes/<focus_group_id>/<theme_id>', methods=['DELETE'])
@jwt_required(optional=True) # Make JWT optional for development
def delete_key_theme(focus_group_id, theme_id):
async def delete_key_theme(focus_group_id, theme_id):
"""
Delete a key theme from a focus group.
@ -406,12 +408,12 @@ def delete_key_theme(focus_group_id, theme_id):
"""
try:
# Validate focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
# Delete theme
success = FocusGroup.delete_generated_theme(focus_group_id, theme_id)
success = await FocusGroup.delete_generated_theme(focus_group_id, theme_id)
if not success:
return jsonify({
@ -434,7 +436,7 @@ def delete_key_theme(focus_group_id, theme_id):
@focus_group_ai_bp.route('/moderator/status/<focus_group_id>', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_moderator_status(focus_group_id):
async def get_moderator_status(focus_group_id):
"""
Get the current moderator status for a focus group.
@ -442,7 +444,7 @@ def get_moderator_status(focus_group_id):
A JSON object containing the current moderator status
"""
try:
status = AIModeratorService.get_moderator_status(focus_group_id)
status = await AIModeratorService.get_moderator_status(focus_group_id)
if "error" in status:
return jsonify(status), 404 if "not found" in status["error"] else 500
@ -462,7 +464,7 @@ def get_moderator_status(focus_group_id):
@focus_group_ai_bp.route('/moderator/advance/<focus_group_id>', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def advance_moderator_discussion(focus_group_id):
async def advance_moderator_discussion(focus_group_id):
"""
Advance the moderator to the next item in the discussion guide.
For manual mode, also use AI to decide which participant should respond next.
@ -477,7 +479,7 @@ def advance_moderator_discussion(focus_group_id):
A JSON object containing the moderator response, updated position, and optionally participant response
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
temperature = data.get('temperature', 0.7)
# Check if focus group is in autonomous mode
@ -493,7 +495,7 @@ def advance_moderator_discussion(focus_group_id):
# Default: generate participant response for manual mode, not for autonomous mode
generate_participant_response = data.get('generate_participant_response', not is_autonomous_mode)
result = AIModeratorService.advance_discussion(focus_group_id)
result = await AIModeratorService.advance_discussion(focus_group_id)
if "error" in result:
return jsonify(result), 404 if "not found" in result["error"] else 500
@ -534,7 +536,7 @@ def advance_moderator_discussion(focus_group_id):
# Generate AI description and enhance the moderator response
try:
print(f"🎨 AI MODE: Generating description for {asset_filename}")
description = ImageDescriptionService.generate_description(focus_group_id, asset_filename)
description = await ImageDescriptionService.generate_description(focus_group_id, asset_filename)
# Enhance the moderator response with the description using display reference if available
if display_reference:
@ -597,21 +599,21 @@ def advance_moderator_discussion(focus_group_id):
if generate_participant_response and not result.get("completed", False):
try:
# Use AI to decide which participant should respond next
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature, 'ai')
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature, 'ai')
if decision.get('action') == 'participant_respond':
participant_id = decision['details']['participant_id']
topic_context = decision['details']['topic_context']
# Get participant data
persona = Persona.find_by_id(participant_id)
persona = await Persona.find_by_id(participant_id)
if persona:
# Get recent messages for context
messages = FocusGroup.get_messages(focus_group_id)
recent_messages = messages[-20:] if len(messages) > 20 else messages
# Generate participant response
response_text = generate_persona_response(
response_text = await generate_persona_response(
persona=persona,
current_topic=topic_context,
previous_messages=recent_messages,
@ -661,7 +663,7 @@ def advance_moderator_discussion(focus_group_id):
@focus_group_ai_bp.route('/moderator/position/<focus_group_id>', methods=['PUT'])
@jwt_required(optional=True) # Make JWT optional for development
def set_moderator_position(focus_group_id):
async def set_moderator_position(focus_group_id):
"""
Set the moderator position to a specific section and item.
@ -675,7 +677,7 @@ def set_moderator_position(focus_group_id):
A JSON object confirming the position change
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
# Validate required fields
if 'section_id' not in data:
@ -686,7 +688,7 @@ def set_moderator_position(focus_group_id):
section_id = data['section_id']
item_id = data.get('item_id')
result = AIModeratorService.set_moderator_position(focus_group_id, section_id, item_id)
result = await AIModeratorService.set_moderator_position(focus_group_id, section_id, item_id)
if "error" in result:
return jsonify(result), 404 if "not found" in result["error"] else 400
@ -705,7 +707,7 @@ def set_moderator_position(focus_group_id):
@focus_group_ai_bp.route('/autonomous/start/<focus_group_id>', methods=['POST'])
@jwt_required(optional=True)
def start_autonomous_conversation(focus_group_id):
async def start_autonomous_conversation(focus_group_id):
"""
Start autonomous conversation for a focus group.
@ -719,7 +721,7 @@ def start_autonomous_conversation(focus_group_id):
"""
try:
current_app.logger.info(f"=== START AUTONOMOUS CONVERSATION API called for focus group {focus_group_id} ===")
data = request.get_json() or {}
data = (await request.get_json()) or {}
initial_prompt = data.get('initial_prompt')
current_app.logger.info(f"Request data: {data}")
@ -728,49 +730,30 @@ def start_autonomous_conversation(focus_group_id):
controller = AutonomousConversationController(focus_group_id, current_app.logger)
current_app.logger.info("Controller created successfully")
# Start the conversation (this will run asynchronously)
import asyncio
current_app.logger.info("Preparing to submit conversation to AI Runner...")
current_app.logger.info("Setting up asyncio loop...")
# For now, we'll run this synchronously. In production, you might want to use a task queue
# Get the AI Runner service and submit the conversation
ai_runner = get_ai_runner()
if not ai_runner.is_running:
current_app.logger.error("AI Runner service is not running")
return jsonify({"error": "AI Runner service is not available"}), 503
# Submit the conversation to the AI Runner (non-blocking)
current_app.logger.info("Submitting conversation to AI Runner...")
try:
# Create a new event loop if one doesn't exist
loop = asyncio.get_event_loop()
current_app.logger.info("Using existing event loop")
except RuntimeError:
current_app.logger.info("Creating new event loop")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
current_app.logger.info("Calling controller.start_autonomous_conversation...")
# GPT-5 fix: Start the conversation in Socket.IO background task instead of threading
# This ensures the AI loop runs in the correct eventlet greenlet
from app.extensions import socketio
from flask import copy_current_request_context
@copy_current_request_context
def start_conversation_in_socketio_greenlet():
"""Run the autonomous conversation in the eventlet greenlet context."""
try:
with current_app.app_context():
# Run the async conversation in this greenlet using asyncio
import asyncio
result = asyncio.run(controller.start_autonomous_conversation(initial_prompt))
current_app.logger.info(f"Background conversation result: {result}")
except Exception as e:
try:
current_app.logger.error(f"Background conversation error: {e}")
except:
print(f"Background conversation error: {e}") # Fallback if logger fails
# Use socketio.start_background_task instead of threading
socketio.start_background_task(start_conversation_in_socketio_greenlet)
future = ai_runner.submit_conversation(
focus_group_id,
controller.start_autonomous_conversation(initial_prompt)
)
current_app.logger.info("Conversation submitted to AI Runner successfully")
except Exception as e:
current_app.logger.error(f"Failed to submit conversation to AI Runner: {e}")
return jsonify({"error": f"Failed to start AI conversation: {str(e)}"}), 500
# Log the AI mode start event
try:
user_id = get_jwt_identity() # Get user ID if available
mode_event_id = FocusGroup.add_mode_event(focus_group_id, 'ai_mode_started', user_id)
mode_event_id = await FocusGroup.add_mode_event(focus_group_id, 'ai_mode_started', user_id)
current_app.logger.info(f"Logged AI mode start event: {mode_event_id}")
except Exception as e:
current_app.logger.warning(f"Failed to log AI mode start event: {e}")
@ -780,7 +763,8 @@ def start_autonomous_conversation(focus_group_id):
"message": "Autonomous conversation started",
"focus_group_id": focus_group_id,
"state": "starting",
"background": True
"background": True,
"ai_runner_active": True
}
current_app.logger.info(f"Controller returned result: {result}")
@ -803,7 +787,7 @@ def start_autonomous_conversation(focus_group_id):
@focus_group_ai_bp.route('/autonomous/stop/<focus_group_id>', methods=['POST'])
@jwt_required(optional=True)
def stop_autonomous_conversation(focus_group_id):
async def stop_autonomous_conversation(focus_group_id):
"""
Stop autonomous conversation for a focus group.
@ -816,25 +800,30 @@ def stop_autonomous_conversation(focus_group_id):
A JSON object containing the stop result
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
reason = data.get('reason', 'manual_stop')
current_app.logger.info(f"=== STOP AUTONOMOUS CONVERSATION API called for focus group {focus_group_id} ===")
current_app.logger.info(f"Stop reason: {reason}")
# Create autonomous conversation controller
controller = AutonomousConversationController(focus_group_id, current_app.logger)
# Use AI Runner to stop the conversation
current_app.logger.info("Requesting AI Runner to stop conversation...")
ai_runner = get_ai_runner()
# Signal the running conversation loop to stop gracefully
# No need for asyncio.run() or background task - just set flags
from datetime import datetime
controller.is_running = False
controller.conversation_state = "completed"
if ai_runner.is_running:
success = ai_runner.stop_conversation(focus_group_id)
if success:
current_app.logger.info(f"Successfully requested stop for focus group {focus_group_id}")
else:
current_app.logger.warning(f"No active conversation found for focus group {focus_group_id}")
else:
current_app.logger.warning("AI Runner is not running, cannot stop conversation")
# Update focus group status in database
from datetime import datetime
status = 'completed' if reason in ['completed', 'discussion_guide_completed', 'natural_completion'] else 'active'
FocusGroup.update(focus_group_id, {
await FocusGroup.update(focus_group_id, {
'status': status,
'autonomous_ended_at': datetime.utcnow(),
'completion_reason': reason
@ -842,8 +831,13 @@ def stop_autonomous_conversation(focus_group_id):
current_app.logger.info(f"Signaled autonomous conversation to stop for focus group {focus_group_id}: {reason}")
# Mode events are now handled by AIModeratorService.end_session_with_concluding_statement()
# to prevent duplicate mode event indicators
# Add mode event for AI session concluded (regardless of reason)
user_id = get_jwt_identity() if get_jwt_identity() else None
mode_event_id = await FocusGroup.add_mode_event(focus_group_id, 'ai_session_concluded', user_id)
if mode_event_id:
current_app.logger.info(f"🎯 Added AI session concluded mode event: {mode_event_id}")
else:
current_app.logger.warning(f"Failed to add AI session concluded mode event for focus group {focus_group_id}")
# Return immediately with a success response like start_autonomous_conversation
result = {
@ -865,7 +859,7 @@ def stop_autonomous_conversation(focus_group_id):
@focus_group_ai_bp.route('/autonomous/status/<focus_group_id>', methods=['GET'])
@jwt_required(optional=True)
def get_autonomous_conversation_status(focus_group_id):
async def get_autonomous_conversation_status(focus_group_id):
"""
Get the status of autonomous conversation for a focus group.
@ -877,7 +871,7 @@ def get_autonomous_conversation_status(focus_group_id):
controller = AutonomousConversationController(focus_group_id, current_app.logger)
# Get status
status = controller.get_conversation_status()
status = await controller.get_conversation_status()
return jsonify({
"message": "Autonomous conversation status retrieved",
@ -958,7 +952,7 @@ def get_conversation_analytics(focus_group_id):
@focus_group_ai_bp.route('/conversation/decision/<focus_group_id>', methods=['POST'])
@jwt_required(optional=True)
def make_conversation_decision(focus_group_id):
async def make_conversation_decision(focus_group_id):
"""
Make a conversation decision using the LLM decision engine.
@ -972,12 +966,12 @@ def make_conversation_decision(focus_group_id):
A JSON object containing the decision
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
temperature = data.get('temperature', 0.7)
mode = data.get('mode', 'ai') # Default to 'ai' mode for backward compatibility
# Make decision
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature, mode)
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature, mode)
response_data = {
"message": "Conversation decision made",
@ -1002,7 +996,7 @@ def make_conversation_decision(focus_group_id):
@focus_group_ai_bp.route('/conversation/insights/<focus_group_id>', methods=['GET'])
@jwt_required(optional=True)
def get_conversation_insights(focus_group_id):
async def get_conversation_insights(focus_group_id):
"""
Get LLM-generated insights about the conversation.
@ -1011,7 +1005,7 @@ def get_conversation_insights(focus_group_id):
"""
try:
# Get insights
insights = ConversationDecisionService.get_conversation_insights(focus_group_id)
insights = await ConversationDecisionService.get_conversation_insights(focus_group_id)
return jsonify({
"message": "Conversation insights generated",
@ -1034,7 +1028,7 @@ def get_conversation_insights(focus_group_id):
@focus_group_ai_bp.route('/conversation/intervene/<focus_group_id>', methods=['POST'])
@jwt_required(optional=True)
def manual_intervention(focus_group_id):
async def manual_intervention(focus_group_id):
"""
Manually intervene in autonomous conversation.
@ -1049,7 +1043,7 @@ def manual_intervention(focus_group_id):
A JSON object containing the intervention result
"""
try:
data = request.get_json() or {}
data = (await request.get_json()) or {}
action = data.get('action', 'pause')
message = data.get('message')
participant_id = data.get('participant_id')
@ -1073,7 +1067,7 @@ def manual_intervention(focus_group_id):
result = {"message": "Moderator message added"}
elif action == 'call_participant' and participant_id:
# Add moderator message calling on specific participant
persona = Persona.find_by_id(participant_id)
persona = await Persona.find_by_id(participant_id)
if persona:
call_message = f"{persona.get('name', 'Participant')}, what are your thoughts on this?"
FocusGroup.add_message(focus_group_id, {
@ -1106,7 +1100,7 @@ def manual_intervention(focus_group_id):
@focus_group_ai_bp.route('/conversation/reasoning-history/<focus_group_id>', methods=['GET'])
@jwt_required(optional=True)
def get_reasoning_history(focus_group_id):
async def get_reasoning_history(focus_group_id):
"""
Get the AI reasoning history for an autonomous conversation.
@ -1116,7 +1110,7 @@ def get_reasoning_history(focus_group_id):
try:
# Create autonomous conversation controller to get reasoning history
controller = AutonomousConversationController(focus_group_id)
status = controller.get_conversation_status()
status = await controller.get_conversation_status()
reasoning_history = status.get('reasoning_history', [])
@ -1136,7 +1130,7 @@ def get_reasoning_history(focus_group_id):
@focus_group_ai_bp.route('/moderator/end-session/<focus_group_id>', methods=['POST'])
@jwt_required(optional=True)
def end_focus_group_session(focus_group_id):
async def end_focus_group_session(focus_group_id):
"""
End a focus group session with a concluding moderator statement.
@ -1151,7 +1145,7 @@ def end_focus_group_session(focus_group_id):
try:
current_app.logger.info(f"=== END FOCUS GROUP SESSION API called for focus group {focus_group_id} ===")
data = request.get_json() or {}
data = (await request.get_json()) or {}
reason = data.get('reason', 'session_ended')
current_app.logger.info(f"Session ending reason: {reason}")
@ -1165,7 +1159,7 @@ def end_focus_group_session(focus_group_id):
current_app.logger.info(f"Focus group found: {focus_group.get('name', 'Unnamed')}")
# End the session with concluding statement
result = AIModeratorService.end_session_with_concluding_statement(focus_group_id, reason)
result = await AIModeratorService.end_session_with_concluding_statement(focus_group_id, reason)
if "error" in result:
current_app.logger.error(f"Error ending session: {result['error']}")

View file

@ -1,5 +1,5 @@
from flask import Blueprint, request, jsonify, Response, send_file
from flask_jwt_extended import jwt_required, get_jwt_identity
from quart import Blueprint, request, jsonify, Response, send_file
from app.auth.quart_jwt import jwt_required, get_jwt_identity
from app.models.focus_group import FocusGroup
from app.models.persona import Persona
from app.services.focus_group_service import FocusGroupService
@ -250,39 +250,49 @@ except Exception as e:
# Request data cache for direct processing
request_data_cache = {}
@focus_groups_bp.before_request
# Temporarily disable this before_request handler due to Quart ASGI context issues
# @focus_groups_bp.before_request
def cache_multipart_data():
"""Cache multipart request data only when temp directories are unavailable."""
from flask import request, g
# Only cache for asset upload endpoints when temp directory issues are expected
if (request.endpoint and 'upload_assets' in request.endpoint and
request.method == 'POST' and
request.content_type and 'multipart/form-data' in request.content_type):
try:
from quart import request, g
# Check if temp directory is available - if so, let Flask handle normally
temp_dir = setup_temp_directory()
if temp_dir:
# Temp directory is available, skip caching to allow normal Flask processing
# Safely check if we have an active request context
if not request:
return
# Safely check request properties - handle Quart/Flask differences
endpoint = getattr(request, 'endpoint', None)
method = getattr(request, 'method', None)
content_type = getattr(request, 'content_type', None)
try:
# Only cache when temp directories are unavailable
raw_data = request.stream.read()
if raw_data:
# Store in flask g object for this request
g.cached_request_data = raw_data
# Create a new stream from the cached data
from io import BytesIO
request.stream = BytesIO(raw_data)
except Exception as e:
# If caching fails, continue normally
pass
# Only cache for asset upload endpoints when temp directory issues are expected
if (endpoint and 'upload_assets' in str(endpoint) and
method == 'POST' and
content_type and 'multipart/form-data' in content_type):
# Check if temp directory is available - if so, let Quart handle normally
temp_dir = setup_temp_directory()
if temp_dir:
# Temp directory is available, skip caching to allow normal processing
return
# Enable the rest of the caching logic if needed
# For now, just return to prevent context errors
return
else:
# Not an upload endpoint, skip processing
return
except (RuntimeError, AttributeError, Exception) as e:
# Handle "Working outside of request context" gracefully
# This can happen during startup or shutdown with ASGI
return
@focus_groups_bp.route('', methods=['GET'])
@focus_groups_bp.route('/', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_focus_groups():
async def get_focus_groups():
import logging
logger = logging.getLogger('app.focus_groups')
@ -293,7 +303,7 @@ def get_focus_groups():
# Always return all focus groups for now
logger.debug("Calling FocusGroup.get_all() to show all focus groups")
focus_groups = FocusGroup.get_all()
focus_groups = await FocusGroup.get_all()
logger.debug(f"Found {len(focus_groups)} total focus groups")
# Make focus groups serializable
@ -310,9 +320,9 @@ def get_focus_groups():
@focus_groups_bp.route('/all', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_all_focus_groups():
async def get_all_focus_groups():
try:
focus_groups = FocusGroup.get_all()
focus_groups = await FocusGroup.get_all()
# Make focus groups serializable
serializable_groups = make_serializable(focus_groups)
return jsonify(serializable_groups), 200
@ -322,9 +332,9 @@ def get_all_focus_groups():
@focus_groups_bp.route('/<focus_group_id>', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_focus_group(focus_group_id):
async def get_focus_group(focus_group_id):
try:
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
@ -337,7 +347,7 @@ def get_focus_group(focus_group_id):
participants_data = []
for persona_id in focus_group['participants']:
try:
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if persona:
participants_data.append(persona)
except Exception as e:
@ -354,14 +364,14 @@ def get_focus_group(focus_group_id):
@focus_groups_bp.route('', methods=['POST'])
@focus_groups_bp.route('/', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def create_focus_group():
async def create_focus_group():
try:
user_id = get_jwt_identity()
# Use default user ID if not authenticated
if not user_id:
user_id = 'default_id'
data = request.get_json()
data = await request.get_json()
if not data or not data.get('name'):
return jsonify({"message": "Missing required fields"}), 400
@ -378,10 +388,10 @@ def create_focus_group():
if 'participants_count' not in data:
data['participants_count'] = len(data['participants'])
focus_group_id = FocusGroup.create(data, user_id)
focus_group_id = await FocusGroup.create(data, user_id)
# Get the created focus group to return
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
return jsonify({
"message": "Focus group created successfully",
@ -402,7 +412,7 @@ def test_logging_endpoint(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>', methods=['PUT'])
@jwt_required()
def update_focus_group(focus_group_id):
async def update_focus_group(focus_group_id):
import datetime
import os
@ -416,7 +426,7 @@ def update_focus_group(focus_group_id):
except:
pass # Don't let logging errors break the endpoint
data = request.get_json()
data = await request.get_json()
try:
log_msg = f"🔧 [{datetime.datetime.now()}] UPDATE DATA: {data}\n"
@ -442,11 +452,11 @@ def update_focus_group(focus_group_id):
if not data:
return jsonify({"message": "No data provided"}), 400
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
success = FocusGroup.update(focus_group_id, data)
success = await FocusGroup.update(focus_group_id, data)
if success:
return jsonify({"message": "Focus group updated successfully"}), 200
@ -455,12 +465,12 @@ def update_focus_group(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>', methods=['DELETE'])
@jwt_required()
def delete_focus_group(focus_group_id):
focus_group = FocusGroup.find_by_id(focus_group_id)
async def delete_focus_group(focus_group_id):
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
success = FocusGroup.delete(focus_group_id)
success = await FocusGroup.delete(focus_group_id)
if success:
return jsonify({"message": "Focus group deleted successfully"}), 200
@ -469,8 +479,8 @@ def delete_focus_group(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/participants', methods=['POST'])
@jwt_required()
def add_participant(focus_group_id):
data = request.get_json()
async def add_participant(focus_group_id):
data = await request.get_json()
if not data or not data.get('persona_id'):
return jsonify({"message": "Missing persona_id"}), 400
@ -478,16 +488,16 @@ def add_participant(focus_group_id):
persona_id = data.get('persona_id')
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
# Verify persona exists
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if not persona:
return jsonify({"message": "Persona not found"}), 404
success = FocusGroup.add_participant(focus_group_id, persona_id)
success = await FocusGroup.add_participant(focus_group_id, persona_id)
if success:
return jsonify({"message": "Participant added successfully"}), 200
@ -496,13 +506,13 @@ def add_participant(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/participants/<persona_id>', methods=['DELETE'])
@jwt_required()
def remove_participant(focus_group_id, persona_id):
async def remove_participant(focus_group_id, persona_id):
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
success = FocusGroup.remove_participant(focus_group_id, persona_id)
success = await FocusGroup.remove_participant(focus_group_id, persona_id)
if success:
return jsonify({"message": "Participant removed successfully"}), 200
@ -511,17 +521,17 @@ def remove_participant(focus_group_id, persona_id):
@focus_groups_bp.route('/<focus_group_id>/messages', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_focus_group_messages(focus_group_id):
async def get_focus_group_messages(focus_group_id):
"""Get all messages for a focus group, including mode events."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
# Get messages and mode events
messages = FocusGroup.get_messages(focus_group_id)
mode_events = FocusGroup.get_mode_events(focus_group_id)
messages = await FocusGroup.get_messages(focus_group_id)
mode_events = await FocusGroup.get_mode_events(focus_group_id)
# Make messages and events serializable and convert field names for frontend compatibility
serializable_messages = make_serializable(messages)
@ -544,17 +554,17 @@ def get_focus_group_messages(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/messages', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def add_focus_group_message(focus_group_id):
async def add_focus_group_message(focus_group_id):
"""Add a new message to a focus group."""
try:
# Get message data from request
data = request.get_json()
data = await request.get_json()
if not data or not data.get('text'):
return jsonify({"message": "Missing required fields"}), 400
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
@ -577,7 +587,7 @@ def add_focus_group_message(focus_group_id):
# Activate visual assets in the focus group for LLM context
try:
success = FocusGroup._activate_visual_assets(focus_group_id, [filename], None)
success = await FocusGroup._activate_visual_assets(focus_group_id, [filename], None)
if success:
print(f"✅ VISUAL CONTEXT ACTIVATED: {filename} ({visual_asset.get('displayReference')})")
else:
@ -604,7 +614,7 @@ def add_focus_group_message(focus_group_id):
# Activate visual assets in the focus group for LLM context
try:
success = FocusGroup._activate_visual_assets(focus_group_id, [asset_filename], None)
success = await FocusGroup._activate_visual_assets(focus_group_id, [asset_filename], None)
if success:
print(f"✅ VISUAL CONTEXT ACTIVATED: {asset_filename}")
else:
@ -623,7 +633,7 @@ def add_focus_group_message(focus_group_id):
print(f" - Activates visual context: {data.get('activates_visual_context', False)}")
# Add message
message_id = FocusGroup.add_message(focus_group_id, data)
message_id = await FocusGroup.add_message(focus_group_id, data)
if not message_id:
return jsonify({"message": "Failed to add message"}), 500
@ -638,22 +648,22 @@ def add_focus_group_message(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/messages/<message_id>', methods=['PATCH'])
@jwt_required(optional=True) # Make JWT optional for development
def update_focus_group_message(focus_group_id, message_id):
async def update_focus_group_message(focus_group_id, message_id):
"""Update a message in a focus group, currently only for highlighted status."""
try:
# Get message data from request
data = request.get_json()
data = await request.get_json()
if data is None or 'highlighted' not in data:
return jsonify({"message": "Missing highlighted field"}), 400
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
# Update message highlight status
success = FocusGroup.update_message_highlight(
success = await FocusGroup.update_message_highlight(
focus_group_id,
message_id,
data['highlighted']
@ -671,16 +681,16 @@ def update_focus_group_message(focus_group_id, message_id):
@focus_groups_bp.route('/<focus_group_id>/notes', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_focus_group_notes(focus_group_id):
async def get_focus_group_notes(focus_group_id):
"""Get all notes for a focus group."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
# Get notes
notes = FocusGroup.get_notes(focus_group_id)
notes = await FocusGroup.get_notes(focus_group_id)
# Make notes serializable
serializable_notes = make_serializable(notes)
@ -691,28 +701,28 @@ def get_focus_group_notes(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/notes', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def add_focus_group_note(focus_group_id):
async def add_focus_group_note(focus_group_id):
"""Add a new note to a focus group."""
try:
# Get note data from request
data = request.get_json()
data = await request.get_json()
if not data or not data.get('content'):
return jsonify({"message": "Missing required fields"}), 400
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
# Add note
note_id = FocusGroup.add_note(focus_group_id, data)
note_id = await FocusGroup.add_note(focus_group_id, data)
if not note_id:
return jsonify({"message": "Failed to add note"}), 500
# Get the created note to return
notes = FocusGroup.get_notes(focus_group_id)
notes = await FocusGroup.get_notes(focus_group_id)
created_note = None
for note in notes:
if str(note.get('_id', '')) == str(note_id):
@ -730,16 +740,16 @@ def add_focus_group_note(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/notes/<note_id>', methods=['DELETE'])
@jwt_required(optional=True) # Make JWT optional for development
def delete_focus_group_note(focus_group_id, note_id):
async def delete_focus_group_note(focus_group_id, note_id):
"""Delete a note from a focus group."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"message": "Focus group not found"}), 404
# Delete note
success = FocusGroup.delete_note(focus_group_id, note_id)
success = await FocusGroup.delete_note(focus_group_id, note_id)
if not success:
return jsonify({"message": "Failed to delete note"}), 500
@ -754,7 +764,7 @@ def delete_focus_group_note(focus_group_id, note_id):
@focus_groups_bp.route('/generate-discussion-guide', methods=['POST'])
@focus_groups_bp.route('/<focus_group_id>/generate-discussion-guide', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def generate_discussion_guide(focus_group_id=None):
async def generate_discussion_guide(focus_group_id=None):
"""Generate a discussion guide for a focus group using the LLM service."""
import logging
logger = logging.getLogger(__name__)
@ -764,7 +774,7 @@ def generate_discussion_guide(focus_group_id=None):
try:
# Get request data
data = request.get_json()
data = await request.get_json()
if not data:
logger.warning("Discussion guide generation failed: Missing request data")
@ -814,7 +824,7 @@ def generate_discussion_guide(focus_group_id=None):
llm_model = None
if focus_group_id:
try:
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if focus_group:
llm_model = focus_group.get('llm_model')
logger.info(f"Using LLM model for focus group {focus_group_id}: {llm_model}")
@ -826,7 +836,7 @@ def generate_discussion_guide(focus_group_id=None):
llm_model = data.get('llm_model')
# Generate the discussion guide
discussion_guide = FocusGroupService.generate_discussion_guide(
discussion_guide = await FocusGroupService.generate_discussion_guide(
focus_group_name=focus_group_name,
research_brief=research_brief,
discussion_topics=formatted_topic,
@ -1038,7 +1048,7 @@ def generate_discussion_guide_filename(focus_group_name=None, guide_title=None):
@focus_groups_bp.route('/<focus_group_id>/discussion-guide/download', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def download_discussion_guide(focus_group_id):
async def download_discussion_guide(focus_group_id):
"""
Download the discussion guide for a focus group as a markdown file.
@ -1052,7 +1062,7 @@ def download_discussion_guide(focus_group_id):
logger.debug(f"=== DOWNLOAD DISCUSSION GUIDE API called for focus group {focus_group_id} ===")
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
logger.warning(f"Focus group not found: {focus_group_id}")
return jsonify({"error": "Focus group not found"}), 404
@ -1104,7 +1114,8 @@ def download_discussion_guide(focus_group_id):
# Additional asset upload utility functions
def get_upload_folder(focus_group_id):
"""Get the upload folder path for a focus group."""
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) # Go up to backend/
# Use absolute path to avoid working directory issues
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) # Go up to backend/
upload_dir = os.path.join(base_dir, 'uploads', f'focus-group-{focus_group_id}')
return upload_dir
@ -1189,7 +1200,7 @@ def save_uploaded_file_directly(file, file_path):
@focus_groups_bp.route('/<focus_group_id>/assets', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def upload_assets(focus_group_id):
async def upload_assets(focus_group_id):
"""Upload creative assets for a focus group."""
import logging
logger = logging.getLogger('app.focus_groups')
@ -1197,8 +1208,9 @@ def upload_assets(focus_group_id):
try:
logger.debug(f"=== UPLOAD ASSETS API called for focus group {focus_group_id} ===")
# Check for replace flag
replace_existing = request.form.get('replace', '').lower() == 'true'
# Check for replace flag (Quart async form access)
form_data = await request.form
replace_existing = form_data.get('replace', '').lower() == 'true'
logger.info(f"Replace existing assets flag: {replace_existing}")
# Set up temporary directory for file processing (optional)
@ -1209,7 +1221,7 @@ def upload_assets(focus_group_id):
logger.info("No temp directory available, processing files directly")
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
logger.warning(f"Focus group not found: {focus_group_id}")
return jsonify({"error": "Focus group not found"}), 404
@ -1256,7 +1268,7 @@ def upload_assets(focus_group_id):
logger.warning(f"Could not find or delete existing asset file: {filename}")
# Clear assets from database
success = FocusGroup.clear_uploaded_assets(focus_group_id)
success = await FocusGroup.clear_uploaded_assets(focus_group_id)
if success:
logger.info("Successfully cleared existing assets from database")
else:
@ -1271,14 +1283,17 @@ def upload_assets(focus_group_id):
logger.info(f"Request content type: {request.content_type}")
logger.info(f"Request content length: {request.content_length}")
logger.info(f"Request method: {request.method}")
logger.info(f"Request files keys: {list(request.files.keys()) if hasattr(request, 'files') else 'No files attribute'}")
# Get files using Quart async pattern
files_data = await request.files
logger.info(f"Request files keys: {list(files_data.keys())}")
try:
if 'assets' not in request.files:
logger.warning(f"No 'assets' key in request.files. Available keys: {list(request.files.keys())}")
if 'assets' not in files_data:
logger.warning(f"No 'assets' key in request.files. Available keys: {list(files_data.keys())}")
return jsonify({"error": "No files provided"}), 400
files = request.files.getlist('assets')
files = files_data.getlist('assets')
if not files or all(f.filename == '' for f in files):
logger.warning("No files selected")
return jsonify({"error": "No files selected"}), 400
@ -1338,9 +1353,9 @@ def upload_assets(focus_group_id):
# Try direct save first
if not save_uploaded_file_directly(file, file_path):
# Fallback to standard save method
# Fallback to standard save method (Quart async version)
try:
file.save(file_path)
await file.save(file_path)
except Exception as save_error:
logger.error(f"Both direct and standard file save failed: {save_error}")
errors.append(f"{file.filename}: Save failed - {str(save_error)}")
@ -1378,7 +1393,7 @@ def upload_assets(focus_group_id):
logger.info(f"Updating focus group {focus_group_id} with {len(uploaded_assets)} assets")
logger.info(f"Asset metadata to save: {uploaded_assets}")
success = FocusGroup.add_uploaded_assets(focus_group_id, uploaded_assets)
success = await FocusGroup.add_uploaded_assets(focus_group_id, uploaded_assets)
logger.info(f"Database update success: {success}")
if not success:
@ -1396,7 +1411,7 @@ def upload_assets(focus_group_id):
# DEBUG: Verify the data was saved by reading it back
try:
verification_assets = FocusGroup.get_uploaded_assets(focus_group_id)
verification_assets = await FocusGroup.get_uploaded_assets(focus_group_id)
logger.info(f"Verification: Found {len(verification_assets)} assets after save")
logger.info(f"Verification asset data: {verification_assets}")
except Exception as verify_error:
@ -1433,11 +1448,11 @@ def upload_assets(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/assets', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_assets(focus_group_id):
async def get_assets(focus_group_id):
"""Get list of uploaded assets for a focus group."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
@ -1468,11 +1483,11 @@ def get_assets(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/assets/<filename>', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def serve_asset(focus_group_id, filename):
async def serve_asset(focus_group_id, filename):
"""Serve an uploaded asset file."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
@ -1493,7 +1508,7 @@ def serve_asset(focus_group_id, filename):
file_path = subdirectory_path
else:
# Try flat storage location (main uploads directory)
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
main_upload_dir = os.path.join(base_dir, 'uploads')
flat_path = os.path.join(main_upload_dir, filename)
if os.path.isfile(flat_path):
@ -1503,12 +1518,12 @@ def serve_asset(focus_group_id, filename):
if not file_path or not os.path.exists(file_path):
return jsonify({"error": "Asset file not found on disk"}), 404
# Serve the file
return send_file(
# Serve the file (Quart uses attachment_filename instead of download_name)
return await send_file(
file_path,
mimetype=asset.get('mime_type', 'image/jpeg'),
as_attachment=False,
download_name=asset.get('original_name', filename)
attachment_filename=asset.get('original_name', filename)
)
except Exception as e:
@ -1517,16 +1532,16 @@ def serve_asset(focus_group_id, filename):
@focus_groups_bp.route('/<focus_group_id>/assets/<filename>', methods=['DELETE'])
@jwt_required(optional=True) # Make JWT optional for development
def delete_asset(focus_group_id, filename):
async def delete_asset(focus_group_id, filename):
"""Delete an uploaded asset."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
# Remove asset from focus group metadata
success = FocusGroup.remove_uploaded_asset(focus_group_id, filename)
success = await FocusGroup.remove_uploaded_asset(focus_group_id, filename)
if not success:
return jsonify({"error": "Failed to update focus group metadata"}), 500
@ -1557,16 +1572,16 @@ def delete_asset(focus_group_id, filename):
@focus_groups_bp.route('/<focus_group_id>/assets/<filename>', methods=['PATCH'])
@jwt_required(optional=True) # Make JWT optional for development
def update_asset_name(focus_group_id, filename):
async def update_asset_name(focus_group_id, filename):
"""Update the user assigned name for an uploaded asset."""
try:
# Verify focus group exists
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return jsonify({"error": "Focus group not found"}), 404
# Get request data
data = request.get_json()
data = await request.get_json()
if not data or 'user_assigned_name' not in data:
return jsonify({"error": "Missing user_assigned_name field"}), 400
@ -1579,7 +1594,7 @@ def update_asset_name(focus_group_id, filename):
return jsonify({"error": "Asset not found"}), 404
# Update the asset name
success = FocusGroup.update_asset_name(focus_group_id, filename, user_assigned_name)
success = await FocusGroup.update_asset_name(focus_group_id, filename, user_assigned_name)
if not success:
return jsonify({"error": "Failed to update asset name"}), 500
@ -1624,13 +1639,13 @@ def test_websocket_emission(focus_group_id):
@focus_groups_bp.route('/<focus_group_id>/describe-asset', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def describe_asset(focus_group_id):
async def describe_asset(focus_group_id):
"""Generate AI description of an asset for enhanced creative review questions."""
print(f"🔍 API ENDPOINT: describe-asset called for focus group {focus_group_id}")
try:
# Verify focus group exists
print(f"🔍 API: Looking up focus group {focus_group_id}")
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
print(f"❌ API: Focus group {focus_group_id} not found")
return jsonify({"error": "Focus group not found"}), 404
@ -1638,7 +1653,7 @@ def describe_asset(focus_group_id):
print(f"✅ API: Focus group {focus_group_id} found")
# Get asset filename from request
data = request.get_json()
data = await request.get_json()
print(f"🔍 API: Request data: {data}")
if not data or 'asset_filename' not in data:
print(f"❌ API: Missing asset_filename in request")
@ -1651,7 +1666,7 @@ def describe_asset(focus_group_id):
# Generate AI description
try:
description = ImageDescriptionService.generate_description(focus_group_id, asset_filename)
description = await ImageDescriptionService.generate_description(focus_group_id, asset_filename)
return jsonify({
"message": "Asset description generated successfully",

View file

@ -1,5 +1,5 @@
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from quart import Blueprint, request, jsonify
from app.auth.quart_jwt import jwt_required, get_jwt_identity
from app.models.folder import Folder
from bson import ObjectId
import datetime
@ -22,11 +22,11 @@ folders_bp = Blueprint('folders', __name__)
@folders_bp.route('', methods=['GET'])
@folders_bp.route('/', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_folders():
async def get_folders():
"""Get all folders - shared across all users."""
try:
# Always return all folders - this is a shared persona system
folders = Folder.get_all()
folders = await Folder.get_all()
# Make folders serializable
serializable_folders = make_serializable(folders)
@ -37,10 +37,10 @@ def get_folders():
@folders_bp.route('/<folder_id>', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_folder(folder_id):
async def get_folder(folder_id):
"""Get a specific folder by ID."""
try:
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
@ -54,10 +54,10 @@ def get_folder(folder_id):
@folders_bp.route('', methods=['POST'])
@folders_bp.route('/', methods=['POST'])
@jwt_required()
def create_folder():
async def create_folder():
"""Create a new folder."""
user_id = get_jwt_identity()
data = request.get_json()
data = await request.get_json()
if not data:
return jsonify({"message": "No data provided"}), 400
@ -65,7 +65,7 @@ def create_folder():
if not data.get('name'):
return jsonify({"message": "Folder name is required"}), 400
folder_id = Folder.create(data, user_id)
folder_id = await Folder.create(data, user_id)
return jsonify({
"message": "Folder created successfully",
@ -74,15 +74,15 @@ def create_folder():
@folders_bp.route('/<folder_id>', methods=['PUT'])
@jwt_required()
def update_folder(folder_id):
async def update_folder(folder_id):
"""Update a folder."""
try:
data = request.get_json()
data = await request.get_json()
if not data:
return jsonify({"message": "No data provided"}), 400
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
@ -96,11 +96,11 @@ def update_folder(folder_id):
if 'id' in data:
del data['id']
success = Folder.update(folder_id, data)
success = await Folder.update(folder_id, data)
if success:
# Get the updated folder and return it
updated_folder = Folder.find_by_id(folder_id)
updated_folder = await Folder.find_by_id(folder_id)
return jsonify({
"message": "Folder updated successfully",
"folder": make_serializable(updated_folder)
@ -113,9 +113,9 @@ def update_folder(folder_id):
@folders_bp.route('/<folder_id>', methods=['DELETE'])
@jwt_required()
def delete_folder(folder_id):
async def delete_folder(folder_id):
"""Delete a folder."""
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
@ -124,7 +124,7 @@ def delete_folder(folder_id):
if folder.get('created_by') != user_id:
return jsonify({"message": "Unauthorized"}), 403
success = Folder.delete(folder_id)
success = await Folder.delete(folder_id)
if success:
return jsonify({"message": "Folder deleted successfully"}), 200
@ -133,22 +133,22 @@ def delete_folder(folder_id):
@folders_bp.route('/<folder_id>/personas', methods=['POST'])
@jwt_required()
def add_persona_to_folder(folder_id):
async def add_persona_to_folder(folder_id):
"""Add a persona to a folder (supports multiple folders per persona)."""
try:
data = request.get_json()
data = await request.get_json()
if not data or not data.get('persona_id'):
return jsonify({"message": "Persona ID is required"}), 400
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
# Folder operations are shared across all users in this system
persona_id = data['persona_id']
success = Folder.add_persona(folder_id, persona_id)
success = await Folder.add_persona(folder_id, persona_id)
if success:
return jsonify({"message": "Persona added to folder successfully"}), 200
@ -160,16 +160,16 @@ def add_persona_to_folder(folder_id):
@folders_bp.route('/<folder_id>/personas/<persona_id>', methods=['DELETE'])
@jwt_required()
def remove_persona_from_folder(folder_id, persona_id):
async def remove_persona_from_folder(folder_id, persona_id):
"""Remove a persona from a folder (persona can remain in other folders)."""
try:
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
# Folder operations are shared across all users in this system
success = Folder.remove_persona(folder_id, persona_id)
success = await Folder.remove_persona(folder_id, persona_id)
if success:
return jsonify({"message": "Persona removed from folder successfully"}), 200
@ -181,15 +181,15 @@ def remove_persona_from_folder(folder_id, persona_id):
@folders_bp.route('/<folder_id>/personas/batch', methods=['POST'])
@jwt_required()
def add_personas_to_folder_batch(folder_id):
async def add_personas_to_folder_batch(folder_id):
"""Add multiple personas to a folder (personas can be in multiple folders)."""
try:
data = request.get_json()
data = await request.get_json()
if not data or not data.get('persona_ids'):
return jsonify({"message": "Persona IDs are required"}), 400
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
@ -199,7 +199,7 @@ def add_personas_to_folder_batch(folder_id):
if not isinstance(persona_ids, list):
return jsonify({"message": "persona_ids must be a list"}), 400
success = Folder.add_personas_batch(folder_id, persona_ids)
success = await Folder.add_personas_batch(folder_id, persona_ids)
if success:
return jsonify({"message": f"Successfully added {len(persona_ids)} personas to folder"}), 200
@ -211,11 +211,11 @@ def add_personas_to_folder_batch(folder_id):
@folders_bp.route('/<folder_id>/personas/remove-batch', methods=['POST'])
@jwt_required()
def remove_personas_from_folder_batch(folder_id):
async def remove_personas_from_folder_batch(folder_id):
"""Remove multiple personas from a folder (personas remain in other folders)."""
print(f"🌐 BACKEND: POST /folders/{folder_id}/personas/remove-batch endpoint hit")
try:
data = request.get_json()
data = await request.get_json()
print(f"🌐 BACKEND: Raw request data: {data}")
print(f"🌐 BACKEND: Request content type: {request.content_type}")
print(f"🌐 BACKEND: Request method: {request.method}")
@ -224,7 +224,7 @@ def remove_personas_from_folder_batch(folder_id):
print(f"❌ BACKEND: Missing persona_ids in data: {data}")
return jsonify({"message": "Persona IDs are required"}), 400
folder = Folder.find_by_id(folder_id)
folder = await Folder.find_by_id(folder_id)
if not folder:
return jsonify({"message": "Folder not found"}), 404
@ -234,7 +234,7 @@ def remove_personas_from_folder_batch(folder_id):
if not isinstance(persona_ids, list):
return jsonify({"message": "persona_ids must be a list"}), 400
success = Folder.remove_personas_batch(folder_id, persona_ids)
success = await Folder.remove_personas_batch(folder_id, persona_ids)
if success:
return jsonify({"message": f"Successfully removed {len(persona_ids)} personas from folder"}), 200

View file

@ -1,5 +1,5 @@
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from quart import Blueprint, request, jsonify
from app.auth.quart_jwt import jwt_required, get_jwt_identity
from app.models.persona import Persona
from app.services.persona_export_service import PersonaExportService
from app.services.persona_modification_service import PersonaModificationService, PersonaModificationError
@ -24,15 +24,15 @@ personas_bp = Blueprint('personas', __name__)
@personas_bp.route('', methods=['GET'])
@personas_bp.route('/', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_personas():
async def get_personas():
try:
user_id = get_jwt_identity()
if user_id:
# If authenticated, get user's personas
personas = Persona.find_by_user(user_id)
personas = await Persona.find_by_user(user_id)
else:
# For development, return all personas if not authenticated
personas = Persona.get_all()
personas = await Persona.get_all()
# Make personas serializable
serializable_personas = make_serializable(personas)
@ -43,9 +43,9 @@ def get_personas():
@personas_bp.route('/all', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_all_personas():
async def get_all_personas():
try:
personas = Persona.get_all()
personas = await Persona.get_all()
# Make personas serializable
serializable_personas = make_serializable(personas)
return jsonify(serializable_personas), 200
@ -55,9 +55,9 @@ def get_all_personas():
@personas_bp.route('/<persona_id>', methods=['GET'])
@jwt_required(optional=True) # Make JWT optional for development
def get_persona(persona_id):
async def get_persona(persona_id):
try:
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if not persona:
return jsonify({"message": "Persona not found"}), 404
@ -71,14 +71,14 @@ def get_persona(persona_id):
@personas_bp.route('', methods=['POST'])
@personas_bp.route('/', methods=['POST'])
@jwt_required()
def create_persona():
async def create_persona():
user_id = get_jwt_identity()
data = request.get_json()
data = await request.get_json()
if not data:
return jsonify({"message": "No data provided"}), 400
persona_id = Persona.create(data, user_id)
persona_id = await Persona.create(data, user_id)
return jsonify({
"message": "Persona created successfully",
@ -87,14 +87,14 @@ def create_persona():
@personas_bp.route('/<persona_id>', methods=['PUT'])
@jwt_required()
def update_persona(persona_id):
async def update_persona(persona_id):
try:
data = request.get_json()
data = await request.get_json()
if not data:
return jsonify({"message": "No data provided"}), 400
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if not persona:
return jsonify({"message": "Persona not found"}), 404
@ -106,11 +106,11 @@ def update_persona(persona_id):
if 'id' in data:
del data['id']
success = Persona.update(persona_id, data)
success = await Persona.update(persona_id, data)
if success:
# Get the updated persona and return it
updated_persona = Persona.find_by_id(persona_id)
updated_persona = await Persona.find_by_id(persona_id)
return jsonify({
"message": "Persona updated successfully",
"persona": make_serializable(updated_persona)
@ -123,12 +123,12 @@ def update_persona(persona_id):
@personas_bp.route('/<persona_id>', methods=['DELETE'])
@jwt_required()
def delete_persona(persona_id):
persona = Persona.find_by_id(persona_id)
async def delete_persona(persona_id):
persona = await Persona.find_by_id(persona_id)
if not persona:
return jsonify({"message": "Persona not found"}), 404
success = Persona.delete(persona_id)
success = await Persona.delete(persona_id)
if success:
return jsonify({"message": "Persona deleted successfully"}), 200
@ -137,16 +137,16 @@ def delete_persona(persona_id):
@personas_bp.route('/batch', methods=['POST'])
@jwt_required()
def create_multiple_personas():
async def create_multiple_personas():
user_id = get_jwt_identity()
data = request.get_json()
data = await request.get_json()
if not data or not isinstance(data, list):
return jsonify({"message": "Invalid data format. Expected list of personas"}), 400
persona_ids = []
for persona_data in data:
persona_id = Persona.create(persona_data, user_id)
persona_id = await Persona.create(persona_data, user_id)
persona_ids.append(persona_id)
return jsonify({
@ -156,7 +156,7 @@ def create_multiple_personas():
@personas_bp.route('/<persona_id>/modify-with-ai', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def modify_persona_with_ai(persona_id):
async def modify_persona_with_ai(persona_id):
"""
Modify a persona using AI based on natural language instructions.
@ -168,7 +168,7 @@ def modify_persona_with_ai(persona_id):
"""
try:
# Get request data
request_data = request.get_json()
request_data = await request.get_json()
if not request_data:
return jsonify({"error": "No request data provided"}), 400
@ -184,7 +184,7 @@ def modify_persona_with_ai(persona_id):
print(f"📝 Modification prompt: {modification_prompt[:100]}...")
# Call the modification service
modified_persona_data = PersonaModificationService.modify_persona(
modified_persona_data = await PersonaModificationService.modify_persona(
persona_id=persona_id,
modification_prompt=modification_prompt,
llm_model=llm_model,
@ -207,7 +207,7 @@ def modify_persona_with_ai(persona_id):
@personas_bp.route('/<persona_id>/export-profile', methods=['POST'])
@jwt_required(optional=True) # Make JWT optional for development
def export_persona_profile(persona_id):
async def export_persona_profile(persona_id):
"""
Export a persona profile as beautifully formatted markdown.
@ -217,12 +217,12 @@ def export_persona_profile(persona_id):
"""
try:
# Get the persona data
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if not persona:
return jsonify({"error": "Persona not found"}), 404
# Get optional parameters from request
request_data = request.get_json() or {}
request_data = await request.get_json() or {}
llm_model = request_data.get('llm_model', 'gpt-4.1')
temperature = request_data.get('temperature', 0.3)
@ -235,7 +235,7 @@ def export_persona_profile(persona_id):
print(f"🤖 Backend: Exporting profile for persona {persona_data.get('name', persona_id)} using {llm_model}")
# Generate the markdown profile
result = export_service.generate_profile_markdown(
result = await export_service.generate_profile_markdown(
persona_data=persona_data,
llm_model=llm_model,
temperature=temperature

View file

@ -5,7 +5,7 @@ including sequential navigation through structured discussion guides.
"""
from typing import Dict, List, Any, Optional, Tuple
from flask import current_app
import logging
from app.models.focus_group import FocusGroup, emit_websocket_event
from app.services.llm_service import LLMService, LLMServiceError
from app.utils.prompt_loader import load_prompt, PromptLoaderError
@ -16,6 +16,8 @@ import json
class AIModeratorService:
"""Service for AI-powered focus group moderation."""
logger = logging.getLogger(__name__)
@staticmethod
def _count_total_items(sections: List[Dict[str, Any]]) -> int:
"""
@ -143,7 +145,7 @@ class AIModeratorService:
return completed_count
@staticmethod
def get_moderator_status(focus_group_id: str) -> Dict[str, Any]:
async def get_moderator_status(focus_group_id: str) -> Dict[str, Any]:
"""
Get the current moderator status for a focus group.
@ -154,7 +156,7 @@ class AIModeratorService:
Dictionary containing current moderator status
"""
try:
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
@ -171,7 +173,7 @@ class AIModeratorService:
# Save the initial position to the database
try:
FocusGroup.update(focus_group_id, {
await FocusGroup.update(focus_group_id, {
'moderator_position': moderator_position
})
print(f"Initialized moderator position for focus group {focus_group_id}")
@ -245,7 +247,7 @@ class AIModeratorService:
return {"error": f"Error getting moderator status: {str(e)}"}
@staticmethod
def advance_discussion(focus_group_id: str) -> Dict[str, Any]:
async def advance_discussion(focus_group_id: str) -> Dict[str, Any]:
"""
Advance the discussion to the next item in the guide and generate appropriate moderator response.
@ -256,7 +258,7 @@ class AIModeratorService:
Dictionary containing the moderator response and updated position
"""
try:
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
@ -265,7 +267,7 @@ class AIModeratorService:
# Handle legacy markdown format
if isinstance(discussion_guide, str):
return AIModeratorService._handle_legacy_advance(focus_group_id, discussion_guide)
return await AIModeratorService._handle_legacy_advance(focus_group_id, discussion_guide)
# Handle structured JSON format
if not discussion_guide or 'sections' not in discussion_guide:
@ -292,13 +294,13 @@ class AIModeratorService:
}
# Generate moderator response based on the next item
moderator_response = AIModeratorService._generate_moderator_response(
moderator_response = await AIModeratorService._generate_moderator_response(
focus_group_id, next_item, section_info, new_position
)
# Update focus group with new position
print(f"🎯 Advancing moderator position for focus group {focus_group_id}: {new_position}")
update_success = FocusGroup.update(focus_group_id, {
update_success = await FocusGroup.update(focus_group_id, {
'moderator_position': new_position
})
@ -307,14 +309,14 @@ class AIModeratorService:
# Emit WebSocket event for moderator position change (same pattern as FocusGroup.add_message)
try:
moderator_status = AIModeratorService.get_moderator_status(focus_group_id)
moderator_status = await AIModeratorService.get_moderator_status(focus_group_id)
if "error" not in moderator_status:
emit_websocket_event('moderator_status_update', focus_group_id, {
await emit_websocket_event('moderator_status_update', focus_group_id, {
'moderator_status': moderator_status
})
current_app.logger.debug(f"🔔 Emitted moderator_status_update websocket event for focus group {focus_group_id}")
AIModeratorService.logger.debug(f"🔔 Emitted moderator_status_update websocket event for focus group {focus_group_id}")
except Exception as e:
current_app.logger.warning(f"Failed to emit moderator position websocket event: {str(e)}")
AIModeratorService.logger.warning(f"Failed to emit moderator position websocket event: {str(e)}")
else:
print(f"❌ Failed to update moderator position in database")
@ -331,7 +333,7 @@ class AIModeratorService:
return {"error": f"Error advancing discussion: {str(e)}"}
@staticmethod
def set_moderator_position(focus_group_id: str, section_id: str, item_id: Optional[str] = None) -> Dict[str, Any]:
async def set_moderator_position(focus_group_id: str, section_id: str, item_id: Optional[str] = None) -> Dict[str, Any]:
"""
Set the moderator position to a specific section and item.
@ -344,7 +346,7 @@ class AIModeratorService:
Dictionary containing the result and new position
"""
try:
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
@ -408,7 +410,7 @@ class AIModeratorService:
item_index = i
item_type = 'activity'
found = True
current_app.logger.info(f"📍 Found item '{item_id}' in subsection {subsection_idx} activity {i}, using subsection_index={subsection_index}, item_index={item_index}")
AIModeratorService.logger.info(f"📍 Found item '{item_id}' in subsection {subsection_idx} activity {i}, using subsection_index={subsection_index}, item_index={item_index}")
break
if found:
break
@ -421,7 +423,7 @@ class AIModeratorService:
item_index = i
item_type = 'question'
found = True
current_app.logger.info(f"📍 Found item '{item_id}' in subsection {subsection_idx} question {i}, using subsection_index={subsection_index}, item_index={item_index}")
AIModeratorService.logger.info(f"📍 Found item '{item_id}' in subsection {subsection_idx} question {i}, using subsection_index={subsection_index}, item_index={item_index}")
break
if found:
@ -443,25 +445,25 @@ class AIModeratorService:
# Log detailed position information for debugging
if subsection_index is not None:
current_app.logger.info(f"🎯 Setting moderator position: section_index={section_index}, subsection_index={subsection_index}, item_index={item_index}, item_type={item_type}")
AIModeratorService.logger.info(f"🎯 Setting moderator position: section_index={section_index}, subsection_index={subsection_index}, item_index={item_index}, item_type={item_type}")
else:
current_app.logger.info(f"🎯 Setting moderator position: section_index={section_index}, item_index={item_index}, item_type={item_type}")
AIModeratorService.logger.info(f"🎯 Setting moderator position: section_index={section_index}, item_index={item_index}, item_type={item_type}")
# Update focus group
FocusGroup.update(focus_group_id, {
await FocusGroup.update(focus_group_id, {
'moderator_position': new_position
})
# Emit WebSocket event for moderator position change (same pattern as FocusGroup.add_message)
try:
moderator_status = AIModeratorService.get_moderator_status(focus_group_id)
moderator_status = await AIModeratorService.get_moderator_status(focus_group_id)
if "error" not in moderator_status:
emit_websocket_event('moderator_status_update', focus_group_id, {
await emit_websocket_event('moderator_status_update', focus_group_id, {
'moderator_status': moderator_status
})
current_app.logger.debug(f"🔔 Emitted moderator_status_update websocket event for focus group {focus_group_id}")
AIModeratorService.logger.debug(f"🔔 Emitted moderator_status_update websocket event for focus group {focus_group_id}")
except Exception as e:
current_app.logger.warning(f"Failed to emit moderator position websocket event: {str(e)}")
AIModeratorService.logger.warning(f"Failed to emit moderator position websocket event: {str(e)}")
return {
"message": "Moderator position updated successfully",
@ -569,7 +571,7 @@ class AIModeratorService:
})
@staticmethod
def _generate_moderator_response(focus_group_id: str, item: Dict[str, Any], section_info: Dict[str, Any], position: Dict[str, Any]) -> str:
async def _generate_moderator_response(focus_group_id: str, item: Dict[str, Any], section_info: Dict[str, Any], position: Dict[str, Any]) -> str:
"""
Generate an appropriate moderator response for the current item.
@ -584,7 +586,7 @@ class AIModeratorService:
"""
try:
# Get previous messages for context
messages = FocusGroup.get_messages(focus_group_id)
messages = await FocusGroup.get_messages(focus_group_id)
recent_messages = messages[-10:] if messages else [] # Last 10 messages
# Format context
@ -602,11 +604,11 @@ class AIModeratorService:
prompt = load_prompt('ai-moderator-system', context)
# Get LLM model for this focus group
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
llm_model = focus_group.get('llm_model') if focus_group else None
# Generate response
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=prompt,
temperature=0.7,
model_name=llm_model
@ -640,13 +642,13 @@ class AIModeratorService:
return "\n".join(formatted)
@staticmethod
def _handle_legacy_advance(focus_group_id: str, discussion_guide: str) -> Dict[str, Any]:
async def _handle_legacy_advance(focus_group_id: str, discussion_guide: str) -> Dict[str, Any]:
"""Handle advancement for legacy markdown format guides."""
# For legacy format, we'll generate a generic moderator response
# This is a fallback for older discussion guides
try:
# Get recent messages for context
messages = FocusGroup.get_messages(focus_group_id)
messages = await FocusGroup.get_messages(focus_group_id)
recent_messages = messages[-5:] if messages else []
# Create a simple context
@ -660,10 +662,10 @@ class AIModeratorService:
prompt = load_prompt('ai-moderator-system', context)
# Get LLM model for this focus group
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
llm_model = focus_group.get('llm_model') if focus_group else None
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=prompt,
temperature=0.7,
model_name=llm_model
@ -684,7 +686,7 @@ class AIModeratorService:
}
@staticmethod
def end_session_with_concluding_statement(focus_group_id: str, reason: str = 'session_ended') -> Dict[str, Any]:
async def end_session_with_concluding_statement(focus_group_id: str, reason: str = 'session_ended') -> Dict[str, Any]:
"""
End a focus group session with a concluding moderator statement.
@ -696,12 +698,12 @@ class AIModeratorService:
Dictionary containing the concluding statement and session end confirmation
"""
try:
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
# Generate concluding statement
concluding_message = AIModeratorService._generate_concluding_statement(
concluding_message = await AIModeratorService._generate_concluding_statement(
focus_group_id, reason
)
@ -712,19 +714,19 @@ class AIModeratorService:
"senderId": "moderator"
}
message_id = FocusGroup.add_message(focus_group_id, message_data)
message_id = await FocusGroup.add_message(focus_group_id, message_data)
if not message_id:
print(f"Warning: Failed to save concluding message for focus group {focus_group_id}")
# Update focus group status to completed
FocusGroup.update(focus_group_id, {
await FocusGroup.update(focus_group_id, {
'status': 'completed'
})
# Add mode event for all AI session conclusions
# This includes auto_complete, natural_completion, discussion_guide_completed, manual_stop, etc.
mode_event_id = FocusGroup.add_mode_event(
mode_event_id = await FocusGroup.add_mode_event(
focus_group_id=focus_group_id,
event_type='ai_session_concluded'
)
@ -749,7 +751,7 @@ class AIModeratorService:
return {"error": f"Error ending session: {str(e)}"}
@staticmethod
def _generate_concluding_statement(focus_group_id: str, reason: str) -> str:
async def _generate_concluding_statement(focus_group_id: str, reason: str) -> str:
"""
Generate an appropriate concluding statement for the session.
@ -762,12 +764,12 @@ class AIModeratorService:
"""
try:
# Get focus group details for context
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
return AIModeratorService._get_fallback_concluding_message(reason)
# Get recent messages for context
messages = FocusGroup.get_messages(focus_group_id)
messages = await FocusGroup.get_messages(focus_group_id)
recent_messages = messages[-5:] if messages else []
# Create context for LLM
@ -783,10 +785,10 @@ class AIModeratorService:
prompt = load_prompt('ai-moderator-system', context)
# Get LLM model for this focus group
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
llm_model = focus_group.get('llm_model') if focus_group else None
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=prompt,
temperature=0.5, # Lower temperature for more consistent, professional responses
model_name=llm_model

View file

@ -58,7 +58,7 @@ def _sanitize_persona_data_for_json(persona_data: Dict[str, Any]) -> Dict[str, A
return sanitized
def generate_basic_personas(
async def generate_basic_personas(
audience_brief: str,
research_objective: Optional[str] = None,
count: int = 5,
@ -117,7 +117,7 @@ def generate_basic_personas(
# Log the LLM API call
print(f"🤖 Backend: Making LLM API call to {llm_model or 'gemini-2.5-pro'} for basic persona generation")
raw_response = LLMService.generate_content(
raw_response = await LLMService.generate_content(
prompt=final_prompt,
temperature=temperature,
system_prompt=system_prompt,
@ -187,7 +187,7 @@ def generate_basic_personas(
raise PersonaGenerationError(f"Error generating basic personas: {str(e)}")
def generate_persona(
async def generate_persona(
prompt_customization: Optional[str] = None,
basic_persona: Optional[Dict[str, Any]] = None,
temperature: float = 0.7,
@ -254,7 +254,7 @@ def generate_persona(
persona_name = basic_persona.get('name', 'Unknown') if basic_persona else 'New Persona'
print(f"🤖 Backend: Making LLM API call to {llm_model or 'gemini-2.5-pro'} for detailed persona generation of '{persona_name}'")
persona_data = LLMService.generate_structured_response(
persona_data = await LLMService.generate_structured_response(
prompt=final_prompt,
temperature=temperature,
system_prompt=system_prompt,
@ -283,7 +283,7 @@ def generate_persona(
raise PersonaGenerationError(f"Error generating persona: {str(e)}")
def generate_persona_summary(
async def generate_persona_summary(
persona_data: Dict[str, Any],
temperature: float = 0.7,
llm_model: Optional[str] = None
@ -325,7 +325,7 @@ def generate_persona_summary(
persona_name = persona_data.get('name', 'Unknown')
print(f"🤖 Backend: Making LLM API call to {llm_model or 'gemini-2.5-pro'} for summary generation of '{persona_name}'")
raw_response = LLMService.generate_content(
raw_response = await LLMService.generate_content(
prompt=final_prompt,
temperature=temperature,
system_prompt=system_prompt,
@ -388,7 +388,7 @@ def generate_persona_summary(
raise PersonaGenerationError(f"Error generating persona summary: {str(e)}")
def generate_persona_download_summary(
async def generate_persona_download_summary(
persona_data: Dict[str, Any],
temperature: float = 0.7,
llm_model: Optional[str] = None
@ -431,7 +431,7 @@ def generate_persona_download_summary(
print(f"🤖 Backend: Making LLM API call to {llm_model or 'gemini-2.5-pro'} for download summary of '{persona_name}'")
# Generate the markdown content directly
markdown_response = LLMService.generate_content(
markdown_response = await LLMService.generate_content(
prompt=final_prompt,
temperature=temperature,
system_prompt=system_prompt,
@ -544,7 +544,7 @@ LIFE SCENARIOS REQUIREMENTS:
return "Create a persona with these characteristics: " + "; ".join(customizations)
def enhance_audience_brief(
async def enhance_audience_brief(
audience_brief: str,
research_objective: str,
temperature: float = 0.7
@ -575,7 +575,7 @@ def enhance_audience_brief(
# Generate suggestions using the LLM service
try:
raw_response = LLMService.generate_content(
raw_response = await LLMService.generate_content(
prompt=final_prompt,
temperature=temperature
)

View file

@ -0,0 +1,376 @@
"""
AI Runner Service
Provides a single dedicated thread with an asyncio event loop for all AI conversations.
This fixes Motor loop affinity issues and improves scalability by avoiding one-thread-per-conversation.
Based on GPT-5 recommendations for clean async/threading architecture.
"""
import asyncio
import threading
import logging
from typing import Dict, Any, Optional, Callable, Awaitable
from datetime import datetime
from concurrent.futures import Future
import weakref
from app.db import get_db
from motor.motor_asyncio import AsyncIOMotorClient
class AIRunnerService:
"""Singleton service that runs all AI conversations in a dedicated thread with single event loop."""
_instance: Optional['AIRunnerService'] = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, '_initialized'):
return
self.logger = logging.getLogger(__name__)
self._thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._running = False
self._stopping = False
# Task registry for tracking and cancellation
self._active_conversations: Dict[str, asyncio.Task] = {} # focus_group_id -> Task
self._task_registry_lock = asyncio.Lock() # Will be created on the AI loop
# Database client for AI operations (will be created on AI loop)
self._db_client: Optional[AsyncIOMotorClient] = None
self._db = None
self._initialized = True
def start(self) -> None:
"""Start the AI runner thread and event loop."""
if self._running:
self.logger.warning("AI Runner already running")
return
self.logger.info("Starting AI Runner service...")
self._stopping = False
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()
# Wait for loop to be ready
while self._loop is None and not self._stopping:
threading.Event().wait(0.01)
if self._loop:
self.logger.info("AI Runner service started successfully")
else:
self.logger.error("Failed to start AI Runner service")
def stop(self) -> None:
"""Stop the AI runner service gracefully (idempotent)."""
self.logger.info("Stopping AI Runner service...")
self._stopping = True
# Get references (they might change during shutdown)
thread = self._thread
loop = self._loop
if loop is not None:
# Cancel all active conversations
try:
future = asyncio.run_coroutine_threadsafe(self._cancel_all_conversations(), loop)
future.result(timeout=3.0)
self.logger.info("All AI conversations cancelled")
except Exception as e:
self.logger.warning(f"Error cancelling conversations (continuing shutdown): {e}")
# Stop the event loop
try:
loop.call_soon_threadsafe(loop.stop)
self.logger.info("AI Runner event loop stop requested")
except Exception as e:
self.logger.warning(f"Error stopping event loop (continuing): {e}")
# Always try to join thread if it exists
if thread and thread.is_alive():
self.logger.info("Waiting for AI Runner thread to finish...")
thread.join(timeout=5.0)
if thread.is_alive():
self.logger.warning("AI Runner thread did not shut down gracefully")
else:
self.logger.info("AI Runner thread joined successfully")
self._running = False
self.logger.info("AI Runner service shutdown complete")
def _run_event_loop(self) -> None:
"""Main thread function that runs the asyncio event loop."""
try:
# Create new event loop for this thread
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
# Initialize async resources on this loop
self._loop.run_until_complete(self._initialize_async_resources())
self._running = True
self.logger.info("AI Runner event loop started")
# Run the event loop
self._loop.run_forever()
# Cleanup phase after loop.stop()
self.logger.info("AI Runner event loop stopped, cleaning up...")
self._loop.run_until_complete(self._cleanup_async_resources())
# Cancel any remaining tasks
pending = [t for t in asyncio.all_tasks(loop=self._loop) if t is not asyncio.current_task(self._loop)]
for t in pending:
t.cancel()
if pending:
self._loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
self.logger.info(f"Cancelled {len(pending)} remaining tasks")
# Shutdown async generators and default executor (Python 3.9+)
try:
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
except Exception as e:
self.logger.debug(f"Error shutting down async generators: {e}")
try:
self._loop.run_until_complete(self._loop.shutdown_default_executor())
except Exception as e:
self.logger.debug(f"Error shutting down default executor: {e}")
except Exception as e:
self.logger.error(f"AI Runner event loop error: {e}")
finally:
# Close the loop
try:
if self._loop:
self._loop.close()
self.logger.info("AI Runner event loop closed")
except Exception as e:
self.logger.debug(f"Error closing loop: {e}")
self._running = False
self.logger.info("AI Runner thread cleanup complete")
async def _initialize_async_resources(self) -> None:
"""Initialize async resources (database, HTTP clients) on the AI loop."""
try:
# Initialize database connection
self._db = await get_db()
self.logger.info("AI Runner: Database connection initialized")
# Create task registry lock on this loop
self._task_registry_lock = asyncio.Lock()
self.logger.info("AI Runner: Async resources initialized")
except Exception as e:
self.logger.error(f"Failed to initialize AI Runner async resources: {e}")
raise
async def _cleanup_async_resources(self) -> None:
"""Cleanup async resources."""
try:
# Cancel any remaining tasks
await self._cancel_all_conversations()
# Close database connections
if self._db_client:
self._db_client.close()
self.logger.info("AI Runner: Async resources cleaned up")
except Exception as e:
self.logger.error(f"Error cleaning up AI Runner resources: {e}")
async def _cancel_all_conversations(self) -> None:
"""Cancel all active conversation tasks."""
if not self._task_registry_lock:
return
async with self._task_registry_lock:
tasks_to_cancel = list(self._active_conversations.values())
self._active_conversations.clear()
if tasks_to_cancel:
self.logger.info(f"Cancelling {len(tasks_to_cancel)} active conversations")
for task in tasks_to_cancel:
if not task.done():
task.cancel()
# Wait for cancellation to complete
if tasks_to_cancel:
try:
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
except Exception as e:
self.logger.debug(f"Expected cancellation errors: {e}")
def submit_conversation(self, focus_group_id: str, coro: Awaitable[Any]) -> Future:
"""
Submit a conversation coroutine to run on the AI event loop.
Args:
focus_group_id: The focus group ID for tracking
coro: The coroutine to execute
Returns:
Future that will contain the result
"""
if not self._running or not self._loop:
raise RuntimeError("AI Runner is not running")
# Use run_coroutine_threadsafe to schedule on the AI loop
future = asyncio.run_coroutine_threadsafe(
self._run_conversation_with_tracking(focus_group_id, coro),
self._loop
)
return future
async def _run_conversation_with_tracking(self, focus_group_id: str, coro: Awaitable[Any]) -> Any:
"""Run a conversation coroutine with proper task tracking."""
# Create task for this conversation
task = asyncio.create_task(coro)
# Register the task
async with self._task_registry_lock:
# Cancel existing conversation for this focus group if any
existing_task = self._active_conversations.get(focus_group_id)
if existing_task and not existing_task.done():
self.logger.info(f"Cancelling existing conversation for focus group {focus_group_id}")
existing_task.cancel()
try:
await existing_task
except asyncio.CancelledError:
pass
self._active_conversations[focus_group_id] = task
try:
# Run the conversation
self.logger.info(f"Starting AI conversation for focus group {focus_group_id}")
result = await task
self.logger.info(f"AI conversation completed for focus group {focus_group_id}")
return result
except asyncio.CancelledError:
self.logger.info(f"AI conversation cancelled for focus group {focus_group_id}")
raise
except Exception as e:
self.logger.error(f"AI conversation error for focus group {focus_group_id}: {e}")
raise
finally:
# Unregister the task
async with self._task_registry_lock:
if self._active_conversations.get(focus_group_id) is task:
del self._active_conversations[focus_group_id]
def stop_conversation(self, focus_group_id: str) -> bool:
"""
Stop a specific conversation.
Args:
focus_group_id: The focus group ID to stop
Returns:
True if conversation was found and cancelled, False otherwise
"""
if not self._running or not self._loop:
return False
# Schedule cancellation on the AI loop
future = asyncio.run_coroutine_threadsafe(
self._cancel_conversation(focus_group_id),
self._loop
)
try:
return future.result(timeout=5.0)
except Exception as e:
self.logger.error(f"Error stopping conversation {focus_group_id}: {e}")
return False
async def _cancel_conversation(self, focus_group_id: str) -> bool:
"""Cancel a specific conversation task."""
async with self._task_registry_lock:
task = self._active_conversations.get(focus_group_id)
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
return True
return False
def get_active_conversations(self) -> Dict[str, Dict[str, Any]]:
"""Get information about active conversations."""
if not self._running or not self._loop:
return {}
future = asyncio.run_coroutine_threadsafe(
self._get_conversation_info(),
self._loop
)
try:
return future.result(timeout=2.0)
except Exception as e:
self.logger.error(f"Error getting conversation info: {e}")
return {}
async def _get_conversation_info(self) -> Dict[str, Dict[str, Any]]:
"""Get conversation information from the AI loop."""
async with self._task_registry_lock:
info = {}
for focus_group_id, task in self._active_conversations.items():
info[focus_group_id] = {
'status': 'running' if not task.done() else 'completed',
'cancelled': task.cancelled() if task.done() else False,
'exception': str(task.exception()) if task.done() and task.exception() else None
}
return info
@property
def is_running(self) -> bool:
"""Check if the AI Runner is running."""
return self._running
@property
def active_conversation_count(self) -> int:
"""Get count of active conversations."""
return len(self._active_conversations) if self._active_conversations else 0
# Global AI Runner instance
_ai_runner: Optional[AIRunnerService] = None
def get_ai_runner() -> AIRunnerService:
"""Get the global AI Runner instance."""
global _ai_runner
if _ai_runner is None:
_ai_runner = AIRunnerService()
return _ai_runner
def init_ai_runner() -> None:
"""Initialize and start the AI Runner service."""
ai_runner = get_ai_runner()
if not ai_runner.is_running:
ai_runner.start()
def shutdown_ai_runner() -> None:
"""Shutdown the AI Runner service."""
global _ai_runner
if _ai_runner and _ai_runner.is_running:
_ai_runner.stop()
_ai_runner = None

View file

@ -12,7 +12,7 @@ import logging
from app.services.conversation_decision_service import ConversationDecisionService, ConversationDecisionError
from app.services.focus_group_response_service import generate_persona_response, FocusGroupResponseError
from app.services.ai_moderator_service import AIModeratorService
from app.models.focus_group import FocusGroup
from app.models.focus_group import FocusGroup # Now fully async
from app.models.persona import Persona
@ -47,27 +47,11 @@ class AutonomousConversationController:
self._initialize_state_from_database()
def _initialize_state_from_database(self):
"""Initialize the controller's state from the database."""
try:
focus_group = FocusGroup.find_by_id(self.focus_group_id)
if focus_group:
db_status = focus_group.get('status', 'unknown')
# Set initial state based on database status
if db_status == 'ai_mode':
self.is_running = True
self.conversation_state = "running"
else:
self.is_running = False
self.conversation_state = "idle"
self.logger.debug(f"Initialized controller state from DB - status: {db_status}, is_running: {self.is_running}")
else:
self.logger.warning(f"Focus group {self.focus_group_id} not found during initialization")
except Exception as e:
self.logger.error(f"Error initializing state from database: {str(e)}")
# Keep default values if database check fails
"""Initialize the controller's state with defaults (database check happens during start)."""
# Set default state - actual database check will happen when start_autonomous_conversation is called
self.is_running = False
self.conversation_state = "idle"
self.logger.debug(f"Initialized controller with default state - database state will be checked on start")
async def start_autonomous_conversation(self, initial_prompt: Optional[str] = None) -> Dict[str, Any]:
"""
@ -88,8 +72,8 @@ class AutonomousConversationController:
self.is_running = False
self.conversation_state = "stopped"
# Validate focus group exists
focus_group = FocusGroup.find_by_id(self.focus_group_id)
# Validate focus group exists (using async model)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
@ -98,8 +82,8 @@ class AutonomousConversationController:
if not participants:
return {"error": "Focus group has no participants"}
# Update focus group status
FocusGroup.update(self.focus_group_id, {
# Update focus group status (using async model)
await FocusGroup.update(self.focus_group_id, {
'status': 'ai_mode',
'autonomous_started_at': datetime.utcnow()
})
@ -145,9 +129,9 @@ class AutonomousConversationController:
self.is_running = False
self.conversation_state = "completed"
# Update focus group status
# Update focus group status (using async model)
status = 'completed' if reason in ['completed', 'discussion_guide_completed', 'natural_completion'] else 'active'
FocusGroup.update(self.focus_group_id, {
await FocusGroup.update(self.focus_group_id, {
'status': status,
'autonomous_ended_at': datetime.utcnow(),
'completion_reason': reason
@ -174,7 +158,7 @@ class AutonomousConversationController:
# Use the AI moderator service to properly end the session with mode events
from app.services.ai_moderator_service import AIModeratorService
ending_result = AIModeratorService.end_session_with_concluding_statement(
ending_result = await AIModeratorService.end_session_with_concluding_statement(
self.focus_group_id, reason
)
@ -185,6 +169,16 @@ class AutonomousConversationController:
await self._add_moderator_message(completion_message, "system")
else:
self.logger.info(f"Successfully ended session with concluding statement: {ending_result.get('concluding_statement', '')[:100]}...")
elif reason == "manual_stop":
# For manual stops, add a mode event to indicate AI session concluded
mode_event_id = await FocusGroup.add_mode_event(
focus_group_id=self.focus_group_id,
event_type='ai_session_concluded'
)
if mode_event_id:
self.logger.info(f"🎯 Added AI session concluded mode event for manual stop: {mode_event_id}")
else:
self.logger.warning(f"Failed to add AI session concluded mode event for manual stop")
# For discussion guide completion, ensure all items are marked as completed (100% progress)
if reason in ["discussion_guide_completed", "natural_completion"]:
@ -231,7 +225,7 @@ class AutonomousConversationController:
# Update reasoning history with execution result
reasoning_id = decision.get('reasoning_id')
self._update_reasoning_execution(reasoning_id, result)
await self._update_reasoning_execution(reasoning_id, result)
if result.get("error"):
self.logger.error(f"Error executing decision: {result['error']}")
@ -272,8 +266,8 @@ class AutonomousConversationController:
async def _should_continue_conversation(self) -> bool:
"""Check if the conversation should continue based on various conditions."""
try:
# Check focus group status
focus_group = FocusGroup.find_by_id(self.focus_group_id)
# Check focus group status (using async model)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
return False
@ -368,7 +362,7 @@ class AutonomousConversationController:
Dictionary containing the decision (with reasoning_id added), or None if no decision could be made
"""
try:
decision = ConversationDecisionService.decide_next_action(
decision = await ConversationDecisionService.decide_next_action(
self.focus_group_id,
temperature=0.7,
mode='ai'
@ -377,7 +371,7 @@ class AutonomousConversationController:
self.logger.info(f"LLM Decision: {decision['action']} - {decision['reasoning']}")
# Store reasoning in history for UI display and get the database ID
reasoning_id = self._store_reasoning(decision)
reasoning_id = await self._store_reasoning(decision)
# Add the reasoning_id to the decision for later use
decision['reasoning_id'] = reasoning_id
@ -391,7 +385,7 @@ class AutonomousConversationController:
self.logger.error(f"Unexpected error in decision making: {str(e)}")
return None
def _store_reasoning(self, decision: Dict[str, Any]) -> Optional[str]:
async def _store_reasoning(self, decision: Dict[str, Any]) -> Optional[str]:
"""
Store reasoning from AI decision for UI display.
@ -412,7 +406,7 @@ class AutonomousConversationController:
}
# Store to database for persistence
reasoning_id = FocusGroup.add_reasoning_entry(self.focus_group_id, reasoning_entry)
reasoning_id = await FocusGroup.add_reasoning_entry(self.focus_group_id, reasoning_entry)
# Also keep in memory for quick access during active session
reasoning_entry['_id'] = reasoning_id # Add the database ID
@ -428,7 +422,7 @@ class AutonomousConversationController:
self.logger.error(f"Error storing reasoning: {str(e)}")
return None
def _update_reasoning_execution(self, reasoning_id: Optional[str], execution_result: Dict[str, Any]) -> None:
async def _update_reasoning_execution(self, reasoning_id: Optional[str], execution_result: Dict[str, Any]) -> None:
"""
Update the reasoning entry with execution results.
@ -439,7 +433,7 @@ class AutonomousConversationController:
try:
# Update the database record
if reasoning_id:
FocusGroup.update_reasoning_execution(self.focus_group_id, reasoning_id, execution_result)
await FocusGroup.update_reasoning_execution(self.focus_group_id, reasoning_id, execution_result)
# Also update in memory for quick access during active session
if self.reasoning_history:
@ -624,7 +618,7 @@ class AutonomousConversationController:
# Advance position past current item so final question shows as completed
try:
advance_result = AIModeratorService.advance_discussion(self.focus_group_id)
advance_result = await AIModeratorService.advance_discussion(self.focus_group_id)
if advance_result.get('error'):
self.logger.info(f"Could not advance past final item when ending session: {advance_result['error']}")
else:
@ -648,7 +642,7 @@ class AutonomousConversationController:
try:
# Get participant data
persona = Persona.find_by_id(participant_id)
persona = await Persona.find_by_id(participant_id)
if not persona:
error_msg = f"Participant {participant_id} not found"
self.logger.error(error_msg)
@ -656,7 +650,7 @@ class AutonomousConversationController:
# Get focus group data
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
error_msg = "Focus group not found"
self.logger.error(error_msg)
@ -673,12 +667,12 @@ class AutonomousConversationController:
self.logger.info(f"🤖 Autonomous conversation using model: {llm_model or 'default (gemini-2.5-pro)'} for focus group {self.focus_group_id}")
# Get recent messages
messages = FocusGroup.get_messages(self.focus_group_id)
messages = await FocusGroup.get_messages(self.focus_group_id)
recent_messages = messages[-20:] if len(messages) > 20 else messages
# Generate response
try:
response_text = generate_persona_response(
response_text = await generate_persona_response(
persona=persona,
current_topic=topic,
previous_messages=recent_messages,
@ -701,7 +695,7 @@ class AutonomousConversationController:
"senderId": participant_id
}
message_id = FocusGroup.add_message(self.focus_group_id, message_data)
message_id = await FocusGroup.add_message(self.focus_group_id, message_data)
# GPT-5 fix: Yield after database write to flush WebSocket events
await self._yield_to_eventlet()
@ -733,7 +727,7 @@ class AutonomousConversationController:
async def _get_item_by_position_id(self, position_id: str) -> Optional[Dict[str, Any]]:
"""Get a discussion guide item by its position ID."""
try:
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
return None
@ -788,11 +782,11 @@ class AutonomousConversationController:
print(f"🔍 Checking current discussion guide item for image attachments")
moderator_status = AIModeratorService.get_moderator_status(self.focus_group_id)
moderator_status = await AIModeratorService.get_moderator_status(self.focus_group_id)
current_item = None
if moderator_status and 'moderator_position' in moderator_status:
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if focus_group:
discussion_guide = focus_group.get('discussionGuide')
if discussion_guide and isinstance(discussion_guide, dict):
@ -842,7 +836,7 @@ class AutonomousConversationController:
from app.services.image_description_service import ImageDescriptionService, ImageDescriptionError
print(f"🎨 AI MODE: Generating description for {asset_filename}")
description = ImageDescriptionService.generate_description(self.focus_group_id, asset_filename)
description = await ImageDescriptionService.generate_description(self.focus_group_id, asset_filename)
# Enhance the content with the description using display reference if available
if display_reference:
@ -902,7 +896,7 @@ class AutonomousConversationController:
"visual_asset": visual_asset_metadata # Frontend needs this for image display
}
message_id = FocusGroup.add_message(self.focus_group_id, message_data)
message_id = await FocusGroup.add_message(self.focus_group_id, message_data)
# GPT-5 fix: Yield after database write to flush WebSocket events
await self._yield_to_eventlet()
@ -932,7 +926,7 @@ class AutonomousConversationController:
if section_id and item_id:
# LLM specified valid position - set it precisely
try:
result = AIModeratorService.set_moderator_position(self.focus_group_id, section_id, item_id)
result = await AIModeratorService.set_moderator_position(self.focus_group_id, section_id, item_id)
position_desc = f"position '{position_id}'"
if result.get('error'):
@ -958,7 +952,7 @@ class AutonomousConversationController:
async def _validate_position_id(self, position_id: str) -> tuple[Optional[str], Optional[str]]:
"""Validate that the position ID exists and return the section_id and item_id it maps to."""
try:
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
return None, None
@ -1034,7 +1028,7 @@ class AutonomousConversationController:
async def _fallback_advance_position(self) -> None:
"""Fallback method to advance position sequentially."""
try:
advance_result = AIModeratorService.advance_discussion(self.focus_group_id)
advance_result = await AIModeratorService.advance_discussion(self.focus_group_id)
if advance_result.get('error'):
self.logger.warning(f"Sequential advancement failed: {advance_result['error']}")
else:
@ -1043,12 +1037,12 @@ class AutonomousConversationController:
self.logger.warning(f"Failed to advance moderator position sequentially: {str(e)}")
async def _yield_to_eventlet(self):
"""GPT-5 fix: Yield to the eventlet hub to flush WebSocket frames."""
"""Yield control to allow other tasks to run and flush WebSocket frames."""
try:
from app.extensions import socketio
socketio.sleep(0) # Cooperative yielding for eventlet
# Use asyncio sleep instead of socketio.sleep since we're in async context
await asyncio.sleep(0) # Yield to other tasks
except Exception as e:
self.logger.warning(f"Could not yield to eventlet: {e}")
self.logger.warning(f"Could not yield to event loop: {e}")
async def _wait_between_actions(self):
"""Wait an appropriate amount of time between actions."""
@ -1058,11 +1052,11 @@ class AutonomousConversationController:
delay = random.uniform(self.min_delay_between_actions, self.max_delay_between_actions)
await asyncio.sleep(delay)
def get_conversation_status(self) -> Dict[str, Any]:
async def get_conversation_status(self) -> Dict[str, Any]:
"""Get the current status of the autonomous conversation."""
try:
# Check the actual database state to determine if autonomous mode is truly running
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if focus_group:
db_status = focus_group.get('status', 'unknown')
@ -1086,7 +1080,7 @@ class AutonomousConversationController:
# Keep existing instance state if database check fails
# Load reasoning history from database to ensure it persists across controller instances
reasoning_history = FocusGroup.get_reasoning_history(self.focus_group_id, self.max_reasoning_history)
reasoning_history = await FocusGroup.get_reasoning_history(self.focus_group_id, self.max_reasoning_history)
return {
"focus_group_id": self.focus_group_id,
@ -1105,7 +1099,7 @@ class AutonomousConversationController:
the moderator position to indicate 100% completion.
"""
try:
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
self.logger.error("Focus group not found when marking all questions completed")
return
@ -1161,7 +1155,7 @@ class AutonomousConversationController:
}
# Update the focus group with the completion position
FocusGroup.update(self.focus_group_id, {
await FocusGroup.update(self.focus_group_id, {
'moderator_position': completion_position
})

View file

@ -18,7 +18,7 @@ class ConversationContextService:
"""Service for aggregating conversation context for LLM decision making."""
@staticmethod
def get_full_context(focus_group_id: str) -> Dict[str, Any]:
async def get_full_context(focus_group_id: str) -> Dict[str, Any]:
"""
Get complete conversation context for LLM decision making.
@ -30,15 +30,15 @@ class ConversationContextService:
"""
try:
# Get focus group data
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
raise ValueError(f"Focus group {focus_group_id} not found")
# Get all participants
participants = ConversationContextService._get_participants_context(focus_group)
participants = await ConversationContextService._get_participants_context(focus_group)
# Get conversation history
messages = FocusGroup.get_messages(focus_group_id)
messages = await FocusGroup.get_messages(focus_group_id)
conversation_history = ConversationContextService._format_conversation_history(messages)
# Get conversation analytics
@ -69,13 +69,13 @@ class ConversationContextService:
raise Exception(f"Error getting conversation context: {str(e)}")
@staticmethod
def _get_participants_context(focus_group: Dict[str, Any]) -> List[Dict[str, Any]]:
async def _get_participants_context(focus_group: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Get formatted participant context with OCEAN traits and participation stats."""
participants = []
participant_ids = focus_group.get('participants', [])
for participant_id in participant_ids:
persona = Persona.find_by_id(participant_id)
persona = await Persona.find_by_id(participant_id)
if persona:
participant_context = {
'id': participant_id,
@ -497,7 +497,7 @@ class ConversationContextService:
# ================== MULTIMODAL CONVERSATION CONTEXT METHODS ==================
@staticmethod
def build_multimodal_context(focus_group_id: str, messages: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
async def build_multimodal_context(focus_group_id: str, messages: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
"""
Build complete multimodal conversation context including text and images in proper sequence.
@ -513,10 +513,10 @@ class ConversationContextService:
# Get messages with visual context if not provided
if messages is None:
messages = FocusGroup.get_messages_with_visual_context(focus_group_id)
messages = await FocusGroup.get_messages_with_visual_context(focus_group_id)
# Get active visual context
active_visual_context = FocusGroup.get_active_visual_context(focus_group_id)
active_visual_context = await FocusGroup.get_active_visual_context(focus_group_id)
print(f" - Total messages: {len(messages)}")
print(f" - Active visual assets: {len(active_visual_context)}")
@ -687,7 +687,7 @@ class ConversationContextService:
return "\n".join(formatted)
@staticmethod
def get_current_visual_assets(focus_group_id: str) -> List[str]:
async def get_current_visual_assets(focus_group_id: str) -> List[str]:
"""
Get list of asset paths that are currently active in conversation context.
@ -698,7 +698,7 @@ class ConversationContextService:
List of full paths to currently active visual assets
"""
try:
active_context = FocusGroup.get_active_visual_context(focus_group_id)
active_context = await FocusGroup.get_active_visual_context(focus_group_id)
asset_paths = []
for asset in active_context:
@ -715,7 +715,7 @@ class ConversationContextService:
return []
@staticmethod
def has_visual_context(focus_group_id: str) -> bool:
async def has_visual_context(focus_group_id: str) -> bool:
"""
Check if a focus group currently has any active visual context.
@ -726,7 +726,7 @@ class ConversationContextService:
True if there are active visual assets, False otherwise
"""
try:
active_context = FocusGroup.get_active_visual_context(focus_group_id)
active_context = await FocusGroup.get_active_visual_context(focus_group_id)
return len(active_context) > 0
except Exception as e:
print(f"❌ Error checking visual context: {e}")

View file

@ -19,7 +19,7 @@ class ConversationDecisionService:
"""Service for making LLM-based conversation decisions."""
@staticmethod
def decide_next_action(focus_group_id: str, temperature: float = 0.7, mode: str = "ai") -> Dict[str, Any]:
async def decide_next_action(focus_group_id: str, temperature: float = 0.7, mode: str = "ai") -> Dict[str, Any]:
"""
Use LLM to decide the next action in the conversation.
@ -38,7 +38,7 @@ class ConversationDecisionService:
try:
# Get full conversation context
context = ConversationContextService.get_full_context(focus_group_id)
context = await ConversationContextService.get_full_context(focus_group_id)
formatted_context = ConversationContextService.format_context_for_llm(context)
# Load the appropriate prompt based on mode
@ -55,12 +55,12 @@ class ConversationDecisionService:
# Get LLM model for this focus group
from app.models.focus_group import FocusGroup
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
llm_model = focus_group.get('llm_model') if focus_group else None
# Get LLM decision
try:
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=prompt,
temperature=temperature,
model_name=llm_model
@ -160,7 +160,7 @@ class ConversationDecisionService:
return True
@staticmethod
def select_next_participant(focus_group_id: str, current_topic: str, temperature: float = 0.7) -> Dict[str, Any]:
async def select_next_participant(focus_group_id: str, current_topic: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Use LLM to select the next participant to respond.
@ -173,7 +173,7 @@ class ConversationDecisionService:
Dictionary containing participant selection details
"""
try:
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature)
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature)
if decision['action'] == 'participant_respond':
return {
@ -195,7 +195,7 @@ class ConversationDecisionService:
raise ConversationDecisionError(f"Error selecting participant: {str(e)}")
@staticmethod
def detect_probe_triggers(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
async def detect_probe_triggers(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Use LLM to detect if probe triggers are needed.
@ -207,7 +207,7 @@ class ConversationDecisionService:
Dictionary containing probe trigger information
"""
try:
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature)
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature)
if decision['action'] == 'probe_trigger':
return {
@ -230,7 +230,7 @@ class ConversationDecisionService:
raise ConversationDecisionError(f"Error detecting probe triggers: {str(e)}")
@staticmethod
def generate_moderator_response(focus_group_id: str, context: str, temperature: float = 0.7) -> Dict[str, Any]:
async def generate_moderator_response(focus_group_id: str, context: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Use LLM to generate appropriate moderator response.
@ -243,7 +243,7 @@ class ConversationDecisionService:
Dictionary containing moderator response details
"""
try:
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature)
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature)
if decision['action'] == 'moderator_speak':
return {
@ -264,7 +264,7 @@ class ConversationDecisionService:
raise ConversationDecisionError(f"Error generating moderator response: {str(e)}")
@staticmethod
def detect_persona_interactions(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
async def detect_persona_interactions(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Use LLM to detect when personas should interact directly.
@ -276,7 +276,7 @@ class ConversationDecisionService:
Dictionary containing persona interaction details
"""
try:
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature)
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature)
if decision['action'] == 'participant_interaction':
return {
@ -299,7 +299,7 @@ class ConversationDecisionService:
raise ConversationDecisionError(f"Error detecting persona interactions: {str(e)}")
@staticmethod
def should_end_session(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
async def should_end_session(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Use LLM to determine if the session should end.
@ -311,7 +311,7 @@ class ConversationDecisionService:
Dictionary containing session ending decision
"""
try:
decision = ConversationDecisionService.decide_next_action(focus_group_id, temperature)
decision = await ConversationDecisionService.decide_next_action(focus_group_id, temperature)
if decision['action'] == 'end_session':
return {
@ -333,7 +333,7 @@ class ConversationDecisionService:
raise ConversationDecisionError(f"Error determining session end: {str(e)}")
@staticmethod
def get_conversation_insights(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
async def get_conversation_insights(focus_group_id: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Use LLM to generate insights about the current conversation state.
@ -346,7 +346,7 @@ class ConversationDecisionService:
"""
try:
# Get conversation context
context = ConversationContextService.get_full_context(focus_group_id)
context = await ConversationContextService.get_full_context(focus_group_id)
# Create a specialized prompt for insights
insight_prompt = f"""
@ -368,10 +368,10 @@ class ConversationDecisionService:
# Get LLM model for this focus group
from app.models.focus_group import FocusGroup
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
llm_model = focus_group.get('llm_model') if focus_group else None
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=insight_prompt,
temperature=temperature,
model_name=llm_model

View file

@ -21,7 +21,7 @@ class ConversationStateManager:
self.cache_ttl = 60 # seconds
self.last_cache_update = None
def get_conversation_state(self) -> Dict[str, Any]:
async def get_conversation_state(self) -> Dict[str, Any]:
"""
Get the current conversation state.
@ -33,12 +33,12 @@ class ConversationStateManager:
if self._is_cache_valid():
return self.state_cache
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
# Get messages
messages = FocusGroup.get_messages(self.focus_group_id)
messages = await FocusGroup.get_messages(self.focus_group_id)
# Calculate conversation state
state = {
@ -71,7 +71,7 @@ class ConversationStateManager:
except Exception as e:
return {"error": f"Error getting conversation state: {str(e)}"}
def get_conversation_analytics(self) -> Dict[str, Any]:
async def get_conversation_analytics(self) -> Dict[str, Any]:
"""
Get detailed conversation analytics.
@ -83,11 +83,11 @@ class ConversationStateManager:
if self._is_analytics_cache_valid():
return self.analytics_cache
focus_group = FocusGroup.find_by_id(self.focus_group_id)
focus_group = await FocusGroup.find_by_id(self.focus_group_id)
if not focus_group:
return {"error": "Focus group not found"}
messages = FocusGroup.get_messages(self.focus_group_id)
messages = await FocusGroup.get_messages(self.focus_group_id)
participants = focus_group.get('participants', [])
analytics = {
@ -111,7 +111,7 @@ class ConversationStateManager:
except Exception as e:
return {"error": f"Error getting conversation analytics: {str(e)}"}
def update_conversation_state(self, updates: Dict[str, Any]) -> Dict[str, Any]:
async def update_conversation_state(self, updates: Dict[str, Any]) -> Dict[str, Any]:
"""
Update conversation state.
@ -123,7 +123,7 @@ class ConversationStateManager:
"""
try:
# Update focus group
success = FocusGroup.update(self.focus_group_id, updates)
success = await FocusGroup.update(self.focus_group_id, updates)
if success:
# Clear cache to force refresh
@ -140,22 +140,22 @@ class ConversationStateManager:
except Exception as e:
return {"error": f"Error updating conversation state: {str(e)}"}
def start_autonomous_mode(self) -> Dict[str, Any]:
async def start_autonomous_mode(self) -> Dict[str, Any]:
"""Start autonomous conversation mode."""
return self.update_conversation_state({
return await self.update_conversation_state({
'status': 'ai_mode',
'autonomous_started_at': datetime.utcnow()
})
def end_autonomous_mode(self, reason: str = "completed") -> Dict[str, Any]:
async def end_autonomous_mode(self, reason: str = "completed") -> Dict[str, Any]:
"""End autonomous conversation mode."""
if reason == "completed":
status = 'completed'
else:
status = 'active'
return self.update_conversation_state({
return await self.update_conversation_state({
'status': status,
'autonomous_ended_at': datetime.utcnow(),
'completion_reason': reason

View file

@ -30,7 +30,8 @@ class CustomerDataService:
raise CustomerDataServiceError("llama-cloud-services package not installed")
self.api_key = api_key
self.base_dir = os.path.join(os.path.dirname(__file__), "..", "..", "persona_data")
# Resolve to absolute path to avoid working directory issues
self.base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "persona_data"))
# Ensure base directory exists
os.makedirs(self.base_dir, exist_ok=True)
@ -50,7 +51,7 @@ class CustomerDataService:
"""Generate a unique session ID for this upload session."""
return str(uuid.uuid4())
def upload_and_parse_files(self, files: List[FileStorage]) -> str:
async def upload_and_parse_files(self, files: List[FileStorage]) -> str:
"""
Upload files and parse them using LlamaParse.
@ -80,14 +81,40 @@ class CustomerDataService:
# Secure filename
filename = f"{session_id}_{file.filename}"
file_path = os.path.join(session_dir, filename)
file.save(file_path)
uploaded_files.append(file_path)
try:
# Save file and verify it exists (Quart async version)
await file.save(file_path)
if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
uploaded_files.append(file_path)
print(f"✅ Successfully saved file: {file_path} ({os.path.getsize(file_path)} bytes)")
else:
raise CustomerDataServiceError(f"Failed to save file: {file.filename}")
except CustomerDataServiceError:
raise # Re-raise our own errors
except Exception as e:
raise CustomerDataServiceError(f"Failed to save file {file.filename}: {str(e)}")
if not uploaded_files:
raise CustomerDataServiceError("No valid files uploaded")
# Parse files using LlamaParse
parsed_documents = self.parser.load_data(uploaded_files)
print(f"🔄 Starting LlamaParse for {len(uploaded_files)} files...")
for file_path in uploaded_files:
print(f"📄 File to parse: {file_path} (exists: {os.path.exists(file_path)})")
try:
parsed_documents = self.parser.load_data(uploaded_files)
print(f"✅ LlamaParse completed successfully. Generated {len(parsed_documents)} documents.")
except Exception as parse_error:
print(f"❌ LlamaParse failed: {str(parse_error)}")
# Check which files still exist before the error
for file_path in uploaded_files:
exists = os.path.exists(file_path)
size = os.path.getsize(file_path) if exists else 0
print(f"📄 File status: {file_path} - exists: {exists}, size: {size}")
raise CustomerDataServiceError(f"LlamaParse failed: {str(parse_error)}")
# Save parsed markdown files
for i, document in enumerate(parsed_documents):

View file

@ -15,7 +15,7 @@ class FocusGroupResponseError(Exception):
pass
def generate_persona_response(
async def generate_persona_response(
persona: Dict[str, Any],
current_topic: str,
previous_messages: List[Dict[str, Any]],
@ -64,11 +64,11 @@ def generate_persona_response(
if focus_group_id:
try:
from app.services.conversation_context_service import ConversationContextService
has_visual_context = ConversationContextService.has_visual_context(focus_group_id)
has_visual_context = await ConversationContextService.has_visual_context(focus_group_id)
if has_visual_context:
print(f"🎨 Visual context detected, building multimodal context...")
multimodal_context = ConversationContextService.build_multimodal_context(
multimodal_context = await ConversationContextService.build_multimodal_context(
focus_group_id, previous_messages
)
print(f"🎨 Built context with {multimodal_context['total_visual_assets']} visual assets")
@ -121,7 +121,7 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
raise FocusGroupResponseError(f"Error loading contextual response prompt: {str(e)}")
# Generate response using contextual conversation method
response = LLMService.generate_contextual_response(
response = await LLMService.generate_contextual_response(
prompt=prompt,
conversation_context=multimodal_context['conversation_context'],
temperature=temperature,
@ -150,7 +150,7 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
raise FocusGroupResponseError(f"Error loading response prompt: {str(e)}")
# Generate the standard response
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=prompt,
temperature=temperature,
model_name=llm_model,
@ -391,7 +391,7 @@ This is your chance to provide more detailed insights and personal anecdotes.
"""
def generate_creative_review_response(
async def generate_creative_review_response(
persona: Dict[str, Any],
current_topic: str,
creative_asset_path: str,
@ -493,7 +493,7 @@ Be genuine and specific in your feedback, drawing on your personal experiences a
print(f" - image_paths: {[full_asset_path]}")
print(f" - temperature: {temperature}")
response = LLMService.generate_multimodal_content(
response = await LLMService.generate_multimodal_content(
prompt=prompt,
image_paths=[full_asset_path],
temperature=temperature

View file

@ -22,7 +22,7 @@ class FocusGroupService:
"""Service for focus group operations."""
@staticmethod
def generate_discussion_guide(
async def generate_discussion_guide(
focus_group_name: str,
research_brief: str,
discussion_topics: str,
@ -94,7 +94,7 @@ class FocusGroupService:
uploaded_assets = []
if focus_group_id:
try:
uploaded_assets = FocusGroup.get_uploaded_assets(focus_group_id)
uploaded_assets = await FocusGroup.get_uploaded_assets(focus_group_id)
if uploaded_assets:
logger.info(f"Retrieved {len(uploaded_assets)} assets for focus group {focus_group_id}")
except Exception as e:
@ -155,7 +155,7 @@ class FocusGroupService:
enhanced_prompt = asset_emphasis + prompt
# Generate content using LLM
response = LLMService.generate_content(
response = await LLMService.generate_content(
prompt=enhanced_prompt,
temperature=temperature,
max_tokens=16000, # Use a much higher token limit to avoid truncation

View file

@ -22,7 +22,7 @@ class ImageDescriptionService:
"""Service for generating AI-powered descriptions of creative assets."""
@staticmethod
def generate_description(focus_group_id: str, asset_filename: str) -> str:
async def generate_description(focus_group_id: str, asset_filename: str) -> str:
"""
Generate a detailed AI description of a creative asset image.
@ -76,10 +76,10 @@ class ImageDescriptionService:
# Get LLM model for this focus group
from app.models.focus_group import FocusGroup
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
llm_model = focus_group.get('llm_model') if focus_group else None
description = LLMService.generate_multimodal_content(
description = await LLMService.generate_multimodal_content(
prompt=prompt,
image_paths=[asset_path],
temperature=0.7,

View file

@ -20,7 +20,7 @@ class KeyThemeService:
"""Service for generating key themes from focus group discussions."""
@staticmethod
def generate_key_themes(
async def generate_key_themes(
focus_group_id: str,
temperature: float = 0.7,
llm_model: Optional[str] = None
@ -45,12 +45,12 @@ class KeyThemeService:
try:
# Get the focus group
focus_group = FocusGroup.find_by_id(focus_group_id)
focus_group = await FocusGroup.find_by_id(focus_group_id)
if not focus_group:
raise KeyThemeServiceError(f"Focus group not found with ID: {focus_group_id}")
# Get all messages from the focus group
messages = FocusGroup.get_messages(focus_group_id)
messages = await FocusGroup.get_messages(focus_group_id)
if not messages:
raise KeyThemeServiceError("No messages found in this focus group")
@ -61,14 +61,14 @@ class KeyThemeService:
if 'participants' in focus_group and focus_group['participants']:
for persona_id in focus_group['participants']:
try:
persona = Persona.find_by_id(persona_id)
persona = await Persona.find_by_id(persona_id)
if persona:
participants_data.append(persona)
except Exception as e:
print(f"Error fetching participant {persona_id}: {e}")
# Generate key themes using LLM
return KeyThemeService._extract_themes_from_discussion(
return await KeyThemeService._extract_themes_from_discussion(
messages=messages,
participants=participants_data,
discussion_guide=focus_group.get('discussionGuide', ''),
@ -80,7 +80,7 @@ class KeyThemeService:
raise KeyThemeServiceError(f"Error generating key themes: {str(e)}")
@staticmethod
def _extract_themes_from_discussion(
async def _extract_themes_from_discussion(
messages: List[Dict[str, Any]],
participants: List[Dict[str, Any]],
discussion_guide: str,
@ -138,7 +138,7 @@ class KeyThemeService:
logger.info(f"Attempt {attempt_num}/{max_retries}: Calling LLM ({llm_model or 'gemini-2.5-pro'}) for theme generation")
try:
themes = LLMService.generate_structured_array(
themes = await LLMService.generate_structured_array(
prompt=prompt,
temperature=temperature,
system_prompt=system_prompt,

View file

@ -7,21 +7,23 @@ different application features.
import os
import json
import time
import asyncio
import logging
import google.generativeai as genai
from openai import OpenAI
import base64
from google import genai
from openai import AsyncOpenAI
import httpx
from typing import Dict, Any, Optional, Union, List
from PIL import Image
import io
# Set up the Gemini API key
# Set up the Gemini API key and client
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', 'AIzaSyAc50jzC3k9K1PmKT1vGFi0sCdhhnqsvl0')
genai.configure(api_key=GEMINI_API_KEY)
gemini_client = genai.Client(api_key=GEMINI_API_KEY)
# Set up OpenAI API key
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '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')
openai_client = OpenAI(api_key=OPENAI_API_KEY)
openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
# The default model we're using
DEFAULT_MODEL = "gemini-2.5-pro"
@ -85,26 +87,14 @@ class LLMService:
actual_model = model_name or DEFAULT_MODEL
return SUPPORTED_MODELS.get(actual_model, 'gemini')
@staticmethod
def get_model(model_name: Optional[str] = None) -> genai.GenerativeModel:
"""
Get a configured Gemini model.
Args:
model_name: Optional model name to use. Defaults to the default model.
Returns:
A configured Gemini generative model
"""
return genai.GenerativeModel(model_name or DEFAULT_MODEL)
@staticmethod
def _extract_text_from_response(response) -> str:
def _extract_text_from_new_genai_response(response) -> str:
"""
Extract text from a Gemini API response, handling both simple and multi-part responses.
Extract text from a new Google GenAI SDK response.
Args:
response: The response object from the Gemini API
response: The response object from the new Google GenAI SDK
Returns:
The extracted text content
@ -113,56 +103,35 @@ class LLMService:
LLMServiceError: If no text content can be extracted
"""
try:
# Try the simple text accessor first
return response.text.strip()
except Exception:
# If that fails, try to extract from parts using the recommended approach
try:
text_parts = []
# Check if response has direct parts attribute (as suggested in error message)
if hasattr(response, 'parts') and response.parts:
for part in response.parts:
if hasattr(part, 'text'):
text_parts.append(part.text)
# If that didn't work, try the candidates approach
if not text_parts and hasattr(response, 'candidates') and response.candidates:
for candidate in response.candidates:
# Check if finish reason indicates blocking
if candidate.finish_reason == 3:
raise LLMServiceError("Response was blocked for safety reasons")
elif candidate.finish_reason == 4:
raise LLMServiceError("Response was blocked for recitation reasons")
elif candidate.finish_reason == 2:
raise LLMServiceError("Response was cut off due to length limit - try reducing max_tokens or removing the limit")
if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
# New SDK has a simpler text attribute
if hasattr(response, 'text') and response.text:
return response.text.strip()
# If that doesn't work, check for candidates structure
if hasattr(response, 'candidates') and response.candidates:
for candidate in response.candidates:
if hasattr(candidate, 'content') and candidate.content:
if hasattr(candidate.content, 'parts') and candidate.content.parts:
text_parts = []
for part in candidate.content.parts:
if hasattr(part, 'text'):
if hasattr(part, 'text') and part.text:
text_parts.append(part.text)
if text_parts:
return ''.join(text_parts).strip()
# If no text found, check if the response object has direct text content
if hasattr(response, 'content') and response.content:
return str(response.content).strip()
raise LLMServiceError("Unable to extract text from new GenAI SDK response")
# Join all text parts if we found any
if text_parts:
return ''.join(text_parts).strip()
# If we still can't extract text, it might be a safety/blocking issue
if hasattr(response, 'candidates') and response.candidates:
finish_reason = response.candidates[0].finish_reason
if finish_reason == 3:
raise LLMServiceError("Response was blocked for safety reasons")
elif finish_reason == 4:
raise LLMServiceError("Response was blocked for recitation reasons")
elif finish_reason == 2:
raise LLMServiceError("Response was cut off due to length limit - try reducing max_tokens or removing the limit")
raise LLMServiceError("Unable to extract text from response parts")
except Exception as e:
raise LLMServiceError(f"Error extracting text from multi-part response: {str(e)}")
except Exception as e:
if isinstance(e, LLMServiceError):
raise
raise LLMServiceError(f"Error extracting text from new GenAI SDK response: {str(e)}")
@staticmethod
def generate_content(
async def generate_content(
prompt: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
@ -233,7 +202,7 @@ class LLMService:
# Note: GPT-5 Responses API does not support max_tokens parameter
response = openai_client.responses.create(**kwargs)
response = await openai_client.responses.create(**kwargs)
result = LLMService._extract_responses_api_content(response)
else:
@ -252,39 +221,33 @@ class LLMService:
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = openai_client.chat.completions.create(**kwargs)
response = await openai_client.chat.completions.create(**kwargs)
result = response.choices[0].message.content.strip()
else:
# Gemini API call (existing logic)
model = LLMService.get_model(model_name)
generation_config = {
"temperature": temperature,
}
# New Google GenAI SDK - async call
config = genai.types.GenerateContentConfig(
temperature=temperature,
)
if max_tokens:
generation_config["max_output_tokens"] = max_tokens
config.max_output_tokens = max_tokens
# If system prompt is provided, use it to create a structured chat
# Prepare the prompt - combine system prompt with user prompt if needed
if system_prompt:
# For Gemini models, system prompts need to be passed as part of the user prompt
# as Gemini API doesn't support 'system' role directly
response = model.generate_content(
[
{"role": "user", "parts": [f"System: {system_prompt}\n\nUser: {prompt}"]}
],
generation_config=genai.types.GenerationConfig(**generation_config)
)
combined_prompt = f"System: {system_prompt}\n\nUser: {prompt}"
else:
# Otherwise use the standard prompt-only approach
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(**generation_config)
)
combined_prompt = prompt
# If successful, extract and return the response
result = LLMService._extract_text_from_response(response)
# Make async call to new GenAI SDK
response = await gemini_client.aio.models.generate_content(
model=actual_model,
contents=combined_prompt,
config=config
)
# Extract text from new SDK response
result = LLMService._extract_text_from_new_genai_response(response)
if attempt > 0:
logger.info(f"LLM content generation succeeded on attempt {attempt_num}/{max_retries}")
@ -308,7 +271,7 @@ class LLMService:
# Wait before retrying (exponential backoff)
wait_time = 2 ** attempt # 1s, 2s, 4s
logger.info(f"Retryable error detected. Waiting {wait_time} seconds before retry {attempt_num + 1}/{max_retries}")
time.sleep(wait_time)
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"Retryable error detected but max retries ({max_retries}) reached")
@ -353,7 +316,7 @@ class LLMService:
raise LLMServiceError(error_msg)
@staticmethod
def generate_structured_response(
async def generate_structured_response(
prompt: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
@ -380,7 +343,7 @@ class LLMService:
Raises:
LLMServiceError: If there's an issue with generation or parsing
"""
response_text = LLMService.generate_content(
response_text = await LLMService.generate_content(
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
@ -393,7 +356,7 @@ class LLMService:
return LLMService.parse_json_response(response_text)
@staticmethod
def generate_structured_array(
async def generate_structured_array(
prompt: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
@ -420,7 +383,7 @@ class LLMService:
Raises:
LLMServiceError: If there's an issue with generation or parsing
"""
response_text = LLMService.generate_content(
response_text = await LLMService.generate_content(
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
@ -439,7 +402,7 @@ class LLMService:
return result
@staticmethod
def generate_multimodal_content(
async def generate_multimodal_content(
prompt: str,
image_paths: List[str],
temperature: float = 0.7,
@ -526,7 +489,7 @@ class LLMService:
# Note: GPT-5 Responses API does not support max_tokens parameter
response = openai_client.responses.create(**kwargs)
response = await openai_client.responses.create(**kwargs)
result = LLMService._extract_responses_api_content(response)
else:
@ -543,50 +506,59 @@ class LLMService:
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = openai_client.chat.completions.create(**kwargs)
response = await openai_client.chat.completions.create(**kwargs)
result = response.choices[0].message.content.strip()
else:
# Gemini multimodal API call (existing logic)
# Load and validate images
images = []
# New Google GenAI SDK - multimodal async call
config = genai.types.GenerateContentConfig(
temperature=temperature,
)
if max_tokens:
config.max_output_tokens = max_tokens
# Prepare multimodal content for new SDK
content_parts = []
# Add text prompt
content_parts.append(genai.types.Part.from_text(prompt))
# Add images
for image_path in image_paths:
try:
if not os.path.exists(image_path):
raise LLMServiceError(f"Image file not found: {image_path}")
# Load image using PIL
with Image.open(image_path) as img:
# Convert to RGB if necessary
if img.mode != 'RGB':
img = img.convert('RGB')
images.append(img.copy())
logger.debug(f"Successfully loaded image for Gemini: {image_path}")
# Read image data for new SDK
with open(image_path, 'rb') as img_file:
image_data = img_file.read()
# Determine MIME type from file extension
ext = os.path.splitext(image_path)[1].lower()
mime_type = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}.get(ext, 'image/jpeg') # Default to JPEG
content_parts.append(genai.types.Part.from_bytes(image_data, mime_type=mime_type))
logger.debug(f"Successfully loaded image for new GenAI SDK: {image_path}")
except Exception as e:
raise LLMServiceError(f"Failed to load image {image_path}: {str(e)}")
model = LLMService.get_model(model_name)
generation_config = {
"temperature": temperature,
}
if max_tokens:
generation_config["max_output_tokens"] = max_tokens
# Create multimodal input - combine text prompt with images
content_parts = [prompt]
content_parts.extend(images)
response = model.generate_content(
content_parts,
generation_config=genai.types.GenerationConfig(**generation_config)
# Make async call to new GenAI SDK with multimodal content
response = await gemini_client.aio.models.generate_content(
model=actual_model,
contents=content_parts,
config=config
)
# Extract and return the response
result = LLMService._extract_text_from_response(response)
# Extract text from new SDK response
result = LLMService._extract_text_from_new_genai_response(response)
if attempt > 0:
logger.info(f"Multimodal content generation succeeded on attempt {attempt_num}/{max_retries}")
@ -610,7 +582,7 @@ class LLMService:
# Wait before retrying (exponential backoff)
wait_time = 2 ** attempt # 1s, 2s, 4s
logger.info(f"Retryable error detected. Waiting {wait_time} seconds before retry {attempt_num + 1}/{max_retries}")
time.sleep(wait_time)
await asyncio.sleep(wait_time)
continue
else:
logger.error(f"Retryable error detected but max retries ({max_retries}) reached")
@ -623,7 +595,7 @@ class LLMService:
raise LLMServiceError(f"Error generating multimodal content: {str(last_error)}")
@staticmethod
def generate_contextual_response(
async def generate_contextual_response(
prompt: str,
conversation_context: List[Dict[str, Any]],
temperature: float = 0.7,
@ -704,7 +676,6 @@ class LLMService:
try:
if provider == 'openai':
# OpenAI contextual multimodal API call
import base64
# Convert PIL images to base64 for OpenAI API
image_content = []
@ -743,7 +714,7 @@ class LLMService:
# Note: GPT-5 Responses API does not support max_tokens parameter
response = openai_client.responses.create(**kwargs)
response = await openai_client.responses.create(**kwargs)
result = LLMService._extract_responses_api_content(response)
else:
@ -760,30 +731,42 @@ class LLMService:
if max_tokens:
kwargs["max_tokens"] = max_tokens
response = openai_client.chat.completions.create(**kwargs)
response = await openai_client.chat.completions.create(**kwargs)
result = response.choices[0].message.content.strip()
else:
# Gemini contextual multimodal API call (existing logic)
# Create content parts with text and images
content_parts = [full_prompt]
content_parts.extend(image_parts)
model = LLMService.get_model(model_name)
generation_config = {
"temperature": temperature,
}
if max_tokens:
generation_config["max_output_tokens"] = max_tokens
response = model.generate_content(
content_parts,
generation_config=genai.types.GenerationConfig(**generation_config)
# New Google GenAI SDK - contextual multimodal async call
config = genai.types.GenerateContentConfig(
temperature=temperature,
)
result = LLMService._extract_text_from_response(response)
if max_tokens:
config.max_output_tokens = max_tokens
# Prepare content parts for new SDK
new_content_parts = []
# Add text prompt
new_content_parts.append(genai.types.Part.from_text(full_prompt))
# Convert PIL image parts to new SDK format
for img in image_parts:
# Convert PIL image to bytes
buffer = io.BytesIO()
img.save(buffer, format='PNG')
image_data = buffer.getvalue()
# Add as image part in new SDK format
new_content_parts.append(genai.types.Part.from_bytes(image_data, mime_type='image/png'))
# Make async call to new GenAI SDK
response = await gemini_client.aio.models.generate_content(
model=actual_model,
contents=new_content_parts,
config=config
)
result = LLMService._extract_text_from_new_genai_response(response)
if attempt > 0:
logger.info(f"Contextual multimodal generation succeeded on attempt {attempt_num}/{max_retries}")
@ -828,7 +811,7 @@ class LLMService:
else:
# No images, use standard text generation
print(f"📝 Using text-only generation (no visual context)")
return LLMService.generate_content(
return await LLMService.generate_content(
prompt=full_prompt,
temperature=temperature,
max_tokens=max_tokens,

View file

@ -44,7 +44,7 @@ class PersonaExportService:
demographics, goals, personality traits, scenarios, and additional data.
"""
def generate_profile_markdown(
async def generate_profile_markdown(
self,
persona_data: Dict[str, Any],
llm_model: str = "gpt-4.1",
@ -83,7 +83,7 @@ class PersonaExportService:
full_prompt = f"{self.prompt_template}\n\n## Persona Data\n```json\n{persona_json}\n```"
# Generate markdown using LLM
markdown_content = self.llm_service.generate_content(
markdown_content = await LLMService.generate_content(
prompt=full_prompt,
model_name=llm_model,
temperature=temperature,

View file

@ -131,7 +131,7 @@ class PersonaModificationService:
return True
@staticmethod
def modify_persona(
async def modify_persona(
persona_id: str,
modification_prompt: str,
llm_model: str = 'gemini-2.5-pro',
@ -158,7 +158,7 @@ class PersonaModificationService:
"""
try:
# Fetch the original persona
original_persona = Persona.find_by_id(persona_id)
original_persona = await Persona.find_by_id(persona_id)
if not original_persona:
raise PersonaModificationError(f"Persona with ID {persona_id} not found")
@ -182,7 +182,7 @@ class PersonaModificationService:
logger.info(f"Attempting persona modification (attempt {attempt + 1}/{max_retries})")
# Call LLM service
llm_response = LLMService.generate_content(
llm_response = await LLMService.generate_content(
prompt=final_prompt,
temperature=0.3, # Lower temperature for consistent modifications
model_name=llm_model,
@ -212,7 +212,7 @@ class PersonaModificationService:
)
# Update the persona in the database
success = Persona.update(persona_id, modified_persona_data)
success = await Persona.update(persona_id, modified_persona_data)
if not success:
raise PersonaModificationError("Failed to update persona in database")

View file

@ -13,7 +13,7 @@ from typing import Dict, Set, Any, Optional
from datetime import datetime
from flask import request, current_app
from flask_socketio import emit, join_room, leave_room, disconnect
from .extensions import socketio # Import singleton SocketIO instance
from .extensions import socketio_server as socketio # Import singleton SocketIO instance
from flask_jwt_extended import decode_token
from functools import wraps
import json

View file

@ -0,0 +1,398 @@
"""
Async WebSocket Manager for Synthetic Society
Handles WebSocket connections, room management, and real-time event broadcasting.
Uses python-socketio AsyncServer for native Quart/ASGI compatibility.
"""
import logging
import os
import threading
import asyncio
from typing import Dict, Set, Any, Optional
from datetime import datetime
from .extensions import socketio_server as sio
from flask_jwt_extended import decode_token
# Set up logging
logger = logging.getLogger(__name__)
class AsyncWebSocketManager:
"""Manages WebSocket connections and rooms for focus group sessions using AsyncServer."""
def __init__(self):
# Use singleton SocketIO AsyncServer instance
self.sio = sio
self.focus_group_rooms: Dict[str, Set[str]] = {} # focus_group_id -> set of session_ids
self.user_sessions: Dict[str, Dict[str, Any]] = {} # session_id -> user info
# Register SocketIO event handlers
self._register_handlers()
def _register_handlers(self):
"""Register all WebSocket event handlers."""
@self.sio.event
async def connect(sid, environ, auth):
"""Handle WebSocket connection."""
import os
import threading
process_id = os.getpid()
thread_id = threading.get_ident()
print(f"🔌 ASYNC PROCESS DEBUG - WebSocket connection attempt from {sid}")
print(f"🔌 ASYNC PROCESS DEBUG - Connection handler PID: {process_id}, Thread: {thread_id}")
logger.info(f"WebSocket connection attempt from {sid}")
# Validate JWT token from auth data (temporarily allow without token for testing)
if not auth or 'token' not in auth:
logger.warning(f"WebSocket connection without auth token - allowing for testing")
# Temporarily allow connections without tokens for testing
self.user_sessions[sid] = {
'user_id': 'test_user', # Default user for testing
'connected_at': datetime.utcnow(),
'focus_groups': set()
}
logger.info(f"WebSocket connected without auth - Session: {sid}")
await self.sio.emit('connected', {'status': 'success', 'session_id': sid}, to=sid)
return True
try:
# Decode and validate JWT token with better error handling
token = auth['token']
import jwt
import os
# Validate token format first
if not token or not isinstance(token, str):
raise ValueError("Invalid token format")
# Check if token has the right number of segments (should have 3 parts separated by dots)
token_parts = token.split('.')
if len(token_parts) != 3:
raise ValueError(f"Invalid JWT format: expected 3 segments, got {len(token_parts)}")
# Get JWT secret from environment (same as our Quart JWT system)
jwt_secret = os.environ.get('SECRET_KEY', 'your-secret-key-for-sessions-and-tokens')
try:
decoded_token = jwt.decode(token, jwt_secret, algorithms=['HS256'])
user_id = decoded_token.get('sub')
if not user_id:
raise ValueError("No user ID in token")
except Exception as jwt_error:
logger.warning(f"JWT decode failed: {jwt_error}")
logger.warning(f"Token format: {len(token_parts)} segments, first 20 chars: {token[:20]}...")
# During migration, allow connection with test user instead of disconnecting
logger.info(f"Allowing WebSocket connection with test user due to JWT transition")
self.user_sessions[sid] = {
'user_id': 'test_user', # Default user during JWT transition
'connected_at': datetime.utcnow(),
'focus_groups': set()
}
await self.sio.emit('connected', {'status': 'success', 'session_id': sid, 'auth': 'fallback'}, to=sid)
return True
# Store user session info
self.user_sessions[sid] = {
'user_id': user_id,
'connected_at': datetime.utcnow(),
'focus_groups': set()
}
logger.info(f"WebSocket connected - Session: {sid}, User: {user_id}")
# Emit connection success
await self.sio.emit('connected', {'status': 'success', 'session_id': sid}, to=sid)
except Exception as e:
logger.error(f"Connection authentication failed: {e}")
await self.sio.disconnect(sid)
return False
@self.sio.event
async def disconnect(sid):
"""Handle WebSocket disconnection."""
session_id = sid
if session_id in self.user_sessions:
user_info = self.user_sessions[session_id]
user_id = user_info['user_id']
# Leave all focus group rooms
for focus_group_id in user_info['focus_groups'].copy():
await self._leave_focus_group_room(session_id, focus_group_id)
# Clean up session
del self.user_sessions[session_id]
logger.info(f"WebSocket disconnected - Session: {session_id}, User: {user_id}")
@self.sio.event
async def join_focus_group(sid, data):
"""Handle joining a focus group room."""
session_id = sid
if session_id not in self.user_sessions:
await self.sio.emit('error', {'message': 'Session not authenticated'}, to=sid)
return
focus_group_id = data.get('focus_group_id')
if not focus_group_id:
await self.sio.emit('error', {'message': 'Focus group ID required'}, to=sid)
return
# Join the room
success = await self._join_focus_group_room(session_id, focus_group_id)
if success:
await self.sio.emit('joined_focus_group', {
'focus_group_id': focus_group_id,
'status': 'success'
}, to=sid)
logger.info(f"User joined focus group room - Session: {session_id}, Group: {focus_group_id}")
else:
await self.sio.emit('error', {'message': 'Failed to join focus group'}, to=sid)
@self.sio.event
async def leave_focus_group(sid, data):
"""Handle leaving a focus group room."""
session_id = sid
if session_id not in self.user_sessions:
await self.sio.emit('error', {'message': 'Session not authenticated'}, to=sid)
return
focus_group_id = data.get('focus_group_id')
if not focus_group_id:
await self.sio.emit('error', {'message': 'Focus group ID required'}, to=sid)
return
# Leave the room
success = await self._leave_focus_group_room(session_id, focus_group_id)
if success:
await self.sio.emit('left_focus_group', {
'focus_group_id': focus_group_id,
'status': 'success'
}, to=sid)
logger.info(f"User left focus group room - Session: {session_id}, Group: {focus_group_id}")
async def _join_focus_group_room(self, session_id: str, focus_group_id: str) -> bool:
"""Join a user session to a focus group room."""
try:
# Add to SocketIO room
await self.sio.enter_room(session_id, focus_group_id)
# Track in our data structures
if focus_group_id not in self.focus_group_rooms:
self.focus_group_rooms[focus_group_id] = set()
self.focus_group_rooms[focus_group_id].add(session_id)
self.user_sessions[session_id]['focus_groups'].add(focus_group_id)
return True
except Exception as e:
logger.error(f"Failed to join focus group room: {e}")
return False
async def _leave_focus_group_room(self, session_id: str, focus_group_id: str) -> bool:
"""Remove a user session from a focus group room."""
try:
# Leave SocketIO room
await self.sio.leave_room(session_id, focus_group_id)
# Clean up tracking
if focus_group_id in self.focus_group_rooms:
self.focus_group_rooms[focus_group_id].discard(session_id)
# Remove room if empty
if not self.focus_group_rooms[focus_group_id]:
del self.focus_group_rooms[focus_group_id]
if session_id in self.user_sessions:
self.user_sessions[session_id]['focus_groups'].discard(focus_group_id)
return True
except Exception as e:
logger.error(f"Failed to leave focus group room: {e}")
return False
async def emit_to_focus_group(self, focus_group_id: str, event: str, data: Any, include_sender: bool = True, sender_session_id: Optional[str] = None):
"""Emit an event to all users in a focus group room."""
process_id = os.getpid()
thread_id = threading.get_ident()
print(f"🔔 ASYNC PROCESS DEBUG - emit_to_focus_group called: {event} for focus group {focus_group_id}")
print(f"🔔 ASYNC PROCESS DEBUG - PID: {process_id}, Thread: {thread_id}")
print(f"🔔 Focus group rooms: {list(self.focus_group_rooms.keys())}")
try:
if focus_group_id not in self.focus_group_rooms:
print(f"🔔 ASYNC ERROR: No active sessions for focus group {focus_group_id}")
logger.debug(f"No active sessions for focus group {focus_group_id}")
return
room_name = focus_group_id
room_sessions = self.focus_group_rooms[focus_group_id].copy()
print(f"🔔 ASYNC Room {focus_group_id} has {len(room_sessions)} tracked sessions: {list(room_sessions)}")
# Clean up stale sessions
active_sessions = []
stale_sessions = []
for session_id in room_sessions:
if session_id in self.user_sessions:
active_sessions.append(session_id)
else:
stale_sessions.append(session_id)
self.focus_group_rooms[focus_group_id].discard(session_id)
if stale_sessions:
print(f"🔔 ASYNC Cleaned up {len(stale_sessions)} stale sessions: {stale_sessions}")
print(f"🔔 ASYNC Room {focus_group_id} has {len(active_sessions)} ACTIVE sessions: {active_sessions}")
if not active_sessions:
print(f"🔔 ASYNC ERROR: No active sessions remaining for focus group {focus_group_id} after cleanup")
return
# Prepare the event data
event_data = {
'focus_group_id': focus_group_id,
'timestamp': datetime.utcnow().isoformat(),
**data
}
if include_sender or not sender_session_id:
# Send to all users in the room using AsyncServer
print(f"🔔 ASYNC Emitting '{event}' to room {room_name} with data keys: {list(event_data.keys())}")
await self.sio.emit(event, event_data, room=room_name)
print(f"🔔 ASYNC Successfully emitted '{event}' to focus group {focus_group_id} ({len(active_sessions)} active users)")
logger.debug(f"Emitted '{event}' to focus group {focus_group_id} ({len(active_sessions)} active users)")
else:
# Send to all users except the sender
for session_id in active_sessions:
if session_id != sender_session_id:
await self.sio.emit(event, event_data, to=session_id)
logger.debug(f"Emitted '{event}' to focus group {focus_group_id} (excluding sender)")
except Exception as e:
logger.error(f"Failed to emit to focus group {focus_group_id}: {e}")
print(f"🔔 ASYNC ERROR: Failed to emit to focus group {focus_group_id}: {e}")
async def emit_message_update(self, focus_group_id: str, message_data: Dict[str, Any], sender_session_id: Optional[str] = None):
"""Emit a new message to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'message_update',
{'message': message_data},
include_sender=True,
sender_session_id=sender_session_id
)
async def emit_ai_status_update(self, focus_group_id: str, status_data: Dict[str, Any]):
"""Emit AI status change to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'ai_status_update',
{'status': status_data}
)
async def emit_moderator_status_update(self, focus_group_id: str, moderator_status: Dict[str, Any]):
"""Emit moderator status change to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'moderator_status_update',
{'moderator_status': moderator_status}
)
async def emit_theme_update(self, focus_group_id: str, theme_data: Dict[str, Any], action: str = 'added'):
"""Emit theme update to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'theme_update',
{'theme': theme_data, 'action': action}
)
async def emit_analytics_update(self, focus_group_id: str, analytics_data: Dict[str, Any]):
"""Emit analytics update to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'analytics_update',
{'analytics': analytics_data}
)
async def emit_conversation_state_update(self, focus_group_id: str, state_data: Dict[str, Any]):
"""Emit conversation state update to focus group participants."""
await self.emit_to_focus_group(
focus_group_id,
'conversation_state_update',
{'state': state_data}
)
def get_room_info(self, focus_group_id: str) -> Dict[str, Any]:
"""Get information about a focus group room."""
if focus_group_id not in self.focus_group_rooms:
return {'active_sessions': 0, 'users': []}
sessions = self.focus_group_rooms[focus_group_id]
users = []
for session_id in sessions:
if session_id in self.user_sessions:
user_info = self.user_sessions[session_id]
users.append({
'session_id': session_id,
'user_id': user_info['user_id'],
'connected_at': user_info['connected_at'].isoformat()
})
return {
'active_sessions': len(sessions),
'users': users
}
def get_connection_stats(self) -> Dict[str, Any]:
"""Get overall connection statistics."""
return {
'total_sessions': len(self.user_sessions),
'total_focus_groups': len(self.focus_group_rooms),
'focus_group_details': {
fg_id: len(sessions) for fg_id, sessions in self.focus_group_rooms.items()
}
}
# Global WebSocket manager instance
websocket_manager: Optional[AsyncWebSocketManager] = None
def init_async_websocket_manager() -> AsyncWebSocketManager:
"""Initialize the global async WebSocket manager."""
global websocket_manager
websocket_manager = AsyncWebSocketManager()
return websocket_manager
def get_async_websocket_manager() -> Optional[AsyncWebSocketManager]:
"""Get the global async WebSocket manager instance."""
return websocket_manager
# Async emit function for use throughout the codebase
async def emit_websocket_event(event: str, data: dict, room: str | None = None) -> None:
"""
Async WebSocket event emission function.
Args:
event: Event name
data: Event data
room: Room to emit to (focus group ID)
"""
try:
if websocket_manager and room:
await websocket_manager.emit_to_focus_group(room, event, data)
print(f"🔔 ASYNC - Successfully emitted {event} to room {room}")
else:
print(f"🔔 ASYNC - WebSocket manager not available or no room specified")
except Exception as e:
print(f"🔔 ASYNC ERROR emitting WebSocket event {event}: {e}")
logger.exception(f"Error emitting WebSocket event {event}: {e}")

View file

@ -7,7 +7,7 @@ flask-jwt-extended
bcrypt
pydantic
hypercorn
google-generativeai
google-genai
openai
requests
llama-cloud-services

View file

@ -2,18 +2,13 @@ import os
import subprocess
import sys
import tempfile
# GPT-5 fix: Monkey patch BEFORE any other imports that might use threads/sockets
try:
import eventlet
eventlet.monkey_patch()
print("✅ GPT-5 FIX: Early eventlet monkey patching applied")
except ImportError:
print("⚠️ Eventlet not available for early monkey patching")
import asyncio
import signal
import threading
# Set up temp directories FIRST, before any imports that might use temp files
def setup_early_temp_directories():
"""Set up temp directories before Flask imports."""
"""Set up temp directories before Quart imports."""
backend_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = os.path.join(backend_dir, 'temp')
@ -45,59 +40,266 @@ setup_early_temp_directories()
from app import create_app
from app.models.user import User
# Create the Flask app
flask_app = create_app()
# Create the ASGI app (which wraps Quart + SocketIO)
asgi_app = create_app()
# Extract the Quart app from the ASGI wrapper
quart_app = asgi_app.quart_app
# Initialize database on startup
def initialize_database():
async def initialize_database():
# Create default user if it doesn't exist
User.create_default_user()
await User.create_default_user()
# Call initialization immediately
with flask_app.app_context():
initialize_database()
# Use the ASGI app for the server
app = asgi_app
# For SocketIO, we need to use the socketio app directly
app = flask_app
async def startup():
"""Initialize the application on startup."""
await initialize_database()
# Signal handlers are now managed within the asyncio event loop for proper async shutdown
async def run_server():
"""Run the server with enhanced shutdown pattern and diagnostics."""
import hypercorn.asyncio
from hypercorn import Config
import faulthandler
import threading
# Enable fault handler for debugging
faulthandler.register(signal.SIGUSR1 if hasattr(signal, 'SIGUSR1') else signal.SIGBREAK)
print("Starting Quart + SocketIO app with hypercorn ASGI server...")
print("📡 WebSocket functionality enabled")
print("🤖 AI Runner active for autonomous conversations")
print("⚡ All operations async and non-blocking")
print("🛑 Use Ctrl-C for graceful shutdown")
print("🔍 Debug: Send SIGUSR1 for stack dump if it hangs")
print("Started Semblance back end service")
# Create hypercorn config with debug settings
config = Config()
config.bind = ["0.0.0.0:5137"]
config.debug = False
config.use_reloader = False
config.loglevel = "info" # Enable more logging
config.lifespan = "on" # Ensure lifespan is enabled
config.startup_timeout = 60
config.shutdown_timeout = 5 # Shorter for faster shutdown
config.graceful_timeout = 3 # Shorter for faster shutdown
config.keep_alive_timeout = 2 # Cut lingering connections faster
# Add startup tasks to Quart app
@quart_app.before_serving
async def startup_task():
await startup()
# Add shutdown tasks to Quart app with timeout protection
@quart_app.after_serving
async def shutdown_task():
print("🛑 Quart app shutting down...")
# 1. Shutdown SocketIO first to stop WebSocket tasks
try:
print("🔌 Shutting down SocketIO...")
from app.extensions import socketio_server
# First disconnect all clients
print("🔌 Disconnecting all SocketIO clients...")
try:
# Get all sessions and disconnect them
for sid in list(socketio_server.manager.get_participants('/', None) or []):
try:
await socketio_server.disconnect(sid, namespace='/')
except Exception as disconnect_err:
print(f"Error disconnecting {sid}: {disconnect_err}")
print("✅ All SocketIO clients disconnected")
except Exception as disconnect_err:
print(f"Error during client disconnection: {disconnect_err}")
# Then shutdown the server
try:
if hasattr(socketio_server, 'shutdown'):
await socketio_server.shutdown()
print("✅ SocketIO server shutdown complete")
else:
# Try shutting down the underlying engine.io server
if hasattr(socketio_server, 'eio') and hasattr(socketio_server.eio, 'shutdown'):
await socketio_server.eio.shutdown()
print("✅ EngineIO shutdown complete")
else:
print("⚠️ No shutdown method available - relying on task cancellation")
except Exception as shutdown_err:
print(f"SocketIO shutdown error: {shutdown_err}")
except Exception as e:
print(f"⚠️ Overall SocketIO shutdown error: {e}")
# 2. Shutdown AI Runner
try:
await asyncio.wait_for(
asyncio.to_thread(shutdown_ai_runner_safe),
timeout=3.0 # Shorter timeout
)
print("✅ AI Runner shutdown complete")
except asyncio.TimeoutError:
print("⏱️ AI Runner shutdown timed out; continuing")
except Exception as e:
print(f"⚠️ AI Runner shutdown error: {e}")
# 3. Close database connections
try:
print("🗄️ Closing database connections...")
await close_database_connections()
print("✅ Database connections closed")
except Exception as e:
print(f"⚠️ Database close error: {e}")
# 4. Cancel any remaining engineio/socketio tasks
await cancel_socketio_tasks()
def shutdown_ai_runner_safe():
"""Safe AI Runner shutdown that doesn't block the main event loop."""
try:
from app.services.ai_runner_service import shutdown_ai_runner
shutdown_ai_runner()
except Exception as e:
print(f"Error in AI Runner shutdown: {e}")
async def close_database_connections():
"""Close all database connections to stop PyMongo background threads."""
try:
# Close the global Motor client singleton
from app.db import close_db_connections
await asyncio.to_thread(close_db_connections)
except Exception as e:
print(f"Database connection close error: {e}")
async def cancel_socketio_tasks():
"""Cancel any remaining SocketIO/EngineIO tasks."""
try:
print("🧹 Cancelling remaining SocketIO tasks...")
cancelled_count = 0
for task in list(asyncio.all_tasks()):
if task.done() or task is asyncio.current_task():
continue
# Check if this is a SocketIO/EngineIO related task
try:
coro = task.get_coro()
module = getattr(coro, '__module__', '')
qualname = getattr(coro, '__qualname__', '')
full_name = f'{module}.{qualname}'
socketio_patterns = [
'engineio.async_socket',
'engineio.async_server',
'socketio.async_server',
'AsyncSocket._websocket_handler',
'AsyncSocket._send_ping',
'AsyncServer._service_task'
]
if any(pattern in full_name for pattern in socketio_patterns):
task.cancel()
cancelled_count += 1
except Exception:
pass # Ignore task inspection errors
if cancelled_count > 0:
print(f"🧹 Cancelled {cancelled_count} SocketIO tasks")
# Give cancellations time to complete
await asyncio.sleep(0.5)
except Exception as e:
print(f"Task cancellation error: {e}")
# Create shutdown event for Hypercorn
shutdown_event = asyncio.Event()
_second_signal = False # For double Ctrl-C force exit
def _raise_shutdown():
"""Signal handler with force-exit on second Ctrl-C."""
nonlocal _second_signal
if _second_signal:
print("🔥 Second Ctrl-C - forcing exit!")
import os
os._exit(1)
_second_signal = True
print("\\n🛑 Shutdown signal received...")
print("📊 Active threads:", [t.name for t in threading.enumerate()])
shutdown_event.set()
# Watchdog to ensure serve() returns
async def shutdown_watchdog():
await asyncio.sleep(4) # Wait 4 seconds (should be enough with enhanced cleanup)
if not shutdown_event.is_set():
return # Already completed
print("⏱️ Shutdown taking too long - dumping diagnostics...")
# Dump asyncio tasks
print("\\n=== Active asyncio tasks ===")
for task in asyncio.all_tasks():
if not task.done():
print(f"- {task.get_name()}: {task}")
# Dump threads
print("\\n=== Active threads ===")
for thread in threading.enumerate():
print(f"- {thread.name}: daemon={thread.daemon}, alive={thread.is_alive()}")
print("🔥 Use second Ctrl-C to force exit if needed")
asyncio.create_task(shutdown_watchdog(), name="shutdown-watchdog")
# Register signal handlers
loop = asyncio.get_running_loop()
try:
loop.add_signal_handler(signal.SIGINT, _raise_shutdown)
loop.add_signal_handler(signal.SIGTERM, _raise_shutdown)
except NotImplementedError:
# Windows fallback
signal.signal(signal.SIGINT, lambda *_: loop.call_soon_threadsafe(_raise_shutdown))
signal.signal(signal.SIGTERM, lambda *_: loop.call_soon_threadsafe(_raise_shutdown))
# Let Hypercorn handle shutdown gracefully with the trigger
print("🔍 Debug: Starting Hypercorn serve with shutdown_trigger...")
await hypercorn.asyncio.serve(
asgi_app,
config,
shutdown_trigger=shutdown_event.wait,
)
print("🛑 Hypercorn server stopped")
if __name__ == '__main__':
# Check if we have SocketIO support and run with eventlet
try:
import eventlet
print("Starting Flask-SocketIO app with eventlet...")
print("Started Semblance back end service")
# Run with SocketIO support - use the socketio instance to run the app
flask_app.socketio.run(
flask_app,
host='0.0.0.0',
port=5137,
debug=False,
use_reloader=False,
allow_unsafe_werkzeug=True
)
except ImportError as e:
print("Eventlet not found. Installing it...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "eventlet", "flask-socketio"])
# Try again
# Check if hypercorn is available
try:
import eventlet
print("Started Semblance back end service")
flask_app.socketio.run(
flask_app,
host='0.0.0.0',
port=5137,
debug=False,
use_reloader=False,
allow_unsafe_werkzeug=True
)
except Exception as e:
print(f"Failed to start with SocketIO: {e}")
print("Falling back to regular Flask...")
print("Started Semblance back end service")
flask_app.run(host='0.0.0.0', port=5137, debug=False)
import hypercorn.asyncio
from hypercorn import Config
except ImportError as e:
print("❌ Error: hypercorn is required for WebSocket functionality!")
print("Please install hypercorn: pip install hypercorn")
print(f"ImportError: {e}")
sys.exit(1)
# Run the server with proper shutdown handling
asyncio.run(run_server())
except KeyboardInterrupt:
print("\nShutting down...")
sys.exit(0)
print("\\n🛑 Keyboard interrupt - shutting down...")
except Exception as e:
print(f"❌ Unexpected error: {e}")
try:
from app.services.ai_runner_service import shutdown_ai_runner
shutdown_ai_runner()
except:
pass
sys.exit(1)
print("👋 Server stopped")

View file

@ -1689,18 +1689,24 @@ const FocusGroupSession = () => {
const response = await focusGroupAiApi.generateKeyThemes(id);
if (response.data && response.data.themes) {
setThemeGenerationComplete(true);
toastService.success(`Generated ${response.data.themes.length} key themes`, {
description: "New themes have been added to the analysis."
});
// Update themes state
// Update themes state immediately
setThemes(prevThemes => [...prevThemes, ...response.data.themes]);
// Allow progress bar to animate for at least 3 seconds before completing
setTimeout(() => {
setThemeGenerationComplete(true);
toastService.success(`Generated ${response.data.themes.length} key themes`, {
description: "New themes have been added to the analysis."
});
}, 3000);
} else {
setThemeGenerationComplete(true);
toastService.warning("No new themes were generated", {
description: "Try again when the discussion has more content."
});
// Allow progress bar to animate for at least 3 seconds before completing
setTimeout(() => {
setThemeGenerationComplete(true);
toastService.warning("No new themes were generated", {
description: "Try again when the discussion has more content."
});
}, 3000);
}
} catch (error) {
console.error('Error generating key themes:', error);