- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks. - Refactored chat memory services to utilize a shared `mem0` client for better memory management. - Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers. - Enhanced error handling and response management in chat-related services. - Cleaned up unused code and improved overall structure for maintainability.
341 lines
12 KiB
Python
341 lines
12 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
import dirtyjson # type: ignore[import-untyped]
|
|
from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found]
|
|
|
|
from services.chat.schemas import (
|
|
GenerateIconInput,
|
|
GenerateImageInput,
|
|
GetContentSchemaFromLayoutIdInput,
|
|
GetSlideAtIndexInput,
|
|
NoArgsInput,
|
|
SaveSlideInput,
|
|
SearchSlidesInput,
|
|
)
|
|
from services.chat.presentation_context_store import PresentationContextStore
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
|
|
|
|
|
|
class ChatTools:
|
|
"""
|
|
llmai function tools for presentation chat.
|
|
|
|
Tool implementations only use the memory abstraction layer and avoid external
|
|
provider-specific logic, keeping them portable across llmai backends.
|
|
"""
|
|
|
|
def __init__(self, memory: PresentationContextStore):
|
|
self._memory = memory
|
|
self._tool_handlers: dict[str, ToolHandler] = {
|
|
"getPresentationOutline": self._get_presentation_outline,
|
|
"searchSlides": self._search_slides,
|
|
"getSlideAtIndex": self._get_slide_at_index,
|
|
"getAvailableLayouts": self._get_available_layouts,
|
|
"getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id,
|
|
"generateImage": self._generate_image,
|
|
"generateIcon": self._generate_icon,
|
|
"saveSlide": self._save_slide,
|
|
}
|
|
|
|
def get_tool_definitions(self) -> list[Tool]:
|
|
return [
|
|
Tool(
|
|
name="getPresentationOutline",
|
|
description=(
|
|
"Retrieve the current presentation outline from live slides "
|
|
"database state (with memory fallback when needed). "
|
|
"Return compact sections (no full slide JSON) to save context "
|
|
"window. Use when the user asks about sections, flow, or slide plan."
|
|
),
|
|
schema=NoArgsInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="searchSlides",
|
|
description=(
|
|
"Search SQL slides by semantic intent/keywords and return "
|
|
"compact relevant snippets with slide identifiers. "
|
|
"Always provide both query and limit."
|
|
),
|
|
schema=SearchSlidesInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="getSlideAtIndex",
|
|
description=(
|
|
"Retrieve a single slide by zero-based index, including its "
|
|
"layout id and compact preview by default. "
|
|
"Set includeFullContent=true only when full JSON is explicitly needed "
|
|
"(for example before editing existing content). "
|
|
"If user says slide N, convert to zero-based index N-1."
|
|
),
|
|
schema=GetSlideAtIndexInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="getAvailableLayouts",
|
|
description=(
|
|
"List all available layout ids and descriptions for the current "
|
|
"presentation template."
|
|
),
|
|
schema=NoArgsInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="getContentSchemaFromLayoutId",
|
|
description=(
|
|
"Fetch the JSON content schema for a layout id. Use before "
|
|
"saving slide content to validate structure."
|
|
),
|
|
schema=GetContentSchemaFromLayoutIdInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="generateImage",
|
|
description=(
|
|
"Generate or fetch an image URL/path from a prompt and return "
|
|
"the usable URL/path."
|
|
),
|
|
schema=GenerateImageInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="generateIcon",
|
|
description="Search icon memory and return the most relevant icon URL.",
|
|
schema=GenerateIconInput,
|
|
strict=True,
|
|
),
|
|
Tool(
|
|
name="saveSlide",
|
|
description=(
|
|
"Save slide content for a layout. If replaceOldSlideAtIndex is "
|
|
"true, replace that index; otherwise insert as a new slide. "
|
|
"Pass content as a JSON-serialized object string and the server "
|
|
"will validate it against layout schema before save."
|
|
),
|
|
schema=SaveSlideInput,
|
|
strict=True,
|
|
),
|
|
]
|
|
|
|
async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]:
|
|
handler = self._tool_handlers.get(tool_call.name)
|
|
if not handler:
|
|
return {
|
|
"ok": False,
|
|
"tool": tool_call.name,
|
|
"error": f"Unsupported tool: {tool_call.name}",
|
|
}
|
|
|
|
try:
|
|
parsed_args = self._parse_args(tool_call.arguments)
|
|
LOGGER.info("Executing chat tool %s", tool_call.name)
|
|
result = await handler(parsed_args)
|
|
return {"ok": True, "tool": tool_call.name, "result": result}
|
|
except Exception as exc:
|
|
LOGGER.exception("Chat tool failed: %s", tool_call.name)
|
|
return {
|
|
"ok": False,
|
|
"tool": tool_call.name,
|
|
"error": str(exc),
|
|
}
|
|
|
|
async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]:
|
|
outline = await self._memory.get("presentation_outline")
|
|
if not isinstance(outline, dict):
|
|
return {
|
|
"found": False,
|
|
"message": "Presentation outline is not available in memory yet.",
|
|
"sections": [],
|
|
}
|
|
|
|
slides = outline.get("slides")
|
|
if not isinstance(slides, list) or not slides:
|
|
return {
|
|
"found": False,
|
|
"message": "Presentation outline exists but has no slides.",
|
|
"sections": [],
|
|
}
|
|
|
|
sections: list[dict[str, Any]] = []
|
|
for position, slide in enumerate(slides):
|
|
index = position
|
|
content = ""
|
|
if isinstance(slide, dict):
|
|
raw_index = slide.get("index")
|
|
if isinstance(raw_index, int):
|
|
index = raw_index
|
|
raw_content = slide.get("content")
|
|
if isinstance(raw_content, str):
|
|
content = raw_content
|
|
elif raw_content is not None:
|
|
try:
|
|
content = json.dumps(raw_content, ensure_ascii=False)
|
|
except Exception:
|
|
content = str(raw_content)
|
|
elif isinstance(slide, str):
|
|
content = slide
|
|
|
|
title = self._extract_title(content) or f"Slide {index + 1}"
|
|
sections.append(
|
|
{
|
|
"index": index,
|
|
"slide_number": index + 1,
|
|
"title": title,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"found": True,
|
|
"slide_count": len(sections),
|
|
"sections": sections,
|
|
"source": outline.get("source", "memory"),
|
|
}
|
|
|
|
async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]:
|
|
payload = SearchSlidesInput(**args)
|
|
results = await self._memory.search(payload.query, payload.limit)
|
|
return {
|
|
"query": payload.query,
|
|
"count": len(results),
|
|
"results": results,
|
|
}
|
|
|
|
async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]:
|
|
normalized_args = dict(args)
|
|
normalized_args.setdefault("includeFullContent", False)
|
|
payload = GetSlideAtIndexInput(**normalized_args)
|
|
slide = await self._memory.get_slide_at_index(
|
|
payload.index,
|
|
include_full_content=payload.include_full_content,
|
|
)
|
|
if not slide and payload.index > 0:
|
|
# Users often refer to slides as 1-based; allow a safe fallback.
|
|
fallback_index = payload.index - 1
|
|
fallback_slide = await self._memory.get_slide_at_index(
|
|
fallback_index,
|
|
include_full_content=payload.include_full_content,
|
|
)
|
|
if fallback_slide:
|
|
return {
|
|
"found": True,
|
|
"slide": fallback_slide,
|
|
"requested_index": payload.index,
|
|
"resolved_index": fallback_index,
|
|
"note": (
|
|
"No slide found at requested index; returned one-based fallback "
|
|
f"at index {fallback_index}."
|
|
),
|
|
}
|
|
if not slide:
|
|
return {
|
|
"found": False,
|
|
"message": f"No slide found at index {payload.index}.",
|
|
}
|
|
return {
|
|
"found": True,
|
|
"slide": slide,
|
|
}
|
|
|
|
async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]:
|
|
layouts = await self._memory.get_available_layouts()
|
|
return {
|
|
"count": len(layouts),
|
|
"layouts": layouts,
|
|
}
|
|
|
|
async def _get_content_schema_from_layout_id(
|
|
self, args: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
payload = GetContentSchemaFromLayoutIdInput(**args)
|
|
schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id)
|
|
if schema is None:
|
|
return {
|
|
"found": False,
|
|
"layout_id": payload.layout_id,
|
|
"message": "Layout schema not found for the provided layout id.",
|
|
}
|
|
return {
|
|
"found": True,
|
|
"layout_id": payload.layout_id,
|
|
"content_schema": schema,
|
|
}
|
|
|
|
async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]:
|
|
payload = GenerateImageInput(**args)
|
|
image_url = await self._memory.generate_image(payload.prompt)
|
|
return {
|
|
"prompt": payload.prompt,
|
|
"url": image_url,
|
|
}
|
|
|
|
async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]:
|
|
payload = GenerateIconInput(**args)
|
|
icon_url = await self._memory.generate_icon(payload.query)
|
|
return {
|
|
"query": payload.query,
|
|
"url": icon_url,
|
|
}
|
|
|
|
async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]:
|
|
payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False))
|
|
raw_content = payload_args.get("content")
|
|
if isinstance(raw_content, dict):
|
|
payload_args["content"] = json.dumps(raw_content, ensure_ascii=False)
|
|
|
|
payload = SaveSlideInput(**payload_args)
|
|
try:
|
|
content_parsed: Any = dirtyjson.loads(payload.content)
|
|
except Exception:
|
|
content_parsed = json.loads(payload.content)
|
|
|
|
if not isinstance(content_parsed, dict):
|
|
raise ValueError("'content' must be a JSON object.")
|
|
|
|
content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False))
|
|
return await self._memory.save_slide(
|
|
content=content_payload,
|
|
layout_id=payload.layout_id,
|
|
index=payload.index,
|
|
replace_old_slide_at_index=payload.replace_old_slide_at_index,
|
|
)
|
|
|
|
@staticmethod
|
|
def _parse_args(arguments: str | None) -> dict[str, Any]:
|
|
if not arguments:
|
|
return {}
|
|
|
|
try:
|
|
parsed = dirtyjson.loads(arguments)
|
|
except Exception:
|
|
parsed = json.loads(arguments)
|
|
|
|
normalized = json.loads(json.dumps(parsed, ensure_ascii=False))
|
|
if isinstance(normalized, dict):
|
|
return normalized
|
|
|
|
raise ValueError("Tool arguments must be a JSON object.")
|
|
|
|
@staticmethod
|
|
def _extract_title(markdown_content: str) -> str:
|
|
for line in markdown_content.splitlines():
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped)
|
|
if heading_match:
|
|
return heading_match.group(1).strip()
|
|
return stripped[:120]
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _truncate(value: str, limit: int) -> str:
|
|
if len(value) <= limit:
|
|
return value
|
|
return f"{value[:limit]}..."
|