- Introduced a new `ChatHistoryMessageModel` to persist chat messages in the database. - Implemented a migration script to create the `chat_history_messages` table. - Enhanced chat services to support storing and retrieving chat history. - Added functionality for generating multiple media assets in a single API call. - Updated the chat UI to display assistant activities and tool usage more effectively. - Refactored API calls to use absolute URLs for better reliability.
212 lines
6.3 KiB
Python
212 lines
6.3 KiB
Python
"""Persist presentation chat threads in SQL rows."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import delete as sa_delete
|
|
from sqlalchemy import func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlmodel import select
|
|
|
|
from models.sql.chat_history_message import ChatHistoryMessageModel
|
|
from utils.datetime_utils import get_current_utc_datetime
|
|
|
|
|
|
def _compact_preview(content: str) -> str:
|
|
preview = content.strip()
|
|
if len(preview) > 200:
|
|
return f"{preview[:200]}…"
|
|
return preview
|
|
|
|
|
|
def _serialize_created_at(value: Any) -> str | None:
|
|
if value is None:
|
|
return None
|
|
if hasattr(value, "isoformat"):
|
|
try:
|
|
return value.isoformat()
|
|
except Exception:
|
|
return None
|
|
if isinstance(value, str) and value.strip():
|
|
return value.strip()
|
|
return None
|
|
|
|
|
|
def _parse_created_at(value: Any) -> datetime | None:
|
|
if isinstance(value, datetime):
|
|
return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
|
|
if not isinstance(value, str) or not value.strip():
|
|
return None
|
|
normalized = value.strip().replace("Z", "+00:00")
|
|
try:
|
|
parsed = datetime.fromisoformat(normalized)
|
|
except ValueError:
|
|
return None
|
|
return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc)
|
|
|
|
|
|
async def load_messages(
|
|
session: AsyncSession,
|
|
*,
|
|
presentation_id: uuid.UUID,
|
|
conversation_id: uuid.UUID,
|
|
) -> list[dict[str, str]]:
|
|
rows = await load_messages_with_meta(
|
|
session,
|
|
presentation_id=presentation_id,
|
|
conversation_id=conversation_id,
|
|
)
|
|
return [
|
|
{"role": row["role"], "content": row["content"]}
|
|
for row in rows
|
|
if isinstance(row.get("role"), str) and isinstance(row.get("content"), str)
|
|
]
|
|
|
|
|
|
async def load_messages_with_meta(
|
|
session: AsyncSession,
|
|
*,
|
|
presentation_id: uuid.UUID,
|
|
conversation_id: uuid.UUID,
|
|
) -> list[dict[str, Any]]:
|
|
rows = list(
|
|
(
|
|
await session.scalars(
|
|
select(ChatHistoryMessageModel)
|
|
.where(
|
|
ChatHistoryMessageModel.presentation_id == presentation_id,
|
|
ChatHistoryMessageModel.conversation_id == conversation_id,
|
|
)
|
|
.order_by(ChatHistoryMessageModel.position.asc())
|
|
)
|
|
).all()
|
|
)
|
|
out: list[dict[str, Any]] = []
|
|
for row in rows:
|
|
entry: dict[str, Any] = {
|
|
"role": row.role,
|
|
"content": row.content,
|
|
}
|
|
created = _serialize_created_at(row.created_at)
|
|
if created:
|
|
entry["created_at"] = created
|
|
out.append(entry)
|
|
return out
|
|
|
|
|
|
async def replace_messages(
|
|
session: AsyncSession,
|
|
*,
|
|
presentation_id: uuid.UUID,
|
|
conversation_id: uuid.UUID,
|
|
messages: list[dict[str, str]],
|
|
) -> None:
|
|
await session.execute(
|
|
sa_delete(ChatHistoryMessageModel).where(
|
|
ChatHistoryMessageModel.presentation_id == presentation_id,
|
|
ChatHistoryMessageModel.conversation_id == conversation_id,
|
|
)
|
|
)
|
|
|
|
next_position = 1
|
|
base_time = get_current_utc_datetime()
|
|
for index, message in enumerate(messages):
|
|
role = message.get("role")
|
|
content = message.get("content")
|
|
if role not in ("user", "assistant"):
|
|
continue
|
|
if not isinstance(content, str) or not content.strip():
|
|
continue
|
|
|
|
created_at = _parse_created_at(message.get("created_at")) or (
|
|
base_time + timedelta(microseconds=index)
|
|
)
|
|
session.add(
|
|
ChatHistoryMessageModel(
|
|
presentation_id=presentation_id,
|
|
conversation_id=conversation_id,
|
|
position=next_position,
|
|
role=role,
|
|
content=content,
|
|
created_at=created_at,
|
|
)
|
|
)
|
|
next_position += 1
|
|
await session.flush()
|
|
|
|
|
|
async def append_turn(
|
|
session: AsyncSession,
|
|
*,
|
|
presentation_id: uuid.UUID,
|
|
conversation_id: uuid.UUID,
|
|
user_message: str,
|
|
assistant_message: str,
|
|
tool_calls: list[str] | None = None,
|
|
) -> None:
|
|
max_position = await session.scalar(
|
|
select(func.max(ChatHistoryMessageModel.position)).where(
|
|
ChatHistoryMessageModel.presentation_id == presentation_id,
|
|
ChatHistoryMessageModel.conversation_id == conversation_id,
|
|
)
|
|
)
|
|
next_position = int(max_position or 0) + 1
|
|
now = get_current_utc_datetime()
|
|
|
|
session.add(
|
|
ChatHistoryMessageModel(
|
|
presentation_id=presentation_id,
|
|
conversation_id=conversation_id,
|
|
position=next_position,
|
|
role="user",
|
|
content=user_message,
|
|
created_at=now,
|
|
)
|
|
)
|
|
session.add(
|
|
ChatHistoryMessageModel(
|
|
presentation_id=presentation_id,
|
|
conversation_id=conversation_id,
|
|
position=next_position + 1,
|
|
role="assistant",
|
|
content=assistant_message,
|
|
created_at=now + timedelta(microseconds=1),
|
|
tool_calls=tool_calls or None,
|
|
)
|
|
)
|
|
await session.flush()
|
|
|
|
|
|
async def list_conversations(
|
|
session: AsyncSession, *, presentation_id: uuid.UUID
|
|
) -> list[dict[str, Any]]:
|
|
rows = list(
|
|
(
|
|
await session.scalars(
|
|
select(ChatHistoryMessageModel)
|
|
.where(ChatHistoryMessageModel.presentation_id == presentation_id)
|
|
.order_by(
|
|
ChatHistoryMessageModel.created_at.desc(),
|
|
ChatHistoryMessageModel.position.desc(),
|
|
)
|
|
)
|
|
).all()
|
|
)
|
|
|
|
summary_by_conversation: dict[str, dict[str, Any]] = {}
|
|
for row in rows:
|
|
conversation_key = str(row.conversation_id)
|
|
if conversation_key in summary_by_conversation:
|
|
continue
|
|
summary_by_conversation[conversation_key] = {
|
|
"conversation_id": conversation_key,
|
|
"updated_at": _serialize_created_at(row.created_at),
|
|
"last_message_preview": _compact_preview(row.content),
|
|
}
|
|
|
|
summaries = list(summary_by_conversation.values())
|
|
summaries.sort(key=lambda item: item.get("updated_at") or "", reverse=True)
|
|
return summaries
|