refactor: Update database session management and enhance chat memory services

- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks.
- Refactored chat memory services to utilize a shared `mem0` client for better memory management.
- Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers.
- Enhanced error handling and response management in chat-related services.
- Cleaned up unused code and improved overall structure for maintainability.
This commit is contained in:
sudipnext 2026-04-25 19:10:39 +05:45
parent 17ea7d9f95
commit 4e87dc8b70
25 changed files with 1508 additions and 368 deletions

View file

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import async_session_maker
from utils.ollama import pull_ollama_model
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
)
log_event_count = 0
session = await get_container_db_async_session().__anext__()
async with async_session_maker() as session:
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
return
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
return
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if "status" in event:
saved_model_status.status = event["status"]
if "status" in event:
saved_model_status.status = event["status"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await upsert_ollama_pull_status(session, model, saved_model_status)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
await upsert_ollama_pull_status(session, model, saved_model_status)
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
saved_model_status.done = True
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
async def upsert_ollama_pull_status(

View file

@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from models.ollama_model_metadata import OllamaModelMetadata
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import get_async_session
from utils.ollama import list_pulled_ollama_models
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
@ -29,7 +29,7 @@ async def get_available_models():
async def pull_model(
model: str,
background_tasks: BackgroundTasks,
session: AsyncSession = Depends(get_container_db_async_session),
session: AsyncSession = Depends(get_async_session),
):
if model not in SUPPORTED_OLLAMA_MODELS:

View file

@ -54,7 +54,6 @@ def main() -> None:
p = _sqlite_file_path(sync_url)
if p is not None:
paths.append(p)
paths.append(p.parent / "container.db")
seen: set[Path] = set()
for path in paths:

View file

@ -1,5 +1,4 @@
from collections.abc import AsyncGenerator
import os
from sqlalchemy.ext.asyncio import (
AsyncEngine,
create_async_engine,
@ -21,7 +20,6 @@ from models.sql.template_create_info import TemplateCreateInfoModel
from models.sql.slide import SlideModel
from models.sql.webhook_subscription import WebhookSubscription
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
from utils.get_env import get_app_data_directory_env
from utils.get_env import get_migrate_database_on_startup_env
@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
yield session
# Container DB (Lives inside the app data directory)
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
container_db_engine: AsyncEngine = create_async_engine(
container_db_url, connect_args={"check_same_thread": False}
)
container_db_async_session_maker = async_sessionmaker(
container_db_engine, expire_on_commit=False
)
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
async with container_db_async_session_maker() as session:
yield session
# Create Database and Tables
async def create_db_and_tables():
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
@ -76,18 +58,11 @@ async def create_db_and_tables():
TemplateModel.__table__,
WebhookSubscription.__table__,
AsyncPresentationGenerationTaskModel.__table__,
OllamaPullStatus.__table__,
],
)
)
async with container_db_engine.begin() as conn:
await conn.run_sync(
lambda sync_conn: SQLModel.metadata.create_all(
sync_conn,
tables=[OllamaPullStatus.__table__],
)
)
async def dispose_engines():
"""Dispose all engine connection pools.
@ -97,4 +72,3 @@ async def dispose_engines():
database and prevent stale / leaked connections.
"""
await sql_engine.dispose()
await container_db_engine.dispose()

View file

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import async_session_maker
from utils.ollama import pull_ollama_model
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
)
log_event_count = 0
session = await get_container_db_async_session().__anext__()
async with async_session_maker() as session:
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
return
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
return
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if "status" in event:
saved_model_status.status = event["status"]
if "status" in event:
saved_model_status.status = event["status"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await upsert_ollama_pull_status(session, model, saved_model_status)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
await upsert_ollama_pull_status(session, model, saved_model_status)
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
saved_model_status.done = True
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
async def upsert_ollama_pull_status(

View file

@ -1,10 +1,17 @@
import json
import uuid
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from models.chat import ChatMessageRequest, ChatMessageResponse
from models.chat import (
ChatConversationListItem,
ChatHistoryMessageItem,
ChatHistoryResponse,
ChatMessageRequest,
ChatMessageResponse,
)
from models.sse_response import (
SSECompleteResponse,
SSEErrorResponse,
@ -13,11 +20,57 @@ from models.sse_response import (
SSEResponse,
)
from services.chat import ChatTurnResult, PresentationChatService
from services.chat import sql_chat_history
from services.database import get_async_session
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
@CHAT_ROUTER.get("/conversations", response_model=list[ChatConversationListItem])
async def list_chat_conversations(
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
sql_session: AsyncSession = Depends(get_async_session),
):
raw = await sql_chat_history.list_conversations(
sql_session, presentation_id=presentation_id
)
return [
ChatConversationListItem(
conversation_id=uuid.UUID(str(item["conversation_id"])),
updated_at=item.get("updated_at"),
last_message_preview=item.get("last_message_preview"),
)
for item in raw
]
@CHAT_ROUTER.get("/history", response_model=ChatHistoryResponse)
async def get_chat_history(
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
conversation_id: uuid.UUID = Query(..., description="Conversation thread id"),
sql_session: AsyncSession = Depends(get_async_session),
):
rows = await sql_chat_history.load_messages_with_meta(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
return ChatHistoryResponse(
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=[
ChatHistoryMessageItem(
role=str(m.get("role") or ""),
content=str(m.get("content") or ""),
created_at=m.get("created_at")
if isinstance(m.get("created_at"), str)
else None,
)
for m in rows
],
)
@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse)
async def chat_message(
payload: ChatMessageRequest,

View file

@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from models.ollama_model_metadata import OllamaModelMetadata
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import get_async_session
from utils.ollama import list_pulled_ollama_models
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
@ -29,7 +29,7 @@ async def get_available_models():
async def pull_model(
model: str,
background_tasks: BackgroundTasks,
session: AsyncSession = Depends(get_container_db_async_session),
session: AsyncSession = Depends(get_async_session),
):
if model not in SUPPORTED_OLLAMA_MODELS:

View file

@ -18,3 +18,27 @@ class ChatMessageResponse(BaseModel):
tool_calls: list[str] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class ChatHistoryMessageItem(BaseModel):
role: str
content: str
created_at: Optional[str] = None
model_config = ConfigDict(extra="forbid")
class ChatHistoryResponse(BaseModel):
presentation_id: uuid.UUID
conversation_id: uuid.UUID
messages: list[ChatHistoryMessageItem]
model_config = ConfigDict(extra="forbid")
class ChatConversationListItem(BaseModel):
conversation_id: uuid.UUID
updated_at: Optional[str] = None
last_message_preview: Optional[str] = None
model_config = ConfigDict(extra="forbid")

View file

@ -1,6 +1,8 @@
from services.chat.service import ChatTurnResult, PresentationChatService
from services.chat.presentation_context_store import PresentationContextStore
__all__ = [
"ChatTurnResult",
"PresentationChatService",
"PresentationContextStore",
]

View file

@ -0,0 +1,324 @@
import asyncio
from datetime import datetime, timezone
import logging
import os
from typing import Any, Optional
from uuid import UUID
from services.mem0_oss_memory import get_shared_mem0_client
LOGGER = logging.getLogger(__name__)
CHAT_TURN_TAG = "[chat_turn]"
DEFAULT_MAX_STORED_TURNS = 20
class ChatMemoryStore:
def __init__(self):
self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True)
self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8)
self._max_context_chars = self._to_int(
os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000
)
self._max_stored_turns = self._to_int(
os.getenv("CHAT_MAX_STORED_TURNS"), default=DEFAULT_MAX_STORED_TURNS
)
self._namespace_prefix = (
os.getenv("MEM0_CHAT_NAMESPACE_PREFIX")
or os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX")
or "presentation"
).strip() or "presentation"
@staticmethod
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
return default
return str(value).strip().lower() in {"1", "true", "yes", "on"}
@staticmethod
def _to_int(value: Optional[str], default: int) -> int:
try:
parsed = int(value) if value is not None else default
return max(1, parsed)
except Exception:
return default
@staticmethod
def _normalize(value: str) -> str:
return " ".join((value or "").split())
@staticmethod
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
return isinstance(exc, (Exception, SystemExit))
def _scope_user_id(self, presentation_id: UUID, conversation_id: UUID) -> str:
return (
f"{self._namespace_prefix}:{presentation_id}:"
f"conversation:{conversation_id}"
)
def _truncate(self, text: str, limit: int = 20000) -> str:
if len(text) <= limit:
return text
return f"{text[:limit]}\n\n[TRUNCATED]"
async def _get_client(self):
if not self._enabled:
return None
return get_shared_mem0_client()
def _build_turn_payload(self, *, user_text: str, assistant_text: str) -> str:
memory_lines = [
CHAT_TURN_TAG,
f"turn_created_at={datetime.now(timezone.utc).isoformat()}",
]
if user_text:
memory_lines.append(f"user={user_text}")
if assistant_text:
memory_lines.append(f"assistant={assistant_text}")
return "\n".join(memory_lines)
@staticmethod
def _extract_text_field(item: dict[str, Any]) -> str:
memory_text = item.get("memory") or item.get("text") or item.get("data")
return str(memory_text).strip() if memory_text is not None else ""
def _collect_results(self, response: Any) -> list[dict[str, Any]]:
if isinstance(response, dict):
raw_results = (
response.get("results")
or response.get("memories")
or response.get("items")
or []
)
if isinstance(raw_results, list):
return [item for item in raw_results if isinstance(item, dict)]
return []
if isinstance(response, list):
return [item for item in response if isinstance(item, dict)]
return []
@staticmethod
def _safe_parse_datetime(raw_value: Any) -> datetime | None:
if not isinstance(raw_value, str) or not raw_value.strip():
return None
value = raw_value.strip().replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(value)
if parsed.tzinfo is None:
return parsed.replace(tzinfo=timezone.utc)
return parsed
except Exception:
return None
@staticmethod
def _extract_chat_turn_fields(text: str) -> tuple[str | None, str | None, datetime | None]:
if CHAT_TURN_TAG not in text:
return None, None, None
user_text: str | None = None
assistant_text: str | None = None
turn_created_at: datetime | None = None
for line in text.splitlines():
if line.startswith("user="):
user_text = line[len("user=") :].strip()
elif line.startswith("assistant="):
assistant_text = line[len("assistant=") :].strip()
elif line.startswith("turn_created_at="):
turn_created_at = ChatMemoryStore._safe_parse_datetime(
line[len("turn_created_at=") :].strip()
)
return user_text, assistant_text, turn_created_at
async def store_chat_turn(
self,
*,
presentation_id: UUID,
conversation_id: UUID,
user_message: str,
assistant_message: str,
) -> None:
client = await self._get_client()
if client is None:
return
user_text = self._normalize(user_message)
assistant_text = self._normalize(assistant_message)
if not user_text and not assistant_text:
return
payload = [
{
"role": "user",
"content": self._truncate(
self._build_turn_payload(
user_text=user_text,
assistant_text=assistant_text,
)
),
}
]
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
def _add():
try:
return client.add(payload, user_id=scoped_user_id, infer=False)
except TypeError:
return client.add(
messages=payload,
user_id=scoped_user_id,
infer=False,
)
try:
await asyncio.to_thread(_add)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
(
"Failed to add chat mem0 memory "
"(presentation_id=%s, conversation_id=%s)"
),
presentation_id,
conversation_id,
)
async def retrieve_context(
self,
*,
presentation_id: UUID,
conversation_id: UUID,
query: str,
) -> str:
client = await self._get_client()
if client is None:
return ""
trimmed_query = (query or "").strip()
if not trimmed_query:
return ""
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
def _search():
try:
return client.search(
trimmed_query,
filters={"user_id": scoped_user_id},
top_k=self._top_k,
)
except TypeError:
return client.search(
trimmed_query,
user_id=scoped_user_id,
top_k=self._top_k,
)
try:
response = await asyncio.to_thread(_search)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
(
"Failed to search chat mem0 memory "
"(presentation_id=%s, conversation_id=%s)"
),
presentation_id,
conversation_id,
)
return ""
results = self._collect_results(response)
memories: list[str] = []
for item in results:
normalized = self._extract_text_field(item)
if normalized:
memories.append(normalized)
if not memories:
return ""
deduped = list(dict.fromkeys(memories))
return self._truncate("\n\n".join(deduped), self._max_context_chars)
async def load_history(
self,
*,
presentation_id: UUID,
conversation_id: UUID,
) -> list[dict[str, str]]:
client = await self._get_client()
if client is None:
return []
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
def _get_all():
try:
return client.get_all(
filters={"user_id": scoped_user_id},
limit=max(10, self._max_stored_turns * 4),
)
except TypeError:
try:
return client.get_all(
user_id=scoped_user_id,
limit=max(10, self._max_stored_turns * 4),
)
except TypeError:
try:
return client.get_all(filters={"user_id": scoped_user_id})
except TypeError:
return client.get_all(user_id=scoped_user_id)
try:
response = await asyncio.to_thread(_get_all)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
(
"Failed to load chat mem0 history "
"(presentation_id=%s, conversation_id=%s)"
),
presentation_id,
conversation_id,
)
return []
results = self._collect_results(response)
ordered_turns: list[tuple[datetime, str, str]] = []
for index, item in enumerate(results):
text_value = self._extract_text_field(item)
if not text_value:
continue
user_text, assistant_text, embedded_timestamp = self._extract_chat_turn_fields(
text_value
)
if not user_text and not assistant_text:
continue
item_created_at = (
self._safe_parse_datetime(item.get("created_at"))
or self._safe_parse_datetime(item.get("updated_at"))
or self._safe_parse_datetime(item.get("event_at"))
)
timestamp = embedded_timestamp or item_created_at or datetime.fromtimestamp(
index, tz=timezone.utc
)
ordered_turns.append((timestamp, user_text or "", assistant_text or ""))
ordered_turns.sort(key=lambda turn: turn[0])
recent_turns = ordered_turns[-self._max_stored_turns :]
history: list[dict[str, str]] = []
for _, user_text, assistant_text in recent_turns:
if user_text:
history.append({"role": "user", "content": user_text})
if assistant_text:
history.append({"role": "assistant", "content": assistant_text})
return history
CHAT_MEMORY_STORE = ChatMemoryStore()

View file

@ -1,44 +1,14 @@
from datetime import datetime, timezone
import uuid
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.sql.key_value import KeyValueSqlModel
CHAT_CONVERSATION_KEY_PREFIX = "chat_conversation"
MAX_STORED_TURNS = 20
class ConversationMessage(BaseModel):
role: Literal["user", "assistant"]
content: str = Field(min_length=1)
model_config = ConfigDict(extra="forbid", strict=True)
@field_validator("content")
@classmethod
def normalize_content(cls, value: str) -> str:
normalized = value.strip()
if not normalized:
raise ValueError("Conversation message cannot be empty.")
return normalized
class ConversationPayload(BaseModel):
conversation_id: uuid.UUID
presentation_id: uuid.UUID
messages: list[ConversationMessage] = Field(default_factory=list)
updated_at: datetime
model_config = ConfigDict(extra="forbid", strict=True)
from services.chat.chat_memory_store import CHAT_MEMORY_STORE
from services.chat import sql_chat_history
class ChatConversationStore:
def __init__(self, sql_session: AsyncSession):
self._sql_session = sql_session
self._sql = sql_session
async def load_history(
self,
@ -46,13 +16,25 @@ class ChatConversationStore:
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
payload = await self._get_payload(conversation_id)
if not payload or payload.presentation_id != presentation_id:
return []
return [
{"role": message.role, "content": message.content}
for message in payload.messages
]
messages = await sql_chat_history.load_messages(
self._sql,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
if messages:
return messages
legacy = await CHAT_MEMORY_STORE.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
if legacy:
await sql_chat_history.replace_messages(
self._sql,
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=legacy,
)
return legacy
async def append_turn(
self,
@ -62,87 +44,35 @@ class ChatConversationStore:
user_message: str,
assistant_message: str,
) -> None:
payload = await self._get_payload(conversation_id)
messages = list(payload.messages) if payload else []
messages.append(ConversationMessage(role="user", content=user_message))
messages.append(ConversationMessage(role="assistant", content=assistant_message))
max_messages = MAX_STORED_TURNS * 2
messages = messages[-max_messages:]
next_payload = ConversationPayload(
conversation_id=conversation_id,
await sql_chat_history.append_turn(
self._sql,
presentation_id=presentation_id,
messages=messages,
updated_at=datetime.now(timezone.utc),
conversation_id=conversation_id,
user_message=user_message,
assistant_message=assistant_message,
)
await CHAT_MEMORY_STORE.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=assistant_message,
)
row = await self._get_row(conversation_id)
if row:
row.value = next_payload.model_dump(mode="json")
self._sql_session.add(row)
else:
self._sql_session.add(
KeyValueSqlModel(
key=self._conversation_key(conversation_id),
value=next_payload.model_dump(mode="json"),
)
)
await self._sql_session.commit()
async def retrieve_semantic_context(
self,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
query: str,
) -> str:
return await CHAT_MEMORY_STORE.retrieve_context(
presentation_id=presentation_id,
conversation_id=conversation_id,
query=query,
)
async def ensure_conversation_id(
self,
conversation_id: uuid.UUID | None,
) -> uuid.UUID:
return conversation_id or uuid.uuid4()
async def _get_payload(
self, conversation_id: uuid.UUID
) -> ConversationPayload | None:
row = await self._get_row(conversation_id)
if not row or not isinstance(row.value, dict):
return None
raw_payload: dict[str, Any] = row.value
try:
return ConversationPayload.model_validate(raw_payload)
except ValidationError:
return self._coerce_payload(raw_payload)
def _coerce_payload(self, payload: dict[str, Any]) -> ConversationPayload | None:
try:
conversation_id = uuid.UUID(str(payload.get("conversation_id")))
presentation_id = uuid.UUID(str(payload.get("presentation_id")))
except Exception:
return None
raw_messages = payload.get("messages")
messages: list[ConversationMessage] = []
if isinstance(raw_messages, list):
for entry in raw_messages:
if not isinstance(entry, dict):
continue
try:
message = ConversationMessage.model_validate(entry)
messages.append(message)
except ValidationError:
continue
return ConversationPayload(
conversation_id=conversation_id,
presentation_id=presentation_id,
messages=messages[-(MAX_STORED_TURNS * 2) :],
updated_at=datetime.now(timezone.utc),
)
async def _get_row(self, conversation_id: uuid.UUID) -> KeyValueSqlModel | None:
return await self._sql_session.scalar(
select(KeyValueSqlModel).where(
KeyValueSqlModel.key == self._conversation_key(conversation_id)
)
)
@staticmethod
def _conversation_key(conversation_id: uuid.UUID) -> str:
return f"{CHAT_CONVERSATION_KEY_PREFIX}:{conversation_id}"

View file

@ -0,0 +1,5 @@
from services.chat.memory_layer import (
PresentationChatMemoryLayer as PresentationContextStore,
)
__all__ = ["PresentationContextStore"]

View file

@ -1,13 +1,31 @@
def build_system_prompt(memory_context: str) -> str:
context_block = (
"\nMemory context (use only when relevant):\n"
f"{memory_context}\n"
if memory_context
else ""
def _trim_block(label: str, text: str) -> str:
t = (text or "").strip()
if not t:
return ""
return f"\n{label}\n(use only when relevant; may be partial)\n{t}\n"
def build_system_prompt(
presentation_memory_context: str,
chat_memory_context: str,
) -> str:
"""
presentation_memory_context: deck-scoped (documents, outlines, prior slide edits, etc.)
chat_memory_context: this thread (prior asks, assistant replies) semantic slice.
"""
presentation_block = _trim_block(
"Presentation memory (this deck: source text, outlines, and stored slide/edit notes):",
presentation_memory_context,
)
chat_block = _trim_block(
"This conversation thread (what was asked and answered in chat):",
chat_memory_context,
)
return (
"You are Presenton backend chat assistant.\n"
"You can call tools to access presentation memory.\n"
"You can call tools to access live slide data and layouts.\n"
"Distinguish: presentation memory = facts about the deck; chat memory = this threads prior Q&A. "
"Tools still win for current slide content.\n"
"- Use getPresentationOutline for outline/section questions.\n"
"- Prefer compact tool outputs to save context window; do not request full slide JSON unless needed.\n"
"- Use searchSlides for finding relevant slide content snippets (DB-backed).\n"
@ -21,5 +39,6 @@ def build_system_prompt(memory_context: str) -> str:
"- After tool outputs are sufficient, stop calling tools and provide a final answer.\n"
"- If memory is missing, state that clearly and suggest next steps.\n"
"- Do not invent slide facts that are not in tool results or memory.\n"
f"{context_block}"
f"{presentation_block}"
f"{chat_block}"
)

View file

@ -19,7 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.sql.presentation import PresentationModel
from services.chat.conversation_store import ChatConversationStore
from services.chat.memory_layer import PresentationChatMemoryLayer
from services.chat.presentation_context_store import PresentationContextStore
from services.chat.prompts import build_system_prompt
from services.chat.tools import ChatTools
from utils.llm_client_error_handler import handle_llm_client_exceptions
@ -58,7 +58,7 @@ class PresentationChatService:
self._conversation_id = conversation_id
self._conversation_store = ChatConversationStore(sql_session)
self._memory = PresentationChatMemoryLayer(sql_session, presentation_id)
self._memory = PresentationContextStore(sql_session, presentation_id)
self._tools = ChatTools(self._memory)
async def generate_reply(self, user_message: str) -> ChatTurnResult:
@ -209,6 +209,8 @@ class PresentationChatService:
if not presentation:
raise HTTPException(status_code=404, detail="Presentation not found")
# A stable conversation_id is created here before the first user message
# so mem0 and SQL can scope the thread; the client need not "chat" first.
conversation_id = await self._conversation_store.ensure_conversation_id(
self._conversation_id
)
@ -218,9 +220,19 @@ class PresentationChatService:
)
history_messages = self._convert_history_to_messages(history)
memory_context = await self._memory.retrieve_context(user_message)
presentation_memory = await self._memory.retrieve_context(user_message)
chat_memory = await self._conversation_store.retrieve_semantic_context(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
query=user_message,
)
messages: list[Message] = [
SystemMessage(content=build_system_prompt(memory_context)),
SystemMessage(
content=build_system_prompt(
presentation_memory_context=presentation_memory,
chat_memory_context=chat_memory,
)
),
*history_messages,
UserMessage(content=user_message),
]

View file

@ -0,0 +1,231 @@
"""Persist presentation chat threads in ``KeyValueSqlModel``.
Each conversation is one row: key ``ppt_chat:{presentation_id}:{conversation_id}``,
value is JSON: ``{version, messages, updated_at, ...}``.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.sql.key_value import KeyValueSqlModel
CHAT_HISTORY_KEY_PREFIX = "ppt_chat"
SCHEMA_VERSION = 1
def chat_history_key(presentation_id: uuid.UUID, conversation_id: uuid.UUID) -> str:
return f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:{conversation_id}"
def _parse_conversation_key(key: str, presentation_id: uuid.UUID) -> uuid.UUID | None:
expected_prefix = f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:"
if not key.startswith(expected_prefix):
return None
rest = key[len(expected_prefix) :]
try:
return uuid.UUID(rest)
except ValueError:
return None
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
async def load_messages(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
"""Load ordered user/assistant messages for LLM context (role + content only)."""
key = chat_history_key(presentation_id, conversation_id)
row = await session.scalar(
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
)
if not row or not isinstance(row.value, dict):
return []
messages = row.value.get("messages")
if not isinstance(messages, list):
return []
out: list[dict[str, str]] = []
for item in messages:
if not isinstance(item, dict):
continue
if item.get("role") not in ("user", "assistant"):
continue
content = item.get("content")
if not isinstance(content, str) or not content.strip():
continue
out.append({"role": item["role"], "content": content})
return out
async def load_messages_with_meta(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, Any]]:
"""Load messages including optional ``created_at`` for API / UI."""
key = chat_history_key(presentation_id, conversation_id)
row = await session.scalar(
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
)
if not row or not isinstance(row.value, dict):
return []
messages = row.value.get("messages")
if not isinstance(messages, list):
return []
out: list[dict[str, Any]] = []
for item in messages:
if not isinstance(item, dict):
continue
if item.get("role") not in ("user", "assistant"):
continue
content = item.get("content")
if not isinstance(content, str) or not content.strip():
continue
entry: dict[str, Any] = {
"role": item["role"],
"content": content,
}
created = item.get("created_at")
if isinstance(created, str) and created.strip():
entry["created_at"] = created.strip()
out.append(entry)
return out
async def replace_messages(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
messages: list[dict[str, str]],
) -> None:
"""Replace transcript (e.g. one-time sync from mem0)."""
key = chat_history_key(presentation_id, conversation_id)
row = await session.scalar(
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
)
now = _utc_now_iso()
built: list[dict[str, Any]] = []
for m in messages:
if m.get("role") not in ("user", "assistant"):
continue
content = m.get("content")
if not isinstance(content, str):
continue
built.append(
{
"role": m["role"],
"content": content,
"created_at": now,
}
)
value = {
"version": SCHEMA_VERSION,
"presentation_id": str(presentation_id),
"conversation_id": str(conversation_id),
"messages": built,
"updated_at": now,
}
if row is None:
session.add(KeyValueSqlModel(key=key, value=value))
else:
row.value = value
await session.flush()
async def append_turn(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
) -> None:
key = chat_history_key(presentation_id, conversation_id)
row = await session.scalar(
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
)
now = _utc_now_iso()
new_messages: list[dict[str, Any]] = [
{"role": "user", "content": user_message, "created_at": now},
{"role": "assistant", "content": assistant_message, "created_at": now},
]
if row is None:
value: dict[str, Any] = {
"version": SCHEMA_VERSION,
"presentation_id": str(presentation_id),
"conversation_id": str(conversation_id),
"messages": new_messages,
"updated_at": now,
}
session.add(KeyValueSqlModel(key=key, value=value))
else:
data = row.value if isinstance(row.value, dict) else {}
existing = data.get("messages")
if not isinstance(existing, list):
existing = []
combined = [m for m in existing if isinstance(m, dict)]
combined.extend(new_messages)
data["version"] = SCHEMA_VERSION
data["presentation_id"] = str(presentation_id)
data["conversation_id"] = str(conversation_id)
data["messages"] = combined
data["updated_at"] = now
row.value = data
await session.flush()
async def list_conversations(
session: AsyncSession, *, presentation_id: uuid.UUID
) -> list[dict[str, Any]]:
"""Return conversation summaries for a presentation, newest ``updated_at`` first."""
prefix = f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:"
result = await session.scalars(
select(KeyValueSqlModel).where(KeyValueSqlModel.key.startswith(prefix))
)
rows = list(result.all())
summaries: list[dict[str, Any]] = []
for row in rows:
cid = _parse_conversation_key(row.key, presentation_id)
if cid is None:
continue
data = row.value if isinstance(row.value, dict) else {}
updated_at: str | None = None
raw_u = data.get("updated_at")
if isinstance(raw_u, str) and raw_u.strip():
updated_at = raw_u.strip()
messages = data.get("messages")
preview: str | None = None
if isinstance(messages, list) and messages:
for item in reversed(messages):
if not isinstance(item, dict):
continue
c = item.get("content")
if isinstance(c, str) and c.strip():
preview = c.strip()
if len(preview) > 200:
preview = f"{preview[:200]}"
break
summaries.append(
{
"conversation_id": str(cid),
"updated_at": updated_at,
"last_message_preview": preview,
}
)
summaries.sort(
key=lambda s: s.get("updated_at") or "",
reverse=True,
)
return summaries

View file

@ -15,7 +15,7 @@ from services.chat.schemas import (
SaveSlideInput,
SearchSlidesInput,
)
from services.chat.memory_layer import PresentationChatMemoryLayer
from services.chat.presentation_context_store import PresentationContextStore
LOGGER = logging.getLogger(__name__)
@ -30,7 +30,7 @@ class ChatTools:
provider-specific logic, keeping them portable across llmai backends.
"""
def __init__(self, memory: PresentationChatMemoryLayer):
def __init__(self, memory: PresentationContextStore):
self._memory = memory
self._tool_handlers: dict[str, ToolHandler] = {
"getPresentationOutline": self._get_presentation_outline,

View file

@ -1,5 +1,4 @@
from collections.abc import AsyncGenerator
import os
from sqlalchemy.ext.asyncio import (
AsyncEngine,
create_async_engine,
@ -20,7 +19,6 @@ from models.sql.template import TemplateModel
from models.sql.template_create_info import TemplateCreateInfoModel
from models.sql.slide import SlideModel
from models.sql.webhook_subscription import WebhookSubscription
from utils.get_env import get_app_data_directory_env
from utils.get_env import get_migrate_database_on_startup_env
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
yield session
# Container DB (Lives inside the app data directory)
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
container_db_engine: AsyncEngine = create_async_engine(
container_db_url, connect_args={"check_same_thread": False}
)
container_db_async_session_maker = async_sessionmaker(
container_db_engine, expire_on_commit=False
)
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
async with container_db_async_session_maker() as session:
yield session
# Create Database and Tables
async def create_db_and_tables():
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
@ -76,18 +58,11 @@ async def create_db_and_tables():
TemplateModel.__table__,
WebhookSubscription.__table__,
AsyncPresentationGenerationTaskModel.__table__,
OllamaPullStatus.__table__,
],
)
)
async with container_db_engine.begin() as conn:
await conn.run_sync(
lambda sync_conn: SQLModel.metadata.create_all(
sync_conn,
tables=[OllamaPullStatus.__table__],
)
)
async def dispose_engines():
"""Dispose all engine connection pools.
@ -97,4 +72,3 @@ async def dispose_engines():
database and prevent stale / leaked connections.
"""
await sql_engine.dispose()
await container_db_engine.dispose()

View file

@ -0,0 +1,131 @@
"""Single shared mem0 OSS ``Memory`` client for the process.
All callers (presentation context, chat turns) use the same on-disk Qdrant/SQLite
and distinguish data via mem0 ``user_id``:
- Deck-level (no chat thread): ``{namespace}:{presentation_id}``
- Chat thread: ``{namespace}:{presentation_id}:conversation:{conversation_id}``
The chat flow calls ``ensure_conversation_id`` before the first turn, so a
``conversation_id`` exists before any mem0 write for that thread.
"""
from __future__ import annotations
import logging
import os
import threading
from importlib import import_module
from typing import Any, Optional
LOGGER = logging.getLogger(__name__)
_memory_init_lock = threading.Lock()
_shared_client: Any | None = None
_init_attempted = False
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
return default
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _to_int(value: Optional[str], default: int) -> int:
try:
parsed = int(value) if value is not None else default
return max(1, parsed)
except Exception:
return default
def _oss_config_from_env() -> tuple[str, str, str, str, int, dict[str, Any]]:
"""Return (mem0_dir, qdrant_path, history_db, collection, dims, from_config_dict)."""
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
qdrant_path = (
os.getenv("MEM0_QDRANT_PATH") or os.path.join(mem0_dir, "qdrant")
).strip()
history_db_path = (
os.getenv("MEM0_HISTORY_DB_PATH") or os.path.join(mem0_dir, "history.db")
).strip()
collection = (
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
).strip() or "presenton_memories"
embedder = (os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed").strip() or "fastembed"
model = (
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
).strip() or "BAAI/bge-small-en-v1.5"
dims = _to_int(os.getenv("MEM0_EMBEDDING_DIMS"), default=384)
config: dict[str, Any] = {
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": collection,
"path": qdrant_path,
"on_disk": True,
"embedding_model_dims": dims,
},
},
"embedder": {
"provider": embedder,
"config": {
"model": model,
"embedding_dims": dims,
},
},
"history_db_path": history_db_path,
}
return mem0_dir, qdrant_path, history_db_path, collection, dims, config
def memory_from_config(config: dict[str, Any], *, telemetry_base: str) -> Any:
"""Construct ``mem0.Memory``. Caller must hold ``_memory_init_lock`` if used with shared state."""
os.makedirs(telemetry_base, exist_ok=True)
import mem0.memory.main as mem0_main # type: ignore[import-untyped]
mem0_main.mem0_dir = telemetry_base
memory_cls = getattr(import_module("mem0"), "Memory")
return memory_cls.from_config(config)
def get_shared_mem0_client() -> Any | None:
"""Return the process-wide mem0 client, or ``None`` if disabled or init failed."""
global _shared_client, _init_attempted
if not _to_bool(os.getenv("MEM0_ENABLED"), default=True):
return None
if _shared_client is not None:
return _shared_client
if _init_attempted:
return None
with _memory_init_lock:
if _shared_client is not None:
return _shared_client
if _init_attempted:
return None
_init_attempted = True
try:
mem0_dir, qdrant_path, history_db, collection, dims, config = (
_oss_config_from_env()
)
os.makedirs(mem0_dir, exist_ok=True)
os.makedirs(qdrant_path, exist_ok=True)
telemetry_base = os.path.join(mem0_dir, "telemetry", "oss")
_shared_client = memory_from_config(
config,
telemetry_base=telemetry_base,
)
LOGGER.info(
"Mem0 OSS shared memory initialized (qdrant_path=%s, history_db_path=%s, collection=%s, dims=%s)",
qdrant_path,
history_db,
collection,
dims,
)
except BaseException:
LOGGER.exception("Failed to initialize shared Mem0 OSS Memory")
_shared_client = None
return _shared_client

View file

@ -2,10 +2,11 @@ import asyncio
import json
import logging
import os
from importlib import import_module
from typing import Any, Optional
from uuid import UUID
from services.mem0_oss_memory import get_shared_mem0_client
LOGGER = logging.getLogger(__name__)
@ -21,31 +22,6 @@ class Mem0PresentationMemoryService:
os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") or "presentation"
).strip() or "presentation"
self._embedder_provider = (
os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed"
).strip() or "fastembed"
self._embedder_model = (
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
).strip() or "BAAI/bge-small-en-v1.5"
self._embedding_dims = self._to_int(
os.getenv("MEM0_EMBEDDING_DIMS"),
default=384,
)
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
self._mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
self._qdrant_path = (os.getenv("MEM0_QDRANT_PATH") or os.path.join(self._mem0_dir, "qdrant")).strip()
self._history_db_path = (
os.getenv("MEM0_HISTORY_DB_PATH")
or os.path.join(self._mem0_dir, "history.db")
).strip()
self._collection_name = (
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
).strip() or "presenton_memories"
self._client: Any = None
self._attempted_client_init = False
@staticmethod
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
@ -68,27 +44,6 @@ class Mem0PresentationMemoryService:
return text
return f"{text[:limit]}\n\n[TRUNCATED]"
def _get_oss_config(self) -> dict:
return {
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": self._collection_name,
"path": self._qdrant_path,
"on_disk": True,
"embedding_model_dims": self._embedding_dims,
},
},
"embedder": {
"provider": self._embedder_provider,
"config": {
"model": self._embedder_model,
"embedding_dims": self._embedding_dims,
},
},
"history_db_path": self._history_db_path,
}
@staticmethod
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
return isinstance(exc, (Exception, SystemExit))
@ -96,42 +51,7 @@ class Mem0PresentationMemoryService:
async def _get_client(self):
if not self._enabled:
return None
if self._client is not None:
return self._client
if self._attempted_client_init:
return None
self._attempted_client_init = True
try:
module = import_module("mem0")
memory_cls = getattr(module, "Memory")
os.makedirs(self._mem0_dir, exist_ok=True)
os.makedirs(self._qdrant_path, exist_ok=True)
config = self._get_oss_config()
try:
self._client = memory_cls.from_config(config)
except Exception:
# Backward compatibility across mem0 OSS versions.
self._client = memory_cls(config)
LOGGER.info(
"Mem0 OSS presentation memory service initialized (qdrant_path=%s, history_db_path=%s)",
self._qdrant_path,
self._history_db_path,
)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception("Failed to initialize Mem0 OSS Memory")
self._client = None
return self._client
return get_shared_mem0_client()
async def _add_message(self, presentation_id: UUID, message: str):
client = await self._get_client()

View file

@ -0,0 +1,136 @@
import asyncio
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
from services.chat.conversation_store import ChatConversationStore
class TestChatConversationStore:
def test_load_history_reads_sql_first(self):
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
expected_history = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
]
sql_session = MagicMock()
with patch(
"services.chat.conversation_store.sql_chat_history.load_messages",
new=AsyncMock(return_value=expected_history),
) as load_sql, patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
new=AsyncMock(),
) as load_mem0:
store = ChatConversationStore(sql_session)
history = asyncio.run(
store.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
)
load_sql.assert_awaited_once_with(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
load_mem0.assert_not_called()
assert history == expected_history
def test_load_history_falls_back_to_mem0_and_backfills_sql(self):
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
legacy = [
{"role": "user", "content": "old"},
{"role": "assistant", "content": "from mem0"},
]
sql_session = MagicMock()
with patch(
"services.chat.conversation_store.sql_chat_history.load_messages",
new=AsyncMock(return_value=[]),
) as load_sql, patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
new=AsyncMock(return_value=legacy),
) as load_mem0, patch(
"services.chat.conversation_store.sql_chat_history.replace_messages",
new=AsyncMock(),
) as replace_messages:
store = ChatConversationStore(sql_session)
history = asyncio.run(
store.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
)
load_sql.assert_awaited_once()
load_mem0.assert_awaited_once()
replace_messages.assert_awaited_once_with(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=legacy,
)
assert history == legacy
def test_append_turn_persists_sql_and_mem0(self):
sql_session = MagicMock()
store = ChatConversationStore(sql_session)
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
with patch(
"services.chat.conversation_store.sql_chat_history.append_turn",
new=AsyncMock(),
) as append_sql, patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.store_chat_turn",
new=AsyncMock(),
) as store_mem0:
asyncio.run(
store.append_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you improve slide 2?",
assistant_message="Yes, I will tighten the bullet points.",
)
)
append_sql.assert_awaited_once_with(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you improve slide 2?",
assistant_message="Yes, I will tighten the bullet points.",
)
store_mem0.assert_awaited_once_with(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you improve slide 2?",
assistant_message="Yes, I will tighten the bullet points.",
)
def test_retrieve_semantic_context_delegates_to_chat_memory_store(self):
store = ChatConversationStore(MagicMock())
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
with patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.retrieve_context",
new=AsyncMock(return_value="conversation-scoped context"),
) as retrieve_context:
context = asyncio.run(
store.retrieve_semantic_context(
presentation_id=presentation_id,
conversation_id=conversation_id,
query="What did we decide?",
)
)
retrieve_context.assert_awaited_once_with(
presentation_id=presentation_id,
conversation_id=conversation_id,
query="What did we decide?",
)
assert context == "conversation-scoped context"

View file

@ -0,0 +1,249 @@
import asyncio
import uuid
from unittest.mock import patch
import services.mem0_oss_memory as mem0_oss
from services.chat.chat_memory_store import ChatMemoryStore
class FakeMemoryClient:
instances: list["FakeMemoryClient"] = []
def __init__(self, config=None):
self.config = config
self.add_calls = []
self.search_calls = []
self.get_all_calls = []
self.next_search_response = {"results": []}
self.next_get_all_response = {"results": []}
FakeMemoryClient.instances.append(self)
@classmethod
def from_config(cls, config):
return cls(config=config)
def add(self, *args, **kwargs):
messages = kwargs.get("messages") if "messages" in kwargs else None
if messages is None and args:
messages = args[0]
self.add_calls.append(
{
"messages": messages,
"user_id": kwargs.get("user_id"),
"infer": kwargs.get("infer"),
}
)
return {"ok": True}
def search(self, query, *args, **kwargs):
self.search_calls.append(
{
"query": query,
"filters": kwargs.get("filters"),
"user_id": kwargs.get("user_id"),
"top_k": kwargs.get("top_k"),
}
)
return self.next_search_response
def get_all(self, *args, **kwargs):
self.get_all_calls.append(
{
"filters": kwargs.get("filters"),
"user_id": kwargs.get("user_id"),
"limit": kwargs.get("limit"),
}
)
return self.next_get_all_response
def _mem0_oss_fresh() -> None:
mem0_oss._shared_client = None # type: ignore[attr-defined]
mem0_oss._init_attempted = False # type: ignore[attr-defined]
class TestChatMemoryStore:
def setup_method(self):
FakeMemoryClient.instances = []
_mem0_oss_fresh()
def test_store_chat_turn_uses_conversation_scoped_user_id(self):
with patch.dict(
"os.environ",
{
"MEM0_ENABLED": "true",
"MEM0_TOP_K": "4",
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
},
clear=False,
), patch(
"services.chat.chat_memory_store.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {"provider": "fastembed", "config": {}},
}
),
):
store = ChatMemoryStore()
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
asyncio.run(
store.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you tighten slide 3?",
assistant_message="Yes, I can make it shorter.",
)
)
assert len(FakeMemoryClient.instances) == 1
client = FakeMemoryClient.instances[0]
assert len(client.add_calls) == 1
expected_user_id = (
f"presentation:{presentation_id}:conversation:{conversation_id}"
)
assert client.add_calls[0]["user_id"] == expected_user_id
assert client.add_calls[0]["infer"] is False
payload = str(client.add_calls[0]["messages"][0]["content"])
assert "[chat_turn]" in payload
assert "user=Can you tighten slide 3?" in payload
assert "assistant=Yes, I can make it shorter." in payload
def test_retrieve_context_reads_only_conversation_scoped_user_id(self):
with patch.dict(
"os.environ",
{
"MEM0_ENABLED": "true",
"MEM0_TOP_K": "6",
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
},
clear=False,
), patch(
"services.chat.chat_memory_store.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {"provider": "fastembed", "config": {}},
}
),
):
store = ChatMemoryStore()
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
expected_user_id = (
f"presentation:{presentation_id}:conversation:{conversation_id}"
)
asyncio.run(
store.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="First turn",
assistant_message="First answer",
)
)
client = FakeMemoryClient.instances[0]
client.next_search_response = {
"results": [
{"memory": "Chat memory A"},
{"memory": "Chat memory A"},
{"memory": "Chat memory B"},
]
}
context = asyncio.run(
store.retrieve_context(
presentation_id=presentation_id,
conversation_id=conversation_id,
query="What did we decide?",
)
)
assert "Chat memory A" in context
assert "Chat memory B" in context
assert context.count("Chat memory A") == 1
assert len(client.search_calls) == 1
assert client.search_calls[0]["query"] == "What did we decide?"
assert client.search_calls[0]["filters"] == {"user_id": expected_user_id}
assert client.search_calls[0]["top_k"] == 6
def test_load_history_reads_conversation_scoped_turns(self):
with patch.dict(
"os.environ",
{
"MEM0_ENABLED": "true",
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
},
clear=False,
), patch(
"services.chat.chat_memory_store.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {"provider": "fastembed", "config": {}},
}
),
):
store = ChatMemoryStore()
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
expected_user_id = (
f"presentation:{presentation_id}:conversation:{conversation_id}"
)
asyncio.run(
store.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Draft intro",
assistant_message="Updated intro done.",
)
)
client = FakeMemoryClient.instances[0]
client.next_get_all_response = {
"results": [
{
"memory": (
"[chat_turn]\n"
"turn_created_at=2026-04-25T10:00:00+00:00\n"
"user=Draft intro\nassistant=Updated intro done."
),
"created_at": "2026-04-25T10:00:01+00:00",
},
{
"memory": (
"[chat_turn]\n"
"turn_created_at=2026-04-25T10:01:00+00:00\n"
"user=Add roadmap\nassistant=Roadmap slide added."
),
"created_at": "2026-04-25T10:01:01+00:00",
},
]
}
history = asyncio.run(
store.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
)
assert history == [
{"role": "user", "content": "Draft intro"},
{"role": "assistant", "content": "Updated intro done."},
{"role": "user", "content": "Add roadmap"},
{"role": "assistant", "content": "Roadmap slide added."},
]
assert len(client.get_all_calls) == 1
assert client.get_all_calls[0]["filters"] == {"user_id": expected_user_id}
assert client.get_all_calls[0]["limit"] >= 10

View file

@ -2,11 +2,12 @@ import asyncio
import uuid
from unittest.mock import patch
import services.mem0_oss_memory as mem0_oss
from services.mem0_presentation_memory_service import Mem0PresentationMemoryService
class FakeMemoryClient:
instances = []
instances: list["FakeMemoryClient"] = []
def __init__(self, config=None):
self.config = config
@ -45,13 +46,15 @@ class FakeMemoryClient:
return self.next_search_response
class FakeMem0Module:
Memory = FakeMemoryClient
def _mem0_oss_fresh() -> None:
mem0_oss._shared_client = None # type: ignore[attr-defined]
mem0_oss._init_attempted = False # type: ignore[attr-defined]
class TestMem0PresentationMemoryService:
def setup_method(self):
FakeMemoryClient.instances = []
_mem0_oss_fresh()
def test_store_generation_context_uses_presentation_scope(self):
with patch.dict(
@ -62,8 +65,25 @@ class TestMem0PresentationMemoryService:
},
clear=False,
), patch(
"services.mem0_presentation_memory_service.import_module",
return_value=FakeMem0Module,
"services.mem0_presentation_memory_service.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {
"provider": "qdrant",
"config": {
"on_disk": True,
"embedding_model_dims": 384,
},
},
"embedder": {
"provider": "fastembed",
"config": {
"model": "BAAI/bge-small-en-v1.5",
"embedding_dims": 384,
},
},
}
),
):
service = Mem0PresentationMemoryService()
presentation_id = uuid.uuid4()
@ -115,8 +135,19 @@ class TestMem0PresentationMemoryService:
},
clear=False,
), patch(
"services.mem0_presentation_memory_service.import_module",
return_value=FakeMem0Module,
"services.mem0_presentation_memory_service.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {
"provider": "fastembed",
"config": {
"model": "BAAI/bge-small-en-v1.5",
"embedding_dims": 384,
},
},
}
),
):
service = Mem0PresentationMemoryService()
presentation_id = uuid.uuid4()
@ -154,3 +185,4 @@ class TestMem0PresentationMemoryService:
"user_id": f"presentation:{presentation_id}"
}
assert client.search_calls[0]["top_k"] == 5

View file

@ -1185,7 +1185,7 @@ wheels = [
[[package]]
name = "llmai"
version = "0.1.9"
version = "0.2.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anthropic" },
@ -1193,9 +1193,9 @@ dependencies = [
{ name = "google-genai" },
{ name = "openai" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/34/1d4188bd842f336726004896be9ae04b693b9c21d349918631433b9f1b63/llmai-0.2.1.tar.gz", hash = "sha256:f911bd7df3eb06d1c56612ce293f926df7b3bf6c36283a353cf780c697d39d31", size = 47862, upload-time = "2026-04-24T16:28:52.417Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" },
{ url = "https://files.pythonhosted.org/packages/45/7d/63b3da1be92b721b0681db5b9e96c1cbc000b63cb70ede40b20cc5302699/llmai-0.2.1-py3-none-any.whl", hash = "sha256:4c51d1186cce1e621f74a9ec70376dc1bd2e996eee484db17dce6a6e7b79a0a7", size = 59880, upload-time = "2026-04-24T16:28:50.706Z" },
]
[[package]]
@ -1686,7 +1686,7 @@ requires-dist = [
{ name = "fastmcp", specifier = ">=2.11.0" },
{ name = "google-genai", specifier = ">=1.28.0" },
{ name = "jsonschema", specifier = ">=4.26.0" },
{ name = "llmai", specifier = "==0.1.9" },
{ name = "llmai", specifier = "==0.2.1" },
{ name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" },
{ name = "nltk", specifier = ">=3.9.1" },
{ name = "openai", specifier = ">=1.98.0" },

View file

@ -252,6 +252,9 @@ const createMessageId = () => {
return `${Date.now()}-${Math.random().toString(16).slice(2)}`;
};
const conversationStorageKey = (presentationId: string) =>
`presenton:chat:conversationId:${presentationId}`;
const AssistantMarker = () => (
<div className="mb-3 flex items-center gap-1.5 text-[#A4A7AE]">
<MessageCircleMore className="h-4 w-4" />
@ -334,6 +337,7 @@ const Chat = ({
const [input, setInput] = useState("");
const [messages, setMessages] = useState<ChatMessage[]>([]);
const [conversationId, setConversationId] = useState<string | null>(null);
const [isHistoryLoading, setIsHistoryLoading] = useState(false);
const [isSending, setIsSending] = useState(false);
const [errorMessage, setErrorMessage] = useState<string | null>(null);
const [expandedActivityByMessage, setExpandedActivityByMessage] = useState<
@ -344,11 +348,70 @@ const Chat = ({
const messagesEndRef = useRef<HTMLDivElement | null>(null);
useEffect(() => {
let cancelled = false;
setMessages([]);
setInput("");
setConversationId(null);
setErrorMessage(null);
setExpandedActivityByMessage({});
if (!presentationId) {
return;
}
setIsHistoryLoading(true);
const run = async () => {
try {
if (typeof sessionStorage === "undefined") {
return;
}
const sKey = conversationStorageKey(presentationId);
let activeId = sessionStorage.getItem(sKey) ?? null;
if (!activeId) {
const list = await PresentationChatApi.listConversations(
presentationId
);
if (list.length > 0) {
activeId = list[0]!.conversation_id;
sessionStorage.setItem(sKey, activeId);
}
}
if (!activeId) {
return;
}
const data = await PresentationChatApi.getHistory(
presentationId,
activeId
);
if (cancelled) {
return;
}
setConversationId(activeId);
setMessages(
data.messages.map((m) => ({
id: createMessageId(),
role:
m.role === "assistant"
? "assistant"
: m.role === "user"
? "user"
: "user",
content: m.content,
}))
);
} catch (error) {
console.error("Failed to load chat history:", error);
toast.error("Could not load previous chat");
} finally {
if (!cancelled) {
setIsHistoryLoading(false);
}
}
};
void run();
return () => {
cancelled = true;
};
}, [presentationId]);
useEffect(() => {
@ -372,6 +435,9 @@ const Chat = ({
setConversationId(null);
setErrorMessage(null);
setExpandedActivityByMessage({});
if (presentationId && typeof sessionStorage !== "undefined") {
sessionStorage.removeItem(conversationStorageKey(presentationId));
}
inputRef.current?.focus();
};
@ -493,7 +559,7 @@ const Chat = ({
const submitMessage = async (rawMessage: string) => {
const trimmedMessage = rawMessage.trim();
if (!trimmedMessage || isSending) {
if (!trimmedMessage || isSending || isHistoryLoading) {
return;
}
@ -578,11 +644,23 @@ const Chat = ({
)
);
settleAssistantActivities(assistantMessageId, "success");
setConversationId((previous) =>
typeof response.conversation_id === "string"
? response.conversation_id
: previous
);
setConversationId((previous) => {
const next =
typeof response.conversation_id === "string"
? response.conversation_id
: previous;
if (
next &&
presentationId &&
typeof sessionStorage !== "undefined"
) {
sessionStorage.setItem(
conversationStorageKey(presentationId),
next
);
}
return next;
});
await refreshPresentationIfNeeded(
Array.isArray(response.tool_calls) ? response.tool_calls : []
@ -655,7 +733,7 @@ const Chat = ({
<button
type="button"
onClick={resetChat}
disabled={isSending}
disabled={isSending || isHistoryLoading}
className="rounded-full p-1 text-[#8C8C8C] transition-colors hover:bg-[#F7F7F7] hover:text-[#191919] disabled:cursor-not-allowed disabled:opacity-50"
aria-label="Reset chat"
title="Reset chat"
@ -665,7 +743,12 @@ const Chat = ({
</div>
<div className="min-h-0 flex-1 overflow-y-auto px-4 pb-4 pt-9 hide-scrollbar">
{messages.length === 0 ? (
{isHistoryLoading && messages.length === 0 ? (
<div className="flex items-center justify-center py-8 text-sm text-[#99A1AF]">
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
Loading chat
</div>
) : messages.length === 0 ? (
<>
<div>
<h4 className="mb-2 text-[10px] font-normal leading-[15px] tracking-[0.367px] text-[#99A1AF]">
@ -819,7 +902,7 @@ const Chat = ({
className="min-h-[92px] w-full resize-none bg-transparent pb-10 text-sm text-[#101828] placeholder:text-[#99A1AF] focus:outline-none focus:ring-0"
rows={4}
value={input}
disabled={isSending}
disabled={isSending || isHistoryLoading}
onChange={(event) => setInput(event.target.value)}
onKeyDown={handleKeyDown}
placeholder="Improve your slides..."
@ -836,7 +919,7 @@ const Chat = ({
</button>
<button
type="submit"
disabled={!input.trim() || isSending}
disabled={!input.trim() || isSending || isHistoryLoading}
className="absolute bottom-3 right-3 flex items-center gap-1.5 px-3 py-2 text-sm font-medium text-[#191919] disabled:cursor-not-allowed disabled:opacity-60"
style={{
background:

View file

@ -14,6 +14,24 @@ export interface ChatMessageResponse {
tool_calls?: string[];
}
export interface ChatHistoryMessage {
role: string;
content: string;
created_at?: string;
}
export interface ChatHistoryData {
presentation_id: string;
conversation_id: string;
messages: ChatHistoryMessage[];
}
export interface ChatConversationSummary {
conversation_id: string;
updated_at?: string | null;
last_message_preview?: string | null;
}
export interface ChatStreamTrace {
kind?: string;
round?: number;
@ -64,6 +82,38 @@ type ChatStreamData =
| Record<string, unknown>;
export class PresentationChatApi {
static async listConversations(
presentationId: string
): Promise<ChatConversationSummary[]> {
const u = new URL(getApiUrl("/api/v1/ppt/chat/conversations"));
u.searchParams.set("presentation_id", presentationId);
const response = await fetch(u.toString(), {
headers: getHeader(),
cache: "no-cache",
});
return await ApiResponseHandler.handleResponse(
response,
"Failed to list chat conversations"
);
}
static async getHistory(
presentationId: string,
conversationId: string
): Promise<ChatHistoryData> {
const u = new URL(getApiUrl("/api/v1/ppt/chat/history"));
u.searchParams.set("presentation_id", presentationId);
u.searchParams.set("conversation_id", conversationId);
const response = await fetch(u.toString(), {
headers: getHeader(),
cache: "no-cache",
});
return await ApiResponseHandler.handleResponse(
response,
"Failed to load chat history"
);
}
static async sendMessage(
payload: ChatMessageRequest
): Promise<ChatMessageResponse> {