Add chat functionality to FastAPI presentation service

- Introduced a new chat endpoint for handling user messages and generating responses.
- Added models for chat message requests and responses.
- Implemented a conversation store to manage chat history.
- Integrated memory layer for retrieving presentation context.
- Created tools for accessing presentation outlines and searching slides.
- Updated dependencies to include jsonschema for validation.
- Enhanced the API router to include the new chat functionality.
This commit is contained in:
sudipnext 2026-04-24 09:34:56 +05:45
parent 9272907a30
commit efd69cc134
12 changed files with 1257 additions and 14 deletions

View file

@ -0,0 +1,26 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from models.chat import ChatMessageRequest, ChatMessageResponse
from services.chat import PresentationChatService
from services.database import get_async_session
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse)
async def chat_message(
payload: ChatMessageRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
service = PresentationChatService(
sql_session=sql_session,
presentation_id=payload.presentation_id,
conversation_id=payload.conversation_id,
)
result = await service.generate_reply(payload.message)
return ChatMessageResponse(
conversation_id=result.conversation_id,
response=result.response_text,
tool_calls=result.tool_calls,
)

View file

@ -15,6 +15,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
from api.v1.ppt.endpoints.chat import CHAT_ROUTER
from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER
from api.v1.ppt.endpoints.theme import THEMES_ROUTER
from api.v1.ppt.endpoints.theme_generate import THEME_ROUTER
@ -29,6 +30,7 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
API_V1_PPT_ROUTER.include_router(PPTX_SLIDES_ROUTER)
API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
API_V1_PPT_ROUTER.include_router(CHAT_ROUTER)
API_V1_PPT_ROUTER.include_router(SLIDE_TO_HTML_ROUTER)
API_V1_PPT_ROUTER.include_router(HTML_TO_REACT_ROUTER)
API_V1_PPT_ROUTER.include_router(HTML_EDIT_ROUTER)

View file

@ -0,0 +1,20 @@
import uuid
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
class ChatMessageRequest(BaseModel):
presentation_id: uuid.UUID
message: str = Field(min_length=1, max_length=8000)
conversation_id: Optional[uuid.UUID] = None
model_config = ConfigDict(extra="forbid")
class ChatMessageResponse(BaseModel):
conversation_id: uuid.UUID
response: str
tool_calls: list[str] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")

View file

@ -26,6 +26,7 @@ dependencies = [
"python-pptx>=1.0.2",
"sqlmodel>=0.0.24",
"llmai==0.1.9",
"jsonschema>=4.26.0",
]
[tool.uv]

View file

@ -0,0 +1,6 @@
from services.chat.service import ChatTurnResult, PresentationChatService
__all__ = [
"ChatTurnResult",
"PresentationChatService",
]

View file

@ -0,0 +1,148 @@
from datetime import datetime, timezone
import uuid
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.sql.key_value import KeyValueSqlModel
CHAT_CONVERSATION_KEY_PREFIX = "chat_conversation"
MAX_STORED_TURNS = 20
class ConversationMessage(BaseModel):
role: Literal["user", "assistant"]
content: str = Field(min_length=1)
model_config = ConfigDict(extra="forbid", strict=True)
@field_validator("content")
@classmethod
def normalize_content(cls, value: str) -> str:
normalized = value.strip()
if not normalized:
raise ValueError("Conversation message cannot be empty.")
return normalized
class ConversationPayload(BaseModel):
conversation_id: uuid.UUID
presentation_id: uuid.UUID
messages: list[ConversationMessage] = Field(default_factory=list)
updated_at: datetime
model_config = ConfigDict(extra="forbid", strict=True)
class ChatConversationStore:
def __init__(self, sql_session: AsyncSession):
self._sql_session = sql_session
async def load_history(
self,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
payload = await self._get_payload(conversation_id)
if not payload or payload.presentation_id != presentation_id:
return []
return [
{"role": message.role, "content": message.content}
for message in payload.messages
]
async def append_turn(
self,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
) -> None:
payload = await self._get_payload(conversation_id)
messages = list(payload.messages) if payload else []
messages.append(ConversationMessage(role="user", content=user_message))
messages.append(ConversationMessage(role="assistant", content=assistant_message))
max_messages = MAX_STORED_TURNS * 2
messages = messages[-max_messages:]
next_payload = ConversationPayload(
conversation_id=conversation_id,
presentation_id=presentation_id,
messages=messages,
updated_at=datetime.now(timezone.utc),
)
row = await self._get_row(conversation_id)
if row:
row.value = next_payload.model_dump(mode="json")
self._sql_session.add(row)
else:
self._sql_session.add(
KeyValueSqlModel(
key=self._conversation_key(conversation_id),
value=next_payload.model_dump(mode="json"),
)
)
await self._sql_session.commit()
async def ensure_conversation_id(
self,
conversation_id: uuid.UUID | None,
) -> uuid.UUID:
return conversation_id or uuid.uuid4()
async def _get_payload(
self, conversation_id: uuid.UUID
) -> ConversationPayload | None:
row = await self._get_row(conversation_id)
if not row or not isinstance(row.value, dict):
return None
raw_payload: dict[str, Any] = row.value
try:
return ConversationPayload.model_validate(raw_payload)
except ValidationError:
return self._coerce_payload(raw_payload)
def _coerce_payload(self, payload: dict[str, Any]) -> ConversationPayload | None:
try:
conversation_id = uuid.UUID(str(payload.get("conversation_id")))
presentation_id = uuid.UUID(str(payload.get("presentation_id")))
except Exception:
return None
raw_messages = payload.get("messages")
messages: list[ConversationMessage] = []
if isinstance(raw_messages, list):
for entry in raw_messages:
if not isinstance(entry, dict):
continue
try:
message = ConversationMessage.model_validate(entry)
messages.append(message)
except ValidationError:
continue
return ConversationPayload(
conversation_id=conversation_id,
presentation_id=presentation_id,
messages=messages[-(MAX_STORED_TURNS * 2) :],
updated_at=datetime.now(timezone.utc),
)
async def _get_row(self, conversation_id: uuid.UUID) -> KeyValueSqlModel | None:
return await self._sql_session.scalar(
select(KeyValueSqlModel).where(
KeyValueSqlModel.key == self._conversation_key(conversation_id)
)
)
@staticmethod
def _conversation_key(conversation_id: uuid.UUID) -> str:
return f"{CHAT_CONVERSATION_KEY_PREFIX}:{conversation_id}"

View file

@ -0,0 +1,452 @@
import copy
import json
import logging
import re
import uuid
from typing import Any
from jsonschema import Draft202012Validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.image_prompt import ImagePrompt
from models.sql.image_asset import ImageAsset
from models.sql.presentation import PresentationModel
from models.sql.slide import SlideModel
from services.icon_finder_service import ICON_FINDER_SERVICE
from services.image_generation_service import ImageGenerationService
from services.mem0_presentation_memory_service import MEM0_PRESENTATION_MEMORY_SERVICE
from templates.presentation_layout import SlideLayoutModel
from utils.asset_directory_utils import get_images_directory
from utils.process_slides import (
process_old_and_new_slides_and_fetch_assets,
process_slide_and_fetch_assets,
)
LOGGER = logging.getLogger(__name__)
MAX_SCHEMA_ERRORS = 10
RUNTIME_CONTENT_FIELDS = {"__speaker_note__", "__image_url__", "__icon_url__"}
class PresentationChatMemoryLayer:
"""
Memory abstraction for chat tools and context retrieval.
This layer intentionally hides where data comes from (SQL-backed persisted state
and mem0 retrieval) behind `get` and `search`-style methods so chat logic stays
decoupled from storage details.
"""
def __init__(self, sql_session: AsyncSession, presentation_id: uuid.UUID):
self._sql_session = sql_session
self._presentation_id = presentation_id
async def get(self, key: str) -> Any:
if key != "presentation_outline":
return None
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not presentation.outlines:
LOGGER.info(
"Chat memory miss for outline (presentation_id=%s)",
self._presentation_id,
)
return None
LOGGER.info(
"Chat memory hit for outline (presentation_id=%s)",
self._presentation_id,
)
return presentation.outlines
async def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]:
"""
Search stored slide memory.
We use a keyword-ranking fallback over persisted slides so this works even
when semantic/vector memories are missing. This still satisfies memory-only
lookup because it reads from local presentation memory state.
"""
trimmed_query = (query or "").strip()
if not trimmed_query:
return []
slides_result = await self._sql_session.scalars(
select(SlideModel).where(SlideModel.presentation == self._presentation_id)
)
slides = sorted(list(slides_result), key=lambda slide: slide.index)
if not slides:
LOGGER.info(
"Chat memory miss for slide search (presentation_id=%s, reason=no_slides)",
self._presentation_id,
)
return []
query_lower = trimmed_query.lower()
query_tokens = set(re.findall(r"[a-z0-9]{2,}", query_lower))
ranked: list[tuple[int, dict[str, Any]]] = []
for slide in slides:
serialized = self._serialize_slide(slide)
searchable = serialized.lower()
score = 0
if query_lower in searchable:
score += 8
if query_tokens:
score += sum(1 for token in query_tokens if token in searchable)
if score <= 0:
continue
ranked.append(
(
score,
{
"slide_id": str(slide.id),
"index": slide.index,
"layout_id": slide.layout,
"content": slide.content,
"snippet": self._build_snippet(serialized, query_lower),
"score": score,
},
)
)
ranked.sort(key=lambda item: (-item[0], item[1]["index"]))
results = [entry for _, entry in ranked[: max(1, limit)]]
LOGGER.info(
"Chat memory search completed (presentation_id=%s, query=%r, hits=%d)",
self._presentation_id,
trimmed_query,
len(results),
)
return results
async def get_slide_at_index(self, index: int) -> dict[str, Any] | None:
slide = await self._sql_session.scalar(
select(SlideModel).where(
SlideModel.presentation == self._presentation_id,
SlideModel.index == index,
)
)
if not slide:
LOGGER.info(
"Chat memory miss for slide by index (presentation_id=%s, index=%d)",
self._presentation_id,
index,
)
return None
return {
"slide_id": str(slide.id),
"index": slide.index,
"layout_id": slide.layout,
"content": slide.content,
"speaker_note": slide.speaker_note,
}
async def get_available_layouts(self) -> list[dict[str, Any]]:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not isinstance(presentation.layout, dict):
return []
try:
layout_model = presentation.get_layout()
except Exception:
LOGGER.exception(
"Failed to parse presentation layout (presentation_id=%s)",
self._presentation_id,
)
return []
return [
{
"id": layout.id,
"name": layout.name,
"description": layout.description,
}
for layout in layout_model.slides
]
async def get_content_schema_from_layout_id(self, layout_id: str) -> dict[str, Any] | None:
layout = await self._get_layout_by_id(layout_id)
if not layout:
return None
return layout.json_schema
async def generate_image(self, prompt: str) -> str:
image_generation_service = ImageGenerationService(get_images_directory())
image = await image_generation_service.generate_image(ImagePrompt(prompt=prompt))
if isinstance(image, ImageAsset):
self._sql_session.add(image)
await self._sql_session.commit()
return image.path
return str(image)
async def generate_icon(self, query: str) -> str:
icons = await ICON_FINDER_SERVICE.search_icons(query, k=1)
if icons:
return icons[0]
return "/static/icons/placeholder.svg"
async def save_slide(
self,
*,
content: dict[str, Any],
layout_id: str,
index: int,
replace_old_slide_at_index: bool,
) -> dict[str, Any]:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
return {
"saved": False,
"message": "Presentation not found.",
"validation_errors": [],
}
layout = await self._get_layout_by_id(layout_id, presentation=presentation)
if not layout:
return {
"saved": False,
"message": f"Layout '{layout_id}' was not found in this presentation.",
"validation_errors": [f"Unknown layout_id '{layout_id}'."],
}
validation_errors = self._validate_slide_content(
content=content,
schema=layout.json_schema,
)
if validation_errors:
return {
"saved": False,
"message": "Slide content failed schema validation.",
"validation_errors": validation_errors,
}
target_index = max(0, index)
image_generation_service = ImageGenerationService(get_images_directory())
if replace_old_slide_at_index:
existing_slide = await self._sql_session.scalar(
select(SlideModel).where(
SlideModel.presentation == self._presentation_id,
SlideModel.index == target_index,
)
)
if not existing_slide:
return {
"saved": False,
"message": f"No existing slide found at index {target_index} to replace.",
"validation_errors": [],
}
updated_content = copy.deepcopy(content)
new_assets = await process_old_and_new_slides_and_fetch_assets(
image_generation_service=image_generation_service,
old_slide_content=existing_slide.content or {},
new_slide_content=updated_content,
)
existing_slide.id = uuid.uuid4()
existing_slide.layout = layout_id
existing_slide.layout_group = self._resolve_layout_group(
presentation=presentation,
fallback=existing_slide.layout_group,
)
existing_slide.content = updated_content
existing_slide.speaker_note = self._extract_speaker_note(updated_content)
self._sql_session.add(existing_slide)
self._sql_session.add_all(new_assets)
await self._sql_session.commit()
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
presentation_id=self._presentation_id,
slide_index=target_index,
edit_prompt=f"[chat_tool_save_slide_replace] layout_id={layout_id}",
edited_slide_content=updated_content,
)
return {
"saved": True,
"action": "replaced",
"message": f"Slide at index {target_index} was replaced successfully.",
"slide_id": str(existing_slide.id),
"index": target_index,
}
slides_result = await self._sql_session.scalars(
select(SlideModel)
.where(SlideModel.presentation == self._presentation_id)
.order_by(SlideModel.index)
)
slides = list(slides_result)
if slides:
max_index = max(slide.index for slide in slides)
insert_index = min(target_index, max_index + 1)
slides_to_shift = [slide for slide in slides if slide.index >= insert_index]
else:
insert_index = 0
slides_to_shift = []
for slide in sorted(slides_to_shift, key=lambda each: each.index, reverse=True):
slide.index += 1
self._sql_session.add(slide)
new_slide_content = copy.deepcopy(content)
new_slide = SlideModel(
presentation=self._presentation_id,
layout_group=self._resolve_layout_group(presentation=presentation),
layout=layout_id,
index=insert_index,
content=new_slide_content,
speaker_note=self._extract_speaker_note(new_slide_content),
)
new_assets = await process_slide_and_fetch_assets(
image_generation_service=image_generation_service,
slide=new_slide,
)
self._sql_session.add(new_slide)
self._sql_session.add_all(new_assets)
await self._sql_session.commit()
await self._sql_session.refresh(new_slide)
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
presentation_id=self._presentation_id,
slide_index=insert_index,
edit_prompt=f"[chat_tool_save_slide_new] layout_id={layout_id}",
edited_slide_content=new_slide.content,
)
return {
"saved": True,
"action": "created",
"message": f"New slide saved at index {insert_index}.",
"slide_id": str(new_slide.id),
"index": insert_index,
"shifted_slide_count": len(slides_to_shift),
}
async def retrieve_context(self, query: str) -> str:
context = await MEM0_PRESENTATION_MEMORY_SERVICE.retrieve_context(
self._presentation_id,
query,
)
if context:
LOGGER.info(
"Chat memory semantic context hit (presentation_id=%s, chars=%d)",
self._presentation_id,
len(context),
)
else:
LOGGER.info(
"Chat memory semantic context miss (presentation_id=%s)",
self._presentation_id,
)
return context
async def _get_layout_by_id(
self,
layout_id: str,
presentation: PresentationModel | None = None,
) -> SlideLayoutModel | None:
if not presentation:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not isinstance(presentation.layout, dict):
return None
try:
layout_model = presentation.get_layout()
except Exception:
return None
for layout in layout_model.slides:
if layout.id == layout_id:
return layout
return None
def _validate_slide_content(
self,
*,
content: dict[str, Any],
schema: dict[str, Any],
) -> list[str]:
validation_content = self._strip_runtime_fields(content)
validator = Draft202012Validator(schema)
errors = sorted(validator.iter_errors(validation_content), key=lambda err: err.path)
if not errors:
return []
formatted_errors: list[str] = []
for err in errors[:MAX_SCHEMA_ERRORS]:
location = ".".join([str(part) for part in err.path]) or "$"
formatted_errors.append(f"{location}: {err.message}")
return formatted_errors
@staticmethod
def _strip_runtime_fields(value: Any) -> Any:
if isinstance(value, dict):
sanitized: dict[str, Any] = {}
for key, nested_value in value.items():
if key in RUNTIME_CONTENT_FIELDS:
continue
sanitized[key] = PresentationChatMemoryLayer._strip_runtime_fields(
nested_value
)
return sanitized
if isinstance(value, list):
return [
PresentationChatMemoryLayer._strip_runtime_fields(item) for item in value
]
return value
@staticmethod
def _extract_speaker_note(content: dict[str, Any]) -> str:
value = content.get("__speaker_note__")
if isinstance(value, str):
return value
return ""
@staticmethod
def _resolve_layout_group(
*,
presentation: PresentationModel,
fallback: str = "presentation",
) -> str:
if isinstance(presentation.layout, dict):
name = str(presentation.layout.get("name") or "").strip()
if name:
return name
return fallback
@staticmethod
def _serialize_slide(slide: SlideModel) -> str:
content_text = ""
try:
content_text = json.dumps(slide.content or {}, ensure_ascii=False)
except Exception:
content_text = str(slide.content)
speaker_note = slide.speaker_note or ""
return f"slide_index={slide.index}\nlayout_id={slide.layout}\n{content_text}\n{speaker_note}"
@staticmethod
def _build_snippet(text: str, query_lower: str, window: int = 320) -> str:
normalized = " ".join(text.split())
if not normalized:
return ""
offset = normalized.lower().find(query_lower)
if offset == -1:
return normalized[:window]
start = max(0, offset - window // 3)
end = min(len(normalized), start + window)
return normalized[start:end]

View file

@ -0,0 +1,23 @@
def build_system_prompt(memory_context: str) -> str:
context_block = (
"\nMemory context (use only when relevant):\n"
f"{memory_context}\n"
if memory_context
else ""
)
return (
"You are Presenton backend chat assistant.\n"
"You can call tools to access presentation memory.\n"
"- Use getPresentationOutline for outline/section questions.\n"
"- Use searchSlides for finding relevant slide content.\n"
"- Use getSlideAtIndex for full content on one known slide index.\n"
"- Use getAvailableLayouts to inspect allowed layout ids.\n"
"- Use getContentSchemaFromLayoutId before saveSlide when validating structure.\n"
"- Use generateImage and generateIcon to fetch media URLs used in content.\n"
"- Use saveSlide to create/replace slides only with schema-valid content.\n"
"- For saveSlide, send content as a JSON-serialized object string.\n"
"- After tool outputs are sufficient, stop calling tools and provide a final answer.\n"
"- If memory is missing, state that clearly and suggest next steps.\n"
"- Do not invent slide facts that are not in tool results or memory.\n"
f"{context_block}"
)

View file

@ -0,0 +1,65 @@
import json
from typing import Any
import dirtyjson # type: ignore[import-untyped]
from pydantic import BaseModel, ConfigDict, Field, field_validator
class StrictSchemaModel(BaseModel):
model_config = ConfigDict(extra="forbid", strict=True)
class NoArgsInput(StrictSchemaModel):
pass
class GetSlideAtIndexInput(StrictSchemaModel):
index: int = Field(ge=0, le=1000)
class SearchSlidesInput(StrictSchemaModel):
query: str = Field(min_length=1, max_length=1000)
limit: int = Field(ge=1, le=10)
class GetContentSchemaFromLayoutIdInput(StrictSchemaModel):
layout_id: str = Field(alias="layoutId", min_length=1, max_length=200)
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
class GenerateImageInput(StrictSchemaModel):
prompt: str = Field(min_length=1, max_length=4000)
class GenerateIconInput(StrictSchemaModel):
query: str = Field(min_length=1, max_length=1000)
class SaveSlideInput(StrictSchemaModel):
content: str = Field(
min_length=2,
max_length=200000,
description=(
"A JSON-serialized object for slide content. "
"Example: '{\"title\": \"Q4 Revenue\", \"bullets\": [\"North America +22%\"]}'"
),
)
layout_id: str = Field(alias="layoutId", min_length=1, max_length=200)
index: int = Field(ge=0, le=1000)
replace_old_slide_at_index: bool = Field(alias="replaceOldSlideAtIndex")
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
@field_validator("content")
@classmethod
def validate_content(cls, value: str) -> str:
try:
parsed: Any = dirtyjson.loads(value)
except Exception:
parsed = json.loads(value)
if not isinstance(parsed, dict):
raise ValueError("'content' must be a JSON object.")
return value

View file

@ -0,0 +1,205 @@
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass
from typing import Any
from fastapi import HTTPException
from llmai import get_client # type: ignore[import-not-found]
from llmai.shared import ( # type: ignore[import-not-found]
AssistantMessage,
Message,
SystemMessage,
ToolResponseMessage,
UserMessage,
)
from sqlalchemy.ext.asyncio import AsyncSession
from models.sql.presentation import PresentationModel
from services.chat.conversation_store import ChatConversationStore
from services.chat.memory_layer import PresentationChatMemoryLayer
from services.chat.prompts import build_system_prompt
from services.chat.tools import ChatTools
from utils.llm_client_error_handler import handle_llm_client_exceptions
from utils.llm_config import get_llm_config
from utils.llm_provider import get_model
from utils.llm_utils import extract_text, get_generate_kwargs
LOGGER = logging.getLogger(__name__)
MAX_TOOL_ROUNDS = 6
@dataclass(frozen=True)
class ChatTurnResult:
conversation_id: uuid.UUID
response_text: str
tool_calls: list[str]
class PresentationChatService:
def __init__(
self,
sql_session: AsyncSession,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID | None,
):
self._sql_session = sql_session
self._presentation_id = presentation_id
self._conversation_id = conversation_id
self._conversation_store = ChatConversationStore(sql_session)
self._memory = PresentationChatMemoryLayer(sql_session, presentation_id)
self._tools = ChatTools(self._memory)
async def generate_reply(self, user_message: str) -> ChatTurnResult:
if not (user_message or "").strip():
raise HTTPException(status_code=400, detail="Message is required")
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
raise HTTPException(status_code=404, detail="Presentation not found")
conversation_id = await self._conversation_store.ensure_conversation_id(
self._conversation_id
)
history = await self._conversation_store.load_history(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
)
history_messages = self._convert_history_to_messages(history)
memory_context = await self._memory.retrieve_context(user_message)
messages: list[Message] = [
SystemMessage(content=build_system_prompt(memory_context)),
*history_messages,
UserMessage(content=user_message),
]
response_text, tool_calls = await self._run_llm_with_tools(messages)
await self._conversation_store.append_turn(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=response_text,
)
return ChatTurnResult(
conversation_id=conversation_id,
response_text=response_text,
tool_calls=tool_calls,
)
async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]:
# llmai is the only LLM entrypoint; provider selection comes from app config.
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
for _ in range(MAX_TOOL_ROUNDS):
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
),
)
except Exception as exc:
raise handle_llm_client_exceptions(exc)
if not response.tool_calls:
response_text = extract_text(response.content) or (
"I could not generate a response for that request."
)
return response_text, called_tools
called_tools.extend([tool_call.name for tool_call in response.tool_calls])
# Reuse llmai-returned threaded messages so provider adapters keep state.
messages = list(response.messages) if response.messages else list(messages)
last_tool_results = []
for tool_call in response.tool_calls:
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
# Tool responses are fed back into llmai to let the model continue.
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[tool_response_content],
)
)
LOGGER.warning("Max tool rounds reached in chat flow")
final_response = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if final_response:
return final_response, called_tools
return self._build_tool_limit_fallback(last_tool_results), called_tools
async def _try_final_response_without_tools(
self,
*,
client: Any,
model: str,
messages: list[Message],
) -> str | None:
"""
Give the model one final chance to synthesize a natural-language answer
from already-executed tool outputs, without allowing more tool calls.
"""
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
),
)
except Exception:
LOGGER.warning("Final no-tool synthesis call failed", exc_info=True)
return None
return extract_text(response.content)
@staticmethod
def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str:
for entry in reversed(last_tool_results):
if not isinstance(entry, dict):
continue
if not entry.get("ok"):
continue
result = entry.get("result")
if not isinstance(result, dict):
continue
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
return (
"I completed several tool operations but could not finalize the response "
"within the tool limit. Please ask a follow-up and I will continue."
)
@staticmethod
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
messages: list[Message] = []
for item in history:
role = item.get("role")
content = item.get("content")
if not content:
continue
if role == "user":
messages.append(UserMessage(content=content))
elif role == "assistant":
messages.append(AssistantMessage(content=[content]))
return messages

View file

@ -0,0 +1,302 @@
import json
import logging
import re
from typing import Any, Awaitable, Callable
import dirtyjson # type: ignore[import-untyped]
from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found]
from services.chat.schemas import (
GenerateIconInput,
GenerateImageInput,
GetContentSchemaFromLayoutIdInput,
GetSlideAtIndexInput,
NoArgsInput,
SaveSlideInput,
SearchSlidesInput,
)
from services.chat.memory_layer import PresentationChatMemoryLayer
LOGGER = logging.getLogger(__name__)
ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
class ChatTools:
"""
llmai function tools for presentation chat.
Tool implementations only use the memory abstraction layer and avoid external
provider-specific logic, keeping them portable across llmai backends.
"""
def __init__(self, memory: PresentationChatMemoryLayer):
self._memory = memory
self._tool_handlers: dict[str, ToolHandler] = {
"getPresentationOutline": self._get_presentation_outline,
"searchSlides": self._search_slides,
"getSlideAtIndex": self._get_slide_at_index,
"getAvailableLayouts": self._get_available_layouts,
"getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id,
"generateImage": self._generate_image,
"generateIcon": self._generate_icon,
"saveSlide": self._save_slide,
}
def get_tool_definitions(self) -> list[Tool]:
return [
Tool(
name="getPresentationOutline",
description=(
"Retrieve the current presentation outline from memory. "
"Use when the user asks about sections, flow, or slide plan."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="searchSlides",
description=(
"Search slide memory by semantic intent or keywords and return "
"relevant slide snippets with identifiers. "
"Always provide both query and limit."
),
schema=SearchSlidesInput,
strict=True,
),
Tool(
name="getSlideAtIndex",
description=(
"Retrieve a single slide by zero-based index, including its "
"layout id and current structured content."
),
schema=GetSlideAtIndexInput,
strict=True,
),
Tool(
name="getAvailableLayouts",
description=(
"List all available layout ids and descriptions for the current "
"presentation template."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="getContentSchemaFromLayoutId",
description=(
"Fetch the JSON content schema for a layout id. Use before "
"saving slide content to validate structure."
),
schema=GetContentSchemaFromLayoutIdInput,
strict=True,
),
Tool(
name="generateImage",
description=(
"Generate or fetch an image URL/path from a prompt and return "
"the usable URL/path."
),
schema=GenerateImageInput,
strict=True,
),
Tool(
name="generateIcon",
description="Search icon memory and return the most relevant icon URL.",
schema=GenerateIconInput,
strict=True,
),
Tool(
name="saveSlide",
description=(
"Save slide content for a layout. If replaceOldSlideAtIndex is "
"true, replace that index; otherwise insert as a new slide. "
"Pass content as a JSON-serialized object string and the server "
"will validate it against layout schema before save."
),
schema=SaveSlideInput,
strict=True,
),
]
async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]:
handler = self._tool_handlers.get(tool_call.name)
if not handler:
return {
"ok": False,
"tool": tool_call.name,
"error": f"Unsupported tool: {tool_call.name}",
}
try:
parsed_args = self._parse_args(tool_call.arguments)
LOGGER.info("Executing chat tool %s", tool_call.name)
result = await handler(parsed_args)
return {"ok": True, "tool": tool_call.name, "result": result}
except Exception as exc:
LOGGER.exception("Chat tool failed: %s", tool_call.name)
return {
"ok": False,
"tool": tool_call.name,
"error": str(exc),
}
async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]:
outline = await self._memory.get("presentation_outline")
if not isinstance(outline, dict):
return {
"found": False,
"message": "Presentation outline is not available in memory yet.",
"sections": [],
}
slides = outline.get("slides")
if not isinstance(slides, list) or not slides:
return {
"found": False,
"message": "Presentation outline exists but has no slides.",
"sections": [],
}
sections: list[dict[str, Any]] = []
for index, slide in enumerate(slides):
content = ""
if isinstance(slide, dict):
content = str(slide.get("content") or "")
elif isinstance(slide, str):
content = slide
title = self._extract_title(content) or f"Slide {index + 1}"
sections.append(
{
"index": index,
"title": title,
"preview": self._truncate(" ".join(content.split()), 220),
}
)
return {
"found": True,
"slide_count": len(sections),
"sections": sections,
"outline": outline,
}
async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]:
payload = SearchSlidesInput(**args)
results = await self._memory.search(payload.query, payload.limit)
return {
"query": payload.query,
"count": len(results),
"results": results,
}
async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GetSlideAtIndexInput(**args)
slide = await self._memory.get_slide_at_index(payload.index)
if not slide:
return {
"found": False,
"message": f"No slide found at index {payload.index}.",
}
return {
"found": True,
"slide": slide,
}
async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]:
layouts = await self._memory.get_available_layouts()
return {
"count": len(layouts),
"layouts": layouts,
}
async def _get_content_schema_from_layout_id(
self, args: dict[str, Any]
) -> dict[str, Any]:
payload = GetContentSchemaFromLayoutIdInput(**args)
schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id)
if schema is None:
return {
"found": False,
"layout_id": payload.layout_id,
"message": "Layout schema not found for the provided layout id.",
}
return {
"found": True,
"layout_id": payload.layout_id,
"content_schema": schema,
}
async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateImageInput(**args)
image_url = await self._memory.generate_image(payload.prompt)
return {
"prompt": payload.prompt,
"url": image_url,
}
async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateIconInput(**args)
icon_url = await self._memory.generate_icon(payload.query)
return {
"query": payload.query,
"url": icon_url,
}
async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]:
payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False))
raw_content = payload_args.get("content")
if isinstance(raw_content, dict):
payload_args["content"] = json.dumps(raw_content, ensure_ascii=False)
payload = SaveSlideInput(**payload_args)
try:
content_parsed: Any = dirtyjson.loads(payload.content)
except Exception:
content_parsed = json.loads(payload.content)
if not isinstance(content_parsed, dict):
raise ValueError("'content' must be a JSON object.")
content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False))
return await self._memory.save_slide(
content=content_payload,
layout_id=payload.layout_id,
index=payload.index,
replace_old_slide_at_index=payload.replace_old_slide_at_index,
)
@staticmethod
def _parse_args(arguments: str | None) -> dict[str, Any]:
if not arguments:
return {}
try:
parsed = dirtyjson.loads(arguments)
except Exception:
parsed = json.loads(arguments)
normalized = json.loads(json.dumps(parsed, ensure_ascii=False))
if isinstance(normalized, dict):
return normalized
raise ValueError("Tool arguments must be a JSON object.")
@staticmethod
def _extract_title(markdown_content: str) -> str:
for line in markdown_content.splitlines():
stripped = line.strip()
if not stripped:
continue
heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped)
if heading_match:
return heading_match.group(1).strip()
return stripped[:120]
return ""
@staticmethod
def _truncate(value: str, limit: int) -> str:
if len(value) <= limit:
return value
return f"{value[:limit]}..."

View file

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = "==3.11.*"
[[package]]
@ -811,9 +811,7 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/c6/dba32cab7e3a625b011aa5647486e2d28423a48845a2998c126dd69c85e1/greenlet-3.4.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:805bebb4945094acbab757d34d6e1098be6de8966009ab9ca54f06ff492def58", size = 285504, upload-time = "2026-04-08T15:52:14.071Z" },
{ url = "https://files.pythonhosted.org/packages/54/f4/7cb5c2b1feb9a1f50e038be79980dfa969aa91979e5e3a18fdbcfad2c517/greenlet-3.4.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:439fc2f12b9b512d9dfa681c5afe5f6b3232c708d13e6f02c845e0d9f4c2d8c6", size = 605476, upload-time = "2026-04-08T16:24:37.064Z" },
{ url = "https://files.pythonhosted.org/packages/d6/af/b66ab0b2f9a4c5a867c136bf66d9599f34f21a1bcca26a2884a29c450bd9/greenlet-3.4.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a70ed1cb0295bee1df57b63bf7f46b4e56a5c93709eea769c1fec1bb23a95875", size = 618336, upload-time = "2026-04-08T16:30:56.59Z" },
{ url = "https://files.pythonhosted.org/packages/6d/31/56c43d2b5de476f77d36ceeec436328533bff960a4cba9a07616e93063ab/greenlet-3.4.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8c5696c42e6bb5cfb7c6ff4453789081c66b9b91f061e5e9367fa15792644e76", size = 625045, upload-time = "2026-04-08T16:40:37.111Z" },
{ url = "https://files.pythonhosted.org/packages/e5/5c/8c5633ece6ba611d64bf2770219a98dd439921d6424e4e8cf16b0ac74ea5/greenlet-3.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c660bce1940a1acae5f51f0a064f1bc785d07ea16efcb4bc708090afc4d69e83", size = 613515, upload-time = "2026-04-08T15:56:32.478Z" },
{ url = "https://files.pythonhosted.org/packages/80/ca/704d4e2c90acb8bdf7ae593f5cbc95f58e82de95cc540fb75631c1054533/greenlet-3.4.0-cp311-cp311-manylinux_2_39_riscv64.whl", hash = "sha256:89995ce5ddcd2896d89615116dd39b9703bfa0c07b583b85b89bf1b5d6eddf81", size = 419745, upload-time = "2026-04-08T16:43:04.022Z" },
{ url = "https://files.pythonhosted.org/packages/a9/df/950d15bca0d90a0e7395eb777903060504cdb509b7b705631e8fb69ff415/greenlet-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ee407d4d1ca9dc632265aee1c8732c4a2d60adff848057cdebfe5fe94eb2c8a2", size = 1574623, upload-time = "2026-04-08T16:26:18.596Z" },
{ url = "https://files.pythonhosted.org/packages/1a/e7/0839afab829fcb7333c9ff6d80c040949510055d2d4d63251f0d1c7c804e/greenlet-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:956215d5e355fffa7c021d168728321fd4d31fd730ac609b1653b450f6a4bc71", size = 1639579, upload-time = "2026-04-08T15:57:29.231Z" },
{ url = "https://files.pythonhosted.org/packages/d9/2b/b4482401e9bcaf9f5c97f67ead38db89c19520ff6d0d6699979c6efcc200/greenlet-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:5cb614ace7c27571270354e9c9f696554d073f8aa9319079dcba466bbdead711", size = 238233, upload-time = "2026-04-08T17:02:54.286Z" },
@ -1188,23 +1186,16 @@ wheels = [
[[package]]
name = "llmai"
version = "0.1.9"
source = { url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl" }
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anthropic" },
{ name = "boto3" },
{ name = "google-genai" },
{ name = "openai" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4" },
]
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = ">=0.79.0" },
{ name = "boto3", specifier = ">=1.42.89" },
{ name = "google-genai", specifier = ">=1.62.0" },
{ name = "openai", specifier = ">=2.18.0" },
{ url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" },
]
[[package]]
@ -1671,6 +1662,7 @@ dependencies = [
{ name = "fastembed-vectorstore" },
{ name = "fastmcp" },
{ name = "google-genai" },
{ name = "jsonschema" },
{ name = "llmai" },
{ name = "mem0ai", extra = ["nlp"] },
{ name = "nltk" },
@ -1693,7 +1685,8 @@ requires-dist = [
{ name = "fastembed-vectorstore", specifier = ">=0.5.2" },
{ name = "fastmcp", specifier = ">=2.11.0" },
{ name = "google-genai", specifier = ">=1.28.0" },
{ name = "llmai", url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl" },
{ name = "jsonschema", specifier = ">=4.26.0" },
{ name = "llmai", specifier = "==0.1.9" },
{ name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" },
{ name = "nltk", specifier = ">=3.9.1" },
{ name = "openai", specifier = ">=1.98.0" },