449 lines
16 KiB
Python
449 lines
16 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from collections.abc import AsyncGenerator
|
|
from dataclasses import dataclass
|
|
from typing import Any, Literal
|
|
|
|
from fastapi import HTTPException
|
|
from llmai import get_client # type: ignore[import-not-found]
|
|
from llmai.shared import ( # type: ignore[import-not-found]
|
|
AssistantMessage,
|
|
Message,
|
|
SystemMessage,
|
|
TextContentPart,
|
|
ToolResponseMessage,
|
|
UserMessage,
|
|
WebSearchTool
|
|
)
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from models.sql.presentation import PresentationModel
|
|
from services.chat.conversation_store import ChatConversationStore
|
|
from services.chat.presentation_context_store import PresentationContextStore
|
|
from services.chat.prompts import build_system_prompt
|
|
from services.chat.tools import ChatTools
|
|
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
|
from utils.llm_config import get_llm_config
|
|
from utils.llm_provider import get_model
|
|
from utils.llm_utils import (
|
|
extract_text,
|
|
get_generate_kwargs,
|
|
stream_generate_events,
|
|
)
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
MAX_TOOL_ROUNDS = 16
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ChatTurnResult:
|
|
conversation_id: uuid.UUID
|
|
response_text: str
|
|
tool_calls: list[str]
|
|
|
|
|
|
ChatStreamEventType = Literal["chunk", "complete", "status", "trace"]
|
|
ChatStreamEventValue = str | ChatTurnResult | dict[str, Any]
|
|
|
|
|
|
class PresentationChatService:
|
|
def __init__(
|
|
self,
|
|
sql_session: AsyncSession,
|
|
presentation_id: uuid.UUID,
|
|
conversation_id: uuid.UUID | None,
|
|
):
|
|
self._sql_session = sql_session
|
|
self._presentation_id = presentation_id
|
|
self._conversation_id = conversation_id
|
|
|
|
self._conversation_store = ChatConversationStore(sql_session)
|
|
self._memory = PresentationContextStore(sql_session, presentation_id)
|
|
self._tools = ChatTools(self._memory)
|
|
|
|
async def generate_reply(self, user_message: str) -> ChatTurnResult:
|
|
conversation_id, messages = await self._prepare_turn_context(user_message)
|
|
response_text, tool_calls = await self._run_llm_with_tools(messages)
|
|
return await self._persist_turn(
|
|
conversation_id=conversation_id,
|
|
user_message=user_message,
|
|
response_text=response_text,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
async def stream_reply(
|
|
self, user_message: str
|
|
) -> AsyncGenerator[tuple[ChatStreamEventType, ChatStreamEventValue], None]:
|
|
yield "status", "Reading deck context"
|
|
conversation_id, messages = await self._prepare_turn_context(user_message)
|
|
|
|
client = get_client(config=get_llm_config())
|
|
model = get_model()
|
|
tools = self._tools.get_tool_definitions()
|
|
tools.append(WebSearchTool())
|
|
|
|
called_tools: list[str] = []
|
|
last_tool_results: list[dict[str, Any]] = []
|
|
response_text: str | None = None
|
|
|
|
for round_index in range(MAX_TOOL_ROUNDS):
|
|
completion_chunk: Any | None = None
|
|
round_content_chunks: list[str] = []
|
|
thinking_chunks: list[str] = []
|
|
|
|
try:
|
|
async for event in stream_generate_events(
|
|
client,
|
|
**get_generate_kwargs(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
stream=True,
|
|
),
|
|
):
|
|
event_type = getattr(event, "type", None)
|
|
if event_type == "content":
|
|
chunk = getattr(event, "chunk", None)
|
|
if chunk:
|
|
round_content_chunks.append(chunk)
|
|
yield "chunk", chunk
|
|
elif event_type == "thinking":
|
|
thinking_text = self._event_text(event)
|
|
if thinking_text:
|
|
thinking_chunks.append(thinking_text)
|
|
elif event_type == "completion":
|
|
completion_chunk = event
|
|
except Exception as exc:
|
|
raise handle_llm_client_exceptions(exc)
|
|
|
|
thinking_summary = self._summarize_model_note(thinking_chunks)
|
|
if thinking_summary:
|
|
yield "trace", {
|
|
"kind": "model_note",
|
|
"round": round_index + 1,
|
|
"status": "info",
|
|
"message": thinking_summary,
|
|
}
|
|
|
|
completion_tool_calls = list(
|
|
getattr(completion_chunk, "tool_calls", []) or []
|
|
)
|
|
if completion_tool_calls:
|
|
tool_names = [tool_call.name for tool_call in completion_tool_calls]
|
|
called_tools.extend(tool_names)
|
|
yield "trace", {
|
|
"kind": "tool_plan",
|
|
"round": round_index + 1,
|
|
"tools": tool_names,
|
|
"message": f"Using tools: {', '.join(tool_names)}",
|
|
}
|
|
messages = (
|
|
list(getattr(completion_chunk, "messages", []) or [])
|
|
if getattr(completion_chunk, "messages", None)
|
|
else list(messages)
|
|
)
|
|
|
|
last_tool_results = []
|
|
for tool_call in completion_tool_calls:
|
|
yield "trace", {
|
|
"kind": "tool_call",
|
|
"round": round_index + 1,
|
|
"tool": tool_call.name,
|
|
"status": "start",
|
|
"message": self._tool_start_message(tool_call.name),
|
|
}
|
|
tool_result = await self._tools.execute_tool_call(tool_call)
|
|
last_tool_results.append(tool_result)
|
|
yield "trace", {
|
|
"kind": "tool_call",
|
|
"round": round_index + 1,
|
|
"tool": tool_call.name,
|
|
"status": "success" if tool_result.get("ok") else "error",
|
|
"message": self._summarize_tool_result(
|
|
tool_call.name, tool_result
|
|
),
|
|
}
|
|
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
|
|
messages.append(
|
|
ToolResponseMessage(
|
|
id=tool_call.id,
|
|
content=[TextContentPart(text=tool_response_content)],
|
|
)
|
|
)
|
|
continue
|
|
|
|
response_text = "".join(round_content_chunks)
|
|
if not response_text and completion_chunk:
|
|
response_text = extract_text(getattr(completion_chunk, "content", None))
|
|
if not response_text:
|
|
response_text = "I could not generate a response for that request."
|
|
|
|
if not round_content_chunks:
|
|
yield "chunk", response_text
|
|
break
|
|
else:
|
|
LOGGER.warning("Max tool rounds reached in chat stream flow")
|
|
yield "trace", {
|
|
"kind": "limit",
|
|
"message": (
|
|
"Reached tool-call limit before final answer; "
|
|
"attempting best-effort summary."
|
|
),
|
|
}
|
|
yield "status", "Finalizing response"
|
|
response_text = await self._try_final_response_without_tools(
|
|
client=client,
|
|
model=model,
|
|
messages=messages,
|
|
)
|
|
if not response_text:
|
|
response_text = self._build_tool_limit_fallback(last_tool_results)
|
|
yield "chunk", response_text
|
|
|
|
final_response_text = response_text or "I could not generate a response for that request."
|
|
if response_text is None:
|
|
yield "chunk", final_response_text
|
|
|
|
yield "status", "Saving chat"
|
|
result = await self._persist_turn(
|
|
conversation_id=conversation_id,
|
|
user_message=user_message,
|
|
response_text=final_response_text,
|
|
tool_calls=called_tools,
|
|
)
|
|
yield "complete", result
|
|
|
|
async def _prepare_turn_context(
|
|
self, user_message: str
|
|
) -> tuple[uuid.UUID, list[Message]]:
|
|
if not (user_message or "").strip():
|
|
raise HTTPException(status_code=400, detail="Message is required")
|
|
|
|
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
|
|
if not presentation:
|
|
raise HTTPException(status_code=404, detail="Presentation not found")
|
|
|
|
conversation_id = await self._conversation_store.ensure_conversation_id(
|
|
self._conversation_id
|
|
)
|
|
history = await self._conversation_store.load_history(
|
|
presentation_id=self._presentation_id,
|
|
conversation_id=conversation_id,
|
|
)
|
|
history_messages = self._convert_history_to_messages(history)
|
|
|
|
presentation_memory = await self._memory.retrieve_context(user_message)
|
|
chat_memory = await self._conversation_store.retrieve_semantic_context(
|
|
presentation_id=self._presentation_id,
|
|
conversation_id=conversation_id,
|
|
query=user_message,
|
|
)
|
|
messages: list[Message] = [
|
|
SystemMessage(
|
|
content=build_system_prompt(
|
|
presentation_memory_context=presentation_memory,
|
|
chat_memory_context=chat_memory,
|
|
)
|
|
),
|
|
*history_messages,
|
|
UserMessage(content=user_message),
|
|
]
|
|
return conversation_id, messages
|
|
|
|
async def _persist_turn(
|
|
self,
|
|
*,
|
|
conversation_id: uuid.UUID,
|
|
user_message: str,
|
|
response_text: str,
|
|
tool_calls: list[str],
|
|
) -> ChatTurnResult:
|
|
await self._conversation_store.append_turn(
|
|
presentation_id=self._presentation_id,
|
|
conversation_id=conversation_id,
|
|
user_message=user_message,
|
|
assistant_message=response_text,
|
|
tool_calls=tool_calls,
|
|
)
|
|
await self._sql_session.commit()
|
|
|
|
return ChatTurnResult(
|
|
conversation_id=conversation_id,
|
|
response_text=response_text,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]:
|
|
client = get_client(config=get_llm_config())
|
|
model = get_model()
|
|
tools = self._tools.get_tool_definitions()
|
|
|
|
called_tools: list[str] = []
|
|
last_tool_results: list[dict[str, Any]] = []
|
|
|
|
for _ in range(MAX_TOOL_ROUNDS):
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
client.generate,
|
|
**get_generate_kwargs(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
),
|
|
)
|
|
except Exception as exc:
|
|
raise handle_llm_client_exceptions(exc)
|
|
|
|
if not response.tool_calls:
|
|
response_text = extract_text(response.content) or (
|
|
"I could not generate a response for that request."
|
|
)
|
|
return response_text, called_tools
|
|
|
|
called_tools.extend([tool_call.name for tool_call in response.tool_calls])
|
|
messages = list(response.messages) if response.messages else list(messages)
|
|
|
|
last_tool_results = []
|
|
for tool_call in response.tool_calls:
|
|
tool_result = await self._tools.execute_tool_call(tool_call)
|
|
last_tool_results.append(tool_result)
|
|
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
|
|
messages.append(
|
|
ToolResponseMessage(
|
|
id=tool_call.id,
|
|
content=[TextContentPart(text=tool_response_content)],
|
|
)
|
|
)
|
|
|
|
LOGGER.warning("Max tool rounds reached in chat flow")
|
|
final_response = await self._try_final_response_without_tools(
|
|
client=client,
|
|
model=model,
|
|
messages=messages,
|
|
)
|
|
if final_response:
|
|
return final_response, called_tools
|
|
|
|
return self._build_tool_limit_fallback(last_tool_results), called_tools
|
|
|
|
async def _try_final_response_without_tools(
|
|
self,
|
|
*,
|
|
client: Any,
|
|
model: str,
|
|
messages: list[Message],
|
|
) -> str | None:
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
client.generate,
|
|
**get_generate_kwargs(
|
|
model=model,
|
|
messages=messages,
|
|
),
|
|
)
|
|
except Exception:
|
|
LOGGER.warning("Final no-tool synthesis call failed", exc_info=True)
|
|
return None
|
|
|
|
return extract_text(response.content)
|
|
|
|
@staticmethod
|
|
def _summarize_model_note(chunks: list[str]) -> str:
|
|
text = "".join(chunks).strip()
|
|
if not text or text in {"{}", "[]"}:
|
|
return ""
|
|
|
|
compact = " ".join(text.split())
|
|
if compact.lower() in {"start", "end"}:
|
|
return ""
|
|
if len(compact) > 600:
|
|
return f"{compact[:600].rstrip()}..."
|
|
return compact
|
|
|
|
@staticmethod
|
|
def _event_text(event: Any) -> str:
|
|
for attr in ("chunk", "delta", "text", "content"):
|
|
value = getattr(event, attr, None)
|
|
if isinstance(value, str):
|
|
return value
|
|
return ""
|
|
|
|
@staticmethod
|
|
def _tool_start_message(tool_name: str) -> str:
|
|
labels = {
|
|
"getPresentationOutline": "Reading the presentation outline",
|
|
"searchSlides": "Searching relevant slides",
|
|
"getSlideAtIndex": "Opening the requested slide",
|
|
"getAvailableLayouts": "Checking available layouts",
|
|
"getContentSchemaFromLayoutId": "Checking the layout schema",
|
|
"generateAssets": "Generating slide assets",
|
|
"saveSlide": "Saving the slide",
|
|
}
|
|
return labels.get(tool_name, f"Running {tool_name}")
|
|
|
|
@staticmethod
|
|
def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str:
|
|
for entry in reversed(last_tool_results):
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
if not entry.get("ok"):
|
|
continue
|
|
result = entry.get("result")
|
|
if not isinstance(result, dict):
|
|
continue
|
|
message = result.get("message")
|
|
if isinstance(message, str) and message.strip():
|
|
return message.strip()
|
|
|
|
return (
|
|
"I completed several tool operations but could not finalize the response "
|
|
"within the tool limit. Please ask a follow-up and I will continue."
|
|
)
|
|
|
|
@staticmethod
|
|
def _summarize_tool_result(tool_name: str, tool_result: dict[str, Any]) -> str:
|
|
if not tool_result.get("ok"):
|
|
error = tool_result.get("error")
|
|
if isinstance(error, str) and error.strip():
|
|
return f"{tool_name} failed: {error.strip()}"
|
|
return f"{tool_name} failed."
|
|
|
|
result = tool_result.get("result")
|
|
if isinstance(result, dict):
|
|
message = result.get("message")
|
|
if isinstance(message, str) and message.strip():
|
|
return message.strip()
|
|
|
|
note = result.get("note")
|
|
if isinstance(note, str) and note.strip():
|
|
return note.strip()
|
|
|
|
count = result.get("count")
|
|
if isinstance(count, int):
|
|
return f"{tool_name} returned {count} result(s)."
|
|
|
|
found = result.get("found")
|
|
if isinstance(found, bool):
|
|
return (
|
|
f"{tool_name} found requested data."
|
|
if found
|
|
else f"{tool_name} did not find matching data."
|
|
)
|
|
|
|
return f"{tool_name} completed."
|
|
|
|
@staticmethod
|
|
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
|
|
messages: list[Message] = []
|
|
for item in history:
|
|
role = item.get("role")
|
|
content = item.get("content")
|
|
if not content:
|
|
continue
|
|
if role == "user":
|
|
messages.append(UserMessage(content=content))
|
|
elif role == "assistant":
|
|
messages.append(AssistantMessage(content=[content]))
|
|
return messages
|