import logging from bson import ObjectId from app.db import get_db from datetime import datetime, timezone logger = logging.getLogger(__name__) # Allowed fields for create/update (mass assignment protection) PERSONA_ALLOWED_FIELDS = { "name", "age", "gender", "occupation", "education", "location", "techSavviness", "personality", "interests", "brandLoyalty", "priceConsciousness", "environmentalConcern", "hasPurchasingPower", "hasChildren", "thinkFeelDo", "description", "imageUrl", "folder_ids", "llm_model", "traits", "background", "goals", "communication_style", "values", "demographics", } class Persona: @staticmethod async def create(persona_data, user_id=None): db = await get_db() # Apply field allowlist (mass assignment protection) safe_data = {k: v for k, v in persona_data.items() if k in PERSONA_ALLOWED_FIELDS} # Add metadata safe_data["created_at"] = datetime.now(timezone.utc) safe_data["created_by"] = user_id # Initialize folder_ids array if not present if "folder_ids" not in safe_data: safe_data["folder_ids"] = [] result = await db.personas.insert_one(safe_data) logger.info(f"Persona created: {safe_data.get('name', 'Unknown')}") return str(result.inserted_id) @staticmethod async def find_by_id(persona_id): db = await get_db() try: if isinstance(persona_id, ObjectId): object_id = persona_id else: try: object_id = ObjectId(persona_id) except Exception as e: logger.warning(f"Invalid ObjectId format: {persona_id}: {e}") persona = await db.personas.find_one({"id": persona_id}) if persona: persona["_id"] = str(persona["_id"]) return persona return None persona = await db.personas.find_one({"_id": object_id}) if persona: persona["_id"] = str(persona["_id"]) return persona except Exception as e: logger.error(f"Error in find_by_id: {e}, persona_id: {persona_id}") return None @staticmethod async def find_by_user(user_id, limit=100): db = await get_db() personas = db.personas.find({"created_by": user_id}).sort("created_at", -1).limit(limit) result = [] async for persona in personas: persona["_id"] = str(persona["_id"]) result.append(persona) return result @staticmethod async def get_all(user_id=None, limit=100): try: db = await get_db() query = {"created_by": user_id} if user_id else {} personas = db.personas.find(query).sort("created_at", -1).limit(limit) result = [] async for persona in personas: persona["_id"] = str(persona["_id"]) result.append(persona) return result except Exception as e: logger.error(f"Error in Persona.get_all: {e}") return [] @staticmethod async def update(persona_id, data, user_id=None): db = await get_db() # Apply field allowlist filtered_data = {k: v for k, v in data.items() if k in PERSONA_ALLOWED_FIELDS} filtered_data["updated_at"] = datetime.now(timezone.utc) # Build ownership-aware query query = {"_id": ObjectId(persona_id)} if user_id: query["created_by"] = user_id result = await db.personas.update_one( query, {"$set": filtered_data} ) return result.modified_count > 0 @staticmethod async def delete(persona_id, user_id=None): db = await get_db() try: if isinstance(persona_id, ObjectId): object_id = persona_id else: try: object_id = ObjectId(persona_id) except Exception as e: logger.warning(f"Invalid ObjectId format for delete: {persona_id}: {e}") query = {"id": persona_id} if user_id: query["created_by"] = user_id result = await db.personas.delete_one(query) return result.deleted_count > 0 query = {"_id": object_id} if user_id: query["created_by"] = user_id result = await db.personas.delete_one(query) return result.deleted_count > 0 except Exception as e: logger.error(f"Error in delete: {e}, persona_id: {persona_id}") return False