feat: Implement streaming chat functionality in FastAPI and update UI
- Added a new endpoint for streaming chat messages in FastAPI, allowing real-time interaction. - Enhanced the `PresentationChatService` to support streaming replies with event types for chunked responses, status updates, and tool tracing. - Updated the chat UI to handle and display assistant activities, including loading states and tool usage. - Introduced new models for SSE responses and integrated them into the chat service. - Improved error handling and response management in the chat API.
This commit is contained in:
parent
ebf3f05d6e
commit
17ea7d9f95
5 changed files with 783 additions and 32 deletions
|
|
@ -1,8 +1,18 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.chat import ChatMessageRequest, ChatMessageResponse
|
||||
from services.chat import PresentationChatService
|
||||
from models.sse_response import (
|
||||
SSECompleteResponse,
|
||||
SSEErrorResponse,
|
||||
SSEStatusResponse,
|
||||
SSETraceResponse,
|
||||
SSEResponse,
|
||||
)
|
||||
from services.chat import ChatTurnResult, PresentationChatService
|
||||
from services.database import get_async_session
|
||||
|
||||
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
|
||||
|
|
@ -24,3 +34,43 @@ async def chat_message(
|
|||
response=result.response_text,
|
||||
tool_calls=result.tool_calls,
|
||||
)
|
||||
|
||||
|
||||
@CHAT_ROUTER.post("/message/stream")
|
||||
async def chat_message_stream(
|
||||
payload: ChatMessageRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
service = PresentationChatService(
|
||||
sql_session=sql_session,
|
||||
presentation_id=payload.presentation_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
)
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
async for event_type, value in service.stream_reply(payload.message):
|
||||
if event_type == "chunk" and isinstance(value, str):
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": value}),
|
||||
).to_string()
|
||||
elif event_type == "status" and isinstance(value, str):
|
||||
yield SSEStatusResponse(status=value).to_string()
|
||||
elif event_type == "trace" and isinstance(value, dict):
|
||||
yield SSETraceResponse(trace=value).to_string()
|
||||
elif event_type == "complete" and isinstance(value, ChatTurnResult):
|
||||
result = value
|
||||
complete_payload = ChatMessageResponse(
|
||||
conversation_id=result.conversation_id,
|
||||
response=result.response_text,
|
||||
tool_calls=result.tool_calls,
|
||||
)
|
||||
yield SSECompleteResponse(
|
||||
key="chat",
|
||||
value=complete_payload.model_dump(mode="json"),
|
||||
).to_string()
|
||||
except HTTPException as exc:
|
||||
yield SSEErrorResponse(detail=exc.detail).to_string()
|
||||
|
||||
return StreamingResponse(inner(), media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -20,6 +20,15 @@ class SSEStatusResponse(BaseModel):
|
|||
).to_string()
|
||||
|
||||
|
||||
class SSETraceResponse(BaseModel):
|
||||
trace: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "trace", "trace": self.trace})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSEErrorResponse(BaseModel):
|
||||
detail: str
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
from llmai import get_client # type: ignore[import-not-found]
|
||||
|
|
@ -24,7 +25,11 @@ from services.chat.tools import ChatTools
|
|||
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
||||
from utils.llm_config import get_llm_config
|
||||
from utils.llm_provider import get_model
|
||||
from utils.llm_utils import extract_text, get_generate_kwargs
|
||||
from utils.llm_utils import (
|
||||
extract_text,
|
||||
get_generate_kwargs,
|
||||
stream_generate_events,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
MAX_TOOL_ROUNDS = 16
|
||||
|
|
@ -37,6 +42,10 @@ class ChatTurnResult:
|
|||
tool_calls: list[str]
|
||||
|
||||
|
||||
ChatStreamEventType = Literal["chunk", "complete", "status", "trace"]
|
||||
ChatStreamEventValue = str | ChatTurnResult | dict[str, Any]
|
||||
|
||||
|
||||
class PresentationChatService:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -53,6 +62,146 @@ class PresentationChatService:
|
|||
self._tools = ChatTools(self._memory)
|
||||
|
||||
async def generate_reply(self, user_message: str) -> ChatTurnResult:
|
||||
conversation_id, messages = await self._prepare_turn_context(user_message)
|
||||
response_text, tool_calls = await self._run_llm_with_tools(messages)
|
||||
return await self._persist_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text=response_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
async def stream_reply(
|
||||
self, user_message: str
|
||||
) -> AsyncGenerator[tuple[ChatStreamEventType, ChatStreamEventValue], None]:
|
||||
yield "status", "Preparing context"
|
||||
conversation_id, messages = await self._prepare_turn_context(user_message)
|
||||
yield "status", "Thinking"
|
||||
|
||||
client = get_client(config=get_llm_config())
|
||||
model = get_model()
|
||||
tools = self._tools.get_tool_definitions()
|
||||
|
||||
called_tools: list[str] = []
|
||||
last_tool_results: list[dict[str, Any]] = []
|
||||
response_text: str | None = None
|
||||
|
||||
for round_index in range(MAX_TOOL_ROUNDS):
|
||||
completion_chunk: Any | None = None
|
||||
round_content_chunks: list[str] = []
|
||||
|
||||
try:
|
||||
async for event in stream_generate_events(
|
||||
client,
|
||||
**get_generate_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
),
|
||||
):
|
||||
if getattr(event, "type", None) == "content":
|
||||
chunk = getattr(event, "chunk", None)
|
||||
if chunk:
|
||||
round_content_chunks.append(chunk)
|
||||
yield "chunk", chunk
|
||||
elif getattr(event, "type", None) == "completion":
|
||||
completion_chunk = event
|
||||
except Exception as exc:
|
||||
raise handle_llm_client_exceptions(exc)
|
||||
|
||||
completion_tool_calls = list(
|
||||
getattr(completion_chunk, "tool_calls", []) or []
|
||||
)
|
||||
if completion_tool_calls:
|
||||
tool_names = [tool_call.name for tool_call in completion_tool_calls]
|
||||
called_tools.extend(tool_names)
|
||||
yield "trace", {
|
||||
"kind": "tool_plan",
|
||||
"round": round_index + 1,
|
||||
"tools": tool_names,
|
||||
"message": f"Using tools: {', '.join(tool_names)}",
|
||||
}
|
||||
messages = (
|
||||
list(getattr(completion_chunk, "messages", []) or [])
|
||||
if getattr(completion_chunk, "messages", None)
|
||||
else list(messages)
|
||||
)
|
||||
|
||||
last_tool_results = []
|
||||
for tool_call in completion_tool_calls:
|
||||
yield "trace", {
|
||||
"kind": "tool_call",
|
||||
"round": round_index + 1,
|
||||
"tool": tool_call.name,
|
||||
"status": "start",
|
||||
"message": f"Running {tool_call.name}",
|
||||
}
|
||||
tool_result = await self._tools.execute_tool_call(tool_call)
|
||||
last_tool_results.append(tool_result)
|
||||
yield "trace", {
|
||||
"kind": "tool_call",
|
||||
"round": round_index + 1,
|
||||
"tool": tool_call.name,
|
||||
"status": "success" if tool_result.get("ok") else "error",
|
||||
"message": self._summarize_tool_result(
|
||||
tool_call.name, tool_result
|
||||
),
|
||||
}
|
||||
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
id=tool_call.id,
|
||||
content=[tool_response_content],
|
||||
)
|
||||
)
|
||||
yield "status", "Thinking"
|
||||
continue
|
||||
|
||||
response_text = "".join(round_content_chunks)
|
||||
if not response_text and completion_chunk:
|
||||
response_text = extract_text(getattr(completion_chunk, "content", None))
|
||||
if not response_text:
|
||||
response_text = "I could not generate a response for that request."
|
||||
|
||||
if not round_content_chunks:
|
||||
yield "chunk", response_text
|
||||
break
|
||||
else:
|
||||
LOGGER.warning("Max tool rounds reached in chat stream flow")
|
||||
yield "trace", {
|
||||
"kind": "limit",
|
||||
"message": (
|
||||
"Reached tool-call limit before final answer; "
|
||||
"attempting best-effort summary."
|
||||
),
|
||||
}
|
||||
yield "status", "Finalizing response"
|
||||
response_text = await self._try_final_response_without_tools(
|
||||
client=client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
if not response_text:
|
||||
response_text = self._build_tool_limit_fallback(last_tool_results)
|
||||
yield "chunk", response_text
|
||||
|
||||
final_response_text = response_text or "I could not generate a response for that request."
|
||||
if response_text is None:
|
||||
yield "chunk", final_response_text
|
||||
|
||||
yield "status", "Saving conversation"
|
||||
result = await self._persist_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text=final_response_text,
|
||||
tool_calls=called_tools,
|
||||
)
|
||||
yield "complete", result
|
||||
|
||||
async def _prepare_turn_context(
|
||||
self, user_message: str
|
||||
) -> tuple[uuid.UUID, list[Message]]:
|
||||
if not (user_message or "").strip():
|
||||
raise HTTPException(status_code=400, detail="Message is required")
|
||||
|
||||
|
|
@ -75,8 +224,16 @@ class PresentationChatService:
|
|||
*history_messages,
|
||||
UserMessage(content=user_message),
|
||||
]
|
||||
return conversation_id, messages
|
||||
|
||||
response_text, tool_calls = await self._run_llm_with_tools(messages)
|
||||
async def _persist_turn(
|
||||
self,
|
||||
*,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
response_text: str,
|
||||
tool_calls: list[str],
|
||||
) -> ChatTurnResult:
|
||||
await self._conversation_store.append_turn(
|
||||
presentation_id=self._presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
|
|
@ -190,6 +347,38 @@ class PresentationChatService:
|
|||
"within the tool limit. Please ask a follow-up and I will continue."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_tool_result(tool_name: str, tool_result: dict[str, Any]) -> str:
|
||||
if not tool_result.get("ok"):
|
||||
error = tool_result.get("error")
|
||||
if isinstance(error, str) and error.strip():
|
||||
return f"{tool_name} failed: {error.strip()}"
|
||||
return f"{tool_name} failed."
|
||||
|
||||
result = tool_result.get("result")
|
||||
if isinstance(result, dict):
|
||||
message = result.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
|
||||
note = result.get("note")
|
||||
if isinstance(note, str) and note.strip():
|
||||
return note.strip()
|
||||
|
||||
count = result.get("count")
|
||||
if isinstance(count, int):
|
||||
return f"{tool_name} returned {count} result(s)."
|
||||
|
||||
found = result.get("found")
|
||||
if isinstance(found, bool):
|
||||
return (
|
||||
f"{tool_name} found requested data."
|
||||
if found
|
||||
else f"{tool_name} did not find matching data."
|
||||
)
|
||||
|
||||
return f"{tool_name} completed."
|
||||
|
||||
@staticmethod
|
||||
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
|
||||
messages: list[Message] = []
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
Activity,
|
||||
CheckCircle2,
|
||||
ChevronDown,
|
||||
ChevronRight,
|
||||
ChevronUp,
|
||||
CircleDot,
|
||||
Loader2,
|
||||
MessageCircle,
|
||||
MessageCircleMore,
|
||||
Plus,
|
||||
RefreshCw,
|
||||
Send,
|
||||
XCircle,
|
||||
} from "lucide-react";
|
||||
import React, {
|
||||
FormEvent,
|
||||
|
|
@ -19,6 +24,7 @@ import React, {
|
|||
} from "react";
|
||||
import { toast } from "sonner";
|
||||
import { PresentationChatApi } from "../../services/api/chat";
|
||||
import type { ChatStreamTrace } from "../../services/api/chat";
|
||||
|
||||
const suggestions: { id: string; icon: ReactNode; suggestion: string }[] = [
|
||||
{
|
||||
|
|
@ -223,6 +229,7 @@ type ChatMessage = {
|
|||
role: "user" | "assistant" | "error";
|
||||
content: string;
|
||||
toolCalls?: string[];
|
||||
activity?: AssistantActivity[];
|
||||
};
|
||||
|
||||
type ChatProps = {
|
||||
|
|
@ -231,6 +238,12 @@ type ChatProps = {
|
|||
onPresentationChanged?: () => Promise<void> | void;
|
||||
};
|
||||
|
||||
type AssistantActivity = {
|
||||
id: string;
|
||||
label: string;
|
||||
state: "running" | "success" | "error" | "info";
|
||||
};
|
||||
|
||||
const createMessageId = () => {
|
||||
if (typeof crypto !== "undefined" && "randomUUID" in crypto) {
|
||||
return crypto.randomUUID();
|
||||
|
|
@ -246,6 +259,73 @@ const AssistantMarker = () => (
|
|||
</div>
|
||||
);
|
||||
|
||||
const inferStatusState = (status: string): AssistantActivity["state"] => {
|
||||
const normalized = status.trim().toLowerCase();
|
||||
if (
|
||||
normalized.includes("preparing") ||
|
||||
normalized.includes("thinking") ||
|
||||
normalized.includes("finalizing") ||
|
||||
normalized.includes("saving")
|
||||
) {
|
||||
return "running";
|
||||
}
|
||||
|
||||
return "info";
|
||||
};
|
||||
|
||||
const formatTraceActivity = (
|
||||
trace: ChatStreamTrace
|
||||
): Omit<AssistantActivity, "id"> | null => {
|
||||
if (typeof trace.message === "string" && trace.message.trim().length > 0) {
|
||||
return {
|
||||
label: trace.message.trim(),
|
||||
state:
|
||||
trace.status === "error"
|
||||
? "error"
|
||||
: trace.status === "success"
|
||||
? "success"
|
||||
: "running",
|
||||
};
|
||||
}
|
||||
|
||||
if (trace.tool && trace.status === "start") {
|
||||
return { label: `Running ${trace.tool}`, state: "running" };
|
||||
}
|
||||
|
||||
if (trace.tool && trace.status === "success") {
|
||||
return { label: `${trace.tool} completed`, state: "success" };
|
||||
}
|
||||
|
||||
if (trace.tool && trace.status === "error") {
|
||||
return { label: `${trace.tool} failed`, state: "error" };
|
||||
}
|
||||
|
||||
if (trace.kind === "tool_plan" && Array.isArray(trace.tools) && trace.tools.length) {
|
||||
return {
|
||||
label: `Using tools: ${trace.tools.join(", ")}`,
|
||||
state: "info",
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const ActivityIcon = ({ state }: { state: AssistantActivity["state"] }) => {
|
||||
if (state === "running") {
|
||||
return <Loader2 className="h-3 w-3 animate-spin text-[#98A2B3]" />;
|
||||
}
|
||||
|
||||
if (state === "success") {
|
||||
return <CheckCircle2 className="h-3 w-3 text-[#12B76A]" />;
|
||||
}
|
||||
|
||||
if (state === "error") {
|
||||
return <XCircle className="h-3 w-3 text-[#F04438]" />;
|
||||
}
|
||||
|
||||
return <CircleDot className="h-3 w-3 text-[#98A2B3]" />;
|
||||
};
|
||||
|
||||
const Chat = ({
|
||||
presentationId,
|
||||
currentSlide,
|
||||
|
|
@ -253,8 +333,12 @@ const Chat = ({
|
|||
}: ChatProps) => {
|
||||
const [input, setInput] = useState("");
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [conversationId, setConversationId] = useState<string | null>(null);
|
||||
const [isSending, setIsSending] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||
const [expandedActivityByMessage, setExpandedActivityByMessage] = useState<
|
||||
Record<string, boolean>
|
||||
>({});
|
||||
|
||||
const inputRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
||||
|
|
@ -262,7 +346,9 @@ const Chat = ({
|
|||
useEffect(() => {
|
||||
setMessages([]);
|
||||
setInput("");
|
||||
setConversationId(null);
|
||||
setErrorMessage(null);
|
||||
setExpandedActivityByMessage({});
|
||||
}, [presentationId]);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -283,7 +369,9 @@ const Chat = ({
|
|||
const resetChat = () => {
|
||||
setMessages([]);
|
||||
setInput("");
|
||||
setConversationId(null);
|
||||
setErrorMessage(null);
|
||||
setExpandedActivityByMessage({});
|
||||
|
||||
inputRef.current?.focus();
|
||||
};
|
||||
|
|
@ -301,6 +389,107 @@ const Chat = ({
|
|||
}
|
||||
};
|
||||
|
||||
const appendAssistantActivity = (
|
||||
assistantMessageId: string,
|
||||
activity: Omit<AssistantActivity, "id">
|
||||
) => {
|
||||
const normalizedLabel = activity.label.trim();
|
||||
if (!normalizedLabel) {
|
||||
return;
|
||||
}
|
||||
|
||||
setMessages((previous) =>
|
||||
previous.map((message) => {
|
||||
if (message.id !== assistantMessageId) {
|
||||
return message;
|
||||
}
|
||||
|
||||
const currentActivity = message.activity ?? [];
|
||||
const lastActivity = currentActivity[currentActivity.length - 1];
|
||||
if (
|
||||
lastActivity &&
|
||||
lastActivity.label === normalizedLabel &&
|
||||
lastActivity.state === activity.state
|
||||
) {
|
||||
return message;
|
||||
}
|
||||
|
||||
const settledActivity: AssistantActivity[] =
|
||||
lastActivity && lastActivity.state === "running"
|
||||
? [
|
||||
...currentActivity.slice(0, -1),
|
||||
{
|
||||
...lastActivity,
|
||||
state:
|
||||
activity.state === "error"
|
||||
? "error"
|
||||
: ("success" as AssistantActivity["state"]),
|
||||
},
|
||||
]
|
||||
: currentActivity;
|
||||
|
||||
const lastSettledActivity = settledActivity[settledActivity.length - 1];
|
||||
if (
|
||||
lastSettledActivity &&
|
||||
lastSettledActivity.label === normalizedLabel &&
|
||||
lastSettledActivity.state !== activity.state
|
||||
) {
|
||||
return {
|
||||
...message,
|
||||
activity: [
|
||||
...settledActivity.slice(0, -1),
|
||||
{
|
||||
...lastSettledActivity,
|
||||
state: activity.state,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
...message,
|
||||
activity: [
|
||||
...settledActivity,
|
||||
{
|
||||
id: createMessageId(),
|
||||
label: normalizedLabel,
|
||||
state: activity.state,
|
||||
},
|
||||
],
|
||||
};
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
const settleAssistantActivities = (
|
||||
assistantMessageId: string,
|
||||
finalState: "success" | "error"
|
||||
) => {
|
||||
setMessages((previous) =>
|
||||
previous.map((message) => {
|
||||
if (message.id !== assistantMessageId || !message.activity?.length) {
|
||||
return message;
|
||||
}
|
||||
|
||||
return {
|
||||
...message,
|
||||
activity: message.activity.map((activityItem) =>
|
||||
activityItem.state === "running"
|
||||
? { ...activityItem, state: finalState }
|
||||
: activityItem
|
||||
),
|
||||
};
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
const toggleActivityExpanded = (messageId: string) => {
|
||||
setExpandedActivityByMessage((previous) => ({
|
||||
...previous,
|
||||
[messageId]: !previous[messageId],
|
||||
}));
|
||||
};
|
||||
|
||||
const submitMessage = async (rawMessage: string) => {
|
||||
const trimmedMessage = rawMessage.trim();
|
||||
|
||||
|
|
@ -319,28 +508,81 @@ const Chat = ({
|
|||
content: trimmedMessage,
|
||||
};
|
||||
|
||||
setMessages((previous) => [...previous, userMessage]);
|
||||
const assistantMessageId = createMessageId();
|
||||
setMessages((previous) => [
|
||||
...previous,
|
||||
userMessage,
|
||||
{
|
||||
id: assistantMessageId,
|
||||
role: "assistant",
|
||||
content: "",
|
||||
toolCalls: [],
|
||||
activity: [],
|
||||
},
|
||||
]);
|
||||
setExpandedActivityByMessage((previous) => ({
|
||||
...previous,
|
||||
[assistantMessageId]: true,
|
||||
}));
|
||||
setInput("");
|
||||
setErrorMessage(null);
|
||||
setIsSending(true);
|
||||
|
||||
try {
|
||||
const response = await PresentationChatApi.sendMessage({
|
||||
presentation_id: presentationId,
|
||||
message: buildBackendMessage(trimmedMessage),
|
||||
});
|
||||
|
||||
setMessages((previous) => [
|
||||
...previous,
|
||||
const response = await PresentationChatApi.streamMessage(
|
||||
{
|
||||
id: createMessageId(),
|
||||
role: "assistant",
|
||||
content: response.response,
|
||||
toolCalls: Array.isArray(response.tool_calls)
|
||||
? response.tool_calls
|
||||
: [],
|
||||
presentation_id: presentationId,
|
||||
message: buildBackendMessage(trimmedMessage),
|
||||
conversation_id: conversationId ?? undefined,
|
||||
},
|
||||
]);
|
||||
{
|
||||
onChunk: (chunk) => {
|
||||
setMessages((previous) =>
|
||||
previous.map((message) =>
|
||||
message.id === assistantMessageId
|
||||
? {
|
||||
...message,
|
||||
content: `${message.content}${chunk}`,
|
||||
}
|
||||
: message
|
||||
)
|
||||
);
|
||||
},
|
||||
onStatus: (status) => {
|
||||
appendAssistantActivity(assistantMessageId, {
|
||||
label: status,
|
||||
state: inferStatusState(status),
|
||||
});
|
||||
},
|
||||
onTrace: (trace) => {
|
||||
const traceActivity = formatTraceActivity(trace);
|
||||
if (!traceActivity) {
|
||||
return;
|
||||
}
|
||||
appendAssistantActivity(assistantMessageId, traceActivity);
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
setMessages((previous) =>
|
||||
previous.map((message) =>
|
||||
message.id === assistantMessageId
|
||||
? {
|
||||
...message,
|
||||
content: response.response,
|
||||
toolCalls: Array.isArray(response.tool_calls)
|
||||
? response.tool_calls
|
||||
: [],
|
||||
}
|
||||
: message
|
||||
)
|
||||
);
|
||||
settleAssistantActivities(assistantMessageId, "success");
|
||||
setConversationId((previous) =>
|
||||
typeof response.conversation_id === "string"
|
||||
? response.conversation_id
|
||||
: previous
|
||||
);
|
||||
|
||||
await refreshPresentationIfNeeded(
|
||||
Array.isArray(response.tool_calls) ? response.tool_calls : []
|
||||
|
|
@ -349,6 +591,11 @@ const Chat = ({
|
|||
const message =
|
||||
error instanceof Error ? error.message : "Failed to send chat message";
|
||||
|
||||
settleAssistantActivities(assistantMessageId, "error");
|
||||
appendAssistantActivity(assistantMessageId, {
|
||||
label: message,
|
||||
state: "error",
|
||||
});
|
||||
setErrorMessage(message);
|
||||
setMessages((previous) => [
|
||||
...previous,
|
||||
|
|
@ -485,8 +732,58 @@ const Chat = ({
|
|||
: "text-[#535862]"
|
||||
}`}
|
||||
>
|
||||
{message.content}
|
||||
{message.content ||
|
||||
(isSending && message.role === "assistant"
|
||||
? message.activity?.[message.activity.length - 1]?.label ||
|
||||
"Working on it..."
|
||||
: "")}
|
||||
</div>
|
||||
{message.activity && message.activity.length > 0 && (
|
||||
<div className="mt-3 overflow-hidden rounded-[10px] border border-[#ECEEF2] bg-[#FAFAFB]">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleActivityExpanded(message.id)}
|
||||
className="flex w-full items-center justify-between px-3 py-2 text-left"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Activity className="h-3.5 w-3.5 text-[#667085]" />
|
||||
<span className="text-[11px] font-medium text-[#667085]">
|
||||
Trace
|
||||
</span>
|
||||
<span className="text-[11px] text-[#98A2B3]">
|
||||
{message.activity.length} events
|
||||
</span>
|
||||
</div>
|
||||
{expandedActivityByMessage[message.id] ? (
|
||||
<ChevronUp className="h-3.5 w-3.5 text-[#98A2B3]" />
|
||||
) : (
|
||||
<ChevronDown className="h-3.5 w-3.5 text-[#98A2B3]" />
|
||||
)}
|
||||
</button>
|
||||
{!expandedActivityByMessage[message.id] && (
|
||||
<div className="px-3 pb-2 text-xs leading-4 text-[#98A2B3]">
|
||||
{message.activity[message.activity.length - 1]?.label}
|
||||
</div>
|
||||
)}
|
||||
{expandedActivityByMessage[message.id] && (
|
||||
<div className="border-t border-[#ECEEF2] px-3 py-2">
|
||||
<div className="flex flex-col gap-1.5">
|
||||
{message.activity.map((activityItem) => (
|
||||
<div
|
||||
key={activityItem.id}
|
||||
className="flex items-start gap-2 text-xs leading-4 text-[#667085]"
|
||||
>
|
||||
<span className="mt-0.5 shrink-0">
|
||||
<ActivityIcon state={activityItem.state} />
|
||||
</span>
|
||||
<span>{activityItem.label}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{message.toolCalls && message.toolCalls.length > 0 && (
|
||||
<div className="mt-3 flex flex-wrap gap-1">
|
||||
{message.toolCalls.map((toolCall) => (
|
||||
|
|
@ -505,16 +802,6 @@ const Chat = ({
|
|||
</div>
|
||||
)}
|
||||
|
||||
{isSending && (
|
||||
<div className="mt-9 max-w-[92%]">
|
||||
<AssistantMarker />
|
||||
<div className="flex items-center gap-2 text-sm font-normal leading-5 text-[#A4A7AE]">
|
||||
<span>Got it. Let me analyze your slide</span>
|
||||
<Loader2 className="h-3.5 w-3.5 animate-spin text-[#D5D7DA]" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { getHeader } from "./header";
|
|||
export interface ChatMessageRequest {
|
||||
presentation_id: string;
|
||||
message: string;
|
||||
conversation_id?: string;
|
||||
}
|
||||
|
||||
export interface ChatMessageResponse {
|
||||
|
|
@ -13,6 +14,55 @@ export interface ChatMessageResponse {
|
|||
tool_calls?: string[];
|
||||
}
|
||||
|
||||
export interface ChatStreamTrace {
|
||||
kind?: string;
|
||||
round?: number;
|
||||
tool?: string;
|
||||
status?: string;
|
||||
message?: string;
|
||||
tools?: string[];
|
||||
}
|
||||
|
||||
export interface ChatStreamHandlers {
|
||||
onChunk?: (chunk: string) => void;
|
||||
onStatus?: (status: string) => void;
|
||||
onTrace?: (trace: ChatStreamTrace) => void;
|
||||
onComplete?: (response: ChatMessageResponse) => void;
|
||||
}
|
||||
|
||||
interface ChatStreamDataChunk {
|
||||
type: "chunk";
|
||||
chunk?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataComplete {
|
||||
type: "complete";
|
||||
chat?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataError {
|
||||
type: "error";
|
||||
detail?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataStatus {
|
||||
type: "status";
|
||||
status?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataTrace {
|
||||
type: "trace";
|
||||
trace?: unknown;
|
||||
}
|
||||
|
||||
type ChatStreamData =
|
||||
| ChatStreamDataChunk
|
||||
| ChatStreamDataComplete
|
||||
| ChatStreamDataError
|
||||
| ChatStreamDataStatus
|
||||
| ChatStreamDataTrace
|
||||
| Record<string, unknown>;
|
||||
|
||||
export class PresentationChatApi {
|
||||
static async sendMessage(
|
||||
payload: ChatMessageRequest
|
||||
|
|
@ -29,4 +79,170 @@ export class PresentationChatApi {
|
|||
"Failed to send chat message"
|
||||
);
|
||||
}
|
||||
|
||||
static async streamMessage(
|
||||
payload: ChatMessageRequest,
|
||||
handlers: ChatStreamHandlers = {},
|
||||
options?: { signal?: AbortSignal }
|
||||
): Promise<ChatMessageResponse> {
|
||||
const response = await fetch(getApiUrl("/api/v1/ppt/chat/message/stream"), {
|
||||
method: "POST",
|
||||
headers: getHeader(),
|
||||
body: JSON.stringify(payload),
|
||||
cache: "no-cache",
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to stream chat message"
|
||||
);
|
||||
throw new Error("Failed to stream chat message");
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body received from chat stream");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
let buffer = "";
|
||||
let finalResponse: ChatMessageResponse | null = null;
|
||||
|
||||
const processSseFrame = (frame: string) => {
|
||||
const normalized = frame.replaceAll("\r", "");
|
||||
const lines = normalized.split("\n");
|
||||
let eventName = "";
|
||||
const dataLines: string[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("event:")) {
|
||||
eventName = line.slice(6).trim();
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith("data:")) {
|
||||
dataLines.push(line.slice(5).trimStart());
|
||||
}
|
||||
}
|
||||
|
||||
if (eventName && eventName !== "response") {
|
||||
return;
|
||||
}
|
||||
if (!dataLines.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
let parsedData: ChatStreamData;
|
||||
try {
|
||||
parsedData = JSON.parse(dataLines.join("\n")) as ChatStreamData;
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
const payloadType = parsedData.type;
|
||||
if (payloadType === "chunk") {
|
||||
const chunk = parsedData.chunk;
|
||||
if (typeof chunk === "string" && chunk.length > 0) {
|
||||
handlers.onChunk?.(chunk);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (payloadType === "complete") {
|
||||
const chatPayload = (parsedData as ChatStreamDataComplete).chat;
|
||||
if (
|
||||
chatPayload &&
|
||||
typeof chatPayload === "object" &&
|
||||
typeof (chatPayload as { response?: unknown }).response === "string"
|
||||
) {
|
||||
const typedResponse: ChatMessageResponse = {
|
||||
conversation_id:
|
||||
typeof (chatPayload as { conversation_id?: unknown })
|
||||
.conversation_id === "string"
|
||||
? (chatPayload as { conversation_id?: string }).conversation_id
|
||||
: undefined,
|
||||
response: (chatPayload as { response: string }).response,
|
||||
tool_calls: Array.isArray(
|
||||
(chatPayload as { tool_calls?: unknown }).tool_calls
|
||||
)
|
||||
? (
|
||||
(chatPayload as { tool_calls?: unknown[] }).tool_calls ?? []
|
||||
).filter((item): item is string => typeof item === "string")
|
||||
: [],
|
||||
};
|
||||
finalResponse = typedResponse;
|
||||
handlers.onComplete?.(typedResponse);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (payloadType === "error") {
|
||||
const detail = (parsedData as ChatStreamDataError).detail;
|
||||
const message =
|
||||
typeof detail === "string" && detail.trim().length > 0
|
||||
? detail
|
||||
: "Chat stream failed";
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
if (payloadType === "status") {
|
||||
const status = (parsedData as ChatStreamDataStatus).status;
|
||||
if (typeof status === "string" && status.trim().length > 0) {
|
||||
handlers.onStatus?.(status);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (payloadType === "trace") {
|
||||
const trace = (parsedData as ChatStreamDataTrace).trace;
|
||||
if (trace && typeof trace === "object") {
|
||||
const typedTrace = trace as Record<string, unknown>;
|
||||
handlers.onTrace?.({
|
||||
kind:
|
||||
typeof typedTrace.kind === "string" ? typedTrace.kind : undefined,
|
||||
round:
|
||||
typeof typedTrace.round === "number" ? typedTrace.round : undefined,
|
||||
tool:
|
||||
typeof typedTrace.tool === "string" ? typedTrace.tool : undefined,
|
||||
status:
|
||||
typeof typedTrace.status === "string" ? typedTrace.status : undefined,
|
||||
message:
|
||||
typeof typedTrace.message === "string" ? typedTrace.message : undefined,
|
||||
tools: Array.isArray(typedTrace.tools)
|
||||
? typedTrace.tools.filter(
|
||||
(value): value is string => typeof value === "string"
|
||||
)
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let delimiterIndex = buffer.indexOf("\n\n");
|
||||
while (delimiterIndex >= 0) {
|
||||
const frame = buffer.slice(0, delimiterIndex);
|
||||
buffer = buffer.slice(delimiterIndex + 2);
|
||||
processSseFrame(frame);
|
||||
delimiterIndex = buffer.indexOf("\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer.trim().length > 0) {
|
||||
processSseFrame(buffer);
|
||||
}
|
||||
|
||||
if (finalResponse) {
|
||||
return finalResponse;
|
||||
}
|
||||
|
||||
throw new Error("Chat stream ended before completion");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue