presenton/servers/fastapi/services/chat/memory_layer.py
sudipnext 7b1c3c6e70 Refactor chat memory layer and tools for improved slide handling
- Updated `RUNTIME_CONTENT_FIELDS` to retain only necessary fields during validation.
- Enhanced `PresentationChatMemoryLayer` to prioritize live slide data from the database.
- Modified `search` method to clarify that it retrieves snippets from SQL-backed slides.
- Updated `get_slide_at_index` to include an option for full content retrieval.
- Adjusted tool descriptions to emphasize compact outputs and proper index handling.
- Improved handling of slide indices to accommodate user-friendly 1-based references.
2026-04-24 09:51:21 +05:45

491 lines
17 KiB
Python

import copy
import json
import logging
import re
import uuid
from typing import Any
from jsonschema import Draft202012Validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.image_prompt import ImagePrompt
from models.sql.image_asset import ImageAsset
from models.sql.presentation import PresentationModel
from models.sql.slide import SlideModel
from services.icon_finder_service import ICON_FINDER_SERVICE
from services.image_generation_service import ImageGenerationService
from services.mem0_presentation_memory_service import MEM0_PRESENTATION_MEMORY_SERVICE
from templates.presentation_layout import SlideLayoutModel
from utils.asset_directory_utils import get_images_directory
from utils.process_slides import (
process_old_and_new_slides_and_fetch_assets,
process_slide_and_fetch_assets,
)
LOGGER = logging.getLogger(__name__)
MAX_SCHEMA_ERRORS = 10
# Keep URL runtime fields during validation because many slide schemas require them.
# Speaker note is handled separately and should not affect JSON-schema checks.
RUNTIME_CONTENT_FIELDS = {"__speaker_note__"}
class PresentationChatMemoryLayer:
"""
Memory abstraction for chat tools and context retrieval.
This layer intentionally hides where data comes from (SQL-backed persisted state
and mem0 retrieval) behind `get` and `search`-style methods so chat logic stays
decoupled from storage details.
"""
def __init__(self, sql_session: AsyncSession, presentation_id: uuid.UUID):
self._sql_session = sql_session
self._presentation_id = presentation_id
async def get(self, key: str) -> Any:
if key != "presentation_outline":
return None
# Prefer live slides from SQL so slide count and slide indices are always current.
slides_result = await self._sql_session.scalars(
select(SlideModel)
.where(SlideModel.presentation == self._presentation_id)
.order_by(SlideModel.index)
)
slides = list(slides_result)
if slides:
LOGGER.info(
"Chat outline loaded from slides table (presentation_id=%s, slides=%d)",
self._presentation_id,
len(slides),
)
return {
"source": "slides_table",
"slide_count": len(slides),
"slides": [
{
"slide_id": str(slide.id),
"index": slide.index,
"layout_id": slide.layout,
"content": slide.content,
"speaker_note": slide.speaker_note,
}
for slide in slides
],
}
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not presentation.outlines:
LOGGER.info(
"Chat memory miss for outline (presentation_id=%s)",
self._presentation_id,
)
return None
LOGGER.info(
"Chat outline fallback hit from presentation.outlines (presentation_id=%s)",
self._presentation_id,
)
return presentation.outlines
async def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]:
"""
Search slides directly from SQL-backed slide rows.
Results are intentionally compact (snippet-first) to keep tool-call payloads
small for models with limited context windows.
"""
trimmed_query = (query or "").strip()
if not trimmed_query:
return []
slides_result = await self._sql_session.scalars(
select(SlideModel).where(SlideModel.presentation == self._presentation_id)
)
slides = sorted(list(slides_result), key=lambda slide: slide.index)
if not slides:
LOGGER.info(
"Chat memory miss for slide search (presentation_id=%s, reason=no_slides)",
self._presentation_id,
)
return []
query_lower = trimmed_query.lower()
query_tokens = set(re.findall(r"[a-z0-9]{2,}", query_lower))
ranked: list[tuple[int, dict[str, Any]]] = []
for slide in slides:
serialized = self._serialize_slide(slide)
searchable = serialized.lower()
score = 0
if query_lower in searchable:
score += 8
if query_tokens:
score += sum(1 for token in query_tokens if token in searchable)
if score <= 0:
continue
ranked.append(
(
score,
{
"slide_id": str(slide.id),
"index": slide.index,
"slide_number": slide.index + 1,
"layout_id": slide.layout,
"snippet": self._build_snippet(serialized, query_lower),
"score": score,
},
)
)
ranked.sort(key=lambda item: (-item[0], item[1]["index"]))
results = [entry for _, entry in ranked[: max(1, limit)]]
LOGGER.info(
"Chat DB slide search completed (presentation_id=%s, query=%r, hits=%d)",
self._presentation_id,
trimmed_query,
len(results),
)
return results
async def get_slide_at_index(
self, index: int, *, include_full_content: bool = False
) -> dict[str, Any] | None:
slide = await self._sql_session.scalar(
select(SlideModel).where(
SlideModel.presentation == self._presentation_id,
SlideModel.index == index,
)
)
if not slide:
LOGGER.info(
"Chat memory miss for slide by index (presentation_id=%s, index=%d)",
self._presentation_id,
index,
)
return None
response: dict[str, Any] = {
"slide_id": str(slide.id),
"index": slide.index,
"slide_number": slide.index + 1,
"layout_id": slide.layout,
"content_preview": self._build_snippet(
self._serialize_slide(slide),
query_lower="",
window=420,
),
"speaker_note": slide.speaker_note,
}
if include_full_content:
response["content"] = slide.content
return response
async def get_available_layouts(self) -> list[dict[str, Any]]:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not isinstance(presentation.layout, dict):
return []
try:
layout_model = presentation.get_layout()
except Exception:
LOGGER.exception(
"Failed to parse presentation layout (presentation_id=%s)",
self._presentation_id,
)
return []
return [
{
"id": layout.id,
"name": layout.name,
"description": layout.description,
}
for layout in layout_model.slides
]
async def get_content_schema_from_layout_id(self, layout_id: str) -> dict[str, Any] | None:
layout = await self._get_layout_by_id(layout_id)
if not layout:
return None
return layout.json_schema
async def generate_image(self, prompt: str) -> str:
image_generation_service = ImageGenerationService(get_images_directory())
image = await image_generation_service.generate_image(ImagePrompt(prompt=prompt))
if isinstance(image, ImageAsset):
self._sql_session.add(image)
await self._sql_session.commit()
return image.path
return str(image)
async def generate_icon(self, query: str) -> str:
icons = await ICON_FINDER_SERVICE.search_icons(query, k=1)
if icons:
return icons[0]
return "/static/icons/placeholder.svg"
async def save_slide(
self,
*,
content: dict[str, Any],
layout_id: str,
index: int,
replace_old_slide_at_index: bool,
) -> dict[str, Any]:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
return {
"saved": False,
"message": "Presentation not found.",
"validation_errors": [],
}
layout = await self._get_layout_by_id(layout_id, presentation=presentation)
if not layout:
return {
"saved": False,
"message": f"Layout '{layout_id}' was not found in this presentation.",
"validation_errors": [f"Unknown layout_id '{layout_id}'."],
}
validation_errors = self._validate_slide_content(
content=content,
schema=layout.json_schema,
)
if validation_errors:
return {
"saved": False,
"message": "Slide content failed schema validation.",
"validation_errors": validation_errors,
}
target_index = max(0, index)
image_generation_service = ImageGenerationService(get_images_directory())
if replace_old_slide_at_index:
existing_slide = await self._sql_session.scalar(
select(SlideModel).where(
SlideModel.presentation == self._presentation_id,
SlideModel.index == target_index,
)
)
if not existing_slide:
return {
"saved": False,
"message": f"No existing slide found at index {target_index} to replace.",
"validation_errors": [],
}
updated_content = copy.deepcopy(content)
new_assets = await process_old_and_new_slides_and_fetch_assets(
image_generation_service=image_generation_service,
old_slide_content=existing_slide.content or {},
new_slide_content=updated_content,
)
existing_slide.id = uuid.uuid4()
existing_slide.layout = layout_id
existing_slide.layout_group = self._resolve_layout_group(
presentation=presentation,
fallback=existing_slide.layout_group,
)
existing_slide.content = updated_content
existing_slide.speaker_note = self._extract_speaker_note(updated_content)
self._sql_session.add(existing_slide)
self._sql_session.add_all(new_assets)
await self._sql_session.commit()
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
presentation_id=self._presentation_id,
slide_index=target_index,
edit_prompt=f"[chat_tool_save_slide_replace] layout_id={layout_id}",
edited_slide_content=updated_content,
)
return {
"saved": True,
"action": "replaced",
"message": f"Slide at index {target_index} was replaced successfully.",
"slide_id": str(existing_slide.id),
"index": target_index,
}
slides_result = await self._sql_session.scalars(
select(SlideModel)
.where(SlideModel.presentation == self._presentation_id)
.order_by(SlideModel.index)
)
slides = list(slides_result)
if slides:
max_index = max(slide.index for slide in slides)
insert_index = min(target_index, max_index + 1)
slides_to_shift = [slide for slide in slides if slide.index >= insert_index]
else:
insert_index = 0
slides_to_shift = []
for slide in sorted(slides_to_shift, key=lambda each: each.index, reverse=True):
slide.index += 1
self._sql_session.add(slide)
new_slide_content = copy.deepcopy(content)
new_slide = SlideModel(
presentation=self._presentation_id,
layout_group=self._resolve_layout_group(presentation=presentation),
layout=layout_id,
index=insert_index,
content=new_slide_content,
speaker_note=self._extract_speaker_note(new_slide_content),
)
new_assets = await process_slide_and_fetch_assets(
image_generation_service=image_generation_service,
slide=new_slide,
)
self._sql_session.add(new_slide)
self._sql_session.add_all(new_assets)
await self._sql_session.commit()
await self._sql_session.refresh(new_slide)
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
presentation_id=self._presentation_id,
slide_index=insert_index,
edit_prompt=f"[chat_tool_save_slide_new] layout_id={layout_id}",
edited_slide_content=new_slide.content,
)
return {
"saved": True,
"action": "created",
"message": f"New slide saved at index {insert_index}.",
"slide_id": str(new_slide.id),
"index": insert_index,
"shifted_slide_count": len(slides_to_shift),
}
async def retrieve_context(self, query: str) -> str:
context = await MEM0_PRESENTATION_MEMORY_SERVICE.retrieve_context(
self._presentation_id,
query,
)
if context:
LOGGER.info(
"Chat memory semantic context hit (presentation_id=%s, chars=%d)",
self._presentation_id,
len(context),
)
else:
LOGGER.info(
"Chat memory semantic context miss (presentation_id=%s)",
self._presentation_id,
)
return context
async def _get_layout_by_id(
self,
layout_id: str,
presentation: PresentationModel | None = None,
) -> SlideLayoutModel | None:
if not presentation:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not isinstance(presentation.layout, dict):
return None
try:
layout_model = presentation.get_layout()
except Exception:
return None
for layout in layout_model.slides:
if layout.id == layout_id:
return layout
return None
def _validate_slide_content(
self,
*,
content: dict[str, Any],
schema: dict[str, Any],
) -> list[str]:
validation_content = self._strip_runtime_fields(content)
validator = Draft202012Validator(schema)
errors = sorted(validator.iter_errors(validation_content), key=lambda err: err.path)
if not errors:
return []
formatted_errors: list[str] = []
for err in errors[:MAX_SCHEMA_ERRORS]:
location = ".".join([str(part) for part in err.path]) or "$"
formatted_errors.append(f"{location}: {err.message}")
return formatted_errors
@staticmethod
def _strip_runtime_fields(value: Any) -> Any:
if isinstance(value, dict):
sanitized: dict[str, Any] = {}
for key, nested_value in value.items():
if key in RUNTIME_CONTENT_FIELDS:
continue
sanitized[key] = PresentationChatMemoryLayer._strip_runtime_fields(
nested_value
)
return sanitized
if isinstance(value, list):
return [
PresentationChatMemoryLayer._strip_runtime_fields(item) for item in value
]
return value
@staticmethod
def _extract_speaker_note(content: dict[str, Any]) -> str:
value = content.get("__speaker_note__")
if isinstance(value, str):
return value
return ""
@staticmethod
def _resolve_layout_group(
*,
presentation: PresentationModel,
fallback: str = "presentation",
) -> str:
if isinstance(presentation.layout, dict):
name = str(presentation.layout.get("name") or "").strip()
if name:
return name
return fallback
@staticmethod
def _serialize_slide(slide: SlideModel) -> str:
content_text = ""
try:
content_text = json.dumps(slide.content or {}, ensure_ascii=False)
except Exception:
content_text = str(slide.content)
speaker_note = slide.speaker_note or ""
return f"slide_index={slide.index}\nlayout_id={slide.layout}\n{content_text}\n{speaker_note}"
@staticmethod
def _build_snippet(text: str, query_lower: str, window: int = 320) -> str:
normalized = " ".join(text.split())
if not normalized:
return ""
offset = normalized.lower().find(query_lower)
if offset == -1:
return normalized[:window]
start = max(0, offset - window // 3)
end = min(len(normalized), start + window)
return normalized[start:end]