diff --git a/electron/servers/fastapi/api/v1/ppt/background_tasks.py b/electron/servers/fastapi/api/v1/ppt/background_tasks.py index 68f7e45e..aa7f4901 100644 --- a/electron/servers/fastapi/api/v1/ppt/background_tasks.py +++ b/electron/servers/fastapi/api/v1/ppt/background_tasks.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from models.ollama_model_status import OllamaModelStatus from models.sql.ollama_pull_status import OllamaPullStatus -from services.database import get_container_db_async_session +from services.database import async_session_maker from utils.ollama import pull_ollama_model @@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str): ) log_event_count = 0 - session = await get_container_db_async_session().__anext__() + async with async_session_maker() as session: + try: + async for event in pull_ollama_model(model): + if "error" in event: + saved_model_status.status = "error" + saved_model_status.done = True + saved_model_status.error = event["error"] + await upsert_ollama_pull_status(session, model, saved_model_status) + return - try: - async for event in pull_ollama_model(model): - if "error" in event: - saved_model_status.status = "error" - saved_model_status.done = True - saved_model_status.error = event["error"] - await upsert_ollama_pull_status(session, model, saved_model_status) - await session.close() - return + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue - log_event_count += 1 - if log_event_count != 1 and log_event_count % 20 != 0: - continue + if "completed" in event: + saved_model_status.downloaded = event["completed"] - if "completed" in event: - saved_model_status.downloaded = event["completed"] + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] - if not saved_model_status.size and "total" in event: - saved_model_status.size = event["total"] + if "status" in event: + saved_model_status.status = event["status"] - if "status" in event: - saved_model_status.status = event["status"] + await upsert_ollama_pull_status(session, model, saved_model_status) - await upsert_ollama_pull_status(session, model, saved_model_status) + except Exception as e: + saved_model_status.status = "error" + saved_model_status.done = True + saved_model_status.error = str(e) + await upsert_ollama_pull_status(session, model, saved_model_status) + raise HTTPException( + status_code=500, + detail=f"Failed to pull model: {e}", + ) - except Exception as e: - saved_model_status.status = "error" saved_model_status.done = True - saved_model_status.error = str(e) + saved_model_status.status = "pulled" + saved_model_status.downloaded = saved_model_status.size + saved_model_status.error = None + await upsert_ollama_pull_status(session, model, saved_model_status) - await session.close() - raise HTTPException( - status_code=500, - detail=f"Failed to pull model: {e}", - ) - - saved_model_status.done = True - saved_model_status.status = "pulled" - saved_model_status.downloaded = saved_model_status.size - saved_model_status.error = None - - await upsert_ollama_pull_status(session, model, saved_model_status) - await session.close() async def upsert_ollama_pull_status( diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/electron/servers/fastapi/api/v1/ppt/endpoints/ollama.py index 0dafa3e1..43505831 100644 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/electron/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS from models.ollama_model_metadata import OllamaModelMetadata from models.ollama_model_status import OllamaModelStatus from models.sql.ollama_pull_status import OllamaPullStatus -from services.database import get_container_db_async_session +from services.database import get_async_session from utils.ollama import list_pulled_ollama_models OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"]) @@ -29,7 +29,7 @@ async def get_available_models(): async def pull_model( model: str, background_tasks: BackgroundTasks, - session: AsyncSession = Depends(get_container_db_async_session), + session: AsyncSession = Depends(get_async_session), ): if model not in SUPPORTED_OLLAMA_MODELS: diff --git a/electron/servers/fastapi/scripts/fresh_sqlite_migrate.py b/electron/servers/fastapi/scripts/fresh_sqlite_migrate.py index 82923672..9afb1f40 100644 --- a/electron/servers/fastapi/scripts/fresh_sqlite_migrate.py +++ b/electron/servers/fastapi/scripts/fresh_sqlite_migrate.py @@ -54,7 +54,6 @@ def main() -> None: p = _sqlite_file_path(sync_url) if p is not None: paths.append(p) - paths.append(p.parent / "container.db") seen: set[Path] = set() for path in paths: diff --git a/electron/servers/fastapi/services/database.py b/electron/servers/fastapi/services/database.py index 96a5c4bf..049bb322 100644 --- a/electron/servers/fastapi/services/database.py +++ b/electron/servers/fastapi/services/database.py @@ -1,5 +1,4 @@ from collections.abc import AsyncGenerator -import os from sqlalchemy.ext.asyncio import ( AsyncEngine, create_async_engine, @@ -21,7 +20,6 @@ from models.sql.template_create_info import TemplateCreateInfoModel from models.sql.slide import SlideModel from models.sql.webhook_subscription import WebhookSubscription from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs -from utils.get_env import get_app_data_directory_env from utils.get_env import get_migrate_database_on_startup_env @@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: yield session -# Container DB (Lives inside the app data directory) -_app_data_dir = get_app_data_directory_env() or "/tmp/presenton" -container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}" -container_db_engine: AsyncEngine = create_async_engine( - container_db_url, connect_args={"check_same_thread": False} -) -container_db_async_session_maker = async_sessionmaker( - container_db_engine, expire_on_commit=False -) - - -async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]: - async with container_db_async_session_maker() as session: - yield session - - # Create Database and Tables async def create_db_and_tables(): should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"] @@ -76,18 +58,11 @@ async def create_db_and_tables(): TemplateModel.__table__, WebhookSubscription.__table__, AsyncPresentationGenerationTaskModel.__table__, + OllamaPullStatus.__table__, ], ) ) - async with container_db_engine.begin() as conn: - await conn.run_sync( - lambda sync_conn: SQLModel.metadata.create_all( - sync_conn, - tables=[OllamaPullStatus.__table__], - ) - ) - async def dispose_engines(): """Dispose all engine connection pools. @@ -97,4 +72,3 @@ async def dispose_engines(): database and prevent stale / leaked connections. """ await sql_engine.dispose() - await container_db_engine.dispose() diff --git a/servers/fastapi/api/v1/ppt/background_tasks.py b/servers/fastapi/api/v1/ppt/background_tasks.py index 68f7e45e..aa7f4901 100644 --- a/servers/fastapi/api/v1/ppt/background_tasks.py +++ b/servers/fastapi/api/v1/ppt/background_tasks.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from models.ollama_model_status import OllamaModelStatus from models.sql.ollama_pull_status import OllamaPullStatus -from services.database import get_container_db_async_session +from services.database import async_session_maker from utils.ollama import pull_ollama_model @@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str): ) log_event_count = 0 - session = await get_container_db_async_session().__anext__() + async with async_session_maker() as session: + try: + async for event in pull_ollama_model(model): + if "error" in event: + saved_model_status.status = "error" + saved_model_status.done = True + saved_model_status.error = event["error"] + await upsert_ollama_pull_status(session, model, saved_model_status) + return - try: - async for event in pull_ollama_model(model): - if "error" in event: - saved_model_status.status = "error" - saved_model_status.done = True - saved_model_status.error = event["error"] - await upsert_ollama_pull_status(session, model, saved_model_status) - await session.close() - return + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue - log_event_count += 1 - if log_event_count != 1 and log_event_count % 20 != 0: - continue + if "completed" in event: + saved_model_status.downloaded = event["completed"] - if "completed" in event: - saved_model_status.downloaded = event["completed"] + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] - if not saved_model_status.size and "total" in event: - saved_model_status.size = event["total"] + if "status" in event: + saved_model_status.status = event["status"] - if "status" in event: - saved_model_status.status = event["status"] + await upsert_ollama_pull_status(session, model, saved_model_status) - await upsert_ollama_pull_status(session, model, saved_model_status) + except Exception as e: + saved_model_status.status = "error" + saved_model_status.done = True + saved_model_status.error = str(e) + await upsert_ollama_pull_status(session, model, saved_model_status) + raise HTTPException( + status_code=500, + detail=f"Failed to pull model: {e}", + ) - except Exception as e: - saved_model_status.status = "error" saved_model_status.done = True - saved_model_status.error = str(e) + saved_model_status.status = "pulled" + saved_model_status.downloaded = saved_model_status.size + saved_model_status.error = None + await upsert_ollama_pull_status(session, model, saved_model_status) - await session.close() - raise HTTPException( - status_code=500, - detail=f"Failed to pull model: {e}", - ) - - saved_model_status.done = True - saved_model_status.status = "pulled" - saved_model_status.downloaded = saved_model_status.size - saved_model_status.error = None - - await upsert_ollama_pull_status(session, model, saved_model_status) - await session.close() async def upsert_ollama_pull_status( diff --git a/servers/fastapi/api/v1/ppt/endpoints/chat.py b/servers/fastapi/api/v1/ppt/endpoints/chat.py index 49d25de1..88a27026 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/chat.py +++ b/servers/fastapi/api/v1/ppt/endpoints/chat.py @@ -1,10 +1,17 @@ import json +import uuid -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession -from models.chat import ChatMessageRequest, ChatMessageResponse +from models.chat import ( + ChatConversationListItem, + ChatHistoryMessageItem, + ChatHistoryResponse, + ChatMessageRequest, + ChatMessageResponse, +) from models.sse_response import ( SSECompleteResponse, SSEErrorResponse, @@ -13,11 +20,57 @@ from models.sse_response import ( SSEResponse, ) from services.chat import ChatTurnResult, PresentationChatService +from services.chat import sql_chat_history from services.database import get_async_session CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"]) +@CHAT_ROUTER.get("/conversations", response_model=list[ChatConversationListItem]) +async def list_chat_conversations( + presentation_id: uuid.UUID = Query(..., description="Presentation id"), + sql_session: AsyncSession = Depends(get_async_session), +): + raw = await sql_chat_history.list_conversations( + sql_session, presentation_id=presentation_id + ) + return [ + ChatConversationListItem( + conversation_id=uuid.UUID(str(item["conversation_id"])), + updated_at=item.get("updated_at"), + last_message_preview=item.get("last_message_preview"), + ) + for item in raw + ] + + +@CHAT_ROUTER.get("/history", response_model=ChatHistoryResponse) +async def get_chat_history( + presentation_id: uuid.UUID = Query(..., description="Presentation id"), + conversation_id: uuid.UUID = Query(..., description="Conversation thread id"), + sql_session: AsyncSession = Depends(get_async_session), +): + rows = await sql_chat_history.load_messages_with_meta( + sql_session, + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + return ChatHistoryResponse( + presentation_id=presentation_id, + conversation_id=conversation_id, + messages=[ + ChatHistoryMessageItem( + role=str(m.get("role") or ""), + content=str(m.get("content") or ""), + created_at=m.get("created_at") + if isinstance(m.get("created_at"), str) + else None, + ) + for m in rows + ], + ) + + @CHAT_ROUTER.post("/message", response_model=ChatMessageResponse) async def chat_message( payload: ChatMessageRequest, diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index 0dafa3e1..43505831 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS from models.ollama_model_metadata import OllamaModelMetadata from models.ollama_model_status import OllamaModelStatus from models.sql.ollama_pull_status import OllamaPullStatus -from services.database import get_container_db_async_session +from services.database import get_async_session from utils.ollama import list_pulled_ollama_models OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"]) @@ -29,7 +29,7 @@ async def get_available_models(): async def pull_model( model: str, background_tasks: BackgroundTasks, - session: AsyncSession = Depends(get_container_db_async_session), + session: AsyncSession = Depends(get_async_session), ): if model not in SUPPORTED_OLLAMA_MODELS: diff --git a/servers/fastapi/models/chat.py b/servers/fastapi/models/chat.py index 98906118..de4139e9 100644 --- a/servers/fastapi/models/chat.py +++ b/servers/fastapi/models/chat.py @@ -18,3 +18,27 @@ class ChatMessageResponse(BaseModel): tool_calls: list[str] = Field(default_factory=list) model_config = ConfigDict(extra="forbid") + + +class ChatHistoryMessageItem(BaseModel): + role: str + content: str + created_at: Optional[str] = None + + model_config = ConfigDict(extra="forbid") + + +class ChatHistoryResponse(BaseModel): + presentation_id: uuid.UUID + conversation_id: uuid.UUID + messages: list[ChatHistoryMessageItem] + + model_config = ConfigDict(extra="forbid") + + +class ChatConversationListItem(BaseModel): + conversation_id: uuid.UUID + updated_at: Optional[str] = None + last_message_preview: Optional[str] = None + + model_config = ConfigDict(extra="forbid") diff --git a/servers/fastapi/services/chat/__init__.py b/servers/fastapi/services/chat/__init__.py index 2dcbe377..b11ffd02 100644 --- a/servers/fastapi/services/chat/__init__.py +++ b/servers/fastapi/services/chat/__init__.py @@ -1,6 +1,8 @@ from services.chat.service import ChatTurnResult, PresentationChatService +from services.chat.presentation_context_store import PresentationContextStore __all__ = [ "ChatTurnResult", "PresentationChatService", + "PresentationContextStore", ] diff --git a/servers/fastapi/services/chat/chat_memory_store.py b/servers/fastapi/services/chat/chat_memory_store.py new file mode 100644 index 00000000..7fef329d --- /dev/null +++ b/servers/fastapi/services/chat/chat_memory_store.py @@ -0,0 +1,324 @@ +import asyncio +from datetime import datetime, timezone +import logging +import os +from typing import Any, Optional +from uuid import UUID + +from services.mem0_oss_memory import get_shared_mem0_client + +LOGGER = logging.getLogger(__name__) +CHAT_TURN_TAG = "[chat_turn]" +DEFAULT_MAX_STORED_TURNS = 20 + + +class ChatMemoryStore: + def __init__(self): + self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True) + self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8) + self._max_context_chars = self._to_int( + os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000 + ) + self._max_stored_turns = self._to_int( + os.getenv("CHAT_MAX_STORED_TURNS"), default=DEFAULT_MAX_STORED_TURNS + ) + self._namespace_prefix = ( + os.getenv("MEM0_CHAT_NAMESPACE_PREFIX") + or os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") + or "presentation" + ).strip() or "presentation" + + @staticmethod + def _to_bool(value: Optional[str], default: bool = False) -> bool: + if value is None: + return default + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + @staticmethod + def _to_int(value: Optional[str], default: int) -> int: + try: + parsed = int(value) if value is not None else default + return max(1, parsed) + except Exception: + return default + + @staticmethod + def _normalize(value: str) -> str: + return " ".join((value or "").split()) + + @staticmethod + def _is_nonfatal_mem0_error(exc: BaseException) -> bool: + return isinstance(exc, (Exception, SystemExit)) + + def _scope_user_id(self, presentation_id: UUID, conversation_id: UUID) -> str: + return ( + f"{self._namespace_prefix}:{presentation_id}:" + f"conversation:{conversation_id}" + ) + + def _truncate(self, text: str, limit: int = 20000) -> str: + if len(text) <= limit: + return text + return f"{text[:limit]}\n\n[TRUNCATED]" + + async def _get_client(self): + if not self._enabled: + return None + return get_shared_mem0_client() + + def _build_turn_payload(self, *, user_text: str, assistant_text: str) -> str: + memory_lines = [ + CHAT_TURN_TAG, + f"turn_created_at={datetime.now(timezone.utc).isoformat()}", + ] + if user_text: + memory_lines.append(f"user={user_text}") + if assistant_text: + memory_lines.append(f"assistant={assistant_text}") + return "\n".join(memory_lines) + + @staticmethod + def _extract_text_field(item: dict[str, Any]) -> str: + memory_text = item.get("memory") or item.get("text") or item.get("data") + return str(memory_text).strip() if memory_text is not None else "" + + def _collect_results(self, response: Any) -> list[dict[str, Any]]: + if isinstance(response, dict): + raw_results = ( + response.get("results") + or response.get("memories") + or response.get("items") + or [] + ) + if isinstance(raw_results, list): + return [item for item in raw_results if isinstance(item, dict)] + return [] + if isinstance(response, list): + return [item for item in response if isinstance(item, dict)] + return [] + + @staticmethod + def _safe_parse_datetime(raw_value: Any) -> datetime | None: + if not isinstance(raw_value, str) or not raw_value.strip(): + return None + value = raw_value.strip().replace("Z", "+00:00") + try: + parsed = datetime.fromisoformat(value) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed + except Exception: + return None + + @staticmethod + def _extract_chat_turn_fields(text: str) -> tuple[str | None, str | None, datetime | None]: + if CHAT_TURN_TAG not in text: + return None, None, None + + user_text: str | None = None + assistant_text: str | None = None + turn_created_at: datetime | None = None + for line in text.splitlines(): + if line.startswith("user="): + user_text = line[len("user=") :].strip() + elif line.startswith("assistant="): + assistant_text = line[len("assistant=") :].strip() + elif line.startswith("turn_created_at="): + turn_created_at = ChatMemoryStore._safe_parse_datetime( + line[len("turn_created_at=") :].strip() + ) + return user_text, assistant_text, turn_created_at + + async def store_chat_turn( + self, + *, + presentation_id: UUID, + conversation_id: UUID, + user_message: str, + assistant_message: str, + ) -> None: + client = await self._get_client() + if client is None: + return + + user_text = self._normalize(user_message) + assistant_text = self._normalize(assistant_message) + if not user_text and not assistant_text: + return + + payload = [ + { + "role": "user", + "content": self._truncate( + self._build_turn_payload( + user_text=user_text, + assistant_text=assistant_text, + ) + ), + } + ] + scoped_user_id = self._scope_user_id(presentation_id, conversation_id) + + def _add(): + try: + return client.add(payload, user_id=scoped_user_id, infer=False) + except TypeError: + return client.add( + messages=payload, + user_id=scoped_user_id, + infer=False, + ) + + try: + await asyncio.to_thread(_add) + except BaseException as exc: + if not self._is_nonfatal_mem0_error(exc): + raise + LOGGER.exception( + ( + "Failed to add chat mem0 memory " + "(presentation_id=%s, conversation_id=%s)" + ), + presentation_id, + conversation_id, + ) + + async def retrieve_context( + self, + *, + presentation_id: UUID, + conversation_id: UUID, + query: str, + ) -> str: + client = await self._get_client() + if client is None: + return "" + + trimmed_query = (query or "").strip() + if not trimmed_query: + return "" + + scoped_user_id = self._scope_user_id(presentation_id, conversation_id) + + def _search(): + try: + return client.search( + trimmed_query, + filters={"user_id": scoped_user_id}, + top_k=self._top_k, + ) + except TypeError: + return client.search( + trimmed_query, + user_id=scoped_user_id, + top_k=self._top_k, + ) + + try: + response = await asyncio.to_thread(_search) + except BaseException as exc: + if not self._is_nonfatal_mem0_error(exc): + raise + LOGGER.exception( + ( + "Failed to search chat mem0 memory " + "(presentation_id=%s, conversation_id=%s)" + ), + presentation_id, + conversation_id, + ) + return "" + + results = self._collect_results(response) + memories: list[str] = [] + for item in results: + normalized = self._extract_text_field(item) + if normalized: + memories.append(normalized) + + if not memories: + return "" + + deduped = list(dict.fromkeys(memories)) + return self._truncate("\n\n".join(deduped), self._max_context_chars) + + async def load_history( + self, + *, + presentation_id: UUID, + conversation_id: UUID, + ) -> list[dict[str, str]]: + client = await self._get_client() + if client is None: + return [] + + scoped_user_id = self._scope_user_id(presentation_id, conversation_id) + + def _get_all(): + try: + return client.get_all( + filters={"user_id": scoped_user_id}, + limit=max(10, self._max_stored_turns * 4), + ) + except TypeError: + try: + return client.get_all( + user_id=scoped_user_id, + limit=max(10, self._max_stored_turns * 4), + ) + except TypeError: + try: + return client.get_all(filters={"user_id": scoped_user_id}) + except TypeError: + return client.get_all(user_id=scoped_user_id) + + try: + response = await asyncio.to_thread(_get_all) + except BaseException as exc: + if not self._is_nonfatal_mem0_error(exc): + raise + LOGGER.exception( + ( + "Failed to load chat mem0 history " + "(presentation_id=%s, conversation_id=%s)" + ), + presentation_id, + conversation_id, + ) + return [] + + results = self._collect_results(response) + ordered_turns: list[tuple[datetime, str, str]] = [] + for index, item in enumerate(results): + text_value = self._extract_text_field(item) + if not text_value: + continue + user_text, assistant_text, embedded_timestamp = self._extract_chat_turn_fields( + text_value + ) + if not user_text and not assistant_text: + continue + + item_created_at = ( + self._safe_parse_datetime(item.get("created_at")) + or self._safe_parse_datetime(item.get("updated_at")) + or self._safe_parse_datetime(item.get("event_at")) + ) + timestamp = embedded_timestamp or item_created_at or datetime.fromtimestamp( + index, tz=timezone.utc + ) + ordered_turns.append((timestamp, user_text or "", assistant_text or "")) + + ordered_turns.sort(key=lambda turn: turn[0]) + recent_turns = ordered_turns[-self._max_stored_turns :] + + history: list[dict[str, str]] = [] + for _, user_text, assistant_text in recent_turns: + if user_text: + history.append({"role": "user", "content": user_text}) + if assistant_text: + history.append({"role": "assistant", "content": assistant_text}) + return history + + +CHAT_MEMORY_STORE = ChatMemoryStore() diff --git a/servers/fastapi/services/chat/conversation_store.py b/servers/fastapi/services/chat/conversation_store.py index 8f4f1878..42b1c881 100644 --- a/servers/fastapi/services/chat/conversation_store.py +++ b/servers/fastapi/services/chat/conversation_store.py @@ -1,44 +1,14 @@ -from datetime import datetime, timezone import uuid -from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import select -from models.sql.key_value import KeyValueSqlModel - -CHAT_CONVERSATION_KEY_PREFIX = "chat_conversation" -MAX_STORED_TURNS = 20 - - -class ConversationMessage(BaseModel): - role: Literal["user", "assistant"] - content: str = Field(min_length=1) - - model_config = ConfigDict(extra="forbid", strict=True) - - @field_validator("content") - @classmethod - def normalize_content(cls, value: str) -> str: - normalized = value.strip() - if not normalized: - raise ValueError("Conversation message cannot be empty.") - return normalized - - -class ConversationPayload(BaseModel): - conversation_id: uuid.UUID - presentation_id: uuid.UUID - messages: list[ConversationMessage] = Field(default_factory=list) - updated_at: datetime - - model_config = ConfigDict(extra="forbid", strict=True) +from services.chat.chat_memory_store import CHAT_MEMORY_STORE +from services.chat import sql_chat_history class ChatConversationStore: def __init__(self, sql_session: AsyncSession): - self._sql_session = sql_session + self._sql = sql_session async def load_history( self, @@ -46,13 +16,25 @@ class ChatConversationStore: presentation_id: uuid.UUID, conversation_id: uuid.UUID, ) -> list[dict[str, str]]: - payload = await self._get_payload(conversation_id) - if not payload or payload.presentation_id != presentation_id: - return [] - return [ - {"role": message.role, "content": message.content} - for message in payload.messages - ] + messages = await sql_chat_history.load_messages( + self._sql, + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + if messages: + return messages + legacy = await CHAT_MEMORY_STORE.load_history( + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + if legacy: + await sql_chat_history.replace_messages( + self._sql, + presentation_id=presentation_id, + conversation_id=conversation_id, + messages=legacy, + ) + return legacy async def append_turn( self, @@ -62,87 +44,35 @@ class ChatConversationStore: user_message: str, assistant_message: str, ) -> None: - payload = await self._get_payload(conversation_id) - messages = list(payload.messages) if payload else [] - - messages.append(ConversationMessage(role="user", content=user_message)) - messages.append(ConversationMessage(role="assistant", content=assistant_message)) - max_messages = MAX_STORED_TURNS * 2 - messages = messages[-max_messages:] - - next_payload = ConversationPayload( - conversation_id=conversation_id, + await sql_chat_history.append_turn( + self._sql, presentation_id=presentation_id, - messages=messages, - updated_at=datetime.now(timezone.utc), + conversation_id=conversation_id, + user_message=user_message, + assistant_message=assistant_message, + ) + await CHAT_MEMORY_STORE.store_chat_turn( + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message=user_message, + assistant_message=assistant_message, ) - row = await self._get_row(conversation_id) - if row: - row.value = next_payload.model_dump(mode="json") - self._sql_session.add(row) - else: - self._sql_session.add( - KeyValueSqlModel( - key=self._conversation_key(conversation_id), - value=next_payload.model_dump(mode="json"), - ) - ) - - await self._sql_session.commit() + async def retrieve_semantic_context( + self, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, + query: str, + ) -> str: + return await CHAT_MEMORY_STORE.retrieve_context( + presentation_id=presentation_id, + conversation_id=conversation_id, + query=query, + ) async def ensure_conversation_id( self, conversation_id: uuid.UUID | None, ) -> uuid.UUID: return conversation_id or uuid.uuid4() - - async def _get_payload( - self, conversation_id: uuid.UUID - ) -> ConversationPayload | None: - row = await self._get_row(conversation_id) - if not row or not isinstance(row.value, dict): - return None - - raw_payload: dict[str, Any] = row.value - try: - return ConversationPayload.model_validate(raw_payload) - except ValidationError: - return self._coerce_payload(raw_payload) - - def _coerce_payload(self, payload: dict[str, Any]) -> ConversationPayload | None: - try: - conversation_id = uuid.UUID(str(payload.get("conversation_id"))) - presentation_id = uuid.UUID(str(payload.get("presentation_id"))) - except Exception: - return None - - raw_messages = payload.get("messages") - messages: list[ConversationMessage] = [] - if isinstance(raw_messages, list): - for entry in raw_messages: - if not isinstance(entry, dict): - continue - try: - message = ConversationMessage.model_validate(entry) - messages.append(message) - except ValidationError: - continue - - return ConversationPayload( - conversation_id=conversation_id, - presentation_id=presentation_id, - messages=messages[-(MAX_STORED_TURNS * 2) :], - updated_at=datetime.now(timezone.utc), - ) - - async def _get_row(self, conversation_id: uuid.UUID) -> KeyValueSqlModel | None: - return await self._sql_session.scalar( - select(KeyValueSqlModel).where( - KeyValueSqlModel.key == self._conversation_key(conversation_id) - ) - ) - - @staticmethod - def _conversation_key(conversation_id: uuid.UUID) -> str: - return f"{CHAT_CONVERSATION_KEY_PREFIX}:{conversation_id}" diff --git a/servers/fastapi/services/chat/presentation_context_store.py b/servers/fastapi/services/chat/presentation_context_store.py new file mode 100644 index 00000000..7100be21 --- /dev/null +++ b/servers/fastapi/services/chat/presentation_context_store.py @@ -0,0 +1,5 @@ +from services.chat.memory_layer import ( + PresentationChatMemoryLayer as PresentationContextStore, +) + +__all__ = ["PresentationContextStore"] diff --git a/servers/fastapi/services/chat/prompts.py b/servers/fastapi/services/chat/prompts.py index 6aacdbf6..0f9de2f9 100644 --- a/servers/fastapi/services/chat/prompts.py +++ b/servers/fastapi/services/chat/prompts.py @@ -1,13 +1,31 @@ -def build_system_prompt(memory_context: str) -> str: - context_block = ( - "\nMemory context (use only when relevant):\n" - f"{memory_context}\n" - if memory_context - else "" +def _trim_block(label: str, text: str) -> str: + t = (text or "").strip() + if not t: + return "" + return f"\n{label}\n(use only when relevant; may be partial)\n{t}\n" + + +def build_system_prompt( + presentation_memory_context: str, + chat_memory_context: str, +) -> str: + """ + presentation_memory_context: deck-scoped (documents, outlines, prior slide edits, etc.) + chat_memory_context: this thread (prior asks, assistant replies) — semantic slice. + """ + presentation_block = _trim_block( + "Presentation memory (this deck: source text, outlines, and stored slide/edit notes):", + presentation_memory_context, + ) + chat_block = _trim_block( + "This conversation thread (what was asked and answered in chat):", + chat_memory_context, ) return ( "You are Presenton backend chat assistant.\n" - "You can call tools to access presentation memory.\n" + "You can call tools to access live slide data and layouts.\n" + "Distinguish: presentation memory = facts about the deck; chat memory = this thread’s prior Q&A. " + "Tools still win for current slide content.\n" "- Use getPresentationOutline for outline/section questions.\n" "- Prefer compact tool outputs to save context window; do not request full slide JSON unless needed.\n" "- Use searchSlides for finding relevant slide content snippets (DB-backed).\n" @@ -21,5 +39,6 @@ def build_system_prompt(memory_context: str) -> str: "- After tool outputs are sufficient, stop calling tools and provide a final answer.\n" "- If memory is missing, state that clearly and suggest next steps.\n" "- Do not invent slide facts that are not in tool results or memory.\n" - f"{context_block}" + f"{presentation_block}" + f"{chat_block}" ) diff --git a/servers/fastapi/services/chat/service.py b/servers/fastapi/services/chat/service.py index c754691c..71413c7d 100644 --- a/servers/fastapi/services/chat/service.py +++ b/servers/fastapi/services/chat/service.py @@ -19,7 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from models.sql.presentation import PresentationModel from services.chat.conversation_store import ChatConversationStore -from services.chat.memory_layer import PresentationChatMemoryLayer +from services.chat.presentation_context_store import PresentationContextStore from services.chat.prompts import build_system_prompt from services.chat.tools import ChatTools from utils.llm_client_error_handler import handle_llm_client_exceptions @@ -58,7 +58,7 @@ class PresentationChatService: self._conversation_id = conversation_id self._conversation_store = ChatConversationStore(sql_session) - self._memory = PresentationChatMemoryLayer(sql_session, presentation_id) + self._memory = PresentationContextStore(sql_session, presentation_id) self._tools = ChatTools(self._memory) async def generate_reply(self, user_message: str) -> ChatTurnResult: @@ -209,6 +209,8 @@ class PresentationChatService: if not presentation: raise HTTPException(status_code=404, detail="Presentation not found") + # A stable conversation_id is created here before the first user message + # so mem0 and SQL can scope the thread; the client need not "chat" first. conversation_id = await self._conversation_store.ensure_conversation_id( self._conversation_id ) @@ -218,9 +220,19 @@ class PresentationChatService: ) history_messages = self._convert_history_to_messages(history) - memory_context = await self._memory.retrieve_context(user_message) + presentation_memory = await self._memory.retrieve_context(user_message) + chat_memory = await self._conversation_store.retrieve_semantic_context( + presentation_id=self._presentation_id, + conversation_id=conversation_id, + query=user_message, + ) messages: list[Message] = [ - SystemMessage(content=build_system_prompt(memory_context)), + SystemMessage( + content=build_system_prompt( + presentation_memory_context=presentation_memory, + chat_memory_context=chat_memory, + ) + ), *history_messages, UserMessage(content=user_message), ] diff --git a/servers/fastapi/services/chat/sql_chat_history.py b/servers/fastapi/services/chat/sql_chat_history.py new file mode 100644 index 00000000..4c8f98ac --- /dev/null +++ b/servers/fastapi/services/chat/sql_chat_history.py @@ -0,0 +1,231 @@ +"""Persist presentation chat threads in ``KeyValueSqlModel``. + +Each conversation is one row: key ``ppt_chat:{presentation_id}:{conversation_id}``, +value is JSON: ``{version, messages, updated_at, ...}``. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from models.sql.key_value import KeyValueSqlModel + +CHAT_HISTORY_KEY_PREFIX = "ppt_chat" +SCHEMA_VERSION = 1 + + +def chat_history_key(presentation_id: uuid.UUID, conversation_id: uuid.UUID) -> str: + return f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:{conversation_id}" + + +def _parse_conversation_key(key: str, presentation_id: uuid.UUID) -> uuid.UUID | None: + expected_prefix = f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:" + if not key.startswith(expected_prefix): + return None + rest = key[len(expected_prefix) :] + try: + return uuid.UUID(rest) + except ValueError: + return None + + +def _utc_now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +async def load_messages( + session: AsyncSession, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, +) -> list[dict[str, str]]: + """Load ordered user/assistant messages for LLM context (role + content only).""" + key = chat_history_key(presentation_id, conversation_id) + row = await session.scalar( + select(KeyValueSqlModel).where(KeyValueSqlModel.key == key) + ) + if not row or not isinstance(row.value, dict): + return [] + messages = row.value.get("messages") + if not isinstance(messages, list): + return [] + out: list[dict[str, str]] = [] + for item in messages: + if not isinstance(item, dict): + continue + if item.get("role") not in ("user", "assistant"): + continue + content = item.get("content") + if not isinstance(content, str) or not content.strip(): + continue + out.append({"role": item["role"], "content": content}) + return out + + +async def load_messages_with_meta( + session: AsyncSession, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, +) -> list[dict[str, Any]]: + """Load messages including optional ``created_at`` for API / UI.""" + key = chat_history_key(presentation_id, conversation_id) + row = await session.scalar( + select(KeyValueSqlModel).where(KeyValueSqlModel.key == key) + ) + if not row or not isinstance(row.value, dict): + return [] + messages = row.value.get("messages") + if not isinstance(messages, list): + return [] + out: list[dict[str, Any]] = [] + for item in messages: + if not isinstance(item, dict): + continue + if item.get("role") not in ("user", "assistant"): + continue + content = item.get("content") + if not isinstance(content, str) or not content.strip(): + continue + entry: dict[str, Any] = { + "role": item["role"], + "content": content, + } + created = item.get("created_at") + if isinstance(created, str) and created.strip(): + entry["created_at"] = created.strip() + out.append(entry) + return out + + +async def replace_messages( + session: AsyncSession, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, + messages: list[dict[str, str]], +) -> None: + """Replace transcript (e.g. one-time sync from mem0).""" + key = chat_history_key(presentation_id, conversation_id) + row = await session.scalar( + select(KeyValueSqlModel).where(KeyValueSqlModel.key == key) + ) + now = _utc_now_iso() + built: list[dict[str, Any]] = [] + for m in messages: + if m.get("role") not in ("user", "assistant"): + continue + content = m.get("content") + if not isinstance(content, str): + continue + built.append( + { + "role": m["role"], + "content": content, + "created_at": now, + } + ) + value = { + "version": SCHEMA_VERSION, + "presentation_id": str(presentation_id), + "conversation_id": str(conversation_id), + "messages": built, + "updated_at": now, + } + if row is None: + session.add(KeyValueSqlModel(key=key, value=value)) + else: + row.value = value + await session.flush() + + +async def append_turn( + session: AsyncSession, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, + user_message: str, + assistant_message: str, +) -> None: + key = chat_history_key(presentation_id, conversation_id) + row = await session.scalar( + select(KeyValueSqlModel).where(KeyValueSqlModel.key == key) + ) + now = _utc_now_iso() + new_messages: list[dict[str, Any]] = [ + {"role": "user", "content": user_message, "created_at": now}, + {"role": "assistant", "content": assistant_message, "created_at": now}, + ] + if row is None: + value: dict[str, Any] = { + "version": SCHEMA_VERSION, + "presentation_id": str(presentation_id), + "conversation_id": str(conversation_id), + "messages": new_messages, + "updated_at": now, + } + session.add(KeyValueSqlModel(key=key, value=value)) + else: + data = row.value if isinstance(row.value, dict) else {} + existing = data.get("messages") + if not isinstance(existing, list): + existing = [] + combined = [m for m in existing if isinstance(m, dict)] + combined.extend(new_messages) + data["version"] = SCHEMA_VERSION + data["presentation_id"] = str(presentation_id) + data["conversation_id"] = str(conversation_id) + data["messages"] = combined + data["updated_at"] = now + row.value = data + await session.flush() + + +async def list_conversations( + session: AsyncSession, *, presentation_id: uuid.UUID +) -> list[dict[str, Any]]: + """Return conversation summaries for a presentation, newest ``updated_at`` first.""" + prefix = f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:" + result = await session.scalars( + select(KeyValueSqlModel).where(KeyValueSqlModel.key.startswith(prefix)) + ) + rows = list(result.all()) + summaries: list[dict[str, Any]] = [] + for row in rows: + cid = _parse_conversation_key(row.key, presentation_id) + if cid is None: + continue + data = row.value if isinstance(row.value, dict) else {} + updated_at: str | None = None + raw_u = data.get("updated_at") + if isinstance(raw_u, str) and raw_u.strip(): + updated_at = raw_u.strip() + messages = data.get("messages") + preview: str | None = None + if isinstance(messages, list) and messages: + for item in reversed(messages): + if not isinstance(item, dict): + continue + c = item.get("content") + if isinstance(c, str) and c.strip(): + preview = c.strip() + if len(preview) > 200: + preview = f"{preview[:200]}…" + break + summaries.append( + { + "conversation_id": str(cid), + "updated_at": updated_at, + "last_message_preview": preview, + } + ) + summaries.sort( + key=lambda s: s.get("updated_at") or "", + reverse=True, + ) + return summaries diff --git a/servers/fastapi/services/chat/tools.py b/servers/fastapi/services/chat/tools.py index c0b06a6d..1265b553 100644 --- a/servers/fastapi/services/chat/tools.py +++ b/servers/fastapi/services/chat/tools.py @@ -15,7 +15,7 @@ from services.chat.schemas import ( SaveSlideInput, SearchSlidesInput, ) -from services.chat.memory_layer import PresentationChatMemoryLayer +from services.chat.presentation_context_store import PresentationContextStore LOGGER = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class ChatTools: provider-specific logic, keeping them portable across llmai backends. """ - def __init__(self, memory: PresentationChatMemoryLayer): + def __init__(self, memory: PresentationContextStore): self._memory = memory self._tool_handlers: dict[str, ToolHandler] = { "getPresentationOutline": self._get_presentation_outline, diff --git a/servers/fastapi/services/database.py b/servers/fastapi/services/database.py index 6bd6aaf1..adf3080c 100644 --- a/servers/fastapi/services/database.py +++ b/servers/fastapi/services/database.py @@ -1,5 +1,4 @@ from collections.abc import AsyncGenerator -import os from sqlalchemy.ext.asyncio import ( AsyncEngine, create_async_engine, @@ -20,7 +19,6 @@ from models.sql.template import TemplateModel from models.sql.template_create_info import TemplateCreateInfoModel from models.sql.slide import SlideModel from models.sql.webhook_subscription import WebhookSubscription -from utils.get_env import get_app_data_directory_env from utils.get_env import get_migrate_database_on_startup_env from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs @@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: yield session -# Container DB (Lives inside the app data directory) -_app_data_dir = get_app_data_directory_env() or "/tmp/presenton" -container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}" -container_db_engine: AsyncEngine = create_async_engine( - container_db_url, connect_args={"check_same_thread": False} -) -container_db_async_session_maker = async_sessionmaker( - container_db_engine, expire_on_commit=False -) - - -async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]: - async with container_db_async_session_maker() as session: - yield session - - # Create Database and Tables async def create_db_and_tables(): should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"] @@ -76,18 +58,11 @@ async def create_db_and_tables(): TemplateModel.__table__, WebhookSubscription.__table__, AsyncPresentationGenerationTaskModel.__table__, + OllamaPullStatus.__table__, ], ) ) - async with container_db_engine.begin() as conn: - await conn.run_sync( - lambda sync_conn: SQLModel.metadata.create_all( - sync_conn, - tables=[OllamaPullStatus.__table__], - ) - ) - async def dispose_engines(): """Dispose all engine connection pools. @@ -97,4 +72,3 @@ async def dispose_engines(): database and prevent stale / leaked connections. """ await sql_engine.dispose() - await container_db_engine.dispose() diff --git a/servers/fastapi/services/mem0_oss_memory.py b/servers/fastapi/services/mem0_oss_memory.py new file mode 100644 index 00000000..b51cc739 --- /dev/null +++ b/servers/fastapi/services/mem0_oss_memory.py @@ -0,0 +1,131 @@ +"""Single shared mem0 OSS ``Memory`` client for the process. + +All callers (presentation context, chat turns) use the same on-disk Qdrant/SQLite +and distinguish data via mem0 ``user_id``: + +- Deck-level (no chat thread): ``{namespace}:{presentation_id}`` +- Chat thread: ``{namespace}:{presentation_id}:conversation:{conversation_id}`` + +The chat flow calls ``ensure_conversation_id`` before the first turn, so a +``conversation_id`` exists before any mem0 write for that thread. +""" + +from __future__ import annotations + +import logging +import os +import threading +from importlib import import_module +from typing import Any, Optional + +LOGGER = logging.getLogger(__name__) + +_memory_init_lock = threading.Lock() +_shared_client: Any | None = None +_init_attempted = False + + +def _to_bool(value: Optional[str], default: bool = False) -> bool: + if value is None: + return default + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _to_int(value: Optional[str], default: int) -> int: + try: + parsed = int(value) if value is not None else default + return max(1, parsed) + except Exception: + return default + + +def _oss_config_from_env() -> tuple[str, str, str, str, int, dict[str, Any]]: + """Return (mem0_dir, qdrant_path, history_db, collection, dims, from_config_dict).""" + app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip() + mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip() + qdrant_path = ( + os.getenv("MEM0_QDRANT_PATH") or os.path.join(mem0_dir, "qdrant") + ).strip() + history_db_path = ( + os.getenv("MEM0_HISTORY_DB_PATH") or os.path.join(mem0_dir, "history.db") + ).strip() + collection = ( + os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories" + ).strip() or "presenton_memories" + embedder = (os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed").strip() or "fastembed" + model = ( + os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5" + ).strip() or "BAAI/bge-small-en-v1.5" + dims = _to_int(os.getenv("MEM0_EMBEDDING_DIMS"), default=384) + config: dict[str, Any] = { + "vector_store": { + "provider": "qdrant", + "config": { + "collection_name": collection, + "path": qdrant_path, + "on_disk": True, + "embedding_model_dims": dims, + }, + }, + "embedder": { + "provider": embedder, + "config": { + "model": model, + "embedding_dims": dims, + }, + }, + "history_db_path": history_db_path, + } + return mem0_dir, qdrant_path, history_db_path, collection, dims, config + + +def memory_from_config(config: dict[str, Any], *, telemetry_base: str) -> Any: + """Construct ``mem0.Memory``. Caller must hold ``_memory_init_lock`` if used with shared state.""" + os.makedirs(telemetry_base, exist_ok=True) + import mem0.memory.main as mem0_main # type: ignore[import-untyped] + + mem0_main.mem0_dir = telemetry_base + memory_cls = getattr(import_module("mem0"), "Memory") + return memory_cls.from_config(config) + + +def get_shared_mem0_client() -> Any | None: + """Return the process-wide mem0 client, or ``None`` if disabled or init failed.""" + global _shared_client, _init_attempted + + if not _to_bool(os.getenv("MEM0_ENABLED"), default=True): + return None + if _shared_client is not None: + return _shared_client + if _init_attempted: + return None + + with _memory_init_lock: + if _shared_client is not None: + return _shared_client + if _init_attempted: + return None + _init_attempted = True + try: + mem0_dir, qdrant_path, history_db, collection, dims, config = ( + _oss_config_from_env() + ) + os.makedirs(mem0_dir, exist_ok=True) + os.makedirs(qdrant_path, exist_ok=True) + telemetry_base = os.path.join(mem0_dir, "telemetry", "oss") + _shared_client = memory_from_config( + config, + telemetry_base=telemetry_base, + ) + LOGGER.info( + "Mem0 OSS shared memory initialized (qdrant_path=%s, history_db_path=%s, collection=%s, dims=%s)", + qdrant_path, + history_db, + collection, + dims, + ) + except BaseException: + LOGGER.exception("Failed to initialize shared Mem0 OSS Memory") + _shared_client = None + + return _shared_client diff --git a/servers/fastapi/services/mem0_presentation_memory_service.py b/servers/fastapi/services/mem0_presentation_memory_service.py index 6fafc224..39af9b36 100644 --- a/servers/fastapi/services/mem0_presentation_memory_service.py +++ b/servers/fastapi/services/mem0_presentation_memory_service.py @@ -2,10 +2,11 @@ import asyncio import json import logging import os -from importlib import import_module from typing import Any, Optional from uuid import UUID +from services.mem0_oss_memory import get_shared_mem0_client + LOGGER = logging.getLogger(__name__) @@ -21,31 +22,6 @@ class Mem0PresentationMemoryService: os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") or "presentation" ).strip() or "presentation" - self._embedder_provider = ( - os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed" - ).strip() or "fastembed" - self._embedder_model = ( - os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5" - ).strip() or "BAAI/bge-small-en-v1.5" - self._embedding_dims = self._to_int( - os.getenv("MEM0_EMBEDDING_DIMS"), - default=384, - ) - - app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip() - self._mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip() - self._qdrant_path = (os.getenv("MEM0_QDRANT_PATH") or os.path.join(self._mem0_dir, "qdrant")).strip() - self._history_db_path = ( - os.getenv("MEM0_HISTORY_DB_PATH") - or os.path.join(self._mem0_dir, "history.db") - ).strip() - self._collection_name = ( - os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories" - ).strip() or "presenton_memories" - - self._client: Any = None - self._attempted_client_init = False - @staticmethod def _to_bool(value: Optional[str], default: bool = False) -> bool: if value is None: @@ -68,27 +44,6 @@ class Mem0PresentationMemoryService: return text return f"{text[:limit]}\n\n[TRUNCATED]" - def _get_oss_config(self) -> dict: - return { - "vector_store": { - "provider": "qdrant", - "config": { - "collection_name": self._collection_name, - "path": self._qdrant_path, - "on_disk": True, - "embedding_model_dims": self._embedding_dims, - }, - }, - "embedder": { - "provider": self._embedder_provider, - "config": { - "model": self._embedder_model, - "embedding_dims": self._embedding_dims, - }, - }, - "history_db_path": self._history_db_path, - } - @staticmethod def _is_nonfatal_mem0_error(exc: BaseException) -> bool: return isinstance(exc, (Exception, SystemExit)) @@ -96,42 +51,7 @@ class Mem0PresentationMemoryService: async def _get_client(self): if not self._enabled: return None - - if self._client is not None: - return self._client - - if self._attempted_client_init: - return None - - self._attempted_client_init = True - - try: - module = import_module("mem0") - memory_cls = getattr(module, "Memory") - - os.makedirs(self._mem0_dir, exist_ok=True) - os.makedirs(self._qdrant_path, exist_ok=True) - - config = self._get_oss_config() - - try: - self._client = memory_cls.from_config(config) - except Exception: - # Backward compatibility across mem0 OSS versions. - self._client = memory_cls(config) - - LOGGER.info( - "Mem0 OSS presentation memory service initialized (qdrant_path=%s, history_db_path=%s)", - self._qdrant_path, - self._history_db_path, - ) - except BaseException as exc: - if not self._is_nonfatal_mem0_error(exc): - raise - LOGGER.exception("Failed to initialize Mem0 OSS Memory") - self._client = None - - return self._client + return get_shared_mem0_client() async def _add_message(self, presentation_id: UUID, message: str): client = await self._get_client() diff --git a/servers/fastapi/tests/test_chat_conversation_store.py b/servers/fastapi/tests/test_chat_conversation_store.py new file mode 100644 index 00000000..d549d4f5 --- /dev/null +++ b/servers/fastapi/tests/test_chat_conversation_store.py @@ -0,0 +1,136 @@ +import asyncio +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +from services.chat.conversation_store import ChatConversationStore + + +class TestChatConversationStore: + def test_load_history_reads_sql_first(self): + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + expected_history = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ] + sql_session = MagicMock() + + with patch( + "services.chat.conversation_store.sql_chat_history.load_messages", + new=AsyncMock(return_value=expected_history), + ) as load_sql, patch( + "services.chat.conversation_store.CHAT_MEMORY_STORE.load_history", + new=AsyncMock(), + ) as load_mem0: + store = ChatConversationStore(sql_session) + history = asyncio.run( + store.load_history( + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + ) + + load_sql.assert_awaited_once_with( + sql_session, + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + load_mem0.assert_not_called() + assert history == expected_history + + def test_load_history_falls_back_to_mem0_and_backfills_sql(self): + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + legacy = [ + {"role": "user", "content": "old"}, + {"role": "assistant", "content": "from mem0"}, + ] + sql_session = MagicMock() + + with patch( + "services.chat.conversation_store.sql_chat_history.load_messages", + new=AsyncMock(return_value=[]), + ) as load_sql, patch( + "services.chat.conversation_store.CHAT_MEMORY_STORE.load_history", + new=AsyncMock(return_value=legacy), + ) as load_mem0, patch( + "services.chat.conversation_store.sql_chat_history.replace_messages", + new=AsyncMock(), + ) as replace_messages: + store = ChatConversationStore(sql_session) + history = asyncio.run( + store.load_history( + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + ) + + load_sql.assert_awaited_once() + load_mem0.assert_awaited_once() + replace_messages.assert_awaited_once_with( + sql_session, + presentation_id=presentation_id, + conversation_id=conversation_id, + messages=legacy, + ) + assert history == legacy + + def test_append_turn_persists_sql_and_mem0(self): + sql_session = MagicMock() + store = ChatConversationStore(sql_session) + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + + with patch( + "services.chat.conversation_store.sql_chat_history.append_turn", + new=AsyncMock(), + ) as append_sql, patch( + "services.chat.conversation_store.CHAT_MEMORY_STORE.store_chat_turn", + new=AsyncMock(), + ) as store_mem0: + asyncio.run( + store.append_turn( + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message="Can you improve slide 2?", + assistant_message="Yes, I will tighten the bullet points.", + ) + ) + + append_sql.assert_awaited_once_with( + sql_session, + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message="Can you improve slide 2?", + assistant_message="Yes, I will tighten the bullet points.", + ) + store_mem0.assert_awaited_once_with( + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message="Can you improve slide 2?", + assistant_message="Yes, I will tighten the bullet points.", + ) + + def test_retrieve_semantic_context_delegates_to_chat_memory_store(self): + store = ChatConversationStore(MagicMock()) + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + + with patch( + "services.chat.conversation_store.CHAT_MEMORY_STORE.retrieve_context", + new=AsyncMock(return_value="conversation-scoped context"), + ) as retrieve_context: + context = asyncio.run( + store.retrieve_semantic_context( + presentation_id=presentation_id, + conversation_id=conversation_id, + query="What did we decide?", + ) + ) + + retrieve_context.assert_awaited_once_with( + presentation_id=presentation_id, + conversation_id=conversation_id, + query="What did we decide?", + ) + assert context == "conversation-scoped context" diff --git a/servers/fastapi/tests/test_chat_memory_store.py b/servers/fastapi/tests/test_chat_memory_store.py new file mode 100644 index 00000000..2e7d7806 --- /dev/null +++ b/servers/fastapi/tests/test_chat_memory_store.py @@ -0,0 +1,249 @@ +import asyncio +import uuid +from unittest.mock import patch + +import services.mem0_oss_memory as mem0_oss +from services.chat.chat_memory_store import ChatMemoryStore + + +class FakeMemoryClient: + instances: list["FakeMemoryClient"] = [] + + def __init__(self, config=None): + self.config = config + self.add_calls = [] + self.search_calls = [] + self.get_all_calls = [] + self.next_search_response = {"results": []} + self.next_get_all_response = {"results": []} + FakeMemoryClient.instances.append(self) + + @classmethod + def from_config(cls, config): + return cls(config=config) + + def add(self, *args, **kwargs): + messages = kwargs.get("messages") if "messages" in kwargs else None + if messages is None and args: + messages = args[0] + + self.add_calls.append( + { + "messages": messages, + "user_id": kwargs.get("user_id"), + "infer": kwargs.get("infer"), + } + ) + return {"ok": True} + + def search(self, query, *args, **kwargs): + self.search_calls.append( + { + "query": query, + "filters": kwargs.get("filters"), + "user_id": kwargs.get("user_id"), + "top_k": kwargs.get("top_k"), + } + ) + return self.next_search_response + + def get_all(self, *args, **kwargs): + self.get_all_calls.append( + { + "filters": kwargs.get("filters"), + "user_id": kwargs.get("user_id"), + "limit": kwargs.get("limit"), + } + ) + return self.next_get_all_response + + +def _mem0_oss_fresh() -> None: + mem0_oss._shared_client = None # type: ignore[attr-defined] + mem0_oss._init_attempted = False # type: ignore[attr-defined] + + +class TestChatMemoryStore: + def setup_method(self): + FakeMemoryClient.instances = [] + _mem0_oss_fresh() + + def test_store_chat_turn_uses_conversation_scoped_user_id(self): + with patch.dict( + "os.environ", + { + "MEM0_ENABLED": "true", + "MEM0_TOP_K": "4", + "MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation", + "APP_DATA_DIRECTORY": "/tmp/presenton-test", + }, + clear=False, + ), patch( + "services.chat.chat_memory_store.get_shared_mem0_client", + return_value=FakeMemoryClient.from_config( + { + "vector_store": {"provider": "qdrant", "config": {}}, + "embedder": {"provider": "fastembed", "config": {}}, + } + ), + ): + store = ChatMemoryStore() + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + + asyncio.run( + store.store_chat_turn( + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message="Can you tighten slide 3?", + assistant_message="Yes, I can make it shorter.", + ) + ) + + assert len(FakeMemoryClient.instances) == 1 + client = FakeMemoryClient.instances[0] + assert len(client.add_calls) == 1 + expected_user_id = ( + f"presentation:{presentation_id}:conversation:{conversation_id}" + ) + assert client.add_calls[0]["user_id"] == expected_user_id + assert client.add_calls[0]["infer"] is False + payload = str(client.add_calls[0]["messages"][0]["content"]) + assert "[chat_turn]" in payload + assert "user=Can you tighten slide 3?" in payload + assert "assistant=Yes, I can make it shorter." in payload + + def test_retrieve_context_reads_only_conversation_scoped_user_id(self): + with patch.dict( + "os.environ", + { + "MEM0_ENABLED": "true", + "MEM0_TOP_K": "6", + "MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation", + "APP_DATA_DIRECTORY": "/tmp/presenton-test", + }, + clear=False, + ), patch( + "services.chat.chat_memory_store.get_shared_mem0_client", + return_value=FakeMemoryClient.from_config( + { + "vector_store": {"provider": "qdrant", "config": {}}, + "embedder": {"provider": "fastembed", "config": {}}, + } + ), + ): + store = ChatMemoryStore() + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + expected_user_id = ( + f"presentation:{presentation_id}:conversation:{conversation_id}" + ) + + asyncio.run( + store.store_chat_turn( + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message="First turn", + assistant_message="First answer", + ) + ) + + client = FakeMemoryClient.instances[0] + client.next_search_response = { + "results": [ + {"memory": "Chat memory A"}, + {"memory": "Chat memory A"}, + {"memory": "Chat memory B"}, + ] + } + + context = asyncio.run( + store.retrieve_context( + presentation_id=presentation_id, + conversation_id=conversation_id, + query="What did we decide?", + ) + ) + + assert "Chat memory A" in context + assert "Chat memory B" in context + assert context.count("Chat memory A") == 1 + + assert len(client.search_calls) == 1 + assert client.search_calls[0]["query"] == "What did we decide?" + assert client.search_calls[0]["filters"] == {"user_id": expected_user_id} + assert client.search_calls[0]["top_k"] == 6 + + def test_load_history_reads_conversation_scoped_turns(self): + with patch.dict( + "os.environ", + { + "MEM0_ENABLED": "true", + "MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation", + "APP_DATA_DIRECTORY": "/tmp/presenton-test", + }, + clear=False, + ), patch( + "services.chat.chat_memory_store.get_shared_mem0_client", + return_value=FakeMemoryClient.from_config( + { + "vector_store": {"provider": "qdrant", "config": {}}, + "embedder": {"provider": "fastembed", "config": {}}, + } + ), + ): + store = ChatMemoryStore() + presentation_id = uuid.uuid4() + conversation_id = uuid.uuid4() + expected_user_id = ( + f"presentation:{presentation_id}:conversation:{conversation_id}" + ) + + asyncio.run( + store.store_chat_turn( + presentation_id=presentation_id, + conversation_id=conversation_id, + user_message="Draft intro", + assistant_message="Updated intro done.", + ) + ) + + client = FakeMemoryClient.instances[0] + client.next_get_all_response = { + "results": [ + { + "memory": ( + "[chat_turn]\n" + "turn_created_at=2026-04-25T10:00:00+00:00\n" + "user=Draft intro\nassistant=Updated intro done." + ), + "created_at": "2026-04-25T10:00:01+00:00", + }, + { + "memory": ( + "[chat_turn]\n" + "turn_created_at=2026-04-25T10:01:00+00:00\n" + "user=Add roadmap\nassistant=Roadmap slide added." + ), + "created_at": "2026-04-25T10:01:01+00:00", + }, + ] + } + + history = asyncio.run( + store.load_history( + presentation_id=presentation_id, + conversation_id=conversation_id, + ) + ) + + assert history == [ + {"role": "user", "content": "Draft intro"}, + {"role": "assistant", "content": "Updated intro done."}, + {"role": "user", "content": "Add roadmap"}, + {"role": "assistant", "content": "Roadmap slide added."}, + ] + + assert len(client.get_all_calls) == 1 + assert client.get_all_calls[0]["filters"] == {"user_id": expected_user_id} + assert client.get_all_calls[0]["limit"] >= 10 diff --git a/servers/fastapi/tests/test_mem0_presentation_memory_service.py b/servers/fastapi/tests/test_mem0_presentation_memory_service.py index ee9bf325..6f6ccc7f 100644 --- a/servers/fastapi/tests/test_mem0_presentation_memory_service.py +++ b/servers/fastapi/tests/test_mem0_presentation_memory_service.py @@ -2,11 +2,12 @@ import asyncio import uuid from unittest.mock import patch +import services.mem0_oss_memory as mem0_oss from services.mem0_presentation_memory_service import Mem0PresentationMemoryService class FakeMemoryClient: - instances = [] + instances: list["FakeMemoryClient"] = [] def __init__(self, config=None): self.config = config @@ -45,13 +46,15 @@ class FakeMemoryClient: return self.next_search_response -class FakeMem0Module: - Memory = FakeMemoryClient +def _mem0_oss_fresh() -> None: + mem0_oss._shared_client = None # type: ignore[attr-defined] + mem0_oss._init_attempted = False # type: ignore[attr-defined] class TestMem0PresentationMemoryService: def setup_method(self): FakeMemoryClient.instances = [] + _mem0_oss_fresh() def test_store_generation_context_uses_presentation_scope(self): with patch.dict( @@ -62,8 +65,25 @@ class TestMem0PresentationMemoryService: }, clear=False, ), patch( - "services.mem0_presentation_memory_service.import_module", - return_value=FakeMem0Module, + "services.mem0_presentation_memory_service.get_shared_mem0_client", + return_value=FakeMemoryClient.from_config( + { + "vector_store": { + "provider": "qdrant", + "config": { + "on_disk": True, + "embedding_model_dims": 384, + }, + }, + "embedder": { + "provider": "fastembed", + "config": { + "model": "BAAI/bge-small-en-v1.5", + "embedding_dims": 384, + }, + }, + } + ), ): service = Mem0PresentationMemoryService() presentation_id = uuid.uuid4() @@ -115,8 +135,19 @@ class TestMem0PresentationMemoryService: }, clear=False, ), patch( - "services.mem0_presentation_memory_service.import_module", - return_value=FakeMem0Module, + "services.mem0_presentation_memory_service.get_shared_mem0_client", + return_value=FakeMemoryClient.from_config( + { + "vector_store": {"provider": "qdrant", "config": {}}, + "embedder": { + "provider": "fastembed", + "config": { + "model": "BAAI/bge-small-en-v1.5", + "embedding_dims": 384, + }, + }, + } + ), ): service = Mem0PresentationMemoryService() presentation_id = uuid.uuid4() @@ -154,3 +185,4 @@ class TestMem0PresentationMemoryService: "user_id": f"presentation:{presentation_id}" } assert client.search_calls[0]["top_k"] == 5 + diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock index fdbf9f8a..83607c2b 100644 --- a/servers/fastapi/uv.lock +++ b/servers/fastapi/uv.lock @@ -1185,7 +1185,7 @@ wheels = [ [[package]] name = "llmai" -version = "0.1.9" +version = "0.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anthropic" }, @@ -1193,9 +1193,9 @@ dependencies = [ { name = "google-genai" }, { name = "openai" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/34/1d4188bd842f336726004896be9ae04b693b9c21d349918631433b9f1b63/llmai-0.2.1.tar.gz", hash = "sha256:f911bd7df3eb06d1c56612ce293f926df7b3bf6c36283a353cf780c697d39d31", size = 47862, upload-time = "2026-04-24T16:28:52.417Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" }, + { url = "https://files.pythonhosted.org/packages/45/7d/63b3da1be92b721b0681db5b9e96c1cbc000b63cb70ede40b20cc5302699/llmai-0.2.1-py3-none-any.whl", hash = "sha256:4c51d1186cce1e621f74a9ec70376dc1bd2e996eee484db17dce6a6e7b79a0a7", size = 59880, upload-time = "2026-04-24T16:28:50.706Z" }, ] [[package]] @@ -1686,7 +1686,7 @@ requires-dist = [ { name = "fastmcp", specifier = ">=2.11.0" }, { name = "google-genai", specifier = ">=1.28.0" }, { name = "jsonschema", specifier = ">=4.26.0" }, - { name = "llmai", specifier = "==0.1.9" }, + { name = "llmai", specifier = "==0.2.1" }, { name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" }, { name = "nltk", specifier = ">=3.9.1" }, { name = "openai", specifier = ">=1.98.0" }, diff --git a/servers/nextjs/app/(presentation-generator)/presentation/components/Chat.tsx b/servers/nextjs/app/(presentation-generator)/presentation/components/Chat.tsx index 6e5699e1..7ca91a74 100644 --- a/servers/nextjs/app/(presentation-generator)/presentation/components/Chat.tsx +++ b/servers/nextjs/app/(presentation-generator)/presentation/components/Chat.tsx @@ -252,6 +252,9 @@ const createMessageId = () => { return `${Date.now()}-${Math.random().toString(16).slice(2)}`; }; +const conversationStorageKey = (presentationId: string) => + `presenton:chat:conversationId:${presentationId}`; + const AssistantMarker = () => (
@@ -334,6 +337,7 @@ const Chat = ({ const [input, setInput] = useState(""); const [messages, setMessages] = useState([]); const [conversationId, setConversationId] = useState(null); + const [isHistoryLoading, setIsHistoryLoading] = useState(false); const [isSending, setIsSending] = useState(false); const [errorMessage, setErrorMessage] = useState(null); const [expandedActivityByMessage, setExpandedActivityByMessage] = useState< @@ -344,11 +348,70 @@ const Chat = ({ const messagesEndRef = useRef(null); useEffect(() => { + let cancelled = false; setMessages([]); setInput(""); setConversationId(null); setErrorMessage(null); setExpandedActivityByMessage({}); + + if (!presentationId) { + return; + } + + setIsHistoryLoading(true); + const run = async () => { + try { + if (typeof sessionStorage === "undefined") { + return; + } + const sKey = conversationStorageKey(presentationId); + let activeId = sessionStorage.getItem(sKey) ?? null; + if (!activeId) { + const list = await PresentationChatApi.listConversations( + presentationId + ); + if (list.length > 0) { + activeId = list[0]!.conversation_id; + sessionStorage.setItem(sKey, activeId); + } + } + if (!activeId) { + return; + } + const data = await PresentationChatApi.getHistory( + presentationId, + activeId + ); + if (cancelled) { + return; + } + setConversationId(activeId); + setMessages( + data.messages.map((m) => ({ + id: createMessageId(), + role: + m.role === "assistant" + ? "assistant" + : m.role === "user" + ? "user" + : "user", + content: m.content, + })) + ); + } catch (error) { + console.error("Failed to load chat history:", error); + toast.error("Could not load previous chat"); + } finally { + if (!cancelled) { + setIsHistoryLoading(false); + } + } + }; + void run(); + return () => { + cancelled = true; + }; }, [presentationId]); useEffect(() => { @@ -372,6 +435,9 @@ const Chat = ({ setConversationId(null); setErrorMessage(null); setExpandedActivityByMessage({}); + if (presentationId && typeof sessionStorage !== "undefined") { + sessionStorage.removeItem(conversationStorageKey(presentationId)); + } inputRef.current?.focus(); }; @@ -493,7 +559,7 @@ const Chat = ({ const submitMessage = async (rawMessage: string) => { const trimmedMessage = rawMessage.trim(); - if (!trimmedMessage || isSending) { + if (!trimmedMessage || isSending || isHistoryLoading) { return; } @@ -578,11 +644,23 @@ const Chat = ({ ) ); settleAssistantActivities(assistantMessageId, "success"); - setConversationId((previous) => - typeof response.conversation_id === "string" - ? response.conversation_id - : previous - ); + setConversationId((previous) => { + const next = + typeof response.conversation_id === "string" + ? response.conversation_id + : previous; + if ( + next && + presentationId && + typeof sessionStorage !== "undefined" + ) { + sessionStorage.setItem( + conversationStorageKey(presentationId), + next + ); + } + return next; + }); await refreshPresentationIfNeeded( Array.isArray(response.tool_calls) ? response.tool_calls : [] @@ -655,7 +733,7 @@ const Chat = ({