458 lines
No EOL
15 KiB
Python
458 lines
No EOL
15 KiB
Python
"""
|
|
MongoDB Utilities for Netflix Chatbot
|
|
|
|
This module provides utility functions for interacting with MongoDB in the Netflix chatbot application.
|
|
It includes functions for connecting to MongoDB, and managing users, conversations, and messages.
|
|
"""
|
|
|
|
import pymongo
|
|
import logging
|
|
from datetime import datetime
|
|
import uuid
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from bson.objectid import ObjectId
|
|
import json
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler('mongodb.log')
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# MongoDB connection information
|
|
MONGO_URI = "mongodb://netflix:netflix@localhost:27017"
|
|
DB_NAME = "netflix_chatbot"
|
|
|
|
# Collection names
|
|
USERS_COLLECTION = "users"
|
|
CONVERSATIONS_COLLECTION = "conversations"
|
|
MESSAGES_COLLECTION = "messages"
|
|
|
|
# Global MongoDB client
|
|
mongo_client = None
|
|
db = None
|
|
|
|
def get_db():
|
|
"""Get or initialize the MongoDB database connection."""
|
|
global mongo_client, db
|
|
|
|
if mongo_client is None:
|
|
try:
|
|
mongo_client = pymongo.MongoClient(MONGO_URI)
|
|
mongo_client.admin.command('ping') # Test connection
|
|
db = mongo_client[DB_NAME]
|
|
logger.info("Successfully connected to MongoDB")
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to MongoDB: {e}")
|
|
raise
|
|
|
|
return db
|
|
|
|
def close_connection():
|
|
"""Close the MongoDB connection."""
|
|
global mongo_client
|
|
|
|
if mongo_client:
|
|
mongo_client.close()
|
|
mongo_client = None
|
|
logger.info("MongoDB connection closed")
|
|
|
|
# User functions
|
|
def get_user_by_username(username: str) -> Optional[Dict]:
|
|
"""Get a user by username."""
|
|
try:
|
|
db = get_db()
|
|
user = db[USERS_COLLECTION].find_one({"username": username})
|
|
return user
|
|
except Exception as e:
|
|
logger.error(f"Error getting user by username: {e}")
|
|
return None
|
|
|
|
def create_or_update_user(username: str, email: Optional[str] = None) -> Optional[str]:
|
|
"""Create a new user or update an existing one."""
|
|
try:
|
|
db = get_db()
|
|
|
|
# Check if user exists
|
|
existing_user = db[USERS_COLLECTION].find_one({"username": username})
|
|
|
|
if existing_user:
|
|
# Update last login
|
|
db[USERS_COLLECTION].update_one(
|
|
{"username": username},
|
|
{"$set": {"last_login": datetime.utcnow()}}
|
|
)
|
|
return str(existing_user["_id"])
|
|
else:
|
|
# Create new user
|
|
new_user = {
|
|
"username": username,
|
|
"created_at": datetime.utcnow(),
|
|
"last_login": datetime.utcnow()
|
|
}
|
|
|
|
# Only include email if it's not None to avoid unique constraint issues
|
|
if email:
|
|
new_user["email"] = email
|
|
|
|
result = db[USERS_COLLECTION].insert_one(new_user)
|
|
return str(result.inserted_id)
|
|
except Exception as e:
|
|
logger.error(f"Error creating or updating user: {e}")
|
|
# If the error is a duplicate key error, try to find the existing user
|
|
if "duplicate key error" in str(e) and "username" in str(e):
|
|
try:
|
|
existing_user = db[USERS_COLLECTION].find_one({"username": username})
|
|
if existing_user:
|
|
return str(existing_user["_id"])
|
|
except:
|
|
pass
|
|
return None
|
|
|
|
# Conversation functions
|
|
def get_conversation(session_id: str) -> Optional[Dict]:
|
|
"""Get a conversation by session ID."""
|
|
try:
|
|
db = get_db()
|
|
conversation = db[CONVERSATIONS_COLLECTION].find_one({"session_id": session_id})
|
|
return conversation
|
|
except Exception as e:
|
|
logger.error(f"Error getting conversation: {e}")
|
|
return None
|
|
|
|
def get_conversation_by_id(conversation_id: str) -> Optional[Dict]:
|
|
"""Get a conversation by its MongoDB ID."""
|
|
try:
|
|
db = get_db()
|
|
# Convert string ID to ObjectId
|
|
try:
|
|
obj_id = ObjectId(conversation_id)
|
|
conversation = db[CONVERSATIONS_COLLECTION].find_one({"_id": obj_id})
|
|
return conversation
|
|
except Exception as e:
|
|
logger.error(f"Error converting conversation ID to ObjectId: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting conversation by ID: {e}")
|
|
return None
|
|
|
|
def get_user_conversations(user_id: str) -> List[Dict]:
|
|
"""Get all conversations for a user that are not marked as deleted."""
|
|
try:
|
|
db = get_db()
|
|
conversations = list(db[CONVERSATIONS_COLLECTION].find(
|
|
{
|
|
"user_id": user_id,
|
|
# Only return conversations that either don't have is_deleted or have it set to False
|
|
"$or": [
|
|
{"is_deleted": {"$exists": False}},
|
|
{"is_deleted": False}
|
|
]
|
|
}
|
|
).sort("last_updated", pymongo.DESCENDING))
|
|
return conversations
|
|
except Exception as e:
|
|
logger.error(f"Error getting user conversations: {e}")
|
|
return []
|
|
|
|
def create_conversation(session_id: str, user_id: str, title: str = "New conversation") -> Optional[str]:
|
|
"""Create a new conversation."""
|
|
try:
|
|
db = get_db()
|
|
|
|
# Check if conversation already exists with this session_id
|
|
existing = db[CONVERSATIONS_COLLECTION].find_one({"session_id": session_id})
|
|
if existing:
|
|
return str(existing["_id"])
|
|
|
|
# Create new conversation
|
|
new_conversation = {
|
|
"session_id": session_id,
|
|
"user_id": user_id,
|
|
"title": title,
|
|
"created_at": datetime.utcnow(),
|
|
"last_updated": datetime.utcnow()
|
|
}
|
|
result = db[CONVERSATIONS_COLLECTION].insert_one(new_conversation)
|
|
return str(result.inserted_id)
|
|
except Exception as e:
|
|
logger.error(f"Error creating conversation: {e}")
|
|
return None
|
|
|
|
def update_conversation_title(conversation_id: str, title: str) -> bool:
|
|
"""Update the title of a conversation."""
|
|
try:
|
|
db = get_db()
|
|
db[CONVERSATIONS_COLLECTION].update_one(
|
|
{"_id": ObjectId(conversation_id)},
|
|
{"$set": {"title": title, "last_updated": datetime.utcnow()}}
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error updating conversation title: {e}")
|
|
return False
|
|
|
|
def update_conversation_timestamp(conversation_id: str) -> bool:
|
|
"""Update the last_updated timestamp of a conversation."""
|
|
try:
|
|
db = get_db()
|
|
db[CONVERSATIONS_COLLECTION].update_one(
|
|
{"_id": ObjectId(conversation_id)},
|
|
{"$set": {"last_updated": datetime.utcnow()}}
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error updating conversation timestamp: {e}")
|
|
return False
|
|
|
|
# Message functions
|
|
def add_message(conversation_id: str, role: str, content: str,
|
|
sources: Optional[List] = None, reasoning: Optional[List] = None,
|
|
images: Optional[List] = None) -> Optional[str]:
|
|
"""Add a message to a conversation."""
|
|
try:
|
|
db = get_db()
|
|
|
|
# Prepare the message document
|
|
message = {
|
|
"conversation_id": conversation_id,
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": datetime.utcnow()
|
|
}
|
|
|
|
# Add optional fields with serialization
|
|
if sources:
|
|
# Serialize sources
|
|
serialized_sources = json.loads(json.dumps(sources, default=serialize_custom_objects))
|
|
message["sources"] = serialized_sources
|
|
|
|
if reasoning:
|
|
# Serialize reasoning steps
|
|
serialized_reasoning = json.loads(json.dumps(reasoning, default=serialize_custom_objects))
|
|
message["reasoning"] = serialized_reasoning
|
|
|
|
if images:
|
|
# Serialize images
|
|
serialized_images = json.loads(json.dumps(images, default=serialize_custom_objects))
|
|
message["images"] = serialized_images
|
|
|
|
# Insert the message
|
|
result = db[MESSAGES_COLLECTION].insert_one(message)
|
|
|
|
# Update the conversation timestamp
|
|
update_conversation_timestamp(conversation_id)
|
|
|
|
return str(result.inserted_id)
|
|
except Exception as e:
|
|
logger.error(f"Error adding message: {e}")
|
|
return None
|
|
|
|
def serialize_custom_objects(obj):
|
|
"""
|
|
Custom serialization function for MongoDB.
|
|
Handles special types like ActionReasoningStep and other custom classes.
|
|
"""
|
|
if hasattr(obj, '__dict__'):
|
|
# For ActionReasoningStep, ObservationReasoningStep, etc.
|
|
if obj.__class__.__name__.endswith('ReasoningStep'):
|
|
result = {
|
|
'type': obj.__class__.__name__
|
|
}
|
|
|
|
# Add attributes based on the specific type
|
|
if hasattr(obj, 'action'):
|
|
result['action'] = obj.action
|
|
if hasattr(obj, 'action_input'):
|
|
result['action_input'] = obj.action_input
|
|
if hasattr(obj, 'observation'):
|
|
result['observation'] = obj.observation
|
|
if hasattr(obj, 'response'):
|
|
result['response'] = obj.response
|
|
if hasattr(obj, 'thought'):
|
|
result['thought'] = obj.thought
|
|
|
|
return result
|
|
|
|
# For other objects with __dict__
|
|
return {k: v for k, v in obj.__dict__.items()
|
|
if not k.startswith('_') and not callable(v)}
|
|
|
|
# For objects with content property
|
|
if hasattr(obj, 'content'):
|
|
return str(obj.content)
|
|
|
|
# For objects with string representation
|
|
try:
|
|
return str(obj)
|
|
except:
|
|
return f"<Unserializable object of type {type(obj).__name__}>"
|
|
|
|
def get_conversation_messages(conversation_id: str) -> List[Dict]:
|
|
"""Get all messages in a conversation."""
|
|
try:
|
|
db = get_db()
|
|
messages = list(db[MESSAGES_COLLECTION].find(
|
|
{"conversation_id": conversation_id}
|
|
).sort("timestamp", pymongo.ASCENDING))
|
|
return messages
|
|
except Exception as e:
|
|
logger.error(f"Error getting conversation messages: {e}")
|
|
return []
|
|
|
|
def generate_conversation_title(conversation_id: str, content: List[Dict]) -> Optional[str]:
|
|
"""
|
|
Generate a title for a conversation based on its content using AI.
|
|
|
|
Args:
|
|
conversation_id: The ID of the conversation
|
|
content: List of messages in the conversation
|
|
|
|
Returns:
|
|
A generated title, or None if generation failed
|
|
"""
|
|
try:
|
|
from llama_index.llms.openai import OpenAI as LlamaOpenAI
|
|
|
|
# Extract text from the conversation (first few messages)
|
|
conversation_text = "\n".join([
|
|
f"{msg['role']}: {msg['content']}"
|
|
for msg in content[:5] # Use first 5 messages or fewer
|
|
])
|
|
|
|
# Create LLM instance
|
|
llm = LlamaOpenAI(
|
|
model="gpt-4.1",
|
|
temperature=0.3,
|
|
)
|
|
|
|
# Generate title
|
|
prompt = f"""
|
|
Based on the following conversation, generate a short, descriptive title (max 5 words):
|
|
|
|
{conversation_text}
|
|
|
|
Title:
|
|
"""
|
|
|
|
response = llm.complete(prompt)
|
|
title = response.text.strip()
|
|
|
|
# Update the conversation with the new title
|
|
update_conversation_title(conversation_id, title)
|
|
|
|
return title
|
|
except Exception as e:
|
|
logger.error(f"Error generating conversation title: {e}")
|
|
return "New conversation" # Fallback title
|
|
|
|
def delete_conversation(conversation_id: str, hard_delete: bool = False) -> bool:
|
|
"""
|
|
Delete a conversation and its messages.
|
|
|
|
Args:
|
|
conversation_id: The ID of the conversation to delete
|
|
hard_delete: If True, physically delete the records; if False, mark as deleted
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
db = get_db()
|
|
|
|
if hard_delete:
|
|
# Permanently delete all messages in the conversation
|
|
db[MESSAGES_COLLECTION].delete_many({"conversation_id": conversation_id})
|
|
|
|
# Permanently delete the conversation
|
|
db[CONVERSATIONS_COLLECTION].delete_one({"_id": ObjectId(conversation_id)})
|
|
else:
|
|
# Mark the conversation as deleted
|
|
db[CONVERSATIONS_COLLECTION].update_one(
|
|
{"_id": ObjectId(conversation_id)},
|
|
{"$set": {"is_deleted": True}}
|
|
)
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting conversation: {e}")
|
|
return False
|
|
|
|
# Session state management
|
|
def get_session_state(session_id: str) -> Optional[Dict]:
|
|
"""
|
|
Get the session state from MongoDB.
|
|
|
|
Args:
|
|
session_id: The session ID
|
|
|
|
Returns:
|
|
The session state or None if not found
|
|
"""
|
|
try:
|
|
conversation = get_conversation(session_id)
|
|
if conversation:
|
|
# Return a minimal session state
|
|
return {
|
|
"initialized": True,
|
|
"conversation_id": str(conversation["_id"]),
|
|
"user_id": conversation["user_id"]
|
|
}
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting session state: {e}")
|
|
return None
|
|
|
|
def create_session_state(session_id: str, user_id: str, conversation_id: Optional[str] = None) -> Optional[Dict]:
|
|
"""
|
|
Create a new session state in MongoDB.
|
|
|
|
Args:
|
|
session_id: The session ID
|
|
user_id: The user ID
|
|
conversation_id: Optional conversation ID. If not provided, a new conversation will be created.
|
|
|
|
Returns:
|
|
The created session state or None if creation failed
|
|
"""
|
|
try:
|
|
if not conversation_id:
|
|
conversation_id = create_conversation(session_id, user_id)
|
|
|
|
if conversation_id:
|
|
return {
|
|
"initialized": True,
|
|
"conversation_id": conversation_id,
|
|
"user_id": user_id
|
|
}
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error creating session state: {e}")
|
|
return None
|
|
|
|
def update_session_state(session_id: str, state: Dict) -> bool:
|
|
"""
|
|
Update the session state in MongoDB.
|
|
|
|
Args:
|
|
session_id: The session ID
|
|
state: The new state to save
|
|
|
|
Returns:
|
|
True if the update was successful, False otherwise
|
|
"""
|
|
try:
|
|
conversation = get_conversation(session_id)
|
|
if conversation:
|
|
# If we need to store additional session state beyond the conversation
|
|
# we could add a separate collection for that
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error updating session state: {e}")
|
|
return False |