presenton/servers/fastapi/services/chat/sql_chat_history.py
sudipnext 4c271170b5 feat: Add chat history management and asset generation features
- Introduced a new `ChatHistoryMessageModel` to persist chat messages in the database.
- Implemented a migration script to create the `chat_history_messages` table.
- Enhanced chat services to support storing and retrieving chat history.
- Added functionality for generating multiple media assets in a single API call.
- Updated the chat UI to display assistant activities and tool usage more effectively.
- Refactored API calls to use absolute URLs for better reliability.
2026-04-26 13:12:50 +05:45

212 lines
6.3 KiB
Python

"""Persist presentation chat threads in SQL rows."""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.sql.chat_history_message import ChatHistoryMessageModel
from utils.datetime_utils import get_current_utc_datetime
def _compact_preview(content: str) -> str:
preview = content.strip()
if len(preview) > 200:
return f"{preview[:200]}"
return preview
def _serialize_created_at(value: Any) -> str | None:
if value is None:
return None
if hasattr(value, "isoformat"):
try:
return value.isoformat()
except Exception:
return None
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _parse_created_at(value: Any) -> datetime | None:
if isinstance(value, datetime):
return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
if not isinstance(value, str) or not value.strip():
return None
normalized = value.strip().replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(normalized)
except ValueError:
return None
return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc)
async def load_messages(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
rows = await load_messages_with_meta(
session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
return [
{"role": row["role"], "content": row["content"]}
for row in rows
if isinstance(row.get("role"), str) and isinstance(row.get("content"), str)
]
async def load_messages_with_meta(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, Any]]:
rows = list(
(
await session.scalars(
select(ChatHistoryMessageModel)
.where(
ChatHistoryMessageModel.presentation_id == presentation_id,
ChatHistoryMessageModel.conversation_id == conversation_id,
)
.order_by(ChatHistoryMessageModel.position.asc())
)
).all()
)
out: list[dict[str, Any]] = []
for row in rows:
entry: dict[str, Any] = {
"role": row.role,
"content": row.content,
}
created = _serialize_created_at(row.created_at)
if created:
entry["created_at"] = created
out.append(entry)
return out
async def replace_messages(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
messages: list[dict[str, str]],
) -> None:
await session.execute(
sa_delete(ChatHistoryMessageModel).where(
ChatHistoryMessageModel.presentation_id == presentation_id,
ChatHistoryMessageModel.conversation_id == conversation_id,
)
)
next_position = 1
base_time = get_current_utc_datetime()
for index, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role not in ("user", "assistant"):
continue
if not isinstance(content, str) or not content.strip():
continue
created_at = _parse_created_at(message.get("created_at")) or (
base_time + timedelta(microseconds=index)
)
session.add(
ChatHistoryMessageModel(
presentation_id=presentation_id,
conversation_id=conversation_id,
position=next_position,
role=role,
content=content,
created_at=created_at,
)
)
next_position += 1
await session.flush()
async def append_turn(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
tool_calls: list[str] | None = None,
) -> None:
max_position = await session.scalar(
select(func.max(ChatHistoryMessageModel.position)).where(
ChatHistoryMessageModel.presentation_id == presentation_id,
ChatHistoryMessageModel.conversation_id == conversation_id,
)
)
next_position = int(max_position or 0) + 1
now = get_current_utc_datetime()
session.add(
ChatHistoryMessageModel(
presentation_id=presentation_id,
conversation_id=conversation_id,
position=next_position,
role="user",
content=user_message,
created_at=now,
)
)
session.add(
ChatHistoryMessageModel(
presentation_id=presentation_id,
conversation_id=conversation_id,
position=next_position + 1,
role="assistant",
content=assistant_message,
created_at=now + timedelta(microseconds=1),
tool_calls=tool_calls or None,
)
)
await session.flush()
async def list_conversations(
session: AsyncSession, *, presentation_id: uuid.UUID
) -> list[dict[str, Any]]:
rows = list(
(
await session.scalars(
select(ChatHistoryMessageModel)
.where(ChatHistoryMessageModel.presentation_id == presentation_id)
.order_by(
ChatHistoryMessageModel.created_at.desc(),
ChatHistoryMessageModel.position.desc(),
)
)
).all()
)
summary_by_conversation: dict[str, dict[str, Any]] = {}
for row in rows:
conversation_key = str(row.conversation_id)
if conversation_key in summary_by_conversation:
continue
summary_by_conversation[conversation_key] = {
"conversation_id": conversation_key,
"updated_at": _serialize_created_at(row.created_at),
"last_message_preview": _compact_preview(row.content),
}
summaries = list(summary_by_conversation.values())
summaries.sort(key=lambda item: item.get("updated_at") or "", reverse=True)
return summaries