presenton/servers/fastapi/services/chat/service.py
sudipnext 4e87dc8b70 refactor: Update database session management and enhance chat memory services
- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks.
- Refactored chat memory services to utilize a shared `mem0` client for better memory management.
- Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers.
- Enhanced error handling and response management in chat-related services.
- Cleaned up unused code and improved overall structure for maintainability.
2026-04-25 19:10:39 +05:45

406 lines
15 KiB
Python

import asyncio
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, Literal
from fastapi import HTTPException
from llmai import get_client # type: ignore[import-not-found]
from llmai.shared import ( # type: ignore[import-not-found]
AssistantMessage,
Message,
SystemMessage,
ToolResponseMessage,
UserMessage,
)
from sqlalchemy.ext.asyncio import AsyncSession
from models.sql.presentation import PresentationModel
from services.chat.conversation_store import ChatConversationStore
from services.chat.presentation_context_store import PresentationContextStore
from services.chat.prompts import build_system_prompt
from services.chat.tools import ChatTools
from utils.llm_client_error_handler import handle_llm_client_exceptions
from utils.llm_config import get_llm_config
from utils.llm_provider import get_model
from utils.llm_utils import (
extract_text,
get_generate_kwargs,
stream_generate_events,
)
LOGGER = logging.getLogger(__name__)
MAX_TOOL_ROUNDS = 16
@dataclass(frozen=True)
class ChatTurnResult:
conversation_id: uuid.UUID
response_text: str
tool_calls: list[str]
ChatStreamEventType = Literal["chunk", "complete", "status", "trace"]
ChatStreamEventValue = str | ChatTurnResult | dict[str, Any]
class PresentationChatService:
def __init__(
self,
sql_session: AsyncSession,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID | None,
):
self._sql_session = sql_session
self._presentation_id = presentation_id
self._conversation_id = conversation_id
self._conversation_store = ChatConversationStore(sql_session)
self._memory = PresentationContextStore(sql_session, presentation_id)
self._tools = ChatTools(self._memory)
async def generate_reply(self, user_message: str) -> ChatTurnResult:
conversation_id, messages = await self._prepare_turn_context(user_message)
response_text, tool_calls = await self._run_llm_with_tools(messages)
return await self._persist_turn(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
tool_calls=tool_calls,
)
async def stream_reply(
self, user_message: str
) -> AsyncGenerator[tuple[ChatStreamEventType, ChatStreamEventValue], None]:
yield "status", "Preparing context"
conversation_id, messages = await self._prepare_turn_context(user_message)
yield "status", "Thinking"
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
response_text: str | None = None
for round_index in range(MAX_TOOL_ROUNDS):
completion_chunk: Any | None = None
round_content_chunks: list[str] = []
try:
async for event in stream_generate_events(
client,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
stream=True,
),
):
if getattr(event, "type", None) == "content":
chunk = getattr(event, "chunk", None)
if chunk:
round_content_chunks.append(chunk)
yield "chunk", chunk
elif getattr(event, "type", None) == "completion":
completion_chunk = event
except Exception as exc:
raise handle_llm_client_exceptions(exc)
completion_tool_calls = list(
getattr(completion_chunk, "tool_calls", []) or []
)
if completion_tool_calls:
tool_names = [tool_call.name for tool_call in completion_tool_calls]
called_tools.extend(tool_names)
yield "trace", {
"kind": "tool_plan",
"round": round_index + 1,
"tools": tool_names,
"message": f"Using tools: {', '.join(tool_names)}",
}
messages = (
list(getattr(completion_chunk, "messages", []) or [])
if getattr(completion_chunk, "messages", None)
else list(messages)
)
last_tool_results = []
for tool_call in completion_tool_calls:
yield "trace", {
"kind": "tool_call",
"round": round_index + 1,
"tool": tool_call.name,
"status": "start",
"message": f"Running {tool_call.name}",
}
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
yield "trace", {
"kind": "tool_call",
"round": round_index + 1,
"tool": tool_call.name,
"status": "success" if tool_result.get("ok") else "error",
"message": self._summarize_tool_result(
tool_call.name, tool_result
),
}
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[tool_response_content],
)
)
yield "status", "Thinking"
continue
response_text = "".join(round_content_chunks)
if not response_text and completion_chunk:
response_text = extract_text(getattr(completion_chunk, "content", None))
if not response_text:
response_text = "I could not generate a response for that request."
if not round_content_chunks:
yield "chunk", response_text
break
else:
LOGGER.warning("Max tool rounds reached in chat stream flow")
yield "trace", {
"kind": "limit",
"message": (
"Reached tool-call limit before final answer; "
"attempting best-effort summary."
),
}
yield "status", "Finalizing response"
response_text = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if not response_text:
response_text = self._build_tool_limit_fallback(last_tool_results)
yield "chunk", response_text
final_response_text = response_text or "I could not generate a response for that request."
if response_text is None:
yield "chunk", final_response_text
yield "status", "Saving conversation"
result = await self._persist_turn(
conversation_id=conversation_id,
user_message=user_message,
response_text=final_response_text,
tool_calls=called_tools,
)
yield "complete", result
async def _prepare_turn_context(
self, user_message: str
) -> tuple[uuid.UUID, list[Message]]:
if not (user_message or "").strip():
raise HTTPException(status_code=400, detail="Message is required")
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
raise HTTPException(status_code=404, detail="Presentation not found")
# A stable conversation_id is created here before the first user message
# so mem0 and SQL can scope the thread; the client need not "chat" first.
conversation_id = await self._conversation_store.ensure_conversation_id(
self._conversation_id
)
history = await self._conversation_store.load_history(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
)
history_messages = self._convert_history_to_messages(history)
presentation_memory = await self._memory.retrieve_context(user_message)
chat_memory = await self._conversation_store.retrieve_semantic_context(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
query=user_message,
)
messages: list[Message] = [
SystemMessage(
content=build_system_prompt(
presentation_memory_context=presentation_memory,
chat_memory_context=chat_memory,
)
),
*history_messages,
UserMessage(content=user_message),
]
return conversation_id, messages
async def _persist_turn(
self,
*,
conversation_id: uuid.UUID,
user_message: str,
response_text: str,
tool_calls: list[str],
) -> ChatTurnResult:
await self._conversation_store.append_turn(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=response_text,
)
return ChatTurnResult(
conversation_id=conversation_id,
response_text=response_text,
tool_calls=tool_calls,
)
async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]:
# llmai is the only LLM entrypoint; provider selection comes from app config.
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
for _ in range(MAX_TOOL_ROUNDS):
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
),
)
except Exception as exc:
raise handle_llm_client_exceptions(exc)
if not response.tool_calls:
response_text = extract_text(response.content) or (
"I could not generate a response for that request."
)
return response_text, called_tools
called_tools.extend([tool_call.name for tool_call in response.tool_calls])
# Reuse llmai-returned threaded messages so provider adapters keep state.
messages = list(response.messages) if response.messages else list(messages)
last_tool_results = []
for tool_call in response.tool_calls:
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
# Tool responses are fed back into llmai to let the model continue.
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[tool_response_content],
)
)
LOGGER.warning("Max tool rounds reached in chat flow")
final_response = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if final_response:
return final_response, called_tools
return self._build_tool_limit_fallback(last_tool_results), called_tools
async def _try_final_response_without_tools(
self,
*,
client: Any,
model: str,
messages: list[Message],
) -> str | None:
"""
Give the model one final chance to synthesize a natural-language answer
from already-executed tool outputs, without allowing more tool calls.
"""
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
),
)
except Exception:
LOGGER.warning("Final no-tool synthesis call failed", exc_info=True)
return None
return extract_text(response.content)
@staticmethod
def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str:
for entry in reversed(last_tool_results):
if not isinstance(entry, dict):
continue
if not entry.get("ok"):
continue
result = entry.get("result")
if not isinstance(result, dict):
continue
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
return (
"I completed several tool operations but could not finalize the response "
"within the tool limit. Please ask a follow-up and I will continue."
)
@staticmethod
def _summarize_tool_result(tool_name: str, tool_result: dict[str, Any]) -> str:
if not tool_result.get("ok"):
error = tool_result.get("error")
if isinstance(error, str) and error.strip():
return f"{tool_name} failed: {error.strip()}"
return f"{tool_name} failed."
result = tool_result.get("result")
if isinstance(result, dict):
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
note = result.get("note")
if isinstance(note, str) and note.strip():
return note.strip()
count = result.get("count")
if isinstance(count, int):
return f"{tool_name} returned {count} result(s)."
found = result.get("found")
if isinstance(found, bool):
return (
f"{tool_name} found requested data."
if found
else f"{tool_name} did not find matching data."
)
return f"{tool_name} completed."
@staticmethod
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
messages: list[Message] = []
for item in history:
role = item.get("role")
content = item.get("content")
if not content:
continue
if role == "user":
messages.append(UserMessage(content=content))
elif role == "assistant":
messages.append(AssistantMessage(content=[content]))
return messages