Merge pull request #545 from presenton/refactor/presenton-chat-stream

Refactor/presenton chat stream
This commit is contained in:
Sudip Parajuli 2026-04-27 19:54:21 +05:45 committed by GitHub
commit cb731aa6c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 5181 additions and 618 deletions

View file

@ -1,6 +1,7 @@
#!/usr/bin/env node
/**
* CLI bridge for Python: one JSON line on stdout for LiteParse extraction.
* CLI bridge for Python: by default, raw extracted text on stdout (--python-bridge plain);
* or one JSON line (--python-bridge json) for backward compatibility.
*
* OCR follows LlamaIndex LiteParse guidance (built-in Tesseract by default):
* https://developers.llamaindex.ai/liteparse/guides/ocr/
@ -56,14 +57,31 @@ function emit(result, exitCode = 0) {
process.exit(exitCode);
}
/** "plain" = success: UTF-8 text on stdout only. "json" = one JSON line (legacy, huge payloads can break). */
const pyBridgeArg = readArg("--python-bridge");
const pyBridge =
pyBridgeArg == null || pyBridgeArg === ""
? "json"
: String(pyBridgeArg).trim().toLowerCase() === "plain"
? "plain"
: "json";
function bridgeError(message, exitCode) {
if (pyBridge === "plain") {
process.stderr.write(`${message}\n`);
process.exit(exitCode);
}
emit({ ok: false, error: message }, exitCode);
}
const filePath = readArg("--file");
if (!filePath) {
emit({ ok: false, error: "Missing required --file argument" }, 2);
bridgeError("Missing required --file argument", 2);
}
const resolvedPath = path.resolve(filePath);
if (!fs.existsSync(resolvedPath)) {
emit({ ok: false, error: `File not found: ${resolvedPath}` }, 2);
bridgeError(`File not found: ${resolvedPath}`, 2);
}
const ocrEnabled = parseBool(readArg("--ocr-enabled"), true);
@ -117,6 +135,10 @@ try {
const result = await parser.parse(resolvedPath, true);
const text = result?.text ?? "";
if (pyBridge === "plain") {
process.stdout.write(text);
process.exit(0);
}
emit({
ok: true,
filePath: resolvedPath,
@ -133,6 +155,13 @@ try {
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
const stack = error instanceof Error ? error.stack : undefined;
if (pyBridge === "plain") {
if (stack) {
process.stderr.write(`${stack}\n`);
}
process.stderr.write(`${message}\n`);
process.exit(1);
}
if (stack) {
process.stderr.write(`${stack}\n`);
}

View file

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import async_session_maker
from utils.ollama import pull_ollama_model
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
)
log_event_count = 0
session = await get_container_db_async_session().__anext__()
async with async_session_maker() as session:
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
return
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
return
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if "status" in event:
saved_model_status.status = event["status"]
if "status" in event:
saved_model_status.status = event["status"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await upsert_ollama_pull_status(session, model, saved_model_status)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
await upsert_ollama_pull_status(session, model, saved_model_status)
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
saved_model_status.done = True
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
async def upsert_ollama_pull_status(

View file

@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from models.ollama_model_metadata import OllamaModelMetadata
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import get_async_session
from utils.ollama import list_pulled_ollama_models
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
@ -29,7 +29,7 @@ async def get_available_models():
async def pull_model(
model: str,
background_tasks: BackgroundTasks,
session: AsyncSession = Depends(get_container_db_async_session),
session: AsyncSession = Depends(get_async_session),
):
if model not in SUPPORTED_OLLAMA_MODELS:

View file

@ -54,7 +54,6 @@ def main() -> None:
p = _sqlite_file_path(sync_url)
if p is not None:
paths.append(p)
paths.append(p.parent / "container.db")
seen: set[Path] = set()
for path in paths:

View file

@ -1,5 +1,4 @@
from collections.abc import AsyncGenerator
import os
from sqlalchemy.ext.asyncio import (
AsyncEngine,
create_async_engine,
@ -21,7 +20,6 @@ from models.sql.template_create_info import TemplateCreateInfoModel
from models.sql.slide import SlideModel
from models.sql.webhook_subscription import WebhookSubscription
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
from utils.get_env import get_app_data_directory_env
from utils.get_env import get_migrate_database_on_startup_env
@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
yield session
# Container DB (Lives inside the app data directory)
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
container_db_engine: AsyncEngine = create_async_engine(
container_db_url, connect_args={"check_same_thread": False}
)
container_db_async_session_maker = async_sessionmaker(
container_db_engine, expire_on_commit=False
)
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
async with container_db_async_session_maker() as session:
yield session
# Create Database and Tables
async def create_db_and_tables():
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
@ -76,18 +58,11 @@ async def create_db_and_tables():
TemplateModel.__table__,
WebhookSubscription.__table__,
AsyncPresentationGenerationTaskModel.__table__,
OllamaPullStatus.__table__,
],
)
)
async with container_db_engine.begin() as conn:
await conn.run_sync(
lambda sync_conn: SQLModel.metadata.create_all(
sync_conn,
tables=[OllamaPullStatus.__table__],
)
)
async def dispose_engines():
"""Dispose all engine connection pools.
@ -97,4 +72,3 @@ async def dispose_engines():
database and prevent stale / leaked connections.
"""
await sql_engine.dispose()
await container_db_engine.dispose()

View file

@ -1,6 +1,8 @@
import asyncio
import json
import logging
import os
import re
import tempfile
from pathlib import Path
from typing import Any, List, Optional, Tuple
@ -30,6 +32,129 @@ except Exception:
LOGGER = logging.getLogger(__name__)
def _unwrap_liteparse_json_line_if_stored(text: str) -> str:
"""If the whole JSON line from the LiteParse runner was stored as the document, keep only the text field."""
if not text:
return text
s = text.lstrip()
if not s.startswith("{"):
return text
try:
payload = json.loads(s)
except (json.JSONDecodeError, TypeError, ValueError):
return text
if not isinstance(payload, dict):
return text
if (
payload.get("ok") is True
and "filePath" in payload
and isinstance(payload.get("text"), str)
):
return payload["text"]
return text
_RE_TEXT_KEY = re.compile(r'"text"\s*:\s*"')
def _json_unescape_quoted_value(s: str, content_start: int) -> str:
"""
Unescape a JSON string value. `content_start` is the index of the first character
*inside* the value (immediately after the opening quote of the "text" field).
If the closing quote is missing (truncated), returns the unescaped rest of the string.
"""
out: list[str] = []
i = content_start
n = len(s)
while i < n:
c = s[i]
if c == "\\" and i + 1 < n:
e = s[i + 1]
if e in '"\\':
out.append(e)
i += 2
elif e == "/":
out.append("/")
i += 2
elif e == "b":
out.append("\b")
i += 2
elif e == "f":
out.append("\f")
i += 2
elif e == "n":
out.append("\n")
i += 2
elif e == "r":
out.append("\r")
i += 2
elif e == "t":
out.append("\t")
i += 2
elif e == "u" and i + 5 < n:
try:
out.append(chr(int(s[i + 2 : i + 6], 16)))
except (ValueError, OverflowError):
out.append(s[i : i + 6])
i += 6
else:
out.append(e)
i += 2
elif c == '"':
return "".join(out)
else:
out.append(c)
i += 1
return "".join(out)
def _try_extract_liteparse_text_value_from_malformed_json(s: str) -> Optional[str]:
"""
When json.loads failed (e.g. truncated or corrupt), find the "text" field value
in a LiteParse-shaped object and return only the unescaped string body.
"""
if not s.startswith("{"):
return None
head = s[:10000] if len(s) > 10000 else s
if not ("ok" in head and "filePath" in head):
return None
m = _RE_TEXT_KEY.search(s)
if not m:
return None
return _json_unescape_quoted_value(s, m.end())
def _clean_extracted_one_pass(t: str) -> str:
for _ in range(3):
nxt = _unwrap_liteparse_json_line_if_stored(t)
if nxt == t:
break
t = nxt
s = t.lstrip()
if s.startswith("{"):
m = _try_extract_liteparse_text_value_from_malformed_json(s)
if m is not None:
return m
return t
def clean_extracted_document_text(text: str) -> str:
"""
Return only the document body: strip LiteParse JSON wrappers, then drop any
leading payload before the "text" value (handles truncated/invalid JSON).
Multiple passes in case the inner body is again JSON-shaped.
"""
if not text:
return text
t = text
for _ in range(4):
nxt = _clean_extracted_one_pass(t)
if nxt == t:
return t
t = nxt
return t
class DocumentsLoader:
DECOMPOSE_TIMEOUT_SECONDS = 600
@ -107,6 +232,7 @@ class DocumentsLoader:
else:
document = await asyncio.to_thread(self._parse_with_liteparse, file_path)
document = clean_extracted_document_text(document)
documents.append(document)
images.append(imgs)

View file

@ -193,6 +193,11 @@ class LiteParseService:
return True, "ok"
@staticmethod
def _use_json_runner_output() -> bool:
"""If true, expect one JSON line on stdout (legacy). Default is plain UTF-8 text (better for large PDFs)."""
return (os.getenv("LITEPARSE_RUNNER_OUTPUT") or "").strip().lower() == "json"
def parse_to_markdown(
self,
file_path: str,
@ -233,6 +238,9 @@ class LiteParseService:
if tessdata:
command.extend(["--tessdata-path", tessdata])
use_json = self._use_json_runner_output()
command.extend(["--python-bridge", "json" if use_json else "plain"])
LOGGER.info(
"[LiteParse] Parsing file=%s ocr_enabled=%s ocr_language=%s",
file_path,
@ -254,6 +262,20 @@ class LiteParseService:
_command_str(command),
)
if not use_json:
if process.returncode != 0:
err = (process.stderr or "").strip() or "LiteParse failed"
raise LiteParseError(
f"{err}; returncode={process.returncode}; "
f"stderr={_snippet(process.stderr)}; stdout={_snippet(process.stdout)}"
)
return {
"ok": True,
"text": (process.stdout or "").lstrip("\ufeff"),
"filePath": file_path,
"pageCount": 0,
}
payload: Dict[str, Any]
try:
payload = self._decode_runner_output(process.stdout)

View file

@ -46,11 +46,7 @@ const PresentationPage = ({ presentation_id }: { presentation_id: string }) => {
}
}, [presentationData]);
// Ensure /app_data and /static image paths resolve through FastAPI in Electron.
useEffect(() => {
const observer = setupImageUrlConverter();
return () => observer?.disconnect();
}, []);
// Function to fetch the slides
useEffect(() => {

View file

@ -63,7 +63,8 @@ const page = () => {
<p className="text-xl font-syne text-[#101323CC]">Turn prompts or documents into presentations with AI</p>
</div>
<div
className='fixed z-0 -bottom-[14.5rem] left-0 w-full h-full'
className="fixed z-0 -bottom-[14.5rem] left-0 w-full h-full pointer-events-none"
aria-hidden
style={{
height: "341px",
borderRadius: '1440px',

View file

@ -1,5 +0,0 @@
{
"description": "Dark, theme-ready presentation layouts with covers, structured content grids, timelines, narrative splits, and chart-driven slides",
"ordered": false,
"default": false
}

View file

@ -15,6 +15,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from models.sql.async_presentation_generation_status import ( # noqa: F401, E402
AsyncPresentationGenerationTaskModel,
)
from models.sql.chat_history_message import ChatHistoryMessageModel # noqa: F401, E402
from models.sql.image_asset import ImageAsset # noqa: F401, E402
from models.sql.key_value import KeyValueSqlModel # noqa: F401, E402
from models.sql.ollama_pull_status import OllamaPullStatus # noqa: F401, E402

View file

@ -0,0 +1,48 @@
"""added_chat_history_messages_table
Revision ID: c7b70d0f31b1
Revises: 95b5127e93cd
Create Date: 2026-04-26 12:29:49.508761
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'c7b70d0f31b1'
down_revision: Union[str, None] = '95b5127e93cd'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('chat_history_messages',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('presentation_id', sa.Uuid(), nullable=False),
sa.Column('conversation_id', sa.Uuid(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('tool_calls', sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(['presentation_id'], ['presentations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_chat_history_messages_conversation_id'), 'chat_history_messages', ['conversation_id'], unique=False)
op.create_index(op.f('ix_chat_history_messages_position'), 'chat_history_messages', ['position'], unique=False)
op.create_index(op.f('ix_chat_history_messages_presentation_id'), 'chat_history_messages', ['presentation_id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_chat_history_messages_presentation_id'), table_name='chat_history_messages')
op.drop_index(op.f('ix_chat_history_messages_position'), table_name='chat_history_messages')
op.drop_index(op.f('ix_chat_history_messages_conversation_id'), table_name='chat_history_messages')
op.drop_table('chat_history_messages')
# ### end Alembic commands ###

View file

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import async_session_maker
from utils.ollama import pull_ollama_model
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
)
log_event_count = 0
session = await get_container_db_async_session().__anext__()
async with async_session_maker() as session:
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
return
try:
async for event in pull_ollama_model(model):
if "error" in event:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = event["error"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
return
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if "status" in event:
saved_model_status.status = event["status"]
if "status" in event:
saved_model_status.status = event["status"]
await upsert_ollama_pull_status(session, model, saved_model_status)
await upsert_ollama_pull_status(session, model, saved_model_status)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
await upsert_ollama_pull_status(session, model, saved_model_status)
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
saved_model_status.error = str(e)
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
saved_model_status.done = True
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
saved_model_status.error = None
await upsert_ollama_pull_status(session, model, saved_model_status)
await session.close()
async def upsert_ollama_pull_status(

View file

@ -0,0 +1,129 @@
import json
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from models.chat import (
ChatConversationListItem,
ChatHistoryMessageItem,
ChatHistoryResponse,
ChatMessageRequest,
ChatMessageResponse,
)
from models.sse_response import (
SSECompleteResponse,
SSEErrorResponse,
SSEStatusResponse,
SSETraceResponse,
SSEResponse,
)
from services.chat import ChatTurnResult, PresentationChatService
from services.chat import sql_chat_history
from services.database import get_async_session
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
@CHAT_ROUTER.get("/conversations", response_model=list[ChatConversationListItem])
async def list_chat_conversations(
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
sql_session: AsyncSession = Depends(get_async_session),
):
raw = await sql_chat_history.list_conversations(
sql_session, presentation_id=presentation_id
)
return [
ChatConversationListItem(
conversation_id=uuid.UUID(str(item["conversation_id"])),
updated_at=item.get("updated_at"),
last_message_preview=item.get("last_message_preview"),
)
for item in raw
]
@CHAT_ROUTER.get("/history", response_model=ChatHistoryResponse)
async def get_chat_history(
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
conversation_id: uuid.UUID = Query(..., description="Conversation thread id"),
sql_session: AsyncSession = Depends(get_async_session),
):
rows = await sql_chat_history.load_messages_with_meta(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
return ChatHistoryResponse(
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=[
ChatHistoryMessageItem(
role=str(m.get("role") or ""),
content=str(m.get("content") or ""),
created_at=m.get("created_at")
if isinstance(m.get("created_at"), str)
else None,
)
for m in rows
],
)
@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse)
async def chat_message(
payload: ChatMessageRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
service = PresentationChatService(
sql_session=sql_session,
presentation_id=payload.presentation_id,
conversation_id=payload.conversation_id,
)
result = await service.generate_reply(payload.message)
return ChatMessageResponse(
conversation_id=result.conversation_id,
response=result.response_text,
tool_calls=result.tool_calls,
)
@CHAT_ROUTER.post("/message/stream")
async def chat_message_stream(
payload: ChatMessageRequest,
sql_session: AsyncSession = Depends(get_async_session),
):
service = PresentationChatService(
sql_session=sql_session,
presentation_id=payload.presentation_id,
conversation_id=payload.conversation_id,
)
async def inner():
try:
async for event_type, value in service.stream_reply(payload.message):
if event_type == "chunk" and isinstance(value, str):
yield SSEResponse(
event="response",
data=json.dumps({"type": "chunk", "chunk": value}),
).to_string()
elif event_type == "status" and isinstance(value, str):
yield SSEStatusResponse(status=value).to_string()
elif event_type == "trace" and isinstance(value, dict):
yield SSETraceResponse(trace=value).to_string()
elif event_type == "complete" and isinstance(value, ChatTurnResult):
result = value
complete_payload = ChatMessageResponse(
conversation_id=result.conversation_id,
response=result.response_text,
tool_calls=result.tool_calls,
)
yield SSECompleteResponse(
key="chat",
value=complete_payload.model_dump(mode="json"),
).to_string()
except HTTPException as exc:
yield SSEErrorResponse(detail=exc.detail).to_string()
return StreamingResponse(inner(), media_type="text/event-stream")

View file

@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from models.ollama_model_metadata import OllamaModelMetadata
from models.ollama_model_status import OllamaModelStatus
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from services.database import get_async_session
from utils.ollama import list_pulled_ollama_models
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
@ -29,7 +29,7 @@ async def get_available_models():
async def pull_model(
model: str,
background_tasks: BackgroundTasks,
session: AsyncSession = Depends(get_container_db_async_session),
session: AsyncSession = Depends(get_async_session),
):
if model not in SUPPORTED_OLLAMA_MODELS:

View file

@ -15,6 +15,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
from api.v1.ppt.endpoints.chat import CHAT_ROUTER
from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER
from api.v1.ppt.endpoints.theme import THEMES_ROUTER
from api.v1.ppt.endpoints.theme_generate import THEME_ROUTER
@ -29,6 +30,7 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
API_V1_PPT_ROUTER.include_router(PPTX_SLIDES_ROUTER)
API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
API_V1_PPT_ROUTER.include_router(CHAT_ROUTER)
API_V1_PPT_ROUTER.include_router(SLIDE_TO_HTML_ROUTER)
API_V1_PPT_ROUTER.include_router(HTML_TO_REACT_ROUTER)
API_V1_PPT_ROUTER.include_router(HTML_EDIT_ROUTER)

View file

@ -0,0 +1,44 @@
import uuid
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
class ChatMessageRequest(BaseModel):
presentation_id: uuid.UUID
message: str = Field(min_length=1, max_length=8000)
conversation_id: Optional[uuid.UUID] = None
model_config = ConfigDict(extra="forbid")
class ChatMessageResponse(BaseModel):
conversation_id: uuid.UUID
response: str
tool_calls: list[str] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class ChatHistoryMessageItem(BaseModel):
role: str
content: str
created_at: Optional[str] = None
model_config = ConfigDict(extra="forbid")
class ChatHistoryResponse(BaseModel):
presentation_id: uuid.UUID
conversation_id: uuid.UUID
messages: list[ChatHistoryMessageItem]
model_config = ConfigDict(extra="forbid")
class ChatConversationListItem(BaseModel):
conversation_id: uuid.UUID
updated_at: Optional[str] = None
last_message_preview: Optional[str] = None
model_config = ConfigDict(extra="forbid")

View file

@ -0,0 +1,31 @@
from datetime import datetime
from typing import Optional
import uuid
from sqlalchemy import JSON, Column, DateTime, ForeignKey, Text
from sqlmodel import Field, SQLModel
from utils.datetime_utils import get_current_utc_datetime
class ChatHistoryMessageModel(SQLModel, table=True):
__tablename__ = "chat_history_messages"
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
presentation_id: uuid.UUID = Field(
sa_column=Column(
ForeignKey("presentations.id", ondelete="CASCADE"),
index=True,
nullable=False,
)
)
conversation_id: uuid.UUID = Field(index=True)
position: int = Field(index=True, ge=1)
role: str
content: str = Field(sa_column=Column(Text, nullable=False))
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, default=get_current_utc_datetime
)
)
tool_calls: Optional[list[str]] = Field(sa_column=Column(JSON), default=None)

View file

@ -20,6 +20,15 @@ class SSEStatusResponse(BaseModel):
).to_string()
class SSETraceResponse(BaseModel):
trace: object
def to_string(self):
return SSEResponse(
event="response", data=json.dumps({"type": "trace", "trace": self.trace})
).to_string()
class SSEErrorResponse(BaseModel):
detail: str

View file

@ -24,7 +24,8 @@ dependencies = [
"pathvalidate>=3.3.1",
"pdfplumber>=0.11.7",
"sqlmodel>=0.0.24",
"llmai==0.1.9",
"llmai==0.2.2",
"jsonschema>=4.26.0",
]
[tool.uv]

View file

@ -0,0 +1,8 @@
from services.chat.service import ChatTurnResult, PresentationChatService
from services.chat.presentation_context_store import PresentationContextStore
__all__ = [
"ChatTurnResult",
"PresentationChatService",
"PresentationContextStore",
]

View file

@ -0,0 +1,324 @@
import asyncio
from datetime import datetime, timezone
import logging
import os
from typing import Any, Optional
from uuid import UUID
from services.mem0_oss_memory import get_shared_mem0_client
LOGGER = logging.getLogger(__name__)
CHAT_TURN_TAG = "[chat_turn]"
DEFAULT_MAX_STORED_TURNS = 20
class ChatMemoryStore:
def __init__(self):
self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True)
self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8)
self._max_context_chars = self._to_int(
os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000
)
self._max_stored_turns = self._to_int(
os.getenv("CHAT_MAX_STORED_TURNS"), default=DEFAULT_MAX_STORED_TURNS
)
self._namespace_prefix = (
os.getenv("MEM0_CHAT_NAMESPACE_PREFIX")
or os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX")
or "presentation"
).strip() or "presentation"
@staticmethod
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
return default
return str(value).strip().lower() in {"1", "true", "yes", "on"}
@staticmethod
def _to_int(value: Optional[str], default: int) -> int:
try:
parsed = int(value) if value is not None else default
return max(1, parsed)
except Exception:
return default
@staticmethod
def _normalize(value: str) -> str:
return " ".join((value or "").split())
@staticmethod
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
return isinstance(exc, (Exception, SystemExit))
def _scope_user_id(self, presentation_id: UUID, conversation_id: UUID) -> str:
return (
f"{self._namespace_prefix}:{presentation_id}:"
f"conversation:{conversation_id}"
)
def _truncate(self, text: str, limit: int = 20000) -> str:
if len(text) <= limit:
return text
return f"{text[:limit]}\n\n[TRUNCATED]"
async def _get_client(self):
if not self._enabled:
return None
return get_shared_mem0_client()
def _build_turn_payload(self, *, user_text: str, assistant_text: str) -> str:
memory_lines = [
CHAT_TURN_TAG,
f"turn_created_at={datetime.now(timezone.utc).isoformat()}",
]
if user_text:
memory_lines.append(f"user={user_text}")
if assistant_text:
memory_lines.append(f"assistant={assistant_text}")
return "\n".join(memory_lines)
@staticmethod
def _extract_text_field(item: dict[str, Any]) -> str:
memory_text = item.get("memory") or item.get("text") or item.get("data")
return str(memory_text).strip() if memory_text is not None else ""
def _collect_results(self, response: Any) -> list[dict[str, Any]]:
if isinstance(response, dict):
raw_results = (
response.get("results")
or response.get("memories")
or response.get("items")
or []
)
if isinstance(raw_results, list):
return [item for item in raw_results if isinstance(item, dict)]
return []
if isinstance(response, list):
return [item for item in response if isinstance(item, dict)]
return []
@staticmethod
def _safe_parse_datetime(raw_value: Any) -> datetime | None:
if not isinstance(raw_value, str) or not raw_value.strip():
return None
value = raw_value.strip().replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(value)
if parsed.tzinfo is None:
return parsed.replace(tzinfo=timezone.utc)
return parsed
except Exception:
return None
@staticmethod
def _extract_chat_turn_fields(text: str) -> tuple[str | None, str | None, datetime | None]:
if CHAT_TURN_TAG not in text:
return None, None, None
user_text: str | None = None
assistant_text: str | None = None
turn_created_at: datetime | None = None
for line in text.splitlines():
if line.startswith("user="):
user_text = line[len("user=") :].strip()
elif line.startswith("assistant="):
assistant_text = line[len("assistant=") :].strip()
elif line.startswith("turn_created_at="):
turn_created_at = ChatMemoryStore._safe_parse_datetime(
line[len("turn_created_at=") :].strip()
)
return user_text, assistant_text, turn_created_at
async def store_chat_turn(
self,
*,
presentation_id: UUID,
conversation_id: UUID,
user_message: str,
assistant_message: str,
) -> None:
client = await self._get_client()
if client is None:
return
user_text = self._normalize(user_message)
assistant_text = self._normalize(assistant_message)
if not user_text and not assistant_text:
return
payload = [
{
"role": "user",
"content": self._truncate(
self._build_turn_payload(
user_text=user_text,
assistant_text=assistant_text,
)
),
}
]
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
def _add():
try:
return client.add(payload, user_id=scoped_user_id, infer=False)
except TypeError:
return client.add(
messages=payload,
user_id=scoped_user_id,
infer=False,
)
try:
await asyncio.to_thread(_add)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
(
"Failed to add chat mem0 memory "
"(presentation_id=%s, conversation_id=%s)"
),
presentation_id,
conversation_id,
)
async def retrieve_context(
self,
*,
presentation_id: UUID,
conversation_id: UUID,
query: str,
) -> str:
client = await self._get_client()
if client is None:
return ""
trimmed_query = (query or "").strip()
if not trimmed_query:
return ""
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
def _search():
try:
return client.search(
trimmed_query,
filters={"user_id": scoped_user_id},
top_k=self._top_k,
)
except TypeError:
return client.search(
trimmed_query,
user_id=scoped_user_id,
top_k=self._top_k,
)
try:
response = await asyncio.to_thread(_search)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
(
"Failed to search chat mem0 memory "
"(presentation_id=%s, conversation_id=%s)"
),
presentation_id,
conversation_id,
)
return ""
results = self._collect_results(response)
memories: list[str] = []
for item in results:
normalized = self._extract_text_field(item)
if normalized:
memories.append(normalized)
if not memories:
return ""
deduped = list(dict.fromkeys(memories))
return self._truncate("\n\n".join(deduped), self._max_context_chars)
async def load_history(
self,
*,
presentation_id: UUID,
conversation_id: UUID,
) -> list[dict[str, str]]:
client = await self._get_client()
if client is None:
return []
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
def _get_all():
try:
return client.get_all(
filters={"user_id": scoped_user_id},
limit=max(10, self._max_stored_turns * 4),
)
except TypeError:
try:
return client.get_all(
user_id=scoped_user_id,
limit=max(10, self._max_stored_turns * 4),
)
except TypeError:
try:
return client.get_all(filters={"user_id": scoped_user_id})
except TypeError:
return client.get_all(user_id=scoped_user_id)
try:
response = await asyncio.to_thread(_get_all)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception(
(
"Failed to load chat mem0 history "
"(presentation_id=%s, conversation_id=%s)"
),
presentation_id,
conversation_id,
)
return []
results = self._collect_results(response)
ordered_turns: list[tuple[datetime, str, str]] = []
for index, item in enumerate(results):
text_value = self._extract_text_field(item)
if not text_value:
continue
user_text, assistant_text, embedded_timestamp = self._extract_chat_turn_fields(
text_value
)
if not user_text and not assistant_text:
continue
item_created_at = (
self._safe_parse_datetime(item.get("created_at"))
or self._safe_parse_datetime(item.get("updated_at"))
or self._safe_parse_datetime(item.get("event_at"))
)
timestamp = embedded_timestamp or item_created_at or datetime.fromtimestamp(
index, tz=timezone.utc
)
ordered_turns.append((timestamp, user_text or "", assistant_text or ""))
ordered_turns.sort(key=lambda turn: turn[0])
recent_turns = ordered_turns[-self._max_stored_turns :]
history: list[dict[str, str]] = []
for _, user_text, assistant_text in recent_turns:
if user_text:
history.append({"role": "user", "content": user_text})
if assistant_text:
history.append({"role": "assistant", "content": assistant_text})
return history
CHAT_MEMORY_STORE = ChatMemoryStore()

View file

@ -0,0 +1,80 @@
import uuid
from sqlalchemy.ext.asyncio import AsyncSession
from services.chat.chat_memory_store import CHAT_MEMORY_STORE
from services.chat import sql_chat_history
class ChatConversationStore:
def __init__(self, sql_session: AsyncSession):
self._sql = sql_session
async def load_history(
self,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
messages = await sql_chat_history.load_messages(
self._sql,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
if messages:
return messages
legacy = await CHAT_MEMORY_STORE.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
if legacy:
await sql_chat_history.replace_messages(
self._sql,
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=legacy,
)
return legacy
async def append_turn(
self,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
tool_calls: list[str] | None = None,
) -> None:
await sql_chat_history.append_turn(
self._sql,
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=assistant_message,
tool_calls=tool_calls,
)
await CHAT_MEMORY_STORE.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=assistant_message,
)
async def retrieve_semantic_context(
self,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
query: str,
) -> str:
return await CHAT_MEMORY_STORE.retrieve_context(
presentation_id=presentation_id,
conversation_id=conversation_id,
query=query,
)
async def ensure_conversation_id(
self,
conversation_id: uuid.UUID | None,
) -> uuid.UUID:
return conversation_id or uuid.uuid4()

View file

@ -0,0 +1,491 @@
import copy
import json
import logging
import re
import uuid
from typing import Any
from jsonschema import Draft202012Validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.image_prompt import ImagePrompt
from models.sql.image_asset import ImageAsset
from models.sql.presentation import PresentationModel
from models.sql.slide import SlideModel
from services.icon_finder_service import ICON_FINDER_SERVICE
from services.image_generation_service import ImageGenerationService
from services.mem0_presentation_memory_service import MEM0_PRESENTATION_MEMORY_SERVICE
from templates.presentation_layout import SlideLayoutModel
from utils.asset_directory_utils import get_images_directory
from utils.process_slides import (
process_old_and_new_slides_and_fetch_assets,
process_slide_and_fetch_assets,
)
LOGGER = logging.getLogger(__name__)
MAX_SCHEMA_ERRORS = 10
# Keep URL runtime fields during validation because many slide schemas require them.
# Speaker note is handled separately and should not affect JSON-schema checks.
RUNTIME_CONTENT_FIELDS = {"__speaker_note__"}
class PresentationChatMemoryLayer:
"""
Memory abstraction for chat tools and context retrieval.
This layer intentionally hides where data comes from (SQL-backed persisted state
and mem0 retrieval) behind `get` and `search`-style methods so chat logic stays
decoupled from storage details.
"""
def __init__(self, sql_session: AsyncSession, presentation_id: uuid.UUID):
self._sql_session = sql_session
self._presentation_id = presentation_id
async def get(self, key: str) -> Any:
if key != "presentation_outline":
return None
# Prefer live slides from SQL so slide count and slide indices are always current.
slides_result = await self._sql_session.scalars(
select(SlideModel)
.where(SlideModel.presentation == self._presentation_id)
.order_by(SlideModel.index)
)
slides = list(slides_result)
if slides:
LOGGER.info(
"Chat outline loaded from slides table (presentation_id=%s, slides=%d)",
self._presentation_id,
len(slides),
)
return {
"source": "slides_table",
"slide_count": len(slides),
"slides": [
{
"slide_id": str(slide.id),
"index": slide.index,
"layout_id": slide.layout,
"content": slide.content,
"speaker_note": slide.speaker_note,
}
for slide in slides
],
}
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not presentation.outlines:
LOGGER.info(
"Chat memory miss for outline (presentation_id=%s)",
self._presentation_id,
)
return None
LOGGER.info(
"Chat outline fallback hit from presentation.outlines (presentation_id=%s)",
self._presentation_id,
)
return presentation.outlines
async def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]:
"""
Search slides directly from SQL-backed slide rows.
Results are intentionally compact (snippet-first) to keep tool-call payloads
small for models with limited context windows.
"""
trimmed_query = (query or "").strip()
if not trimmed_query:
return []
slides_result = await self._sql_session.scalars(
select(SlideModel).where(SlideModel.presentation == self._presentation_id)
)
slides = sorted(list(slides_result), key=lambda slide: slide.index)
if not slides:
LOGGER.info(
"Chat memory miss for slide search (presentation_id=%s, reason=no_slides)",
self._presentation_id,
)
return []
query_lower = trimmed_query.lower()
query_tokens = set(re.findall(r"[a-z0-9]{2,}", query_lower))
ranked: list[tuple[int, dict[str, Any]]] = []
for slide in slides:
serialized = self._serialize_slide(slide)
searchable = serialized.lower()
score = 0
if query_lower in searchable:
score += 8
if query_tokens:
score += sum(1 for token in query_tokens if token in searchable)
if score <= 0:
continue
ranked.append(
(
score,
{
"slide_id": str(slide.id),
"index": slide.index,
"slide_number": slide.index + 1,
"layout_id": slide.layout,
"snippet": self._build_snippet(serialized, query_lower),
"score": score,
},
)
)
ranked.sort(key=lambda item: (-item[0], item[1]["index"]))
results = [entry for _, entry in ranked[: max(1, limit)]]
LOGGER.info(
"Chat DB slide search completed (presentation_id=%s, query=%r, hits=%d)",
self._presentation_id,
trimmed_query,
len(results),
)
return results
async def get_slide_at_index(
self, index: int, *, include_full_content: bool = False
) -> dict[str, Any] | None:
slide = await self._sql_session.scalar(
select(SlideModel).where(
SlideModel.presentation == self._presentation_id,
SlideModel.index == index,
)
)
if not slide:
LOGGER.info(
"Chat memory miss for slide by index (presentation_id=%s, index=%d)",
self._presentation_id,
index,
)
return None
response: dict[str, Any] = {
"slide_id": str(slide.id),
"index": slide.index,
"slide_number": slide.index + 1,
"layout_id": slide.layout,
"content_preview": self._build_snippet(
self._serialize_slide(slide),
query_lower="",
window=420,
),
"speaker_note": slide.speaker_note,
}
if include_full_content:
response["content"] = slide.content
return response
async def get_available_layouts(self) -> list[dict[str, Any]]:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not isinstance(presentation.layout, dict):
return []
try:
layout_model = presentation.get_layout()
except Exception:
LOGGER.exception(
"Failed to parse presentation layout (presentation_id=%s)",
self._presentation_id,
)
return []
return [
{
"id": layout.id,
"name": layout.name,
"description": layout.description,
}
for layout in layout_model.slides
]
async def get_content_schema_from_layout_id(self, layout_id: str) -> dict[str, Any] | None:
layout = await self._get_layout_by_id(layout_id)
if not layout:
return None
return layout.json_schema
async def generate_image(self, prompt: str) -> str:
image_generation_service = ImageGenerationService(get_images_directory())
image = await image_generation_service.generate_image(ImagePrompt(prompt=prompt))
if isinstance(image, ImageAsset):
self._sql_session.add(image)
await self._sql_session.commit()
return image.path
return str(image)
async def generate_icon(self, query: str) -> str:
icons = await ICON_FINDER_SERVICE.search_icons(query, k=1)
if icons:
return icons[0]
return "/static/icons/placeholder.svg"
async def save_slide(
self,
*,
content: dict[str, Any],
layout_id: str,
index: int,
replace_old_slide_at_index: bool,
) -> dict[str, Any]:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
return {
"saved": False,
"message": "Presentation not found.",
"validation_errors": [],
}
layout = await self._get_layout_by_id(layout_id, presentation=presentation)
if not layout:
return {
"saved": False,
"message": f"Layout '{layout_id}' was not found in this presentation.",
"validation_errors": [f"Unknown layout_id '{layout_id}'."],
}
validation_errors = self._validate_slide_content(
content=content,
schema=layout.json_schema,
)
if validation_errors:
return {
"saved": False,
"message": "Slide content failed schema validation.",
"validation_errors": validation_errors,
}
target_index = max(0, index)
image_generation_service = ImageGenerationService(get_images_directory())
if replace_old_slide_at_index:
existing_slide = await self._sql_session.scalar(
select(SlideModel).where(
SlideModel.presentation == self._presentation_id,
SlideModel.index == target_index,
)
)
if not existing_slide:
return {
"saved": False,
"message": f"No existing slide found at index {target_index} to replace.",
"validation_errors": [],
}
updated_content = copy.deepcopy(content)
new_assets = await process_old_and_new_slides_and_fetch_assets(
image_generation_service=image_generation_service,
old_slide_content=existing_slide.content or {},
new_slide_content=updated_content,
)
existing_slide.id = uuid.uuid4()
existing_slide.layout = layout_id
existing_slide.layout_group = self._resolve_layout_group(
presentation=presentation,
fallback=existing_slide.layout_group,
)
existing_slide.content = updated_content
existing_slide.speaker_note = self._extract_speaker_note(updated_content)
self._sql_session.add(existing_slide)
self._sql_session.add_all(new_assets)
await self._sql_session.commit()
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
presentation_id=self._presentation_id,
slide_index=target_index,
edit_prompt=f"[chat_tool_save_slide_replace] layout_id={layout_id}",
edited_slide_content=updated_content,
)
return {
"saved": True,
"action": "replaced",
"message": f"Slide at index {target_index} was replaced successfully.",
"slide_id": str(existing_slide.id),
"index": target_index,
}
slides_result = await self._sql_session.scalars(
select(SlideModel)
.where(SlideModel.presentation == self._presentation_id)
.order_by(SlideModel.index)
)
slides = list(slides_result)
if slides:
max_index = max(slide.index for slide in slides)
insert_index = min(target_index, max_index + 1)
slides_to_shift = [slide for slide in slides if slide.index >= insert_index]
else:
insert_index = 0
slides_to_shift = []
for slide in sorted(slides_to_shift, key=lambda each: each.index, reverse=True):
slide.index += 1
self._sql_session.add(slide)
new_slide_content = copy.deepcopy(content)
new_slide = SlideModel(
presentation=self._presentation_id,
layout_group=self._resolve_layout_group(presentation=presentation),
layout=layout_id,
index=insert_index,
content=new_slide_content,
speaker_note=self._extract_speaker_note(new_slide_content),
)
new_assets = await process_slide_and_fetch_assets(
image_generation_service=image_generation_service,
slide=new_slide,
)
self._sql_session.add(new_slide)
self._sql_session.add_all(new_assets)
await self._sql_session.commit()
await self._sql_session.refresh(new_slide)
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
presentation_id=self._presentation_id,
slide_index=insert_index,
edit_prompt=f"[chat_tool_save_slide_new] layout_id={layout_id}",
edited_slide_content=new_slide.content,
)
return {
"saved": True,
"action": "created",
"message": f"New slide saved at index {insert_index}.",
"slide_id": str(new_slide.id),
"index": insert_index,
"shifted_slide_count": len(slides_to_shift),
}
async def retrieve_context(self, query: str) -> str:
context = await MEM0_PRESENTATION_MEMORY_SERVICE.retrieve_context(
self._presentation_id,
query,
)
if context:
LOGGER.info(
"Chat memory semantic context hit (presentation_id=%s, chars=%d)",
self._presentation_id,
len(context),
)
else:
LOGGER.info(
"Chat memory semantic context miss (presentation_id=%s)",
self._presentation_id,
)
return context
async def _get_layout_by_id(
self,
layout_id: str,
presentation: PresentationModel | None = None,
) -> SlideLayoutModel | None:
if not presentation:
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation or not isinstance(presentation.layout, dict):
return None
try:
layout_model = presentation.get_layout()
except Exception:
return None
for layout in layout_model.slides:
if layout.id == layout_id:
return layout
return None
def _validate_slide_content(
self,
*,
content: dict[str, Any],
schema: dict[str, Any],
) -> list[str]:
validation_content = self._strip_runtime_fields(content)
validator = Draft202012Validator(schema)
errors = sorted(validator.iter_errors(validation_content), key=lambda err: err.path)
if not errors:
return []
formatted_errors: list[str] = []
for err in errors[:MAX_SCHEMA_ERRORS]:
location = ".".join([str(part) for part in err.path]) or "$"
formatted_errors.append(f"{location}: {err.message}")
return formatted_errors
@staticmethod
def _strip_runtime_fields(value: Any) -> Any:
if isinstance(value, dict):
sanitized: dict[str, Any] = {}
for key, nested_value in value.items():
if key in RUNTIME_CONTENT_FIELDS:
continue
sanitized[key] = PresentationChatMemoryLayer._strip_runtime_fields(
nested_value
)
return sanitized
if isinstance(value, list):
return [
PresentationChatMemoryLayer._strip_runtime_fields(item) for item in value
]
return value
@staticmethod
def _extract_speaker_note(content: dict[str, Any]) -> str:
value = content.get("__speaker_note__")
if isinstance(value, str):
return value
return ""
@staticmethod
def _resolve_layout_group(
*,
presentation: PresentationModel,
fallback: str = "presentation",
) -> str:
if isinstance(presentation.layout, dict):
name = str(presentation.layout.get("name") or "").strip()
if name:
return name
return fallback
@staticmethod
def _serialize_slide(slide: SlideModel) -> str:
content_text = ""
try:
content_text = json.dumps(slide.content or {}, ensure_ascii=False)
except Exception:
content_text = str(slide.content)
speaker_note = slide.speaker_note or ""
return f"slide_index={slide.index}\nlayout_id={slide.layout}\n{content_text}\n{speaker_note}"
@staticmethod
def _build_snippet(text: str, query_lower: str, window: int = 320) -> str:
normalized = " ".join(text.split())
if not normalized:
return ""
offset = normalized.lower().find(query_lower)
if offset == -1:
return normalized[:window]
start = max(0, offset - window // 3)
end = min(len(normalized), start + window)
return normalized[start:end]

View file

@ -0,0 +1,5 @@
from services.chat.memory_layer import (
PresentationChatMemoryLayer as PresentationContextStore,
)
__all__ = ["PresentationContextStore"]

View file

@ -0,0 +1,63 @@
def _trim_block(label: str, text: str) -> str:
t = (text or "").strip()
if not t:
return ""
return f"\n{label}\n{t}\n"
def build_system_prompt(
presentation_memory_context: str,
chat_memory_context: str,
) -> str:
presentation_block = _trim_block(
"Deck memory (semantic / long-term: uploaded document text, outline drafts & prompts, stored slide-edit notes; snippets may be partial and can lag the live deck):",
presentation_memory_context,
)
chat_block = _trim_block(
"Chat memory (earlier messages in this conversation only):",
chat_memory_context,
)
return (
"You are Presenton's slide assistant. Be concise, accurate, and action-oriented.\n"
"\n"
"Operating priorities\n"
"1) Complete the user's intent with the fewest reliable tool calls.\n"
"2) Prefer verified deck state over assumptions.\n"
"3) Keep responses short and concrete.\n"
"\n"
"Source-of-truth policy\n"
"- Tool outputs from this turn are authoritative for live deck state.\n"
"- Conversation context (user constraints, prior decisions) is next.\n"
"- Deck memory is background context and may be partial or stale.\n"
"- If sources conflict, trust tools over memory.\n"
"\n"
"When to use memory vs tools\n"
"- Use deck memory for uploaded-document meaning, original outline intent, and planning rationale.\n"
"- Use tools for anything about current slides: exact text, ordering, layout, slide identity, and edits.\n"
"- If user asks what is currently on slide N or asks for a change, do not rely on memory alone.\n"
"\n"
"Tool-use protocol (live SQL slide data)\n"
"- User slide numbers are 1-based; tool indexes are 0-based.\n"
"- Start with compact reads: getPresentationOutline -> searchSlides -> getSlideAtIndex.\n"
"- Set includeFullContent=true only when full JSON is required (typically right before saveSlide).\n"
"- Before saveSlide, validate target layout/schema (getAvailableLayouts, getContentSchemaFromLayoutId).\n"
"- Generate required assets in batch with generateAssets before saving.\n"
"- saveSlide payload must match the schema exactly; do not invent fields.\n"
"- If a tool fails, report it briefly and choose the best next step.\n"
"\n"
"Autonomous decision policy (default behavior)\n"
"- For edit requests, execute the best reasonable implementation without asking for optional preferences.\n"
"- Do not ask the user to choose among layouts/assets unless the user explicitly asks to choose.\n"
"- If visual details are unspecified (image style, icon set, exact layout), infer from slide content and deck theme.\n"
"- For requests like 'add images/icons' or 'make it better', pick a layout that best preserves existing intent and readability, then apply it.\n"
"- Ask a clarification only when blocked by a required missing fact (e.g., target slide is ambiguous, conflicting constraints, or missing required data).\n"
"- When in doubt, prefer a professional, neutral visual style and continue.\n"
"\n"
"Response policy\n"
"- Never invent slide facts, tool results, or document claims.\n"
"- If information is missing, run the right tool or ask one focused clarification.\n"
"- After enough evidence is collected, stop calling tools and provide a brief final answer.\n"
"- For edits, apply changes first, then report what changed and where; for lookups, state what you found.\n"
f"{presentation_block}"
f"{chat_block}"
)

View file

@ -0,0 +1,81 @@
import json
from typing import Any, Literal
import dirtyjson # type: ignore[import-untyped]
from pydantic import BaseModel, ConfigDict, Field, field_validator
class StrictSchemaModel(BaseModel):
model_config = ConfigDict(extra="forbid", strict=True)
class NoArgsInput(StrictSchemaModel):
pass
class GetSlideAtIndexInput(StrictSchemaModel):
index: int = Field(ge=0, le=1000)
include_full_content: bool = Field(alias="includeFullContent")
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
class SearchSlidesInput(StrictSchemaModel):
query: str = Field(min_length=1, max_length=1000)
limit: int = Field(ge=1, le=10)
class GetContentSchemaFromLayoutIdInput(StrictSchemaModel):
layout_id: str = Field(alias="layoutId", min_length=1, max_length=200)
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
class GenerateImageInput(StrictSchemaModel):
prompt: str = Field(min_length=1, max_length=4000)
class GenerateIconInput(StrictSchemaModel):
query: str = Field(min_length=1, max_length=1000)
class GenerateAssetItemInput(StrictSchemaModel):
kind: Literal["image", "icon"]
prompt: str = Field(
min_length=1,
max_length=4000,
description="Image prompt or icon search query.",
)
class GenerateAssetsInput(StrictSchemaModel):
assets: list[GenerateAssetItemInput] = Field(min_length=1, max_length=12)
class SaveSlideInput(StrictSchemaModel):
content: str = Field(
min_length=2,
max_length=200000,
description=(
"A JSON-serialized object for slide content. "
"Example: '{\"title\": \"Q4 Revenue\", \"bullets\": [\"North America +22%\"]}'"
),
)
layout_id: str = Field(alias="layoutId", min_length=1, max_length=200)
index: int = Field(ge=0, le=1000)
replace_old_slide_at_index: bool = Field(alias="replaceOldSlideAtIndex")
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
@field_validator("content")
@classmethod
def validate_content(cls, value: str) -> str:
try:
parsed: Any = dirtyjson.loads(value)
except Exception:
parsed = json.loads(value)
if not isinstance(parsed, dict):
raise ValueError("'content' must be a JSON object.")
return value

View file

@ -0,0 +1,449 @@
import asyncio
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, Literal
from fastapi import HTTPException
from llmai import get_client # type: ignore[import-not-found]
from llmai.shared import ( # type: ignore[import-not-found]
AssistantMessage,
Message,
SystemMessage,
TextContentPart,
ToolResponseMessage,
UserMessage,
WebSearchTool
)
from sqlalchemy.ext.asyncio import AsyncSession
from models.sql.presentation import PresentationModel
from services.chat.conversation_store import ChatConversationStore
from services.chat.presentation_context_store import PresentationContextStore
from services.chat.prompts import build_system_prompt
from services.chat.tools import ChatTools
from utils.llm_client_error_handler import handle_llm_client_exceptions
from utils.llm_config import get_llm_config
from utils.llm_provider import get_model
from utils.llm_utils import (
extract_text,
get_generate_kwargs,
stream_generate_events,
)
LOGGER = logging.getLogger(__name__)
MAX_TOOL_ROUNDS = 16
@dataclass(frozen=True)
class ChatTurnResult:
conversation_id: uuid.UUID
response_text: str
tool_calls: list[str]
ChatStreamEventType = Literal["chunk", "complete", "status", "trace"]
ChatStreamEventValue = str | ChatTurnResult | dict[str, Any]
class PresentationChatService:
def __init__(
self,
sql_session: AsyncSession,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID | None,
):
self._sql_session = sql_session
self._presentation_id = presentation_id
self._conversation_id = conversation_id
self._conversation_store = ChatConversationStore(sql_session)
self._memory = PresentationContextStore(sql_session, presentation_id)
self._tools = ChatTools(self._memory)
async def generate_reply(self, user_message: str) -> ChatTurnResult:
conversation_id, messages = await self._prepare_turn_context(user_message)
response_text, tool_calls = await self._run_llm_with_tools(messages)
return await self._persist_turn(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
tool_calls=tool_calls,
)
async def stream_reply(
self, user_message: str
) -> AsyncGenerator[tuple[ChatStreamEventType, ChatStreamEventValue], None]:
yield "status", "Reading deck context"
conversation_id, messages = await self._prepare_turn_context(user_message)
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
tools.append(WebSearchTool())
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
response_text: str | None = None
for round_index in range(MAX_TOOL_ROUNDS):
completion_chunk: Any | None = None
round_content_chunks: list[str] = []
thinking_chunks: list[str] = []
try:
async for event in stream_generate_events(
client,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
stream=True,
),
):
event_type = getattr(event, "type", None)
if event_type == "content":
chunk = getattr(event, "chunk", None)
if chunk:
round_content_chunks.append(chunk)
yield "chunk", chunk
elif event_type == "thinking":
thinking_text = self._event_text(event)
if thinking_text:
thinking_chunks.append(thinking_text)
elif event_type == "completion":
completion_chunk = event
except Exception as exc:
raise handle_llm_client_exceptions(exc)
thinking_summary = self._summarize_model_note(thinking_chunks)
if thinking_summary:
yield "trace", {
"kind": "model_note",
"round": round_index + 1,
"status": "info",
"message": thinking_summary,
}
completion_tool_calls = list(
getattr(completion_chunk, "tool_calls", []) or []
)
if completion_tool_calls:
tool_names = [tool_call.name for tool_call in completion_tool_calls]
called_tools.extend(tool_names)
yield "trace", {
"kind": "tool_plan",
"round": round_index + 1,
"tools": tool_names,
"message": f"Using tools: {', '.join(tool_names)}",
}
messages = (
list(getattr(completion_chunk, "messages", []) or [])
if getattr(completion_chunk, "messages", None)
else list(messages)
)
last_tool_results = []
for tool_call in completion_tool_calls:
yield "trace", {
"kind": "tool_call",
"round": round_index + 1,
"tool": tool_call.name,
"status": "start",
"message": self._tool_start_message(tool_call.name),
}
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
yield "trace", {
"kind": "tool_call",
"round": round_index + 1,
"tool": tool_call.name,
"status": "success" if tool_result.get("ok") else "error",
"message": self._summarize_tool_result(
tool_call.name, tool_result
),
}
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[TextContentPart(text=tool_response_content)],
)
)
continue
response_text = "".join(round_content_chunks)
if not response_text and completion_chunk:
response_text = extract_text(getattr(completion_chunk, "content", None))
if not response_text:
response_text = "I could not generate a response for that request."
if not round_content_chunks:
yield "chunk", response_text
break
else:
LOGGER.warning("Max tool rounds reached in chat stream flow")
yield "trace", {
"kind": "limit",
"message": (
"Reached tool-call limit before final answer; "
"attempting best-effort summary."
),
}
yield "status", "Finalizing response"
response_text = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if not response_text:
response_text = self._build_tool_limit_fallback(last_tool_results)
yield "chunk", response_text
final_response_text = response_text or "I could not generate a response for that request."
if response_text is None:
yield "chunk", final_response_text
yield "status", "Saving chat"
result = await self._persist_turn(
conversation_id=conversation_id,
user_message=user_message,
response_text=final_response_text,
tool_calls=called_tools,
)
yield "complete", result
async def _prepare_turn_context(
self, user_message: str
) -> tuple[uuid.UUID, list[Message]]:
if not (user_message or "").strip():
raise HTTPException(status_code=400, detail="Message is required")
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
if not presentation:
raise HTTPException(status_code=404, detail="Presentation not found")
conversation_id = await self._conversation_store.ensure_conversation_id(
self._conversation_id
)
history = await self._conversation_store.load_history(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
)
history_messages = self._convert_history_to_messages(history)
presentation_memory = await self._memory.retrieve_context(user_message)
chat_memory = await self._conversation_store.retrieve_semantic_context(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
query=user_message,
)
messages: list[Message] = [
SystemMessage(
content=build_system_prompt(
presentation_memory_context=presentation_memory,
chat_memory_context=chat_memory,
)
),
*history_messages,
UserMessage(content=user_message),
]
return conversation_id, messages
async def _persist_turn(
self,
*,
conversation_id: uuid.UUID,
user_message: str,
response_text: str,
tool_calls: list[str],
) -> ChatTurnResult:
await self._conversation_store.append_turn(
presentation_id=self._presentation_id,
conversation_id=conversation_id,
user_message=user_message,
assistant_message=response_text,
tool_calls=tool_calls,
)
await self._sql_session.commit()
return ChatTurnResult(
conversation_id=conversation_id,
response_text=response_text,
tool_calls=tool_calls,
)
async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]:
client = get_client(config=get_llm_config())
model = get_model()
tools = self._tools.get_tool_definitions()
called_tools: list[str] = []
last_tool_results: list[dict[str, Any]] = []
for _ in range(MAX_TOOL_ROUNDS):
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
tools=tools,
),
)
except Exception as exc:
raise handle_llm_client_exceptions(exc)
if not response.tool_calls:
response_text = extract_text(response.content) or (
"I could not generate a response for that request."
)
return response_text, called_tools
called_tools.extend([tool_call.name for tool_call in response.tool_calls])
messages = list(response.messages) if response.messages else list(messages)
last_tool_results = []
for tool_call in response.tool_calls:
tool_result = await self._tools.execute_tool_call(tool_call)
last_tool_results.append(tool_result)
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
messages.append(
ToolResponseMessage(
id=tool_call.id,
content=[TextContentPart(text=tool_response_content)],
)
)
LOGGER.warning("Max tool rounds reached in chat flow")
final_response = await self._try_final_response_without_tools(
client=client,
model=model,
messages=messages,
)
if final_response:
return final_response, called_tools
return self._build_tool_limit_fallback(last_tool_results), called_tools
async def _try_final_response_without_tools(
self,
*,
client: Any,
model: str,
messages: list[Message],
) -> str | None:
try:
response = await asyncio.to_thread(
client.generate,
**get_generate_kwargs(
model=model,
messages=messages,
),
)
except Exception:
LOGGER.warning("Final no-tool synthesis call failed", exc_info=True)
return None
return extract_text(response.content)
@staticmethod
def _summarize_model_note(chunks: list[str]) -> str:
text = "".join(chunks).strip()
if not text or text in {"{}", "[]"}:
return ""
compact = " ".join(text.split())
if compact.lower() in {"start", "end"}:
return ""
if len(compact) > 600:
return f"{compact[:600].rstrip()}..."
return compact
@staticmethod
def _event_text(event: Any) -> str:
for attr in ("chunk", "delta", "text", "content"):
value = getattr(event, attr, None)
if isinstance(value, str):
return value
return ""
@staticmethod
def _tool_start_message(tool_name: str) -> str:
labels = {
"getPresentationOutline": "Reading the presentation outline",
"searchSlides": "Searching relevant slides",
"getSlideAtIndex": "Opening the requested slide",
"getAvailableLayouts": "Checking available layouts",
"getContentSchemaFromLayoutId": "Checking the layout schema",
"generateAssets": "Generating slide assets",
"saveSlide": "Saving the slide",
}
return labels.get(tool_name, f"Running {tool_name}")
@staticmethod
def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str:
for entry in reversed(last_tool_results):
if not isinstance(entry, dict):
continue
if not entry.get("ok"):
continue
result = entry.get("result")
if not isinstance(result, dict):
continue
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
return (
"I completed several tool operations but could not finalize the response "
"within the tool limit. Please ask a follow-up and I will continue."
)
@staticmethod
def _summarize_tool_result(tool_name: str, tool_result: dict[str, Any]) -> str:
if not tool_result.get("ok"):
error = tool_result.get("error")
if isinstance(error, str) and error.strip():
return f"{tool_name} failed: {error.strip()}"
return f"{tool_name} failed."
result = tool_result.get("result")
if isinstance(result, dict):
message = result.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
note = result.get("note")
if isinstance(note, str) and note.strip():
return note.strip()
count = result.get("count")
if isinstance(count, int):
return f"{tool_name} returned {count} result(s)."
found = result.get("found")
if isinstance(found, bool):
return (
f"{tool_name} found requested data."
if found
else f"{tool_name} did not find matching data."
)
return f"{tool_name} completed."
@staticmethod
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
messages: list[Message] = []
for item in history:
role = item.get("role")
content = item.get("content")
if not content:
continue
if role == "user":
messages.append(UserMessage(content=content))
elif role == "assistant":
messages.append(AssistantMessage(content=[content]))
return messages

View file

@ -0,0 +1,212 @@
"""Persist presentation chat threads in SQL rows."""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from models.sql.chat_history_message import ChatHistoryMessageModel
from utils.datetime_utils import get_current_utc_datetime
def _compact_preview(content: str) -> str:
preview = content.strip()
if len(preview) > 200:
return f"{preview[:200]}"
return preview
def _serialize_created_at(value: Any) -> str | None:
if value is None:
return None
if hasattr(value, "isoformat"):
try:
return value.isoformat()
except Exception:
return None
if isinstance(value, str) and value.strip():
return value.strip()
return None
def _parse_created_at(value: Any) -> datetime | None:
if isinstance(value, datetime):
return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
if not isinstance(value, str) or not value.strip():
return None
normalized = value.strip().replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(normalized)
except ValueError:
return None
return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc)
async def load_messages(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, str]]:
rows = await load_messages_with_meta(
session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
return [
{"role": row["role"], "content": row["content"]}
for row in rows
if isinstance(row.get("role"), str) and isinstance(row.get("content"), str)
]
async def load_messages_with_meta(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
) -> list[dict[str, Any]]:
rows = list(
(
await session.scalars(
select(ChatHistoryMessageModel)
.where(
ChatHistoryMessageModel.presentation_id == presentation_id,
ChatHistoryMessageModel.conversation_id == conversation_id,
)
.order_by(ChatHistoryMessageModel.position.asc())
)
).all()
)
out: list[dict[str, Any]] = []
for row in rows:
entry: dict[str, Any] = {
"role": row.role,
"content": row.content,
}
created = _serialize_created_at(row.created_at)
if created:
entry["created_at"] = created
out.append(entry)
return out
async def replace_messages(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
messages: list[dict[str, str]],
) -> None:
await session.execute(
sa_delete(ChatHistoryMessageModel).where(
ChatHistoryMessageModel.presentation_id == presentation_id,
ChatHistoryMessageModel.conversation_id == conversation_id,
)
)
next_position = 1
base_time = get_current_utc_datetime()
for index, message in enumerate(messages):
role = message.get("role")
content = message.get("content")
if role not in ("user", "assistant"):
continue
if not isinstance(content, str) or not content.strip():
continue
created_at = _parse_created_at(message.get("created_at")) or (
base_time + timedelta(microseconds=index)
)
session.add(
ChatHistoryMessageModel(
presentation_id=presentation_id,
conversation_id=conversation_id,
position=next_position,
role=role,
content=content,
created_at=created_at,
)
)
next_position += 1
await session.flush()
async def append_turn(
session: AsyncSession,
*,
presentation_id: uuid.UUID,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
tool_calls: list[str] | None = None,
) -> None:
max_position = await session.scalar(
select(func.max(ChatHistoryMessageModel.position)).where(
ChatHistoryMessageModel.presentation_id == presentation_id,
ChatHistoryMessageModel.conversation_id == conversation_id,
)
)
next_position = int(max_position or 0) + 1
now = get_current_utc_datetime()
session.add(
ChatHistoryMessageModel(
presentation_id=presentation_id,
conversation_id=conversation_id,
position=next_position,
role="user",
content=user_message,
created_at=now,
)
)
session.add(
ChatHistoryMessageModel(
presentation_id=presentation_id,
conversation_id=conversation_id,
position=next_position + 1,
role="assistant",
content=assistant_message,
created_at=now + timedelta(microseconds=1),
tool_calls=tool_calls or None,
)
)
await session.flush()
async def list_conversations(
session: AsyncSession, *, presentation_id: uuid.UUID
) -> list[dict[str, Any]]:
rows = list(
(
await session.scalars(
select(ChatHistoryMessageModel)
.where(ChatHistoryMessageModel.presentation_id == presentation_id)
.order_by(
ChatHistoryMessageModel.created_at.desc(),
ChatHistoryMessageModel.position.desc(),
)
)
).all()
)
summary_by_conversation: dict[str, dict[str, Any]] = {}
for row in rows:
conversation_key = str(row.conversation_id)
if conversation_key in summary_by_conversation:
continue
summary_by_conversation[conversation_key] = {
"conversation_id": conversation_key,
"updated_at": _serialize_created_at(row.created_at),
"last_message_preview": _compact_preview(row.content),
}
summaries = list(summary_by_conversation.values())
summaries.sort(key=lambda item: item.get("updated_at") or "", reverse=True)
return summaries

View file

@ -0,0 +1,356 @@
import json
import logging
import re
from typing import Any, Awaitable, Callable
import dirtyjson # type: ignore[import-untyped]
from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found]
from services.chat.schemas import (
GenerateAssetsInput,
GenerateIconInput,
GenerateImageInput,
GetContentSchemaFromLayoutIdInput,
GetSlideAtIndexInput,
NoArgsInput,
SaveSlideInput,
SearchSlidesInput,
)
from services.chat.presentation_context_store import PresentationContextStore
LOGGER = logging.getLogger(__name__)
ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
class ChatTools:
def __init__(self, memory: PresentationContextStore):
self._memory = memory
self._tool_handlers: dict[str, ToolHandler] = {
"getPresentationOutline": self._get_presentation_outline,
"searchSlides": self._search_slides,
"getSlideAtIndex": self._get_slide_at_index,
"getAvailableLayouts": self._get_available_layouts,
"getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id,
"generateAssets": self._generate_assets,
"generateImage": self._generate_image,
"generateIcon": self._generate_icon,
"saveSlide": self._save_slide,
}
def get_tool_definitions(self) -> list[Tool]:
return [
Tool(
name="getPresentationOutline",
description=(
"Live database: current deck structure. "
"Use for the **actual** slide list/order and compact previews—not for uploaded PDF text or pre-outline RAG. "
"Falls back to stored outlines only if no slide rows exist. "
"Return compact sections (no full slide JSON). Use for flow, sections, or 'what slides exist'."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="searchSlides",
description=(
"Live SQL slides: keyword/semantic style search with snippets and indices. "
"Use to find on-slide text, topics, or which slide mentioned something. "
"For source-document-only questions, rely on deck memory; use this when the question is about **slides as built**. "
"Always provide both query and limit."
),
schema=SearchSlidesInput,
strict=True,
),
Tool(
name="getSlideAtIndex",
description=(
"Live SQL: one slide by index—authoritative for exact current content. "
"Set includeFullContent=true when you need full JSON (before saveSlide or precise edits). "
"If user says slide N, use zero-based index N-1."
),
schema=GetSlideAtIndexInput,
strict=True,
),
Tool(
name="getAvailableLayouts",
description=(
"List all available layout ids and descriptions for the current "
"presentation template."
),
schema=NoArgsInput,
strict=True,
),
Tool(
name="getContentSchemaFromLayoutId",
description=(
"Fetch the JSON content schema for a layout id. Use before "
"saving slide content to validate structure."
),
schema=GetContentSchemaFromLayoutIdInput,
strict=True,
),
Tool(
name="generateAssets",
description=(
"Generate multiple media assets in one call. Use for all slide "
"images and icons before saving content; include every needed "
"asset in the assets array instead of calling image/icon tools "
"one at a time."
),
schema=GenerateAssetsInput,
strict=True,
),
Tool(
name="saveSlide",
description=(
"Save slide content for a layout. If replaceOldSlideAtIndex is "
"true, replace that index; otherwise insert as a new slide. "
"Pass content as a JSON-serialized object string and the server "
"will validate it against layout schema before save."
),
schema=SaveSlideInput,
strict=True,
),
]
async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]:
handler = self._tool_handlers.get(tool_call.name)
if not handler:
return {
"ok": False,
"tool": tool_call.name,
"error": f"Unsupported tool: {tool_call.name}",
}
try:
parsed_args = self._parse_args(tool_call.arguments)
LOGGER.info("Executing chat tool %s", tool_call.name)
result = await handler(parsed_args)
return {"ok": True, "tool": tool_call.name, "result": result}
except Exception as exc:
LOGGER.exception("Chat tool failed: %s", tool_call.name)
return {
"ok": False,
"tool": tool_call.name,
"error": str(exc),
}
async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]:
outline = await self._memory.get("presentation_outline")
if not isinstance(outline, dict):
return {
"found": False,
"message": "Presentation outline is not available in memory yet.",
"sections": [],
}
slides = outline.get("slides")
if not isinstance(slides, list) or not slides:
return {
"found": False,
"message": "Presentation outline exists but has no slides.",
"sections": [],
}
sections: list[dict[str, Any]] = []
for position, slide in enumerate(slides):
index = position
content = ""
if isinstance(slide, dict):
raw_index = slide.get("index")
if isinstance(raw_index, int):
index = raw_index
raw_content = slide.get("content")
if isinstance(raw_content, str):
content = raw_content
elif raw_content is not None:
try:
content = json.dumps(raw_content, ensure_ascii=False)
except Exception:
content = str(raw_content)
elif isinstance(slide, str):
content = slide
title = self._extract_title(content) or f"Slide {index + 1}"
sections.append(
{
"index": index,
"slide_number": index + 1,
"title": title,
}
)
return {
"found": True,
"slide_count": len(sections),
"sections": sections,
"source": outline.get("source", "memory"),
}
async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]:
payload = SearchSlidesInput(**args)
results = await self._memory.search(payload.query, payload.limit)
return {
"query": payload.query,
"count": len(results),
"results": results,
}
async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]:
normalized_args = dict(args)
normalized_args.setdefault("includeFullContent", False)
payload = GetSlideAtIndexInput(**normalized_args)
slide = await self._memory.get_slide_at_index(
payload.index,
include_full_content=payload.include_full_content,
)
if not slide and payload.index > 0:
# Users often refer to slides as 1-based; allow a safe fallback.
fallback_index = payload.index - 1
fallback_slide = await self._memory.get_slide_at_index(
fallback_index,
include_full_content=payload.include_full_content,
)
if fallback_slide:
return {
"found": True,
"slide": fallback_slide,
"requested_index": payload.index,
"resolved_index": fallback_index,
"note": (
"No slide found at requested index; returned one-based fallback "
f"at index {fallback_index}."
),
}
if not slide:
return {
"found": False,
"message": f"No slide found at index {payload.index}.",
}
return {
"found": True,
"slide": slide,
}
async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]:
layouts = await self._memory.get_available_layouts()
return {
"count": len(layouts),
"layouts": layouts,
}
async def _get_content_schema_from_layout_id(
self, args: dict[str, Any]
) -> dict[str, Any]:
payload = GetContentSchemaFromLayoutIdInput(**args)
schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id)
if schema is None:
return {
"found": False,
"layout_id": payload.layout_id,
"message": "Layout schema not found for the provided layout id.",
}
return {
"found": True,
"layout_id": payload.layout_id,
"content_schema": schema,
}
async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateImageInput(**args)
image_url = await self._memory.generate_image(payload.prompt)
return {
"prompt": payload.prompt,
"url": image_url,
}
async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateIconInput(**args)
icon_url = await self._memory.generate_icon(payload.query)
return {
"query": payload.query,
"url": icon_url,
}
async def _generate_assets(self, args: dict[str, Any]) -> dict[str, Any]:
payload = GenerateAssetsInput(**args)
generated_assets: list[dict[str, Any]] = []
for index, asset in enumerate(payload.assets):
if asset.kind == "image":
result = await self._generate_image({"prompt": asset.prompt})
else:
result = await self._generate_icon({"query": asset.prompt})
generated_assets.append(
{
"index": index,
"kind": asset.kind,
"prompt": asset.prompt,
"url": result.get("url"),
}
)
return {
"count": len(generated_assets),
"assets": generated_assets,
"message": f"Generated {len(generated_assets)} asset(s).",
}
async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]:
payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False))
raw_content = payload_args.get("content")
if isinstance(raw_content, dict):
payload_args["content"] = json.dumps(raw_content, ensure_ascii=False)
payload = SaveSlideInput(**payload_args)
try:
content_parsed: Any = dirtyjson.loads(payload.content)
except Exception:
content_parsed = json.loads(payload.content)
if not isinstance(content_parsed, dict):
raise ValueError("'content' must be a JSON object.")
content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False))
return await self._memory.save_slide(
content=content_payload,
layout_id=payload.layout_id,
index=payload.index,
replace_old_slide_at_index=payload.replace_old_slide_at_index,
)
@staticmethod
def _parse_args(arguments: str | None) -> dict[str, Any]:
if not arguments:
return {}
try:
parsed = dirtyjson.loads(arguments)
except Exception:
parsed = json.loads(arguments)
normalized = json.loads(json.dumps(parsed, ensure_ascii=False))
if isinstance(normalized, dict):
return normalized
raise ValueError("Tool arguments must be a JSON object.")
@staticmethod
def _extract_title(markdown_content: str) -> str:
for line in markdown_content.splitlines():
stripped = line.strip()
if not stripped:
continue
heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped)
if heading_match:
return heading_match.group(1).strip()
return stripped[:120]
return ""
@staticmethod
def _truncate(value: str, limit: int) -> str:
if len(value) <= limit:
return value
return f"{value[:limit]}..."

View file

@ -1,5 +1,4 @@
from collections.abc import AsyncGenerator
import os
from sqlalchemy.ext.asyncio import (
AsyncEngine,
create_async_engine,
@ -11,6 +10,7 @@ from sqlmodel import SQLModel
from models.sql.async_presentation_generation_status import (
AsyncPresentationGenerationTaskModel,
)
from models.sql.chat_history_message import ChatHistoryMessageModel
from models.sql.image_asset import ImageAsset
from models.sql.key_value import KeyValueSqlModel
from models.sql.ollama_pull_status import OllamaPullStatus
@ -20,7 +20,6 @@ from models.sql.template import TemplateModel
from models.sql.template_create_info import TemplateCreateInfoModel
from models.sql.slide import SlideModel
from models.sql.webhook_subscription import WebhookSubscription
from utils.get_env import get_app_data_directory_env
from utils.get_env import get_migrate_database_on_startup_env
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
@ -42,22 +41,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
yield session
# Container DB (Lives inside the app data directory)
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
container_db_engine: AsyncEngine = create_async_engine(
container_db_url, connect_args={"check_same_thread": False}
)
container_db_async_session_maker = async_sessionmaker(
container_db_engine, expire_on_commit=False
)
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
async with container_db_async_session_maker() as session:
yield session
# Create Database and Tables
async def create_db_and_tables():
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
@ -70,24 +53,18 @@ async def create_db_and_tables():
PresentationModel.__table__,
SlideModel.__table__,
KeyValueSqlModel.__table__,
ChatHistoryMessageModel.__table__,
ImageAsset.__table__,
PresentationLayoutCodeModel.__table__,
TemplateCreateInfoModel.__table__,
TemplateModel.__table__,
WebhookSubscription.__table__,
AsyncPresentationGenerationTaskModel.__table__,
OllamaPullStatus.__table__,
],
)
)
async with container_db_engine.begin() as conn:
await conn.run_sync(
lambda sync_conn: SQLModel.metadata.create_all(
sync_conn,
tables=[OllamaPullStatus.__table__],
)
)
async def dispose_engines():
"""Dispose all engine connection pools.
@ -97,4 +74,3 @@ async def dispose_engines():
database and prevent stale / leaked connections.
"""
await sql_engine.dispose()
await container_db_engine.dispose()

View file

@ -1,6 +1,8 @@
import asyncio
import json
import logging
import os
import re
import tempfile
from pathlib import Path
from typing import Any, List, Optional, Tuple
@ -30,6 +32,129 @@ except Exception:
LOGGER = logging.getLogger(__name__)
def _unwrap_liteparse_json_line_if_stored(text: str) -> str:
"""If the whole JSON line from the LiteParse runner was stored as the document, keep only the text field."""
if not text:
return text
s = text.lstrip()
if not s.startswith("{"):
return text
try:
payload = json.loads(s)
except (json.JSONDecodeError, TypeError, ValueError):
return text
if not isinstance(payload, dict):
return text
if (
payload.get("ok") is True
and "filePath" in payload
and isinstance(payload.get("text"), str)
):
return payload["text"]
return text
_RE_TEXT_KEY = re.compile(r'"text"\s*:\s*"')
def _json_unescape_quoted_value(s: str, content_start: int) -> str:
"""
Unescape a JSON string value. `content_start` is the index of the first character
*inside* the value (immediately after the opening quote of the "text" field).
If the closing quote is missing (truncated), returns the unescaped rest of the string.
"""
out: list[str] = []
i = content_start
n = len(s)
while i < n:
c = s[i]
if c == "\\" and i + 1 < n:
e = s[i + 1]
if e in '"\\':
out.append(e)
i += 2
elif e == "/":
out.append("/")
i += 2
elif e == "b":
out.append("\b")
i += 2
elif e == "f":
out.append("\f")
i += 2
elif e == "n":
out.append("\n")
i += 2
elif e == "r":
out.append("\r")
i += 2
elif e == "t":
out.append("\t")
i += 2
elif e == "u" and i + 5 < n:
try:
out.append(chr(int(s[i + 2 : i + 6], 16)))
except (ValueError, OverflowError):
out.append(s[i : i + 6])
i += 6
else:
out.append(e)
i += 2
elif c == '"':
return "".join(out)
else:
out.append(c)
i += 1
return "".join(out)
def _try_extract_liteparse_text_value_from_malformed_json(s: str) -> Optional[str]:
"""
When json.loads failed (e.g. truncated or corrupt), find the "text" field value
in a LiteParse-shaped object and return only the unescaped string body.
"""
if not s.startswith("{"):
return None
head = s[:10000] if len(s) > 10000 else s
if not ("ok" in head and "filePath" in head):
return None
m = _RE_TEXT_KEY.search(s)
if not m:
return None
return _json_unescape_quoted_value(s, m.end())
def _clean_extracted_one_pass(t: str) -> str:
for _ in range(3):
nxt = _unwrap_liteparse_json_line_if_stored(t)
if nxt == t:
break
t = nxt
s = t.lstrip()
if s.startswith("{"):
m = _try_extract_liteparse_text_value_from_malformed_json(s)
if m is not None:
return m
return t
def clean_extracted_document_text(text: str) -> str:
"""
Return only the document body: strip LiteParse JSON wrappers, then drop any
leading payload before the "text" value (handles truncated/invalid JSON).
Multiple passes in case the inner body is again JSON-shaped.
"""
if not text:
return text
t = text
for _ in range(4):
nxt = _clean_extracted_one_pass(t)
if nxt == t:
return t
t = nxt
return t
class DocumentsLoader:
DECOMPOSE_TIMEOUT_SECONDS = 600
@ -107,6 +232,7 @@ class DocumentsLoader:
else:
document = await asyncio.to_thread(self._parse_with_liteparse, file_path)
document = clean_extracted_document_text(document)
documents.append(document)
images.append(imgs)

View file

@ -227,6 +227,11 @@ class LiteParseService:
return True, "ok"
@staticmethod
def _use_json_runner_output() -> bool:
"""If true, expect one JSON line on stdout (legacy). Default is plain UTF-8 text (better for large PDFs)."""
return (os.getenv("LITEPARSE_RUNNER_OUTPUT") or "").strip().lower() == "json"
def parse_to_markdown(
self,
file_path: str,
@ -271,6 +276,9 @@ class LiteParseService:
if tessdata:
command.extend(["--tessdata-path", tessdata])
use_json = self._use_json_runner_output()
command.extend(["--python-bridge", "json" if use_json else "plain"])
LOGGER.info(
"[LiteParse] Parsing file=%s ocr_enabled=%s ocr_language=%s dpi=%s num_workers=%s",
file_path,
@ -294,6 +302,20 @@ class LiteParseService:
_command_str(command),
)
if not use_json:
if process.returncode != 0:
err = (process.stderr or "").strip() or "LiteParse failed"
raise LiteParseError(
f"{err}; returncode={process.returncode}; "
f"stderr={_snippet(process.stderr)}; stdout={_snippet(process.stdout)}"
)
return {
"ok": True,
"text": (process.stdout or "").lstrip("\ufeff"),
"filePath": file_path,
"pageCount": 0,
}
payload: Dict[str, Any]
try:
payload = self._decode_runner_output(process.stdout)

View file

@ -0,0 +1,131 @@
"""Single shared mem0 OSS ``Memory`` client for the process.
All callers (presentation context, chat turns) use the same on-disk Qdrant/SQLite
and distinguish data via mem0 ``user_id``:
- Deck-level (no chat thread): ``{namespace}:{presentation_id}``
- Chat thread: ``{namespace}:{presentation_id}:conversation:{conversation_id}``
The chat flow calls ``ensure_conversation_id`` before the first turn, so a
``conversation_id`` exists before any mem0 write for that thread.
"""
from __future__ import annotations
import logging
import os
import threading
from importlib import import_module
from typing import Any, Optional
LOGGER = logging.getLogger(__name__)
_memory_init_lock = threading.Lock()
_shared_client: Any | None = None
_init_attempted = False
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
return default
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _to_int(value: Optional[str], default: int) -> int:
try:
parsed = int(value) if value is not None else default
return max(1, parsed)
except Exception:
return default
def _oss_config_from_env() -> tuple[str, str, str, str, int, dict[str, Any]]:
"""Return (mem0_dir, qdrant_path, history_db, collection, dims, from_config_dict)."""
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
qdrant_path = (
os.getenv("MEM0_QDRANT_PATH") or os.path.join(mem0_dir, "qdrant")
).strip()
history_db_path = (
os.getenv("MEM0_HISTORY_DB_PATH") or os.path.join(mem0_dir, "history.db")
).strip()
collection = (
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
).strip() or "presenton_memories"
embedder = (os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed").strip() or "fastembed"
model = (
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
).strip() or "BAAI/bge-small-en-v1.5"
dims = _to_int(os.getenv("MEM0_EMBEDDING_DIMS"), default=384)
config: dict[str, Any] = {
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": collection,
"path": qdrant_path,
"on_disk": True,
"embedding_model_dims": dims,
},
},
"embedder": {
"provider": embedder,
"config": {
"model": model,
"embedding_dims": dims,
},
},
"history_db_path": history_db_path,
}
return mem0_dir, qdrant_path, history_db_path, collection, dims, config
def memory_from_config(config: dict[str, Any], *, telemetry_base: str) -> Any:
"""Construct ``mem0.Memory``. Caller must hold ``_memory_init_lock`` if used with shared state."""
os.makedirs(telemetry_base, exist_ok=True)
import mem0.memory.main as mem0_main # type: ignore[import-untyped]
mem0_main.mem0_dir = telemetry_base
memory_cls = getattr(import_module("mem0"), "Memory")
return memory_cls.from_config(config)
def get_shared_mem0_client() -> Any | None:
"""Return the process-wide mem0 client, or ``None`` if disabled or init failed."""
global _shared_client, _init_attempted
if not _to_bool(os.getenv("MEM0_ENABLED"), default=True):
return None
if _shared_client is not None:
return _shared_client
if _init_attempted:
return None
with _memory_init_lock:
if _shared_client is not None:
return _shared_client
if _init_attempted:
return None
_init_attempted = True
try:
mem0_dir, qdrant_path, history_db, collection, dims, config = (
_oss_config_from_env()
)
os.makedirs(mem0_dir, exist_ok=True)
os.makedirs(qdrant_path, exist_ok=True)
telemetry_base = os.path.join(mem0_dir, "telemetry", "oss")
_shared_client = memory_from_config(
config,
telemetry_base=telemetry_base,
)
LOGGER.info(
"Mem0 OSS shared memory initialized (qdrant_path=%s, history_db_path=%s, collection=%s, dims=%s)",
qdrant_path,
history_db,
collection,
dims,
)
except BaseException:
LOGGER.exception("Failed to initialize shared Mem0 OSS Memory")
_shared_client = None
return _shared_client

View file

@ -2,10 +2,11 @@ import asyncio
import json
import logging
import os
from importlib import import_module
from typing import Any, Optional
from uuid import UUID
from services.mem0_oss_memory import get_shared_mem0_client
LOGGER = logging.getLogger(__name__)
@ -21,31 +22,6 @@ class Mem0PresentationMemoryService:
os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") or "presentation"
).strip() or "presentation"
self._embedder_provider = (
os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed"
).strip() or "fastembed"
self._embedder_model = (
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
).strip() or "BAAI/bge-small-en-v1.5"
self._embedding_dims = self._to_int(
os.getenv("MEM0_EMBEDDING_DIMS"),
default=384,
)
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
self._mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
self._qdrant_path = (os.getenv("MEM0_QDRANT_PATH") or os.path.join(self._mem0_dir, "qdrant")).strip()
self._history_db_path = (
os.getenv("MEM0_HISTORY_DB_PATH")
or os.path.join(self._mem0_dir, "history.db")
).strip()
self._collection_name = (
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
).strip() or "presenton_memories"
self._client: Any = None
self._attempted_client_init = False
@staticmethod
def _to_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
@ -68,27 +44,6 @@ class Mem0PresentationMemoryService:
return text
return f"{text[:limit]}\n\n[TRUNCATED]"
def _get_oss_config(self) -> dict:
return {
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": self._collection_name,
"path": self._qdrant_path,
"on_disk": True,
"embedding_model_dims": self._embedding_dims,
},
},
"embedder": {
"provider": self._embedder_provider,
"config": {
"model": self._embedder_model,
"embedding_dims": self._embedding_dims,
},
},
"history_db_path": self._history_db_path,
}
@staticmethod
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
return isinstance(exc, (Exception, SystemExit))
@ -96,42 +51,7 @@ class Mem0PresentationMemoryService:
async def _get_client(self):
if not self._enabled:
return None
if self._client is not None:
return self._client
if self._attempted_client_init:
return None
self._attempted_client_init = True
try:
module = import_module("mem0")
memory_cls = getattr(module, "Memory")
os.makedirs(self._mem0_dir, exist_ok=True)
os.makedirs(self._qdrant_path, exist_ok=True)
config = self._get_oss_config()
try:
self._client = memory_cls.from_config(config)
except Exception:
# Backward compatibility across mem0 OSS versions.
self._client = memory_cls(config)
LOGGER.info(
"Mem0 OSS presentation memory service initialized (qdrant_path=%s, history_db_path=%s)",
self._qdrant_path,
self._history_db_path,
)
except BaseException as exc:
if not self._is_nonfatal_mem0_error(exc):
raise
LOGGER.exception("Failed to initialize Mem0 OSS Memory")
self._client = None
return self._client
return get_shared_mem0_client()
async def _add_message(self, presentation_id: UUID, message: str):
client = await self._get_client()

View file

@ -0,0 +1,137 @@
import asyncio
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
from services.chat.conversation_store import ChatConversationStore
class TestChatConversationStore:
def test_load_history_reads_sql_first(self):
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
expected_history = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
]
sql_session = MagicMock()
with patch(
"services.chat.conversation_store.sql_chat_history.load_messages",
new=AsyncMock(return_value=expected_history),
) as load_sql, patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
new=AsyncMock(),
) as load_mem0:
store = ChatConversationStore(sql_session)
history = asyncio.run(
store.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
)
load_sql.assert_awaited_once_with(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
)
load_mem0.assert_not_called()
assert history == expected_history
def test_load_history_falls_back_to_mem0_and_backfills_sql(self):
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
legacy = [
{"role": "user", "content": "old"},
{"role": "assistant", "content": "from mem0"},
]
sql_session = MagicMock()
with patch(
"services.chat.conversation_store.sql_chat_history.load_messages",
new=AsyncMock(return_value=[]),
) as load_sql, patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
new=AsyncMock(return_value=legacy),
) as load_mem0, patch(
"services.chat.conversation_store.sql_chat_history.replace_messages",
new=AsyncMock(),
) as replace_messages:
store = ChatConversationStore(sql_session)
history = asyncio.run(
store.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
)
load_sql.assert_awaited_once()
load_mem0.assert_awaited_once()
replace_messages.assert_awaited_once_with(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
messages=legacy,
)
assert history == legacy
def test_append_turn_persists_sql_and_mem0(self):
sql_session = MagicMock()
store = ChatConversationStore(sql_session)
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
with patch(
"services.chat.conversation_store.sql_chat_history.append_turn",
new=AsyncMock(),
) as append_sql, patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.store_chat_turn",
new=AsyncMock(),
) as store_mem0:
asyncio.run(
store.append_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you improve slide 2?",
assistant_message="Yes, I will tighten the bullet points.",
)
)
append_sql.assert_awaited_once_with(
sql_session,
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you improve slide 2?",
assistant_message="Yes, I will tighten the bullet points.",
tool_calls=None,
)
store_mem0.assert_awaited_once_with(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you improve slide 2?",
assistant_message="Yes, I will tighten the bullet points.",
)
def test_retrieve_semantic_context_delegates_to_chat_memory_store(self):
store = ChatConversationStore(MagicMock())
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
with patch(
"services.chat.conversation_store.CHAT_MEMORY_STORE.retrieve_context",
new=AsyncMock(return_value="conversation-scoped context"),
) as retrieve_context:
context = asyncio.run(
store.retrieve_semantic_context(
presentation_id=presentation_id,
conversation_id=conversation_id,
query="What did we decide?",
)
)
retrieve_context.assert_awaited_once_with(
presentation_id=presentation_id,
conversation_id=conversation_id,
query="What did we decide?",
)
assert context == "conversation-scoped context"

View file

@ -0,0 +1,249 @@
import asyncio
import uuid
from unittest.mock import patch
import services.mem0_oss_memory as mem0_oss
from services.chat.chat_memory_store import ChatMemoryStore
class FakeMemoryClient:
instances: list["FakeMemoryClient"] = []
def __init__(self, config=None):
self.config = config
self.add_calls = []
self.search_calls = []
self.get_all_calls = []
self.next_search_response = {"results": []}
self.next_get_all_response = {"results": []}
FakeMemoryClient.instances.append(self)
@classmethod
def from_config(cls, config):
return cls(config=config)
def add(self, *args, **kwargs):
messages = kwargs.get("messages") if "messages" in kwargs else None
if messages is None and args:
messages = args[0]
self.add_calls.append(
{
"messages": messages,
"user_id": kwargs.get("user_id"),
"infer": kwargs.get("infer"),
}
)
return {"ok": True}
def search(self, query, *args, **kwargs):
self.search_calls.append(
{
"query": query,
"filters": kwargs.get("filters"),
"user_id": kwargs.get("user_id"),
"top_k": kwargs.get("top_k"),
}
)
return self.next_search_response
def get_all(self, *args, **kwargs):
self.get_all_calls.append(
{
"filters": kwargs.get("filters"),
"user_id": kwargs.get("user_id"),
"limit": kwargs.get("limit"),
}
)
return self.next_get_all_response
def _mem0_oss_fresh() -> None:
mem0_oss._shared_client = None # type: ignore[attr-defined]
mem0_oss._init_attempted = False # type: ignore[attr-defined]
class TestChatMemoryStore:
def setup_method(self):
FakeMemoryClient.instances = []
_mem0_oss_fresh()
def test_store_chat_turn_uses_conversation_scoped_user_id(self):
with patch.dict(
"os.environ",
{
"MEM0_ENABLED": "true",
"MEM0_TOP_K": "4",
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
},
clear=False,
), patch(
"services.chat.chat_memory_store.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {"provider": "fastembed", "config": {}},
}
),
):
store = ChatMemoryStore()
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
asyncio.run(
store.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Can you tighten slide 3?",
assistant_message="Yes, I can make it shorter.",
)
)
assert len(FakeMemoryClient.instances) == 1
client = FakeMemoryClient.instances[0]
assert len(client.add_calls) == 1
expected_user_id = (
f"presentation:{presentation_id}:conversation:{conversation_id}"
)
assert client.add_calls[0]["user_id"] == expected_user_id
assert client.add_calls[0]["infer"] is False
payload = str(client.add_calls[0]["messages"][0]["content"])
assert "[chat_turn]" in payload
assert "user=Can you tighten slide 3?" in payload
assert "assistant=Yes, I can make it shorter." in payload
def test_retrieve_context_reads_only_conversation_scoped_user_id(self):
with patch.dict(
"os.environ",
{
"MEM0_ENABLED": "true",
"MEM0_TOP_K": "6",
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
},
clear=False,
), patch(
"services.chat.chat_memory_store.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {"provider": "fastembed", "config": {}},
}
),
):
store = ChatMemoryStore()
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
expected_user_id = (
f"presentation:{presentation_id}:conversation:{conversation_id}"
)
asyncio.run(
store.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="First turn",
assistant_message="First answer",
)
)
client = FakeMemoryClient.instances[0]
client.next_search_response = {
"results": [
{"memory": "Chat memory A"},
{"memory": "Chat memory A"},
{"memory": "Chat memory B"},
]
}
context = asyncio.run(
store.retrieve_context(
presentation_id=presentation_id,
conversation_id=conversation_id,
query="What did we decide?",
)
)
assert "Chat memory A" in context
assert "Chat memory B" in context
assert context.count("Chat memory A") == 1
assert len(client.search_calls) == 1
assert client.search_calls[0]["query"] == "What did we decide?"
assert client.search_calls[0]["filters"] == {"user_id": expected_user_id}
assert client.search_calls[0]["top_k"] == 6
def test_load_history_reads_conversation_scoped_turns(self):
with patch.dict(
"os.environ",
{
"MEM0_ENABLED": "true",
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
},
clear=False,
), patch(
"services.chat.chat_memory_store.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {"provider": "fastembed", "config": {}},
}
),
):
store = ChatMemoryStore()
presentation_id = uuid.uuid4()
conversation_id = uuid.uuid4()
expected_user_id = (
f"presentation:{presentation_id}:conversation:{conversation_id}"
)
asyncio.run(
store.store_chat_turn(
presentation_id=presentation_id,
conversation_id=conversation_id,
user_message="Draft intro",
assistant_message="Updated intro done.",
)
)
client = FakeMemoryClient.instances[0]
client.next_get_all_response = {
"results": [
{
"memory": (
"[chat_turn]\n"
"turn_created_at=2026-04-25T10:00:00+00:00\n"
"user=Draft intro\nassistant=Updated intro done."
),
"created_at": "2026-04-25T10:00:01+00:00",
},
{
"memory": (
"[chat_turn]\n"
"turn_created_at=2026-04-25T10:01:00+00:00\n"
"user=Add roadmap\nassistant=Roadmap slide added."
),
"created_at": "2026-04-25T10:01:01+00:00",
},
]
}
history = asyncio.run(
store.load_history(
presentation_id=presentation_id,
conversation_id=conversation_id,
)
)
assert history == [
{"role": "user", "content": "Draft intro"},
{"role": "assistant", "content": "Updated intro done."},
{"role": "user", "content": "Add roadmap"},
{"role": "assistant", "content": "Roadmap slide added."},
]
assert len(client.get_all_calls) == 1
assert client.get_all_calls[0]["filters"] == {"user_id": expected_user_id}
assert client.get_all_calls[0]["limit"] >= 10

View file

@ -0,0 +1,60 @@
import json
from services.documents_loader import (
_unwrap_liteparse_json_line_if_stored,
clean_extracted_document_text,
)
def test_unwrap_strips_liteparse_json_line():
inner = "Title\n\nBody with \"quotes\" and\nnewlines."
line = json.dumps(
{"ok": True, "filePath": "/tmp/x.pdf", "text": inner},
ensure_ascii=False,
)
assert _unwrap_liteparse_json_line_if_stored(line) == inner
assert _unwrap_liteparse_json_line_if_stored(" \n" + line) == inner
def test_unwrap_leaves_plain_text():
t = "Not JSON. {Braces} in prose."
assert _unwrap_liteparse_json_line_if_stored(t) is t
def test_unwrap_rejects_malformed_json():
t = "{not valid json"
assert _unwrap_liteparse_json_line_if_stored(t) is t
def test_clean_extracts_text_when_json_truncated():
"""Drops everything before the "text" value and unescapes, even if JSON is not closed."""
blob = (
'{"ok": true, "filePath": "/tmp/x.pdf", "text": " similarweb | HypeAuditor\\n\\n2024" '
)
# Missing closing " } — json.loads will fail, fallback path should still return body
out = clean_extracted_document_text(blob)
assert "similarweb" in out
assert "ok" not in out
assert "filePath" not in out
def test_clean_same_as_unwrap_for_valid_line():
inner = "Prose only."
line = json.dumps(
{"ok": True, "filePath": "/tmp/x.pdf", "text": inner},
ensure_ascii=False,
)
assert clean_extracted_document_text(line) == inner
def test_clean_double_json_embedded_in_text_field():
inner2 = "Final body."
inner1 = json.dumps(
{"ok": True, "filePath": "/a.pdf", "text": inner2},
ensure_ascii=False,
)
outer = json.dumps(
{"ok": True, "filePath": "/b.pdf", "text": inner1},
ensure_ascii=False,
)
assert clean_extracted_document_text(outer) == inner2

View file

@ -4,8 +4,12 @@ from unittest.mock import patch
from services.liteparse_service import LiteParseService
def _ok_process(stdout: str = '{"ok": true, "text": "ok"}'):
return SimpleNamespace(returncode=0, stdout=stdout, stderr="")
def _ok_process(
stdout: str = "ok",
returncode: int = 0,
stderr: str = "",
):
return SimpleNamespace(returncode=returncode, stdout=stdout, stderr=stderr)
class TestLiteParseService:
@ -26,13 +30,16 @@ class TestLiteParseService:
return_value=_ok_process(),
) as mock_run:
service = LiteParseService(timeout_seconds=30)
service.parse("/tmp/sample.pdf", ocr_enabled=True, ocr_language="eng")
r = service.parse("/tmp/sample.pdf", ocr_enabled=True, ocr_language="eng")
assert r["ok"] is True
assert r["text"] == "ok"
command = mock_run.call_args.args[0]
assert "--dpi" in command
assert command[command.index("--dpi") + 1] == "120"
assert "--num-workers" in command
assert command[command.index("--num-workers") + 1] == "1"
assert command[command.index("--python-bridge") + 1] == "plain"
def test_parse_uses_env_overrides(self):
with patch.dict(
@ -79,3 +86,23 @@ class TestLiteParseService:
command = mock_run.call_args.args[0]
assert command[command.index("--dpi") + 1] == "72"
assert command[command.index("--num-workers") + 1] == "1"
def test_parse_json_bridge_env(self):
with patch.dict(
"os.environ",
{"LITEPARSE_RUNNER_OUTPUT": "json"},
clear=False,
), patch.object(
LiteParseService,
"check_runtime_ready",
return_value=(True, "ok"),
), patch(
"services.liteparse_service.subprocess.run",
return_value=_ok_process(stdout='{"ok": true, "text": "legacy"}\n'),
) as mock_run:
service = LiteParseService(timeout_seconds=30)
r = service.parse("/tmp/sample.pdf", ocr_enabled=True, ocr_language="eng")
assert r["text"] == "legacy"
command = mock_run.call_args.args[0]
assert command[command.index("--python-bridge") + 1] == "json"

View file

@ -2,11 +2,12 @@ import asyncio
import uuid
from unittest.mock import patch
import services.mem0_oss_memory as mem0_oss
from services.mem0_presentation_memory_service import Mem0PresentationMemoryService
class FakeMemoryClient:
instances = []
instances: list["FakeMemoryClient"] = []
def __init__(self, config=None):
self.config = config
@ -45,13 +46,15 @@ class FakeMemoryClient:
return self.next_search_response
class FakeMem0Module:
Memory = FakeMemoryClient
def _mem0_oss_fresh() -> None:
mem0_oss._shared_client = None # type: ignore[attr-defined]
mem0_oss._init_attempted = False # type: ignore[attr-defined]
class TestMem0PresentationMemoryService:
def setup_method(self):
FakeMemoryClient.instances = []
_mem0_oss_fresh()
def test_store_generation_context_uses_presentation_scope(self):
with patch.dict(
@ -62,8 +65,25 @@ class TestMem0PresentationMemoryService:
},
clear=False,
), patch(
"services.mem0_presentation_memory_service.import_module",
return_value=FakeMem0Module,
"services.mem0_presentation_memory_service.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {
"provider": "qdrant",
"config": {
"on_disk": True,
"embedding_model_dims": 384,
},
},
"embedder": {
"provider": "fastembed",
"config": {
"model": "BAAI/bge-small-en-v1.5",
"embedding_dims": 384,
},
},
}
),
):
service = Mem0PresentationMemoryService()
presentation_id = uuid.uuid4()
@ -115,8 +135,19 @@ class TestMem0PresentationMemoryService:
},
clear=False,
), patch(
"services.mem0_presentation_memory_service.import_module",
return_value=FakeMem0Module,
"services.mem0_presentation_memory_service.get_shared_mem0_client",
return_value=FakeMemoryClient.from_config(
{
"vector_store": {"provider": "qdrant", "config": {}},
"embedder": {
"provider": "fastembed",
"config": {
"model": "BAAI/bge-small-en-v1.5",
"embedding_dims": 384,
},
},
}
),
):
service = Mem0PresentationMemoryService()
presentation_id = uuid.uuid4()
@ -154,3 +185,4 @@ class TestMem0PresentationMemoryService:
"user_id": f"presentation:{presentation_id}"
}
assert client.search_calls[0]["top_k"] == 5

View file

@ -94,7 +94,7 @@ def get_llm_config() -> ClientConfig:
raise HTTPException(status_code=400, detail="OpenAI API Key is not set")
return OpenAIClientConfig(
api_key=api_key,
api_type=OpenAIApiType.RESPONSES,
api_type=OpenAIApiType.COMPLETIONS,
)
case LLMProvider.GOOGLE:
api_key = get_google_api_key_env()

View file

@ -10,8 +10,42 @@ from llmai.shared import (
ResponseFormat,
normalize_content_parts,
)
from llmai.shared.tools import Tool # type: ignore[import-not-found]
from pydantic import BaseModel
from enums.llm_provider import LLMProvider
from utils.llm_config import get_extra_body
from utils.llm_provider import get_llm_provider
from utils.schema_utils import flatten_json_schema
def _tools_for_google_gemini(tools: list[LLMTool]) -> list[LLMTool]:
"""Gemini's Python SDK rejects ``$ref`` / ``$defs`` in function parameters; inline them."""
converted: list[LLMTool] = []
for tool in tools:
if not isinstance(tool, Tool):
converted.append(tool)
continue
schema_obj = tool.input_schema
if isinstance(schema_obj, dict):
raw = dict(schema_obj)
elif isinstance(schema_obj, type) and issubclass(schema_obj, BaseModel):
raw = schema_obj.model_json_schema()
elif isinstance(schema_obj, BaseModel):
raw = schema_obj.__class__.model_json_schema()
else:
converted.append(tool)
continue
flat = flatten_json_schema(raw)
converted.append(
Tool(
name=tool.name,
description=tool.description,
schema=flat,
strict=tool.strict,
)
)
return converted
def get_generate_kwargs(
@ -30,7 +64,10 @@ def get_generate_kwargs(
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if tools:
kwargs["tools"] = tools
if get_llm_provider() == LLMProvider.GOOGLE:
kwargs["tools"] = _tools_for_google_gemini(tools)
else:
kwargs["tools"] = tools
if response_format is not None:
kwargs["response_format"] = response_format

View file

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = "==3.11.*"
[[package]]
@ -811,9 +811,7 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/c6/dba32cab7e3a625b011aa5647486e2d28423a48845a2998c126dd69c85e1/greenlet-3.4.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:805bebb4945094acbab757d34d6e1098be6de8966009ab9ca54f06ff492def58", size = 285504, upload-time = "2026-04-08T15:52:14.071Z" },
{ url = "https://files.pythonhosted.org/packages/54/f4/7cb5c2b1feb9a1f50e038be79980dfa969aa91979e5e3a18fdbcfad2c517/greenlet-3.4.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:439fc2f12b9b512d9dfa681c5afe5f6b3232c708d13e6f02c845e0d9f4c2d8c6", size = 605476, upload-time = "2026-04-08T16:24:37.064Z" },
{ url = "https://files.pythonhosted.org/packages/d6/af/b66ab0b2f9a4c5a867c136bf66d9599f34f21a1bcca26a2884a29c450bd9/greenlet-3.4.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a70ed1cb0295bee1df57b63bf7f46b4e56a5c93709eea769c1fec1bb23a95875", size = 618336, upload-time = "2026-04-08T16:30:56.59Z" },
{ url = "https://files.pythonhosted.org/packages/6d/31/56c43d2b5de476f77d36ceeec436328533bff960a4cba9a07616e93063ab/greenlet-3.4.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8c5696c42e6bb5cfb7c6ff4453789081c66b9b91f061e5e9367fa15792644e76", size = 625045, upload-time = "2026-04-08T16:40:37.111Z" },
{ url = "https://files.pythonhosted.org/packages/e5/5c/8c5633ece6ba611d64bf2770219a98dd439921d6424e4e8cf16b0ac74ea5/greenlet-3.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c660bce1940a1acae5f51f0a064f1bc785d07ea16efcb4bc708090afc4d69e83", size = 613515, upload-time = "2026-04-08T15:56:32.478Z" },
{ url = "https://files.pythonhosted.org/packages/80/ca/704d4e2c90acb8bdf7ae593f5cbc95f58e82de95cc540fb75631c1054533/greenlet-3.4.0-cp311-cp311-manylinux_2_39_riscv64.whl", hash = "sha256:89995ce5ddcd2896d89615116dd39b9703bfa0c07b583b85b89bf1b5d6eddf81", size = 419745, upload-time = "2026-04-08T16:43:04.022Z" },
{ url = "https://files.pythonhosted.org/packages/a9/df/950d15bca0d90a0e7395eb777903060504cdb509b7b705631e8fb69ff415/greenlet-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ee407d4d1ca9dc632265aee1c8732c4a2d60adff848057cdebfe5fe94eb2c8a2", size = 1574623, upload-time = "2026-04-08T16:26:18.596Z" },
{ url = "https://files.pythonhosted.org/packages/1a/e7/0839afab829fcb7333c9ff6d80c040949510055d2d4d63251f0d1c7c804e/greenlet-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:956215d5e355fffa7c021d168728321fd4d31fd730ac609b1653b450f6a4bc71", size = 1639579, upload-time = "2026-04-08T15:57:29.231Z" },
{ url = "https://files.pythonhosted.org/packages/d9/2b/b4482401e9bcaf9f5c97f67ead38db89c19520ff6d0d6699979c6efcc200/greenlet-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:5cb614ace7c27571270354e9c9f696554d073f8aa9319079dcba466bbdead711", size = 238233, upload-time = "2026-04-08T17:02:54.286Z" },
@ -1187,7 +1185,7 @@ wheels = [
[[package]]
name = "llmai"
version = "0.1.9"
version = "0.2.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anthropic" },
@ -1195,9 +1193,9 @@ dependencies = [
{ name = "google-genai" },
{ name = "openai" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" }
sdist = { url = "https://files.pythonhosted.org/packages/1f/28/7dc14f9a417d933f8c799665a0a86ee31489dde40e9264ef6dc41b32759c/llmai-0.2.2.tar.gz", hash = "sha256:1f62d1e3d05fa5c43bbd948275398668d8f85c8d2fde252d34562332101dd7b3", size = 47863, upload-time = "2026-04-27T10:32:08.726Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" },
{ url = "https://files.pythonhosted.org/packages/b0/a1/e34d8aaa015fd7a2cfefa391b6eb2995aee675d7f8d93f636f4da8b07c13/llmai-0.2.2-py3-none-any.whl", hash = "sha256:d65c016983036319df704927b5e7fece494efde25b4865caf9a95d555a25449c", size = 59874, upload-time = "2026-04-27T10:32:07.272Z" },
]
[[package]]
@ -1634,6 +1632,7 @@ dependencies = [
{ name = "fastembed-vectorstore" },
{ name = "fastmcp" },
{ name = "google-genai" },
{ name = "jsonschema" },
{ name = "llmai" },
{ name = "mem0ai", extra = ["nlp"] },
{ name = "nltk" },
@ -1655,7 +1654,7 @@ requires-dist = [
{ name = "fastembed-vectorstore", specifier = ">=0.5.2" },
{ name = "fastmcp", specifier = ">=2.11.0" },
{ name = "google-genai", specifier = ">=1.28.0" },
{ name = "llmai", specifier = "==0.1.9" },
{ name = "llmai", url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl" },
{ name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" },
{ name = "nltk", specifier = ">=3.9.1" },
{ name = "openai", specifier = ">=1.98.0" },

View file

@ -45,11 +45,7 @@ const PresentationPage = ({ presentation_id }: { presentation_id: string }) => {
}
}, [presentationData]);
// Ensure /app_data and /static image paths resolve through the backend origin.
useEffect(() => {
const observer = setupImageUrlConverter();
return () => observer?.disconnect();
}, []);
// Function to fetch the slides
useEffect(() => {

File diff suppressed because it is too large Load diff

View file

@ -413,12 +413,18 @@ const PresentationHeader = ({
return (
<>
<div className="py-7 sticky top-0 bg-white z-50 mb-[17px] font-syne flex justify-between items-center gap-4">
{presentationData && !isStreaming && !isEditingTitle ? (
<ToolTip content="Rename presentation">{titleBlock}</ToolTip>
) : (
titleBlock
)}
<div className="py-[18px] px-4 sticky top-0 bg-white z-50 shadow-sm font-syne flex justify-between items-center gap-4">
<div className="flex items-center gap-3">
<img onClick={() => {
router.push("/dashboard");
}} src="/logo-with-bg.png" alt="" className="w-10 h-10 cursor-pointer object-contain" />
{presentationData && !isStreaming && !isEditingTitle ? (
<ToolTip content="Rename presentation">{titleBlock}</ToolTip>
) : (
titleBlock
)}
</div>
<div className="flex items-center gap-2.5">

View file

@ -1,5 +1,5 @@
"use client";
import React, { useEffect, useLayoutEffect, useState } from "react";
import React, { useEffect, useLayoutEffect, useRef, useState } from "react";
import { useSelector } from "react-redux";
import { RootState } from "@/store/store";
import "../../utils/prism-languages";
@ -21,8 +21,8 @@ import { PresentationPageProps } from "../types";
import LoadingState from "./LoadingState";
import { applyPresentationThemeToElement } from "../utils/applyPresentationThemeDom";
import { usePresentationUndoRedo } from "../hooks/PresentationUndoRedo";
import PresentationHeader from "./PresentationHeader";
import Chat from "./Chat";
const PresentationPage: React.FC<PresentationPageProps> = ({
presentation_id,
@ -33,6 +33,7 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
const [selectedSlide, setSelectedSlide] = useState(0);
const [isFullscreen, setIsFullscreen] = useState(false);
const [error, setError] = useState(false);
const slidesScrollContainerRef = useRef<HTMLDivElement | null>(null);
const router = useRouter();
@ -40,6 +41,11 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
const { presentationData, isStreaming } = useSelector(
(state: RootState) => state.presentationGeneration
);
const slidesLength = presentationData?.slides?.length ?? 0;
const lastStreamingSlideIndex =
slidesLength > 0
? presentationData?.slides?.[slidesLength - 1]?.index
: undefined;
// Auto-save functionality
const { isSaving } = useAutoSave({
@ -78,7 +84,38 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
fetchUserSlides
);
usePresentationUndoRedo();
useEffect(() => {
if (!isStreaming) return;
const scrollContainer = slidesScrollContainerRef.current;
if (!scrollContainer) return;
const frame = window.requestAnimationFrame(() => {
if (slidesLength <= 1) {
scrollContainer.scrollTo({ top: 0, behavior: "auto" });
return;
}
if (lastStreamingSlideIndex === undefined) return;
const slideElement = document.getElementById(
`slide-${lastStreamingSlideIndex}`
);
if (!slideElement) return;
const containerRect = scrollContainer.getBoundingClientRect();
const slideRect = slideElement.getBoundingClientRect();
const slideTop =
slideRect.top - containerRect.top + scrollContainer.scrollTop;
scrollContainer.scrollTo({
top: Math.max(slideTop, 0),
behavior: "smooth",
});
});
return () => window.cancelAnimationFrame(frame);
}, [isStreaming, lastStreamingSlideIndex, slidesLength]);
useEffect(() => {
trackEvent(MixpanelEvent.Presentation_Editor_Viewed, {
@ -141,66 +178,70 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
}
return (
<div className="h-screen overflow-hidden font-syne ">
<div className="h-screen overflow-hidden font-syne">
<div
style={{
background: "#ffffff",
background: "#EDEEEF",
}}
id="presentation-slides-wrapper"
className="flex gap-6 relative "
className="relative flex h-full flex-col overflow-hidden"
>
<div className="w-[200px]">
<SidePanel
selectedSlide={selectedSlide}
onSlideClick={handleSlideClick}
presentationId={presentation_id}
loading={loading}
/>
</div>
<div className=" w-full h-[calc(100vh-20px)] pr-[25px] overflow-y-auto">
<PresentationHeader presentation_id={presentation_id} isPresentationSaving={isSaving} currentSlide={selectedSlide} />
<div
style={{
background: "rgba(255, 255, 255, 0.10)",
boxShadow: "0 0 20.01px 0 rgba(122, 90, 248, 0.16) inset",
}}
className="p-6 rounded-[20px] font-inter flex flex-col items-center overflow-hidden justify-center border border-[#EDECEC] "
>
<div className="w-full max-w-[1280px] h-full">
{!presentationData ||
loading ||
!presentationData?.slides ||
presentationData?.slides.length === 0 ? (
<div className="relative w-full h-[calc(100vh-120px)] mx-auto">
<div className="">
{Array.from({ length: 2 }).map((_, index) => (
<Skeleton
key={index}
className="aspect-video bg-gray-400 my-4 w-full mx-auto "
/>
))}
<PresentationHeader presentation_id={presentation_id} isPresentationSaving={isSaving} currentSlide={selectedSlide} />
<div className="flex flex-1 min-h-0 gap-6 overflow-hidden">
<div className="w-[120px] h-full shrink-0 self-start sticky top-0 pt-[18px]">
<SidePanel
selectedSlide={selectedSlide}
onSlideClick={handleSlideClick}
presentationId={presentation_id}
loading={loading}
/>
</div>
<div className="w-full min-w-0 h-full flex-1 pt-[18px]">
<div
ref={slidesScrollContainerRef}
className="font-inter h-full overflow-y-auto hide-scrollbar scroll-pt-[18px]"
>
<div className="w-full max-w-[1280px] min-h-full mx-auto flex flex-col items-center pb-8">
{!presentationData ||
loading ||
!presentationData?.slides ||
presentationData?.slides.length === 0 ? (
<div className="relative w-full h-[calc(100vh-120px)] mx-auto hide-scrollbar">
<div className="">
{Array.from({ length: 2 }).map((_, index) => (
<Skeleton
key={index}
className="aspect-video bg-gray-400 my-4 w-full mx-auto "
/>
))}
</div>
{stream && <LoadingState />}
</div>
{stream && <LoadingState />}
</div>
) : (
<>
{presentationData &&
presentationData.slides &&
presentationData.slides.length > 0 &&
presentationData.slides.map((slide: any, index: number) => (
<SlideContent
key={`${slide.type}-${index}-${slide.index}`}
slide={slide}
index={index}
presentationId={presentation_id}
/>
))}
</>
)}
) : (
<>
{presentationData &&
presentationData.slides &&
presentationData.slides.length > 0 &&
presentationData.slides.map((slide: any, index: number) => (
<SlideContent
key={`${slide.type}-${index}-${slide.index}`}
slide={slide}
index={index}
presentationId={presentation_id}
/>
))}
</>
)}
</div>
</div>
</div>
<div className="w-full max-w-[370px] h-full shrink-0 self-start sticky top-0">
<Chat
presentationId={presentation_id}
currentSlide={selectedSlide}
onPresentationChanged={() => fetchUserSlides({ clearHistory: false })}
/>
</div>
</div>
</div>
</div>

View file

@ -19,11 +19,11 @@ import {
} from "@dnd-kit/sortable";
import { setPresentationData } from "@/store/slices/presentationGeneration";
import { SortableSlide } from "./SortableSlide";
import SlideScale from "../../components/PresentationRender";
import { Separator } from "@/components/ui/separator";
import { usePathname, useRouter } from "next/navigation";
import { usePathname } from "next/navigation";
import NewSlide from "./NewSlide";
import { trackEvent, MixpanelEvent } from "@/utils/mixpanel";
import { SlideThumbnailCard } from "./SlideThumbnailCard";
interface SidePanelProps {
selectedSlide: number;
@ -40,8 +40,6 @@ const SidePanel = ({
loading,
}: SidePanelProps) => {
const router = useRouter();
const pathname = usePathname();
const [showNewSlideSelection, setShowNewSlideSelection] = useState(false);
@ -132,48 +130,36 @@ const SidePanel = ({
}
return (
<div className="bg-[#F6F6F9] pt-8 px-4 w-[200px]">
<div className="px-4 w-[120px] h-full">
<img onClick={() => {
router.push("/dashboard");
}} src="/logo-with-bg.png" alt="" className="w-10 h-10 cursor-pointer object-contain" />
<Separator orientation="horizontal" className="my-6 " />
<div
className={`
relative bg-[#F6F6F9] h-full z-50 xl:z-auto
relative h-full z-50 xl:z-auto
transition-all duration-300 ease-in-out
`}
>
<div
className="w-full h-[calc(100vh-120px)] hide-scrollbar overflow-hidden slide-theme "
className="w-full h-full hide-scrollbar overflow-hidden slide-theme flex flex-col"
>
<p className="text-xl font-normal font-syne pb-3.5 text-[#000000]">Slides ({presentationData?.slides?.length})</p>
<DndContext
sensors={sensors}
collisionDetection={closestCenter}
onDragEnd={handleDragEnd}
>
<div className=" overflow-y-auto w-full hide-scrollbar h-[calc(100%-140px)] space-y-3.5">
<div className="overflow-y-auto w-full hide-scrollbar min-h-0 flex-1 space-y-3.5">
{isStreaming ? (
presentationData &&
presentationData?.slides.map((slide: any, index: number) => (
<div
<SlideThumbnailCard
key={`${slide.id}-${index}`}
onClick={() => onSlideClick(index)}
className={` cursor-pointer ring-2 rounded-[12px] transition-all duration-200 ${selectedSlide === index ? ' ring-[#5141e5]' : 'ring-gray-200'
}`}
>
<div className=" bg-white pointer-events-none relative overflow-hidden aspect-video">
<div className="absolute bg-gray-100/5 z-50 top-0 left-0 w-full h-full" />
<div className="transform scale-[0.2] flex justify-center items-center origin-top-left w-[500%] h-[500%]">
<SlideScale slide={slide} />
</div>
</div>
</div>
slide={slide}
index={index}
selected={selectedSlide === index}
onClick={() => onSlideClick(slide.index ?? index)}
/>
))
) : (
<SortableContext
@ -203,7 +189,7 @@ const SidePanel = ({
<button
type="button"
onClick={handleAddSlideClick}
className="pt-6 gap-2 flex flex-col py-2 duration-300 items-center justify-center rounded-lg cursor-pointer mx-auto"
className="py-4 gap-2 flex flex-col duration-300 items-center justify-center rounded-lg cursor-pointer mx-auto"
>
<Plus className="w-3.5 h-3.5" />
<span className="text-[11px] font-normal text-[#000000]">Add Slide</span>

View file

@ -103,30 +103,6 @@ const SlideContent = ({ slide, index, presentationId }: SlideContentProps) => {
});
}
};
// Scroll to the new slide when streaming and new slides are being generated
useEffect(() => {
if (
presentationData &&
presentationData?.slides &&
presentationData.slides.length > 1 &&
isStreaming
) {
// Scroll to the last slide (newly generated during streaming)
const lastSlideIndex = presentationData.slides.length - 1;
const slideElement = document.getElementById(
`slide-${presentationData.slides[lastSlideIndex].index}`
);
if (slideElement) {
slideElement.scrollIntoView({
behavior: "smooth",
block: "center",
});
}
}
}, [presentationData?.slides?.length, isStreaming]);
useEffect(() => {
if (slide.layout.includes("custom")) {

View file

@ -0,0 +1,54 @@
import React, { forwardRef } from "react";
import type { Slide } from "../../types/slide";
import { V1ContentRender } from "../../components/V1ContentRender";
interface SlideThumbnailCardProps extends React.HTMLAttributes<HTMLDivElement> {
slide: Slide;
index: number;
selected: boolean;
}
const SCALE = 0.061;
export const SlideThumbnailCard = forwardRef<
HTMLDivElement,
SlideThumbnailCardProps
>(({ slide, index, selected, className = "", style, ...props }, ref) => {
return (
<div
ref={ref}
style={{
backgroundColor: "var(--card-color, #ffffff)",
borderColor: selected ? "#5141e5" : "var(--stroke, #e5e7eb)",
...style,
}}
className={`cursor-pointer border relative p-1.5 rounded-[12px] overflow-hidden transition-all duration-200 ${
selected ? "border-[#BDB4FE]" : "border-[#EDEEEF]"
} ${className}`}
{...props}
>
<p className="pointer-events-none absolute -left-1 top-1/2 z-50 flex h-[18px] min-w-[18px] -translate-y-1/2 items-center justify-center rounded-full border border-[#EDEEEF] bg-white px-1 text-[10px] font-medium text-[#191919] shadow-sm">
{index + 1}
</p>
<div
className="relative"
style={{ height: `${720 * SCALE}px`, overflow: "hidden" }}
>
<div
className="absolute top-0 left-0 rounded-[10px] overflow-hidden pointer-events-none"
style={{
width: 1280,
height: 720,
transformOrigin: "top left",
transform: `scale(${SCALE})`,
}}
>
<V1ContentRender slide={slide} isEditMode={true} />
</div>
</div>
</div>
);
});
SlideThumbnailCard.displayName = "SlideThumbnailCard";

View file

@ -1,16 +1,14 @@
import { useSortable } from '@dnd-kit/sortable';
import { CSS } from '@dnd-kit/utilities';
import { Slide } from '../../types/slide';
import type { Slide } from '../../types/slide';
import { useRef } from 'react';
import { V1ContentRender } from '../../components/V1ContentRender';
import { useSearchParams } from 'next/navigation';
import { SlideThumbnailCard } from './SlideThumbnailCard';
interface SortableSlideProps {
slide: Slide;
index: number;
selectedSlide: number;
onSlideClick: (index: any) => void;
}
const SCALE = 0.125;
export function SortableSlide({ slide, index, selectedSlide, onSlideClick }: SortableSlideProps) {
const lastClickTime = useRef(0);
@ -27,8 +25,6 @@ export function SortableSlide({ slide, index, selectedSlide, onSlideClick }: Sor
transform: CSS.Transform.toString(transform),
transition,
opacity: isDragging ? 0.5 : 1,
backgroundColor: `var(--card-color, #ffffff)`,
borderColor: selectedSlide === index ? `#5141e5` : `var(--stroke, #e5e7eb)`
};
const handleClick = (e: React.MouseEvent) => {
@ -47,33 +43,15 @@ export function SortableSlide({ slide, index, selectedSlide, onSlideClick }: Sor
};
return (
<div
<SlideThumbnailCard
ref={setNodeRef}
slide={slide}
index={index}
selected={selectedSlide === index}
style={style}
{...attributes}
{...listeners}
onClick={handleClick}
className={` cursor-pointer border relative p-1 rounded-[12px] transition-all duration-200 ${selectedSlide === index ? ' border-[#BDB4FE]' : 'border-[#EDEEEF]'
}`}
>
<div
className="relative"
style={{ height: `${720 * SCALE}px`, overflow: "hidden" }}
>
<div
className="absolute top-0 left-0 pointer-events-none"
style={{
width: 1280,
height: 720,
transformOrigin: "top left",
transform: `scale(${SCALE})`,
}}
>
<V1ContentRender slide={slide} isEditMode={true} />
</div>
</div>
</div>
/>
);
}
}

View file

@ -1,136 +1,100 @@
import { useCallback } from "react";
import { useDispatch, useSelector } from "react-redux";
import { RootState } from "@/store/store";
import { finishUndoRedo, redo, undo } from "@/store/slices/undoRedoSlice";
import { redo, undo } from "@/store/slices/undoRedoSlice";
import { useKeyboardShortcut } from "../../hooks/use-keyboard-shortcut";
import { setPresentationData } from "@/store/slices/presentationGeneration";
export const usePresentationUndoRedo = () => {
const dispatch = useDispatch();
const undoRedoState = useSelector((state: RootState) => state.undoRedo);
const { presentationData } = useSelector((state: RootState) => state.presentationGeneration);
const canUndo = undoRedoState.past.length > 0;
const canRedo = undoRedoState.future.length > 0;
const onUndo = useCallback(() => {
if (!canUndo) return;
const previousState = undoRedoState.past[undoRedoState.past.length - 1];
dispatch(undo());
if (previousState) {
const newSlides = JSON.parse(JSON.stringify(previousState.slides));
dispatch(
setPresentationData({
...presentationData!,
slides: newSlides,
})
);
}
setTimeout(() => {
dispatch(finishUndoRedo());
}, 100);
}, [canUndo, dispatch, presentationData, undoRedoState.past]);
const onRedo = useCallback(() => {
if (!canRedo) return;
const nextState = undoRedoState.future[0];
dispatch(redo());
if (nextState) {
const newSlides = JSON.parse(JSON.stringify(nextState.slides));
dispatch(
setPresentationData({
...presentationData!,
slides: newSlides,
})
);
}
setTimeout(() => {
dispatch(finishUndoRedo());
}, 100);
}, [canRedo, dispatch, presentationData, undoRedoState.future]);
// Handle undo
useKeyboardShortcut(
["z"],
(e) => {
if (e.ctrlKey && !e.shiftKey && undoRedoState.past.length > 0) {
e.preventDefault();
// Get the previous state before dispatching undo
const previousState = undoRedoState.past[undoRedoState.past.length - 1];
// Perform undo
dispatch(undo());
// Use the previousState directly instead of relying on the updated undoRedoState
if (previousState) {
// Create a deep copy to ensure no reference issues
const newSlides = JSON.parse(JSON.stringify(previousState.slides));
// Update the presentation data with the properly structured slides
dispatch(
setPresentationData({
...presentationData!,
slides: newSlides,
})
);
}
// Reset the undo/redo flag
setTimeout(() => {
dispatch(finishUndoRedo());
}, 100);
}
},
[undoRedoState.past, presentationData]
const dispatch = useDispatch();
const undoRedoState = useSelector((state: RootState) => state.undoRedo);
const { presentationData } = useSelector(
(state: RootState) => state.presentationGeneration
);
// Handle redo
const canUndo = undoRedoState.past.length > 0;
const canRedo = undoRedoState.future.length > 0;
const applySlidesSnapshot = useCallback(
(slidesSnapshot: unknown) => {
if (!presentationData || !Array.isArray(slidesSnapshot)) {
return;
}
const clonedSlides = JSON.parse(JSON.stringify(slidesSnapshot));
dispatch(
setPresentationData({
...presentationData,
slides: clonedSlides,
})
);
},
[dispatch, presentationData]
);
const onUndo = useCallback(() => {
if (!canUndo) {
return;
}
const previousState = undoRedoState.past[undoRedoState.past.length - 1];
if (!previousState) {
return;
}
dispatch(undo());
applySlidesSnapshot(previousState.slides);
}, [applySlidesSnapshot, canUndo, dispatch, undoRedoState.past]);
const onRedo = useCallback(() => {
if (!canRedo) {
return;
}
const nextState = undoRedoState.future[0];
if (!nextState) {
return;
}
dispatch(redo());
applySlidesSnapshot(nextState.slides);
}, [applySlidesSnapshot, canRedo, dispatch, undoRedoState.future]);
// Handle undo (Ctrl + Z)
useKeyboardShortcut(
["z"],
(e) => {
if (e.ctrlKey && e.shiftKey && undoRedoState.future.length > 0) {
if (e.ctrlKey && !e.shiftKey && canUndo) {
e.preventDefault();
// Get the next state before dispatching redo
const nextState = undoRedoState.future[0];
// Perform redo
dispatch(redo());
// Use the nextState directly instead of relying on the updated undoRedoState
if (nextState) {
// Create a deep copy to ensure no reference issues
const newSlides = JSON.parse(JSON.stringify(nextState.slides));
// Update the presentation data with the properly structured slides
dispatch(
setPresentationData({
...presentationData!,
slides: newSlides,
})
);
}
// Reset the undo/redo flag
setTimeout(() => {
dispatch(finishUndoRedo());
}, 100);
onUndo();
}
},
[undoRedoState.future, presentationData]
[canUndo, onUndo]
);
// Handle redo (Ctrl + Shift + Z)
useKeyboardShortcut(
["z"],
(e) => {
if (e.ctrlKey && e.shiftKey && canRedo) {
e.preventDefault();
onRedo();
}
},
[canRedo, onRedo]
);
// Handle redo (Ctrl + Y)
useKeyboardShortcut(
["y"],
(e) => {
if (e.ctrlKey && canRedo) {
e.preventDefault();
onRedo();
}
},
[canRedo, onRedo]
);
return { onUndo, onRedo, canUndo, canRedo };
}
};

View file

@ -17,14 +17,16 @@ export const usePresentationData = (
) => {
const dispatch = useDispatch();
const fetchUserSlides = useCallback(async () => {
const fetchUserSlides = useCallback(async (options?: { clearHistory?: boolean }) => {
try {
const data = await DashboardApi.getPresentation(presentationId);
if (data) {
dispatch(setPresentationData(data));
dispatch(clearHistory());
if (options?.clearHistory ?? true) {
dispatch(clearHistory());
}
setLoading(false);
}
if (data.fonts) {

View file

@ -0,0 +1,300 @@
import { buildAbsoluteApiRequestUrl, getApiUrl } from "@/utils/api";
import { ApiResponseHandler } from "./api-error-handler";
import { getHeader } from "./header";
export interface ChatMessageRequest {
presentation_id: string;
message: string;
conversation_id?: string;
}
export interface ChatMessageResponse {
conversation_id?: string;
response: string;
tool_calls?: string[];
}
export interface ChatHistoryMessage {
role: string;
content: string;
created_at?: string;
}
export interface ChatHistoryData {
presentation_id: string;
conversation_id: string;
messages: ChatHistoryMessage[];
}
export interface ChatConversationSummary {
conversation_id: string;
updated_at?: string | null;
last_message_preview?: string | null;
}
export interface ChatStreamTrace {
kind?: string;
round?: number;
tool?: string;
status?: string;
message?: string;
tools?: string[];
}
export interface ChatStreamHandlers {
onChunk?: (chunk: string) => void;
onStatus?: (status: string) => void;
onTrace?: (trace: ChatStreamTrace) => void;
onComplete?: (response: ChatMessageResponse) => void;
}
interface ChatStreamDataChunk {
type: "chunk";
chunk?: unknown;
}
interface ChatStreamDataComplete {
type: "complete";
chat?: unknown;
}
interface ChatStreamDataError {
type: "error";
detail?: unknown;
}
interface ChatStreamDataStatus {
type: "status";
status?: unknown;
}
interface ChatStreamDataTrace {
type: "trace";
trace?: unknown;
}
type ChatStreamData =
| ChatStreamDataChunk
| ChatStreamDataComplete
| ChatStreamDataError
| ChatStreamDataStatus
| ChatStreamDataTrace
| Record<string, unknown>;
export class PresentationChatApi {
static async listConversations(
presentationId: string
): Promise<ChatConversationSummary[]> {
const u = new URL(
buildAbsoluteApiRequestUrl("/api/v1/ppt/chat/conversations")
);
u.searchParams.set("presentation_id", presentationId);
const response = await fetch(u.toString(), {
headers: getHeader(),
cache: "no-cache",
});
return await ApiResponseHandler.handleResponse(
response,
"Failed to list chat conversations"
);
}
static async getHistory(
presentationId: string,
conversationId: string
): Promise<ChatHistoryData> {
const u = new URL(buildAbsoluteApiRequestUrl("/api/v1/ppt/chat/history"));
u.searchParams.set("presentation_id", presentationId);
u.searchParams.set("conversation_id", conversationId);
const response = await fetch(u.toString(), {
headers: getHeader(),
cache: "no-cache",
});
return await ApiResponseHandler.handleResponse(
response,
"Failed to load chat history"
);
}
static async sendMessage(
payload: ChatMessageRequest
): Promise<ChatMessageResponse> {
const response = await fetch(getApiUrl("/api/v1/ppt/chat/message"), {
method: "POST",
headers: getHeader(),
body: JSON.stringify(payload),
cache: "no-cache",
});
return await ApiResponseHandler.handleResponse(
response,
"Failed to send chat message"
);
}
static async streamMessage(
payload: ChatMessageRequest,
handlers: ChatStreamHandlers = {},
options?: { signal?: AbortSignal }
): Promise<ChatMessageResponse> {
const response = await fetch(getApiUrl("/api/v1/ppt/chat/message/stream"), {
method: "POST",
headers: getHeader(),
body: JSON.stringify(payload),
cache: "no-cache",
signal: options?.signal,
});
if (!response.ok) {
await ApiResponseHandler.handleResponse(
response,
"Failed to stream chat message"
);
throw new Error("Failed to stream chat message");
}
if (!response.body) {
throw new Error("No response body received from chat stream");
}
const reader = response.body.getReader();
const decoder = new TextDecoder("utf-8");
let buffer = "";
let finalResponse: ChatMessageResponse | null = null;
const processSseFrame = (frame: string) => {
const normalized = frame.replaceAll("\r", "");
const lines = normalized.split("\n");
let eventName = "";
const dataLines: string[] = [];
for (const line of lines) {
if (line.startsWith("event:")) {
eventName = line.slice(6).trim();
continue;
}
if (line.startsWith("data:")) {
dataLines.push(line.slice(5).trimStart());
}
}
if (eventName && eventName !== "response") {
return;
}
if (!dataLines.length) {
return;
}
let parsedData: ChatStreamData;
try {
parsedData = JSON.parse(dataLines.join("\n")) as ChatStreamData;
} catch {
return;
}
const payloadType = parsedData.type;
if (payloadType === "chunk") {
const chunk = parsedData.chunk;
if (typeof chunk === "string" && chunk.length > 0) {
handlers.onChunk?.(chunk);
}
return;
}
if (payloadType === "complete") {
const chatPayload = (parsedData as ChatStreamDataComplete).chat;
if (
chatPayload &&
typeof chatPayload === "object" &&
typeof (chatPayload as { response?: unknown }).response === "string"
) {
const typedResponse: ChatMessageResponse = {
conversation_id:
typeof (chatPayload as { conversation_id?: unknown })
.conversation_id === "string"
? (chatPayload as { conversation_id?: string }).conversation_id
: undefined,
response: (chatPayload as { response: string }).response,
tool_calls: Array.isArray(
(chatPayload as { tool_calls?: unknown }).tool_calls
)
? (
(chatPayload as { tool_calls?: unknown[] }).tool_calls ?? []
).filter((item): item is string => typeof item === "string")
: [],
};
finalResponse = typedResponse;
handlers.onComplete?.(typedResponse);
}
return;
}
if (payloadType === "error") {
const detail = (parsedData as ChatStreamDataError).detail;
const message =
typeof detail === "string" && detail.trim().length > 0
? detail
: "Chat stream failed";
throw new Error(message);
}
if (payloadType === "status") {
const status = (parsedData as ChatStreamDataStatus).status;
if (typeof status === "string" && status.trim().length > 0) {
handlers.onStatus?.(status);
}
return;
}
if (payloadType === "trace") {
const trace = (parsedData as ChatStreamDataTrace).trace;
if (trace && typeof trace === "object") {
const typedTrace = trace as Record<string, unknown>;
handlers.onTrace?.({
kind:
typeof typedTrace.kind === "string" ? typedTrace.kind : undefined,
round:
typeof typedTrace.round === "number" ? typedTrace.round : undefined,
tool:
typeof typedTrace.tool === "string" ? typedTrace.tool : undefined,
status:
typeof typedTrace.status === "string" ? typedTrace.status : undefined,
message:
typeof typedTrace.message === "string" ? typedTrace.message : undefined,
tools: Array.isArray(typedTrace.tools)
? typedTrace.tools.filter(
(value): value is string => typeof value === "string"
)
: undefined,
});
}
}
};
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
let delimiterIndex = buffer.indexOf("\n\n");
while (delimiterIndex >= 0) {
const frame = buffer.slice(0, delimiterIndex);
buffer = buffer.slice(delimiterIndex + 2);
processSseFrame(frame);
delimiterIndex = buffer.indexOf("\n\n");
}
}
if (buffer.trim().length > 0) {
processSseFrame(buffer);
}
if (finalResponse) {
return finalResponse;
}
throw new Error("Chat stream ended before completion");
}
}

View file

@ -64,7 +64,8 @@ const page = () => {
</div>
<div
className='fixed z-0 -bottom-[14.5rem] left-0 w-full h-full'
className="fixed z-0 -bottom-[14.5rem] left-0 w-full h-full pointer-events-none"
aria-hidden
style={{
height: "341px",
borderRadius: '1440px',

View file

@ -96,6 +96,7 @@ input[type="number"]::-webkit-outer-spin-button {
}
input[type="number"] {
appearance: textfield;
-moz-appearance: textfield;
}
@ -104,7 +105,6 @@ tbody tr {
display: table;
width: 100%;
table-layout: fixed;
/* even columns width , fix width of table too*/
}
thead {
@ -276,83 +276,30 @@ thead {
@apply prose prose-slate max-w-none;
}
/* .markdown-content h1 {
@apply text-xl font-bold mb-4 text-gray-900;
.chat-markdown {
@apply prose prose-slate max-w-none;
}
.markdown-content h2 {
@apply text-lg font-bold mb-3 text-gray-900;
.chat-markdown :where(p, ul, ol, pre, blockquote) {
margin-bottom: 0.75rem;
}
.markdown-content h3 {
@apply text-base font-bold mb-2 text-gray-900;
.chat-markdown :where(p:last-child, ul:last-child, ol:last-child, pre:last-child, blockquote:last-child) {
margin-bottom: 0;
}
.markdown-content h4 {
@apply text-sm font-bold mb-2 text-gray-900;
.chat-markdown :where(ul, ol) {
padding-left: 1.25rem;
}
.markdown-content h5 {
@apply text-xs font-bold mb-1 text-gray-900;
.chat-markdown :where(code) {
@apply rounded bg-[#F5F5F5] px-1 py-0.5 text-[0.85em];
}
.markdown-content h6 {
@apply text-xs font-semibold mb-1 text-gray-900;
.chat-markdown :where(pre) {
@apply overflow-x-auto rounded-lg bg-[#F5F5F5] p-3;
}
.markdown-content p {
@apply mb-4 text-base text-gray-700;
}
.markdown-content ul {
@apply list-disc pl-6 mb-4;
}
.markdown-content ol {
@apply list-decimal pl-6 mb-4;
}
.markdown-content li {
@apply mb-1;
}
.markdown-content strong,
.markdown-content b {
@apply font-bold text-gray-900;
}
.markdown-content em {
@apply italic;
}
.markdown-content blockquote {
@apply border-l-4 border-gray-300 pl-4 italic my-4;
}
.markdown-content code {
@apply bg-gray-100 px-1 py-0.5 rounded font-mono text-sm;
}
.markdown-content pre {
@apply bg-gray-100 p-4 rounded-lg my-4 overflow-x-auto;
}
.markdown-content a {
@apply text-blue-600 hover:text-blue-800 underline;
}
.markdown-content table {
@apply min-w-full border border-gray-300 my-4;
}
.markdown-content th {
@apply bg-gray-100 border border-gray-300 px-4 py-2 font-bold;
}
.markdown-content td {
@apply border border-gray-300 px-4 py-2;
} */
/* Override Tailwind Typography prose heading sizes for markdown editor */
.prose h1 {
font-size: 18px !important;

View file

@ -1,5 +0,0 @@
{
"description": "Dark, theme-ready presentation layouts with covers, structured content grids, timelines, narrative splits, and chart-driven slides",
"ordered": false,
"default": false
}

View file

@ -54,7 +54,8 @@ export default function Home() {
<div className="flex min-h-screen relative">
<div
className='fixed z-0 -bottom-[14.5rem] left-0 w-full h-full'
className="fixed z-0 -bottom-[14.5rem] left-0 w-full h-full pointer-events-none"
aria-hidden
style={{
height: "341px",
borderRadius: '1440px',

View file

@ -13,6 +13,7 @@ interface UndoRedoState {
future: HistoryState[];
maxHistorySize: number;
isUndoRedoInProgress: boolean;
pendingHistorySkips: number;
}
// Helper function for deep copy
@ -25,7 +26,8 @@ const initialState: UndoRedoState = {
present: null,
future: [],
maxHistorySize: 30,
isUndoRedoInProgress: false
isUndoRedoInProgress: false,
pendingHistorySkips: 0,
};
const undoRedoSlice = createSlice({
@ -33,15 +35,22 @@ const undoRedoSlice = createSlice({
initialState,
reducers: {
addToHistory: (state, action: PayloadAction<{slides: Slide[], actionType: string}>) => {
// Skip if undo/redo is in progress
if (state.pendingHistorySkips > 0) {
state.pendingHistorySkips -= 1;
if (state.pendingHistorySkips === 0) {
state.isUndoRedoInProgress = false;
}
return;
}
// Defensive guard for any stale in-progress state.
if (state.isUndoRedoInProgress) {
return;
}
// Deep copy the slides to avoid reference issues
const newSlides = deepCopy(action.payload.slides);
// Only add to history if the slides have actually changed
if (!state.present) {
state.present = {
@ -51,84 +60,80 @@ const undoRedoSlice = createSlice({
};
return;
}
// Skip if slides are identical
if (JSON.stringify(state.present.slides) === JSON.stringify(newSlides)) {
return;
}
// Add current state to past
state.past.push(state.present);
// Limit history size
if (state.past.length > state.maxHistorySize) {
state.past.shift();
}
// Clear future on new change
state.future = [];
// Set new present
state.present = {
slides: newSlides,
timestamp: Date.now(),
actionType: action.payload.actionType
};
},
undo: (state) => {
if (state.past.length === 0) {
if (state.past.length === 0) {
return;
}
state.isUndoRedoInProgress = true;
state.pendingHistorySkips = 1;
// Move present to future
if (state.present) {
state.future.unshift(deepCopy(state.present));
}
// Get last past state
const previous = state.past[state.past.length - 1];
state.past = state.past.slice(0, -1);
state.present = deepCopy(previous);
},
redo: (state) => {
if (state.future.length === 0) {
return;
}
state.isUndoRedoInProgress = true;
state.pendingHistorySkips = 1;
// Move present to past
if (state.present) {
state.past.push(deepCopy(state.present));
}
// Get first future state
const next = state.future[0];
state.future = state.future.slice(1);
state.present = deepCopy(next);
},
finishUndoRedo: (state) => {
state.isUndoRedoInProgress = false;
state.pendingHistorySkips = 0;
},
clearHistory: (state) => {
state.past = [];
state.future = [];
state.present = null;
// Keep present
state.isUndoRedoInProgress = false;
state.pendingHistorySkips = 0;
}
}
});

View file

@ -63,6 +63,25 @@ export function getApiUrl(path: string): string {
return normalizedPath;
}
/**
* getApiUrl may return a path without host (e.g. `/api/v1/...`). A single-argument
* `new URL("/api/...")` call is invalid; use this before `new URL(..., ...)`-style
* builds or to obtain an absolute string for `URL` + `searchParams`.
*/
export function buildAbsoluteApiRequestUrl(
path: string,
baseForRelative: string = typeof window !== "undefined" &&
window.location?.origin
? window.location.origin
: "http://127.0.0.1:5000"
): string {
const resolved = getApiUrl(path);
if (isAbsoluteHttpUrl(resolved)) {
return resolved;
}
return new URL(resolved, baseForRelative).toString();
}
function hasBackendAssetPrefix(path: string): boolean {
return path.startsWith("/static/") || path.startsWith("/app_data/");
}