presenton/servers/fastapi/api/v1/ppt/endpoints/chat.py
sudipnext 4e87dc8b70 refactor: Update database session management and enhance chat memory services
- Replaced `get_container_db_async_session` with `async_session_maker` for improved session handling in background tasks.
- Refactored chat memory services to utilize a shared `mem0` client for better memory management.
- Introduced new methods for retrieving and storing chat history, integrating with SQL and memory layers.
- Enhanced error handling and response management in chat-related services.
- Cleaned up unused code and improved overall structure for maintainability.
2026-04-25 19:10:39 +05:45

129 lines
4.5 KiB
Python

import json
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from models.chat import (
ChatConversationListItem,
ChatHistoryMessageItem,
ChatHistoryResponse,
ChatMessageRequest,
ChatMessageResponse,
)
from models.sse_response import (
SSECompleteResponse,
SSEErrorResponse,
SSEStatusResponse,
SSETraceResponse,
SSEResponse,
)
from services.chat import ChatTurnResult, PresentationChatService
from services.chat import sql_chat_history
from services.database import get_async_session
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
@CHAT_ROUTER.get("/conversations", response_model=list[ChatConversationListItem])
async def list_chat_conversations(
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
sql_session: AsyncSession = Depends(get_async_session),
):
raw = await sql_chat_history.list_conversations(
sql_session, presentation_id=presentation_id
)
return [
ChatConversationListItem(
conversation_id=uuid.UUID(str(item["conversation_id"])),
updated_at=item.get("updated_at"),
last_message_preview=item.get("last_message_preview"),
)
for item in raw
]
@CHAT_ROUTER.get("/history", response_model=ChatHistoryResponse)
async def get_chat_history(
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
conversation_id: uuid.UUID = Query(..., description="Conversation thread id"),
sql_session: AsyncSession = Depends(get_async_session),
):
rows = await sql_chat_history.load_messages_with_meta(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
return ChatHistoryResponse(
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=[
ChatHistoryMessageItem(
role=str(m.get("role") or ""),
content=str(m.get("content") or ""),
created_at=m.get("created_at")
if isinstance(m.get("created_at"), str)
else None,
)
for m in rows
],
)
@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse)
async def chat_message(
payload: ChatMessageRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
service = PresentationChatService(
sql_session=sql_session,
presentation_id=payload.presentation_id,
conversation_id=payload.conversation_id,
)
result = await service.generate_reply(payload.message)
return ChatMessageResponse(
conversation_id=result.conversation_id,
response=result.response_text,
tool_calls=result.tool_calls,
)
@CHAT_ROUTER.post("/message/stream")
async def chat_message_stream(
payload: ChatMessageRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
service = PresentationChatService(
sql_session=sql_session,
presentation_id=payload.presentation_id,
conversation_id=payload.conversation_id,
)
async def inner():
try:
async for event_type, value in service.stream_reply(payload.message):
if event_type == "chunk" and isinstance(value, str):
yield SSEResponse(
event="response",
data=json.dumps({"type": "chunk", "chunk": value}),
).to_string()
elif event_type == "status" and isinstance(value, str):
yield SSEStatusResponse(status=value).to_string()
elif event_type == "trace" and isinstance(value, dict):
yield SSETraceResponse(trace=value).to_string()
elif event_type == "complete" and isinstance(value, ChatTurnResult):
result = value
complete_payload = ChatMessageResponse(
conversation_id=result.conversation_id,
response=result.response_text,
tool_calls=result.tool_calls,
)
yield SSECompleteResponse(
key="chat",
value=complete_payload.model_dump(mode="json"),
).to_string()
except HTTPException as exc:
yield SSEErrorResponse(detail=exc.detail).to_string()
return StreamingResponse(inner(), media_type="text/event-stream")