- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks. - Refactored chat memory services to utilize a shared `mem0` client for better memory management. - Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers. - Enhanced error handling and response management in chat-related services. - Cleaned up unused code and improved overall structure for maintainability.
324 lines
11 KiB
Python
324 lines
11 KiB
Python
import asyncio
|
|
from datetime import datetime, timezone
|
|
import logging
|
|
import os
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from services.mem0_oss_memory import get_shared_mem0_client
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
CHAT_TURN_TAG = "[chat_turn]"
|
|
DEFAULT_MAX_STORED_TURNS = 20
|
|
|
|
|
|
class ChatMemoryStore:
|
|
def __init__(self):
|
|
self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True)
|
|
self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8)
|
|
self._max_context_chars = self._to_int(
|
|
os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000
|
|
)
|
|
self._max_stored_turns = self._to_int(
|
|
os.getenv("CHAT_MAX_STORED_TURNS"), default=DEFAULT_MAX_STORED_TURNS
|
|
)
|
|
self._namespace_prefix = (
|
|
os.getenv("MEM0_CHAT_NAMESPACE_PREFIX")
|
|
or os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX")
|
|
or "presentation"
|
|
).strip() or "presentation"
|
|
|
|
@staticmethod
|
|
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
|
if value is None:
|
|
return default
|
|
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
@staticmethod
|
|
def _to_int(value: Optional[str], default: int) -> int:
|
|
try:
|
|
parsed = int(value) if value is not None else default
|
|
return max(1, parsed)
|
|
except Exception:
|
|
return default
|
|
|
|
@staticmethod
|
|
def _normalize(value: str) -> str:
|
|
return " ".join((value or "").split())
|
|
|
|
@staticmethod
|
|
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
|
|
return isinstance(exc, (Exception, SystemExit))
|
|
|
|
def _scope_user_id(self, presentation_id: UUID, conversation_id: UUID) -> str:
|
|
return (
|
|
f"{self._namespace_prefix}:{presentation_id}:"
|
|
f"conversation:{conversation_id}"
|
|
)
|
|
|
|
def _truncate(self, text: str, limit: int = 20000) -> str:
|
|
if len(text) <= limit:
|
|
return text
|
|
return f"{text[:limit]}\n\n[TRUNCATED]"
|
|
|
|
async def _get_client(self):
|
|
if not self._enabled:
|
|
return None
|
|
return get_shared_mem0_client()
|
|
|
|
def _build_turn_payload(self, *, user_text: str, assistant_text: str) -> str:
|
|
memory_lines = [
|
|
CHAT_TURN_TAG,
|
|
f"turn_created_at={datetime.now(timezone.utc).isoformat()}",
|
|
]
|
|
if user_text:
|
|
memory_lines.append(f"user={user_text}")
|
|
if assistant_text:
|
|
memory_lines.append(f"assistant={assistant_text}")
|
|
return "\n".join(memory_lines)
|
|
|
|
@staticmethod
|
|
def _extract_text_field(item: dict[str, Any]) -> str:
|
|
memory_text = item.get("memory") or item.get("text") or item.get("data")
|
|
return str(memory_text).strip() if memory_text is not None else ""
|
|
|
|
def _collect_results(self, response: Any) -> list[dict[str, Any]]:
|
|
if isinstance(response, dict):
|
|
raw_results = (
|
|
response.get("results")
|
|
or response.get("memories")
|
|
or response.get("items")
|
|
or []
|
|
)
|
|
if isinstance(raw_results, list):
|
|
return [item for item in raw_results if isinstance(item, dict)]
|
|
return []
|
|
if isinstance(response, list):
|
|
return [item for item in response if isinstance(item, dict)]
|
|
return []
|
|
|
|
@staticmethod
|
|
def _safe_parse_datetime(raw_value: Any) -> datetime | None:
|
|
if not isinstance(raw_value, str) or not raw_value.strip():
|
|
return None
|
|
value = raw_value.strip().replace("Z", "+00:00")
|
|
try:
|
|
parsed = datetime.fromisoformat(value)
|
|
if parsed.tzinfo is None:
|
|
return parsed.replace(tzinfo=timezone.utc)
|
|
return parsed
|
|
except Exception:
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_chat_turn_fields(text: str) -> tuple[str | None, str | None, datetime | None]:
|
|
if CHAT_TURN_TAG not in text:
|
|
return None, None, None
|
|
|
|
user_text: str | None = None
|
|
assistant_text: str | None = None
|
|
turn_created_at: datetime | None = None
|
|
for line in text.splitlines():
|
|
if line.startswith("user="):
|
|
user_text = line[len("user=") :].strip()
|
|
elif line.startswith("assistant="):
|
|
assistant_text = line[len("assistant=") :].strip()
|
|
elif line.startswith("turn_created_at="):
|
|
turn_created_at = ChatMemoryStore._safe_parse_datetime(
|
|
line[len("turn_created_at=") :].strip()
|
|
)
|
|
return user_text, assistant_text, turn_created_at
|
|
|
|
async def store_chat_turn(
|
|
self,
|
|
*,
|
|
presentation_id: UUID,
|
|
conversation_id: UUID,
|
|
user_message: str,
|
|
assistant_message: str,
|
|
) -> None:
|
|
client = await self._get_client()
|
|
if client is None:
|
|
return
|
|
|
|
user_text = self._normalize(user_message)
|
|
assistant_text = self._normalize(assistant_message)
|
|
if not user_text and not assistant_text:
|
|
return
|
|
|
|
payload = [
|
|
{
|
|
"role": "user",
|
|
"content": self._truncate(
|
|
self._build_turn_payload(
|
|
user_text=user_text,
|
|
assistant_text=assistant_text,
|
|
)
|
|
),
|
|
}
|
|
]
|
|
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
|
|
|
def _add():
|
|
try:
|
|
return client.add(payload, user_id=scoped_user_id, infer=False)
|
|
except TypeError:
|
|
return client.add(
|
|
messages=payload,
|
|
user_id=scoped_user_id,
|
|
infer=False,
|
|
)
|
|
|
|
try:
|
|
await asyncio.to_thread(_add)
|
|
except BaseException as exc:
|
|
if not self._is_nonfatal_mem0_error(exc):
|
|
raise
|
|
LOGGER.exception(
|
|
(
|
|
"Failed to add chat mem0 memory "
|
|
"(presentation_id=%s, conversation_id=%s)"
|
|
),
|
|
presentation_id,
|
|
conversation_id,
|
|
)
|
|
|
|
async def retrieve_context(
|
|
self,
|
|
*,
|
|
presentation_id: UUID,
|
|
conversation_id: UUID,
|
|
query: str,
|
|
) -> str:
|
|
client = await self._get_client()
|
|
if client is None:
|
|
return ""
|
|
|
|
trimmed_query = (query or "").strip()
|
|
if not trimmed_query:
|
|
return ""
|
|
|
|
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
|
|
|
def _search():
|
|
try:
|
|
return client.search(
|
|
trimmed_query,
|
|
filters={"user_id": scoped_user_id},
|
|
top_k=self._top_k,
|
|
)
|
|
except TypeError:
|
|
return client.search(
|
|
trimmed_query,
|
|
user_id=scoped_user_id,
|
|
top_k=self._top_k,
|
|
)
|
|
|
|
try:
|
|
response = await asyncio.to_thread(_search)
|
|
except BaseException as exc:
|
|
if not self._is_nonfatal_mem0_error(exc):
|
|
raise
|
|
LOGGER.exception(
|
|
(
|
|
"Failed to search chat mem0 memory "
|
|
"(presentation_id=%s, conversation_id=%s)"
|
|
),
|
|
presentation_id,
|
|
conversation_id,
|
|
)
|
|
return ""
|
|
|
|
results = self._collect_results(response)
|
|
memories: list[str] = []
|
|
for item in results:
|
|
normalized = self._extract_text_field(item)
|
|
if normalized:
|
|
memories.append(normalized)
|
|
|
|
if not memories:
|
|
return ""
|
|
|
|
deduped = list(dict.fromkeys(memories))
|
|
return self._truncate("\n\n".join(deduped), self._max_context_chars)
|
|
|
|
async def load_history(
|
|
self,
|
|
*,
|
|
presentation_id: UUID,
|
|
conversation_id: UUID,
|
|
) -> list[dict[str, str]]:
|
|
client = await self._get_client()
|
|
if client is None:
|
|
return []
|
|
|
|
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
|
|
|
def _get_all():
|
|
try:
|
|
return client.get_all(
|
|
filters={"user_id": scoped_user_id},
|
|
limit=max(10, self._max_stored_turns * 4),
|
|
)
|
|
except TypeError:
|
|
try:
|
|
return client.get_all(
|
|
user_id=scoped_user_id,
|
|
limit=max(10, self._max_stored_turns * 4),
|
|
)
|
|
except TypeError:
|
|
try:
|
|
return client.get_all(filters={"user_id": scoped_user_id})
|
|
except TypeError:
|
|
return client.get_all(user_id=scoped_user_id)
|
|
|
|
try:
|
|
response = await asyncio.to_thread(_get_all)
|
|
except BaseException as exc:
|
|
if not self._is_nonfatal_mem0_error(exc):
|
|
raise
|
|
LOGGER.exception(
|
|
(
|
|
"Failed to load chat mem0 history "
|
|
"(presentation_id=%s, conversation_id=%s)"
|
|
),
|
|
presentation_id,
|
|
conversation_id,
|
|
)
|
|
return []
|
|
|
|
results = self._collect_results(response)
|
|
ordered_turns: list[tuple[datetime, str, str]] = []
|
|
for index, item in enumerate(results):
|
|
text_value = self._extract_text_field(item)
|
|
if not text_value:
|
|
continue
|
|
user_text, assistant_text, embedded_timestamp = self._extract_chat_turn_fields(
|
|
text_value
|
|
)
|
|
if not user_text and not assistant_text:
|
|
continue
|
|
|
|
item_created_at = (
|
|
self._safe_parse_datetime(item.get("created_at"))
|
|
or self._safe_parse_datetime(item.get("updated_at"))
|
|
or self._safe_parse_datetime(item.get("event_at"))
|
|
)
|
|
timestamp = embedded_timestamp or item_created_at or datetime.fromtimestamp(
|
|
index, tz=timezone.utc
|
|
)
|
|
ordered_turns.append((timestamp, user_text or "", assistant_text or ""))
|
|
|
|
ordered_turns.sort(key=lambda turn: turn[0])
|
|
recent_turns = ordered_turns[-self._max_stored_turns :]
|
|
|
|
history: list[dict[str, str]] = []
|
|
for _, user_text, assistant_text in recent_turns:
|
|
if user_text:
|
|
history.append({"role": "user", "content": user_text})
|
|
if assistant_text:
|
|
history.append({"role": "assistant", "content": assistant_text})
|
|
return history
|
|
|
|
|
|
CHAT_MEMORY_STORE = ChatMemoryStore()
|