presenton/servers/fastapi/services/chat/tools.py
sudipnext d1aeff6f82 feat: Revise prompt and tool descriptions for clarity and accuracy
- Updated the system prompt to provide clearer guidance on the use of deck memory and tools, emphasizing the distinction between historical context and current slide data.
- Enhanced tool descriptions in the ChatTools class to clarify their intended use cases, focusing on live slide data and the importance of using tools for current content.
- Improved instructions on handling slide indices and the retrieval of slide information, ensuring users understand the differences between memory and live data.
2026-04-26 18:28:27 +05:45

356 lines
13 KiB
Python

import json
import logging
import re
from typing import Any, Awaitable, Callable
import dirtyjson # type: ignore[import-untyped]
from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found]
from services.chat.schemas import (
GenerateAssetsInput,
GenerateIconInput,
GenerateImageInput,
GetContentSchemaFromLayoutIdInput,
GetSlideAtIndexInput,
NoArgsInput,
SaveSlideInput,
SearchSlidesInput,
)
from services.chat.presentation_context_store import PresentationContextStore
LOGGER = logging.getLogger(__name__)
ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
class ChatTools:
def __init__(self, memory: PresentationContextStore):
self._memory = memory
self._tool_handlers: dict[str, ToolHandler] = {
"getPresentationOutline": self._get_presentation_outline,
"searchSlides": self._search_slides,
"getSlideAtIndex": self._get_slide_at_index,
"getAvailableLayouts": self._get_available_layouts,
"getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id,
"generateAssets": self._generate_assets,
"generateImage": self._generate_image,
"generateIcon": self._generate_icon,
"saveSlide": self._save_slide,
}
def get_tool_definitions(self) -> list[Tool]:
return [
Tool(
name="getPresentationOutline",
description=(
"Live database: current deck structure. "
"Use for the **actual** slide list/order and compact previews—not for uploaded PDF text or pre-outline RAG. "
"Falls back to stored outlines only if no slide rows exist. "
"Return compact sections (no full slide JSON). Use for flow, sections, or 'what slides exist'."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="searchSlides",
description=(
"Live SQL slides: keyword/semantic style search with snippets and indices. "
"Use to find on-slide text, topics, or which slide mentioned something. "
"For source-document-only questions, rely on deck memory; use this when the question is about **slides as built**. "
"Always provide both query and limit."
),
schema=SearchSlidesInput,
strict=True,
),
Tool(
name="getSlideAtIndex",
description=(
"Live SQL: one slide by index—authoritative for exact current content. "
"Set includeFullContent=true when you need full JSON (before saveSlide or precise edits). "
"If user says slide N, use zero-based index N-1."
),
schema=GetSlideAtIndexInput,
strict=True,
),
Tool(
name="getAvailableLayouts",
description=(
"List all available layout ids and descriptions for the current "
"presentation template."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="getContentSchemaFromLayoutId",
description=(
"Fetch the JSON content schema for a layout id. Use before "
"saving slide content to validate structure."
),
schema=GetContentSchemaFromLayoutIdInput,
strict=True,
),
Tool(
name="generateAssets",
description=(
"Generate multiple media assets in one call. Use for all slide "
"images and icons before saving content; include every needed "
"asset in the assets array instead of calling image/icon tools "
"one at a time."
),
schema=GenerateAssetsInput,
strict=True,
),
Tool(
name="saveSlide",
description=(
"Save slide content for a layout. If replaceOldSlideAtIndex is "
"true, replace that index; otherwise insert as a new slide. "
"Pass content as a JSON-serialized object string and the server "
"will validate it against layout schema before save."
),
schema=SaveSlideInput,
strict=True,
),
]
async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]:
handler = self._tool_handlers.get(tool_call.name)
if not handler:
return {
"ok": False,
"tool": tool_call.name,
"error": f"Unsupported tool: {tool_call.name}",
}
try:
parsed_args = self._parse_args(tool_call.arguments)
LOGGER.info("Executing chat tool %s", tool_call.name)
result = await handler(parsed_args)
return {"ok": True, "tool": tool_call.name, "result": result}
except Exception as exc:
LOGGER.exception("Chat tool failed: %s", tool_call.name)
return {
"ok": False,
"tool": tool_call.name,
"error": str(exc),
}
async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]:
outline = await self._memory.get("presentation_outline")
if not isinstance(outline, dict):
return {
"found": False,
"message": "Presentation outline is not available in memory yet.",
"sections": [],
}
slides = outline.get("slides")
if not isinstance(slides, list) or not slides:
return {
"found": False,
"message": "Presentation outline exists but has no slides.",
"sections": [],
}
sections: list[dict[str, Any]] = []
for position, slide in enumerate(slides):
index = position
content = ""
if isinstance(slide, dict):
raw_index = slide.get("index")
if isinstance(raw_index, int):
index = raw_index
raw_content = slide.get("content")
if isinstance(raw_content, str):
content = raw_content
elif raw_content is not None:
try:
content = json.dumps(raw_content, ensure_ascii=False)
except Exception:
content = str(raw_content)
elif isinstance(slide, str):
content = slide
title = self._extract_title(content) or f"Slide {index + 1}"
sections.append(
{
"index": index,
"slide_number": index + 1,
"title": title,
}
)
return {
"found": True,
"slide_count": len(sections),
"sections": sections,
"source": outline.get("source", "memory"),
}
async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]:
payload = SearchSlidesInput(**args)
results = await self._memory.search(payload.query, payload.limit)
return {
"query": payload.query,
"count": len(results),
"results": results,
}
async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]:
normalized_args = dict(args)
normalized_args.setdefault("includeFullContent", False)
payload = GetSlideAtIndexInput(**normalized_args)
slide = await self._memory.get_slide_at_index(
payload.index,
include_full_content=payload.include_full_content,
)
if not slide and payload.index > 0:
# Users often refer to slides as 1-based; allow a safe fallback.
fallback_index = payload.index - 1
fallback_slide = await self._memory.get_slide_at_index(
fallback_index,
include_full_content=payload.include_full_content,
)
if fallback_slide:
return {
"found": True,
"slide": fallback_slide,
"requested_index": payload.index,
"resolved_index": fallback_index,
"note": (
"No slide found at requested index; returned one-based fallback "
f"at index {fallback_index}."
),
}
if not slide:
return {
"found": False,
"message": f"No slide found at index {payload.index}.",
}
return {
"found": True,
"slide": slide,
}
async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]:
layouts = await self._memory.get_available_layouts()
return {
"count": len(layouts),
"layouts": layouts,
}
async def _get_content_schema_from_layout_id(
self, args: dict[str, Any]
) -> dict[str, Any]:
payload = GetContentSchemaFromLayoutIdInput(**args)
schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id)
if schema is None:
return {
"found": False,
"layout_id": payload.layout_id,
"message": "Layout schema not found for the provided layout id.",
}
return {
"found": True,
"layout_id": payload.layout_id,
"content_schema": schema,
}
async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateImageInput(**args)
image_url = await self._memory.generate_image(payload.prompt)
return {
"prompt": payload.prompt,
"url": image_url,
}
async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateIconInput(**args)
icon_url = await self._memory.generate_icon(payload.query)
return {
"query": payload.query,
"url": icon_url,
}
async def _generate_assets(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateAssetsInput(**args)
generated_assets: list[dict[str, Any]] = []
for index, asset in enumerate(payload.assets):
if asset.kind == "image":
result = await self._generate_image({"prompt": asset.prompt})
else:
result = await self._generate_icon({"query": asset.prompt})
generated_assets.append(
{
"index": index,
"kind": asset.kind,
"prompt": asset.prompt,
"url": result.get("url"),
}
)
return {
"count": len(generated_assets),
"assets": generated_assets,
"message": f"Generated {len(generated_assets)} asset(s).",
}
async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]:
payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False))
raw_content = payload_args.get("content")
if isinstance(raw_content, dict):
payload_args["content"] = json.dumps(raw_content, ensure_ascii=False)
payload = SaveSlideInput(**payload_args)
try:
content_parsed: Any = dirtyjson.loads(payload.content)
except Exception:
content_parsed = json.loads(payload.content)
if not isinstance(content_parsed, dict):
raise ValueError("'content' must be a JSON object.")
content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False))
return await self._memory.save_slide(
content=content_payload,
layout_id=payload.layout_id,
index=payload.index,
replace_old_slide_at_index=payload.replace_old_slide_at_index,
)
@staticmethod
def _parse_args(arguments: str | None) -> dict[str, Any]:
if not arguments:
return {}
try:
parsed = dirtyjson.loads(arguments)
except Exception:
parsed = json.loads(arguments)
normalized = json.loads(json.dumps(parsed, ensure_ascii=False))
if isinstance(normalized, dict):
return normalized
raise ValueError("Tool arguments must be a JSON object.")
@staticmethod
def _extract_title(markdown_content: str) -> str:
for line in markdown_content.splitlines():
stripped = line.strip()
if not stripped:
continue
heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped)
if heading_match:
return heading_match.group(1).strip()
return stripped[:120]
return ""
@staticmethod
def _truncate(value: str, limit: int) -> str:
if len(value) <= limit:
return value
return f"{value[:limit]}..."