refactor: Update database session management and enhance chat memory services
- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks. - Refactored chat memory services to utilize a shared `mem0` client for better memory management. - Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers. - Enhanced error handling and response management in chat-related services. - Cleaned up unused code and improved overall structure for maintainability.
This commit is contained in:
parent
17ea7d9f95
commit
4e87dc8b70
25 changed files with 1508 additions and 368 deletions
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import async_session_maker
|
||||
from utils.ollama import pull_ollama_model
|
||||
|
||||
|
||||
|
|
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
|
|||
)
|
||||
log_event_count = 0
|
||||
|
||||
session = await get_container_db_async_session().__anext__()
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
return
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
return
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
|
||||
|
||||
async def upsert_ollama_pull_status(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import get_async_session
|
||||
from utils.ollama import list_pulled_ollama_models
|
||||
|
||||
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
||||
|
|
@ -29,7 +29,7 @@ async def get_available_models():
|
|||
async def pull_model(
|
||||
model: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
session: AsyncSession = Depends(get_container_db_async_session),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
|
||||
if model not in SUPPORTED_OLLAMA_MODELS:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ def main() -> None:
|
|||
p = _sqlite_file_path(sync_url)
|
||||
if p is not None:
|
||||
paths.append(p)
|
||||
paths.append(p.parent / "container.db")
|
||||
|
||||
seen: set[Path] = set()
|
||||
for path in paths:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
import os
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
|
|
@ -21,7 +20,6 @@ from models.sql.template_create_info import TemplateCreateInfoModel
|
|||
from models.sql.slide import SlideModel
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
from utils.get_env import get_migrate_database_on_startup_env
|
||||
|
||||
|
||||
|
|
@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
yield session
|
||||
|
||||
|
||||
# Container DB (Lives inside the app data directory)
|
||||
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
|
||||
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
|
||||
container_db_engine: AsyncEngine = create_async_engine(
|
||||
container_db_url, connect_args={"check_same_thread": False}
|
||||
)
|
||||
container_db_async_session_maker = async_sessionmaker(
|
||||
container_db_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with container_db_async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Create Database and Tables
|
||||
async def create_db_and_tables():
|
||||
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
|
||||
|
|
@ -76,18 +58,11 @@ async def create_db_and_tables():
|
|||
TemplateModel.__table__,
|
||||
WebhookSubscription.__table__,
|
||||
AsyncPresentationGenerationTaskModel.__table__,
|
||||
OllamaPullStatus.__table__,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async with container_db_engine.begin() as conn:
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: SQLModel.metadata.create_all(
|
||||
sync_conn,
|
||||
tables=[OllamaPullStatus.__table__],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def dispose_engines():
|
||||
"""Dispose all engine connection pools.
|
||||
|
|
@ -97,4 +72,3 @@ async def dispose_engines():
|
|||
database and prevent stale / leaked connections.
|
||||
"""
|
||||
await sql_engine.dispose()
|
||||
await container_db_engine.dispose()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import async_session_maker
|
||||
from utils.ollama import pull_ollama_model
|
||||
|
||||
|
||||
|
|
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
|
|||
)
|
||||
log_event_count = 0
|
||||
|
||||
session = await get_container_db_async_session().__anext__()
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
return
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
return
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
|
||||
|
||||
async def upsert_ollama_pull_status(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.chat import ChatMessageRequest, ChatMessageResponse
|
||||
from models.chat import (
|
||||
ChatConversationListItem,
|
||||
ChatHistoryMessageItem,
|
||||
ChatHistoryResponse,
|
||||
ChatMessageRequest,
|
||||
ChatMessageResponse,
|
||||
)
|
||||
from models.sse_response import (
|
||||
SSECompleteResponse,
|
||||
SSEErrorResponse,
|
||||
|
|
@ -13,11 +20,57 @@ from models.sse_response import (
|
|||
SSEResponse,
|
||||
)
|
||||
from services.chat import ChatTurnResult, PresentationChatService
|
||||
from services.chat import sql_chat_history
|
||||
from services.database import get_async_session
|
||||
|
||||
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
|
||||
|
||||
|
||||
@CHAT_ROUTER.get("/conversations", response_model=list[ChatConversationListItem])
|
||||
async def list_chat_conversations(
|
||||
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
raw = await sql_chat_history.list_conversations(
|
||||
sql_session, presentation_id=presentation_id
|
||||
)
|
||||
return [
|
||||
ChatConversationListItem(
|
||||
conversation_id=uuid.UUID(str(item["conversation_id"])),
|
||||
updated_at=item.get("updated_at"),
|
||||
last_message_preview=item.get("last_message_preview"),
|
||||
)
|
||||
for item in raw
|
||||
]
|
||||
|
||||
|
||||
@CHAT_ROUTER.get("/history", response_model=ChatHistoryResponse)
|
||||
async def get_chat_history(
|
||||
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
|
||||
conversation_id: uuid.UUID = Query(..., description="Conversation thread id"),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
rows = await sql_chat_history.load_messages_with_meta(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
return ChatHistoryResponse(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
messages=[
|
||||
ChatHistoryMessageItem(
|
||||
role=str(m.get("role") or ""),
|
||||
content=str(m.get("content") or ""),
|
||||
created_at=m.get("created_at")
|
||||
if isinstance(m.get("created_at"), str)
|
||||
else None,
|
||||
)
|
||||
for m in rows
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse)
|
||||
async def chat_message(
|
||||
payload: ChatMessageRequest,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import get_async_session
|
||||
from utils.ollama import list_pulled_ollama_models
|
||||
|
||||
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
||||
|
|
@ -29,7 +29,7 @@ async def get_available_models():
|
|||
async def pull_model(
|
||||
model: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
session: AsyncSession = Depends(get_container_db_async_session),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
|
||||
if model not in SUPPORTED_OLLAMA_MODELS:
|
||||
|
|
|
|||
|
|
@ -18,3 +18,27 @@ class ChatMessageResponse(BaseModel):
|
|||
tool_calls: list[str] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatHistoryMessageItem(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
created_at: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatHistoryResponse(BaseModel):
|
||||
presentation_id: uuid.UUID
|
||||
conversation_id: uuid.UUID
|
||||
messages: list[ChatHistoryMessageItem]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatConversationListItem(BaseModel):
|
||||
conversation_id: uuid.UUID
|
||||
updated_at: Optional[str] = None
|
||||
last_message_preview: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from services.chat.service import ChatTurnResult, PresentationChatService
|
||||
from services.chat.presentation_context_store import PresentationContextStore
|
||||
|
||||
__all__ = [
|
||||
"ChatTurnResult",
|
||||
"PresentationChatService",
|
||||
"PresentationContextStore",
|
||||
]
|
||||
|
|
|
|||
324
servers/fastapi/services/chat/chat_memory_store.py
Normal file
324
servers/fastapi/services/chat/chat_memory_store.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from services.mem0_oss_memory import get_shared_mem0_client
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
CHAT_TURN_TAG = "[chat_turn]"
|
||||
DEFAULT_MAX_STORED_TURNS = 20
|
||||
|
||||
|
||||
class ChatMemoryStore:
|
||||
def __init__(self):
|
||||
self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True)
|
||||
self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8)
|
||||
self._max_context_chars = self._to_int(
|
||||
os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000
|
||||
)
|
||||
self._max_stored_turns = self._to_int(
|
||||
os.getenv("CHAT_MAX_STORED_TURNS"), default=DEFAULT_MAX_STORED_TURNS
|
||||
)
|
||||
self._namespace_prefix = (
|
||||
os.getenv("MEM0_CHAT_NAMESPACE_PREFIX")
|
||||
or os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX")
|
||||
or "presentation"
|
||||
).strip() or "presentation"
|
||||
|
||||
@staticmethod
|
||||
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Optional[str], default: int) -> int:
|
||||
try:
|
||||
parsed = int(value) if value is not None else default
|
||||
return max(1, parsed)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _normalize(value: str) -> str:
|
||||
return " ".join((value or "").split())
|
||||
|
||||
@staticmethod
|
||||
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
|
||||
return isinstance(exc, (Exception, SystemExit))
|
||||
|
||||
def _scope_user_id(self, presentation_id: UUID, conversation_id: UUID) -> str:
|
||||
return (
|
||||
f"{self._namespace_prefix}:{presentation_id}:"
|
||||
f"conversation:{conversation_id}"
|
||||
)
|
||||
|
||||
def _truncate(self, text: str, limit: int = 20000) -> str:
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return f"{text[:limit]}\n\n[TRUNCATED]"
|
||||
|
||||
async def _get_client(self):
|
||||
if not self._enabled:
|
||||
return None
|
||||
return get_shared_mem0_client()
|
||||
|
||||
def _build_turn_payload(self, *, user_text: str, assistant_text: str) -> str:
|
||||
memory_lines = [
|
||||
CHAT_TURN_TAG,
|
||||
f"turn_created_at={datetime.now(timezone.utc).isoformat()}",
|
||||
]
|
||||
if user_text:
|
||||
memory_lines.append(f"user={user_text}")
|
||||
if assistant_text:
|
||||
memory_lines.append(f"assistant={assistant_text}")
|
||||
return "\n".join(memory_lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_field(item: dict[str, Any]) -> str:
|
||||
memory_text = item.get("memory") or item.get("text") or item.get("data")
|
||||
return str(memory_text).strip() if memory_text is not None else ""
|
||||
|
||||
def _collect_results(self, response: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(response, dict):
|
||||
raw_results = (
|
||||
response.get("results")
|
||||
or response.get("memories")
|
||||
or response.get("items")
|
||||
or []
|
||||
)
|
||||
if isinstance(raw_results, list):
|
||||
return [item for item in raw_results if isinstance(item, dict)]
|
||||
return []
|
||||
if isinstance(response, list):
|
||||
return [item for item in response if isinstance(item, dict)]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _safe_parse_datetime(raw_value: Any) -> datetime | None:
|
||||
if not isinstance(raw_value, str) or not raw_value.strip():
|
||||
return None
|
||||
value = raw_value.strip().replace("Z", "+00:00")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value)
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_chat_turn_fields(text: str) -> tuple[str | None, str | None, datetime | None]:
|
||||
if CHAT_TURN_TAG not in text:
|
||||
return None, None, None
|
||||
|
||||
user_text: str | None = None
|
||||
assistant_text: str | None = None
|
||||
turn_created_at: datetime | None = None
|
||||
for line in text.splitlines():
|
||||
if line.startswith("user="):
|
||||
user_text = line[len("user=") :].strip()
|
||||
elif line.startswith("assistant="):
|
||||
assistant_text = line[len("assistant=") :].strip()
|
||||
elif line.startswith("turn_created_at="):
|
||||
turn_created_at = ChatMemoryStore._safe_parse_datetime(
|
||||
line[len("turn_created_at=") :].strip()
|
||||
)
|
||||
return user_text, assistant_text, turn_created_at
|
||||
|
||||
async def store_chat_turn(
|
||||
self,
|
||||
*,
|
||||
presentation_id: UUID,
|
||||
conversation_id: UUID,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
) -> None:
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return
|
||||
|
||||
user_text = self._normalize(user_message)
|
||||
assistant_text = self._normalize(assistant_message)
|
||||
if not user_text and not assistant_text:
|
||||
return
|
||||
|
||||
payload = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._truncate(
|
||||
self._build_turn_payload(
|
||||
user_text=user_text,
|
||||
assistant_text=assistant_text,
|
||||
)
|
||||
),
|
||||
}
|
||||
]
|
||||
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
||||
|
||||
def _add():
|
||||
try:
|
||||
return client.add(payload, user_id=scoped_user_id, infer=False)
|
||||
except TypeError:
|
||||
return client.add(
|
||||
messages=payload,
|
||||
user_id=scoped_user_id,
|
||||
infer=False,
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_add)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception(
|
||||
(
|
||||
"Failed to add chat mem0 memory "
|
||||
"(presentation_id=%s, conversation_id=%s)"
|
||||
),
|
||||
presentation_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
async def retrieve_context(
|
||||
self,
|
||||
*,
|
||||
presentation_id: UUID,
|
||||
conversation_id: UUID,
|
||||
query: str,
|
||||
) -> str:
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return ""
|
||||
|
||||
trimmed_query = (query or "").strip()
|
||||
if not trimmed_query:
|
||||
return ""
|
||||
|
||||
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
||||
|
||||
def _search():
|
||||
try:
|
||||
return client.search(
|
||||
trimmed_query,
|
||||
filters={"user_id": scoped_user_id},
|
||||
top_k=self._top_k,
|
||||
)
|
||||
except TypeError:
|
||||
return client.search(
|
||||
trimmed_query,
|
||||
user_id=scoped_user_id,
|
||||
top_k=self._top_k,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(_search)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception(
|
||||
(
|
||||
"Failed to search chat mem0 memory "
|
||||
"(presentation_id=%s, conversation_id=%s)"
|
||||
),
|
||||
presentation_id,
|
||||
conversation_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
results = self._collect_results(response)
|
||||
memories: list[str] = []
|
||||
for item in results:
|
||||
normalized = self._extract_text_field(item)
|
||||
if normalized:
|
||||
memories.append(normalized)
|
||||
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
deduped = list(dict.fromkeys(memories))
|
||||
return self._truncate("\n\n".join(deduped), self._max_context_chars)
|
||||
|
||||
async def load_history(
|
||||
self,
|
||||
*,
|
||||
presentation_id: UUID,
|
||||
conversation_id: UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return []
|
||||
|
||||
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
||||
|
||||
def _get_all():
|
||||
try:
|
||||
return client.get_all(
|
||||
filters={"user_id": scoped_user_id},
|
||||
limit=max(10, self._max_stored_turns * 4),
|
||||
)
|
||||
except TypeError:
|
||||
try:
|
||||
return client.get_all(
|
||||
user_id=scoped_user_id,
|
||||
limit=max(10, self._max_stored_turns * 4),
|
||||
)
|
||||
except TypeError:
|
||||
try:
|
||||
return client.get_all(filters={"user_id": scoped_user_id})
|
||||
except TypeError:
|
||||
return client.get_all(user_id=scoped_user_id)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(_get_all)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception(
|
||||
(
|
||||
"Failed to load chat mem0 history "
|
||||
"(presentation_id=%s, conversation_id=%s)"
|
||||
),
|
||||
presentation_id,
|
||||
conversation_id,
|
||||
)
|
||||
return []
|
||||
|
||||
results = self._collect_results(response)
|
||||
ordered_turns: list[tuple[datetime, str, str]] = []
|
||||
for index, item in enumerate(results):
|
||||
text_value = self._extract_text_field(item)
|
||||
if not text_value:
|
||||
continue
|
||||
user_text, assistant_text, embedded_timestamp = self._extract_chat_turn_fields(
|
||||
text_value
|
||||
)
|
||||
if not user_text and not assistant_text:
|
||||
continue
|
||||
|
||||
item_created_at = (
|
||||
self._safe_parse_datetime(item.get("created_at"))
|
||||
or self._safe_parse_datetime(item.get("updated_at"))
|
||||
or self._safe_parse_datetime(item.get("event_at"))
|
||||
)
|
||||
timestamp = embedded_timestamp or item_created_at or datetime.fromtimestamp(
|
||||
index, tz=timezone.utc
|
||||
)
|
||||
ordered_turns.append((timestamp, user_text or "", assistant_text or ""))
|
||||
|
||||
ordered_turns.sort(key=lambda turn: turn[0])
|
||||
recent_turns = ordered_turns[-self._max_stored_turns :]
|
||||
|
||||
history: list[dict[str, str]] = []
|
||||
for _, user_text, assistant_text in recent_turns:
|
||||
if user_text:
|
||||
history.append({"role": "user", "content": user_text})
|
||||
if assistant_text:
|
||||
history.append({"role": "assistant", "content": assistant_text})
|
||||
return history
|
||||
|
||||
|
||||
CHAT_MEMORY_STORE = ChatMemoryStore()
|
||||
|
|
@ -1,44 +1,14 @@
|
|||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from models.sql.key_value import KeyValueSqlModel
|
||||
|
||||
CHAT_CONVERSATION_KEY_PREFIX = "chat_conversation"
|
||||
MAX_STORED_TURNS = 20
|
||||
|
||||
|
||||
class ConversationMessage(BaseModel):
|
||||
role: Literal["user", "assistant"]
|
||||
content: str = Field(min_length=1)
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
|
||||
@field_validator("content")
|
||||
@classmethod
|
||||
def normalize_content(cls, value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
raise ValueError("Conversation message cannot be empty.")
|
||||
return normalized
|
||||
|
||||
|
||||
class ConversationPayload(BaseModel):
|
||||
conversation_id: uuid.UUID
|
||||
presentation_id: uuid.UUID
|
||||
messages: list[ConversationMessage] = Field(default_factory=list)
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
from services.chat.chat_memory_store import CHAT_MEMORY_STORE
|
||||
from services.chat import sql_chat_history
|
||||
|
||||
|
||||
class ChatConversationStore:
|
||||
def __init__(self, sql_session: AsyncSession):
|
||||
self._sql_session = sql_session
|
||||
self._sql = sql_session
|
||||
|
||||
async def load_history(
|
||||
self,
|
||||
|
|
@ -46,13 +16,25 @@ class ChatConversationStore:
|
|||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
payload = await self._get_payload(conversation_id)
|
||||
if not payload or payload.presentation_id != presentation_id:
|
||||
return []
|
||||
return [
|
||||
{"role": message.role, "content": message.content}
|
||||
for message in payload.messages
|
||||
]
|
||||
messages = await sql_chat_history.load_messages(
|
||||
self._sql,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if messages:
|
||||
return messages
|
||||
legacy = await CHAT_MEMORY_STORE.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if legacy:
|
||||
await sql_chat_history.replace_messages(
|
||||
self._sql,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
messages=legacy,
|
||||
)
|
||||
return legacy
|
||||
|
||||
async def append_turn(
|
||||
self,
|
||||
|
|
@ -62,87 +44,35 @@ class ChatConversationStore:
|
|||
user_message: str,
|
||||
assistant_message: str,
|
||||
) -> None:
|
||||
payload = await self._get_payload(conversation_id)
|
||||
messages = list(payload.messages) if payload else []
|
||||
|
||||
messages.append(ConversationMessage(role="user", content=user_message))
|
||||
messages.append(ConversationMessage(role="assistant", content=assistant_message))
|
||||
max_messages = MAX_STORED_TURNS * 2
|
||||
messages = messages[-max_messages:]
|
||||
|
||||
next_payload = ConversationPayload(
|
||||
conversation_id=conversation_id,
|
||||
await sql_chat_history.append_turn(
|
||||
self._sql,
|
||||
presentation_id=presentation_id,
|
||||
messages=messages,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
assistant_message=assistant_message,
|
||||
)
|
||||
await CHAT_MEMORY_STORE.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
assistant_message=assistant_message,
|
||||
)
|
||||
|
||||
row = await self._get_row(conversation_id)
|
||||
if row:
|
||||
row.value = next_payload.model_dump(mode="json")
|
||||
self._sql_session.add(row)
|
||||
else:
|
||||
self._sql_session.add(
|
||||
KeyValueSqlModel(
|
||||
key=self._conversation_key(conversation_id),
|
||||
value=next_payload.model_dump(mode="json"),
|
||||
)
|
||||
)
|
||||
|
||||
await self._sql_session.commit()
|
||||
async def retrieve_semantic_context(
|
||||
self,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
query: str,
|
||||
) -> str:
|
||||
return await CHAT_MEMORY_STORE.retrieve_context(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def ensure_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID | None,
|
||||
) -> uuid.UUID:
|
||||
return conversation_id or uuid.uuid4()
|
||||
|
||||
async def _get_payload(
|
||||
self, conversation_id: uuid.UUID
|
||||
) -> ConversationPayload | None:
|
||||
row = await self._get_row(conversation_id)
|
||||
if not row or not isinstance(row.value, dict):
|
||||
return None
|
||||
|
||||
raw_payload: dict[str, Any] = row.value
|
||||
try:
|
||||
return ConversationPayload.model_validate(raw_payload)
|
||||
except ValidationError:
|
||||
return self._coerce_payload(raw_payload)
|
||||
|
||||
def _coerce_payload(self, payload: dict[str, Any]) -> ConversationPayload | None:
|
||||
try:
|
||||
conversation_id = uuid.UUID(str(payload.get("conversation_id")))
|
||||
presentation_id = uuid.UUID(str(payload.get("presentation_id")))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
raw_messages = payload.get("messages")
|
||||
messages: list[ConversationMessage] = []
|
||||
if isinstance(raw_messages, list):
|
||||
for entry in raw_messages:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
try:
|
||||
message = ConversationMessage.model_validate(entry)
|
||||
messages.append(message)
|
||||
except ValidationError:
|
||||
continue
|
||||
|
||||
return ConversationPayload(
|
||||
conversation_id=conversation_id,
|
||||
presentation_id=presentation_id,
|
||||
messages=messages[-(MAX_STORED_TURNS * 2) :],
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
async def _get_row(self, conversation_id: uuid.UUID) -> KeyValueSqlModel | None:
|
||||
return await self._sql_session.scalar(
|
||||
select(KeyValueSqlModel).where(
|
||||
KeyValueSqlModel.key == self._conversation_key(conversation_id)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _conversation_key(conversation_id: uuid.UUID) -> str:
|
||||
return f"{CHAT_CONVERSATION_KEY_PREFIX}:{conversation_id}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from services.chat.memory_layer import (
|
||||
PresentationChatMemoryLayer as PresentationContextStore,
|
||||
)
|
||||
|
||||
__all__ = ["PresentationContextStore"]
|
||||
|
|
@ -1,13 +1,31 @@
|
|||
def build_system_prompt(memory_context: str) -> str:
|
||||
context_block = (
|
||||
"\nMemory context (use only when relevant):\n"
|
||||
f"{memory_context}\n"
|
||||
if memory_context
|
||||
else ""
|
||||
def _trim_block(label: str, text: str) -> str:
|
||||
t = (text or "").strip()
|
||||
if not t:
|
||||
return ""
|
||||
return f"\n{label}\n(use only when relevant; may be partial)\n{t}\n"
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
presentation_memory_context: str,
|
||||
chat_memory_context: str,
|
||||
) -> str:
|
||||
"""
|
||||
presentation_memory_context: deck-scoped (documents, outlines, prior slide edits, etc.)
|
||||
chat_memory_context: this thread (prior asks, assistant replies) — semantic slice.
|
||||
"""
|
||||
presentation_block = _trim_block(
|
||||
"Presentation memory (this deck: source text, outlines, and stored slide/edit notes):",
|
||||
presentation_memory_context,
|
||||
)
|
||||
chat_block = _trim_block(
|
||||
"This conversation thread (what was asked and answered in chat):",
|
||||
chat_memory_context,
|
||||
)
|
||||
return (
|
||||
"You are Presenton backend chat assistant.\n"
|
||||
"You can call tools to access presentation memory.\n"
|
||||
"You can call tools to access live slide data and layouts.\n"
|
||||
"Distinguish: presentation memory = facts about the deck; chat memory = this thread’s prior Q&A. "
|
||||
"Tools still win for current slide content.\n"
|
||||
"- Use getPresentationOutline for outline/section questions.\n"
|
||||
"- Prefer compact tool outputs to save context window; do not request full slide JSON unless needed.\n"
|
||||
"- Use searchSlides for finding relevant slide content snippets (DB-backed).\n"
|
||||
|
|
@ -21,5 +39,6 @@ def build_system_prompt(memory_context: str) -> str:
|
|||
"- After tool outputs are sufficient, stop calling tools and provide a final answer.\n"
|
||||
"- If memory is missing, state that clearly and suggest next steps.\n"
|
||||
"- Do not invent slide facts that are not in tool results or memory.\n"
|
||||
f"{context_block}"
|
||||
f"{presentation_block}"
|
||||
f"{chat_block}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from models.sql.presentation import PresentationModel
|
||||
from services.chat.conversation_store import ChatConversationStore
|
||||
from services.chat.memory_layer import PresentationChatMemoryLayer
|
||||
from services.chat.presentation_context_store import PresentationContextStore
|
||||
from services.chat.prompts import build_system_prompt
|
||||
from services.chat.tools import ChatTools
|
||||
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
||||
|
|
@ -58,7 +58,7 @@ class PresentationChatService:
|
|||
self._conversation_id = conversation_id
|
||||
|
||||
self._conversation_store = ChatConversationStore(sql_session)
|
||||
self._memory = PresentationChatMemoryLayer(sql_session, presentation_id)
|
||||
self._memory = PresentationContextStore(sql_session, presentation_id)
|
||||
self._tools = ChatTools(self._memory)
|
||||
|
||||
async def generate_reply(self, user_message: str) -> ChatTurnResult:
|
||||
|
|
@ -209,6 +209,8 @@ class PresentationChatService:
|
|||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
# A stable conversation_id is created here before the first user message
|
||||
# so mem0 and SQL can scope the thread; the client need not "chat" first.
|
||||
conversation_id = await self._conversation_store.ensure_conversation_id(
|
||||
self._conversation_id
|
||||
)
|
||||
|
|
@ -218,9 +220,19 @@ class PresentationChatService:
|
|||
)
|
||||
history_messages = self._convert_history_to_messages(history)
|
||||
|
||||
memory_context = await self._memory.retrieve_context(user_message)
|
||||
presentation_memory = await self._memory.retrieve_context(user_message)
|
||||
chat_memory = await self._conversation_store.retrieve_semantic_context(
|
||||
presentation_id=self._presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query=user_message,
|
||||
)
|
||||
messages: list[Message] = [
|
||||
SystemMessage(content=build_system_prompt(memory_context)),
|
||||
SystemMessage(
|
||||
content=build_system_prompt(
|
||||
presentation_memory_context=presentation_memory,
|
||||
chat_memory_context=chat_memory,
|
||||
)
|
||||
),
|
||||
*history_messages,
|
||||
UserMessage(content=user_message),
|
||||
]
|
||||
|
|
|
|||
231
servers/fastapi/services/chat/sql_chat_history.py
Normal file
231
servers/fastapi/services/chat/sql_chat_history.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Persist presentation chat threads in ``KeyValueSqlModel``.
|
||||
|
||||
Each conversation is one row: key ``ppt_chat:{presentation_id}:{conversation_id}``,
|
||||
value is JSON: ``{version, messages, updated_at, ...}``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from models.sql.key_value import KeyValueSqlModel
|
||||
|
||||
CHAT_HISTORY_KEY_PREFIX = "ppt_chat"
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
def chat_history_key(presentation_id: uuid.UUID, conversation_id: uuid.UUID) -> str:
|
||||
return f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:{conversation_id}"
|
||||
|
||||
|
||||
def _parse_conversation_key(key: str, presentation_id: uuid.UUID) -> uuid.UUID | None:
|
||||
expected_prefix = f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:"
|
||||
if not key.startswith(expected_prefix):
|
||||
return None
|
||||
rest = key[len(expected_prefix) :]
|
||||
try:
|
||||
return uuid.UUID(rest)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
async def load_messages(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Load ordered user/assistant messages for LLM context (role + content only)."""
|
||||
key = chat_history_key(presentation_id, conversation_id)
|
||||
row = await session.scalar(
|
||||
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
|
||||
)
|
||||
if not row or not isinstance(row.value, dict):
|
||||
return []
|
||||
messages = row.value.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return []
|
||||
out: list[dict[str, str]] = []
|
||||
for item in messages:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("role") not in ("user", "assistant"):
|
||||
continue
|
||||
content = item.get("content")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
out.append({"role": item["role"], "content": content})
|
||||
return out
|
||||
|
||||
|
||||
async def load_messages_with_meta(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Load messages including optional ``created_at`` for API / UI."""
|
||||
key = chat_history_key(presentation_id, conversation_id)
|
||||
row = await session.scalar(
|
||||
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
|
||||
)
|
||||
if not row or not isinstance(row.value, dict):
|
||||
return []
|
||||
messages = row.value.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
for item in messages:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("role") not in ("user", "assistant"):
|
||||
continue
|
||||
content = item.get("content")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
entry: dict[str, Any] = {
|
||||
"role": item["role"],
|
||||
"content": content,
|
||||
}
|
||||
created = item.get("created_at")
|
||||
if isinstance(created, str) and created.strip():
|
||||
entry["created_at"] = created.strip()
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
|
||||
async def replace_messages(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
messages: list[dict[str, str]],
|
||||
) -> None:
|
||||
"""Replace transcript (e.g. one-time sync from mem0)."""
|
||||
key = chat_history_key(presentation_id, conversation_id)
|
||||
row = await session.scalar(
|
||||
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
|
||||
)
|
||||
now = _utc_now_iso()
|
||||
built: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m.get("role") not in ("user", "assistant"):
|
||||
continue
|
||||
content = m.get("content")
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
built.append(
|
||||
{
|
||||
"role": m["role"],
|
||||
"content": content,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
value = {
|
||||
"version": SCHEMA_VERSION,
|
||||
"presentation_id": str(presentation_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"messages": built,
|
||||
"updated_at": now,
|
||||
}
|
||||
if row is None:
|
||||
session.add(KeyValueSqlModel(key=key, value=value))
|
||||
else:
|
||||
row.value = value
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def append_turn(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
) -> None:
|
||||
key = chat_history_key(presentation_id, conversation_id)
|
||||
row = await session.scalar(
|
||||
select(KeyValueSqlModel).where(KeyValueSqlModel.key == key)
|
||||
)
|
||||
now = _utc_now_iso()
|
||||
new_messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": user_message, "created_at": now},
|
||||
{"role": "assistant", "content": assistant_message, "created_at": now},
|
||||
]
|
||||
if row is None:
|
||||
value: dict[str, Any] = {
|
||||
"version": SCHEMA_VERSION,
|
||||
"presentation_id": str(presentation_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"messages": new_messages,
|
||||
"updated_at": now,
|
||||
}
|
||||
session.add(KeyValueSqlModel(key=key, value=value))
|
||||
else:
|
||||
data = row.value if isinstance(row.value, dict) else {}
|
||||
existing = data.get("messages")
|
||||
if not isinstance(existing, list):
|
||||
existing = []
|
||||
combined = [m for m in existing if isinstance(m, dict)]
|
||||
combined.extend(new_messages)
|
||||
data["version"] = SCHEMA_VERSION
|
||||
data["presentation_id"] = str(presentation_id)
|
||||
data["conversation_id"] = str(conversation_id)
|
||||
data["messages"] = combined
|
||||
data["updated_at"] = now
|
||||
row.value = data
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def list_conversations(
|
||||
session: AsyncSession, *, presentation_id: uuid.UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return conversation summaries for a presentation, newest ``updated_at`` first."""
|
||||
prefix = f"{CHAT_HISTORY_KEY_PREFIX}:{presentation_id}:"
|
||||
result = await session.scalars(
|
||||
select(KeyValueSqlModel).where(KeyValueSqlModel.key.startswith(prefix))
|
||||
)
|
||||
rows = list(result.all())
|
||||
summaries: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
cid = _parse_conversation_key(row.key, presentation_id)
|
||||
if cid is None:
|
||||
continue
|
||||
data = row.value if isinstance(row.value, dict) else {}
|
||||
updated_at: str | None = None
|
||||
raw_u = data.get("updated_at")
|
||||
if isinstance(raw_u, str) and raw_u.strip():
|
||||
updated_at = raw_u.strip()
|
||||
messages = data.get("messages")
|
||||
preview: str | None = None
|
||||
if isinstance(messages, list) and messages:
|
||||
for item in reversed(messages):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
c = item.get("content")
|
||||
if isinstance(c, str) and c.strip():
|
||||
preview = c.strip()
|
||||
if len(preview) > 200:
|
||||
preview = f"{preview[:200]}…"
|
||||
break
|
||||
summaries.append(
|
||||
{
|
||||
"conversation_id": str(cid),
|
||||
"updated_at": updated_at,
|
||||
"last_message_preview": preview,
|
||||
}
|
||||
)
|
||||
summaries.sort(
|
||||
key=lambda s: s.get("updated_at") or "",
|
||||
reverse=True,
|
||||
)
|
||||
return summaries
|
||||
|
|
@ -15,7 +15,7 @@ from services.chat.schemas import (
|
|||
SaveSlideInput,
|
||||
SearchSlidesInput,
|
||||
)
|
||||
from services.chat.memory_layer import PresentationChatMemoryLayer
|
||||
from services.chat.presentation_context_store import PresentationContextStore
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ class ChatTools:
|
|||
provider-specific logic, keeping them portable across llmai backends.
|
||||
"""
|
||||
|
||||
def __init__(self, memory: PresentationChatMemoryLayer):
|
||||
def __init__(self, memory: PresentationContextStore):
|
||||
self._memory = memory
|
||||
self._tool_handlers: dict[str, ToolHandler] = {
|
||||
"getPresentationOutline": self._get_presentation_outline,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
import os
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
|
|
@ -20,7 +19,6 @@ from models.sql.template import TemplateModel
|
|||
from models.sql.template_create_info import TemplateCreateInfoModel
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
from utils.get_env import get_migrate_database_on_startup_env
|
||||
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
|
||||
|
||||
|
|
@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
yield session
|
||||
|
||||
|
||||
# Container DB (Lives inside the app data directory)
|
||||
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
|
||||
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
|
||||
container_db_engine: AsyncEngine = create_async_engine(
|
||||
container_db_url, connect_args={"check_same_thread": False}
|
||||
)
|
||||
container_db_async_session_maker = async_sessionmaker(
|
||||
container_db_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with container_db_async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Create Database and Tables
|
||||
async def create_db_and_tables():
|
||||
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
|
||||
|
|
@ -76,18 +58,11 @@ async def create_db_and_tables():
|
|||
TemplateModel.__table__,
|
||||
WebhookSubscription.__table__,
|
||||
AsyncPresentationGenerationTaskModel.__table__,
|
||||
OllamaPullStatus.__table__,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async with container_db_engine.begin() as conn:
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: SQLModel.metadata.create_all(
|
||||
sync_conn,
|
||||
tables=[OllamaPullStatus.__table__],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def dispose_engines():
|
||||
"""Dispose all engine connection pools.
|
||||
|
|
@ -97,4 +72,3 @@ async def dispose_engines():
|
|||
database and prevent stale / leaked connections.
|
||||
"""
|
||||
await sql_engine.dispose()
|
||||
await container_db_engine.dispose()
|
||||
|
|
|
|||
131
servers/fastapi/services/mem0_oss_memory.py
Normal file
131
servers/fastapi/services/mem0_oss_memory.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Single shared mem0 OSS ``Memory`` client for the process.
|
||||
|
||||
All callers (presentation context, chat turns) use the same on-disk Qdrant/SQLite
|
||||
and distinguish data via mem0 ``user_id``:
|
||||
|
||||
- Deck-level (no chat thread): ``{namespace}:{presentation_id}``
|
||||
- Chat thread: ``{namespace}:{presentation_id}:conversation:{conversation_id}``
|
||||
|
||||
The chat flow calls ``ensure_conversation_id`` before the first turn, so a
|
||||
``conversation_id`` exists before any mem0 write for that thread.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from importlib import import_module
|
||||
from typing import Any, Optional
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_memory_init_lock = threading.Lock()
|
||||
_shared_client: Any | None = None
|
||||
_init_attempted = False
|
||||
|
||||
|
||||
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _to_int(value: Optional[str], default: int) -> int:
|
||||
try:
|
||||
parsed = int(value) if value is not None else default
|
||||
return max(1, parsed)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _oss_config_from_env() -> tuple[str, str, str, str, int, dict[str, Any]]:
|
||||
"""Return (mem0_dir, qdrant_path, history_db, collection, dims, from_config_dict)."""
|
||||
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
|
||||
mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
|
||||
qdrant_path = (
|
||||
os.getenv("MEM0_QDRANT_PATH") or os.path.join(mem0_dir, "qdrant")
|
||||
).strip()
|
||||
history_db_path = (
|
||||
os.getenv("MEM0_HISTORY_DB_PATH") or os.path.join(mem0_dir, "history.db")
|
||||
).strip()
|
||||
collection = (
|
||||
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
|
||||
).strip() or "presenton_memories"
|
||||
embedder = (os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed").strip() or "fastembed"
|
||||
model = (
|
||||
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
|
||||
).strip() or "BAAI/bge-small-en-v1.5"
|
||||
dims = _to_int(os.getenv("MEM0_EMBEDDING_DIMS"), default=384)
|
||||
config: dict[str, Any] = {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": collection,
|
||||
"path": qdrant_path,
|
||||
"on_disk": True,
|
||||
"embedding_model_dims": dims,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": embedder,
|
||||
"config": {
|
||||
"model": model,
|
||||
"embedding_dims": dims,
|
||||
},
|
||||
},
|
||||
"history_db_path": history_db_path,
|
||||
}
|
||||
return mem0_dir, qdrant_path, history_db_path, collection, dims, config
|
||||
|
||||
|
||||
def memory_from_config(config: dict[str, Any], *, telemetry_base: str) -> Any:
|
||||
"""Construct ``mem0.Memory``. Caller must hold ``_memory_init_lock`` if used with shared state."""
|
||||
os.makedirs(telemetry_base, exist_ok=True)
|
||||
import mem0.memory.main as mem0_main # type: ignore[import-untyped]
|
||||
|
||||
mem0_main.mem0_dir = telemetry_base
|
||||
memory_cls = getattr(import_module("mem0"), "Memory")
|
||||
return memory_cls.from_config(config)
|
||||
|
||||
|
||||
def get_shared_mem0_client() -> Any | None:
|
||||
"""Return the process-wide mem0 client, or ``None`` if disabled or init failed."""
|
||||
global _shared_client, _init_attempted
|
||||
|
||||
if not _to_bool(os.getenv("MEM0_ENABLED"), default=True):
|
||||
return None
|
||||
if _shared_client is not None:
|
||||
return _shared_client
|
||||
if _init_attempted:
|
||||
return None
|
||||
|
||||
with _memory_init_lock:
|
||||
if _shared_client is not None:
|
||||
return _shared_client
|
||||
if _init_attempted:
|
||||
return None
|
||||
_init_attempted = True
|
||||
try:
|
||||
mem0_dir, qdrant_path, history_db, collection, dims, config = (
|
||||
_oss_config_from_env()
|
||||
)
|
||||
os.makedirs(mem0_dir, exist_ok=True)
|
||||
os.makedirs(qdrant_path, exist_ok=True)
|
||||
telemetry_base = os.path.join(mem0_dir, "telemetry", "oss")
|
||||
_shared_client = memory_from_config(
|
||||
config,
|
||||
telemetry_base=telemetry_base,
|
||||
)
|
||||
LOGGER.info(
|
||||
"Mem0 OSS shared memory initialized (qdrant_path=%s, history_db_path=%s, collection=%s, dims=%s)",
|
||||
qdrant_path,
|
||||
history_db,
|
||||
collection,
|
||||
dims,
|
||||
)
|
||||
except BaseException:
|
||||
LOGGER.exception("Failed to initialize shared Mem0 OSS Memory")
|
||||
_shared_client = None
|
||||
|
||||
return _shared_client
|
||||
|
|
@ -2,10 +2,11 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from importlib import import_module
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from services.mem0_oss_memory import get_shared_mem0_client
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,31 +22,6 @@ class Mem0PresentationMemoryService:
|
|||
os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") or "presentation"
|
||||
).strip() or "presentation"
|
||||
|
||||
self._embedder_provider = (
|
||||
os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed"
|
||||
).strip() or "fastembed"
|
||||
self._embedder_model = (
|
||||
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
|
||||
).strip() or "BAAI/bge-small-en-v1.5"
|
||||
self._embedding_dims = self._to_int(
|
||||
os.getenv("MEM0_EMBEDDING_DIMS"),
|
||||
default=384,
|
||||
)
|
||||
|
||||
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
|
||||
self._mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
|
||||
self._qdrant_path = (os.getenv("MEM0_QDRANT_PATH") or os.path.join(self._mem0_dir, "qdrant")).strip()
|
||||
self._history_db_path = (
|
||||
os.getenv("MEM0_HISTORY_DB_PATH")
|
||||
or os.path.join(self._mem0_dir, "history.db")
|
||||
).strip()
|
||||
self._collection_name = (
|
||||
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
|
||||
).strip() or "presenton_memories"
|
||||
|
||||
self._client: Any = None
|
||||
self._attempted_client_init = False
|
||||
|
||||
@staticmethod
|
||||
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
||||
if value is None:
|
||||
|
|
@ -68,27 +44,6 @@ class Mem0PresentationMemoryService:
|
|||
return text
|
||||
return f"{text[:limit]}\n\n[TRUNCATED]"
|
||||
|
||||
def _get_oss_config(self) -> dict:
|
||||
return {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": self._collection_name,
|
||||
"path": self._qdrant_path,
|
||||
"on_disk": True,
|
||||
"embedding_model_dims": self._embedding_dims,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": self._embedder_provider,
|
||||
"config": {
|
||||
"model": self._embedder_model,
|
||||
"embedding_dims": self._embedding_dims,
|
||||
},
|
||||
},
|
||||
"history_db_path": self._history_db_path,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
|
||||
return isinstance(exc, (Exception, SystemExit))
|
||||
|
|
@ -96,42 +51,7 @@ class Mem0PresentationMemoryService:
|
|||
async def _get_client(self):
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
if self._attempted_client_init:
|
||||
return None
|
||||
|
||||
self._attempted_client_init = True
|
||||
|
||||
try:
|
||||
module = import_module("mem0")
|
||||
memory_cls = getattr(module, "Memory")
|
||||
|
||||
os.makedirs(self._mem0_dir, exist_ok=True)
|
||||
os.makedirs(self._qdrant_path, exist_ok=True)
|
||||
|
||||
config = self._get_oss_config()
|
||||
|
||||
try:
|
||||
self._client = memory_cls.from_config(config)
|
||||
except Exception:
|
||||
# Backward compatibility across mem0 OSS versions.
|
||||
self._client = memory_cls(config)
|
||||
|
||||
LOGGER.info(
|
||||
"Mem0 OSS presentation memory service initialized (qdrant_path=%s, history_db_path=%s)",
|
||||
self._qdrant_path,
|
||||
self._history_db_path,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception("Failed to initialize Mem0 OSS Memory")
|
||||
self._client = None
|
||||
|
||||
return self._client
|
||||
return get_shared_mem0_client()
|
||||
|
||||
async def _add_message(self, presentation_id: UUID, message: str):
|
||||
client = await self._get_client()
|
||||
|
|
|
|||
136
servers/fastapi/tests/test_chat_conversation_store.py
Normal file
136
servers/fastapi/tests/test_chat_conversation_store.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from services.chat.conversation_store import ChatConversationStore
|
||||
|
||||
|
||||
class TestChatConversationStore:
|
||||
def test_load_history_reads_sql_first(self):
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
expected_history = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
]
|
||||
sql_session = MagicMock()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.sql_chat_history.load_messages",
|
||||
new=AsyncMock(return_value=expected_history),
|
||||
) as load_sql, patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
|
||||
new=AsyncMock(),
|
||||
) as load_mem0:
|
||||
store = ChatConversationStore(sql_session)
|
||||
history = asyncio.run(
|
||||
store.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
load_sql.assert_awaited_once_with(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
load_mem0.assert_not_called()
|
||||
assert history == expected_history
|
||||
|
||||
def test_load_history_falls_back_to_mem0_and_backfills_sql(self):
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
legacy = [
|
||||
{"role": "user", "content": "old"},
|
||||
{"role": "assistant", "content": "from mem0"},
|
||||
]
|
||||
sql_session = MagicMock()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.sql_chat_history.load_messages",
|
||||
new=AsyncMock(return_value=[]),
|
||||
) as load_sql, patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
|
||||
new=AsyncMock(return_value=legacy),
|
||||
) as load_mem0, patch(
|
||||
"services.chat.conversation_store.sql_chat_history.replace_messages",
|
||||
new=AsyncMock(),
|
||||
) as replace_messages:
|
||||
store = ChatConversationStore(sql_session)
|
||||
history = asyncio.run(
|
||||
store.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
load_sql.assert_awaited_once()
|
||||
load_mem0.assert_awaited_once()
|
||||
replace_messages.assert_awaited_once_with(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
messages=legacy,
|
||||
)
|
||||
assert history == legacy
|
||||
|
||||
def test_append_turn_persists_sql_and_mem0(self):
|
||||
sql_session = MagicMock()
|
||||
store = ChatConversationStore(sql_session)
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.sql_chat_history.append_turn",
|
||||
new=AsyncMock(),
|
||||
) as append_sql, patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.store_chat_turn",
|
||||
new=AsyncMock(),
|
||||
) as store_mem0:
|
||||
asyncio.run(
|
||||
store.append_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you improve slide 2?",
|
||||
assistant_message="Yes, I will tighten the bullet points.",
|
||||
)
|
||||
)
|
||||
|
||||
append_sql.assert_awaited_once_with(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you improve slide 2?",
|
||||
assistant_message="Yes, I will tighten the bullet points.",
|
||||
)
|
||||
store_mem0.assert_awaited_once_with(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you improve slide 2?",
|
||||
assistant_message="Yes, I will tighten the bullet points.",
|
||||
)
|
||||
|
||||
def test_retrieve_semantic_context_delegates_to_chat_memory_store(self):
|
||||
store = ChatConversationStore(MagicMock())
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.retrieve_context",
|
||||
new=AsyncMock(return_value="conversation-scoped context"),
|
||||
) as retrieve_context:
|
||||
context = asyncio.run(
|
||||
store.retrieve_semantic_context(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What did we decide?",
|
||||
)
|
||||
)
|
||||
|
||||
retrieve_context.assert_awaited_once_with(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What did we decide?",
|
||||
)
|
||||
assert context == "conversation-scoped context"
|
||||
249
servers/fastapi/tests/test_chat_memory_store.py
Normal file
249
servers/fastapi/tests/test_chat_memory_store.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import services.mem0_oss_memory as mem0_oss
|
||||
from services.chat.chat_memory_store import ChatMemoryStore
|
||||
|
||||
|
||||
class FakeMemoryClient:
|
||||
instances: list["FakeMemoryClient"] = []
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config
|
||||
self.add_calls = []
|
||||
self.search_calls = []
|
||||
self.get_all_calls = []
|
||||
self.next_search_response = {"results": []}
|
||||
self.next_get_all_response = {"results": []}
|
||||
FakeMemoryClient.instances.append(self)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(config=config)
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
messages = kwargs.get("messages") if "messages" in kwargs else None
|
||||
if messages is None and args:
|
||||
messages = args[0]
|
||||
|
||||
self.add_calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"infer": kwargs.get("infer"),
|
||||
}
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
def search(self, query, *args, **kwargs):
|
||||
self.search_calls.append(
|
||||
{
|
||||
"query": query,
|
||||
"filters": kwargs.get("filters"),
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"top_k": kwargs.get("top_k"),
|
||||
}
|
||||
)
|
||||
return self.next_search_response
|
||||
|
||||
def get_all(self, *args, **kwargs):
|
||||
self.get_all_calls.append(
|
||||
{
|
||||
"filters": kwargs.get("filters"),
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"limit": kwargs.get("limit"),
|
||||
}
|
||||
)
|
||||
return self.next_get_all_response
|
||||
|
||||
|
||||
def _mem0_oss_fresh() -> None:
|
||||
mem0_oss._shared_client = None # type: ignore[attr-defined]
|
||||
mem0_oss._init_attempted = False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestChatMemoryStore:
|
||||
def setup_method(self):
|
||||
FakeMemoryClient.instances = []
|
||||
_mem0_oss_fresh()
|
||||
|
||||
def test_store_chat_turn_uses_conversation_scoped_user_id(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"MEM0_ENABLED": "true",
|
||||
"MEM0_TOP_K": "4",
|
||||
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
|
||||
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
|
||||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.chat.chat_memory_store.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {"provider": "fastembed", "config": {}},
|
||||
}
|
||||
),
|
||||
):
|
||||
store = ChatMemoryStore()
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
asyncio.run(
|
||||
store.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you tighten slide 3?",
|
||||
assistant_message="Yes, I can make it shorter.",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(FakeMemoryClient.instances) == 1
|
||||
client = FakeMemoryClient.instances[0]
|
||||
assert len(client.add_calls) == 1
|
||||
expected_user_id = (
|
||||
f"presentation:{presentation_id}:conversation:{conversation_id}"
|
||||
)
|
||||
assert client.add_calls[0]["user_id"] == expected_user_id
|
||||
assert client.add_calls[0]["infer"] is False
|
||||
payload = str(client.add_calls[0]["messages"][0]["content"])
|
||||
assert "[chat_turn]" in payload
|
||||
assert "user=Can you tighten slide 3?" in payload
|
||||
assert "assistant=Yes, I can make it shorter." in payload
|
||||
|
||||
def test_retrieve_context_reads_only_conversation_scoped_user_id(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"MEM0_ENABLED": "true",
|
||||
"MEM0_TOP_K": "6",
|
||||
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
|
||||
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
|
||||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.chat.chat_memory_store.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {"provider": "fastembed", "config": {}},
|
||||
}
|
||||
),
|
||||
):
|
||||
store = ChatMemoryStore()
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
expected_user_id = (
|
||||
f"presentation:{presentation_id}:conversation:{conversation_id}"
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
store.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="First turn",
|
||||
assistant_message="First answer",
|
||||
)
|
||||
)
|
||||
|
||||
client = FakeMemoryClient.instances[0]
|
||||
client.next_search_response = {
|
||||
"results": [
|
||||
{"memory": "Chat memory A"},
|
||||
{"memory": "Chat memory A"},
|
||||
{"memory": "Chat memory B"},
|
||||
]
|
||||
}
|
||||
|
||||
context = asyncio.run(
|
||||
store.retrieve_context(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What did we decide?",
|
||||
)
|
||||
)
|
||||
|
||||
assert "Chat memory A" in context
|
||||
assert "Chat memory B" in context
|
||||
assert context.count("Chat memory A") == 1
|
||||
|
||||
assert len(client.search_calls) == 1
|
||||
assert client.search_calls[0]["query"] == "What did we decide?"
|
||||
assert client.search_calls[0]["filters"] == {"user_id": expected_user_id}
|
||||
assert client.search_calls[0]["top_k"] == 6
|
||||
|
||||
def test_load_history_reads_conversation_scoped_turns(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"MEM0_ENABLED": "true",
|
||||
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
|
||||
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
|
||||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.chat.chat_memory_store.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {"provider": "fastembed", "config": {}},
|
||||
}
|
||||
),
|
||||
):
|
||||
store = ChatMemoryStore()
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
expected_user_id = (
|
||||
f"presentation:{presentation_id}:conversation:{conversation_id}"
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
store.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Draft intro",
|
||||
assistant_message="Updated intro done.",
|
||||
)
|
||||
)
|
||||
|
||||
client = FakeMemoryClient.instances[0]
|
||||
client.next_get_all_response = {
|
||||
"results": [
|
||||
{
|
||||
"memory": (
|
||||
"[chat_turn]\n"
|
||||
"turn_created_at=2026-04-25T10:00:00+00:00\n"
|
||||
"user=Draft intro\nassistant=Updated intro done."
|
||||
),
|
||||
"created_at": "2026-04-25T10:00:01+00:00",
|
||||
},
|
||||
{
|
||||
"memory": (
|
||||
"[chat_turn]\n"
|
||||
"turn_created_at=2026-04-25T10:01:00+00:00\n"
|
||||
"user=Add roadmap\nassistant=Roadmap slide added."
|
||||
),
|
||||
"created_at": "2026-04-25T10:01:01+00:00",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
history = asyncio.run(
|
||||
store.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert history == [
|
||||
{"role": "user", "content": "Draft intro"},
|
||||
{"role": "assistant", "content": "Updated intro done."},
|
||||
{"role": "user", "content": "Add roadmap"},
|
||||
{"role": "assistant", "content": "Roadmap slide added."},
|
||||
]
|
||||
|
||||
assert len(client.get_all_calls) == 1
|
||||
assert client.get_all_calls[0]["filters"] == {"user_id": expected_user_id}
|
||||
assert client.get_all_calls[0]["limit"] >= 10
|
||||
|
|
@ -2,11 +2,12 @@ import asyncio
|
|||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import services.mem0_oss_memory as mem0_oss
|
||||
from services.mem0_presentation_memory_service import Mem0PresentationMemoryService
|
||||
|
||||
|
||||
class FakeMemoryClient:
|
||||
instances = []
|
||||
instances: list["FakeMemoryClient"] = []
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config
|
||||
|
|
@ -45,13 +46,15 @@ class FakeMemoryClient:
|
|||
return self.next_search_response
|
||||
|
||||
|
||||
class FakeMem0Module:
|
||||
Memory = FakeMemoryClient
|
||||
def _mem0_oss_fresh() -> None:
|
||||
mem0_oss._shared_client = None # type: ignore[attr-defined]
|
||||
mem0_oss._init_attempted = False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestMem0PresentationMemoryService:
|
||||
def setup_method(self):
|
||||
FakeMemoryClient.instances = []
|
||||
_mem0_oss_fresh()
|
||||
|
||||
def test_store_generation_context_uses_presentation_scope(self):
|
||||
with patch.dict(
|
||||
|
|
@ -62,8 +65,25 @@ class TestMem0PresentationMemoryService:
|
|||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.mem0_presentation_memory_service.import_module",
|
||||
return_value=FakeMem0Module,
|
||||
"services.mem0_presentation_memory_service.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"on_disk": True,
|
||||
"embedding_model_dims": 384,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "fastembed",
|
||||
"config": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"embedding_dims": 384,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
):
|
||||
service = Mem0PresentationMemoryService()
|
||||
presentation_id = uuid.uuid4()
|
||||
|
|
@ -115,8 +135,19 @@ class TestMem0PresentationMemoryService:
|
|||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.mem0_presentation_memory_service.import_module",
|
||||
return_value=FakeMem0Module,
|
||||
"services.mem0_presentation_memory_service.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {
|
||||
"provider": "fastembed",
|
||||
"config": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"embedding_dims": 384,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
):
|
||||
service = Mem0PresentationMemoryService()
|
||||
presentation_id = uuid.uuid4()
|
||||
|
|
@ -154,3 +185,4 @@ class TestMem0PresentationMemoryService:
|
|||
"user_id": f"presentation:{presentation_id}"
|
||||
}
|
||||
assert client.search_calls[0]["top_k"] == 5
|
||||
|
||||
|
|
|
|||
8
servers/fastapi/uv.lock
generated
8
servers/fastapi/uv.lock
generated
|
|
@ -1185,7 +1185,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "llmai"
|
||||
version = "0.1.9"
|
||||
version = "0.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anthropic" },
|
||||
|
|
@ -1193,9 +1193,9 @@ dependencies = [
|
|||
{ name = "google-genai" },
|
||||
{ name = "openai" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/34/1d4188bd842f336726004896be9ae04b693b9c21d349918631433b9f1b63/llmai-0.2.1.tar.gz", hash = "sha256:f911bd7df3eb06d1c56612ce293f926df7b3bf6c36283a353cf780c697d39d31", size = 47862, upload-time = "2026-04-24T16:28:52.417Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/7d/63b3da1be92b721b0681db5b9e96c1cbc000b63cb70ede40b20cc5302699/llmai-0.2.1-py3-none-any.whl", hash = "sha256:4c51d1186cce1e621f74a9ec70376dc1bd2e996eee484db17dce6a6e7b79a0a7", size = 59880, upload-time = "2026-04-24T16:28:50.706Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1686,7 +1686,7 @@ requires-dist = [
|
|||
{ name = "fastmcp", specifier = ">=2.11.0" },
|
||||
{ name = "google-genai", specifier = ">=1.28.0" },
|
||||
{ name = "jsonschema", specifier = ">=4.26.0" },
|
||||
{ name = "llmai", specifier = "==0.1.9" },
|
||||
{ name = "llmai", specifier = "==0.2.1" },
|
||||
{ name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" },
|
||||
{ name = "nltk", specifier = ">=3.9.1" },
|
||||
{ name = "openai", specifier = ">=1.98.0" },
|
||||
|
|
|
|||
|
|
@ -252,6 +252,9 @@ const createMessageId = () => {
|
|||
return `${Date.now()}-${Math.random().toString(16).slice(2)}`;
|
||||
};
|
||||
|
||||
const conversationStorageKey = (presentationId: string) =>
|
||||
`presenton:chat:conversationId:${presentationId}`;
|
||||
|
||||
const AssistantMarker = () => (
|
||||
<div className="mb-3 flex items-center gap-1.5 text-[#A4A7AE]">
|
||||
<MessageCircleMore className="h-4 w-4" />
|
||||
|
|
@ -334,6 +337,7 @@ const Chat = ({
|
|||
const [input, setInput] = useState("");
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [conversationId, setConversationId] = useState<string | null>(null);
|
||||
const [isHistoryLoading, setIsHistoryLoading] = useState(false);
|
||||
const [isSending, setIsSending] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||
const [expandedActivityByMessage, setExpandedActivityByMessage] = useState<
|
||||
|
|
@ -344,11 +348,70 @@ const Chat = ({
|
|||
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setMessages([]);
|
||||
setInput("");
|
||||
setConversationId(null);
|
||||
setErrorMessage(null);
|
||||
setExpandedActivityByMessage({});
|
||||
|
||||
if (!presentationId) {
|
||||
return;
|
||||
}
|
||||
|
||||
setIsHistoryLoading(true);
|
||||
const run = async () => {
|
||||
try {
|
||||
if (typeof sessionStorage === "undefined") {
|
||||
return;
|
||||
}
|
||||
const sKey = conversationStorageKey(presentationId);
|
||||
let activeId = sessionStorage.getItem(sKey) ?? null;
|
||||
if (!activeId) {
|
||||
const list = await PresentationChatApi.listConversations(
|
||||
presentationId
|
||||
);
|
||||
if (list.length > 0) {
|
||||
activeId = list[0]!.conversation_id;
|
||||
sessionStorage.setItem(sKey, activeId);
|
||||
}
|
||||
}
|
||||
if (!activeId) {
|
||||
return;
|
||||
}
|
||||
const data = await PresentationChatApi.getHistory(
|
||||
presentationId,
|
||||
activeId
|
||||
);
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
setConversationId(activeId);
|
||||
setMessages(
|
||||
data.messages.map((m) => ({
|
||||
id: createMessageId(),
|
||||
role:
|
||||
m.role === "assistant"
|
||||
? "assistant"
|
||||
: m.role === "user"
|
||||
? "user"
|
||||
: "user",
|
||||
content: m.content,
|
||||
}))
|
||||
);
|
||||
} catch (error) {
|
||||
console.error("Failed to load chat history:", error);
|
||||
toast.error("Could not load previous chat");
|
||||
} finally {
|
||||
if (!cancelled) {
|
||||
setIsHistoryLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
void run();
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [presentationId]);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -372,6 +435,9 @@ const Chat = ({
|
|||
setConversationId(null);
|
||||
setErrorMessage(null);
|
||||
setExpandedActivityByMessage({});
|
||||
if (presentationId && typeof sessionStorage !== "undefined") {
|
||||
sessionStorage.removeItem(conversationStorageKey(presentationId));
|
||||
}
|
||||
|
||||
inputRef.current?.focus();
|
||||
};
|
||||
|
|
@ -493,7 +559,7 @@ const Chat = ({
|
|||
const submitMessage = async (rawMessage: string) => {
|
||||
const trimmedMessage = rawMessage.trim();
|
||||
|
||||
if (!trimmedMessage || isSending) {
|
||||
if (!trimmedMessage || isSending || isHistoryLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -578,11 +644,23 @@ const Chat = ({
|
|||
)
|
||||
);
|
||||
settleAssistantActivities(assistantMessageId, "success");
|
||||
setConversationId((previous) =>
|
||||
typeof response.conversation_id === "string"
|
||||
? response.conversation_id
|
||||
: previous
|
||||
);
|
||||
setConversationId((previous) => {
|
||||
const next =
|
||||
typeof response.conversation_id === "string"
|
||||
? response.conversation_id
|
||||
: previous;
|
||||
if (
|
||||
next &&
|
||||
presentationId &&
|
||||
typeof sessionStorage !== "undefined"
|
||||
) {
|
||||
sessionStorage.setItem(
|
||||
conversationStorageKey(presentationId),
|
||||
next
|
||||
);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
|
||||
await refreshPresentationIfNeeded(
|
||||
Array.isArray(response.tool_calls) ? response.tool_calls : []
|
||||
|
|
@ -655,7 +733,7 @@ const Chat = ({
|
|||
<button
|
||||
type="button"
|
||||
onClick={resetChat}
|
||||
disabled={isSending}
|
||||
disabled={isSending || isHistoryLoading}
|
||||
className="rounded-full p-1 text-[#8C8C8C] transition-colors hover:bg-[#F7F7F7] hover:text-[#191919] disabled:cursor-not-allowed disabled:opacity-50"
|
||||
aria-label="Reset chat"
|
||||
title="Reset chat"
|
||||
|
|
@ -665,7 +743,12 @@ const Chat = ({
|
|||
</div>
|
||||
|
||||
<div className="min-h-0 flex-1 overflow-y-auto px-4 pb-4 pt-9 hide-scrollbar">
|
||||
{messages.length === 0 ? (
|
||||
{isHistoryLoading && messages.length === 0 ? (
|
||||
<div className="flex items-center justify-center py-8 text-sm text-[#99A1AF]">
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Loading chat…
|
||||
</div>
|
||||
) : messages.length === 0 ? (
|
||||
<>
|
||||
<div>
|
||||
<h4 className="mb-2 text-[10px] font-normal leading-[15px] tracking-[0.367px] text-[#99A1AF]">
|
||||
|
|
@ -819,7 +902,7 @@ const Chat = ({
|
|||
className="min-h-[92px] w-full resize-none bg-transparent pb-10 text-sm text-[#101828] placeholder:text-[#99A1AF] focus:outline-none focus:ring-0"
|
||||
rows={4}
|
||||
value={input}
|
||||
disabled={isSending}
|
||||
disabled={isSending || isHistoryLoading}
|
||||
onChange={(event) => setInput(event.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Improve your slides..."
|
||||
|
|
@ -836,7 +919,7 @@ const Chat = ({
|
|||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!input.trim() || isSending}
|
||||
disabled={!input.trim() || isSending || isHistoryLoading}
|
||||
className="absolute bottom-3 right-3 flex items-center gap-1.5 px-3 py-2 text-sm font-medium text-[#191919] disabled:cursor-not-allowed disabled:opacity-60"
|
||||
style={{
|
||||
background:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,24 @@ export interface ChatMessageResponse {
|
|||
tool_calls?: string[];
|
||||
}
|
||||
|
||||
export interface ChatHistoryMessage {
|
||||
role: string;
|
||||
content: string;
|
||||
created_at?: string;
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
presentation_id: string;
|
||||
conversation_id: string;
|
||||
messages: ChatHistoryMessage[];
|
||||
}
|
||||
|
||||
export interface ChatConversationSummary {
|
||||
conversation_id: string;
|
||||
updated_at?: string | null;
|
||||
last_message_preview?: string | null;
|
||||
}
|
||||
|
||||
export interface ChatStreamTrace {
|
||||
kind?: string;
|
||||
round?: number;
|
||||
|
|
@ -64,6 +82,38 @@ type ChatStreamData =
|
|||
| Record<string, unknown>;
|
||||
|
||||
export class PresentationChatApi {
|
||||
static async listConversations(
|
||||
presentationId: string
|
||||
): Promise<ChatConversationSummary[]> {
|
||||
const u = new URL(getApiUrl("/api/v1/ppt/chat/conversations"));
|
||||
u.searchParams.set("presentation_id", presentationId);
|
||||
const response = await fetch(u.toString(), {
|
||||
headers: getHeader(),
|
||||
cache: "no-cache",
|
||||
});
|
||||
return await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to list chat conversations"
|
||||
);
|
||||
}
|
||||
|
||||
static async getHistory(
|
||||
presentationId: string,
|
||||
conversationId: string
|
||||
): Promise<ChatHistoryData> {
|
||||
const u = new URL(getApiUrl("/api/v1/ppt/chat/history"));
|
||||
u.searchParams.set("presentation_id", presentationId);
|
||||
u.searchParams.set("conversation_id", conversationId);
|
||||
const response = await fetch(u.toString(), {
|
||||
headers: getHeader(),
|
||||
cache: "no-cache",
|
||||
});
|
||||
return await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to load chat history"
|
||||
);
|
||||
}
|
||||
|
||||
static async sendMessage(
|
||||
payload: ChatMessageRequest
|
||||
): Promise<ChatMessageResponse> {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue