Merge pull request #545 from presenton/refactor/presenton-chat-stream
Refactor/presenton chat stream
This commit is contained in:
commit
cb731aa6c3
60 changed files with 5181 additions and 618 deletions
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* CLI bridge for Python: one JSON line on stdout for LiteParse extraction.
|
||||
* CLI bridge for Python: by default, raw extracted text on stdout (--python-bridge plain);
|
||||
* or one JSON line (--python-bridge json) for backward compatibility.
|
||||
*
|
||||
* OCR follows LlamaIndex LiteParse guidance (built-in Tesseract by default):
|
||||
* https://developers.llamaindex.ai/liteparse/guides/ocr/
|
||||
|
|
@ -56,14 +57,31 @@ function emit(result, exitCode = 0) {
|
|||
process.exit(exitCode);
|
||||
}
|
||||
|
||||
/** "plain" = success: UTF-8 text on stdout only. "json" = one JSON line (legacy, huge payloads can break). */
|
||||
const pyBridgeArg = readArg("--python-bridge");
|
||||
const pyBridge =
|
||||
pyBridgeArg == null || pyBridgeArg === ""
|
||||
? "json"
|
||||
: String(pyBridgeArg).trim().toLowerCase() === "plain"
|
||||
? "plain"
|
||||
: "json";
|
||||
|
||||
function bridgeError(message, exitCode) {
|
||||
if (pyBridge === "plain") {
|
||||
process.stderr.write(`${message}\n`);
|
||||
process.exit(exitCode);
|
||||
}
|
||||
emit({ ok: false, error: message }, exitCode);
|
||||
}
|
||||
|
||||
const filePath = readArg("--file");
|
||||
if (!filePath) {
|
||||
emit({ ok: false, error: "Missing required --file argument" }, 2);
|
||||
bridgeError("Missing required --file argument", 2);
|
||||
}
|
||||
|
||||
const resolvedPath = path.resolve(filePath);
|
||||
if (!fs.existsSync(resolvedPath)) {
|
||||
emit({ ok: false, error: `File not found: ${resolvedPath}` }, 2);
|
||||
bridgeError(`File not found: ${resolvedPath}`, 2);
|
||||
}
|
||||
|
||||
const ocrEnabled = parseBool(readArg("--ocr-enabled"), true);
|
||||
|
|
@ -117,6 +135,10 @@ try {
|
|||
|
||||
const result = await parser.parse(resolvedPath, true);
|
||||
const text = result?.text ?? "";
|
||||
if (pyBridge === "plain") {
|
||||
process.stdout.write(text);
|
||||
process.exit(0);
|
||||
}
|
||||
emit({
|
||||
ok: true,
|
||||
filePath: resolvedPath,
|
||||
|
|
@ -133,6 +155,13 @@ try {
|
|||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const stack = error instanceof Error ? error.stack : undefined;
|
||||
if (pyBridge === "plain") {
|
||||
if (stack) {
|
||||
process.stderr.write(`${stack}\n`);
|
||||
}
|
||||
process.stderr.write(`${message}\n`);
|
||||
process.exit(1);
|
||||
}
|
||||
if (stack) {
|
||||
process.stderr.write(`${stack}\n`);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import async_session_maker
|
||||
from utils.ollama import pull_ollama_model
|
||||
|
||||
|
||||
|
|
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
|
|||
)
|
||||
log_event_count = 0
|
||||
|
||||
session = await get_container_db_async_session().__anext__()
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
return
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
return
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
|
||||
|
||||
async def upsert_ollama_pull_status(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import get_async_session
|
||||
from utils.ollama import list_pulled_ollama_models
|
||||
|
||||
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
||||
|
|
@ -29,7 +29,7 @@ async def get_available_models():
|
|||
async def pull_model(
|
||||
model: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
session: AsyncSession = Depends(get_container_db_async_session),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
|
||||
if model not in SUPPORTED_OLLAMA_MODELS:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ def main() -> None:
|
|||
p = _sqlite_file_path(sync_url)
|
||||
if p is not None:
|
||||
paths.append(p)
|
||||
paths.append(p.parent / "container.db")
|
||||
|
||||
seen: set[Path] = set()
|
||||
for path in paths:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
import os
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
|
|
@ -21,7 +20,6 @@ from models.sql.template_create_info import TemplateCreateInfoModel
|
|||
from models.sql.slide import SlideModel
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
from utils.get_env import get_migrate_database_on_startup_env
|
||||
|
||||
|
||||
|
|
@ -42,22 +40,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
yield session
|
||||
|
||||
|
||||
# Container DB (Lives inside the app data directory)
|
||||
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
|
||||
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
|
||||
container_db_engine: AsyncEngine = create_async_engine(
|
||||
container_db_url, connect_args={"check_same_thread": False}
|
||||
)
|
||||
container_db_async_session_maker = async_sessionmaker(
|
||||
container_db_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with container_db_async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Create Database and Tables
|
||||
async def create_db_and_tables():
|
||||
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
|
||||
|
|
@ -76,18 +58,11 @@ async def create_db_and_tables():
|
|||
TemplateModel.__table__,
|
||||
WebhookSubscription.__table__,
|
||||
AsyncPresentationGenerationTaskModel.__table__,
|
||||
OllamaPullStatus.__table__,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async with container_db_engine.begin() as conn:
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: SQLModel.metadata.create_all(
|
||||
sync_conn,
|
||||
tables=[OllamaPullStatus.__table__],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def dispose_engines():
|
||||
"""Dispose all engine connection pools.
|
||||
|
|
@ -97,4 +72,3 @@ async def dispose_engines():
|
|||
database and prevent stale / leaked connections.
|
||||
"""
|
||||
await sql_engine.dispose()
|
||||
await container_db_engine.dispose()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
|
@ -30,6 +32,129 @@ except Exception:
|
|||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _unwrap_liteparse_json_line_if_stored(text: str) -> str:
|
||||
"""If the whole JSON line from the LiteParse runner was stored as the document, keep only the text field."""
|
||||
if not text:
|
||||
return text
|
||||
s = text.lstrip()
|
||||
if not s.startswith("{"):
|
||||
return text
|
||||
try:
|
||||
payload = json.loads(s)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
return text
|
||||
if not isinstance(payload, dict):
|
||||
return text
|
||||
if (
|
||||
payload.get("ok") is True
|
||||
and "filePath" in payload
|
||||
and isinstance(payload.get("text"), str)
|
||||
):
|
||||
return payload["text"]
|
||||
return text
|
||||
|
||||
|
||||
_RE_TEXT_KEY = re.compile(r'"text"\s*:\s*"')
|
||||
|
||||
|
||||
def _json_unescape_quoted_value(s: str, content_start: int) -> str:
|
||||
"""
|
||||
Unescape a JSON string value. `content_start` is the index of the first character
|
||||
*inside* the value (immediately after the opening quote of the "text" field).
|
||||
If the closing quote is missing (truncated), returns the unescaped rest of the string.
|
||||
"""
|
||||
out: list[str] = []
|
||||
i = content_start
|
||||
n = len(s)
|
||||
while i < n:
|
||||
c = s[i]
|
||||
if c == "\\" and i + 1 < n:
|
||||
e = s[i + 1]
|
||||
if e in '"\\':
|
||||
out.append(e)
|
||||
i += 2
|
||||
elif e == "/":
|
||||
out.append("/")
|
||||
i += 2
|
||||
elif e == "b":
|
||||
out.append("\b")
|
||||
i += 2
|
||||
elif e == "f":
|
||||
out.append("\f")
|
||||
i += 2
|
||||
elif e == "n":
|
||||
out.append("\n")
|
||||
i += 2
|
||||
elif e == "r":
|
||||
out.append("\r")
|
||||
i += 2
|
||||
elif e == "t":
|
||||
out.append("\t")
|
||||
i += 2
|
||||
elif e == "u" and i + 5 < n:
|
||||
try:
|
||||
out.append(chr(int(s[i + 2 : i + 6], 16)))
|
||||
except (ValueError, OverflowError):
|
||||
out.append(s[i : i + 6])
|
||||
i += 6
|
||||
else:
|
||||
out.append(e)
|
||||
i += 2
|
||||
elif c == '"':
|
||||
return "".join(out)
|
||||
else:
|
||||
out.append(c)
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _try_extract_liteparse_text_value_from_malformed_json(s: str) -> Optional[str]:
|
||||
"""
|
||||
When json.loads failed (e.g. truncated or corrupt), find the "text" field value
|
||||
in a LiteParse-shaped object and return only the unescaped string body.
|
||||
"""
|
||||
if not s.startswith("{"):
|
||||
return None
|
||||
head = s[:10000] if len(s) > 10000 else s
|
||||
if not ("ok" in head and "filePath" in head):
|
||||
return None
|
||||
m = _RE_TEXT_KEY.search(s)
|
||||
if not m:
|
||||
return None
|
||||
return _json_unescape_quoted_value(s, m.end())
|
||||
|
||||
|
||||
def _clean_extracted_one_pass(t: str) -> str:
|
||||
for _ in range(3):
|
||||
nxt = _unwrap_liteparse_json_line_if_stored(t)
|
||||
if nxt == t:
|
||||
break
|
||||
t = nxt
|
||||
s = t.lstrip()
|
||||
if s.startswith("{"):
|
||||
m = _try_extract_liteparse_text_value_from_malformed_json(s)
|
||||
if m is not None:
|
||||
return m
|
||||
return t
|
||||
|
||||
|
||||
def clean_extracted_document_text(text: str) -> str:
|
||||
"""
|
||||
Return only the document body: strip LiteParse JSON wrappers, then drop any
|
||||
leading payload before the "text" value (handles truncated/invalid JSON).
|
||||
Multiple passes in case the inner body is again JSON-shaped.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
t = text
|
||||
for _ in range(4):
|
||||
nxt = _clean_extracted_one_pass(t)
|
||||
if nxt == t:
|
||||
return t
|
||||
t = nxt
|
||||
return t
|
||||
|
||||
|
||||
class DocumentsLoader:
|
||||
DECOMPOSE_TIMEOUT_SECONDS = 600
|
||||
|
||||
|
|
@ -107,6 +232,7 @@ class DocumentsLoader:
|
|||
else:
|
||||
document = await asyncio.to_thread(self._parse_with_liteparse, file_path)
|
||||
|
||||
document = clean_extracted_document_text(document)
|
||||
documents.append(document)
|
||||
images.append(imgs)
|
||||
|
||||
|
|
|
|||
|
|
@ -193,6 +193,11 @@ class LiteParseService:
|
|||
|
||||
return True, "ok"
|
||||
|
||||
@staticmethod
|
||||
def _use_json_runner_output() -> bool:
|
||||
"""If true, expect one JSON line on stdout (legacy). Default is plain UTF-8 text (better for large PDFs)."""
|
||||
return (os.getenv("LITEPARSE_RUNNER_OUTPUT") or "").strip().lower() == "json"
|
||||
|
||||
def parse_to_markdown(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
@ -233,6 +238,9 @@ class LiteParseService:
|
|||
if tessdata:
|
||||
command.extend(["--tessdata-path", tessdata])
|
||||
|
||||
use_json = self._use_json_runner_output()
|
||||
command.extend(["--python-bridge", "json" if use_json else "plain"])
|
||||
|
||||
LOGGER.info(
|
||||
"[LiteParse] Parsing file=%s ocr_enabled=%s ocr_language=%s",
|
||||
file_path,
|
||||
|
|
@ -254,6 +262,20 @@ class LiteParseService:
|
|||
_command_str(command),
|
||||
)
|
||||
|
||||
if not use_json:
|
||||
if process.returncode != 0:
|
||||
err = (process.stderr or "").strip() or "LiteParse failed"
|
||||
raise LiteParseError(
|
||||
f"{err}; returncode={process.returncode}; "
|
||||
f"stderr={_snippet(process.stderr)}; stdout={_snippet(process.stdout)}"
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"text": (process.stdout or "").lstrip("\ufeff"),
|
||||
"filePath": file_path,
|
||||
"pageCount": 0,
|
||||
}
|
||||
|
||||
payload: Dict[str, Any]
|
||||
try:
|
||||
payload = self._decode_runner_output(process.stdout)
|
||||
|
|
|
|||
|
|
@ -46,11 +46,7 @@ const PresentationPage = ({ presentation_id }: { presentation_id: string }) => {
|
|||
}
|
||||
}, [presentationData]);
|
||||
|
||||
// Ensure /app_data and /static image paths resolve through FastAPI in Electron.
|
||||
useEffect(() => {
|
||||
const observer = setupImageUrlConverter();
|
||||
return () => observer?.disconnect();
|
||||
}, []);
|
||||
|
||||
|
||||
// Function to fetch the slides
|
||||
useEffect(() => {
|
||||
|
|
|
|||
|
|
@ -63,7 +63,8 @@ const page = () => {
|
|||
<p className="text-xl font-syne text-[#101323CC]">Turn prompts or documents into presentations with AI</p>
|
||||
</div>
|
||||
<div
|
||||
className='fixed z-0 -bottom-[14.5rem] left-0 w-full h-full'
|
||||
className="fixed z-0 -bottom-[14.5rem] left-0 w-full h-full pointer-events-none"
|
||||
aria-hidden
|
||||
style={{
|
||||
height: "341px",
|
||||
borderRadius: '1440px',
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"description": "Dark, theme-ready presentation layouts with covers, structured content grids, timelines, narrative splits, and chart-driven slides",
|
||||
"ordered": false,
|
||||
"default": false
|
||||
}
|
||||
|
|
@ -15,6 +15,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
|||
from models.sql.async_presentation_generation_status import ( # noqa: F401, E402
|
||||
AsyncPresentationGenerationTaskModel,
|
||||
)
|
||||
from models.sql.chat_history_message import ChatHistoryMessageModel # noqa: F401, E402
|
||||
from models.sql.image_asset import ImageAsset # noqa: F401, E402
|
||||
from models.sql.key_value import KeyValueSqlModel # noqa: F401, E402
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus # noqa: F401, E402
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
"""added_chat_history_messages_table
|
||||
|
||||
Revision ID: c7b70d0f31b1
|
||||
Revises: 95b5127e93cd
|
||||
Create Date: 2026-04-26 12:29:49.508761
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c7b70d0f31b1'
|
||||
down_revision: Union[str, None] = '95b5127e93cd'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('chat_history_messages',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('presentation_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('conversation_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('role', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('tool_calls', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['presentation_id'], ['presentations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_chat_history_messages_conversation_id'), 'chat_history_messages', ['conversation_id'], unique=False)
|
||||
op.create_index(op.f('ix_chat_history_messages_position'), 'chat_history_messages', ['position'], unique=False)
|
||||
op.create_index(op.f('ix_chat_history_messages_presentation_id'), 'chat_history_messages', ['presentation_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_chat_history_messages_presentation_id'), table_name='chat_history_messages')
|
||||
op.drop_index(op.f('ix_chat_history_messages_position'), table_name='chat_history_messages')
|
||||
op.drop_index(op.f('ix_chat_history_messages_conversation_id'), table_name='chat_history_messages')
|
||||
op.drop_table('chat_history_messages')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import async_session_maker
|
||||
from utils.ollama import pull_ollama_model
|
||||
|
||||
|
||||
|
|
@ -17,51 +17,47 @@ async def pull_ollama_model_background_task(model: str):
|
|||
)
|
||||
log_event_count = 0
|
||||
|
||||
session = await get_container_db_async_session().__anext__()
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
return
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
if "error" in event:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = event["error"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
return
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.error = str(e)
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
saved_model_status.error = None
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
|
||||
|
||||
async def upsert_ollama_pull_status(
|
||||
|
|
|
|||
129
servers/fastapi/api/v1/ppt/endpoints/chat.py
Normal file
129
servers/fastapi/api/v1/ppt/endpoints/chat.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.chat import (
|
||||
ChatConversationListItem,
|
||||
ChatHistoryMessageItem,
|
||||
ChatHistoryResponse,
|
||||
ChatMessageRequest,
|
||||
ChatMessageResponse,
|
||||
)
|
||||
from models.sse_response import (
|
||||
SSECompleteResponse,
|
||||
SSEErrorResponse,
|
||||
SSEStatusResponse,
|
||||
SSETraceResponse,
|
||||
SSEResponse,
|
||||
)
|
||||
from services.chat import ChatTurnResult, PresentationChatService
|
||||
from services.chat import sql_chat_history
|
||||
from services.database import get_async_session
|
||||
|
||||
CHAT_ROUTER = APIRouter(prefix="/chat", tags=["Chat"])
|
||||
|
||||
|
||||
@CHAT_ROUTER.get("/conversations", response_model=list[ChatConversationListItem])
|
||||
async def list_chat_conversations(
|
||||
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
raw = await sql_chat_history.list_conversations(
|
||||
sql_session, presentation_id=presentation_id
|
||||
)
|
||||
return [
|
||||
ChatConversationListItem(
|
||||
conversation_id=uuid.UUID(str(item["conversation_id"])),
|
||||
updated_at=item.get("updated_at"),
|
||||
last_message_preview=item.get("last_message_preview"),
|
||||
)
|
||||
for item in raw
|
||||
]
|
||||
|
||||
|
||||
@CHAT_ROUTER.get("/history", response_model=ChatHistoryResponse)
|
||||
async def get_chat_history(
|
||||
presentation_id: uuid.UUID = Query(..., description="Presentation id"),
|
||||
conversation_id: uuid.UUID = Query(..., description="Conversation thread id"),
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
rows = await sql_chat_history.load_messages_with_meta(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
return ChatHistoryResponse(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
messages=[
|
||||
ChatHistoryMessageItem(
|
||||
role=str(m.get("role") or ""),
|
||||
content=str(m.get("content") or ""),
|
||||
created_at=m.get("created_at")
|
||||
if isinstance(m.get("created_at"), str)
|
||||
else None,
|
||||
)
|
||||
for m in rows
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@CHAT_ROUTER.post("/message", response_model=ChatMessageResponse)
|
||||
async def chat_message(
|
||||
payload: ChatMessageRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
service = PresentationChatService(
|
||||
sql_session=sql_session,
|
||||
presentation_id=payload.presentation_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
)
|
||||
result = await service.generate_reply(payload.message)
|
||||
return ChatMessageResponse(
|
||||
conversation_id=result.conversation_id,
|
||||
response=result.response_text,
|
||||
tool_calls=result.tool_calls,
|
||||
)
|
||||
|
||||
|
||||
@CHAT_ROUTER.post("/message/stream")
|
||||
async def chat_message_stream(
|
||||
payload: ChatMessageRequest,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
service = PresentationChatService(
|
||||
sql_session=sql_session,
|
||||
presentation_id=payload.presentation_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
)
|
||||
|
||||
async def inner():
|
||||
try:
|
||||
async for event_type, value in service.stream_reply(payload.message):
|
||||
if event_type == "chunk" and isinstance(value, str):
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": value}),
|
||||
).to_string()
|
||||
elif event_type == "status" and isinstance(value, str):
|
||||
yield SSEStatusResponse(status=value).to_string()
|
||||
elif event_type == "trace" and isinstance(value, dict):
|
||||
yield SSETraceResponse(trace=value).to_string()
|
||||
elif event_type == "complete" and isinstance(value, ChatTurnResult):
|
||||
result = value
|
||||
complete_payload = ChatMessageResponse(
|
||||
conversation_id=result.conversation_id,
|
||||
response=result.response_text,
|
||||
tool_calls=result.tool_calls,
|
||||
)
|
||||
yield SSECompleteResponse(
|
||||
key="chat",
|
||||
value=complete_payload.model_dump(mode="json"),
|
||||
).to_string()
|
||||
except HTTPException as exc:
|
||||
yield SSEErrorResponse(detail=exc.detail).to_string()
|
||||
|
||||
return StreamingResponse(inner(), media_type="text/event-stream")
|
||||
|
|
@ -9,7 +9,7 @@ from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
from services.database import get_container_db_async_session
|
||||
from services.database import get_async_session
|
||||
from utils.ollama import list_pulled_ollama_models
|
||||
|
||||
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
||||
|
|
@ -29,7 +29,7 @@ async def get_available_models():
|
|||
async def pull_model(
|
||||
model: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
session: AsyncSession = Depends(get_container_db_async_session),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
|
||||
if model not in SUPPORTED_OLLAMA_MODELS:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
|||
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
|
||||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
|
||||
from api.v1.ppt.endpoints.chat import CHAT_ROUTER
|
||||
from api.v1.ppt.endpoints.pptx_slides import PPTX_FONTS_ROUTER
|
||||
from api.v1.ppt.endpoints.theme import THEMES_ROUTER
|
||||
from api.v1.ppt.endpoints.theme_generate import THEME_ROUTER
|
||||
|
|
@ -29,6 +30,7 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PPTX_SLIDES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CHAT_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(SLIDE_TO_HTML_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(HTML_TO_REACT_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(HTML_EDIT_ROUTER)
|
||||
|
|
|
|||
44
servers/fastapi/models/chat.py
Normal file
44
servers/fastapi/models/chat.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
presentation_id: uuid.UUID
|
||||
message: str = Field(min_length=1, max_length=8000)
|
||||
conversation_id: Optional[uuid.UUID] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
conversation_id: uuid.UUID
|
||||
response: str
|
||||
tool_calls: list[str] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatHistoryMessageItem(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
created_at: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatHistoryResponse(BaseModel):
|
||||
presentation_id: uuid.UUID
|
||||
conversation_id: uuid.UUID
|
||||
messages: list[ChatHistoryMessageItem]
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ChatConversationListItem(BaseModel):
|
||||
conversation_id: uuid.UUID
|
||||
updated_at: Optional[str] = None
|
||||
last_message_preview: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
31
servers/fastapi/models/sql/chat_history_message.py
Normal file
31
servers/fastapi/models/sql/chat_history_message.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, ForeignKey, Text
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
||||
class ChatHistoryMessageModel(SQLModel, table=True):
|
||||
__tablename__ = "chat_history_messages"
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
presentation_id: uuid.UUID = Field(
|
||||
sa_column=Column(
|
||||
ForeignKey("presentations.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
conversation_id: uuid.UUID = Field(index=True)
|
||||
position: int = Field(index=True, ge=1)
|
||||
role: str
|
||||
content: str = Field(sa_column=Column(Text, nullable=False))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, default=get_current_utc_datetime
|
||||
)
|
||||
)
|
||||
tool_calls: Optional[list[str]] = Field(sa_column=Column(JSON), default=None)
|
||||
|
|
@ -20,6 +20,15 @@ class SSEStatusResponse(BaseModel):
|
|||
).to_string()
|
||||
|
||||
|
||||
class SSETraceResponse(BaseModel):
|
||||
trace: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "trace", "trace": self.trace})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSEErrorResponse(BaseModel):
|
||||
detail: str
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ dependencies = [
|
|||
"pathvalidate>=3.3.1",
|
||||
"pdfplumber>=0.11.7",
|
||||
"sqlmodel>=0.0.24",
|
||||
"llmai==0.1.9",
|
||||
"llmai==0.2.2",
|
||||
"jsonschema>=4.26.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
|
|
|||
8
servers/fastapi/services/chat/__init__.py
Normal file
8
servers/fastapi/services/chat/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from services.chat.service import ChatTurnResult, PresentationChatService
|
||||
from services.chat.presentation_context_store import PresentationContextStore
|
||||
|
||||
__all__ = [
|
||||
"ChatTurnResult",
|
||||
"PresentationChatService",
|
||||
"PresentationContextStore",
|
||||
]
|
||||
324
servers/fastapi/services/chat/chat_memory_store.py
Normal file
324
servers/fastapi/services/chat/chat_memory_store.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from services.mem0_oss_memory import get_shared_mem0_client
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
CHAT_TURN_TAG = "[chat_turn]"
|
||||
DEFAULT_MAX_STORED_TURNS = 20
|
||||
|
||||
|
||||
class ChatMemoryStore:
|
||||
def __init__(self):
|
||||
self._enabled = self._to_bool(os.getenv("MEM0_ENABLED"), default=True)
|
||||
self._top_k = self._to_int(os.getenv("MEM0_TOP_K"), default=8)
|
||||
self._max_context_chars = self._to_int(
|
||||
os.getenv("MEM0_MAX_CONTEXT_CHARS"), default=6000
|
||||
)
|
||||
self._max_stored_turns = self._to_int(
|
||||
os.getenv("CHAT_MAX_STORED_TURNS"), default=DEFAULT_MAX_STORED_TURNS
|
||||
)
|
||||
self._namespace_prefix = (
|
||||
os.getenv("MEM0_CHAT_NAMESPACE_PREFIX")
|
||||
or os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX")
|
||||
or "presentation"
|
||||
).strip() or "presentation"
|
||||
|
||||
@staticmethod
|
||||
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
@staticmethod
|
||||
def _to_int(value: Optional[str], default: int) -> int:
|
||||
try:
|
||||
parsed = int(value) if value is not None else default
|
||||
return max(1, parsed)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _normalize(value: str) -> str:
|
||||
return " ".join((value or "").split())
|
||||
|
||||
@staticmethod
|
||||
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
|
||||
return isinstance(exc, (Exception, SystemExit))
|
||||
|
||||
def _scope_user_id(self, presentation_id: UUID, conversation_id: UUID) -> str:
|
||||
return (
|
||||
f"{self._namespace_prefix}:{presentation_id}:"
|
||||
f"conversation:{conversation_id}"
|
||||
)
|
||||
|
||||
def _truncate(self, text: str, limit: int = 20000) -> str:
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return f"{text[:limit]}\n\n[TRUNCATED]"
|
||||
|
||||
async def _get_client(self):
|
||||
if not self._enabled:
|
||||
return None
|
||||
return get_shared_mem0_client()
|
||||
|
||||
def _build_turn_payload(self, *, user_text: str, assistant_text: str) -> str:
|
||||
memory_lines = [
|
||||
CHAT_TURN_TAG,
|
||||
f"turn_created_at={datetime.now(timezone.utc).isoformat()}",
|
||||
]
|
||||
if user_text:
|
||||
memory_lines.append(f"user={user_text}")
|
||||
if assistant_text:
|
||||
memory_lines.append(f"assistant={assistant_text}")
|
||||
return "\n".join(memory_lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_field(item: dict[str, Any]) -> str:
|
||||
memory_text = item.get("memory") or item.get("text") or item.get("data")
|
||||
return str(memory_text).strip() if memory_text is not None else ""
|
||||
|
||||
def _collect_results(self, response: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(response, dict):
|
||||
raw_results = (
|
||||
response.get("results")
|
||||
or response.get("memories")
|
||||
or response.get("items")
|
||||
or []
|
||||
)
|
||||
if isinstance(raw_results, list):
|
||||
return [item for item in raw_results if isinstance(item, dict)]
|
||||
return []
|
||||
if isinstance(response, list):
|
||||
return [item for item in response if isinstance(item, dict)]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _safe_parse_datetime(raw_value: Any) -> datetime | None:
|
||||
if not isinstance(raw_value, str) or not raw_value.strip():
|
||||
return None
|
||||
value = raw_value.strip().replace("Z", "+00:00")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value)
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_chat_turn_fields(text: str) -> tuple[str | None, str | None, datetime | None]:
|
||||
if CHAT_TURN_TAG not in text:
|
||||
return None, None, None
|
||||
|
||||
user_text: str | None = None
|
||||
assistant_text: str | None = None
|
||||
turn_created_at: datetime | None = None
|
||||
for line in text.splitlines():
|
||||
if line.startswith("user="):
|
||||
user_text = line[len("user=") :].strip()
|
||||
elif line.startswith("assistant="):
|
||||
assistant_text = line[len("assistant=") :].strip()
|
||||
elif line.startswith("turn_created_at="):
|
||||
turn_created_at = ChatMemoryStore._safe_parse_datetime(
|
||||
line[len("turn_created_at=") :].strip()
|
||||
)
|
||||
return user_text, assistant_text, turn_created_at
|
||||
|
||||
async def store_chat_turn(
|
||||
self,
|
||||
*,
|
||||
presentation_id: UUID,
|
||||
conversation_id: UUID,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
) -> None:
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return
|
||||
|
||||
user_text = self._normalize(user_message)
|
||||
assistant_text = self._normalize(assistant_message)
|
||||
if not user_text and not assistant_text:
|
||||
return
|
||||
|
||||
payload = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._truncate(
|
||||
self._build_turn_payload(
|
||||
user_text=user_text,
|
||||
assistant_text=assistant_text,
|
||||
)
|
||||
),
|
||||
}
|
||||
]
|
||||
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
||||
|
||||
def _add():
|
||||
try:
|
||||
return client.add(payload, user_id=scoped_user_id, infer=False)
|
||||
except TypeError:
|
||||
return client.add(
|
||||
messages=payload,
|
||||
user_id=scoped_user_id,
|
||||
infer=False,
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_add)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception(
|
||||
(
|
||||
"Failed to add chat mem0 memory "
|
||||
"(presentation_id=%s, conversation_id=%s)"
|
||||
),
|
||||
presentation_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
async def retrieve_context(
|
||||
self,
|
||||
*,
|
||||
presentation_id: UUID,
|
||||
conversation_id: UUID,
|
||||
query: str,
|
||||
) -> str:
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return ""
|
||||
|
||||
trimmed_query = (query or "").strip()
|
||||
if not trimmed_query:
|
||||
return ""
|
||||
|
||||
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
||||
|
||||
def _search():
|
||||
try:
|
||||
return client.search(
|
||||
trimmed_query,
|
||||
filters={"user_id": scoped_user_id},
|
||||
top_k=self._top_k,
|
||||
)
|
||||
except TypeError:
|
||||
return client.search(
|
||||
trimmed_query,
|
||||
user_id=scoped_user_id,
|
||||
top_k=self._top_k,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(_search)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception(
|
||||
(
|
||||
"Failed to search chat mem0 memory "
|
||||
"(presentation_id=%s, conversation_id=%s)"
|
||||
),
|
||||
presentation_id,
|
||||
conversation_id,
|
||||
)
|
||||
return ""
|
||||
|
||||
results = self._collect_results(response)
|
||||
memories: list[str] = []
|
||||
for item in results:
|
||||
normalized = self._extract_text_field(item)
|
||||
if normalized:
|
||||
memories.append(normalized)
|
||||
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
deduped = list(dict.fromkeys(memories))
|
||||
return self._truncate("\n\n".join(deduped), self._max_context_chars)
|
||||
|
||||
async def load_history(
|
||||
self,
|
||||
*,
|
||||
presentation_id: UUID,
|
||||
conversation_id: UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return []
|
||||
|
||||
scoped_user_id = self._scope_user_id(presentation_id, conversation_id)
|
||||
|
||||
def _get_all():
|
||||
try:
|
||||
return client.get_all(
|
||||
filters={"user_id": scoped_user_id},
|
||||
limit=max(10, self._max_stored_turns * 4),
|
||||
)
|
||||
except TypeError:
|
||||
try:
|
||||
return client.get_all(
|
||||
user_id=scoped_user_id,
|
||||
limit=max(10, self._max_stored_turns * 4),
|
||||
)
|
||||
except TypeError:
|
||||
try:
|
||||
return client.get_all(filters={"user_id": scoped_user_id})
|
||||
except TypeError:
|
||||
return client.get_all(user_id=scoped_user_id)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(_get_all)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception(
|
||||
(
|
||||
"Failed to load chat mem0 history "
|
||||
"(presentation_id=%s, conversation_id=%s)"
|
||||
),
|
||||
presentation_id,
|
||||
conversation_id,
|
||||
)
|
||||
return []
|
||||
|
||||
results = self._collect_results(response)
|
||||
ordered_turns: list[tuple[datetime, str, str]] = []
|
||||
for index, item in enumerate(results):
|
||||
text_value = self._extract_text_field(item)
|
||||
if not text_value:
|
||||
continue
|
||||
user_text, assistant_text, embedded_timestamp = self._extract_chat_turn_fields(
|
||||
text_value
|
||||
)
|
||||
if not user_text and not assistant_text:
|
||||
continue
|
||||
|
||||
item_created_at = (
|
||||
self._safe_parse_datetime(item.get("created_at"))
|
||||
or self._safe_parse_datetime(item.get("updated_at"))
|
||||
or self._safe_parse_datetime(item.get("event_at"))
|
||||
)
|
||||
timestamp = embedded_timestamp or item_created_at or datetime.fromtimestamp(
|
||||
index, tz=timezone.utc
|
||||
)
|
||||
ordered_turns.append((timestamp, user_text or "", assistant_text or ""))
|
||||
|
||||
ordered_turns.sort(key=lambda turn: turn[0])
|
||||
recent_turns = ordered_turns[-self._max_stored_turns :]
|
||||
|
||||
history: list[dict[str, str]] = []
|
||||
for _, user_text, assistant_text in recent_turns:
|
||||
if user_text:
|
||||
history.append({"role": "user", "content": user_text})
|
||||
if assistant_text:
|
||||
history.append({"role": "assistant", "content": assistant_text})
|
||||
return history
|
||||
|
||||
|
||||
CHAT_MEMORY_STORE = ChatMemoryStore()
|
||||
80
servers/fastapi/services/chat/conversation_store.py
Normal file
80
servers/fastapi/services/chat/conversation_store.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import uuid
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from services.chat.chat_memory_store import CHAT_MEMORY_STORE
|
||||
from services.chat import sql_chat_history
|
||||
|
||||
|
||||
class ChatConversationStore:
|
||||
def __init__(self, sql_session: AsyncSession):
|
||||
self._sql = sql_session
|
||||
|
||||
async def load_history(
|
||||
self,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = await sql_chat_history.load_messages(
|
||||
self._sql,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if messages:
|
||||
return messages
|
||||
legacy = await CHAT_MEMORY_STORE.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if legacy:
|
||||
await sql_chat_history.replace_messages(
|
||||
self._sql,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
messages=legacy,
|
||||
)
|
||||
return legacy
|
||||
|
||||
async def append_turn(
|
||||
self,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
tool_calls: list[str] | None = None,
|
||||
) -> None:
|
||||
await sql_chat_history.append_turn(
|
||||
self._sql,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
assistant_message=assistant_message,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
await CHAT_MEMORY_STORE.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
assistant_message=assistant_message,
|
||||
)
|
||||
|
||||
async def retrieve_semantic_context(
|
||||
self,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
query: str,
|
||||
) -> str:
|
||||
return await CHAT_MEMORY_STORE.retrieve_context(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def ensure_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID | None,
|
||||
) -> uuid.UUID:
|
||||
return conversation_id or uuid.uuid4()
|
||||
491
servers/fastapi/services/chat/memory_layer.py
Normal file
491
servers/fastapi/services/chat/memory_layer.py
Normal file
|
|
@ -0,0 +1,491 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft202012Validator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.sql.image_asset import ImageAsset
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.icon_finder_service import ICON_FINDER_SERVICE
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from services.mem0_presentation_memory_service import MEM0_PRESENTATION_MEMORY_SERVICE
|
||||
from templates.presentation_layout import SlideLayoutModel
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
from utils.process_slides import (
|
||||
process_old_and_new_slides_and_fetch_assets,
|
||||
process_slide_and_fetch_assets,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
MAX_SCHEMA_ERRORS = 10
|
||||
# Keep URL runtime fields during validation because many slide schemas require them.
|
||||
# Speaker note is handled separately and should not affect JSON-schema checks.
|
||||
RUNTIME_CONTENT_FIELDS = {"__speaker_note__"}
|
||||
|
||||
|
||||
class PresentationChatMemoryLayer:
|
||||
"""
|
||||
Memory abstraction for chat tools and context retrieval.
|
||||
|
||||
This layer intentionally hides where data comes from (SQL-backed persisted state
|
||||
and mem0 retrieval) behind `get` and `search`-style methods so chat logic stays
|
||||
decoupled from storage details.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_session: AsyncSession, presentation_id: uuid.UUID):
|
||||
self._sql_session = sql_session
|
||||
self._presentation_id = presentation_id
|
||||
|
||||
async def get(self, key: str) -> Any:
|
||||
if key != "presentation_outline":
|
||||
return None
|
||||
|
||||
# Prefer live slides from SQL so slide count and slide indices are always current.
|
||||
slides_result = await self._sql_session.scalars(
|
||||
select(SlideModel)
|
||||
.where(SlideModel.presentation == self._presentation_id)
|
||||
.order_by(SlideModel.index)
|
||||
)
|
||||
slides = list(slides_result)
|
||||
if slides:
|
||||
LOGGER.info(
|
||||
"Chat outline loaded from slides table (presentation_id=%s, slides=%d)",
|
||||
self._presentation_id,
|
||||
len(slides),
|
||||
)
|
||||
return {
|
||||
"source": "slides_table",
|
||||
"slide_count": len(slides),
|
||||
"slides": [
|
||||
{
|
||||
"slide_id": str(slide.id),
|
||||
"index": slide.index,
|
||||
"layout_id": slide.layout,
|
||||
"content": slide.content,
|
||||
"speaker_note": slide.speaker_note,
|
||||
}
|
||||
for slide in slides
|
||||
],
|
||||
}
|
||||
|
||||
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
|
||||
if not presentation or not presentation.outlines:
|
||||
LOGGER.info(
|
||||
"Chat memory miss for outline (presentation_id=%s)",
|
||||
self._presentation_id,
|
||||
)
|
||||
return None
|
||||
|
||||
LOGGER.info(
|
||||
"Chat outline fallback hit from presentation.outlines (presentation_id=%s)",
|
||||
self._presentation_id,
|
||||
)
|
||||
return presentation.outlines
|
||||
|
||||
async def search(self, query: str, limit: int = 5) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search slides directly from SQL-backed slide rows.
|
||||
|
||||
Results are intentionally compact (snippet-first) to keep tool-call payloads
|
||||
small for models with limited context windows.
|
||||
"""
|
||||
|
||||
trimmed_query = (query or "").strip()
|
||||
if not trimmed_query:
|
||||
return []
|
||||
|
||||
slides_result = await self._sql_session.scalars(
|
||||
select(SlideModel).where(SlideModel.presentation == self._presentation_id)
|
||||
)
|
||||
slides = sorted(list(slides_result), key=lambda slide: slide.index)
|
||||
if not slides:
|
||||
LOGGER.info(
|
||||
"Chat memory miss for slide search (presentation_id=%s, reason=no_slides)",
|
||||
self._presentation_id,
|
||||
)
|
||||
return []
|
||||
|
||||
query_lower = trimmed_query.lower()
|
||||
query_tokens = set(re.findall(r"[a-z0-9]{2,}", query_lower))
|
||||
ranked: list[tuple[int, dict[str, Any]]] = []
|
||||
for slide in slides:
|
||||
serialized = self._serialize_slide(slide)
|
||||
searchable = serialized.lower()
|
||||
|
||||
score = 0
|
||||
if query_lower in searchable:
|
||||
score += 8
|
||||
if query_tokens:
|
||||
score += sum(1 for token in query_tokens if token in searchable)
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
ranked.append(
|
||||
(
|
||||
score,
|
||||
{
|
||||
"slide_id": str(slide.id),
|
||||
"index": slide.index,
|
||||
"slide_number": slide.index + 1,
|
||||
"layout_id": slide.layout,
|
||||
"snippet": self._build_snippet(serialized, query_lower),
|
||||
"score": score,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
ranked.sort(key=lambda item: (-item[0], item[1]["index"]))
|
||||
results = [entry for _, entry in ranked[: max(1, limit)]]
|
||||
LOGGER.info(
|
||||
"Chat DB slide search completed (presentation_id=%s, query=%r, hits=%d)",
|
||||
self._presentation_id,
|
||||
trimmed_query,
|
||||
len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
async def get_slide_at_index(
|
||||
self, index: int, *, include_full_content: bool = False
|
||||
) -> dict[str, Any] | None:
|
||||
slide = await self._sql_session.scalar(
|
||||
select(SlideModel).where(
|
||||
SlideModel.presentation == self._presentation_id,
|
||||
SlideModel.index == index,
|
||||
)
|
||||
)
|
||||
if not slide:
|
||||
LOGGER.info(
|
||||
"Chat memory miss for slide by index (presentation_id=%s, index=%d)",
|
||||
self._presentation_id,
|
||||
index,
|
||||
)
|
||||
return None
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"slide_id": str(slide.id),
|
||||
"index": slide.index,
|
||||
"slide_number": slide.index + 1,
|
||||
"layout_id": slide.layout,
|
||||
"content_preview": self._build_snippet(
|
||||
self._serialize_slide(slide),
|
||||
query_lower="",
|
||||
window=420,
|
||||
),
|
||||
"speaker_note": slide.speaker_note,
|
||||
}
|
||||
if include_full_content:
|
||||
response["content"] = slide.content
|
||||
return response
|
||||
|
||||
async def get_available_layouts(self) -> list[dict[str, Any]]:
|
||||
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
|
||||
if not presentation or not isinstance(presentation.layout, dict):
|
||||
return []
|
||||
|
||||
try:
|
||||
layout_model = presentation.get_layout()
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"Failed to parse presentation layout (presentation_id=%s)",
|
||||
self._presentation_id,
|
||||
)
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"id": layout.id,
|
||||
"name": layout.name,
|
||||
"description": layout.description,
|
||||
}
|
||||
for layout in layout_model.slides
|
||||
]
|
||||
|
||||
async def get_content_schema_from_layout_id(self, layout_id: str) -> dict[str, Any] | None:
|
||||
layout = await self._get_layout_by_id(layout_id)
|
||||
if not layout:
|
||||
return None
|
||||
return layout.json_schema
|
||||
|
||||
async def generate_image(self, prompt: str) -> str:
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
image = await image_generation_service.generate_image(ImagePrompt(prompt=prompt))
|
||||
|
||||
if isinstance(image, ImageAsset):
|
||||
self._sql_session.add(image)
|
||||
await self._sql_session.commit()
|
||||
return image.path
|
||||
|
||||
return str(image)
|
||||
|
||||
async def generate_icon(self, query: str) -> str:
|
||||
icons = await ICON_FINDER_SERVICE.search_icons(query, k=1)
|
||||
if icons:
|
||||
return icons[0]
|
||||
return "/static/icons/placeholder.svg"
|
||||
|
||||
async def save_slide(
|
||||
self,
|
||||
*,
|
||||
content: dict[str, Any],
|
||||
layout_id: str,
|
||||
index: int,
|
||||
replace_old_slide_at_index: bool,
|
||||
) -> dict[str, Any]:
|
||||
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
|
||||
if not presentation:
|
||||
return {
|
||||
"saved": False,
|
||||
"message": "Presentation not found.",
|
||||
"validation_errors": [],
|
||||
}
|
||||
|
||||
layout = await self._get_layout_by_id(layout_id, presentation=presentation)
|
||||
if not layout:
|
||||
return {
|
||||
"saved": False,
|
||||
"message": f"Layout '{layout_id}' was not found in this presentation.",
|
||||
"validation_errors": [f"Unknown layout_id '{layout_id}'."],
|
||||
}
|
||||
|
||||
validation_errors = self._validate_slide_content(
|
||||
content=content,
|
||||
schema=layout.json_schema,
|
||||
)
|
||||
if validation_errors:
|
||||
return {
|
||||
"saved": False,
|
||||
"message": "Slide content failed schema validation.",
|
||||
"validation_errors": validation_errors,
|
||||
}
|
||||
|
||||
target_index = max(0, index)
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
|
||||
if replace_old_slide_at_index:
|
||||
existing_slide = await self._sql_session.scalar(
|
||||
select(SlideModel).where(
|
||||
SlideModel.presentation == self._presentation_id,
|
||||
SlideModel.index == target_index,
|
||||
)
|
||||
)
|
||||
if not existing_slide:
|
||||
return {
|
||||
"saved": False,
|
||||
"message": f"No existing slide found at index {target_index} to replace.",
|
||||
"validation_errors": [],
|
||||
}
|
||||
|
||||
updated_content = copy.deepcopy(content)
|
||||
new_assets = await process_old_and_new_slides_and_fetch_assets(
|
||||
image_generation_service=image_generation_service,
|
||||
old_slide_content=existing_slide.content or {},
|
||||
new_slide_content=updated_content,
|
||||
)
|
||||
|
||||
existing_slide.id = uuid.uuid4()
|
||||
existing_slide.layout = layout_id
|
||||
existing_slide.layout_group = self._resolve_layout_group(
|
||||
presentation=presentation,
|
||||
fallback=existing_slide.layout_group,
|
||||
)
|
||||
existing_slide.content = updated_content
|
||||
existing_slide.speaker_note = self._extract_speaker_note(updated_content)
|
||||
self._sql_session.add(existing_slide)
|
||||
self._sql_session.add_all(new_assets)
|
||||
await self._sql_session.commit()
|
||||
|
||||
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
|
||||
presentation_id=self._presentation_id,
|
||||
slide_index=target_index,
|
||||
edit_prompt=f"[chat_tool_save_slide_replace] layout_id={layout_id}",
|
||||
edited_slide_content=updated_content,
|
||||
)
|
||||
|
||||
return {
|
||||
"saved": True,
|
||||
"action": "replaced",
|
||||
"message": f"Slide at index {target_index} was replaced successfully.",
|
||||
"slide_id": str(existing_slide.id),
|
||||
"index": target_index,
|
||||
}
|
||||
|
||||
slides_result = await self._sql_session.scalars(
|
||||
select(SlideModel)
|
||||
.where(SlideModel.presentation == self._presentation_id)
|
||||
.order_by(SlideModel.index)
|
||||
)
|
||||
slides = list(slides_result)
|
||||
|
||||
if slides:
|
||||
max_index = max(slide.index for slide in slides)
|
||||
insert_index = min(target_index, max_index + 1)
|
||||
slides_to_shift = [slide for slide in slides if slide.index >= insert_index]
|
||||
else:
|
||||
insert_index = 0
|
||||
slides_to_shift = []
|
||||
|
||||
for slide in sorted(slides_to_shift, key=lambda each: each.index, reverse=True):
|
||||
slide.index += 1
|
||||
self._sql_session.add(slide)
|
||||
|
||||
new_slide_content = copy.deepcopy(content)
|
||||
new_slide = SlideModel(
|
||||
presentation=self._presentation_id,
|
||||
layout_group=self._resolve_layout_group(presentation=presentation),
|
||||
layout=layout_id,
|
||||
index=insert_index,
|
||||
content=new_slide_content,
|
||||
speaker_note=self._extract_speaker_note(new_slide_content),
|
||||
)
|
||||
new_assets = await process_slide_and_fetch_assets(
|
||||
image_generation_service=image_generation_service,
|
||||
slide=new_slide,
|
||||
)
|
||||
|
||||
self._sql_session.add(new_slide)
|
||||
self._sql_session.add_all(new_assets)
|
||||
await self._sql_session.commit()
|
||||
await self._sql_session.refresh(new_slide)
|
||||
|
||||
await MEM0_PRESENTATION_MEMORY_SERVICE.store_slide_edit(
|
||||
presentation_id=self._presentation_id,
|
||||
slide_index=insert_index,
|
||||
edit_prompt=f"[chat_tool_save_slide_new] layout_id={layout_id}",
|
||||
edited_slide_content=new_slide.content,
|
||||
)
|
||||
|
||||
return {
|
||||
"saved": True,
|
||||
"action": "created",
|
||||
"message": f"New slide saved at index {insert_index}.",
|
||||
"slide_id": str(new_slide.id),
|
||||
"index": insert_index,
|
||||
"shifted_slide_count": len(slides_to_shift),
|
||||
}
|
||||
|
||||
async def retrieve_context(self, query: str) -> str:
|
||||
context = await MEM0_PRESENTATION_MEMORY_SERVICE.retrieve_context(
|
||||
self._presentation_id,
|
||||
query,
|
||||
)
|
||||
if context:
|
||||
LOGGER.info(
|
||||
"Chat memory semantic context hit (presentation_id=%s, chars=%d)",
|
||||
self._presentation_id,
|
||||
len(context),
|
||||
)
|
||||
else:
|
||||
LOGGER.info(
|
||||
"Chat memory semantic context miss (presentation_id=%s)",
|
||||
self._presentation_id,
|
||||
)
|
||||
return context
|
||||
|
||||
async def _get_layout_by_id(
|
||||
self,
|
||||
layout_id: str,
|
||||
presentation: PresentationModel | None = None,
|
||||
) -> SlideLayoutModel | None:
|
||||
if not presentation:
|
||||
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
|
||||
if not presentation or not isinstance(presentation.layout, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
layout_model = presentation.get_layout()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for layout in layout_model.slides:
|
||||
if layout.id == layout_id:
|
||||
return layout
|
||||
return None
|
||||
|
||||
def _validate_slide_content(
|
||||
self,
|
||||
*,
|
||||
content: dict[str, Any],
|
||||
schema: dict[str, Any],
|
||||
) -> list[str]:
|
||||
validation_content = self._strip_runtime_fields(content)
|
||||
validator = Draft202012Validator(schema)
|
||||
errors = sorted(validator.iter_errors(validation_content), key=lambda err: err.path)
|
||||
|
||||
if not errors:
|
||||
return []
|
||||
|
||||
formatted_errors: list[str] = []
|
||||
for err in errors[:MAX_SCHEMA_ERRORS]:
|
||||
location = ".".join([str(part) for part in err.path]) or "$"
|
||||
formatted_errors.append(f"{location}: {err.message}")
|
||||
return formatted_errors
|
||||
|
||||
@staticmethod
|
||||
def _strip_runtime_fields(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, nested_value in value.items():
|
||||
if key in RUNTIME_CONTENT_FIELDS:
|
||||
continue
|
||||
sanitized[key] = PresentationChatMemoryLayer._strip_runtime_fields(
|
||||
nested_value
|
||||
)
|
||||
return sanitized
|
||||
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
PresentationChatMemoryLayer._strip_runtime_fields(item) for item in value
|
||||
]
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_speaker_note(content: dict[str, Any]) -> str:
|
||||
value = content.get("__speaker_note__")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_layout_group(
|
||||
*,
|
||||
presentation: PresentationModel,
|
||||
fallback: str = "presentation",
|
||||
) -> str:
|
||||
if isinstance(presentation.layout, dict):
|
||||
name = str(presentation.layout.get("name") or "").strip()
|
||||
if name:
|
||||
return name
|
||||
return fallback
|
||||
|
||||
@staticmethod
|
||||
def _serialize_slide(slide: SlideModel) -> str:
|
||||
content_text = ""
|
||||
try:
|
||||
content_text = json.dumps(slide.content or {}, ensure_ascii=False)
|
||||
except Exception:
|
||||
content_text = str(slide.content)
|
||||
|
||||
speaker_note = slide.speaker_note or ""
|
||||
return f"slide_index={slide.index}\nlayout_id={slide.layout}\n{content_text}\n{speaker_note}"
|
||||
|
||||
@staticmethod
|
||||
def _build_snippet(text: str, query_lower: str, window: int = 320) -> str:
|
||||
normalized = " ".join(text.split())
|
||||
if not normalized:
|
||||
return ""
|
||||
|
||||
offset = normalized.lower().find(query_lower)
|
||||
if offset == -1:
|
||||
return normalized[:window]
|
||||
|
||||
start = max(0, offset - window // 3)
|
||||
end = min(len(normalized), start + window)
|
||||
return normalized[start:end]
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from services.chat.memory_layer import (
|
||||
PresentationChatMemoryLayer as PresentationContextStore,
|
||||
)
|
||||
|
||||
__all__ = ["PresentationContextStore"]
|
||||
63
servers/fastapi/services/chat/prompts.py
Normal file
63
servers/fastapi/services/chat/prompts.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
def _trim_block(label: str, text: str) -> str:
|
||||
t = (text or "").strip()
|
||||
if not t:
|
||||
return ""
|
||||
return f"\n{label}\n{t}\n"
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
presentation_memory_context: str,
|
||||
chat_memory_context: str,
|
||||
) -> str:
|
||||
presentation_block = _trim_block(
|
||||
"Deck memory (semantic / long-term: uploaded document text, outline drafts & prompts, stored slide-edit notes; snippets may be partial and can lag the live deck):",
|
||||
presentation_memory_context,
|
||||
)
|
||||
chat_block = _trim_block(
|
||||
"Chat memory (earlier messages in this conversation only):",
|
||||
chat_memory_context,
|
||||
)
|
||||
return (
|
||||
"You are Presenton's slide assistant. Be concise, accurate, and action-oriented.\n"
|
||||
"\n"
|
||||
"Operating priorities\n"
|
||||
"1) Complete the user's intent with the fewest reliable tool calls.\n"
|
||||
"2) Prefer verified deck state over assumptions.\n"
|
||||
"3) Keep responses short and concrete.\n"
|
||||
"\n"
|
||||
"Source-of-truth policy\n"
|
||||
"- Tool outputs from this turn are authoritative for live deck state.\n"
|
||||
"- Conversation context (user constraints, prior decisions) is next.\n"
|
||||
"- Deck memory is background context and may be partial or stale.\n"
|
||||
"- If sources conflict, trust tools over memory.\n"
|
||||
"\n"
|
||||
"When to use memory vs tools\n"
|
||||
"- Use deck memory for uploaded-document meaning, original outline intent, and planning rationale.\n"
|
||||
"- Use tools for anything about current slides: exact text, ordering, layout, slide identity, and edits.\n"
|
||||
"- If user asks what is currently on slide N or asks for a change, do not rely on memory alone.\n"
|
||||
"\n"
|
||||
"Tool-use protocol (live SQL slide data)\n"
|
||||
"- User slide numbers are 1-based; tool indexes are 0-based.\n"
|
||||
"- Start with compact reads: getPresentationOutline -> searchSlides -> getSlideAtIndex.\n"
|
||||
"- Set includeFullContent=true only when full JSON is required (typically right before saveSlide).\n"
|
||||
"- Before saveSlide, validate target layout/schema (getAvailableLayouts, getContentSchemaFromLayoutId).\n"
|
||||
"- Generate required assets in batch with generateAssets before saving.\n"
|
||||
"- saveSlide payload must match the schema exactly; do not invent fields.\n"
|
||||
"- If a tool fails, report it briefly and choose the best next step.\n"
|
||||
"\n"
|
||||
"Autonomous decision policy (default behavior)\n"
|
||||
"- For edit requests, execute the best reasonable implementation without asking for optional preferences.\n"
|
||||
"- Do not ask the user to choose among layouts/assets unless the user explicitly asks to choose.\n"
|
||||
"- If visual details are unspecified (image style, icon set, exact layout), infer from slide content and deck theme.\n"
|
||||
"- For requests like 'add images/icons' or 'make it better', pick a layout that best preserves existing intent and readability, then apply it.\n"
|
||||
"- Ask a clarification only when blocked by a required missing fact (e.g., target slide is ambiguous, conflicting constraints, or missing required data).\n"
|
||||
"- When in doubt, prefer a professional, neutral visual style and continue.\n"
|
||||
"\n"
|
||||
"Response policy\n"
|
||||
"- Never invent slide facts, tool results, or document claims.\n"
|
||||
"- If information is missing, run the right tool or ask one focused clarification.\n"
|
||||
"- After enough evidence is collected, stop calling tools and provide a brief final answer.\n"
|
||||
"- For edits, apply changes first, then report what changed and where; for lookups, state what you found.\n"
|
||||
f"{presentation_block}"
|
||||
f"{chat_block}"
|
||||
)
|
||||
81
servers/fastapi/services/chat/schemas.py
Normal file
81
servers/fastapi/services/chat/schemas.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
import dirtyjson # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class StrictSchemaModel(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
|
||||
|
||||
class NoArgsInput(StrictSchemaModel):
|
||||
pass
|
||||
|
||||
|
||||
class GetSlideAtIndexInput(StrictSchemaModel):
|
||||
index: int = Field(ge=0, le=1000)
|
||||
include_full_content: bool = Field(alias="includeFullContent")
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
|
||||
|
||||
|
||||
class SearchSlidesInput(StrictSchemaModel):
|
||||
query: str = Field(min_length=1, max_length=1000)
|
||||
limit: int = Field(ge=1, le=10)
|
||||
|
||||
|
||||
class GetContentSchemaFromLayoutIdInput(StrictSchemaModel):
|
||||
layout_id: str = Field(alias="layoutId", min_length=1, max_length=200)
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
|
||||
|
||||
|
||||
class GenerateImageInput(StrictSchemaModel):
|
||||
prompt: str = Field(min_length=1, max_length=4000)
|
||||
|
||||
|
||||
class GenerateIconInput(StrictSchemaModel):
|
||||
query: str = Field(min_length=1, max_length=1000)
|
||||
|
||||
|
||||
class GenerateAssetItemInput(StrictSchemaModel):
|
||||
kind: Literal["image", "icon"]
|
||||
prompt: str = Field(
|
||||
min_length=1,
|
||||
max_length=4000,
|
||||
description="Image prompt or icon search query.",
|
||||
)
|
||||
|
||||
|
||||
class GenerateAssetsInput(StrictSchemaModel):
|
||||
assets: list[GenerateAssetItemInput] = Field(min_length=1, max_length=12)
|
||||
|
||||
|
||||
class SaveSlideInput(StrictSchemaModel):
|
||||
content: str = Field(
|
||||
min_length=2,
|
||||
max_length=200000,
|
||||
description=(
|
||||
"A JSON-serialized object for slide content. "
|
||||
"Example: '{\"title\": \"Q4 Revenue\", \"bullets\": [\"North America +22%\"]}'"
|
||||
),
|
||||
)
|
||||
layout_id: str = Field(alias="layoutId", min_length=1, max_length=200)
|
||||
index: int = Field(ge=0, le=1000)
|
||||
replace_old_slide_at_index: bool = Field(alias="replaceOldSlideAtIndex")
|
||||
|
||||
model_config = ConfigDict(extra="forbid", strict=True, populate_by_name=True)
|
||||
|
||||
@field_validator("content")
|
||||
@classmethod
|
||||
def validate_content(cls, value: str) -> str:
|
||||
try:
|
||||
parsed: Any = dirtyjson.loads(value)
|
||||
except Exception:
|
||||
parsed = json.loads(value)
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("'content' must be a JSON object.")
|
||||
|
||||
return value
|
||||
449
servers/fastapi/services/chat/service.py
Normal file
449
servers/fastapi/services/chat/service.py
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import HTTPException
|
||||
from llmai import get_client # type: ignore[import-not-found]
|
||||
from llmai.shared import ( # type: ignore[import-not-found]
|
||||
AssistantMessage,
|
||||
Message,
|
||||
SystemMessage,
|
||||
TextContentPart,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
WebSearchTool
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.sql.presentation import PresentationModel
|
||||
from services.chat.conversation_store import ChatConversationStore
|
||||
from services.chat.presentation_context_store import PresentationContextStore
|
||||
from services.chat.prompts import build_system_prompt
|
||||
from services.chat.tools import ChatTools
|
||||
from utils.llm_client_error_handler import handle_llm_client_exceptions
|
||||
from utils.llm_config import get_llm_config
|
||||
from utils.llm_provider import get_model
|
||||
from utils.llm_utils import (
|
||||
extract_text,
|
||||
get_generate_kwargs,
|
||||
stream_generate_events,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
MAX_TOOL_ROUNDS = 16
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatTurnResult:
|
||||
conversation_id: uuid.UUID
|
||||
response_text: str
|
||||
tool_calls: list[str]
|
||||
|
||||
|
||||
ChatStreamEventType = Literal["chunk", "complete", "status", "trace"]
|
||||
ChatStreamEventValue = str | ChatTurnResult | dict[str, Any]
|
||||
|
||||
|
||||
class PresentationChatService:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: AsyncSession,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID | None,
|
||||
):
|
||||
self._sql_session = sql_session
|
||||
self._presentation_id = presentation_id
|
||||
self._conversation_id = conversation_id
|
||||
|
||||
self._conversation_store = ChatConversationStore(sql_session)
|
||||
self._memory = PresentationContextStore(sql_session, presentation_id)
|
||||
self._tools = ChatTools(self._memory)
|
||||
|
||||
async def generate_reply(self, user_message: str) -> ChatTurnResult:
|
||||
conversation_id, messages = await self._prepare_turn_context(user_message)
|
||||
response_text, tool_calls = await self._run_llm_with_tools(messages)
|
||||
return await self._persist_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text=response_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
async def stream_reply(
|
||||
self, user_message: str
|
||||
) -> AsyncGenerator[tuple[ChatStreamEventType, ChatStreamEventValue], None]:
|
||||
yield "status", "Reading deck context"
|
||||
conversation_id, messages = await self._prepare_turn_context(user_message)
|
||||
|
||||
client = get_client(config=get_llm_config())
|
||||
model = get_model()
|
||||
tools = self._tools.get_tool_definitions()
|
||||
tools.append(WebSearchTool())
|
||||
|
||||
called_tools: list[str] = []
|
||||
last_tool_results: list[dict[str, Any]] = []
|
||||
response_text: str | None = None
|
||||
|
||||
for round_index in range(MAX_TOOL_ROUNDS):
|
||||
completion_chunk: Any | None = None
|
||||
round_content_chunks: list[str] = []
|
||||
thinking_chunks: list[str] = []
|
||||
|
||||
try:
|
||||
async for event in stream_generate_events(
|
||||
client,
|
||||
**get_generate_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
),
|
||||
):
|
||||
event_type = getattr(event, "type", None)
|
||||
if event_type == "content":
|
||||
chunk = getattr(event, "chunk", None)
|
||||
if chunk:
|
||||
round_content_chunks.append(chunk)
|
||||
yield "chunk", chunk
|
||||
elif event_type == "thinking":
|
||||
thinking_text = self._event_text(event)
|
||||
if thinking_text:
|
||||
thinking_chunks.append(thinking_text)
|
||||
elif event_type == "completion":
|
||||
completion_chunk = event
|
||||
except Exception as exc:
|
||||
raise handle_llm_client_exceptions(exc)
|
||||
|
||||
thinking_summary = self._summarize_model_note(thinking_chunks)
|
||||
if thinking_summary:
|
||||
yield "trace", {
|
||||
"kind": "model_note",
|
||||
"round": round_index + 1,
|
||||
"status": "info",
|
||||
"message": thinking_summary,
|
||||
}
|
||||
|
||||
completion_tool_calls = list(
|
||||
getattr(completion_chunk, "tool_calls", []) or []
|
||||
)
|
||||
if completion_tool_calls:
|
||||
tool_names = [tool_call.name for tool_call in completion_tool_calls]
|
||||
called_tools.extend(tool_names)
|
||||
yield "trace", {
|
||||
"kind": "tool_plan",
|
||||
"round": round_index + 1,
|
||||
"tools": tool_names,
|
||||
"message": f"Using tools: {', '.join(tool_names)}",
|
||||
}
|
||||
messages = (
|
||||
list(getattr(completion_chunk, "messages", []) or [])
|
||||
if getattr(completion_chunk, "messages", None)
|
||||
else list(messages)
|
||||
)
|
||||
|
||||
last_tool_results = []
|
||||
for tool_call in completion_tool_calls:
|
||||
yield "trace", {
|
||||
"kind": "tool_call",
|
||||
"round": round_index + 1,
|
||||
"tool": tool_call.name,
|
||||
"status": "start",
|
||||
"message": self._tool_start_message(tool_call.name),
|
||||
}
|
||||
tool_result = await self._tools.execute_tool_call(tool_call)
|
||||
last_tool_results.append(tool_result)
|
||||
yield "trace", {
|
||||
"kind": "tool_call",
|
||||
"round": round_index + 1,
|
||||
"tool": tool_call.name,
|
||||
"status": "success" if tool_result.get("ok") else "error",
|
||||
"message": self._summarize_tool_result(
|
||||
tool_call.name, tool_result
|
||||
),
|
||||
}
|
||||
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
id=tool_call.id,
|
||||
content=[TextContentPart(text=tool_response_content)],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_text = "".join(round_content_chunks)
|
||||
if not response_text and completion_chunk:
|
||||
response_text = extract_text(getattr(completion_chunk, "content", None))
|
||||
if not response_text:
|
||||
response_text = "I could not generate a response for that request."
|
||||
|
||||
if not round_content_chunks:
|
||||
yield "chunk", response_text
|
||||
break
|
||||
else:
|
||||
LOGGER.warning("Max tool rounds reached in chat stream flow")
|
||||
yield "trace", {
|
||||
"kind": "limit",
|
||||
"message": (
|
||||
"Reached tool-call limit before final answer; "
|
||||
"attempting best-effort summary."
|
||||
),
|
||||
}
|
||||
yield "status", "Finalizing response"
|
||||
response_text = await self._try_final_response_without_tools(
|
||||
client=client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
if not response_text:
|
||||
response_text = self._build_tool_limit_fallback(last_tool_results)
|
||||
yield "chunk", response_text
|
||||
|
||||
final_response_text = response_text or "I could not generate a response for that request."
|
||||
if response_text is None:
|
||||
yield "chunk", final_response_text
|
||||
|
||||
yield "status", "Saving chat"
|
||||
result = await self._persist_turn(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text=final_response_text,
|
||||
tool_calls=called_tools,
|
||||
)
|
||||
yield "complete", result
|
||||
|
||||
async def _prepare_turn_context(
|
||||
self, user_message: str
|
||||
) -> tuple[uuid.UUID, list[Message]]:
|
||||
if not (user_message or "").strip():
|
||||
raise HTTPException(status_code=400, detail="Message is required")
|
||||
|
||||
presentation = await self._sql_session.get(PresentationModel, self._presentation_id)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
conversation_id = await self._conversation_store.ensure_conversation_id(
|
||||
self._conversation_id
|
||||
)
|
||||
history = await self._conversation_store.load_history(
|
||||
presentation_id=self._presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
history_messages = self._convert_history_to_messages(history)
|
||||
|
||||
presentation_memory = await self._memory.retrieve_context(user_message)
|
||||
chat_memory = await self._conversation_store.retrieve_semantic_context(
|
||||
presentation_id=self._presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query=user_message,
|
||||
)
|
||||
messages: list[Message] = [
|
||||
SystemMessage(
|
||||
content=build_system_prompt(
|
||||
presentation_memory_context=presentation_memory,
|
||||
chat_memory_context=chat_memory,
|
||||
)
|
||||
),
|
||||
*history_messages,
|
||||
UserMessage(content=user_message),
|
||||
]
|
||||
return conversation_id, messages
|
||||
|
||||
async def _persist_turn(
|
||||
self,
|
||||
*,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
response_text: str,
|
||||
tool_calls: list[str],
|
||||
) -> ChatTurnResult:
|
||||
await self._conversation_store.append_turn(
|
||||
presentation_id=self._presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
assistant_message=response_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
await self._sql_session.commit()
|
||||
|
||||
return ChatTurnResult(
|
||||
conversation_id=conversation_id,
|
||||
response_text=response_text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
async def _run_llm_with_tools(self, messages: list[Message]) -> tuple[str, list[str]]:
|
||||
client = get_client(config=get_llm_config())
|
||||
model = get_model()
|
||||
tools = self._tools.get_tool_definitions()
|
||||
|
||||
called_tools: list[str] = []
|
||||
last_tool_results: list[dict[str, Any]] = []
|
||||
|
||||
for _ in range(MAX_TOOL_ROUNDS):
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.generate,
|
||||
**get_generate_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise handle_llm_client_exceptions(exc)
|
||||
|
||||
if not response.tool_calls:
|
||||
response_text = extract_text(response.content) or (
|
||||
"I could not generate a response for that request."
|
||||
)
|
||||
return response_text, called_tools
|
||||
|
||||
called_tools.extend([tool_call.name for tool_call in response.tool_calls])
|
||||
messages = list(response.messages) if response.messages else list(messages)
|
||||
|
||||
last_tool_results = []
|
||||
for tool_call in response.tool_calls:
|
||||
tool_result = await self._tools.execute_tool_call(tool_call)
|
||||
last_tool_results.append(tool_result)
|
||||
tool_response_content = json.dumps(tool_result, ensure_ascii=False)
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
id=tool_call.id,
|
||||
content=[TextContentPart(text=tool_response_content)],
|
||||
)
|
||||
)
|
||||
|
||||
LOGGER.warning("Max tool rounds reached in chat flow")
|
||||
final_response = await self._try_final_response_without_tools(
|
||||
client=client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
if final_response:
|
||||
return final_response, called_tools
|
||||
|
||||
return self._build_tool_limit_fallback(last_tool_results), called_tools
|
||||
|
||||
async def _try_final_response_without_tools(
|
||||
self,
|
||||
*,
|
||||
client: Any,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
) -> str | None:
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
client.generate,
|
||||
**get_generate_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.warning("Final no-tool synthesis call failed", exc_info=True)
|
||||
return None
|
||||
|
||||
return extract_text(response.content)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_model_note(chunks: list[str]) -> str:
|
||||
text = "".join(chunks).strip()
|
||||
if not text or text in {"{}", "[]"}:
|
||||
return ""
|
||||
|
||||
compact = " ".join(text.split())
|
||||
if compact.lower() in {"start", "end"}:
|
||||
return ""
|
||||
if len(compact) > 600:
|
||||
return f"{compact[:600].rstrip()}..."
|
||||
return compact
|
||||
|
||||
@staticmethod
|
||||
def _event_text(event: Any) -> str:
|
||||
for attr in ("chunk", "delta", "text", "content"):
|
||||
value = getattr(event, attr, None)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _tool_start_message(tool_name: str) -> str:
|
||||
labels = {
|
||||
"getPresentationOutline": "Reading the presentation outline",
|
||||
"searchSlides": "Searching relevant slides",
|
||||
"getSlideAtIndex": "Opening the requested slide",
|
||||
"getAvailableLayouts": "Checking available layouts",
|
||||
"getContentSchemaFromLayoutId": "Checking the layout schema",
|
||||
"generateAssets": "Generating slide assets",
|
||||
"saveSlide": "Saving the slide",
|
||||
}
|
||||
return labels.get(tool_name, f"Running {tool_name}")
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_limit_fallback(last_tool_results: list[dict[str, Any]]) -> str:
|
||||
for entry in reversed(last_tool_results):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if not entry.get("ok"):
|
||||
continue
|
||||
result = entry.get("result")
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
message = result.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
|
||||
return (
|
||||
"I completed several tool operations but could not finalize the response "
|
||||
"within the tool limit. Please ask a follow-up and I will continue."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _summarize_tool_result(tool_name: str, tool_result: dict[str, Any]) -> str:
|
||||
if not tool_result.get("ok"):
|
||||
error = tool_result.get("error")
|
||||
if isinstance(error, str) and error.strip():
|
||||
return f"{tool_name} failed: {error.strip()}"
|
||||
return f"{tool_name} failed."
|
||||
|
||||
result = tool_result.get("result")
|
||||
if isinstance(result, dict):
|
||||
message = result.get("message")
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
|
||||
note = result.get("note")
|
||||
if isinstance(note, str) and note.strip():
|
||||
return note.strip()
|
||||
|
||||
count = result.get("count")
|
||||
if isinstance(count, int):
|
||||
return f"{tool_name} returned {count} result(s)."
|
||||
|
||||
found = result.get("found")
|
||||
if isinstance(found, bool):
|
||||
return (
|
||||
f"{tool_name} found requested data."
|
||||
if found
|
||||
else f"{tool_name} did not find matching data."
|
||||
)
|
||||
|
||||
return f"{tool_name} completed."
|
||||
|
||||
@staticmethod
|
||||
def _convert_history_to_messages(history: list[dict[str, str]]) -> list[Message]:
|
||||
messages: list[Message] = []
|
||||
for item in history:
|
||||
role = item.get("role")
|
||||
content = item.get("content")
|
||||
if not content:
|
||||
continue
|
||||
if role == "user":
|
||||
messages.append(UserMessage(content=content))
|
||||
elif role == "assistant":
|
||||
messages.append(AssistantMessage(content=[content]))
|
||||
return messages
|
||||
212
servers/fastapi/services/chat/sql_chat_history.py
Normal file
212
servers/fastapi/services/chat/sql_chat_history.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
"""Persist presentation chat threads in SQL rows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete as sa_delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from models.sql.chat_history_message import ChatHistoryMessageModel
|
||||
from utils.datetime_utils import get_current_utc_datetime
|
||||
|
||||
|
||||
def _compact_preview(content: str) -> str:
|
||||
preview = content.strip()
|
||||
if len(preview) > 200:
|
||||
return f"{preview[:200]}…"
|
||||
return preview
|
||||
|
||||
|
||||
def _serialize_created_at(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "isoformat"):
|
||||
try:
|
||||
return value.isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _parse_created_at(value: Any) -> datetime | None:
|
||||
if isinstance(value, datetime):
|
||||
return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
normalized = value.strip().replace("Z", "+00:00")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
async def load_messages(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, str]]:
|
||||
rows = await load_messages_with_meta(
|
||||
session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
return [
|
||||
{"role": row["role"], "content": row["content"]}
|
||||
for row in rows
|
||||
if isinstance(row.get("role"), str) and isinstance(row.get("content"), str)
|
||||
]
|
||||
|
||||
|
||||
async def load_messages_with_meta(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = list(
|
||||
(
|
||||
await session.scalars(
|
||||
select(ChatHistoryMessageModel)
|
||||
.where(
|
||||
ChatHistoryMessageModel.presentation_id == presentation_id,
|
||||
ChatHistoryMessageModel.conversation_id == conversation_id,
|
||||
)
|
||||
.order_by(ChatHistoryMessageModel.position.asc())
|
||||
)
|
||||
).all()
|
||||
)
|
||||
out: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
entry: dict[str, Any] = {
|
||||
"role": row.role,
|
||||
"content": row.content,
|
||||
}
|
||||
created = _serialize_created_at(row.created_at)
|
||||
if created:
|
||||
entry["created_at"] = created
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
|
||||
async def replace_messages(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
messages: list[dict[str, str]],
|
||||
) -> None:
|
||||
await session.execute(
|
||||
sa_delete(ChatHistoryMessageModel).where(
|
||||
ChatHistoryMessageModel.presentation_id == presentation_id,
|
||||
ChatHistoryMessageModel.conversation_id == conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_position = 1
|
||||
base_time = get_current_utc_datetime()
|
||||
for index, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if role not in ("user", "assistant"):
|
||||
continue
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
|
||||
created_at = _parse_created_at(message.get("created_at")) or (
|
||||
base_time + timedelta(microseconds=index)
|
||||
)
|
||||
session.add(
|
||||
ChatHistoryMessageModel(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
position=next_position,
|
||||
role=role,
|
||||
content=content,
|
||||
created_at=created_at,
|
||||
)
|
||||
)
|
||||
next_position += 1
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def append_turn(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
presentation_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
tool_calls: list[str] | None = None,
|
||||
) -> None:
|
||||
max_position = await session.scalar(
|
||||
select(func.max(ChatHistoryMessageModel.position)).where(
|
||||
ChatHistoryMessageModel.presentation_id == presentation_id,
|
||||
ChatHistoryMessageModel.conversation_id == conversation_id,
|
||||
)
|
||||
)
|
||||
next_position = int(max_position or 0) + 1
|
||||
now = get_current_utc_datetime()
|
||||
|
||||
session.add(
|
||||
ChatHistoryMessageModel(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
position=next_position,
|
||||
role="user",
|
||||
content=user_message,
|
||||
created_at=now,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ChatHistoryMessageModel(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
position=next_position + 1,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
created_at=now + timedelta(microseconds=1),
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def list_conversations(
|
||||
session: AsyncSession, *, presentation_id: uuid.UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = list(
|
||||
(
|
||||
await session.scalars(
|
||||
select(ChatHistoryMessageModel)
|
||||
.where(ChatHistoryMessageModel.presentation_id == presentation_id)
|
||||
.order_by(
|
||||
ChatHistoryMessageModel.created_at.desc(),
|
||||
ChatHistoryMessageModel.position.desc(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
summary_by_conversation: dict[str, dict[str, Any]] = {}
|
||||
for row in rows:
|
||||
conversation_key = str(row.conversation_id)
|
||||
if conversation_key in summary_by_conversation:
|
||||
continue
|
||||
summary_by_conversation[conversation_key] = {
|
||||
"conversation_id": conversation_key,
|
||||
"updated_at": _serialize_created_at(row.created_at),
|
||||
"last_message_preview": _compact_preview(row.content),
|
||||
}
|
||||
|
||||
summaries = list(summary_by_conversation.values())
|
||||
summaries.sort(key=lambda item: item.get("updated_at") or "", reverse=True)
|
||||
return summaries
|
||||
356
servers/fastapi/services/chat/tools.py
Normal file
356
servers/fastapi/services/chat/tools.py
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import dirtyjson # type: ignore[import-untyped]
|
||||
from llmai.shared import AssistantToolCall, Tool # type: ignore[import-not-found]
|
||||
|
||||
from services.chat.schemas import (
|
||||
GenerateAssetsInput,
|
||||
GenerateIconInput,
|
||||
GenerateImageInput,
|
||||
GetContentSchemaFromLayoutIdInput,
|
||||
GetSlideAtIndexInput,
|
||||
NoArgsInput,
|
||||
SaveSlideInput,
|
||||
SearchSlidesInput,
|
||||
)
|
||||
from services.chat.presentation_context_store import PresentationContextStore
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
|
||||
|
||||
class ChatTools:
|
||||
def __init__(self, memory: PresentationContextStore):
|
||||
self._memory = memory
|
||||
self._tool_handlers: dict[str, ToolHandler] = {
|
||||
"getPresentationOutline": self._get_presentation_outline,
|
||||
"searchSlides": self._search_slides,
|
||||
"getSlideAtIndex": self._get_slide_at_index,
|
||||
"getAvailableLayouts": self._get_available_layouts,
|
||||
"getContentSchemaFromLayoutId": self._get_content_schema_from_layout_id,
|
||||
"generateAssets": self._generate_assets,
|
||||
"generateImage": self._generate_image,
|
||||
"generateIcon": self._generate_icon,
|
||||
"saveSlide": self._save_slide,
|
||||
}
|
||||
|
||||
def get_tool_definitions(self) -> list[Tool]:
|
||||
return [
|
||||
Tool(
|
||||
name="getPresentationOutline",
|
||||
description=(
|
||||
"Live database: current deck structure. "
|
||||
"Use for the **actual** slide list/order and compact previews—not for uploaded PDF text or pre-outline RAG. "
|
||||
"Falls back to stored outlines only if no slide rows exist. "
|
||||
"Return compact sections (no full slide JSON). Use for flow, sections, or 'what slides exist'."
|
||||
),
|
||||
schema=NoArgsInput,
|
||||
strict=True,
|
||||
),
|
||||
Tool(
|
||||
name="searchSlides",
|
||||
description=(
|
||||
"Live SQL slides: keyword/semantic style search with snippets and indices. "
|
||||
"Use to find on-slide text, topics, or which slide mentioned something. "
|
||||
"For source-document-only questions, rely on deck memory; use this when the question is about **slides as built**. "
|
||||
"Always provide both query and limit."
|
||||
),
|
||||
schema=SearchSlidesInput,
|
||||
strict=True,
|
||||
),
|
||||
Tool(
|
||||
name="getSlideAtIndex",
|
||||
description=(
|
||||
"Live SQL: one slide by index—authoritative for exact current content. "
|
||||
"Set includeFullContent=true when you need full JSON (before saveSlide or precise edits). "
|
||||
"If user says slide N, use zero-based index N-1."
|
||||
),
|
||||
schema=GetSlideAtIndexInput,
|
||||
strict=True,
|
||||
),
|
||||
Tool(
|
||||
name="getAvailableLayouts",
|
||||
description=(
|
||||
"List all available layout ids and descriptions for the current "
|
||||
"presentation template."
|
||||
),
|
||||
schema=NoArgsInput,
|
||||
strict=True,
|
||||
),
|
||||
Tool(
|
||||
name="getContentSchemaFromLayoutId",
|
||||
description=(
|
||||
"Fetch the JSON content schema for a layout id. Use before "
|
||||
"saving slide content to validate structure."
|
||||
),
|
||||
schema=GetContentSchemaFromLayoutIdInput,
|
||||
strict=True,
|
||||
),
|
||||
Tool(
|
||||
name="generateAssets",
|
||||
description=(
|
||||
"Generate multiple media assets in one call. Use for all slide "
|
||||
"images and icons before saving content; include every needed "
|
||||
"asset in the assets array instead of calling image/icon tools "
|
||||
"one at a time."
|
||||
),
|
||||
schema=GenerateAssetsInput,
|
||||
strict=True,
|
||||
),
|
||||
Tool(
|
||||
name="saveSlide",
|
||||
description=(
|
||||
"Save slide content for a layout. If replaceOldSlideAtIndex is "
|
||||
"true, replace that index; otherwise insert as a new slide. "
|
||||
"Pass content as a JSON-serialized object string and the server "
|
||||
"will validate it against layout schema before save."
|
||||
),
|
||||
schema=SaveSlideInput,
|
||||
strict=True,
|
||||
),
|
||||
]
|
||||
|
||||
async def execute_tool_call(self, tool_call: AssistantToolCall) -> dict[str, Any]:
|
||||
handler = self._tool_handlers.get(tool_call.name)
|
||||
if not handler:
|
||||
return {
|
||||
"ok": False,
|
||||
"tool": tool_call.name,
|
||||
"error": f"Unsupported tool: {tool_call.name}",
|
||||
}
|
||||
|
||||
try:
|
||||
parsed_args = self._parse_args(tool_call.arguments)
|
||||
LOGGER.info("Executing chat tool %s", tool_call.name)
|
||||
result = await handler(parsed_args)
|
||||
return {"ok": True, "tool": tool_call.name, "result": result}
|
||||
except Exception as exc:
|
||||
LOGGER.exception("Chat tool failed: %s", tool_call.name)
|
||||
return {
|
||||
"ok": False,
|
||||
"tool": tool_call.name,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
async def _get_presentation_outline(self, _: dict[str, Any]) -> dict[str, Any]:
|
||||
outline = await self._memory.get("presentation_outline")
|
||||
if not isinstance(outline, dict):
|
||||
return {
|
||||
"found": False,
|
||||
"message": "Presentation outline is not available in memory yet.",
|
||||
"sections": [],
|
||||
}
|
||||
|
||||
slides = outline.get("slides")
|
||||
if not isinstance(slides, list) or not slides:
|
||||
return {
|
||||
"found": False,
|
||||
"message": "Presentation outline exists but has no slides.",
|
||||
"sections": [],
|
||||
}
|
||||
|
||||
sections: list[dict[str, Any]] = []
|
||||
for position, slide in enumerate(slides):
|
||||
index = position
|
||||
content = ""
|
||||
if isinstance(slide, dict):
|
||||
raw_index = slide.get("index")
|
||||
if isinstance(raw_index, int):
|
||||
index = raw_index
|
||||
raw_content = slide.get("content")
|
||||
if isinstance(raw_content, str):
|
||||
content = raw_content
|
||||
elif raw_content is not None:
|
||||
try:
|
||||
content = json.dumps(raw_content, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(raw_content)
|
||||
elif isinstance(slide, str):
|
||||
content = slide
|
||||
|
||||
title = self._extract_title(content) or f"Slide {index + 1}"
|
||||
sections.append(
|
||||
{
|
||||
"index": index,
|
||||
"slide_number": index + 1,
|
||||
"title": title,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"found": True,
|
||||
"slide_count": len(sections),
|
||||
"sections": sections,
|
||||
"source": outline.get("source", "memory"),
|
||||
}
|
||||
|
||||
async def _search_slides(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = SearchSlidesInput(**args)
|
||||
results = await self._memory.search(payload.query, payload.limit)
|
||||
return {
|
||||
"query": payload.query,
|
||||
"count": len(results),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
async def _get_slide_at_index(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized_args = dict(args)
|
||||
normalized_args.setdefault("includeFullContent", False)
|
||||
payload = GetSlideAtIndexInput(**normalized_args)
|
||||
slide = await self._memory.get_slide_at_index(
|
||||
payload.index,
|
||||
include_full_content=payload.include_full_content,
|
||||
)
|
||||
if not slide and payload.index > 0:
|
||||
# Users often refer to slides as 1-based; allow a safe fallback.
|
||||
fallback_index = payload.index - 1
|
||||
fallback_slide = await self._memory.get_slide_at_index(
|
||||
fallback_index,
|
||||
include_full_content=payload.include_full_content,
|
||||
)
|
||||
if fallback_slide:
|
||||
return {
|
||||
"found": True,
|
||||
"slide": fallback_slide,
|
||||
"requested_index": payload.index,
|
||||
"resolved_index": fallback_index,
|
||||
"note": (
|
||||
"No slide found at requested index; returned one-based fallback "
|
||||
f"at index {fallback_index}."
|
||||
),
|
||||
}
|
||||
if not slide:
|
||||
return {
|
||||
"found": False,
|
||||
"message": f"No slide found at index {payload.index}.",
|
||||
}
|
||||
return {
|
||||
"found": True,
|
||||
"slide": slide,
|
||||
}
|
||||
|
||||
async def _get_available_layouts(self, _: dict[str, Any]) -> dict[str, Any]:
|
||||
layouts = await self._memory.get_available_layouts()
|
||||
return {
|
||||
"count": len(layouts),
|
||||
"layouts": layouts,
|
||||
}
|
||||
|
||||
async def _get_content_schema_from_layout_id(
|
||||
self, args: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
payload = GetContentSchemaFromLayoutIdInput(**args)
|
||||
schema = await self._memory.get_content_schema_from_layout_id(payload.layout_id)
|
||||
if schema is None:
|
||||
return {
|
||||
"found": False,
|
||||
"layout_id": payload.layout_id,
|
||||
"message": "Layout schema not found for the provided layout id.",
|
||||
}
|
||||
return {
|
||||
"found": True,
|
||||
"layout_id": payload.layout_id,
|
||||
"content_schema": schema,
|
||||
}
|
||||
|
||||
async def _generate_image(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = GenerateImageInput(**args)
|
||||
image_url = await self._memory.generate_image(payload.prompt)
|
||||
return {
|
||||
"prompt": payload.prompt,
|
||||
"url": image_url,
|
||||
}
|
||||
|
||||
async def _generate_icon(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = GenerateIconInput(**args)
|
||||
icon_url = await self._memory.generate_icon(payload.query)
|
||||
return {
|
||||
"query": payload.query,
|
||||
"url": icon_url,
|
||||
}
|
||||
|
||||
async def _generate_assets(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = GenerateAssetsInput(**args)
|
||||
generated_assets: list[dict[str, Any]] = []
|
||||
|
||||
for index, asset in enumerate(payload.assets):
|
||||
if asset.kind == "image":
|
||||
result = await self._generate_image({"prompt": asset.prompt})
|
||||
else:
|
||||
result = await self._generate_icon({"query": asset.prompt})
|
||||
|
||||
generated_assets.append(
|
||||
{
|
||||
"index": index,
|
||||
"kind": asset.kind,
|
||||
"prompt": asset.prompt,
|
||||
"url": result.get("url"),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"count": len(generated_assets),
|
||||
"assets": generated_assets,
|
||||
"message": f"Generated {len(generated_assets)} asset(s).",
|
||||
}
|
||||
|
||||
async def _save_slide(self, args: dict[str, Any]) -> dict[str, Any]:
|
||||
payload_args = json.loads(json.dumps(dict(args), ensure_ascii=False))
|
||||
raw_content = payload_args.get("content")
|
||||
if isinstance(raw_content, dict):
|
||||
payload_args["content"] = json.dumps(raw_content, ensure_ascii=False)
|
||||
|
||||
payload = SaveSlideInput(**payload_args)
|
||||
try:
|
||||
content_parsed: Any = dirtyjson.loads(payload.content)
|
||||
except Exception:
|
||||
content_parsed = json.loads(payload.content)
|
||||
|
||||
if not isinstance(content_parsed, dict):
|
||||
raise ValueError("'content' must be a JSON object.")
|
||||
|
||||
content_payload = json.loads(json.dumps(content_parsed, ensure_ascii=False))
|
||||
return await self._memory.save_slide(
|
||||
content=content_payload,
|
||||
layout_id=payload.layout_id,
|
||||
index=payload.index,
|
||||
replace_old_slide_at_index=payload.replace_old_slide_at_index,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(arguments: str | None) -> dict[str, Any]:
|
||||
if not arguments:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = dirtyjson.loads(arguments)
|
||||
except Exception:
|
||||
parsed = json.loads(arguments)
|
||||
|
||||
normalized = json.loads(json.dumps(parsed, ensure_ascii=False))
|
||||
if isinstance(normalized, dict):
|
||||
return normalized
|
||||
|
||||
raise ValueError("Tool arguments must be a JSON object.")
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(markdown_content: str) -> str:
|
||||
for line in markdown_content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
heading_match = re.match(r"^#{1,6}\s*(.+?)\s*$", stripped)
|
||||
if heading_match:
|
||||
return heading_match.group(1).strip()
|
||||
return stripped[:120]
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _truncate(value: str, limit: int) -> str:
|
||||
if len(value) <= limit:
|
||||
return value
|
||||
return f"{value[:limit]}..."
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
import os
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
|
|
@ -11,6 +10,7 @@ from sqlmodel import SQLModel
|
|||
from models.sql.async_presentation_generation_status import (
|
||||
AsyncPresentationGenerationTaskModel,
|
||||
)
|
||||
from models.sql.chat_history_message import ChatHistoryMessageModel
|
||||
from models.sql.image_asset import ImageAsset
|
||||
from models.sql.key_value import KeyValueSqlModel
|
||||
from models.sql.ollama_pull_status import OllamaPullStatus
|
||||
|
|
@ -20,7 +20,6 @@ from models.sql.template import TemplateModel
|
|||
from models.sql.template_create_info import TemplateCreateInfoModel
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sql.webhook_subscription import WebhookSubscription
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
from utils.get_env import get_migrate_database_on_startup_env
|
||||
from utils.db_utils import get_database_url_and_connect_args, get_pool_kwargs
|
||||
|
||||
|
|
@ -42,22 +41,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
yield session
|
||||
|
||||
|
||||
# Container DB (Lives inside the app data directory)
|
||||
_app_data_dir = get_app_data_directory_env() or "/tmp/presenton"
|
||||
container_db_url = f"sqlite+aiosqlite:///{os.path.join(_app_data_dir, 'container.db')}"
|
||||
container_db_engine: AsyncEngine = create_async_engine(
|
||||
container_db_url, connect_args={"check_same_thread": False}
|
||||
)
|
||||
container_db_async_session_maker = async_sessionmaker(
|
||||
container_db_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with container_db_async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Create Database and Tables
|
||||
async def create_db_and_tables():
|
||||
should_run_alembic = get_migrate_database_on_startup_env() in ["true", "True"]
|
||||
|
|
@ -70,24 +53,18 @@ async def create_db_and_tables():
|
|||
PresentationModel.__table__,
|
||||
SlideModel.__table__,
|
||||
KeyValueSqlModel.__table__,
|
||||
ChatHistoryMessageModel.__table__,
|
||||
ImageAsset.__table__,
|
||||
PresentationLayoutCodeModel.__table__,
|
||||
TemplateCreateInfoModel.__table__,
|
||||
TemplateModel.__table__,
|
||||
WebhookSubscription.__table__,
|
||||
AsyncPresentationGenerationTaskModel.__table__,
|
||||
OllamaPullStatus.__table__,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async with container_db_engine.begin() as conn:
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: SQLModel.metadata.create_all(
|
||||
sync_conn,
|
||||
tables=[OllamaPullStatus.__table__],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def dispose_engines():
|
||||
"""Dispose all engine connection pools.
|
||||
|
|
@ -97,4 +74,3 @@ async def dispose_engines():
|
|||
database and prevent stale / leaked connections.
|
||||
"""
|
||||
await sql_engine.dispose()
|
||||
await container_db_engine.dispose()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
|
@ -30,6 +32,129 @@ except Exception:
|
|||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _unwrap_liteparse_json_line_if_stored(text: str) -> str:
|
||||
"""If the whole JSON line from the LiteParse runner was stored as the document, keep only the text field."""
|
||||
if not text:
|
||||
return text
|
||||
s = text.lstrip()
|
||||
if not s.startswith("{"):
|
||||
return text
|
||||
try:
|
||||
payload = json.loads(s)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
return text
|
||||
if not isinstance(payload, dict):
|
||||
return text
|
||||
if (
|
||||
payload.get("ok") is True
|
||||
and "filePath" in payload
|
||||
and isinstance(payload.get("text"), str)
|
||||
):
|
||||
return payload["text"]
|
||||
return text
|
||||
|
||||
|
||||
_RE_TEXT_KEY = re.compile(r'"text"\s*:\s*"')
|
||||
|
||||
|
||||
def _json_unescape_quoted_value(s: str, content_start: int) -> str:
|
||||
"""
|
||||
Unescape a JSON string value. `content_start` is the index of the first character
|
||||
*inside* the value (immediately after the opening quote of the "text" field).
|
||||
If the closing quote is missing (truncated), returns the unescaped rest of the string.
|
||||
"""
|
||||
out: list[str] = []
|
||||
i = content_start
|
||||
n = len(s)
|
||||
while i < n:
|
||||
c = s[i]
|
||||
if c == "\\" and i + 1 < n:
|
||||
e = s[i + 1]
|
||||
if e in '"\\':
|
||||
out.append(e)
|
||||
i += 2
|
||||
elif e == "/":
|
||||
out.append("/")
|
||||
i += 2
|
||||
elif e == "b":
|
||||
out.append("\b")
|
||||
i += 2
|
||||
elif e == "f":
|
||||
out.append("\f")
|
||||
i += 2
|
||||
elif e == "n":
|
||||
out.append("\n")
|
||||
i += 2
|
||||
elif e == "r":
|
||||
out.append("\r")
|
||||
i += 2
|
||||
elif e == "t":
|
||||
out.append("\t")
|
||||
i += 2
|
||||
elif e == "u" and i + 5 < n:
|
||||
try:
|
||||
out.append(chr(int(s[i + 2 : i + 6], 16)))
|
||||
except (ValueError, OverflowError):
|
||||
out.append(s[i : i + 6])
|
||||
i += 6
|
||||
else:
|
||||
out.append(e)
|
||||
i += 2
|
||||
elif c == '"':
|
||||
return "".join(out)
|
||||
else:
|
||||
out.append(c)
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _try_extract_liteparse_text_value_from_malformed_json(s: str) -> Optional[str]:
|
||||
"""
|
||||
When json.loads failed (e.g. truncated or corrupt), find the "text" field value
|
||||
in a LiteParse-shaped object and return only the unescaped string body.
|
||||
"""
|
||||
if not s.startswith("{"):
|
||||
return None
|
||||
head = s[:10000] if len(s) > 10000 else s
|
||||
if not ("ok" in head and "filePath" in head):
|
||||
return None
|
||||
m = _RE_TEXT_KEY.search(s)
|
||||
if not m:
|
||||
return None
|
||||
return _json_unescape_quoted_value(s, m.end())
|
||||
|
||||
|
||||
def _clean_extracted_one_pass(t: str) -> str:
|
||||
for _ in range(3):
|
||||
nxt = _unwrap_liteparse_json_line_if_stored(t)
|
||||
if nxt == t:
|
||||
break
|
||||
t = nxt
|
||||
s = t.lstrip()
|
||||
if s.startswith("{"):
|
||||
m = _try_extract_liteparse_text_value_from_malformed_json(s)
|
||||
if m is not None:
|
||||
return m
|
||||
return t
|
||||
|
||||
|
||||
def clean_extracted_document_text(text: str) -> str:
|
||||
"""
|
||||
Return only the document body: strip LiteParse JSON wrappers, then drop any
|
||||
leading payload before the "text" value (handles truncated/invalid JSON).
|
||||
Multiple passes in case the inner body is again JSON-shaped.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
t = text
|
||||
for _ in range(4):
|
||||
nxt = _clean_extracted_one_pass(t)
|
||||
if nxt == t:
|
||||
return t
|
||||
t = nxt
|
||||
return t
|
||||
|
||||
|
||||
class DocumentsLoader:
|
||||
DECOMPOSE_TIMEOUT_SECONDS = 600
|
||||
|
||||
|
|
@ -107,6 +232,7 @@ class DocumentsLoader:
|
|||
else:
|
||||
document = await asyncio.to_thread(self._parse_with_liteparse, file_path)
|
||||
|
||||
document = clean_extracted_document_text(document)
|
||||
documents.append(document)
|
||||
images.append(imgs)
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,11 @@ class LiteParseService:
|
|||
|
||||
return True, "ok"
|
||||
|
||||
@staticmethod
|
||||
def _use_json_runner_output() -> bool:
|
||||
"""If true, expect one JSON line on stdout (legacy). Default is plain UTF-8 text (better for large PDFs)."""
|
||||
return (os.getenv("LITEPARSE_RUNNER_OUTPUT") or "").strip().lower() == "json"
|
||||
|
||||
def parse_to_markdown(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
@ -271,6 +276,9 @@ class LiteParseService:
|
|||
if tessdata:
|
||||
command.extend(["--tessdata-path", tessdata])
|
||||
|
||||
use_json = self._use_json_runner_output()
|
||||
command.extend(["--python-bridge", "json" if use_json else "plain"])
|
||||
|
||||
LOGGER.info(
|
||||
"[LiteParse] Parsing file=%s ocr_enabled=%s ocr_language=%s dpi=%s num_workers=%s",
|
||||
file_path,
|
||||
|
|
@ -294,6 +302,20 @@ class LiteParseService:
|
|||
_command_str(command),
|
||||
)
|
||||
|
||||
if not use_json:
|
||||
if process.returncode != 0:
|
||||
err = (process.stderr or "").strip() or "LiteParse failed"
|
||||
raise LiteParseError(
|
||||
f"{err}; returncode={process.returncode}; "
|
||||
f"stderr={_snippet(process.stderr)}; stdout={_snippet(process.stdout)}"
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"text": (process.stdout or "").lstrip("\ufeff"),
|
||||
"filePath": file_path,
|
||||
"pageCount": 0,
|
||||
}
|
||||
|
||||
payload: Dict[str, Any]
|
||||
try:
|
||||
payload = self._decode_runner_output(process.stdout)
|
||||
|
|
|
|||
131
servers/fastapi/services/mem0_oss_memory.py
Normal file
131
servers/fastapi/services/mem0_oss_memory.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
"""Single shared mem0 OSS ``Memory`` client for the process.
|
||||
|
||||
All callers (presentation context, chat turns) use the same on-disk Qdrant/SQLite
|
||||
and distinguish data via mem0 ``user_id``:
|
||||
|
||||
- Deck-level (no chat thread): ``{namespace}:{presentation_id}``
|
||||
- Chat thread: ``{namespace}:{presentation_id}:conversation:{conversation_id}``
|
||||
|
||||
The chat flow calls ``ensure_conversation_id`` before the first turn, so a
|
||||
``conversation_id`` exists before any mem0 write for that thread.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from importlib import import_module
|
||||
from typing import Any, Optional
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_memory_init_lock = threading.Lock()
|
||||
_shared_client: Any | None = None
|
||||
_init_attempted = False
|
||||
|
||||
|
||||
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _to_int(value: Optional[str], default: int) -> int:
|
||||
try:
|
||||
parsed = int(value) if value is not None else default
|
||||
return max(1, parsed)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _oss_config_from_env() -> tuple[str, str, str, str, int, dict[str, Any]]:
|
||||
"""Return (mem0_dir, qdrant_path, history_db, collection, dims, from_config_dict)."""
|
||||
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
|
||||
mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
|
||||
qdrant_path = (
|
||||
os.getenv("MEM0_QDRANT_PATH") or os.path.join(mem0_dir, "qdrant")
|
||||
).strip()
|
||||
history_db_path = (
|
||||
os.getenv("MEM0_HISTORY_DB_PATH") or os.path.join(mem0_dir, "history.db")
|
||||
).strip()
|
||||
collection = (
|
||||
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
|
||||
).strip() or "presenton_memories"
|
||||
embedder = (os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed").strip() or "fastembed"
|
||||
model = (
|
||||
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
|
||||
).strip() or "BAAI/bge-small-en-v1.5"
|
||||
dims = _to_int(os.getenv("MEM0_EMBEDDING_DIMS"), default=384)
|
||||
config: dict[str, Any] = {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": collection,
|
||||
"path": qdrant_path,
|
||||
"on_disk": True,
|
||||
"embedding_model_dims": dims,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": embedder,
|
||||
"config": {
|
||||
"model": model,
|
||||
"embedding_dims": dims,
|
||||
},
|
||||
},
|
||||
"history_db_path": history_db_path,
|
||||
}
|
||||
return mem0_dir, qdrant_path, history_db_path, collection, dims, config
|
||||
|
||||
|
||||
def memory_from_config(config: dict[str, Any], *, telemetry_base: str) -> Any:
|
||||
"""Construct ``mem0.Memory``. Caller must hold ``_memory_init_lock`` if used with shared state."""
|
||||
os.makedirs(telemetry_base, exist_ok=True)
|
||||
import mem0.memory.main as mem0_main # type: ignore[import-untyped]
|
||||
|
||||
mem0_main.mem0_dir = telemetry_base
|
||||
memory_cls = getattr(import_module("mem0"), "Memory")
|
||||
return memory_cls.from_config(config)
|
||||
|
||||
|
||||
def get_shared_mem0_client() -> Any | None:
|
||||
"""Return the process-wide mem0 client, or ``None`` if disabled or init failed."""
|
||||
global _shared_client, _init_attempted
|
||||
|
||||
if not _to_bool(os.getenv("MEM0_ENABLED"), default=True):
|
||||
return None
|
||||
if _shared_client is not None:
|
||||
return _shared_client
|
||||
if _init_attempted:
|
||||
return None
|
||||
|
||||
with _memory_init_lock:
|
||||
if _shared_client is not None:
|
||||
return _shared_client
|
||||
if _init_attempted:
|
||||
return None
|
||||
_init_attempted = True
|
||||
try:
|
||||
mem0_dir, qdrant_path, history_db, collection, dims, config = (
|
||||
_oss_config_from_env()
|
||||
)
|
||||
os.makedirs(mem0_dir, exist_ok=True)
|
||||
os.makedirs(qdrant_path, exist_ok=True)
|
||||
telemetry_base = os.path.join(mem0_dir, "telemetry", "oss")
|
||||
_shared_client = memory_from_config(
|
||||
config,
|
||||
telemetry_base=telemetry_base,
|
||||
)
|
||||
LOGGER.info(
|
||||
"Mem0 OSS shared memory initialized (qdrant_path=%s, history_db_path=%s, collection=%s, dims=%s)",
|
||||
qdrant_path,
|
||||
history_db,
|
||||
collection,
|
||||
dims,
|
||||
)
|
||||
except BaseException:
|
||||
LOGGER.exception("Failed to initialize shared Mem0 OSS Memory")
|
||||
_shared_client = None
|
||||
|
||||
return _shared_client
|
||||
|
|
@ -2,10 +2,11 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from importlib import import_module
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from services.mem0_oss_memory import get_shared_mem0_client
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,31 +22,6 @@ class Mem0PresentationMemoryService:
|
|||
os.getenv("MEM0_PRESENTATION_NAMESPACE_PREFIX") or "presentation"
|
||||
).strip() or "presentation"
|
||||
|
||||
self._embedder_provider = (
|
||||
os.getenv("MEM0_EMBEDDER_PROVIDER") or "fastembed"
|
||||
).strip() or "fastembed"
|
||||
self._embedder_model = (
|
||||
os.getenv("MEM0_EMBEDDER_MODEL") or "BAAI/bge-small-en-v1.5"
|
||||
).strip() or "BAAI/bge-small-en-v1.5"
|
||||
self._embedding_dims = self._to_int(
|
||||
os.getenv("MEM0_EMBEDDING_DIMS"),
|
||||
default=384,
|
||||
)
|
||||
|
||||
app_data_dir = (os.getenv("APP_DATA_DIRECTORY") or "/tmp/presenton").strip()
|
||||
self._mem0_dir = (os.getenv("MEM0_DIR") or os.path.join(app_data_dir, "mem0")).strip()
|
||||
self._qdrant_path = (os.getenv("MEM0_QDRANT_PATH") or os.path.join(self._mem0_dir, "qdrant")).strip()
|
||||
self._history_db_path = (
|
||||
os.getenv("MEM0_HISTORY_DB_PATH")
|
||||
or os.path.join(self._mem0_dir, "history.db")
|
||||
).strip()
|
||||
self._collection_name = (
|
||||
os.getenv("MEM0_COLLECTION_NAME") or "presenton_memories"
|
||||
).strip() or "presenton_memories"
|
||||
|
||||
self._client: Any = None
|
||||
self._attempted_client_init = False
|
||||
|
||||
@staticmethod
|
||||
def _to_bool(value: Optional[str], default: bool = False) -> bool:
|
||||
if value is None:
|
||||
|
|
@ -68,27 +44,6 @@ class Mem0PresentationMemoryService:
|
|||
return text
|
||||
return f"{text[:limit]}\n\n[TRUNCATED]"
|
||||
|
||||
def _get_oss_config(self) -> dict:
|
||||
return {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": self._collection_name,
|
||||
"path": self._qdrant_path,
|
||||
"on_disk": True,
|
||||
"embedding_model_dims": self._embedding_dims,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": self._embedder_provider,
|
||||
"config": {
|
||||
"model": self._embedder_model,
|
||||
"embedding_dims": self._embedding_dims,
|
||||
},
|
||||
},
|
||||
"history_db_path": self._history_db_path,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_nonfatal_mem0_error(exc: BaseException) -> bool:
|
||||
return isinstance(exc, (Exception, SystemExit))
|
||||
|
|
@ -96,42 +51,7 @@ class Mem0PresentationMemoryService:
|
|||
async def _get_client(self):
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
if self._attempted_client_init:
|
||||
return None
|
||||
|
||||
self._attempted_client_init = True
|
||||
|
||||
try:
|
||||
module = import_module("mem0")
|
||||
memory_cls = getattr(module, "Memory")
|
||||
|
||||
os.makedirs(self._mem0_dir, exist_ok=True)
|
||||
os.makedirs(self._qdrant_path, exist_ok=True)
|
||||
|
||||
config = self._get_oss_config()
|
||||
|
||||
try:
|
||||
self._client = memory_cls.from_config(config)
|
||||
except Exception:
|
||||
# Backward compatibility across mem0 OSS versions.
|
||||
self._client = memory_cls(config)
|
||||
|
||||
LOGGER.info(
|
||||
"Mem0 OSS presentation memory service initialized (qdrant_path=%s, history_db_path=%s)",
|
||||
self._qdrant_path,
|
||||
self._history_db_path,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if not self._is_nonfatal_mem0_error(exc):
|
||||
raise
|
||||
LOGGER.exception("Failed to initialize Mem0 OSS Memory")
|
||||
self._client = None
|
||||
|
||||
return self._client
|
||||
return get_shared_mem0_client()
|
||||
|
||||
async def _add_message(self, presentation_id: UUID, message: str):
|
||||
client = await self._get_client()
|
||||
|
|
|
|||
137
servers/fastapi/tests/test_chat_conversation_store.py
Normal file
137
servers/fastapi/tests/test_chat_conversation_store.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from services.chat.conversation_store import ChatConversationStore
|
||||
|
||||
|
||||
class TestChatConversationStore:
|
||||
def test_load_history_reads_sql_first(self):
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
expected_history = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
]
|
||||
sql_session = MagicMock()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.sql_chat_history.load_messages",
|
||||
new=AsyncMock(return_value=expected_history),
|
||||
) as load_sql, patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
|
||||
new=AsyncMock(),
|
||||
) as load_mem0:
|
||||
store = ChatConversationStore(sql_session)
|
||||
history = asyncio.run(
|
||||
store.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
load_sql.assert_awaited_once_with(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
load_mem0.assert_not_called()
|
||||
assert history == expected_history
|
||||
|
||||
def test_load_history_falls_back_to_mem0_and_backfills_sql(self):
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
legacy = [
|
||||
{"role": "user", "content": "old"},
|
||||
{"role": "assistant", "content": "from mem0"},
|
||||
]
|
||||
sql_session = MagicMock()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.sql_chat_history.load_messages",
|
||||
new=AsyncMock(return_value=[]),
|
||||
) as load_sql, patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.load_history",
|
||||
new=AsyncMock(return_value=legacy),
|
||||
) as load_mem0, patch(
|
||||
"services.chat.conversation_store.sql_chat_history.replace_messages",
|
||||
new=AsyncMock(),
|
||||
) as replace_messages:
|
||||
store = ChatConversationStore(sql_session)
|
||||
history = asyncio.run(
|
||||
store.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
load_sql.assert_awaited_once()
|
||||
load_mem0.assert_awaited_once()
|
||||
replace_messages.assert_awaited_once_with(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
messages=legacy,
|
||||
)
|
||||
assert history == legacy
|
||||
|
||||
def test_append_turn_persists_sql_and_mem0(self):
|
||||
sql_session = MagicMock()
|
||||
store = ChatConversationStore(sql_session)
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.sql_chat_history.append_turn",
|
||||
new=AsyncMock(),
|
||||
) as append_sql, patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.store_chat_turn",
|
||||
new=AsyncMock(),
|
||||
) as store_mem0:
|
||||
asyncio.run(
|
||||
store.append_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you improve slide 2?",
|
||||
assistant_message="Yes, I will tighten the bullet points.",
|
||||
)
|
||||
)
|
||||
|
||||
append_sql.assert_awaited_once_with(
|
||||
sql_session,
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you improve slide 2?",
|
||||
assistant_message="Yes, I will tighten the bullet points.",
|
||||
tool_calls=None,
|
||||
)
|
||||
store_mem0.assert_awaited_once_with(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you improve slide 2?",
|
||||
assistant_message="Yes, I will tighten the bullet points.",
|
||||
)
|
||||
|
||||
def test_retrieve_semantic_context_delegates_to_chat_memory_store(self):
|
||||
store = ChatConversationStore(MagicMock())
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
"services.chat.conversation_store.CHAT_MEMORY_STORE.retrieve_context",
|
||||
new=AsyncMock(return_value="conversation-scoped context"),
|
||||
) as retrieve_context:
|
||||
context = asyncio.run(
|
||||
store.retrieve_semantic_context(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What did we decide?",
|
||||
)
|
||||
)
|
||||
|
||||
retrieve_context.assert_awaited_once_with(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What did we decide?",
|
||||
)
|
||||
assert context == "conversation-scoped context"
|
||||
249
servers/fastapi/tests/test_chat_memory_store.py
Normal file
249
servers/fastapi/tests/test_chat_memory_store.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import services.mem0_oss_memory as mem0_oss
|
||||
from services.chat.chat_memory_store import ChatMemoryStore
|
||||
|
||||
|
||||
class FakeMemoryClient:
|
||||
instances: list["FakeMemoryClient"] = []
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config
|
||||
self.add_calls = []
|
||||
self.search_calls = []
|
||||
self.get_all_calls = []
|
||||
self.next_search_response = {"results": []}
|
||||
self.next_get_all_response = {"results": []}
|
||||
FakeMemoryClient.instances.append(self)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(config=config)
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
messages = kwargs.get("messages") if "messages" in kwargs else None
|
||||
if messages is None and args:
|
||||
messages = args[0]
|
||||
|
||||
self.add_calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"infer": kwargs.get("infer"),
|
||||
}
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
def search(self, query, *args, **kwargs):
|
||||
self.search_calls.append(
|
||||
{
|
||||
"query": query,
|
||||
"filters": kwargs.get("filters"),
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"top_k": kwargs.get("top_k"),
|
||||
}
|
||||
)
|
||||
return self.next_search_response
|
||||
|
||||
def get_all(self, *args, **kwargs):
|
||||
self.get_all_calls.append(
|
||||
{
|
||||
"filters": kwargs.get("filters"),
|
||||
"user_id": kwargs.get("user_id"),
|
||||
"limit": kwargs.get("limit"),
|
||||
}
|
||||
)
|
||||
return self.next_get_all_response
|
||||
|
||||
|
||||
def _mem0_oss_fresh() -> None:
|
||||
mem0_oss._shared_client = None # type: ignore[attr-defined]
|
||||
mem0_oss._init_attempted = False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestChatMemoryStore:
|
||||
def setup_method(self):
|
||||
FakeMemoryClient.instances = []
|
||||
_mem0_oss_fresh()
|
||||
|
||||
def test_store_chat_turn_uses_conversation_scoped_user_id(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"MEM0_ENABLED": "true",
|
||||
"MEM0_TOP_K": "4",
|
||||
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
|
||||
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
|
||||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.chat.chat_memory_store.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {"provider": "fastembed", "config": {}},
|
||||
}
|
||||
),
|
||||
):
|
||||
store = ChatMemoryStore()
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
|
||||
asyncio.run(
|
||||
store.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Can you tighten slide 3?",
|
||||
assistant_message="Yes, I can make it shorter.",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(FakeMemoryClient.instances) == 1
|
||||
client = FakeMemoryClient.instances[0]
|
||||
assert len(client.add_calls) == 1
|
||||
expected_user_id = (
|
||||
f"presentation:{presentation_id}:conversation:{conversation_id}"
|
||||
)
|
||||
assert client.add_calls[0]["user_id"] == expected_user_id
|
||||
assert client.add_calls[0]["infer"] is False
|
||||
payload = str(client.add_calls[0]["messages"][0]["content"])
|
||||
assert "[chat_turn]" in payload
|
||||
assert "user=Can you tighten slide 3?" in payload
|
||||
assert "assistant=Yes, I can make it shorter." in payload
|
||||
|
||||
def test_retrieve_context_reads_only_conversation_scoped_user_id(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"MEM0_ENABLED": "true",
|
||||
"MEM0_TOP_K": "6",
|
||||
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
|
||||
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
|
||||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.chat.chat_memory_store.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {"provider": "fastembed", "config": {}},
|
||||
}
|
||||
),
|
||||
):
|
||||
store = ChatMemoryStore()
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
expected_user_id = (
|
||||
f"presentation:{presentation_id}:conversation:{conversation_id}"
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
store.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="First turn",
|
||||
assistant_message="First answer",
|
||||
)
|
||||
)
|
||||
|
||||
client = FakeMemoryClient.instances[0]
|
||||
client.next_search_response = {
|
||||
"results": [
|
||||
{"memory": "Chat memory A"},
|
||||
{"memory": "Chat memory A"},
|
||||
{"memory": "Chat memory B"},
|
||||
]
|
||||
}
|
||||
|
||||
context = asyncio.run(
|
||||
store.retrieve_context(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
query="What did we decide?",
|
||||
)
|
||||
)
|
||||
|
||||
assert "Chat memory A" in context
|
||||
assert "Chat memory B" in context
|
||||
assert context.count("Chat memory A") == 1
|
||||
|
||||
assert len(client.search_calls) == 1
|
||||
assert client.search_calls[0]["query"] == "What did we decide?"
|
||||
assert client.search_calls[0]["filters"] == {"user_id": expected_user_id}
|
||||
assert client.search_calls[0]["top_k"] == 6
|
||||
|
||||
def test_load_history_reads_conversation_scoped_turns(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"MEM0_ENABLED": "true",
|
||||
"MEM0_PRESENTATION_NAMESPACE_PREFIX": "presentation",
|
||||
"APP_DATA_DIRECTORY": "/tmp/presenton-test",
|
||||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.chat.chat_memory_store.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {"provider": "fastembed", "config": {}},
|
||||
}
|
||||
),
|
||||
):
|
||||
store = ChatMemoryStore()
|
||||
presentation_id = uuid.uuid4()
|
||||
conversation_id = uuid.uuid4()
|
||||
expected_user_id = (
|
||||
f"presentation:{presentation_id}:conversation:{conversation_id}"
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
store.store_chat_turn(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
user_message="Draft intro",
|
||||
assistant_message="Updated intro done.",
|
||||
)
|
||||
)
|
||||
|
||||
client = FakeMemoryClient.instances[0]
|
||||
client.next_get_all_response = {
|
||||
"results": [
|
||||
{
|
||||
"memory": (
|
||||
"[chat_turn]\n"
|
||||
"turn_created_at=2026-04-25T10:00:00+00:00\n"
|
||||
"user=Draft intro\nassistant=Updated intro done."
|
||||
),
|
||||
"created_at": "2026-04-25T10:00:01+00:00",
|
||||
},
|
||||
{
|
||||
"memory": (
|
||||
"[chat_turn]\n"
|
||||
"turn_created_at=2026-04-25T10:01:00+00:00\n"
|
||||
"user=Add roadmap\nassistant=Roadmap slide added."
|
||||
),
|
||||
"created_at": "2026-04-25T10:01:01+00:00",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
history = asyncio.run(
|
||||
store.load_history(
|
||||
presentation_id=presentation_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert history == [
|
||||
{"role": "user", "content": "Draft intro"},
|
||||
{"role": "assistant", "content": "Updated intro done."},
|
||||
{"role": "user", "content": "Add roadmap"},
|
||||
{"role": "assistant", "content": "Roadmap slide added."},
|
||||
]
|
||||
|
||||
assert len(client.get_all_calls) == 1
|
||||
assert client.get_all_calls[0]["filters"] == {"user_id": expected_user_id}
|
||||
assert client.get_all_calls[0]["limit"] >= 10
|
||||
60
servers/fastapi/tests/test_documents_loader_unwrap.py
Normal file
60
servers/fastapi/tests/test_documents_loader_unwrap.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
|
||||
from services.documents_loader import (
|
||||
_unwrap_liteparse_json_line_if_stored,
|
||||
clean_extracted_document_text,
|
||||
)
|
||||
|
||||
|
||||
def test_unwrap_strips_liteparse_json_line():
|
||||
inner = "Title\n\nBody with \"quotes\" and\nnewlines."
|
||||
line = json.dumps(
|
||||
{"ok": True, "filePath": "/tmp/x.pdf", "text": inner},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
assert _unwrap_liteparse_json_line_if_stored(line) == inner
|
||||
assert _unwrap_liteparse_json_line_if_stored(" \n" + line) == inner
|
||||
|
||||
|
||||
def test_unwrap_leaves_plain_text():
|
||||
t = "Not JSON. {Braces} in prose."
|
||||
assert _unwrap_liteparse_json_line_if_stored(t) is t
|
||||
|
||||
|
||||
def test_unwrap_rejects_malformed_json():
|
||||
t = "{not valid json"
|
||||
assert _unwrap_liteparse_json_line_if_stored(t) is t
|
||||
|
||||
|
||||
def test_clean_extracts_text_when_json_truncated():
|
||||
"""Drops everything before the "text" value and unescapes, even if JSON is not closed."""
|
||||
blob = (
|
||||
'{"ok": true, "filePath": "/tmp/x.pdf", "text": " similarweb | HypeAuditor\\n\\n2024" '
|
||||
)
|
||||
# Missing closing " } — json.loads will fail, fallback path should still return body
|
||||
out = clean_extracted_document_text(blob)
|
||||
assert "similarweb" in out
|
||||
assert "ok" not in out
|
||||
assert "filePath" not in out
|
||||
|
||||
|
||||
def test_clean_same_as_unwrap_for_valid_line():
|
||||
inner = "Prose only."
|
||||
line = json.dumps(
|
||||
{"ok": True, "filePath": "/tmp/x.pdf", "text": inner},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
assert clean_extracted_document_text(line) == inner
|
||||
|
||||
|
||||
def test_clean_double_json_embedded_in_text_field():
|
||||
inner2 = "Final body."
|
||||
inner1 = json.dumps(
|
||||
{"ok": True, "filePath": "/a.pdf", "text": inner2},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
outer = json.dumps(
|
||||
{"ok": True, "filePath": "/b.pdf", "text": inner1},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
assert clean_extracted_document_text(outer) == inner2
|
||||
|
|
@ -4,8 +4,12 @@ from unittest.mock import patch
|
|||
from services.liteparse_service import LiteParseService
|
||||
|
||||
|
||||
def _ok_process(stdout: str = '{"ok": true, "text": "ok"}'):
|
||||
return SimpleNamespace(returncode=0, stdout=stdout, stderr="")
|
||||
def _ok_process(
|
||||
stdout: str = "ok",
|
||||
returncode: int = 0,
|
||||
stderr: str = "",
|
||||
):
|
||||
return SimpleNamespace(returncode=returncode, stdout=stdout, stderr=stderr)
|
||||
|
||||
|
||||
class TestLiteParseService:
|
||||
|
|
@ -26,13 +30,16 @@ class TestLiteParseService:
|
|||
return_value=_ok_process(),
|
||||
) as mock_run:
|
||||
service = LiteParseService(timeout_seconds=30)
|
||||
service.parse("/tmp/sample.pdf", ocr_enabled=True, ocr_language="eng")
|
||||
r = service.parse("/tmp/sample.pdf", ocr_enabled=True, ocr_language="eng")
|
||||
assert r["ok"] is True
|
||||
assert r["text"] == "ok"
|
||||
|
||||
command = mock_run.call_args.args[0]
|
||||
assert "--dpi" in command
|
||||
assert command[command.index("--dpi") + 1] == "120"
|
||||
assert "--num-workers" in command
|
||||
assert command[command.index("--num-workers") + 1] == "1"
|
||||
assert command[command.index("--python-bridge") + 1] == "plain"
|
||||
|
||||
def test_parse_uses_env_overrides(self):
|
||||
with patch.dict(
|
||||
|
|
@ -79,3 +86,23 @@ class TestLiteParseService:
|
|||
command = mock_run.call_args.args[0]
|
||||
assert command[command.index("--dpi") + 1] == "72"
|
||||
assert command[command.index("--num-workers") + 1] == "1"
|
||||
|
||||
def test_parse_json_bridge_env(self):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"LITEPARSE_RUNNER_OUTPUT": "json"},
|
||||
clear=False,
|
||||
), patch.object(
|
||||
LiteParseService,
|
||||
"check_runtime_ready",
|
||||
return_value=(True, "ok"),
|
||||
), patch(
|
||||
"services.liteparse_service.subprocess.run",
|
||||
return_value=_ok_process(stdout='{"ok": true, "text": "legacy"}\n'),
|
||||
) as mock_run:
|
||||
service = LiteParseService(timeout_seconds=30)
|
||||
r = service.parse("/tmp/sample.pdf", ocr_enabled=True, ocr_language="eng")
|
||||
assert r["text"] == "legacy"
|
||||
|
||||
command = mock_run.call_args.args[0]
|
||||
assert command[command.index("--python-bridge") + 1] == "json"
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@ import asyncio
|
|||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import services.mem0_oss_memory as mem0_oss
|
||||
from services.mem0_presentation_memory_service import Mem0PresentationMemoryService
|
||||
|
||||
|
||||
class FakeMemoryClient:
|
||||
instances = []
|
||||
instances: list["FakeMemoryClient"] = []
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config
|
||||
|
|
@ -45,13 +46,15 @@ class FakeMemoryClient:
|
|||
return self.next_search_response
|
||||
|
||||
|
||||
class FakeMem0Module:
|
||||
Memory = FakeMemoryClient
|
||||
def _mem0_oss_fresh() -> None:
|
||||
mem0_oss._shared_client = None # type: ignore[attr-defined]
|
||||
mem0_oss._init_attempted = False # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestMem0PresentationMemoryService:
|
||||
def setup_method(self):
|
||||
FakeMemoryClient.instances = []
|
||||
_mem0_oss_fresh()
|
||||
|
||||
def test_store_generation_context_uses_presentation_scope(self):
|
||||
with patch.dict(
|
||||
|
|
@ -62,8 +65,25 @@ class TestMem0PresentationMemoryService:
|
|||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.mem0_presentation_memory_service.import_module",
|
||||
return_value=FakeMem0Module,
|
||||
"services.mem0_presentation_memory_service.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"on_disk": True,
|
||||
"embedding_model_dims": 384,
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "fastembed",
|
||||
"config": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"embedding_dims": 384,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
):
|
||||
service = Mem0PresentationMemoryService()
|
||||
presentation_id = uuid.uuid4()
|
||||
|
|
@ -115,8 +135,19 @@ class TestMem0PresentationMemoryService:
|
|||
},
|
||||
clear=False,
|
||||
), patch(
|
||||
"services.mem0_presentation_memory_service.import_module",
|
||||
return_value=FakeMem0Module,
|
||||
"services.mem0_presentation_memory_service.get_shared_mem0_client",
|
||||
return_value=FakeMemoryClient.from_config(
|
||||
{
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
"embedder": {
|
||||
"provider": "fastembed",
|
||||
"config": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"embedding_dims": 384,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
):
|
||||
service = Mem0PresentationMemoryService()
|
||||
presentation_id = uuid.uuid4()
|
||||
|
|
@ -154,3 +185,4 @@ class TestMem0PresentationMemoryService:
|
|||
"user_id": f"presentation:{presentation_id}"
|
||||
}
|
||||
assert client.search_calls[0]["top_k"] == 5
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ def get_llm_config() -> ClientConfig:
|
|||
raise HTTPException(status_code=400, detail="OpenAI API Key is not set")
|
||||
return OpenAIClientConfig(
|
||||
api_key=api_key,
|
||||
api_type=OpenAIApiType.RESPONSES,
|
||||
api_type=OpenAIApiType.COMPLETIONS,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
api_key = get_google_api_key_env()
|
||||
|
|
|
|||
|
|
@ -10,8 +10,42 @@ from llmai.shared import (
|
|||
ResponseFormat,
|
||||
normalize_content_parts,
|
||||
)
|
||||
from llmai.shared.tools import Tool # type: ignore[import-not-found]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.llm_config import get_extra_body
|
||||
from utils.llm_provider import get_llm_provider
|
||||
from utils.schema_utils import flatten_json_schema
|
||||
|
||||
|
||||
def _tools_for_google_gemini(tools: list[LLMTool]) -> list[LLMTool]:
|
||||
"""Gemini's Python SDK rejects ``$ref`` / ``$defs`` in function parameters; inline them."""
|
||||
converted: list[LLMTool] = []
|
||||
for tool in tools:
|
||||
if not isinstance(tool, Tool):
|
||||
converted.append(tool)
|
||||
continue
|
||||
schema_obj = tool.input_schema
|
||||
if isinstance(schema_obj, dict):
|
||||
raw = dict(schema_obj)
|
||||
elif isinstance(schema_obj, type) and issubclass(schema_obj, BaseModel):
|
||||
raw = schema_obj.model_json_schema()
|
||||
elif isinstance(schema_obj, BaseModel):
|
||||
raw = schema_obj.__class__.model_json_schema()
|
||||
else:
|
||||
converted.append(tool)
|
||||
continue
|
||||
flat = flatten_json_schema(raw)
|
||||
converted.append(
|
||||
Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
schema=flat,
|
||||
strict=tool.strict,
|
||||
)
|
||||
)
|
||||
return converted
|
||||
|
||||
|
||||
def get_generate_kwargs(
|
||||
|
|
@ -30,7 +64,10 @@ def get_generate_kwargs(
|
|||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if get_llm_provider() == LLMProvider.GOOGLE:
|
||||
kwargs["tools"] = _tools_for_google_gemini(tools)
|
||||
else:
|
||||
kwargs["tools"] = tools
|
||||
if response_format is not None:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
|
|
|
|||
13
servers/fastapi/uv.lock
generated
13
servers/fastapi/uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = "==3.11.*"
|
||||
|
||||
[[package]]
|
||||
|
|
@ -811,9 +811,7 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/fb/c6/dba32cab7e3a625b011aa5647486e2d28423a48845a2998c126dd69c85e1/greenlet-3.4.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:805bebb4945094acbab757d34d6e1098be6de8966009ab9ca54f06ff492def58", size = 285504, upload-time = "2026-04-08T15:52:14.071Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/f4/7cb5c2b1feb9a1f50e038be79980dfa969aa91979e5e3a18fdbcfad2c517/greenlet-3.4.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:439fc2f12b9b512d9dfa681c5afe5f6b3232c708d13e6f02c845e0d9f4c2d8c6", size = 605476, upload-time = "2026-04-08T16:24:37.064Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/af/b66ab0b2f9a4c5a867c136bf66d9599f34f21a1bcca26a2884a29c450bd9/greenlet-3.4.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a70ed1cb0295bee1df57b63bf7f46b4e56a5c93709eea769c1fec1bb23a95875", size = 618336, upload-time = "2026-04-08T16:30:56.59Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/31/56c43d2b5de476f77d36ceeec436328533bff960a4cba9a07616e93063ab/greenlet-3.4.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8c5696c42e6bb5cfb7c6ff4453789081c66b9b91f061e5e9367fa15792644e76", size = 625045, upload-time = "2026-04-08T16:40:37.111Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/5c/8c5633ece6ba611d64bf2770219a98dd439921d6424e4e8cf16b0ac74ea5/greenlet-3.4.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c660bce1940a1acae5f51f0a064f1bc785d07ea16efcb4bc708090afc4d69e83", size = 613515, upload-time = "2026-04-08T15:56:32.478Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/ca/704d4e2c90acb8bdf7ae593f5cbc95f58e82de95cc540fb75631c1054533/greenlet-3.4.0-cp311-cp311-manylinux_2_39_riscv64.whl", hash = "sha256:89995ce5ddcd2896d89615116dd39b9703bfa0c07b583b85b89bf1b5d6eddf81", size = 419745, upload-time = "2026-04-08T16:43:04.022Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/df/950d15bca0d90a0e7395eb777903060504cdb509b7b705631e8fb69ff415/greenlet-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ee407d4d1ca9dc632265aee1c8732c4a2d60adff848057cdebfe5fe94eb2c8a2", size = 1574623, upload-time = "2026-04-08T16:26:18.596Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/e7/0839afab829fcb7333c9ff6d80c040949510055d2d4d63251f0d1c7c804e/greenlet-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:956215d5e355fffa7c021d168728321fd4d31fd730ac609b1653b450f6a4bc71", size = 1639579, upload-time = "2026-04-08T15:57:29.231Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/2b/b4482401e9bcaf9f5c97f67ead38db89c19520ff6d0d6699979c6efcc200/greenlet-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:5cb614ace7c27571270354e9c9f696554d073f8aa9319079dcba466bbdead711", size = 238233, upload-time = "2026-04-08T17:02:54.286Z" },
|
||||
|
|
@ -1187,7 +1185,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "llmai"
|
||||
version = "0.1.9"
|
||||
version = "0.2.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anthropic" },
|
||||
|
|
@ -1195,9 +1193,9 @@ dependencies = [
|
|||
{ name = "google-genai" },
|
||||
{ name = "openai" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/dd/dc7cb70fb5f9b33abf457b2bded61f27189232e769badc065ca0e2d1cda2/llmai-0.1.9.tar.gz", hash = "sha256:00ee4b987dc07a65425a1296df937d7640541630fd347ca758ea1ed496880e67", size = 46798, upload-time = "2026-04-23T07:34:49.975Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1f/28/7dc14f9a417d933f8c799665a0a86ee31489dde40e9264ef6dc41b32759c/llmai-0.2.2.tar.gz", hash = "sha256:1f62d1e3d05fa5c43bbd948275398668d8f85c8d2fde252d34562332101dd7b3", size = 47863, upload-time = "2026-04-27T10:32:08.726Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl", hash = "sha256:dcd94502516586bbd6394fe2c9c610941ff4c19eae0f1316825435f35134cfb4", size = 58968, upload-time = "2026-04-23T07:34:48.375Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/a1/e34d8aaa015fd7a2cfefa391b6eb2995aee675d7f8d93f636f4da8b07c13/llmai-0.2.2-py3-none-any.whl", hash = "sha256:d65c016983036319df704927b5e7fece494efde25b4865caf9a95d555a25449c", size = 59874, upload-time = "2026-04-27T10:32:07.272Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1634,6 +1632,7 @@ dependencies = [
|
|||
{ name = "fastembed-vectorstore" },
|
||||
{ name = "fastmcp" },
|
||||
{ name = "google-genai" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llmai" },
|
||||
{ name = "mem0ai", extra = ["nlp"] },
|
||||
{ name = "nltk" },
|
||||
|
|
@ -1655,7 +1654,7 @@ requires-dist = [
|
|||
{ name = "fastembed-vectorstore", specifier = ">=0.5.2" },
|
||||
{ name = "fastmcp", specifier = ">=2.11.0" },
|
||||
{ name = "google-genai", specifier = ">=1.28.0" },
|
||||
{ name = "llmai", specifier = "==0.1.9" },
|
||||
{ name = "llmai", url = "https://files.pythonhosted.org/packages/c6/86/5dcfd77b634947cd570680b13217b40bc72cd7d9e7f04cc1a52ff5f549a0/llmai-0.1.9-py3-none-any.whl" },
|
||||
{ name = "mem0ai", extras = ["nlp"], specifier = ">=0.1.115" },
|
||||
{ name = "nltk", specifier = ">=3.9.1" },
|
||||
{ name = "openai", specifier = ">=1.98.0" },
|
||||
|
|
|
|||
|
|
@ -45,11 +45,7 @@ const PresentationPage = ({ presentation_id }: { presentation_id: string }) => {
|
|||
}
|
||||
}, [presentationData]);
|
||||
|
||||
// Ensure /app_data and /static image paths resolve through the backend origin.
|
||||
useEffect(() => {
|
||||
const observer = setupImageUrlConverter();
|
||||
return () => observer?.disconnect();
|
||||
}, []);
|
||||
|
||||
|
||||
// Function to fetch the slides
|
||||
useEffect(() => {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -413,12 +413,18 @@ const PresentationHeader = ({
|
|||
|
||||
return (
|
||||
<>
|
||||
<div className="py-7 sticky top-0 bg-white z-50 mb-[17px] font-syne flex justify-between items-center gap-4">
|
||||
{presentationData && !isStreaming && !isEditingTitle ? (
|
||||
<ToolTip content="Rename presentation">{titleBlock}</ToolTip>
|
||||
) : (
|
||||
titleBlock
|
||||
)}
|
||||
<div className="py-[18px] px-4 sticky top-0 bg-white z-50 shadow-sm font-syne flex justify-between items-center gap-4">
|
||||
<div className="flex items-center gap-3">
|
||||
|
||||
<img onClick={() => {
|
||||
router.push("/dashboard");
|
||||
}} src="/logo-with-bg.png" alt="" className="w-10 h-10 cursor-pointer object-contain" />
|
||||
{presentationData && !isStreaming && !isEditingTitle ? (
|
||||
<ToolTip content="Rename presentation">{titleBlock}</ToolTip>
|
||||
) : (
|
||||
titleBlock
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2.5">
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"use client";
|
||||
import React, { useEffect, useLayoutEffect, useState } from "react";
|
||||
import React, { useEffect, useLayoutEffect, useRef, useState } from "react";
|
||||
import { useSelector } from "react-redux";
|
||||
import { RootState } from "@/store/store";
|
||||
import "../../utils/prism-languages";
|
||||
|
|
@ -21,8 +21,8 @@ import { PresentationPageProps } from "../types";
|
|||
import LoadingState from "./LoadingState";
|
||||
import { applyPresentationThemeToElement } from "../utils/applyPresentationThemeDom";
|
||||
|
||||
import { usePresentationUndoRedo } from "../hooks/PresentationUndoRedo";
|
||||
import PresentationHeader from "./PresentationHeader";
|
||||
import Chat from "./Chat";
|
||||
|
||||
const PresentationPage: React.FC<PresentationPageProps> = ({
|
||||
presentation_id,
|
||||
|
|
@ -33,6 +33,7 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
|
|||
const [selectedSlide, setSelectedSlide] = useState(0);
|
||||
const [isFullscreen, setIsFullscreen] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
const slidesScrollContainerRef = useRef<HTMLDivElement | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
|
||||
|
|
@ -40,6 +41,11 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
|
|||
const { presentationData, isStreaming } = useSelector(
|
||||
(state: RootState) => state.presentationGeneration
|
||||
);
|
||||
const slidesLength = presentationData?.slides?.length ?? 0;
|
||||
const lastStreamingSlideIndex =
|
||||
slidesLength > 0
|
||||
? presentationData?.slides?.[slidesLength - 1]?.index
|
||||
: undefined;
|
||||
|
||||
// Auto-save functionality
|
||||
const { isSaving } = useAutoSave({
|
||||
|
|
@ -78,7 +84,38 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
|
|||
fetchUserSlides
|
||||
);
|
||||
|
||||
usePresentationUndoRedo();
|
||||
useEffect(() => {
|
||||
if (!isStreaming) return;
|
||||
|
||||
const scrollContainer = slidesScrollContainerRef.current;
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const frame = window.requestAnimationFrame(() => {
|
||||
if (slidesLength <= 1) {
|
||||
scrollContainer.scrollTo({ top: 0, behavior: "auto" });
|
||||
return;
|
||||
}
|
||||
|
||||
if (lastStreamingSlideIndex === undefined) return;
|
||||
|
||||
const slideElement = document.getElementById(
|
||||
`slide-${lastStreamingSlideIndex}`
|
||||
);
|
||||
if (!slideElement) return;
|
||||
|
||||
const containerRect = scrollContainer.getBoundingClientRect();
|
||||
const slideRect = slideElement.getBoundingClientRect();
|
||||
const slideTop =
|
||||
slideRect.top - containerRect.top + scrollContainer.scrollTop;
|
||||
|
||||
scrollContainer.scrollTo({
|
||||
top: Math.max(slideTop, 0),
|
||||
behavior: "smooth",
|
||||
});
|
||||
});
|
||||
|
||||
return () => window.cancelAnimationFrame(frame);
|
||||
}, [isStreaming, lastStreamingSlideIndex, slidesLength]);
|
||||
|
||||
useEffect(() => {
|
||||
trackEvent(MixpanelEvent.Presentation_Editor_Viewed, {
|
||||
|
|
@ -141,66 +178,70 @@ const PresentationPage: React.FC<PresentationPageProps> = ({
|
|||
}
|
||||
|
||||
return (
|
||||
<div className="h-screen overflow-hidden font-syne ">
|
||||
<div className="h-screen overflow-hidden font-syne">
|
||||
<div
|
||||
style={{
|
||||
background: "#ffffff",
|
||||
background: "#EDEEEF",
|
||||
}}
|
||||
id="presentation-slides-wrapper"
|
||||
className="flex gap-6 relative "
|
||||
className="relative flex h-full flex-col overflow-hidden"
|
||||
>
|
||||
<div className="w-[200px]">
|
||||
<SidePanel
|
||||
selectedSlide={selectedSlide}
|
||||
onSlideClick={handleSlideClick}
|
||||
presentationId={presentation_id}
|
||||
loading={loading}
|
||||
/>
|
||||
</div>
|
||||
<div className=" w-full h-[calc(100vh-20px)] pr-[25px] overflow-y-auto">
|
||||
<PresentationHeader presentation_id={presentation_id} isPresentationSaving={isSaving} currentSlide={selectedSlide} />
|
||||
<div
|
||||
|
||||
style={{
|
||||
background: "rgba(255, 255, 255, 0.10)",
|
||||
boxShadow: "0 0 20.01px 0 rgba(122, 90, 248, 0.16) inset",
|
||||
}}
|
||||
className="p-6 rounded-[20px] font-inter flex flex-col items-center overflow-hidden justify-center border border-[#EDECEC] "
|
||||
>
|
||||
<div className="w-full max-w-[1280px] h-full">
|
||||
|
||||
{!presentationData ||
|
||||
loading ||
|
||||
!presentationData?.slides ||
|
||||
presentationData?.slides.length === 0 ? (
|
||||
<div className="relative w-full h-[calc(100vh-120px)] mx-auto">
|
||||
<div className="">
|
||||
{Array.from({ length: 2 }).map((_, index) => (
|
||||
<Skeleton
|
||||
key={index}
|
||||
className="aspect-video bg-gray-400 my-4 w-full mx-auto "
|
||||
/>
|
||||
))}
|
||||
<PresentationHeader presentation_id={presentation_id} isPresentationSaving={isSaving} currentSlide={selectedSlide} />
|
||||
<div className="flex flex-1 min-h-0 gap-6 overflow-hidden">
|
||||
<div className="w-[120px] h-full shrink-0 self-start sticky top-0 pt-[18px]">
|
||||
<SidePanel
|
||||
selectedSlide={selectedSlide}
|
||||
onSlideClick={handleSlideClick}
|
||||
presentationId={presentation_id}
|
||||
loading={loading}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full min-w-0 h-full flex-1 pt-[18px]">
|
||||
<div
|
||||
ref={slidesScrollContainerRef}
|
||||
className="font-inter h-full overflow-y-auto hide-scrollbar scroll-pt-[18px]"
|
||||
>
|
||||
<div className="w-full max-w-[1280px] min-h-full mx-auto flex flex-col items-center pb-8">
|
||||
{!presentationData ||
|
||||
loading ||
|
||||
!presentationData?.slides ||
|
||||
presentationData?.slides.length === 0 ? (
|
||||
<div className="relative w-full h-[calc(100vh-120px)] mx-auto hide-scrollbar">
|
||||
<div className="">
|
||||
{Array.from({ length: 2 }).map((_, index) => (
|
||||
<Skeleton
|
||||
key={index}
|
||||
className="aspect-video bg-gray-400 my-4 w-full mx-auto "
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{stream && <LoadingState />}
|
||||
</div>
|
||||
{stream && <LoadingState />}
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{presentationData &&
|
||||
presentationData.slides &&
|
||||
presentationData.slides.length > 0 &&
|
||||
presentationData.slides.map((slide: any, index: number) => (
|
||||
<SlideContent
|
||||
key={`${slide.type}-${index}-${slide.index}`}
|
||||
slide={slide}
|
||||
index={index}
|
||||
presentationId={presentation_id}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
) : (
|
||||
<>
|
||||
{presentationData &&
|
||||
presentationData.slides &&
|
||||
presentationData.slides.length > 0 &&
|
||||
presentationData.slides.map((slide: any, index: number) => (
|
||||
<SlideContent
|
||||
key={`${slide.type}-${index}-${slide.index}`}
|
||||
slide={slide}
|
||||
index={index}
|
||||
presentationId={presentation_id}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full max-w-[370px] h-full shrink-0 self-start sticky top-0">
|
||||
<Chat
|
||||
presentationId={presentation_id}
|
||||
currentSlide={selectedSlide}
|
||||
onPresentationChanged={() => fetchUserSlides({ clearHistory: false })}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ import {
|
|||
} from "@dnd-kit/sortable";
|
||||
import { setPresentationData } from "@/store/slices/presentationGeneration";
|
||||
import { SortableSlide } from "./SortableSlide";
|
||||
import SlideScale from "../../components/PresentationRender";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { usePathname, useRouter } from "next/navigation";
|
||||
import { usePathname } from "next/navigation";
|
||||
import NewSlide from "./NewSlide";
|
||||
import { trackEvent, MixpanelEvent } from "@/utils/mixpanel";
|
||||
import { SlideThumbnailCard } from "./SlideThumbnailCard";
|
||||
|
||||
interface SidePanelProps {
|
||||
selectedSlide: number;
|
||||
|
|
@ -40,8 +40,6 @@ const SidePanel = ({
|
|||
|
||||
loading,
|
||||
}: SidePanelProps) => {
|
||||
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const [showNewSlideSelection, setShowNewSlideSelection] = useState(false);
|
||||
|
||||
|
|
@ -132,48 +130,36 @@ const SidePanel = ({
|
|||
}
|
||||
|
||||
return (
|
||||
<div className="bg-[#F6F6F9] pt-8 px-4 w-[200px]">
|
||||
<div className="px-4 w-[120px] h-full">
|
||||
|
||||
<img onClick={() => {
|
||||
router.push("/dashboard");
|
||||
}} src="/logo-with-bg.png" alt="" className="w-10 h-10 cursor-pointer object-contain" />
|
||||
|
||||
<Separator orientation="horizontal" className="my-6 " />
|
||||
<div
|
||||
className={`
|
||||
relative bg-[#F6F6F9] h-full z-50 xl:z-auto
|
||||
relative h-full z-50 xl:z-auto
|
||||
transition-all duration-300 ease-in-out
|
||||
`}
|
||||
>
|
||||
<div
|
||||
|
||||
className="w-full h-[calc(100vh-120px)] hide-scrollbar overflow-hidden slide-theme "
|
||||
className="w-full h-full hide-scrollbar overflow-hidden slide-theme flex flex-col"
|
||||
>
|
||||
|
||||
<p className="text-xl font-normal font-syne pb-3.5 text-[#000000]">Slides ({presentationData?.slides?.length})</p>
|
||||
|
||||
<DndContext
|
||||
sensors={sensors}
|
||||
collisionDetection={closestCenter}
|
||||
onDragEnd={handleDragEnd}
|
||||
>
|
||||
<div className=" overflow-y-auto w-full hide-scrollbar h-[calc(100%-140px)] space-y-3.5">
|
||||
<div className="overflow-y-auto w-full hide-scrollbar min-h-0 flex-1 space-y-3.5">
|
||||
{isStreaming ? (
|
||||
presentationData &&
|
||||
presentationData?.slides.map((slide: any, index: number) => (
|
||||
<div
|
||||
<SlideThumbnailCard
|
||||
key={`${slide.id}-${index}`}
|
||||
onClick={() => onSlideClick(index)}
|
||||
className={` cursor-pointer ring-2 rounded-[12px] transition-all duration-200 ${selectedSlide === index ? ' ring-[#5141e5]' : 'ring-gray-200'
|
||||
}`}
|
||||
>
|
||||
<div className=" bg-white pointer-events-none relative overflow-hidden aspect-video">
|
||||
<div className="absolute bg-gray-100/5 z-50 top-0 left-0 w-full h-full" />
|
||||
<div className="transform scale-[0.2] flex justify-center items-center origin-top-left w-[500%] h-[500%]">
|
||||
<SlideScale slide={slide} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
slide={slide}
|
||||
index={index}
|
||||
selected={selectedSlide === index}
|
||||
onClick={() => onSlideClick(slide.index ?? index)}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<SortableContext
|
||||
|
|
@ -203,7 +189,7 @@ const SidePanel = ({
|
|||
<button
|
||||
type="button"
|
||||
onClick={handleAddSlideClick}
|
||||
className="pt-6 gap-2 flex flex-col py-2 duration-300 items-center justify-center rounded-lg cursor-pointer mx-auto"
|
||||
className="py-4 gap-2 flex flex-col duration-300 items-center justify-center rounded-lg cursor-pointer mx-auto"
|
||||
>
|
||||
<Plus className="w-3.5 h-3.5" />
|
||||
<span className="text-[11px] font-normal text-[#000000]">Add Slide</span>
|
||||
|
|
|
|||
|
|
@ -103,30 +103,6 @@ const SlideContent = ({ slide, index, presentationId }: SlideContentProps) => {
|
|||
});
|
||||
}
|
||||
};
|
||||
// Scroll to the new slide when streaming and new slides are being generated
|
||||
useEffect(() => {
|
||||
if (
|
||||
presentationData &&
|
||||
presentationData?.slides &&
|
||||
presentationData.slides.length > 1 &&
|
||||
isStreaming
|
||||
) {
|
||||
// Scroll to the last slide (newly generated during streaming)
|
||||
const lastSlideIndex = presentationData.slides.length - 1;
|
||||
const slideElement = document.getElementById(
|
||||
`slide-${presentationData.slides[lastSlideIndex].index}`
|
||||
);
|
||||
if (slideElement) {
|
||||
slideElement.scrollIntoView({
|
||||
behavior: "smooth",
|
||||
block: "center",
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [presentationData?.slides?.length, isStreaming]);
|
||||
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
if (slide.layout.includes("custom")) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
import React, { forwardRef } from "react";
|
||||
import type { Slide } from "../../types/slide";
|
||||
import { V1ContentRender } from "../../components/V1ContentRender";
|
||||
|
||||
interface SlideThumbnailCardProps extends React.HTMLAttributes<HTMLDivElement> {
|
||||
slide: Slide;
|
||||
index: number;
|
||||
selected: boolean;
|
||||
}
|
||||
|
||||
const SCALE = 0.061;
|
||||
|
||||
export const SlideThumbnailCard = forwardRef<
|
||||
HTMLDivElement,
|
||||
SlideThumbnailCardProps
|
||||
>(({ slide, index, selected, className = "", style, ...props }, ref) => {
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
style={{
|
||||
backgroundColor: "var(--card-color, #ffffff)",
|
||||
borderColor: selected ? "#5141e5" : "var(--stroke, #e5e7eb)",
|
||||
...style,
|
||||
}}
|
||||
className={`cursor-pointer border relative p-1.5 rounded-[12px] overflow-hidden transition-all duration-200 ${
|
||||
selected ? "border-[#BDB4FE]" : "border-[#EDEEEF]"
|
||||
} ${className}`}
|
||||
{...props}
|
||||
>
|
||||
<p className="pointer-events-none absolute -left-1 top-1/2 z-50 flex h-[18px] min-w-[18px] -translate-y-1/2 items-center justify-center rounded-full border border-[#EDEEEF] bg-white px-1 text-[10px] font-medium text-[#191919] shadow-sm">
|
||||
{index + 1}
|
||||
</p>
|
||||
|
||||
<div
|
||||
className="relative"
|
||||
style={{ height: `${720 * SCALE}px`, overflow: "hidden" }}
|
||||
>
|
||||
<div
|
||||
className="absolute top-0 left-0 rounded-[10px] overflow-hidden pointer-events-none"
|
||||
style={{
|
||||
width: 1280,
|
||||
height: 720,
|
||||
transformOrigin: "top left",
|
||||
transform: `scale(${SCALE})`,
|
||||
}}
|
||||
>
|
||||
<V1ContentRender slide={slide} isEditMode={true} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
SlideThumbnailCard.displayName = "SlideThumbnailCard";
|
||||
|
|
@ -1,16 +1,14 @@
|
|||
import { useSortable } from '@dnd-kit/sortable';
|
||||
import { CSS } from '@dnd-kit/utilities';
|
||||
import { Slide } from '../../types/slide';
|
||||
import type { Slide } from '../../types/slide';
|
||||
import { useRef } from 'react';
|
||||
import { V1ContentRender } from '../../components/V1ContentRender';
|
||||
import { useSearchParams } from 'next/navigation';
|
||||
import { SlideThumbnailCard } from './SlideThumbnailCard';
|
||||
interface SortableSlideProps {
|
||||
slide: Slide;
|
||||
index: number;
|
||||
selectedSlide: number;
|
||||
onSlideClick: (index: any) => void;
|
||||
}
|
||||
const SCALE = 0.125;
|
||||
|
||||
export function SortableSlide({ slide, index, selectedSlide, onSlideClick }: SortableSlideProps) {
|
||||
const lastClickTime = useRef(0);
|
||||
|
|
@ -27,8 +25,6 @@ export function SortableSlide({ slide, index, selectedSlide, onSlideClick }: Sor
|
|||
transform: CSS.Transform.toString(transform),
|
||||
transition,
|
||||
opacity: isDragging ? 0.5 : 1,
|
||||
backgroundColor: `var(--card-color, #ffffff)`,
|
||||
borderColor: selectedSlide === index ? `#5141e5` : `var(--stroke, #e5e7eb)`
|
||||
};
|
||||
|
||||
const handleClick = (e: React.MouseEvent) => {
|
||||
|
|
@ -47,33 +43,15 @@ export function SortableSlide({ slide, index, selectedSlide, onSlideClick }: Sor
|
|||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
<SlideThumbnailCard
|
||||
ref={setNodeRef}
|
||||
slide={slide}
|
||||
index={index}
|
||||
selected={selectedSlide === index}
|
||||
style={style}
|
||||
{...attributes}
|
||||
{...listeners}
|
||||
onClick={handleClick}
|
||||
className={` cursor-pointer border relative p-1 rounded-[12px] transition-all duration-200 ${selectedSlide === index ? ' border-[#BDB4FE]' : 'border-[#EDEEEF]'
|
||||
}`}
|
||||
>
|
||||
|
||||
<div
|
||||
className="relative"
|
||||
style={{ height: `${720 * SCALE}px`, overflow: "hidden" }}
|
||||
>
|
||||
<div
|
||||
className="absolute top-0 left-0 pointer-events-none"
|
||||
style={{
|
||||
width: 1280,
|
||||
height: 720,
|
||||
transformOrigin: "top left",
|
||||
transform: `scale(${SCALE})`,
|
||||
}}
|
||||
>
|
||||
<V1ContentRender slide={slide} isEditMode={true} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,136 +1,100 @@
|
|||
import { useCallback } from "react";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { RootState } from "@/store/store";
|
||||
import { finishUndoRedo, redo, undo } from "@/store/slices/undoRedoSlice";
|
||||
import { redo, undo } from "@/store/slices/undoRedoSlice";
|
||||
import { useKeyboardShortcut } from "../../hooks/use-keyboard-shortcut";
|
||||
import { setPresentationData } from "@/store/slices/presentationGeneration";
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
export const usePresentationUndoRedo = () => {
|
||||
const dispatch = useDispatch();
|
||||
const undoRedoState = useSelector((state: RootState) => state.undoRedo);
|
||||
const { presentationData } = useSelector((state: RootState) => state.presentationGeneration);
|
||||
|
||||
const canUndo = undoRedoState.past.length > 0;
|
||||
const canRedo = undoRedoState.future.length > 0;
|
||||
|
||||
const onUndo = useCallback(() => {
|
||||
if (!canUndo) return;
|
||||
|
||||
const previousState = undoRedoState.past[undoRedoState.past.length - 1];
|
||||
|
||||
dispatch(undo());
|
||||
|
||||
if (previousState) {
|
||||
const newSlides = JSON.parse(JSON.stringify(previousState.slides));
|
||||
|
||||
dispatch(
|
||||
setPresentationData({
|
||||
...presentationData!,
|
||||
slides: newSlides,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
dispatch(finishUndoRedo());
|
||||
}, 100);
|
||||
}, [canUndo, dispatch, presentationData, undoRedoState.past]);
|
||||
|
||||
const onRedo = useCallback(() => {
|
||||
if (!canRedo) return;
|
||||
|
||||
const nextState = undoRedoState.future[0];
|
||||
|
||||
dispatch(redo());
|
||||
|
||||
if (nextState) {
|
||||
const newSlides = JSON.parse(JSON.stringify(nextState.slides));
|
||||
|
||||
dispatch(
|
||||
setPresentationData({
|
||||
...presentationData!,
|
||||
slides: newSlides,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
dispatch(finishUndoRedo());
|
||||
}, 100);
|
||||
}, [canRedo, dispatch, presentationData, undoRedoState.future]);
|
||||
|
||||
// Handle undo
|
||||
useKeyboardShortcut(
|
||||
["z"],
|
||||
(e) => {
|
||||
if (e.ctrlKey && !e.shiftKey && undoRedoState.past.length > 0) {
|
||||
e.preventDefault();
|
||||
|
||||
// Get the previous state before dispatching undo
|
||||
const previousState = undoRedoState.past[undoRedoState.past.length - 1];
|
||||
|
||||
// Perform undo
|
||||
dispatch(undo());
|
||||
|
||||
// Use the previousState directly instead of relying on the updated undoRedoState
|
||||
if (previousState) {
|
||||
// Create a deep copy to ensure no reference issues
|
||||
const newSlides = JSON.parse(JSON.stringify(previousState.slides));
|
||||
|
||||
// Update the presentation data with the properly structured slides
|
||||
dispatch(
|
||||
setPresentationData({
|
||||
...presentationData!,
|
||||
slides: newSlides,
|
||||
})
|
||||
);
|
||||
}
|
||||
// Reset the undo/redo flag
|
||||
setTimeout(() => {
|
||||
dispatch(finishUndoRedo());
|
||||
}, 100);
|
||||
}
|
||||
},
|
||||
[undoRedoState.past, presentationData]
|
||||
const dispatch = useDispatch();
|
||||
const undoRedoState = useSelector((state: RootState) => state.undoRedo);
|
||||
const { presentationData } = useSelector(
|
||||
(state: RootState) => state.presentationGeneration
|
||||
);
|
||||
// Handle redo
|
||||
|
||||
const canUndo = undoRedoState.past.length > 0;
|
||||
const canRedo = undoRedoState.future.length > 0;
|
||||
|
||||
const applySlidesSnapshot = useCallback(
|
||||
(slidesSnapshot: unknown) => {
|
||||
if (!presentationData || !Array.isArray(slidesSnapshot)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const clonedSlides = JSON.parse(JSON.stringify(slidesSnapshot));
|
||||
dispatch(
|
||||
setPresentationData({
|
||||
...presentationData,
|
||||
slides: clonedSlides,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, presentationData]
|
||||
);
|
||||
|
||||
const onUndo = useCallback(() => {
|
||||
if (!canUndo) {
|
||||
return;
|
||||
}
|
||||
|
||||
const previousState = undoRedoState.past[undoRedoState.past.length - 1];
|
||||
if (!previousState) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(undo());
|
||||
applySlidesSnapshot(previousState.slides);
|
||||
}, [applySlidesSnapshot, canUndo, dispatch, undoRedoState.past]);
|
||||
|
||||
const onRedo = useCallback(() => {
|
||||
if (!canRedo) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nextState = undoRedoState.future[0];
|
||||
if (!nextState) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(redo());
|
||||
applySlidesSnapshot(nextState.slides);
|
||||
}, [applySlidesSnapshot, canRedo, dispatch, undoRedoState.future]);
|
||||
|
||||
// Handle undo (Ctrl + Z)
|
||||
useKeyboardShortcut(
|
||||
["z"],
|
||||
(e) => {
|
||||
if (e.ctrlKey && e.shiftKey && undoRedoState.future.length > 0) {
|
||||
if (e.ctrlKey && !e.shiftKey && canUndo) {
|
||||
e.preventDefault();
|
||||
|
||||
// Get the next state before dispatching redo
|
||||
const nextState = undoRedoState.future[0];
|
||||
|
||||
// Perform redo
|
||||
dispatch(redo());
|
||||
|
||||
// Use the nextState directly instead of relying on the updated undoRedoState
|
||||
if (nextState) {
|
||||
// Create a deep copy to ensure no reference issues
|
||||
const newSlides = JSON.parse(JSON.stringify(nextState.slides));
|
||||
|
||||
// Update the presentation data with the properly structured slides
|
||||
dispatch(
|
||||
setPresentationData({
|
||||
...presentationData!,
|
||||
slides: newSlides,
|
||||
})
|
||||
);
|
||||
}
|
||||
// Reset the undo/redo flag
|
||||
setTimeout(() => {
|
||||
dispatch(finishUndoRedo());
|
||||
}, 100);
|
||||
onUndo();
|
||||
}
|
||||
},
|
||||
[undoRedoState.future, presentationData]
|
||||
[canUndo, onUndo]
|
||||
);
|
||||
|
||||
// Handle redo (Ctrl + Shift + Z)
|
||||
useKeyboardShortcut(
|
||||
["z"],
|
||||
(e) => {
|
||||
if (e.ctrlKey && e.shiftKey && canRedo) {
|
||||
e.preventDefault();
|
||||
onRedo();
|
||||
}
|
||||
},
|
||||
[canRedo, onRedo]
|
||||
);
|
||||
|
||||
// Handle redo (Ctrl + Y)
|
||||
useKeyboardShortcut(
|
||||
["y"],
|
||||
(e) => {
|
||||
if (e.ctrlKey && canRedo) {
|
||||
e.preventDefault();
|
||||
onRedo();
|
||||
}
|
||||
},
|
||||
[canRedo, onRedo]
|
||||
);
|
||||
|
||||
return { onUndo, onRedo, canUndo, canRedo };
|
||||
}
|
||||
};
|
||||
|
|
@ -17,14 +17,16 @@ export const usePresentationData = (
|
|||
) => {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const fetchUserSlides = useCallback(async () => {
|
||||
const fetchUserSlides = useCallback(async (options?: { clearHistory?: boolean }) => {
|
||||
try {
|
||||
const data = await DashboardApi.getPresentation(presentationId);
|
||||
|
||||
|
||||
if (data) {
|
||||
dispatch(setPresentationData(data));
|
||||
dispatch(clearHistory());
|
||||
if (options?.clearHistory ?? true) {
|
||||
dispatch(clearHistory());
|
||||
}
|
||||
setLoading(false);
|
||||
}
|
||||
if (data.fonts) {
|
||||
|
|
|
|||
300
servers/nextjs/app/(presentation-generator)/services/api/chat.ts
Normal file
300
servers/nextjs/app/(presentation-generator)/services/api/chat.ts
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
import { buildAbsoluteApiRequestUrl, getApiUrl } from "@/utils/api";
|
||||
import { ApiResponseHandler } from "./api-error-handler";
|
||||
import { getHeader } from "./header";
|
||||
|
||||
export interface ChatMessageRequest {
|
||||
presentation_id: string;
|
||||
message: string;
|
||||
conversation_id?: string;
|
||||
}
|
||||
|
||||
export interface ChatMessageResponse {
|
||||
conversation_id?: string;
|
||||
response: string;
|
||||
tool_calls?: string[];
|
||||
}
|
||||
|
||||
export interface ChatHistoryMessage {
|
||||
role: string;
|
||||
content: string;
|
||||
created_at?: string;
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
presentation_id: string;
|
||||
conversation_id: string;
|
||||
messages: ChatHistoryMessage[];
|
||||
}
|
||||
|
||||
export interface ChatConversationSummary {
|
||||
conversation_id: string;
|
||||
updated_at?: string | null;
|
||||
last_message_preview?: string | null;
|
||||
}
|
||||
|
||||
export interface ChatStreamTrace {
|
||||
kind?: string;
|
||||
round?: number;
|
||||
tool?: string;
|
||||
status?: string;
|
||||
message?: string;
|
||||
tools?: string[];
|
||||
}
|
||||
|
||||
export interface ChatStreamHandlers {
|
||||
onChunk?: (chunk: string) => void;
|
||||
onStatus?: (status: string) => void;
|
||||
onTrace?: (trace: ChatStreamTrace) => void;
|
||||
onComplete?: (response: ChatMessageResponse) => void;
|
||||
}
|
||||
|
||||
interface ChatStreamDataChunk {
|
||||
type: "chunk";
|
||||
chunk?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataComplete {
|
||||
type: "complete";
|
||||
chat?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataError {
|
||||
type: "error";
|
||||
detail?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataStatus {
|
||||
type: "status";
|
||||
status?: unknown;
|
||||
}
|
||||
|
||||
interface ChatStreamDataTrace {
|
||||
type: "trace";
|
||||
trace?: unknown;
|
||||
}
|
||||
|
||||
type ChatStreamData =
|
||||
| ChatStreamDataChunk
|
||||
| ChatStreamDataComplete
|
||||
| ChatStreamDataError
|
||||
| ChatStreamDataStatus
|
||||
| ChatStreamDataTrace
|
||||
| Record<string, unknown>;
|
||||
|
||||
export class PresentationChatApi {
|
||||
static async listConversations(
|
||||
presentationId: string
|
||||
): Promise<ChatConversationSummary[]> {
|
||||
const u = new URL(
|
||||
buildAbsoluteApiRequestUrl("/api/v1/ppt/chat/conversations")
|
||||
);
|
||||
u.searchParams.set("presentation_id", presentationId);
|
||||
const response = await fetch(u.toString(), {
|
||||
headers: getHeader(),
|
||||
cache: "no-cache",
|
||||
});
|
||||
return await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to list chat conversations"
|
||||
);
|
||||
}
|
||||
|
||||
static async getHistory(
|
||||
presentationId: string,
|
||||
conversationId: string
|
||||
): Promise<ChatHistoryData> {
|
||||
const u = new URL(buildAbsoluteApiRequestUrl("/api/v1/ppt/chat/history"));
|
||||
u.searchParams.set("presentation_id", presentationId);
|
||||
u.searchParams.set("conversation_id", conversationId);
|
||||
const response = await fetch(u.toString(), {
|
||||
headers: getHeader(),
|
||||
cache: "no-cache",
|
||||
});
|
||||
return await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to load chat history"
|
||||
);
|
||||
}
|
||||
|
||||
static async sendMessage(
|
||||
payload: ChatMessageRequest
|
||||
): Promise<ChatMessageResponse> {
|
||||
const response = await fetch(getApiUrl("/api/v1/ppt/chat/message"), {
|
||||
method: "POST",
|
||||
headers: getHeader(),
|
||||
body: JSON.stringify(payload),
|
||||
cache: "no-cache",
|
||||
});
|
||||
|
||||
return await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to send chat message"
|
||||
);
|
||||
}
|
||||
|
||||
static async streamMessage(
|
||||
payload: ChatMessageRequest,
|
||||
handlers: ChatStreamHandlers = {},
|
||||
options?: { signal?: AbortSignal }
|
||||
): Promise<ChatMessageResponse> {
|
||||
const response = await fetch(getApiUrl("/api/v1/ppt/chat/message/stream"), {
|
||||
method: "POST",
|
||||
headers: getHeader(),
|
||||
body: JSON.stringify(payload),
|
||||
cache: "no-cache",
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
await ApiResponseHandler.handleResponse(
|
||||
response,
|
||||
"Failed to stream chat message"
|
||||
);
|
||||
throw new Error("Failed to stream chat message");
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body received from chat stream");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
let buffer = "";
|
||||
let finalResponse: ChatMessageResponse | null = null;
|
||||
|
||||
const processSseFrame = (frame: string) => {
|
||||
const normalized = frame.replaceAll("\r", "");
|
||||
const lines = normalized.split("\n");
|
||||
let eventName = "";
|
||||
const dataLines: string[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("event:")) {
|
||||
eventName = line.slice(6).trim();
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith("data:")) {
|
||||
dataLines.push(line.slice(5).trimStart());
|
||||
}
|
||||
}
|
||||
|
||||
if (eventName && eventName !== "response") {
|
||||
return;
|
||||
}
|
||||
if (!dataLines.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
let parsedData: ChatStreamData;
|
||||
try {
|
||||
parsedData = JSON.parse(dataLines.join("\n")) as ChatStreamData;
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
const payloadType = parsedData.type;
|
||||
if (payloadType === "chunk") {
|
||||
const chunk = parsedData.chunk;
|
||||
if (typeof chunk === "string" && chunk.length > 0) {
|
||||
handlers.onChunk?.(chunk);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (payloadType === "complete") {
|
||||
const chatPayload = (parsedData as ChatStreamDataComplete).chat;
|
||||
if (
|
||||
chatPayload &&
|
||||
typeof chatPayload === "object" &&
|
||||
typeof (chatPayload as { response?: unknown }).response === "string"
|
||||
) {
|
||||
const typedResponse: ChatMessageResponse = {
|
||||
conversation_id:
|
||||
typeof (chatPayload as { conversation_id?: unknown })
|
||||
.conversation_id === "string"
|
||||
? (chatPayload as { conversation_id?: string }).conversation_id
|
||||
: undefined,
|
||||
response: (chatPayload as { response: string }).response,
|
||||
tool_calls: Array.isArray(
|
||||
(chatPayload as { tool_calls?: unknown }).tool_calls
|
||||
)
|
||||
? (
|
||||
(chatPayload as { tool_calls?: unknown[] }).tool_calls ?? []
|
||||
).filter((item): item is string => typeof item === "string")
|
||||
: [],
|
||||
};
|
||||
finalResponse = typedResponse;
|
||||
handlers.onComplete?.(typedResponse);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (payloadType === "error") {
|
||||
const detail = (parsedData as ChatStreamDataError).detail;
|
||||
const message =
|
||||
typeof detail === "string" && detail.trim().length > 0
|
||||
? detail
|
||||
: "Chat stream failed";
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
if (payloadType === "status") {
|
||||
const status = (parsedData as ChatStreamDataStatus).status;
|
||||
if (typeof status === "string" && status.trim().length > 0) {
|
||||
handlers.onStatus?.(status);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (payloadType === "trace") {
|
||||
const trace = (parsedData as ChatStreamDataTrace).trace;
|
||||
if (trace && typeof trace === "object") {
|
||||
const typedTrace = trace as Record<string, unknown>;
|
||||
handlers.onTrace?.({
|
||||
kind:
|
||||
typeof typedTrace.kind === "string" ? typedTrace.kind : undefined,
|
||||
round:
|
||||
typeof typedTrace.round === "number" ? typedTrace.round : undefined,
|
||||
tool:
|
||||
typeof typedTrace.tool === "string" ? typedTrace.tool : undefined,
|
||||
status:
|
||||
typeof typedTrace.status === "string" ? typedTrace.status : undefined,
|
||||
message:
|
||||
typeof typedTrace.message === "string" ? typedTrace.message : undefined,
|
||||
tools: Array.isArray(typedTrace.tools)
|
||||
? typedTrace.tools.filter(
|
||||
(value): value is string => typeof value === "string"
|
||||
)
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let delimiterIndex = buffer.indexOf("\n\n");
|
||||
while (delimiterIndex >= 0) {
|
||||
const frame = buffer.slice(0, delimiterIndex);
|
||||
buffer = buffer.slice(delimiterIndex + 2);
|
||||
processSseFrame(frame);
|
||||
delimiterIndex = buffer.indexOf("\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer.trim().length > 0) {
|
||||
processSseFrame(buffer);
|
||||
}
|
||||
|
||||
if (finalResponse) {
|
||||
return finalResponse;
|
||||
}
|
||||
|
||||
throw new Error("Chat stream ended before completion");
|
||||
}
|
||||
}
|
||||
|
|
@ -64,7 +64,8 @@ const page = () => {
|
|||
</div>
|
||||
|
||||
<div
|
||||
className='fixed z-0 -bottom-[14.5rem] left-0 w-full h-full'
|
||||
className="fixed z-0 -bottom-[14.5rem] left-0 w-full h-full pointer-events-none"
|
||||
aria-hidden
|
||||
style={{
|
||||
height: "341px",
|
||||
borderRadius: '1440px',
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ input[type="number"]::-webkit-outer-spin-button {
|
|||
}
|
||||
|
||||
input[type="number"] {
|
||||
appearance: textfield;
|
||||
-moz-appearance: textfield;
|
||||
}
|
||||
|
||||
|
|
@ -104,7 +105,6 @@ tbody tr {
|
|||
display: table;
|
||||
width: 100%;
|
||||
table-layout: fixed;
|
||||
/* even columns width , fix width of table too*/
|
||||
}
|
||||
|
||||
thead {
|
||||
|
|
@ -276,83 +276,30 @@ thead {
|
|||
@apply prose prose-slate max-w-none;
|
||||
}
|
||||
|
||||
/* .markdown-content h1 {
|
||||
@apply text-xl font-bold mb-4 text-gray-900;
|
||||
.chat-markdown {
|
||||
@apply prose prose-slate max-w-none;
|
||||
}
|
||||
|
||||
.markdown-content h2 {
|
||||
@apply text-lg font-bold mb-3 text-gray-900;
|
||||
.chat-markdown :where(p, ul, ol, pre, blockquote) {
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
|
||||
.markdown-content h3 {
|
||||
@apply text-base font-bold mb-2 text-gray-900;
|
||||
.chat-markdown :where(p:last-child, ul:last-child, ol:last-child, pre:last-child, blockquote:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.markdown-content h4 {
|
||||
@apply text-sm font-bold mb-2 text-gray-900;
|
||||
.chat-markdown :where(ul, ol) {
|
||||
padding-left: 1.25rem;
|
||||
}
|
||||
|
||||
.markdown-content h5 {
|
||||
@apply text-xs font-bold mb-1 text-gray-900;
|
||||
.chat-markdown :where(code) {
|
||||
@apply rounded bg-[#F5F5F5] px-1 py-0.5 text-[0.85em];
|
||||
}
|
||||
|
||||
.markdown-content h6 {
|
||||
@apply text-xs font-semibold mb-1 text-gray-900;
|
||||
.chat-markdown :where(pre) {
|
||||
@apply overflow-x-auto rounded-lg bg-[#F5F5F5] p-3;
|
||||
}
|
||||
|
||||
.markdown-content p {
|
||||
@apply mb-4 text-base text-gray-700;
|
||||
}
|
||||
|
||||
.markdown-content ul {
|
||||
@apply list-disc pl-6 mb-4;
|
||||
}
|
||||
|
||||
.markdown-content ol {
|
||||
@apply list-decimal pl-6 mb-4;
|
||||
}
|
||||
|
||||
.markdown-content li {
|
||||
@apply mb-1;
|
||||
}
|
||||
|
||||
.markdown-content strong,
|
||||
.markdown-content b {
|
||||
@apply font-bold text-gray-900;
|
||||
}
|
||||
|
||||
.markdown-content em {
|
||||
@apply italic;
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
@apply border-l-4 border-gray-300 pl-4 italic my-4;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
@apply bg-gray-100 px-1 py-0.5 rounded font-mono text-sm;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
@apply bg-gray-100 p-4 rounded-lg my-4 overflow-x-auto;
|
||||
}
|
||||
|
||||
.markdown-content a {
|
||||
@apply text-blue-600 hover:text-blue-800 underline;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
@apply min-w-full border border-gray-300 my-4;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
@apply bg-gray-100 border border-gray-300 px-4 py-2 font-bold;
|
||||
}
|
||||
|
||||
.markdown-content td {
|
||||
@apply border border-gray-300 px-4 py-2;
|
||||
} */
|
||||
|
||||
/* Override Tailwind Typography prose heading sizes for markdown editor */
|
||||
.prose h1 {
|
||||
font-size: 18px !important;
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"description": "Dark, theme-ready presentation layouts with covers, structured content grids, timelines, narrative splits, and chart-driven slides",
|
||||
"ordered": false,
|
||||
"default": false
|
||||
}
|
||||
|
|
@ -54,7 +54,8 @@ export default function Home() {
|
|||
|
||||
<div className="flex min-h-screen relative">
|
||||
<div
|
||||
className='fixed z-0 -bottom-[14.5rem] left-0 w-full h-full'
|
||||
className="fixed z-0 -bottom-[14.5rem] left-0 w-full h-full pointer-events-none"
|
||||
aria-hidden
|
||||
style={{
|
||||
height: "341px",
|
||||
borderRadius: '1440px',
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ interface UndoRedoState {
|
|||
future: HistoryState[];
|
||||
maxHistorySize: number;
|
||||
isUndoRedoInProgress: boolean;
|
||||
pendingHistorySkips: number;
|
||||
}
|
||||
|
||||
// Helper function for deep copy
|
||||
|
|
@ -25,7 +26,8 @@ const initialState: UndoRedoState = {
|
|||
present: null,
|
||||
future: [],
|
||||
maxHistorySize: 30,
|
||||
isUndoRedoInProgress: false
|
||||
isUndoRedoInProgress: false,
|
||||
pendingHistorySkips: 0,
|
||||
};
|
||||
|
||||
const undoRedoSlice = createSlice({
|
||||
|
|
@ -33,15 +35,22 @@ const undoRedoSlice = createSlice({
|
|||
initialState,
|
||||
reducers: {
|
||||
addToHistory: (state, action: PayloadAction<{slides: Slide[], actionType: string}>) => {
|
||||
|
||||
// Skip if undo/redo is in progress
|
||||
if (state.pendingHistorySkips > 0) {
|
||||
state.pendingHistorySkips -= 1;
|
||||
if (state.pendingHistorySkips === 0) {
|
||||
state.isUndoRedoInProgress = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Defensive guard for any stale in-progress state.
|
||||
if (state.isUndoRedoInProgress) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Deep copy the slides to avoid reference issues
|
||||
const newSlides = deepCopy(action.payload.slides);
|
||||
|
||||
|
||||
// Only add to history if the slides have actually changed
|
||||
if (!state.present) {
|
||||
state.present = {
|
||||
|
|
@ -51,84 +60,80 @@ const undoRedoSlice = createSlice({
|
|||
};
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Skip if slides are identical
|
||||
if (JSON.stringify(state.present.slides) === JSON.stringify(newSlides)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Add current state to past
|
||||
state.past.push(state.present);
|
||||
|
||||
|
||||
// Limit history size
|
||||
if (state.past.length > state.maxHistorySize) {
|
||||
state.past.shift();
|
||||
}
|
||||
|
||||
|
||||
// Clear future on new change
|
||||
state.future = [];
|
||||
|
||||
|
||||
// Set new present
|
||||
state.present = {
|
||||
slides: newSlides,
|
||||
timestamp: Date.now(),
|
||||
actionType: action.payload.actionType
|
||||
};
|
||||
|
||||
|
||||
},
|
||||
|
||||
undo: (state) => {
|
||||
if (state.past.length === 0) {
|
||||
|
||||
if (state.past.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
state.isUndoRedoInProgress = true;
|
||||
|
||||
state.pendingHistorySkips = 1;
|
||||
|
||||
// Move present to future
|
||||
if (state.present) {
|
||||
state.future.unshift(deepCopy(state.present));
|
||||
}
|
||||
|
||||
|
||||
// Get last past state
|
||||
const previous = state.past[state.past.length - 1];
|
||||
state.past = state.past.slice(0, -1);
|
||||
state.present = deepCopy(previous);
|
||||
|
||||
|
||||
},
|
||||
|
||||
redo: (state) => {
|
||||
if (state.future.length === 0) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
state.isUndoRedoInProgress = true;
|
||||
|
||||
state.pendingHistorySkips = 1;
|
||||
|
||||
// Move present to past
|
||||
if (state.present) {
|
||||
state.past.push(deepCopy(state.present));
|
||||
}
|
||||
|
||||
|
||||
// Get first future state
|
||||
const next = state.future[0];
|
||||
state.future = state.future.slice(1);
|
||||
state.present = deepCopy(next);
|
||||
|
||||
|
||||
},
|
||||
|
||||
|
||||
finishUndoRedo: (state) => {
|
||||
state.isUndoRedoInProgress = false;
|
||||
state.pendingHistorySkips = 0;
|
||||
},
|
||||
|
||||
clearHistory: (state) => {
|
||||
state.past = [];
|
||||
state.future = [];
|
||||
state.present = null;
|
||||
// Keep present
|
||||
state.isUndoRedoInProgress = false;
|
||||
state.pendingHistorySkips = 0;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -63,6 +63,25 @@ export function getApiUrl(path: string): string {
|
|||
return normalizedPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* getApiUrl may return a path without host (e.g. `/api/v1/...`). A single-argument
|
||||
* `new URL("/api/...")` call is invalid; use this before `new URL(..., ...)`-style
|
||||
* builds or to obtain an absolute string for `URL` + `searchParams`.
|
||||
*/
|
||||
export function buildAbsoluteApiRequestUrl(
|
||||
path: string,
|
||||
baseForRelative: string = typeof window !== "undefined" &&
|
||||
window.location?.origin
|
||||
? window.location.origin
|
||||
: "http://127.0.0.1:5000"
|
||||
): string {
|
||||
const resolved = getApiUrl(path);
|
||||
if (isAbsoluteHttpUrl(resolved)) {
|
||||
return resolved;
|
||||
}
|
||||
return new URL(resolved, baseForRelative).toString();
|
||||
}
|
||||
|
||||
function hasBackendAssetPrefix(path: string): boolean {
|
||||
return path.startsWith("/static/") || path.startsWith("/app_data/");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue