presenton/servers/fastapi/services/mem0_presentation_memory_service.py
sudipnext 4e87dc8b70 refactor: Update database session management and enhance chat memory services
- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks.
- Refactored chat memory services to utilize a shared `mem0` client for better memory management.
- Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers.
- Enhanced error handling and response management in chat-related services.
- Cleaned up unused code and improved overall structure for maintainability.
2026-04-25 19:10:39 +05:45

221 lines
6.8 KiB
Python

import asyncio
import json
import logging
import os
from typing import Any, Optional
from uuid import UUID
from services.mem0_oss_memory import get_shared_mem0_client
LOGGER = logging.getLogger(__name__)
class Mem0PresentationMemoryService:
def __init__(self):
self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True)
self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8)
self._max_context_chars = self._to_int(
os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000
)
self._namespace_prefix = (
os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") or "presentation"
).strip() or "presentation"
@staticmethod
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
return default
return str(value).strip().lower() in {"1", "true", "yes", "on"}
@staticmethod
def _to_int(value: Optional[str], default: int) -> int:
try:
parsed = int(value) if value is not None else default
return max(1, parsed)
except Exception:
return default
def _scope_user_id(self, presentation_id: UUID) -> str:
return f"{self._namespace_prefix}:{presentation_id}"
def _truncate(self, text: str, limit: int = 20000) -> str:
if len(text) <= limit:
return text
return f"{text[:limit]}\n\n[TRUNCATED]"
@staticmethod
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
return isinstance(exc, (Exception, SystemExit))
async def _get_client(self):
if not self._enabled:
return None
return get_shared_mem0_client()
async def _add_message(self, presentation_id: UUID, message: str):
client = await self._get_client()
if client is None or not message.strip():
return
scoped_user_id = self._scope_user_id(presentation_id)
payload = [{"role": "user", "content": self._truncate(message)}]
def _add():
try:
return client.add(payload, user_id=scoped_user_id, infer=False)
except TypeError:
return client.add(
messages=payload,
user_id=scoped_user_id,
infer=False,
)
try:
await asyncio.to_thread(_add)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
"Failed to add mem0 memory for presentation_id=%s", presentation_id
)
async def store_generation_context(
self,
presentation_id: UUID,
system_prompt: Optional[str],
user_prompt: Optional[str],
extracted_document_text: Optional[str],
source_content: Optional[str],
instructions: Optional[str],
):
if source_content:
await self._add_message(
presentation_id,
"[presentation_source_prompt]\n" + source_content,
)
if instructions:
await self._add_message(
presentation_id,
"[presentation_generation_instructions]\n" + instructions,
)
if system_prompt:
await self._add_message(
presentation_id,
"[outline_system_prompt]\n" + system_prompt,
)
if user_prompt:
await self._add_message(
presentation_id,
"[outline_user_prompt]\n" + user_prompt,
)
if extracted_document_text:
await self._add_message(
presentation_id,
"[document_extracted_text]\n" + extracted_document_text,
)
async def store_generated_outlines(self, presentation_id: UUID, outlines: Any):
if outlines is None:
return
try:
outlines_text = (
outlines
if isinstance(outlines, str)
else json.dumps(outlines, ensure_ascii=False)
)
except Exception:
outlines_text = str(outlines)
await self._add_message(
presentation_id,
"[generated_outlines]\n" + outlines_text,
)
async def store_slide_edit(
self,
presentation_id: UUID,
slide_index: Optional[int],
edit_prompt: str,
edited_slide_content: Any,
):
try:
edited_text = (
edited_slide_content
if isinstance(edited_slide_content, str)
else json.dumps(edited_slide_content, ensure_ascii=False)
)
except Exception:
edited_text = str(edited_slide_content)
index_text = f"{slide_index}" if slide_index is not None else "unknown"
message = (
f"[slide_edit]\n"
f"slide_index={index_text}\n"
f"user_edit_prompt={edit_prompt}\n"
f"edited_slide_content={edited_text}"
)
await self._add_message(presentation_id, message)
async def retrieve_context(self, presentation_id: UUID, query: str) -> str:
client = await self._get_client()
if client is None:
return ""
scoped_user_id = self._scope_user_id(presentation_id)
def _search():
try:
return client.search(
query,
filters={"user_id": scoped_user_id},
top_k=self._top_k,
)
except TypeError:
return client.search(
query,
user_id=scoped_user_id,
top_k=self._top_k,
)
try:
response = await asyncio.to_thread(_search)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
"Failed to search mem0 context for presentation_id=%s", presentation_id
)
return ""
results = []
if isinstance(response, dict):
results = response.get("results") or []
elif isinstance(response, list):
results = response
memories: list[str] = []
for item in results:
if not isinstance(item, dict):
continue
memory_text = item.get("memory") or item.get("text") or item.get("data")
if not memory_text:
continue
normalized = str(memory_text).strip()
if normalized:
memories.append(normalized)
if not memories:
return ""
deduped_memories = list(dict.fromkeys(memories))
context = "\n\n".join(deduped_memories)
return self._truncate(context, self._max_context_chars)
MEM0_PRESENTATION_MEMORY_SERVICE = Mem0PresentationMemoryService()