presenton/servers/fastapi/services/chat/tools.py
sudipnext 4e87dc8b70 refactor: Update database session management and enhance chat memory services
- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks.
- Refactored chat memory services to utilize a shared `mem0` client for better memory management.
- Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers.
- Enhanced error handling and response management in chat-related services.
- Cleaned up unused code and improved overall structure for maintainability.
2026-04-25 19:10:39 +05:45

341 lines
12 KiB
Python

import json
import logging
import re
from typing import Any, Awaitable, Callable
import dirtyjson # type: ignore[import-untyped]
from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found]
from services.chat.schemas import (
GenerateIconInput,
GenerateImageInput,
GetContentSchemaFromLayoutIdInput,
GetSlideAtIndexInput,
NoArgsInput,
SaveSlideInput,
SearchSlidesInput,
)
from services.chat.presentation_context_store import PresentationContextStore
LOGGER = logging.getLogger(__name__)
ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
class ChatTools:
"""
llmai function tools for presentation chat.
Tool implementations only use the memory abstraction layer and avoid external
provider-specific logic, keeping them portable across llmai backends.
"""
def __init__(self, memory: PresentationContextStore):
self._memory = memory
self._tool_handlers: dict[str, ToolHandler] = {
"getPresentationOutline": self._get_presentation_outline,
"searchSlides": self._search_slides,
"getSlideAtIndex": self._get_slide_at_index,
"getAvailableLayouts": self._get_available_layouts,
"getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id,
"generateImage": self._generate_image,
"generateIcon": self._generate_icon,
"saveSlide": self._save_slide,
}
def get_tool_definitions(self) -> list[Tool]:
return [
Tool(
name="getPresentationOutline",
description=(
"Retrieve the current presentation outline from live slides "
"database state (with memory fallback when needed). "
"Return compact sections (no full slide JSON) to save context "
"window. Use when the user asks about sections, flow, or slide plan."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="searchSlides",
description=(
"Search SQL slides by semantic intent/keywords and return "
"compact relevant snippets with slide identifiers. "
"Always provide both query and limit."
),
schema=SearchSlidesInput,
strict=True,
),
Tool(
name="getSlideAtIndex",
description=(
"Retrieve a single slide by zero-based index, including its "
"layout id and compact preview by default. "
"Set includeFullContent=true only when full JSON is explicitly needed "
"(for example before editing existing content). "
"If user says slide N, convert to zero-based index N-1."
),
schema=GetSlideAtIndexInput,
strict=True,
),
Tool(
name="getAvailableLayouts",
description=(
"List all available layout ids and descriptions for the current "
"presentation template."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="getContentSchemaFromLayoutId",
description=(
"Fetch the JSON content schema for a layout id. Use before "
"saving slide content to validate structure."
),
schema=GetContentSchemaFromLayoutIdInput,
strict=True,
),
Tool(
name="generateImage",
description=(
"Generate or fetch an image URL/path from a prompt and return "
"the usable URL/path."
),
schema=GenerateImageInput,
strict=True,
),
Tool(
name="generateIcon",
description="Search icon memory and return the most relevant icon URL.",
schema=GenerateIconInput,
strict=True,
),
Tool(
name="saveSlide",
description=(
"Save slide content for a layout. If replaceOldSlideAtIndex is "
"true, replace that index; otherwise insert as a new slide. "
"Pass content as a JSON-serialized object string and the server "
"will validate it against layout schema before save."
),
schema=SaveSlideInput,
strict=True,
),
]
async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]:
handler = self._tool_handlers.get(tool_call.name)
if not handler:
return {
"ok": False,
"tool": tool_call.name,
"error": f"Unsupported tool: {tool_call.name}",
}
try:
parsed_args = self._parse_args(tool_call.arguments)
LOGGER.info("Executing chat tool %s", tool_call.name)
result = await handler(parsed_args)
return {"ok": True, "tool": tool_call.name, "result": result}
except Exception as exc:
LOGGER.exception("Chat tool failed: %s", tool_call.name)
return {
"ok": False,
"tool": tool_call.name,
"error": str(exc),
}
async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]:
outline = await self._memory.get("presentation_outline")
if not isinstance(outline, dict):
return {
"found": False,
"message": "Presentation outline is not available in memory yet.",
"sections": [],
}
slides = outline.get("slides")
if not isinstance(slides, list) or not slides:
return {
"found": False,
"message": "Presentation outline exists but has no slides.",
"sections": [],
}
sections: list[dict[str, Any]] = []
for position, slide in enumerate(slides):
index = position
content = ""
if isinstance(slide, dict):
raw_index = slide.get("index")
if isinstance(raw_index, int):
index = raw_index
raw_content = slide.get("content")
if isinstance(raw_content, str):
content = raw_content
elif raw_content is not None:
try:
content = json.dumps(raw_content, ensure_ascii=False)
except Exception:
content = str(raw_content)
elif isinstance(slide, str):
content = slide
title = self._extract_title(content) or f"Slide {index + 1}"
sections.append(
{
"index": index,
"slide_number": index + 1,
"title": title,
}
)
return {
"found": True,
"slide_count": len(sections),
"sections": sections,
"source": outline.get("source", "memory"),
}
async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]:
payload = SearchSlidesInput(**args)
results = await self._memory.search(payload.query, payload.limit)
return {
"query": payload.query,
"count": len(results),
"results": results,
}
async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]:
normalized_args = dict(args)
normalized_args.setdefault("includeFullContent", False)
payload = GetSlideAtIndexInput(**normalized_args)
slide = await self._memory.get_slide_at_index(
payload.index,
include_full_content=payload.include_full_content,
)
if not slide and payload.index > 0:
# Users often refer to slides as 1-based; allow a safe fallback.
fallback_index = payload.index - 1
fallback_slide = await self._memory.get_slide_at_index(
fallback_index,
include_full_content=payload.include_full_content,
)
if fallback_slide:
return {
"found": True,
"slide": fallback_slide,
"requested_index": payload.index,
"resolved_index": fallback_index,
"note": (
"No slide found at requested index; returned one-based fallback "
f"at index {fallback_index}."
),
}
if not slide:
return {
"found": False,
"message": f"No slide found at index {payload.index}.",
}
return {
"found": True,
"slide": slide,
}
async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]:
layouts = await self._memory.get_available_layouts()
return {
"count": len(layouts),
"layouts": layouts,
}
async def _get_content_schema_from_layout_id(
self, args: dict[str, Any]
) -> dict[str, Any]:
payload = GetContentSchemaFromLayoutIdInput(**args)
schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id)
if schema is None:
return {
"found": False,
"layout_id": payload.layout_id,
"message": "Layout schema not found for the provided layout id.",
}
return {
"found": True,
"layout_id": payload.layout_id,
"content_schema": schema,
}
async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateImageInput(**args)
image_url = await self._memory.generate_image(payload.prompt)
return {
"prompt": payload.prompt,
"url": image_url,
}
async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateIconInput(**args)
icon_url = await self._memory.generate_icon(payload.query)
return {
"query": payload.query,
"url": icon_url,
}
async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]:
payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False))
raw_content = payload_args.get("content")
if isinstance(raw_content, dict):
payload_args["content"] = json.dumps(raw_content, ensure_ascii=False)
payload = SaveSlideInput(**payload_args)
try:
content_parsed: Any = dirtyjson.loads(payload.content)
except Exception:
content_parsed = json.loads(payload.content)
if not isinstance(content_parsed, dict):
raise ValueError("'content' must be a JSON object.")
content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False))
return await self._memory.save_slide(
content=content_payload,
layout_id=payload.layout_id,
index=payload.index,
replace_old_slide_at_index=payload.replace_old_slide_at_index,
)
@staticmethod
def _parse_args(arguments: str | None) -> dict[str, Any]:
if not arguments:
return {}
try:
parsed = dirtyjson.loads(arguments)
except Exception:
parsed = json.loads(arguments)
normalized = json.loads(json.dumps(parsed, ensure_ascii=False))
if isinstance(normalized, dict):
return normalized
raise ValueError("Tool arguments must be a JSON object.")
@staticmethod
def _extract_title(markdown_content: str) -> str:
for line in markdown_content.splitlines():
stripped = line.strip()
if not stripped:
continue
heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped)
if heading_match:
return heading_match.group(1).strip()
return stripped[:120]
return ""
@staticmethod
def _truncate(value: str, limit: int) -> str:
if len(value) <= limit:
return value
return f"{value[:limit]}..."