From efd69cc13479747d80cbeb784661927d18c68a7e Mon Sep 17 00:00:00 2001 From: sudipnext Date: Fri, 24 Apr 2026 09:34:56 +0545 Subject: [PATCH] Add chat functionality to FastAPI presentation service - Introduced a new chat endpoint for handling user messages and generating responses. - Added models for chat message requests and responses. - Implemented a conversation store to manage chat history. - Integrated memory layer for retrieving presentation context. - Created tools for accessing presentation outlines and searching slides. - Updated dependencies to include jsonschema for validation. - Enhanced the API router to include the new chat functionality. --- servers/fastapi/api/v1/ppt/endpoints/chat.py | 26 + servers/fastapi/api/v1/ppt/router.py | 2 + servers/fastapi/models/chat.py | 20 + servers/fastapi/pyproject.toml | 1 + servers/fastapi/services/chat/__init__.py | 6 + .../services/chat/conversation_store.py | 148 ++++++ servers/fastapi/services/chat/memory_layer.py | 452 ++++++++++++++++++ servers/fastapi/services/chat/prompts.py | 23 + servers/fastapi/services/chat/schemas.py | 65 +++ servers/fastapi/services/chat/service.py | 205 ++++++++ servers/fastapi/services/chat/tools.py | 302 ++++++++++++ servers/fastapi/uv.lock | 21 +- 12 files changed, 1257 insertions(+), 14 deletions(-) create mode 100644 servers/fastapi/api/v1/ppt/endpoints/chat.py create mode 100644 servers/fastapi/models/chat.py create mode 100644 servers/fastapi/services/chat/__init__.py create mode 100644 servers/fastapi/services/chat/conversation_store.py create mode 100644 servers/fastapi/services/chat/memory_layer.py create mode 100644 servers/fastapi/services/chat/prompts.py create mode 100644 servers/fastapi/services/chat/schemas.py create mode 100644 servers/fastapi/services/chat/service.py create mode 100644 servers/fastapi/services/chat/tools.py diff --git a/servers/fastapi/api/v1/ppt/endpoints/chat.py b/servers/fastapi/api/v1/ppt/endpoints/chat.py new file mode 100644 index 00000000..76b744a5 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/chat.py @@ -0,0 +1,26 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from models.chat import ChatMessageRequest, ChatMessageResponse +from services.chat import PresentationChatService +from services.database import get_async_session + +CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"]) + + +@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse) +async def chat_message( + payload: ChatMessageRequest, + sql_session: AsyncSession = Depends(get_async_session), +): + service = PresentationChatService( + sql_session=sql_session, + presentation_id=payload.presentation_id, + conversation_id=payload.conversation_id, + ) + result = await service.generate_reply(payload.message) + return ChatMessageResponse( + conversation_id=result.conversation_id, + response=result.response_text, + tool_calls=result.tool_calls, + ) diff --git a/servers/fastapi/api/v1/ppt/router.py b/servers/fastapi/api/v1/ppt/router.py index 42b2812b..ba159bb8 100644 --- a/servers/fastapi/api/v1/ppt/router.py +++ b/servers/fastapi/api/v1/ppt/router.py @@ -15,6 +15,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER from api.v1.ppt.endpoints.slide import SLIDE_ROUTER +from api.v1.ppt.endpoints.chat import CHAT_ROUTER from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER from api.v1.ppt.endpoints.theme import THEMES_ROUTER from api.v1.ppt.endpoints.theme_generate import THEME_ROUTER @@ -29,6 +30,7 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER) API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER) API_V1_PPT_ROUTER.include_router(PPTX_SLIDES_ROUTER) API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER) +API_V1_PPT_ROUTER.include_router(CHAT_ROUTER) API_V1_PPT_ROUTER.include_router(SLIDE_TO_HTML_ROUTER) API_V1_PPT_ROUTER.include_router(HTML_TO_REACT_ROUTER) API_V1_PPT_ROUTER.include_router(HTML_EDIT_ROUTER) diff --git a/servers/fastapi/models/chat.py b/servers/fastapi/models/chat.py new file mode 100644 index 00000000..98906118 --- /dev/null +++ b/servers/fastapi/models/chat.py @@ -0,0 +1,20 @@ +import uuid +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class ChatMessageRequest(BaseModel): + presentation_id: uuid.UUID + message: str = Field(min_length=1, max_length=8000) + conversation_id: Optional[uuid.UUID] = None + + model_config = ConfigDict(extra="forbid") + + +class ChatMessageResponse(BaseModel): + conversation_id: uuid.UUID + response: str + tool_calls: list[str] = Field(default_factory=list) + + model_config = ConfigDict(extra="forbid") diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml index b36d123b..565ab612 100644 --- a/servers/fastapi/pyproject.toml +++ b/servers/fastapi/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "python-pptx>=1.0.2", "sqlmodel>=0.0.24", "llmai==0.1.9", + "jsonschema>=4.26.0", ] [tool.uv] diff --git a/servers/fastapi/services/chat/__init__.py b/servers/fastapi/services/chat/__init__.py new file mode 100644 index 00000000..2dcbe377 --- /dev/null +++ b/servers/fastapi/services/chat/__init__.py @@ -0,0 +1,6 @@ +from services.chat.service import ChatTurnResult, PresentationChatService + +__all__ = [ + "ChatTurnResult", + "PresentationChatService", +] diff --git a/servers/fastapi/services/chat/conversation_store.py b/servers/fastapi/services/chat/conversation_store.py new file mode 100644 index 00000000..8f4f1878 --- /dev/null +++ b/servers/fastapi/services/chat/conversation_store.py @@ -0,0 +1,148 @@ +from datetime import datetime, timezone +import uuid +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from models.sql.key_value import KeyValueSqlModel + +CHAT_CONVERSATION_KEY_PREFIX = "chat_conversation" +MAX_STORED_TURNS = 20 + + +class ConversationMessage(BaseModel): + role: Literal["user", "assistant"] + content: str = Field(min_length=1) + + model_config = ConfigDict(extra="forbid", strict=True) + + @field_validator("content") + @classmethod + def normalize_content(cls, value: str) -> str: + normalized = value.strip() + if not normalized: + raise ValueError("Conversation message cannot be empty.") + return normalized + + +class ConversationPayload(BaseModel): + conversation_id: uuid.UUID + presentation_id: uuid.UUID + messages: list[ConversationMessage] = Field(default_factory=list) + updated_at: datetime + + model_config = ConfigDict(extra="forbid", strict=True) + + +class ChatConversationStore: + def __init__(self, sql_session: AsyncSession): + self._sql_session = sql_session + + async def load_history( + self, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, + ) -> list[dict[str, str]]: + payload = await self._get_payload(conversation_id) + if not payload or payload.presentation_id != presentation_id: + return [] + return [ + {"role": message.role, "content": message.content} + for message in payload.messages + ] + + async def append_turn( + self, + *, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID, + user_message: str, + assistant_message: str, + ) -> None: + payload = await self._get_payload(conversation_id) + messages = list(payload.messages) if payload else [] + + messages.append(ConversationMessage(role="user", content=user_message)) + messages.append(ConversationMessage(role="assistant", content=assistant_message)) + max_messages = MAX_STORED_TURNS * 2 + messages = messages[-max_messages:] + + next_payload = ConversationPayload( + conversation_id=conversation_id, + presentation_id=presentation_id, + messages=messages, + updated_at=datetime.now(timezone.utc), + ) + + row = await self._get_row(conversation_id) + if row: + row.value = next_payload.model_dump(mode="json") + self._sql_session.add(row) + else: + self._sql_session.add( + KeyValueSqlModel( + key=self._conversation_key(conversation_id), + value=next_payload.model_dump(mode="json"), + ) + ) + + await self._sql_session.commit() + + async def ensure_conversation_id( + self, + conversation_id: uuid.UUID | None, + ) -> uuid.UUID: + return conversation_id or uuid.uuid4() + + async def _get_payload( + self, conversation_id: uuid.UUID + ) -> ConversationPayload | None: + row = await self._get_row(conversation_id) + if not row or not isinstance(row.value, dict): + return None + + raw_payload: dict[str, Any] = row.value + try: + return ConversationPayload.model_validate(raw_payload) + except ValidationError: + return self._coerce_payload(raw_payload) + + def _coerce_payload(self, payload: dict[str, Any]) -> ConversationPayload | None: + try: + conversation_id = uuid.UUID(str(payload.get("conversation_id"))) + presentation_id = uuid.UUID(str(payload.get("presentation_id"))) + except Exception: + return None + + raw_messages = payload.get("messages") + messages: list[ConversationMessage] = [] + if isinstance(raw_messages, list): + for entry in raw_messages: + if not isinstance(entry, dict): + continue + try: + message = ConversationMessage.model_validate(entry) + messages.append(message) + except ValidationError: + continue + + return ConversationPayload( + conversation_id=conversation_id, + presentation_id=presentation_id, + messages=messages[-(MAX_STORED_TURNS * 2) :], + updated_at=datetime.now(timezone.utc), + ) + + async def _get_row(self, conversation_id: uuid.UUID) -> KeyValueSqlModel | None: + return await self._sql_session.scalar( + select(KeyValueSqlModel).where( + KeyValueSqlModel.key == self._conversation_key(conversation_id) + ) + ) + + @staticmethod + def _conversation_key(conversation_id: uuid.UUID) -> str: + return f"{CHAT_CONVERSATION_KEY_PREFIX}:{conversation_id}" diff --git a/servers/fastapi/services/chat/memory_layer.py b/servers/fastapi/services/chat/memory_layer.py new file mode 100644 index 00000000..71ecc051 --- /dev/null +++ b/servers/fastapi/services/chat/memory_layer.py @@ -0,0 +1,452 @@ +import copy +import json +import logging +import re +import uuid +from typing import Any + +from jsonschema import Draft202012Validator +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from models.image_prompt import ImagePrompt +from models.sql.image_asset import ImageAsset +from models.sql.presentation import PresentationModel +from models.sql.slide import SlideModel +from services.icon_finder_service import ICON_FINDER_SERVICE +from services.image_generation_service import ImageGenerationService +from services.mem0_presentation_memory_service import MEM0_PRESENTATION_MEMORY_SERVICE +from templates.presentation_layout import SlideLayoutModel +from utils.asset_directory_utils import get_images_directory +from utils.process_slides import ( + process_old_and_new_slides_and_fetch_assets, + process_slide_and_fetch_assets, +) + +LOGGER = logging.getLogger(__name__) +MAX_SCHEMA_ERRORS = 10 +RUNTIME_CONTENT_FIELDS = {"__speaker_note__", "__image_url__", "__icon_url__"} + + +class PresentationChatMemoryLayer: + """ + Memory abstraction for chat tools and context retrieval. + + This layer intentionally hides where data comes from (SQL-backed persisted state + and mem0 retrieval) behind `get` and `search`-style methods so chat logic stays + decoupled from storage details. + """ + + def __init__(self, sql_session: AsyncSession, presentation_id: uuid.UUID): + self._sql_session = sql_session + self._presentation_id = presentation_id + + async def get(self, key: str) -> Any: + if key != "presentation_outline": + return None + + presentation = await self._sql_session.get(PresentationModel, self._presentation_id) + if not presentation or not presentation.outlines: + LOGGER.info( + "Chat memory miss for outline (presentation_id=%s)", + self._presentation_id, + ) + return None + + LOGGER.info( + "Chat memory hit for outline (presentation_id=%s)", + self._presentation_id, + ) + return presentation.outlines + + async def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]: + """ + Search stored slide memory. + + We use a keyword-ranking fallback over persisted slides so this works even + when semantic/vector memories are missing. This still satisfies memory-only + lookup because it reads from local presentation memory state. + """ + + trimmed_query = (query or "").strip() + if not trimmed_query: + return [] + + slides_result = await self._sql_session.scalars( + select(SlideModel).where(SlideModel.presentation == self._presentation_id) + ) + slides = sorted(list(slides_result), key=lambda slide: slide.index) + if not slides: + LOGGER.info( + "Chat memory miss for slide search (presentation_id=%s, reason=no_slides)", + self._presentation_id, + ) + return [] + + query_lower = trimmed_query.lower() + query_tokens = set(re.findall(r"[a-z0-9]{2,}", query_lower)) + ranked: list[tuple[int, dict[str, Any]]] = [] + for slide in slides: + serialized = self._serialize_slide(slide) + searchable = serialized.lower() + + score = 0 + if query_lower in searchable: + score += 8 + if query_tokens: + score += sum(1 for token in query_tokens if token in searchable) + if score <= 0: + continue + + ranked.append( + ( + score, + { + "slide_id": str(slide.id), + "index": slide.index, + "layout_id": slide.layout, + "content": slide.content, + "snippet": self._build_snippet(serialized, query_lower), + "score": score, + }, + ) + ) + + ranked.sort(key=lambda item: (-item[0], item[1]["index"])) + results = [entry for _, entry in ranked[: max(1, limit)]] + LOGGER.info( + "Chat memory search completed (presentation_id=%s, query=%r, hits=%d)", + self._presentation_id, + trimmed_query, + len(results), + ) + return results + + async def get_slide_at_index(self, index: int) -> dict[str, Any] | None: + slide = await self._sql_session.scalar( + select(SlideModel).where( + SlideModel.presentation == self._presentation_id, + SlideModel.index == index, + ) + ) + if not slide: + LOGGER.info( + "Chat memory miss for slide by index (presentation_id=%s, index=%d)", + self._presentation_id, + index, + ) + return None + + return { + "slide_id": str(slide.id), + "index": slide.index, + "layout_id": slide.layout, + "content": slide.content, + "speaker_note": slide.speaker_note, + } + + async def get_available_layouts(self) -> list[dict[str, Any]]: + presentation = await self._sql_session.get(PresentationModel, self._presentation_id) + if not presentation or not isinstance(presentation.layout, dict): + return [] + + try: + layout_model = presentation.get_layout() + except Exception: + LOGGER.exception( + "Failed to parse presentation layout (presentation_id=%s)", + self._presentation_id, + ) + return [] + + return [ + { + "id": layout.id, + "name": layout.name, + "description": layout.description, + } + for layout in layout_model.slides + ] + + async def get_content_schema_from_layout_id(self, layout_id: str) -> dict[str, Any] | None: + layout = await self._get_layout_by_id(layout_id) + if not layout: + return None + return layout.json_schema + + async def generate_image(self, prompt: str) -> str: + image_generation_service = ImageGenerationService(get_images_directory()) + image = await image_generation_service.generate_image(ImagePrompt(prompt=prompt)) + + if isinstance(image, ImageAsset): + self._sql_session.add(image) + await self._sql_session.commit() + return image.path + + return str(image) + + async def generate_icon(self, query: str) -> str: + icons = await ICON_FINDER_SERVICE.search_icons(query, k=1) + if icons: + return icons[0] + return "/static/icons/placeholder.svg" + + async def save_slide( + self, + *, + content: dict[str, Any], + layout_id: str, + index: int, + replace_old_slide_at_index: bool, + ) -> dict[str, Any]: + presentation = await self._sql_session.get(PresentationModel, self._presentation_id) + if not presentation: + return { + "saved": False, + "message": "Presentation not found.", + "validation_errors": [], + } + + layout = await self._get_layout_by_id(layout_id, presentation=presentation) + if not layout: + return { + "saved": False, + "message": f"Layout '{layout_id}' was not found in this presentation.", + "validation_errors": [f"Unknown layout_id '{layout_id}'."], + } + + validation_errors = self._validate_slide_content( + content=content, + schema=layout.json_schema, + ) + if validation_errors: + return { + "saved": False, + "message": "Slide content failed schema validation.", + "validation_errors": validation_errors, + } + + target_index = max(0, index) + image_generation_service = ImageGenerationService(get_images_directory()) + + if replace_old_slide_at_index: + existing_slide = await self._sql_session.scalar( + select(SlideModel).where( + SlideModel.presentation == self._presentation_id, + SlideModel.index == target_index, + ) + ) + if not existing_slide: + return { + "saved": False, + "message": f"No existing slide found at index {target_index} to replace.", + "validation_errors": [], + } + + updated_content = copy.deepcopy(content) + new_assets = await process_old_and_new_slides_and_fetch_assets( + image_generation_service=image_generation_service, + old_slide_content=existing_slide.content or {}, + new_slide_content=updated_content, + ) + + existing_slide.id = uuid.uuid4() + existing_slide.layout = layout_id + existing_slide.layout_group = self._resolve_layout_group( + presentation=presentation, + fallback=existing_slide.layout_group, + ) + existing_slide.content = updated_content + existing_slide.speaker_note = self._extract_speaker_note(updated_content) + self._sql_session.add(existing_slide) + self._sql_session.add_all(new_assets) + await self._sql_session.commit() + + await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit( + presentation_id=self._presentation_id, + slide_index=target_index, + edit_prompt=f"[chat_tool_save_slide_replace] layout_id={layout_id}", + edited_slide_content=updated_content, + ) + + return { + "saved": True, + "action": "replaced", + "message": f"Slide at index {target_index} was replaced successfully.", + "slide_id": str(existing_slide.id), + "index": target_index, + } + + slides_result = await self._sql_session.scalars( + select(SlideModel) + .where(SlideModel.presentation == self._presentation_id) + .order_by(SlideModel.index) + ) + slides = list(slides_result) + + if slides: + max_index = max(slide.index for slide in slides) + insert_index = min(target_index, max_index + 1) + slides_to_shift = [slide for slide in slides if slide.index >= insert_index] + else: + insert_index = 0 + slides_to_shift = [] + + for slide in sorted(slides_to_shift, key=lambda each: each.index, reverse=True): + slide.index += 1 + self._sql_session.add(slide) + + new_slide_content = copy.deepcopy(content) + new_slide = SlideModel( + presentation=self._presentation_id, + layout_group=self._resolve_layout_group(presentation=presentation), + layout=layout_id, + index=insert_index, + content=new_slide_content, + speaker_note=self._extract_speaker_note(new_slide_content), + ) + new_assets = await process_slide_and_fetch_assets( + image_generation_service=image_generation_service, + slide=new_slide, + ) + + self._sql_session.add(new_slide) + self._sql_session.add_all(new_assets) + await self._sql_session.commit() + await self._sql_session.refresh(new_slide) + + await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit( + presentation_id=self._presentation_id, + slide_index=insert_index, + edit_prompt=f"[chat_tool_save_slide_new] layout_id={layout_id}", + edited_slide_content=new_slide.content, + ) + + return { + "saved": True, + "action": "created", + "message": f"New slide saved at index {insert_index}.", + "slide_id": str(new_slide.id), + "index": insert_index, + "shifted_slide_count": len(slides_to_shift), + } + + async def retrieve_context(self, query: str) -> str: + context = await MEM0_PRESENTATION_MEMORY_SERVICE.retrieve_context( + self._presentation_id, + query, + ) + if context: + LOGGER.info( + "Chat memory semantic context hit (presentation_id=%s, chars=%d)", + self._presentation_id, + len(context), + ) + else: + LOGGER.info( + "Chat memory semantic context miss (presentation_id=%s)", + self._presentation_id, + ) + return context + + async def _get_layout_by_id( + self, + layout_id: str, + presentation: PresentationModel | None = None, + ) -> SlideLayoutModel | None: + if not presentation: + presentation = await self._sql_session.get(PresentationModel, self._presentation_id) + if not presentation or not isinstance(presentation.layout, dict): + return None + + try: + layout_model = presentation.get_layout() + except Exception: + return None + + for layout in layout_model.slides: + if layout.id == layout_id: + return layout + return None + + def _validate_slide_content( + self, + *, + content: dict[str, Any], + schema: dict[str, Any], + ) -> list[str]: + validation_content = self._strip_runtime_fields(content) + validator = Draft202012Validator(schema) + errors = sorted(validator.iter_errors(validation_content), key=lambda err: err.path) + + if not errors: + return [] + + formatted_errors: list[str] = [] + for err in errors[:MAX_SCHEMA_ERRORS]: + location = ".".join([str(part) for part in err.path]) or "$" + formatted_errors.append(f"{location}: {err.message}") + return formatted_errors + + @staticmethod + def _strip_runtime_fields(value: Any) -> Any: + if isinstance(value, dict): + sanitized: dict[str, Any] = {} + for key, nested_value in value.items(): + if key in RUNTIME_CONTENT_FIELDS: + continue + sanitized[key] = PresentationChatMemoryLayer._strip_runtime_fields( + nested_value + ) + return sanitized + + if isinstance(value, list): + return [ + PresentationChatMemoryLayer._strip_runtime_fields(item) for item in value + ] + + return value + + @staticmethod + def _extract_speaker_note(content: dict[str, Any]) -> str: + value = content.get("__speaker_note__") + if isinstance(value, str): + return value + return "" + + @staticmethod + def _resolve_layout_group( + *, + presentation: PresentationModel, + fallback: str = "presentation", + ) -> str: + if isinstance(presentation.layout, dict): + name = str(presentation.layout.get("name") or "").strip() + if name: + return name + return fallback + + @staticmethod + def _serialize_slide(slide: SlideModel) -> str: + content_text = "" + try: + content_text = json.dumps(slide.content or {}, ensure_ascii=False) + except Exception: + content_text = str(slide.content) + + speaker_note = slide.speaker_note or "" + return f"slide_index={slide.index}\nlayout_id={slide.layout}\n{content_text}\n{speaker_note}" + + @staticmethod + def _build_snippet(text: str, query_lower: str, window: int = 320) -> str: + normalized = " ".join(text.split()) + if not normalized: + return "" + + offset = normalized.lower().find(query_lower) + if offset == -1: + return normalized[:window] + + start = max(0, offset - window // 3) + end = min(len(normalized), start + window) + return normalized[start:end] diff --git a/servers/fastapi/services/chat/prompts.py b/servers/fastapi/services/chat/prompts.py new file mode 100644 index 00000000..89b732ee --- /dev/null +++ b/servers/fastapi/services/chat/prompts.py @@ -0,0 +1,23 @@ +def build_system_prompt(memory_context: str) -> str: + context_block = ( + "\nMemory context (use only when relevant):\n" + f"{memory_context}\n" + if memory_context + else "" + ) + return ( + "You are Presenton backend chat assistant.\n" + "You can call tools to access presentation memory.\n" + "- Use getPresentationOutline for outline/section questions.\n" + "- Use searchSlides for finding relevant slide content.\n" + "- Use getSlideAtIndex for full content on one known slide index.\n" + "- Use getAvailableLayouts to inspect allowed layout ids.\n" + "- Use getContentSchemaFromLayoutId before saveSlide when validating structure.\n" + "- Use generateImage and generateIcon to fetch media URLs used in content.\n" + "- Use saveSlide to create/replace slides only with schema-valid content.\n" + "- For saveSlide, send content as a JSON-serialized object string.\n" + "- After tool outputs are sufficient, stop calling tools and provide a final answer.\n" + "- If memory is missing, state that clearly and suggest next steps.\n" + "- Do not invent slide facts that are not in tool results or memory.\n" + f"{context_block}" + ) diff --git a/servers/fastapi/services/chat/schemas.py b/servers/fastapi/services/chat/schemas.py new file mode 100644 index 00000000..ef440351 --- /dev/null +++ b/servers/fastapi/services/chat/schemas.py @@ -0,0 +1,65 @@ +import json +from typing import Any + +import dirtyjson # type: ignore[import-untyped] +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class StrictSchemaModel(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True) + + +class NoArgsInput(StrictSchemaModel): + pass + + +class GetSlideAtIndexInput(StrictSchemaModel): + index: int = Field(ge=0, le=1000) + + +class SearchSlidesInput(StrictSchemaModel): + query: str = Field(min_length=1, max_length=1000) + limit: int = Field(ge=1, le=10) + + +class GetContentSchemaFromLayoutIdInput(StrictSchemaModel): + layout_id: str = Field(alias="layoutId", min_length=1, max_length=200) + + model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True) + + +class GenerateImageInput(StrictSchemaModel): + prompt: str = Field(min_length=1, max_length=4000) + + +class GenerateIconInput(StrictSchemaModel): + query: str = Field(min_length=1, max_length=1000) + + +class SaveSlideInput(StrictSchemaModel): + content: str = Field( + min_length=2, + max_length=200000, + description=( + "A JSON-serialized object for slide content. " + "Example: '{\"title\": \"Q4 Revenue\", \"bullets\": [\"North America +22%\"]}'" + ), + ) + layout_id: str = Field(alias="layoutId", min_length=1, max_length=200) + index: int = Field(ge=0, le=1000) + replace_old_slide_at_index: bool = Field(alias="replaceOldSlideAtIndex") + + model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True) + + @field_validator("content") + @classmethod + def validate_content(cls, value: str) -> str: + try: + parsed: Any = dirtyjson.loads(value) + except Exception: + parsed = json.loads(value) + + if not isinstance(parsed, dict): + raise ValueError("'content' must be a JSON object.") + + return value diff --git a/servers/fastapi/services/chat/service.py b/servers/fastapi/services/chat/service.py new file mode 100644 index 00000000..61c3eebf --- /dev/null +++ b/servers/fastapi/services/chat/service.py @@ -0,0 +1,205 @@ +import asyncio +import json +import logging +import uuid +from dataclasses import dataclass +from typing import Any + +from fastapi import HTTPException +from llmai import get_client # type: ignore[import-not-found] +from llmai.shared import ( # type: ignore[import-not-found] + AssistantMessage, + Message, + SystemMessage, + ToolResponseMessage, + UserMessage, +) +from sqlalchemy.ext.asyncio import AsyncSession + +from models.sql.presentation import PresentationModel +from services.chat.conversation_store import ChatConversationStore +from services.chat.memory_layer import PresentationChatMemoryLayer +from services.chat.prompts import build_system_prompt +from services.chat.tools import ChatTools +from utils.llm_client_error_handler import handle_llm_client_exceptions +from utils.llm_config import get_llm_config +from utils.llm_provider import get_model +from utils.llm_utils import extract_text, get_generate_kwargs + +LOGGER = logging.getLogger(__name__) +MAX_TOOL_ROUNDS = 6 + + +@dataclass(frozen=True) +class ChatTurnResult: + conversation_id: uuid.UUID + response_text: str + tool_calls: list[str] + + +class PresentationChatService: + def __init__( + self, + sql_session: AsyncSession, + presentation_id: uuid.UUID, + conversation_id: uuid.UUID | None, + ): + self._sql_session = sql_session + self._presentation_id = presentation_id + self._conversation_id = conversation_id + + self._conversation_store = ChatConversationStore(sql_session) + self._memory = PresentationChatMemoryLayer(sql_session, presentation_id) + self._tools = ChatTools(self._memory) + + async def generate_reply(self, user_message: str) -> ChatTurnResult: + if not (user_message or "").strip(): + raise HTTPException(status_code=400, detail="Message is required") + + presentation = await self._sql_session.get(PresentationModel, self._presentation_id) + if not presentation: + raise HTTPException(status_code=404, detail="Presentation not found") + + conversation_id = await self._conversation_store.ensure_conversation_id( + self._conversation_id + ) + history = await self._conversation_store.load_history( + presentation_id=self._presentation_id, + conversation_id=conversation_id, + ) + history_messages = self._convert_history_to_messages(history) + + memory_context = await self._memory.retrieve_context(user_message) + messages: list[Message] = [ + SystemMessage(content=build_system_prompt(memory_context)), + *history_messages, + UserMessage(content=user_message), + ] + + response_text, tool_calls = await self._run_llm_with_tools(messages) + await self._conversation_store.append_turn( + presentation_id=self._presentation_id, + conversation_id=conversation_id, + user_message=user_message, + assistant_message=response_text, + ) + + return ChatTurnResult( + conversation_id=conversation_id, + response_text=response_text, + tool_calls=tool_calls, + ) + + async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]: + # llmai is the only LLM entrypoint; provider selection comes from app config. + client = get_client(config=get_llm_config()) + model = get_model() + tools = self._tools.get_tool_definitions() + + called_tools: list[str] = [] + last_tool_results: list[dict[str, Any]] = [] + + for _ in range(MAX_TOOL_ROUNDS): + try: + response = await asyncio.to_thread( + client.generate, + **get_generate_kwargs( + model=model, + messages=messages, + tools=tools, + ), + ) + except Exception as exc: + raise handle_llm_client_exceptions(exc) + + if not response.tool_calls: + response_text = extract_text(response.content) or ( + "I could not generate a response for that request." + ) + return response_text, called_tools + + called_tools.extend([tool_call.name for tool_call in response.tool_calls]) + # Reuse llmai-returned threaded messages so provider adapters keep state. + messages = list(response.messages) if response.messages else list(messages) + + last_tool_results = [] + for tool_call in response.tool_calls: + tool_result = await self._tools.execute_tool_call(tool_call) + last_tool_results.append(tool_result) + tool_response_content = json.dumps(tool_result, ensure_ascii=False) + # Tool responses are fed back into llmai to let the model continue. + messages.append( + ToolResponseMessage( + id=tool_call.id, + content=[tool_response_content], + ) + ) + + LOGGER.warning("Max tool rounds reached in chat flow") + final_response = await self._try_final_response_without_tools( + client=client, + model=model, + messages=messages, + ) + if final_response: + return final_response, called_tools + + return self._build_tool_limit_fallback(last_tool_results), called_tools + + async def _try_final_response_without_tools( + self, + *, + client: Any, + model: str, + messages: list[Message], + ) -> str | None: + """ + Give the model one final chance to synthesize a natural-language answer + from already-executed tool outputs, without allowing more tool calls. + """ + try: + response = await asyncio.to_thread( + client.generate, + **get_generate_kwargs( + model=model, + messages=messages, + ), + ) + except Exception: + LOGGER.warning("Final no-tool synthesis call failed", exc_info=True) + return None + + return extract_text(response.content) + + @staticmethod + def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str: + for entry in reversed(last_tool_results): + if not isinstance(entry, dict): + continue + if not entry.get("ok"): + continue + result = entry.get("result") + if not isinstance(result, dict): + continue + message = result.get("message") + if isinstance(message, str) and message.strip(): + return message.strip() + + return ( + "I completed several tool operations but could not finalize the response " + "within the tool limit. Please ask a follow-up and I will continue." + ) + + @staticmethod + def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]: + messages: list[Message] = [] + for item in history: + role = item.get("role") + content = item.get("content") + if not content: + continue + if role == "user": + messages.append(UserMessage(content=content)) + elif role == "assistant": + messages.append(AssistantMessage(content=[content])) + return messages diff --git a/servers/fastapi/services/chat/tools.py b/servers/fastapi/services/chat/tools.py new file mode 100644 index 00000000..58cf779d --- /dev/null +++ b/servers/fastapi/services/chat/tools.py @@ -0,0 +1,302 @@ +import json +import logging +import re +from typing import Any, Awaitable, Callable + +import dirtyjson # type: ignore[import-untyped] +from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found] + +from services.chat.schemas import ( + GenerateIconInput, + GenerateImageInput, + GetContentSchemaFromLayoutIdInput, + GetSlideAtIndexInput, + NoArgsInput, + SaveSlideInput, + SearchSlidesInput, +) +from services.chat.memory_layer import PresentationChatMemoryLayer + +LOGGER = logging.getLogger(__name__) + +ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] + + +class ChatTools: + """ + llmai function tools for presentation chat. + + Tool implementations only use the memory abstraction layer and avoid external + provider-specific logic, keeping them portable across llmai backends. + """ + + def __init__(self, memory: PresentationChatMemoryLayer): + self._memory = memory + self._tool_handlers: dict[str, ToolHandler] = { + "getPresentationOutline": self._get_presentation_outline, + "searchSlides": self._search_slides, + "getSlideAtIndex": self._get_slide_at_index, + "getAvailableLayouts": self._get_available_layouts, + "getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id, + "generateImage": self._generate_image, + "generateIcon": self._generate_icon, + "saveSlide": self._save_slide, + } + + def get_tool_definitions(self) -> list[Tool]: + return [ + Tool( + name="getPresentationOutline", + description=( + "Retrieve the current presentation outline from memory. " + "Use when the user asks about sections, flow, or slide plan." + ), + schema=NoArgsInput, + strict=True, + ), + Tool( + name="searchSlides", + description=( + "Search slide memory by semantic intent or keywords and return " + "relevant slide snippets with identifiers. " + "Always provide both query and limit." + ), + schema=SearchSlidesInput, + strict=True, + ), + Tool( + name="getSlideAtIndex", + description=( + "Retrieve a single slide by zero-based index, including its " + "layout id and current structured content." + ), + schema=GetSlideAtIndexInput, + strict=True, + ), + Tool( + name="getAvailableLayouts", + description=( + "List all available layout ids and descriptions for the current " + "presentation template." + ), + schema=NoArgsInput, + strict=True, + ), + Tool( + name="getContentSchemaFromLayoutId", + description=( + "Fetch the JSON content schema for a layout id. Use before " + "saving slide content to validate structure." + ), + schema=GetContentSchemaFromLayoutIdInput, + strict=True, + ), + Tool( + name="generateImage", + description=( + "Generate or fetch an image URL/path from a prompt and return " + "the usable URL/path." + ), + schema=GenerateImageInput, + strict=True, + ), + Tool( + name="generateIcon", + description="Search icon memory and return the most relevant icon URL.", + schema=GenerateIconInput, + strict=True, + ), + Tool( + name="saveSlide", + description=( + "Save slide content for a layout. If replaceOldSlideAtIndex is " + "true, replace that index; otherwise insert as a new slide. " + "Pass content as a JSON-serialized object string and the server " + "will validate it against layout schema before save." + ), + schema=SaveSlideInput, + strict=True, + ), + ] + + async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]: + handler = self._tool_handlers.get(tool_call.name) + if not handler: + return { + "ok": False, + "tool": tool_call.name, + "error": f"Unsupported tool: {tool_call.name}", + } + + try: + parsed_args = self._parse_args(tool_call.arguments) + LOGGER.info("Executing chat tool %s", tool_call.name) + result = await handler(parsed_args) + return {"ok": True, "tool": tool_call.name, "result": result} + except Exception as exc: + LOGGER.exception("Chat tool failed: %s", tool_call.name) + return { + "ok": False, + "tool": tool_call.name, + "error": str(exc), + } + + async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]: + outline = await self._memory.get("presentation_outline") + if not isinstance(outline, dict): + return { + "found": False, + "message": "Presentation outline is not available in memory yet.", + "sections": [], + } + + slides = outline.get("slides") + if not isinstance(slides, list) or not slides: + return { + "found": False, + "message": "Presentation outline exists but has no slides.", + "sections": [], + } + + sections: list[dict[str, Any]] = [] + for index, slide in enumerate(slides): + content = "" + if isinstance(slide, dict): + content = str(slide.get("content") or "") + elif isinstance(slide, str): + content = slide + + title = self._extract_title(content) or f"Slide {index + 1}" + sections.append( + { + "index": index, + "title": title, + "preview": self._truncate(" ".join(content.split()), 220), + } + ) + + return { + "found": True, + "slide_count": len(sections), + "sections": sections, + "outline": outline, + } + + async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]: + payload = SearchSlidesInput(**args) + results = await self._memory.search(payload.query, payload.limit) + return { + "query": payload.query, + "count": len(results), + "results": results, + } + + async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]: + payload = GetSlideAtIndexInput(**args) + slide = await self._memory.get_slide_at_index(payload.index) + if not slide: + return { + "found": False, + "message": f"No slide found at index {payload.index}.", + } + return { + "found": True, + "slide": slide, + } + + async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]: + layouts = await self._memory.get_available_layouts() + return { + "count": len(layouts), + "layouts": layouts, + } + + async def _get_content_schema_from_layout_id( + self, args: dict[str, Any] + ) -> dict[str, Any]: + payload = GetContentSchemaFromLayoutIdInput(**args) + schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id) + if schema is None: + return { + "found": False, + "layout_id": payload.layout_id, + "message": "Layout schema not found for the provided layout id.", + } + return { + "found": True, + "layout_id": payload.layout_id, + "content_schema": schema, + } + + async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]: + payload = GenerateImageInput(**args) + image_url = await self._memory.generate_image(payload.prompt) + return { + "prompt": payload.prompt, + "url": image_url, + } + + async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]: + payload = GenerateIconInput(**args) + icon_url = await self._memory.generate_icon(payload.query) + return { + "query": payload.query, + "url": icon_url, + } + + async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]: + payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False)) + raw_content = payload_args.get("content") + if isinstance(raw_content, dict): + payload_args["content"] = json.dumps(raw_content, ensure_ascii=False) + + payload = SaveSlideInput(**payload_args) + try: + content_parsed: Any = dirtyjson.loads(payload.content) + except Exception: + content_parsed = json.loads(payload.content) + + if not isinstance(content_parsed, dict): + raise ValueError("'content' must be a JSON object.") + + content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False)) + return await self._memory.save_slide( + content=content_payload, + layout_id=payload.layout_id, + index=payload.index, + replace_old_slide_at_index=payload.replace_old_slide_at_index, + ) + + @staticmethod + def _parse_args(arguments: str | None) -> dict[str, Any]: + if not arguments: + return {} + + try: + parsed = dirtyjson.loads(arguments) + except Exception: + parsed = json.loads(arguments) + + normalized = json.loads(json.dumps(parsed, ensure_ascii=False)) + if isinstance(normalized, dict): + return normalized + + raise ValueError("Tool arguments must be a JSON object.") + + @staticmethod + def _extract_title(markdown_content: str) -> str: + for line in markdown_content.splitlines(): + stripped = line.strip() + if not stripped: + continue + heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped) + if heading_match: + return heading_match.group(1).strip() + return stripped[:120] + return "" + + @staticmethod + def _truncate(value: str, limit: int) -> str: + if len(value) <= limit: + return value + return f"{value[:limit]}..." diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock index 0a12daf7..fdbf9f8a 100644 --- a/servers/fastapi/uv.lock +++ b/servers/fastapi/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = "==3.11.*" [[package]] @@ -811,9 +811,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/c6/dba32cab7e3a625b011aa5647486e2d28423a48845a2998c126dd69c85e1/greenlet-3.4.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:805bebb4945094acbab757d34d6e1098be6de8966009ab9ca54f06ff492def58", size = 285504, upload-time = "2026-04-08T15:52:14.071Z" }, { url = "https://files.pythonhosted.org/packages/54/f4/7cb5c2b1feb9a1f50e038be79980dfa969aa91979e5e3a18fdbcfad2c517/greenlet-3.4.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:439fc2f12b9b512d9dfa681c5afe5f6b3232c708d13e6f02c845e0d9f4c2d8c6", size = 605476, upload-time = "2026-04-08T16:24:37.064Z" }, { url = "https://files.pythonhosted.org/packages/d6/af/b66ab0b2f9a4c5a867c136bf66d9599f34f21a1bcca26a2884a29c450bd9/greenlet-3.4.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a70ed1cb0295bee1df57b63bf7f46b4e56a5c93709eea769c1fec1bb23a95875", size = 618336, upload-time = "2026-04-08T16:30:56.59Z" }, - { url = "https://files.pythonhosted.org/packages/6d/31/56c43d2b5de476f77d36ceeec436328533bff960a4cba9a07616e93063ab/greenlet-3.4.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8c5696c42e6bb5cfb7c6ff4453789081c66b9b91f061e5e9367fa15792644e76", size = 625045, upload-time = "2026-04-08T16:40:37.111Z" }, { url = "https://files.pythonhosted.org/packages/e5/5c/8c5633ece6ba611d64bf2770219a98dd439921d6424e4e8cf16b0ac74ea5/greenlet-3.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c660bce1940a1acae5f51f0a064f1bc785d07ea16efcb4bc708090afc4d69e83", size = 613515, upload-time = "2026-04-08T15:56:32.478Z" }, - { url = "https://files.pythonhosted.org/packages/80/ca/704d4e2c90acb8bdf7ae593f5cbc95f58e82de95cc540fb75631c1054533/greenlet-3.4.0-cp311-cp311-manylinux_2_39_riscv64.whl", hash = "sha256:89995ce5ddcd2896d89615116dd39b9703bfa0c07b583b85b89bf1b5d6eddf81", size = 419745, upload-time = "2026-04-08T16:43:04.022Z" }, { url = "https://files.pythonhosted.org/packages/a9/df/950d15bca0d90a0e7395eb777903060504cdb509b7b705631e8fb69ff415/greenlet-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ee407d4d1ca9dc632265aee1c8732c4a2d60adff848057cdebfe5fe94eb2c8a2", size = 1574623, upload-time = "2026-04-08T16:26:18.596Z" }, { url = "https://files.pythonhosted.org/packages/1a/e7/0839afab829fcb7333c9ff6d80c040949510055d2d4d63251f0d1c7c804e/greenlet-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:956215d5e355fffa7c021d168728321fd4d31fd730ac609b1653b450f6a4bc71", size = 1639579, upload-time = "2026-04-08T15:57:29.231Z" }, { url = "https://files.pythonhosted.org/packages/d9/2b/b4482401e9bcaf9f5c97f67ead38db89c19520ff6d0d6699979c6efcc200/greenlet-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:5cb614ace7c27571270354e9c9f696554d073f8aa9319079dcba466bbdead711", size = 238233, upload-time = "2026-04-08T17:02:54.286Z" }, @@ -1188,23 +1186,16 @@ wheels = [ [[package]] name = "llmai" version = "0.1.9" -source = { url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl" } +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anthropic" }, { name = "boto3" }, { name = "google-genai" }, { name = "openai" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4" }, -] - -[package.metadata] -requires-dist = [ - { name = "anthropic", specifier = ">=0.79.0" }, - { name = "boto3", specifier = ">=1.42.89" }, - { name = "google-genai", specifier = ">=1.62.0" }, - { name = "openai", specifier = ">=2.18.0" }, + { url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" }, ] [[package]] @@ -1671,6 +1662,7 @@ dependencies = [ { name = "fastembed-vectorstore" }, { name = "fastmcp" }, { name = "google-genai" }, + { name = "jsonschema" }, { name = "llmai" }, { name = "mem0ai", extra = ["nlp"] }, { name = "nltk" }, @@ -1693,7 +1685,8 @@ requires-dist = [ { name = "fastembed-vectorstore", specifier = ">=0.5.2" }, { name = "fastmcp", specifier = ">=2.11.0" }, { name = "google-genai", specifier = ">=1.28.0" }, - { name = "llmai", url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl" }, + { name = "jsonschema", specifier = ">=4.26.0" }, + { name = "llmai", specifier = "==0.1.9" }, { name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" }, { name = "nltk", specifier = ">=3.9.1" }, { name = "openai", specifier = ">=1.98.0" },