cohorta/backend/app/services/persona_modification_service.py
Vadym Samoilenko e01569c412
All checks were successful
Deploy to Production / deploy (push) Successful in 2m23s
feat: commit all app changes — billing API, new auth, design overhaul
Includes frontend redesign (Navigation, billingApi), backend updates
(auth routes, admin routes, LLM service refactor), MSAL removal,
and dependency updates.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-23 19:04:43 +01:00

238 lines
No EOL
10 KiB
Python
Executable file

"""
Persona Modification Service
This service handles AI-powered modification of existing personas using natural language instructions.
It integrates with the LLM service to process modification requests while maintaining data integrity
and internal consistency of persona attributes.
"""
import json
import logging
from typing import Dict, Any, Optional
from datetime import datetime, timezone
from .llm_service import LLMService, LLMServiceError
from app.utils.prompt_loader import load_prompt, PromptLoaderError
from app.models.persona import Persona
from bson import ObjectId
logger = logging.getLogger(__name__)
class PersonaModificationError(Exception):
"""Exception raised for errors in the persona modification process."""
pass
class PersonaModificationService:
"""Service for modifying personas using AI."""
@staticmethod
def _sanitize_persona_for_json(persona_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitize persona data to make it JSON serializable for the LLM prompt.
Args:
persona_data: The persona data dictionary that may contain non-serializable objects
Returns:
A sanitized dictionary that can be JSON serialized
"""
sanitized = {}
for key, value in persona_data.items():
if isinstance(value, ObjectId):
# Convert ObjectId to string
sanitized[key] = str(value)
elif isinstance(value, datetime):
# Convert datetime to ISO string
sanitized[key] = value.isoformat()
elif isinstance(value, dict):
# Recursively sanitize nested dictionaries
sanitized[key] = PersonaModificationService._sanitize_persona_for_json(value)
elif isinstance(value, list):
# Sanitize list items
sanitized_list = []
for item in value:
if isinstance(item, dict):
sanitized_list.append(PersonaModificationService._sanitize_persona_for_json(item))
elif isinstance(item, ObjectId):
sanitized_list.append(str(item))
elif isinstance(item, datetime):
sanitized_list.append(item.isoformat())
else:
sanitized_list.append(item)
sanitized[key] = sanitized_list
else:
# Keep other values as-is
sanitized[key] = value
return sanitized
@staticmethod
def _protect_readonly_fields(original_persona: Dict[str, Any], modified_persona: Dict[str, Any]) -> Dict[str, Any]:
"""
Protect readonly fields from being modified by the LLM.
Args:
original_persona: The original persona data
modified_persona: The LLM-modified persona data
Returns:
Modified persona with readonly fields restored from original
"""
# List of fields that should never be modified
protected_fields = ['id', '_id', 'created_at', 'created_by']
for field in protected_fields:
if field in original_persona:
modified_persona[field] = original_persona[field]
# Ensure updated_at is set to current time
modified_persona['updated_at'] = datetime.now(timezone.utc).isoformat()
return modified_persona
@staticmethod
def _validate_persona_structure(persona_data: Dict[str, Any]) -> bool:
"""
Validate that the modified persona contains all required fields.
Args:
persona_data: The persona data to validate
Returns:
True if valid, False otherwise
"""
required_fields = ['name', 'age', 'gender', 'occupation', 'location', 'personality']
for field in required_fields:
if field not in persona_data or persona_data[field] is None:
logger.error(f"Missing required field: {field}")
return False
# Validate numeric fields are within expected ranges
numeric_fields = {
'techSavviness': (0, 100),
'brandLoyalty': (0, 100),
'priceConsciousness': (0, 100),
'environmentalConcern': (0, 100)
}
for field, (min_val, max_val) in numeric_fields.items():
if field in persona_data:
try:
value = int(persona_data[field])
if not (min_val <= value <= max_val):
logger.error(f"Field {field} value {value} out of range [{min_val}, {max_val}]")
return False
except (ValueError, TypeError):
logger.error(f"Field {field} is not a valid number")
return False
return True
@staticmethod
async def modify_persona(
persona_id: str,
modification_prompt: str,
llm_model: str = 'gpt-5.4',
reasoning_effort: str = 'medium',
verbosity: str = 'medium',
max_retries: int = 3,
preview_only: bool = False
) -> Dict[str, Any]:
"""
Modify a persona using AI based on natural language instructions.
Args:
persona_id: The ID of the persona to modify
modification_prompt: Natural language description of desired changes
llm_model: The LLM model to use for modification
reasoning_effort: Reasoning effort for GPT-5 (minimal, low, medium, high)
verbosity: Response verbosity for GPT-5 (low, medium, high)
max_retries: Maximum number of retries for invalid responses
preview_only: If True, returns modified data without saving to database
Returns:
Dictionary containing the modified persona data
Raises:
PersonaModificationError: If modification fails or validation fails
"""
try:
from app.services.llm_usage_context import set_llm_context
set_llm_context(feature="persona_modify", persona_id=persona_id)
# Fetch the original persona
original_persona = await Persona.find_by_id(persona_id)
if not original_persona:
raise PersonaModificationError(f"Persona with ID {persona_id} not found")
# Convert to dict and sanitize for JSON serialization
original_persona_dict = dict(original_persona) if hasattr(original_persona, '_data') else original_persona
sanitized_persona = PersonaModificationService._sanitize_persona_for_json(original_persona_dict)
# Load the modification prompt template
try:
final_prompt = load_prompt('persona-modification', {
'original_persona_json': json.dumps(sanitized_persona, indent=2),
'modification_prompt': modification_prompt
})
except PromptLoaderError as e:
logger.error(f"Failed to load persona modification prompt: {e}")
raise PersonaModificationError(f"Failed to load modification prompt: {str(e)}")
# Attempt modification with retries
for attempt in range(max_retries):
try:
logger.info(f"Attempting persona modification (attempt {attempt + 1}/{max_retries})")
# Call LLM service
llm_response = await LLMService.generate_content(
prompt=final_prompt,
temperature=0.3, # Lower temperature for consistent modifications
model_name=llm_model,
reasoning_effort=reasoning_effort,
verbosity=verbosity
)
# Parse JSON response
try:
modified_persona_data = json.loads(llm_response.strip())
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON response on attempt {attempt + 1}: {e}")
if attempt == max_retries - 1:
raise PersonaModificationError(f"LLM returned invalid JSON after {max_retries} attempts")
continue
# Validate the modified persona structure
if not PersonaModificationService._validate_persona_structure(modified_persona_data):
logger.warning(f"Invalid persona structure on attempt {attempt + 1}")
if attempt == max_retries - 1:
raise PersonaModificationError(f"LLM returned invalid persona structure after {max_retries} attempts")
continue
# Protect readonly fields
modified_persona_data = PersonaModificationService._protect_readonly_fields(
sanitized_persona, modified_persona_data
)
# Update the persona in the database (only if not preview mode)
if not preview_only:
success = await Persona.update(persona_id, modified_persona_data)
if not success:
raise PersonaModificationError("Failed to update persona in database")
logger.info(f"Successfully modified persona {persona_id}")
else:
logger.info(f"Generated preview for persona {persona_id} (not saved to database)")
# Return the modified persona data
return modified_persona_data
except LLMServiceError as e:
logger.error(f"LLM service error on attempt {attempt + 1}: {e}")
if attempt == max_retries - 1:
raise PersonaModificationError(f"LLM service failed after {max_retries} attempts: {str(e)}")
continue
except Exception as e:
logger.error(f"Unexpected error during persona modification: {e}")
raise PersonaModificationError(f"Persona modification failed: {str(e)}")