presenton/servers/fastapi/services/chat/service.py

449 lines
16 KiB
Python

import asyncio
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, Literal
from fastapi import HTTPException
from llmai import get_client # type: ignore[import-not-found]
from llmai.shared import ( # type: ignore[import-not-found]
AssistantMessage,
Message,
SystemMessage,
TextContentPart,
ToolResponseMessage,
UserMessage,
WebSearchTool
)
from sqlalchemy.ext.asyncio import AsyncSession
from models.sql.presentation import PresentationModel
from services.chat.conversation_store import ChatConversationStore
from services.chat.presentation_context_store import PresentationContextStore
from services.chat.prompts import build_system_prompt
from services.chat.tools import ChatTools
from utils.llm_client_error_handler import handle_llm_client_exceptions
from utils.llm_config import get_llm_config
from utils.llm_provider import get_model
from utils.llm_utils import (
extract_text,
get_generate_kwargs,
stream_generate_events,
)
LOGGER = logging.getLogger(__name__)
MAX_TOOL_ROUNDS = 16
@dataclass(frozen=True)
class ChatTurnResult:
conversation_id: uuid.UUID
response_text: str
tool_calls: list[str]
ChatStreamEventType = Literal["chunk", "complete", "status", "trace"]
ChatStreamEventValue = str | ChatTurnResult | dict[str, Any]
class PresentationChatService:
def __init__(
self,
sql_session: AsyncSession,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID | None,
):
self._sql_session = sql_session
self._presentation_id = presentation_id
self._conversation_id = conversation_id
self._conversation_store = ChatConversationStore(sql_session)
self._memory = PresentationContextStore(sql_session, presentation_id)
self._tools = ChatTools(self._memory)
async def generate_reply(self, user_message: str) -> ChatTurnResult:
conversation_id, messages = await self._prepare_turn_context(user_message)
response_text, tool_calls = await self._run_llm_with_tools(messages)
return await self._persist_turn(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
tool_calls=tool_calls,
)
async def stream_reply(
self, user_message: str
) -> AsyncGenerator[tuple[ChatStreamEventType, ChatStreamEventValue], None]:
yield "status", "Reading deck context"
conversation_id, messages = await self._prepare_turn_context(user_message)
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
tools.append(WebSearchTool())
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
response_text: str | None = None
for round_index in range(MAX_TOOL_ROUNDS):
completion_chunk: Any | None = None
round_content_chunks: list[str] = []
thinking_chunks: list[str] = []
try:
async for event in stream_generate_events(
client,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
stream=True,
),
):
event_type = getattr(event, "type", None)
if event_type == "content":
chunk = getattr(event, "chunk", None)
if chunk:
round_content_chunks.append(chunk)
yield "chunk", chunk
elif event_type == "thinking":
thinking_text = self._event_text(event)
if thinking_text:
thinking_chunks.append(thinking_text)
elif event_type == "completion":
completion_chunk = event
except Exception as exc:
raise handle_llm_client_exceptions(exc)
thinking_summary = self._summarize_model_note(thinking_chunks)
if thinking_summary:
yield "trace", {
"kind": "model_note",
"round": round_index + 1,
"status": "info",
"message": thinking_summary,
}
completion_tool_calls = list(
getattr(completion_chunk, "tool_calls", []) or []
)
if completion_tool_calls:
tool_names = [tool_call.name for tool_call in completion_tool_calls]
called_tools.extend(tool_names)
yield "trace", {
"kind": "tool_plan",
"round": round_index + 1,
"tools": tool_names,
"message": f"Using tools: {', '.join(tool_names)}",
}
messages = (
list(getattr(completion_chunk, "messages", []) or [])
if getattr(completion_chunk, "messages", None)
else list(messages)
)
last_tool_results = []
for tool_call in completion_tool_calls:
yield "trace", {
"kind": "tool_call",
"round": round_index + 1,
"tool": tool_call.name,
"status": "start",
"message": self._tool_start_message(tool_call.name),
}
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
yield "trace", {
"kind": "tool_call",
"round": round_index + 1,
"tool": tool_call.name,
"status": "success" if tool_result.get("ok") else "error",
"message": self._summarize_tool_result(
tool_call.name, tool_result
),
}
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[TextContentPart(text=tool_response_content)],
)
)
continue
response_text = "".join(round_content_chunks)
if not response_text and completion_chunk:
response_text = extract_text(getattr(completion_chunk, "content", None))
if not response_text:
response_text = "I could not generate a response for that request."
if not round_content_chunks:
yield "chunk", response_text
break
else:
LOGGER.warning("Max tool rounds reached in chat stream flow")
yield "trace", {
"kind": "limit",
"message": (
"Reached tool-call limit before final answer; "
"attempting best-effort summary."
),
}
yield "status", "Finalizing response"
response_text = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if not response_text:
response_text = self._build_tool_limit_fallback(last_tool_results)
yield "chunk", response_text
final_response_text = response_text or "I could not generate a response for that request."
if response_text is None:
yield "chunk", final_response_text
yield "status", "Saving chat"
result = await self._persist_turn(
conversation_id=conversation_id,
user_message=user_message,
response_text=final_response_text,
tool_calls=called_tools,
)
yield "complete", result
async def _prepare_turn_context(
self, user_message: str
) -> tuple[uuid.UUID, list[Message]]:
if not (user_message or "").strip():
raise HTTPException(status_code=400, detail="Message is required")
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
raise HTTPException(status_code=404, detail="Presentation not found")
conversation_id = await self._conversation_store.ensure_conversation_id(
self._conversation_id
)
history = await self._conversation_store.load_history(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
)
history_messages = self._convert_history_to_messages(history)
presentation_memory = await self._memory.retrieve_context(user_message)
chat_memory = await self._conversation_store.retrieve_semantic_context(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
query=user_message,
)
messages: list[Message] = [
SystemMessage(
content=build_system_prompt(
presentation_memory_context=presentation_memory,
chat_memory_context=chat_memory,
)
),
*history_messages,
UserMessage(content=user_message),
]
return conversation_id, messages
async def _persist_turn(
self,
*,
conversation_id: uuid.UUID,
user_message: str,
response_text: str,
tool_calls: list[str],
) -> ChatTurnResult:
await self._conversation_store.append_turn(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=response_text,
tool_calls=tool_calls,
)
await self._sql_session.commit()
return ChatTurnResult(
conversation_id=conversation_id,
response_text=response_text,
tool_calls=tool_calls,
)
async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]:
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
for _ in range(MAX_TOOL_ROUNDS):
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
),
)
except Exception as exc:
raise handle_llm_client_exceptions(exc)
if not response.tool_calls:
response_text = extract_text(response.content) or (
"I could not generate a response for that request."
)
return response_text, called_tools
called_tools.extend([tool_call.name for tool_call in response.tool_calls])
messages = list(response.messages) if response.messages else list(messages)
last_tool_results = []
for tool_call in response.tool_calls:
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[TextContentPart(text=tool_response_content)],
)
)
LOGGER.warning("Max tool rounds reached in chat flow")
final_response = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if final_response:
return final_response, called_tools
return self._build_tool_limit_fallback(last_tool_results), called_tools
async def _try_final_response_without_tools(
self,
*,
client: Any,
model: str,
messages: list[Message],
) -> str | None:
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
),
)
except Exception:
LOGGER.warning("Final no-tool synthesis call failed", exc_info=True)
return None
return extract_text(response.content)
@staticmethod
def _summarize_model_note(chunks: list[str]) -> str:
text = "".join(chunks).strip()
if not text or text in {"{}", "[]"}:
return ""
compact = " ".join(text.split())
if compact.lower() in {"start", "end"}:
return ""
if len(compact) > 600:
return f"{compact[:600].rstrip()}..."
return compact
@staticmethod
def _event_text(event: Any) -> str:
for attr in ("chunk", "delta", "text", "content"):
value = getattr(event, attr, None)
if isinstance(value, str):
return value
return ""
@staticmethod
def _tool_start_message(tool_name: str) -> str:
labels = {
"getPresentationOutline": "Reading the presentation outline",
"searchSlides": "Searching relevant slides",
"getSlideAtIndex": "Opening the requested slide",
"getAvailableLayouts": "Checking available layouts",
"getContentSchemaFromLayoutId": "Checking the layout schema",
"generateAssets": "Generating slide assets",
"saveSlide": "Saving the slide",
}
return labels.get(tool_name, f"Running {tool_name}")
@staticmethod
def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str:
for entry in reversed(last_tool_results):
if not isinstance(entry, dict):
continue
if not entry.get("ok"):
continue
result = entry.get("result")
if not isinstance(result, dict):
continue
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
return (
"I completed several tool operations but could not finalize the response "
"within the tool limit. Please ask a follow-up and I will continue."
)
@staticmethod
def _summarize_tool_result(tool_name: str, tool_result: dict[str, Any]) -> str:
if not tool_result.get("ok"):
error = tool_result.get("error")
if isinstance(error, str) and error.strip():
return f"{tool_name} failed: {error.strip()}"
return f"{tool_name} failed."
result = tool_result.get("result")
if isinstance(result, dict):
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
note = result.get("note")
if isinstance(note, str) and note.strip():
return note.strip()
count = result.get("count")
if isinstance(count, int):
return f"{tool_name} returned {count} result(s)."
found = result.get("found")
if isinstance(found, bool):
return (
f"{tool_name} found requested data."
if found
else f"{tool_name} did not find matching data."
)
return f"{tool_name} completed."
@staticmethod
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
messages: list[Message] = []
for item in history:
role = item.get("role")
content = item.get("content")
if not content:
continue
if role == "user":
messages.append(UserMessage(content=content))
elif role == "assistant":
messages.append(AssistantMessage(content=[content]))
return messages