Merge main into pdf-pptx-layout
This commit is contained in:
commit
68bb4bae3a
55 changed files with 1732 additions and 485 deletions
|
|
@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep
|
|||
- **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom**
|
||||
- **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output.
|
||||
- **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled.
|
||||
- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results.
|
||||
|
||||
You can also set the following environment variables to customize the image generation provider and API keys:
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ services:
|
|||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
- DISABLE_THINKING=${DISABLE_THINKING}
|
||||
- WEB_GROUNDING=${WEB_GROUNDING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
||||
production-gpu:
|
||||
|
|
@ -60,6 +63,9 @@ services:
|
|||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
- DISABLE_THINKING=${DISABLE_THINKING}
|
||||
- WEB_GROUNDING=${WEB_GROUNDING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
||||
development:
|
||||
|
|
@ -87,6 +93,9 @@ services:
|
|||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
- DISABLE_THINKING=${DISABLE_THINKING}
|
||||
- WEB_GROUNDING=${WEB_GROUNDING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
||||
development-gpu:
|
||||
|
|
@ -122,3 +131,7 @@ services:
|
|||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- TOOL_CALLS=${TOOL_CALLS}
|
||||
- DISABLE_THINKING=${DISABLE_THINKING}
|
||||
- WEB_GROUNDING=${WEB_GROUNDING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ async def pull_model(
|
|||
# If the model is being pulled, return the model
|
||||
if saved_model_status:
|
||||
# If the model is being pulled, return the model
|
||||
# ? If the model status is pulled in redis but was not found while listing pulled models,
|
||||
# ? If the model status is pulled in database but was not found while listing pulled models,
|
||||
# ? it means the model was deleted and we need to pull it again
|
||||
if (
|
||||
saved_model_status["status"] == "error"
|
||||
|
|
|
|||
|
|
@ -72,6 +72,9 @@ async def stream_outlines(
|
|||
presentation_outlines_json = json.loads(presentation_outlines_text)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
with open("./debug/outlines.txt", "w") as f:
|
||||
f.write(presentation_outlines_text)
|
||||
print(presentation_outlines_text)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Failed to generate presentation outlines. Please try again.",
|
||||
|
|
@ -87,7 +90,8 @@ async def stream_outlines(
|
|||
|
||||
presentation.outlines = presentation_outlines.model_dump()
|
||||
presentation.title = (
|
||||
presentation_outlines.slides[0][:50]
|
||||
presentation_outlines.slides[0]
|
||||
.content[:50]
|
||||
.replace("#", "")
|
||||
.replace("/", "")
|
||||
.replace("\\", "")
|
||||
|
|
|
|||
|
|
@ -11,7 +11,10 @@ from sqlmodel import select
|
|||
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
|
||||
from models.presentation_and_path import PresentationPathAndEditPath
|
||||
from models.presentation_from_template import GetPresentationUsingTemplateRequest
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.pptx_models import PptxPresentationModel
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
|
@ -126,7 +129,7 @@ async def create_presentation(
|
|||
@PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel)
|
||||
async def prepare_presentation(
|
||||
presentation_id: Annotated[str, Body()],
|
||||
outlines: Annotated[List[str], Body()],
|
||||
outlines: Annotated[List[SlideOutlineModel], Body()],
|
||||
layout: Annotated[PresentationLayoutModel, Body()],
|
||||
title: Annotated[Optional[str], Body()] = None,
|
||||
sql_session: AsyncSession = Depends(get_async_session),
|
||||
|
|
@ -161,7 +164,9 @@ async def prepare_presentation(
|
|||
presentation_structure.slides[index] = random_slide_index
|
||||
|
||||
sql_session.add(presentation)
|
||||
presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump()
|
||||
presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
presentation.title = title or presentation.title
|
||||
presentation.set_layout(layout)
|
||||
presentation.set_structure(presentation_structure)
|
||||
|
|
|
|||
|
|
@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1"
|
|||
|
||||
# Default models
|
||||
DEFAULT_OPENAI_MODEL = "gpt-4.1"
|
||||
DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash"
|
||||
DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"
|
||||
DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash"
|
||||
DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514"
|
||||
|
|
|
|||
|
|
@ -6,61 +6,51 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
label="Llama 3:8b",
|
||||
value="llama3:8b",
|
||||
size="4.7GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3:70b": OllamaModelMetadata(
|
||||
label="Llama 3:70b",
|
||||
value="llama3:70b",
|
||||
size="40GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:8b": OllamaModelMetadata(
|
||||
label="Llama 3.1:8b",
|
||||
value="llama3.1:8b",
|
||||
size="4.9GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:70b": OllamaModelMetadata(
|
||||
label="Llama 3.1:70b",
|
||||
value="llama3.1:70b",
|
||||
size="43GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:405b": OllamaModelMetadata(
|
||||
label="Llama 3.1:405b",
|
||||
value="llama3.1:405b",
|
||||
size="243GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:1b": OllamaModelMetadata(
|
||||
label="Llama 3.2:1b",
|
||||
value="llama3.2:1b",
|
||||
size="1.3GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:3b": OllamaModelMetadata(
|
||||
label="Llama 3.2:3b",
|
||||
value="llama3.2:3b",
|
||||
size="2GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.3:70b": OllamaModelMetadata(
|
||||
label="Llama 3.3:70b",
|
||||
value="llama3.3:70b",
|
||||
size="43GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:16x17b": OllamaModelMetadata(
|
||||
label="Llama 4:16x17b",
|
||||
value="llama4:16x17b",
|
||||
size="67GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:128x17b": OllamaModelMetadata(
|
||||
label="Llama 4:128x17b",
|
||||
value="llama4:128x17b",
|
||||
size="245GB",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -69,25 +59,21 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
label="Gemma 3:1b",
|
||||
value="gemma3:1b",
|
||||
size="815MB",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:4b": OllamaModelMetadata(
|
||||
label="Gemma 3:4b",
|
||||
value="gemma3:4b",
|
||||
size="3.3GB",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:12b": OllamaModelMetadata(
|
||||
label="Gemma 3:12b",
|
||||
value="gemma3:12b",
|
||||
size="8.1GB",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:27b": OllamaModelMetadata(
|
||||
label="Gemma 3:27b",
|
||||
value="gemma3:27b",
|
||||
size="17GB",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -96,43 +82,36 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
label="DeepSeek R1:1.5b",
|
||||
value="deepseek-r1:1.5b",
|
||||
size="1.1GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:7b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:7b",
|
||||
value="deepseek-r1:7b",
|
||||
size="4.7GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:8b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:8b",
|
||||
value="deepseek-r1:8b",
|
||||
size="5.2GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:14b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:14b",
|
||||
value="deepseek-r1:14b",
|
||||
size="9GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:32b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:32b",
|
||||
value="deepseek-r1:32b",
|
||||
size="20GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:70b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:70b",
|
||||
value="deepseek-r1:70b",
|
||||
size="43GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:671b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:671b",
|
||||
value="deepseek-r1:671b",
|
||||
size="404GB",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -141,49 +120,54 @@ SUPPORTED_QWEN_MODELS = {
|
|||
label="Qwen 3:0.6b",
|
||||
value="qwen3:0.6b",
|
||||
size="523MB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:1.7b": OllamaModelMetadata(
|
||||
label="Qwen 3:1.7b",
|
||||
value="qwen3:1.7b",
|
||||
size="1.4GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:4b": OllamaModelMetadata(
|
||||
label="Qwen 3:4b",
|
||||
value="qwen3:4b",
|
||||
size="2.6GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:8b": OllamaModelMetadata(
|
||||
label="Qwen 3:8b",
|
||||
value="qwen3:8b",
|
||||
size="5.2GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:14b": OllamaModelMetadata(
|
||||
label="Qwen 3:14b",
|
||||
value="qwen3:14b",
|
||||
size="9.3GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:30b": OllamaModelMetadata(
|
||||
label="Qwen 3:30b",
|
||||
value="qwen3:30b",
|
||||
size="19GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:32b": OllamaModelMetadata(
|
||||
label="Qwen 3:32b",
|
||||
value="qwen3:32b",
|
||||
size="20GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:235b": OllamaModelMetadata(
|
||||
label="Qwen 3:235b",
|
||||
value="qwen3:235b",
|
||||
size="142GB",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
}
|
||||
|
||||
SUPPORTED_GPT_OSS_MODELS = {
|
||||
"gpt-oss:20b": OllamaModelMetadata(
|
||||
label="GPT-OSS 20b",
|
||||
value="gpt-oss:20b",
|
||||
size="14GB",
|
||||
),
|
||||
"gpt-oss:120b": OllamaModelMetadata(
|
||||
label="GPT-OSS 120b",
|
||||
value="gpt-oss:120b",
|
||||
size="65GB",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -192,4 +176,5 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
**SUPPORTED_GEMMA_MODELS,
|
||||
**SUPPORTED_DEEPSEEK_MODELS,
|
||||
**SUPPORTED_QWEN_MODELS,
|
||||
**SUPPORTED_GPT_OSS_MODELS,
|
||||
}
|
||||
|
|
|
|||
8
servers/fastapi/enums/llm_call_type.py
Normal file
8
servers/fastapi/enums/llm_call_type.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class LLMCallType(Enum):
|
||||
UNSTRUCTURED = "unstructured"
|
||||
UNSTRUCTURED_STREAM = "unstructured_stream"
|
||||
STRUCTURED = "structured"
|
||||
STRUCTURED_STREAM = "structured_stream"
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
heading: str
|
||||
|
|
@ -7,5 +9,5 @@ class DocumentChunk(BaseModel):
|
|||
heading_index: int
|
||||
score: float
|
||||
|
||||
def to_slide_outline(self) -> str:
|
||||
return f"{self.heading}\n{self.content}"
|
||||
def to_slide_outline(self) -> SlideOutlineModel:
|
||||
return SlideOutlineModel(content=f"{self.heading}\n{self.content}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,58 @@
|
|||
from typing import Literal
|
||||
from typing import Any, List, Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
from google.genai.types import Content as GoogleContent
|
||||
|
||||
from models.llm_tool_call import AnthropicToolCall
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: Literal["user", "system"]
|
||||
pass
|
||||
|
||||
|
||||
class LLMUserMessage(LLMMessage):
|
||||
role: Literal["user"] = "user"
|
||||
content: str
|
||||
|
||||
|
||||
class LLMSystemMessage(LLMMessage):
|
||||
role: Literal["system"] = "system"
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAIAssistantMessage(LLMMessage):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str | None = None
|
||||
tool_calls: Optional[List[dict]] = None
|
||||
|
||||
|
||||
class GoogleAssistantMessage(LLMMessage):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: GoogleContent
|
||||
|
||||
|
||||
class AnthropicAssistantMessage(LLMMessage):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: List[AnthropicToolCall]
|
||||
|
||||
|
||||
class AnthropicToolCallMessage(LLMMessage):
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str
|
||||
content: str
|
||||
|
||||
|
||||
class AnthropicUserMessage(LLMMessage):
|
||||
role: Literal["user"] = "user"
|
||||
content: List[AnthropicToolCallMessage]
|
||||
|
||||
|
||||
class OpenAIToolCallMessage(LLMMessage):
|
||||
role: Literal["tool"] = "tool"
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class GoogleToolCallMessage(LLMMessage):
|
||||
role: Literal["tool"] = "tool"
|
||||
name: str
|
||||
response: dict
|
||||
|
|
|
|||
29
servers/fastapi/models/llm_tool_call.py
Normal file
29
servers/fastapi/models/llm_tool_call.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMToolCall(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIToolCallFunction(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class OpenAIToolCall(LLMToolCall):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: OpenAIToolCallFunction
|
||||
|
||||
|
||||
class GoogleToolCall(LLMToolCall):
|
||||
name: str
|
||||
arguments: Optional[dict] = None
|
||||
|
||||
|
||||
class AnthropicToolCall(LLMToolCall):
|
||||
type: Literal["tool_use"] = "tool_use"
|
||||
id: str
|
||||
name: str
|
||||
input: object
|
||||
29
servers/fastapi/models/llm_tools.py
Normal file
29
servers/fastapi/models/llm_tools.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Any, Callable, Coroutine, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LLMTool(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class LLMDynamicTool(LLMTool):
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict = {}
|
||||
handler: Callable[..., Coroutine[Any, Any, str]]
|
||||
|
||||
|
||||
class SearchWebTool(LLMTool):
|
||||
"""
|
||||
Search the web for information.
|
||||
"""
|
||||
|
||||
query: str = Field(description="The query to search the web for")
|
||||
|
||||
|
||||
class GetCurrentDatetimeTool(LLMTool):
|
||||
"""
|
||||
Get the current datetime.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
@ -4,5 +4,4 @@ from pydantic import BaseModel
|
|||
class OllamaModelMetadata(BaseModel):
|
||||
label: str
|
||||
value: str
|
||||
icon: str
|
||||
size: str
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@ from typing import List
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlideOutlineModel(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class PresentationOutlineModel(BaseModel):
|
||||
slides: List[str]
|
||||
slides: List[SlideOutlineModel]
|
||||
|
||||
def to_string(self):
|
||||
message = ""
|
||||
|
|
|
|||
|
|
@ -35,3 +35,6 @@ class UserConfig(BaseModel):
|
|||
TOOL_CALLS: Optional[bool] = None
|
||||
DISABLE_THINKING: Optional[bool] = None
|
||||
EXTENDED_REASONING: Optional[bool] = None
|
||||
|
||||
# Web Search
|
||||
WEB_GROUNDING: Optional[bool] = None
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"openai>=1.98.0",
|
||||
"pathvalidate>=3.3.1",
|
||||
"pdfplumber>=0.11.7",
|
||||
"pytest>=8.4.1",
|
||||
"python-pptx>=1.0.2",
|
||||
"redis>=6.2.0",
|
||||
"sqlmodel>=0.0.24",
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
201
servers/fastapi/services/llm_tool_calls_handler.py
Normal file
201
servers/fastapi/services/llm_tool_calls_handler.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
import asyncio
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Any, Callable, Coroutine, List, Optional
|
||||
from fastapi import HTTPException
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import (
|
||||
AnthropicToolCallMessage,
|
||||
GoogleToolCallMessage,
|
||||
OpenAIToolCallMessage,
|
||||
)
|
||||
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
|
||||
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
|
||||
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
|
||||
|
||||
|
||||
class LLMToolCallsHandler:
|
||||
def __init__(self, client):
|
||||
from services.llm_client import LLMClient
|
||||
|
||||
self.client: LLMClient = client
|
||||
|
||||
self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = {
|
||||
"SearchWebTool": self.search_web_tool_call_handler,
|
||||
"GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler,
|
||||
}
|
||||
self.dynamic_tools: List[LLMDynamicTool] = []
|
||||
|
||||
def get_tool_handler(
|
||||
self, tool_name: str
|
||||
) -> Callable[..., Coroutine[Any, Any, str]]:
|
||||
handler = self.tools_map.get(tool_name)
|
||||
if handler:
|
||||
return handler
|
||||
else:
|
||||
dynamic_tools = list(
|
||||
filter(lambda tool: tool.name == tool_name, self.dynamic_tools)
|
||||
)
|
||||
if dynamic_tools:
|
||||
return dynamic_tools[0].handler
|
||||
raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found")
|
||||
|
||||
def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None):
|
||||
if tools is None:
|
||||
return None
|
||||
parsed_tools = map(self.parse_tool, tools)
|
||||
return list(parsed_tools)
|
||||
|
||||
def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False):
|
||||
if isinstance(tool, LLMDynamicTool):
|
||||
self.dynamic_tools.append(tool)
|
||||
|
||||
match self.client.llm_provider:
|
||||
case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM:
|
||||
return self.parse_tool_openai(tool, strict)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self.parse_tool_anthropic(tool)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self.parse_tool_google(tool)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"LLM provider must be either openai, anthropic, or google"
|
||||
)
|
||||
|
||||
def parse_tool_openai(
|
||||
self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False
|
||||
):
|
||||
if isinstance(tool, LLMDynamicTool):
|
||||
name = tool.name
|
||||
description = tool.description
|
||||
parameters = tool.parameters
|
||||
else:
|
||||
name = tool.__name__
|
||||
description = tool.__doc__ or ""
|
||||
parameters = tool.model_json_schema()
|
||||
|
||||
if strict:
|
||||
parameters = ensure_strict_json_schema(parameters, path=(), root=parameters)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"strict": strict,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
parsed = self.parse_tool_openai(tool)
|
||||
# parsed["function"]["parameters"] = flatten_json_schema(
|
||||
# parsed["function"]["parameters"]
|
||||
# )
|
||||
return {
|
||||
"name": parsed["function"]["name"],
|
||||
"description": parsed["function"]["description"],
|
||||
"parameters": parsed["function"]["parameters"],
|
||||
}
|
||||
|
||||
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
|
||||
parsed = self.parse_tool_openai(tool)
|
||||
input_schema = parsed["function"]["parameters"]
|
||||
return {
|
||||
"name": parsed["function"]["name"],
|
||||
"description": parsed["function"]["description"],
|
||||
"input_schema": {"type": "object"} if input_schema == {} else input_schema,
|
||||
}
|
||||
|
||||
async def handle_tool_calls_openai(
|
||||
self,
|
||||
tool_calls: List[OpenAIToolCall],
|
||||
) -> List[OpenAIToolCallMessage]:
|
||||
async_tool_calls_tasks = []
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_handler = self.get_tool_handler(tool_name)
|
||||
async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments))
|
||||
|
||||
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
|
||||
tool_call_messages = [
|
||||
OpenAIToolCallMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
for tool_call, result in zip(tool_calls, tool_call_results)
|
||||
]
|
||||
return tool_call_messages
|
||||
|
||||
async def handle_tool_calls_google(
|
||||
self,
|
||||
tool_calls: List[GoogleToolCall],
|
||||
) -> List[GoogleToolCallMessage]:
|
||||
async_tool_calls_tasks = []
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_handler = self.get_tool_handler(tool_name)
|
||||
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments)))
|
||||
|
||||
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
|
||||
tool_call_messages = [
|
||||
GoogleToolCallMessage(
|
||||
name=tool_call.name,
|
||||
response={"result": result},
|
||||
)
|
||||
for tool_call, result in zip(tool_calls, tool_call_results)
|
||||
]
|
||||
return tool_call_messages
|
||||
|
||||
async def handle_tool_calls_anthropic(
|
||||
self,
|
||||
tool_calls: List[AnthropicToolCall],
|
||||
) -> List[AnthropicToolCallMessage]:
|
||||
async_tool_calls_tasks = []
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_handler = self.get_tool_handler(tool_name)
|
||||
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input)))
|
||||
|
||||
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
|
||||
tool_call_messages = [
|
||||
AnthropicToolCallMessage(
|
||||
content=result,
|
||||
tool_use_id=tool_call.id,
|
||||
)
|
||||
for tool_call, result in zip(tool_calls, tool_call_results)
|
||||
]
|
||||
return tool_call_messages
|
||||
|
||||
# ? Tool call handlers
|
||||
# Search web tool call handler
|
||||
async def search_web_tool_call_handler(self, arguments: str) -> str:
|
||||
match self.client.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return await self.search_web_tool_call_handler_openai(arguments)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return await self.search_web_tool_call_handler_anthropic(arguments)
|
||||
case LLMProvider.GOOGLE:
|
||||
return await self.search_web_tool_call_handler_google(arguments)
|
||||
case _:
|
||||
return (
|
||||
"Web search tool call handler not implemented for this LLM provider: "
|
||||
+ self.client.llm_provider.value
|
||||
)
|
||||
|
||||
async def search_web_tool_call_handler_openai(self, arguments: str) -> str:
|
||||
args = SearchWebTool.model_validate_json(arguments)
|
||||
return await self.client._search_openai(args.query)
|
||||
|
||||
async def search_web_tool_call_handler_google(self, arguments: str) -> str:
|
||||
args = SearchWebTool.model_validate_json(arguments)
|
||||
return await self.client._search_google(args.query)
|
||||
|
||||
async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str:
|
||||
args = SearchWebTool.model_validate_json(arguments)
|
||||
return await self.client._search_anthropic(args.query)
|
||||
|
||||
# Get current datetime tool call handler
|
||||
async def get_current_datetime_tool_call_handler(self, _) -> str:
|
||||
current_time = datetime.now()
|
||||
return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}"
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
from typing import Any, Optional
|
||||
import redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from utils.get_env import (
|
||||
get_redis_db_env,
|
||||
get_redis_host_env,
|
||||
get_redis_password_env,
|
||||
get_redis_port_env,
|
||||
)
|
||||
|
||||
|
||||
class RedisService:
|
||||
def __init__(self):
|
||||
self.redis_host = get_redis_host_env() or "localhost"
|
||||
self.redis_port = int(get_redis_port_env() or "6379")
|
||||
self.redis_db = int(get_redis_db_env() or "0")
|
||||
self.redis_password = get_redis_password_env() or None
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
db=self.redis_db,
|
||||
password=self.redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
|
||||
try:
|
||||
return self.client.set(key, value, ex=expire)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
try:
|
||||
return self.client.get(key)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.delete(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.exists(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def set_hash(self, name: str, mapping: dict) -> bool:
|
||||
try:
|
||||
return self.client.hmset(name, mapping)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_hash(self, name: str) -> Optional[dict]:
|
||||
try:
|
||||
return self.client.hgetall(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete_hash(self, name: str, *fields: str) -> int:
|
||||
try:
|
||||
return self.client.hdel(name, *fields)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def set_list(self, name: str, values: list) -> bool:
|
||||
try:
|
||||
self.client.delete(name)
|
||||
if values:
|
||||
self.client.rpush(name, *values)
|
||||
return True
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
|
||||
try:
|
||||
return self.client.lrange(name, start, end)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def add_to_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.sadd(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def get_set(self, name: str) -> Optional[set]:
|
||||
try:
|
||||
return self.client.smembers(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def remove_from_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.srem(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def clear(self) -> bool:
|
||||
try:
|
||||
return self.client.flushdb()
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.client.close()
|
||||
except RedisError:
|
||||
pass
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 23 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 69 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 26 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 20 KiB |
2
servers/fastapi/utils/dummy_functions.py
Normal file
2
servers/fastapi/utils/dummy_functions.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
async def do_nothing_async(_):
|
||||
return None
|
||||
|
|
@ -1,13 +1,23 @@
|
|||
from typing import List
|
||||
from pydantic import Field
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
def get_presentation_outline_model_with_n_slides(n_slides: int):
|
||||
class SlideOutlineModelWithNSlides(SlideOutlineModel):
|
||||
content: str = Field(
|
||||
description="Markdown content for each slide",
|
||||
min_length=100,
|
||||
max_length=300,
|
||||
)
|
||||
|
||||
class PresentationOutlineModelWithNSlides(PresentationOutlineModel):
|
||||
slides: List[str] = Field(
|
||||
description="Markdown content for each slide in about 100 to 200 words",
|
||||
slides: List[SlideOutlineModelWithNSlides] = Field(
|
||||
description="List of slide outlines",
|
||||
min_items=n_slides,
|
||||
max_items=n_slides,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -81,22 +81,6 @@ def get_pixabay_api_key_env():
|
|||
return os.getenv("PIXABAY_API_KEY")
|
||||
|
||||
|
||||
def get_redis_host_env():
|
||||
return os.getenv("REDIS_HOST")
|
||||
|
||||
|
||||
def get_redis_port_env():
|
||||
return os.getenv("REDIS_PORT")
|
||||
|
||||
|
||||
def get_redis_db_env():
|
||||
return os.getenv("REDIS_DB")
|
||||
|
||||
|
||||
def get_redis_password_env():
|
||||
return os.getenv("REDIS_PASSWORD")
|
||||
|
||||
|
||||
def get_tool_calls_env():
|
||||
return os.getenv("TOOL_CALLS")
|
||||
|
||||
|
|
@ -107,3 +91,7 @@ def get_disable_thinking_env():
|
|||
|
||||
def get_extended_reasoning_env():
|
||||
return os.getenv("EXTENDED_REASONING")
|
||||
|
||||
|
||||
def get_web_grounding_env():
|
||||
return os.getenv("WEB_GROUNDING")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
|
|
@ -41,12 +41,10 @@ def get_messages(
|
|||
language: str,
|
||||
):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(prompt, slide_data, language),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional
|
||||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
|
@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str):
|
|||
response = await client.generate(
|
||||
model=model,
|
||||
messages=[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
|
||||
LLMSystemMessage(content=system_prompt),
|
||||
LLMUserMessage(content=get_user_prompt(prompt, html)),
|
||||
],
|
||||
)
|
||||
return extract_html_from_response(response) or html
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
|
||||
from services.llm_client import LLMClient
|
||||
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
|
||||
from utils.get_env import get_web_grounding_env
|
||||
from utils.llm_provider import get_model
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.user_config import get_user_config
|
||||
|
||||
system_prompt = """
|
||||
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
|
||||
|
|
@ -29,12 +32,10 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
|
|||
|
||||
def get_messages(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(prompt, n_slides, language, content),
|
||||
),
|
||||
]
|
||||
|
|
@ -51,10 +52,13 @@ async def generate_ppt_outline(
|
|||
|
||||
client = LLMClient()
|
||||
|
||||
tools = [SearchWebTool, GetCurrentDatetimeTool]
|
||||
|
||||
async for chunk in client.stream_structured(
|
||||
model,
|
||||
get_messages(prompt, n_slides, language, content),
|
||||
response_model.model_json_schema(),
|
||||
strict=True,
|
||||
tools=tools if client.enable_web_grounding() else None,
|
||||
):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
|
|
@ -11,8 +11,7 @@ def get_messages(
|
|||
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
|
||||
):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=f"""
|
||||
You're a professional presentation designer with creative freedom to design engaging presentations.
|
||||
|
||||
|
|
@ -47,8 +46,7 @@ def get_messages(
|
|||
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
|
||||
""",
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=f"""
|
||||
{data}
|
||||
""",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
|
@ -38,19 +39,17 @@ def get_user_prompt(outline: str, language: str):
|
|||
def get_messages(outline: str, language: str):
|
||||
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=get_user_prompt(outline, language),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
slide_layout: SlideLayoutModel, outline: str, language: str
|
||||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
|
||||
):
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
|
@ -62,7 +61,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(
|
||||
outline,
|
||||
outline.content,
|
||||
language,
|
||||
),
|
||||
response_format=response_schema,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.llm_message import LLMSystemMessage, LLMUserMessage
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
|
|
@ -13,8 +13,7 @@ def get_messages(
|
|||
current_slide_layout: int,
|
||||
):
|
||||
return [
|
||||
LLMMessage(
|
||||
role="system",
|
||||
LLMSystemMessage(
|
||||
content=f"""
|
||||
Select a Slide Layout index based on provided user prompt and current slide data.
|
||||
{layout.to_string()}
|
||||
|
|
@ -26,8 +25,7 @@ def get_messages(
|
|||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
LLMUserMessage(
|
||||
content=f"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
|
|
|
|||
|
|
@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
|
|||
return resolved
|
||||
|
||||
|
||||
# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions
|
||||
def flatten_json_schema(schema: dict) -> dict:
|
||||
root_schema = deepcopy(schema)
|
||||
|
||||
def _flatten(node: Any) -> Any:
|
||||
if isinstance(node, dict):
|
||||
# If node is a pure $ref (or combined with extra fields), inline it
|
||||
if "$ref" in node:
|
||||
ref_value = node["$ref"]
|
||||
assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}"
|
||||
resolved = resolve_ref(root=root_schema, ref=ref_value)
|
||||
assert isinstance(resolved, dict), (
|
||||
f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}"
|
||||
)
|
||||
# Merge: referenced first, then overlay current (excluding $ref)
|
||||
merged: dict[str, Any] = deepcopy(resolved)
|
||||
for key, value in node.items():
|
||||
if key == "$ref":
|
||||
continue
|
||||
merged[key] = value
|
||||
return _flatten(merged)
|
||||
|
||||
flattened: dict[str, Any] = {}
|
||||
for key, value in node.items():
|
||||
# Drop defs/definitions in output
|
||||
if key in ("$defs", "definitions"):
|
||||
continue
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()}
|
||||
elif key in ("items", "contains", "additionalProperties", "not"):
|
||||
if isinstance(value, dict):
|
||||
flattened[key] = _flatten(value)
|
||||
elif isinstance(value, list):
|
||||
flattened[key] = [_flatten(v) for v in value]
|
||||
else:
|
||||
flattened[key] = value
|
||||
elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list):
|
||||
flattened[key] = [_flatten(v) for v in value]
|
||||
else:
|
||||
flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value
|
||||
return flattened
|
||||
if isinstance(node, list):
|
||||
return [_flatten(v) for v in node]
|
||||
return node
|
||||
|
||||
result = _flatten(schema)
|
||||
# Ensure top-level cleanup just in case
|
||||
if isinstance(result, dict):
|
||||
result.pop("$defs", None)
|
||||
result.pop("definitions", None)
|
||||
return result
|
||||
|
||||
|
||||
# ? Not used
|
||||
def generate_constraint_sentences(schema: dict) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -79,3 +79,7 @@ def set_disable_thinking_env(value):
|
|||
|
||||
def set_extended_reasoning_env(value):
|
||||
os.environ["EXTENDED_REASONING"] = value
|
||||
|
||||
|
||||
def set_web_grounding_env(value):
|
||||
os.environ["WEB_GROUNDING"] = value
|
||||
|
|
@ -22,6 +22,7 @@ from utils.get_env import (
|
|||
get_image_provider_env,
|
||||
get_pixabay_api_key_env,
|
||||
get_extended_reasoning_env,
|
||||
get_web_grounding_env,
|
||||
)
|
||||
from utils.parsers import parse_bool_or_none
|
||||
from utils.set_env import (
|
||||
|
|
@ -43,6 +44,7 @@ from utils.set_env import (
|
|||
set_image_provider_env,
|
||||
set_pixabay_api_key_env,
|
||||
set_tool_calls_env,
|
||||
set_web_grounding_env,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -76,12 +78,26 @@ def get_user_config():
|
|||
IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(),
|
||||
PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(),
|
||||
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(),
|
||||
TOOL_CALLS=existing_config.TOOL_CALLS
|
||||
or parse_bool_or_none(get_tool_calls_env()),
|
||||
DISABLE_THINKING=existing_config.DISABLE_THINKING
|
||||
or parse_bool_or_none(get_disable_thinking_env()),
|
||||
EXTENDED_REASONING=existing_config.EXTENDED_REASONING
|
||||
or parse_bool_or_none(get_extended_reasoning_env()),
|
||||
TOOL_CALLS=(
|
||||
existing_config.TOOL_CALLS
|
||||
if existing_config.TOOL_CALLS is not None
|
||||
else (parse_bool_or_none(get_tool_calls_env()) or False)
|
||||
),
|
||||
DISABLE_THINKING=(
|
||||
existing_config.DISABLE_THINKING
|
||||
if existing_config.DISABLE_THINKING is not None
|
||||
else (parse_bool_or_none(get_disable_thinking_env()) or False)
|
||||
),
|
||||
EXTENDED_REASONING=(
|
||||
existing_config.EXTENDED_REASONING
|
||||
if existing_config.EXTENDED_REASONING is not None
|
||||
else (parse_bool_or_none(get_extended_reasoning_env()) or False)
|
||||
),
|
||||
WEB_GROUNDING=(
|
||||
existing_config.WEB_GROUNDING
|
||||
if existing_config.WEB_GROUNDING is not None
|
||||
else (parse_bool_or_none(get_web_grounding_env()) or False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -122,5 +138,6 @@ def update_env_with_user_config():
|
|||
if user_config.DISABLE_THINKING:
|
||||
set_disable_thinking_env(str(user_config.DISABLE_THINKING))
|
||||
if user_config.EXTENDED_REASONING:
|
||||
if user_config.EXTENDED_REASONING:
|
||||
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
|
||||
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
|
||||
if user_config.WEB_GROUNDING:
|
||||
set_web_grounding_env(str(user_config.WEB_GROUNDING))
|
||||
|
|
|
|||
27
servers/fastapi/uv.lock
generated
27
servers/fastapi/uv.lock
generated
|
|
@ -1061,6 +1061,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isodate"
|
||||
version = "0.7.2"
|
||||
|
|
@ -1907,6 +1916,7 @@ dependencies = [
|
|||
{ name = "openai" },
|
||||
{ name = "pathvalidate" },
|
||||
{ name = "pdfplumber" },
|
||||
{ name = "pytest" },
|
||||
{ name = "python-pptx" },
|
||||
{ name = "redis" },
|
||||
{ name = "sqlmodel" },
|
||||
|
|
@ -1928,6 +1938,7 @@ requires-dist = [
|
|||
{ name = "openai", specifier = ">=1.98.0" },
|
||||
{ name = "pathvalidate", specifier = ">=3.3.1" },
|
||||
{ name = "pdfplumber", specifier = ">=0.11.7" },
|
||||
{ name = "pytest", specifier = ">=8.4.1" },
|
||||
{ name = "python-pptx", specifier = ">=1.0.2" },
|
||||
{ name = "redis", specifier = ">=6.2.0" },
|
||||
{ name = "sqlmodel", specifier = ">=0.0.24" },
|
||||
|
|
@ -2211,6 +2222,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-bidi"
|
||||
version = "0.6.6"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import React from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { LoadingState, StreamState, LayoutGroup } from "../types/index";
|
||||
import { LoadingState, LayoutGroup } from "../types/index";
|
||||
|
||||
interface GenerateButtonProps {
|
||||
loadingState: LoadingState;
|
||||
streamState: StreamState;
|
||||
streamState: { isStreaming: boolean, isLoading: boolean };
|
||||
selectedLayoutGroup: LayoutGroup | null;
|
||||
onSubmit: () => void;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import { Button } from "@/components/ui/button";
|
|||
import { FileText } from "lucide-react";
|
||||
|
||||
interface OutlineContentProps {
|
||||
outlines: string[] | null;
|
||||
outlines: { content: string }[] | null;
|
||||
isLoading: boolean;
|
||||
isStreaming: boolean;
|
||||
onDragEnd: (event: any) => void;
|
||||
|
|
@ -32,7 +32,7 @@ const OutlineContent: React.FC<OutlineContentProps> = ({
|
|||
onDragEnd,
|
||||
onAddSlide
|
||||
}) => {
|
||||
|
||||
console.log('isLoading', isLoading)
|
||||
const sensors = useSensors(
|
||||
useSensor(PointerSensor),
|
||||
useSensor(KeyboardSensor, {
|
||||
|
|
@ -83,7 +83,18 @@ const OutlineContent: React.FC<OutlineContentProps> = ({
|
|||
collisionDetection={closestCenter}
|
||||
onDragEnd={onDragEnd}
|
||||
>
|
||||
<SortableContext
|
||||
{isStreaming ? (
|
||||
|
||||
outlines.map((item, index) => (
|
||||
<OutlineItem
|
||||
key={`slide-${index}`}
|
||||
index={index + 1}
|
||||
slideOutline={item}
|
||||
isStreaming={isStreaming}
|
||||
/>
|
||||
))
|
||||
) :
|
||||
<SortableContext
|
||||
items={outlines?.map((item, index) => ({ id: `slide-${index}` })) || []}
|
||||
strategy={verticalListSortingStrategy}
|
||||
>
|
||||
|
|
@ -95,7 +106,7 @@ const OutlineContent: React.FC<OutlineContentProps> = ({
|
|||
isStreaming={isStreaming}
|
||||
/>
|
||||
))}
|
||||
</SortableContext>
|
||||
</SortableContext>}
|
||||
</DndContext>
|
||||
|
||||
<Button
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ import { useEffect } from "react"
|
|||
|
||||
|
||||
interface OutlineItemProps {
|
||||
slideOutline: string,
|
||||
slideOutline: {
|
||||
content: string,
|
||||
},
|
||||
index: number
|
||||
isStreaming: boolean
|
||||
}
|
||||
|
|
@ -38,7 +40,7 @@ export function OutlineItem({
|
|||
}
|
||||
}, [outlines.length]);
|
||||
|
||||
const handleSlideChange = (newOutline: string) => {
|
||||
const handleSlideChange = (newOutline:any) => {
|
||||
if (isStreaming) return;
|
||||
const newData = outlines?.map((each, idx) => {
|
||||
if (idx === index - 1) {
|
||||
|
|
@ -100,10 +102,10 @@ export function OutlineItem({
|
|||
{isStreaming ? <p
|
||||
className="text-sm flex-1 font-normal"
|
||||
>
|
||||
{slideOutline || ''}
|
||||
{slideOutline.content || ''}
|
||||
</p> : <MarkdownEditor
|
||||
key={index}
|
||||
content={slideOutline || ''}
|
||||
content={slideOutline.content || ''}
|
||||
onChange={(content) => handleSlideChange(content)}
|
||||
/>}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
import React from "react";
|
||||
|
||||
const PageHeader: React.FC = () => (
|
||||
<div className="mb-8">
|
||||
{/* <h4 className="text-2xl font-bold mb-2 text-gray-900">
|
||||
Customize Your Presentation
|
||||
</h4> */}
|
||||
{/* <p className="text-gray-600">
|
||||
Review your outline and select a layout style for your presentation.
|
||||
</p> */}
|
||||
</div>
|
||||
);
|
||||
|
||||
export default PageHeader;
|
||||
|
|
@ -3,7 +3,7 @@ import { useDispatch } from "react-redux";
|
|||
import { arrayMove } from "@dnd-kit/sortable";
|
||||
import { setOutlines } from "@/store/slices/presentationGeneration";
|
||||
|
||||
export const useOutlineManagement = (outlines: string[] | null) => {
|
||||
export const useOutlineManagement = (outlines: { content: string }[] | null) => {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const handleDragEnd = useCallback((event: any) => {
|
||||
|
|
@ -12,8 +12,8 @@ export const useOutlineManagement = (outlines: string[] | null) => {
|
|||
if (!active || !over || !outlines) return;
|
||||
|
||||
if (active.id !== over.id) {
|
||||
const oldIndex = outlines.findIndex((item) => item === active.id);
|
||||
const newIndex = outlines.findIndex((item) => item === over.id);
|
||||
const oldIndex = outlines.findIndex((item) => item.content === active.id);
|
||||
const newIndex = outlines.findIndex((item) => item.content === over.id);
|
||||
const reorderedArray = arrayMove(outlines, oldIndex, newIndex);
|
||||
dispatch(setOutlines(reorderedArray));
|
||||
}
|
||||
|
|
@ -22,7 +22,7 @@ export const useOutlineManagement = (outlines: string[] | null) => {
|
|||
const handleAddSlide = useCallback(() => {
|
||||
if (!outlines) return;
|
||||
|
||||
const updatedOutlines = [...outlines, "Outline title"];
|
||||
const updatedOutlines = [...outlines, { content: "Outline title" }];
|
||||
dispatch(setOutlines(updatedOutlines));
|
||||
}, [outlines, dispatch]);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,18 +3,15 @@ import { useDispatch, useSelector } from "react-redux";
|
|||
import { toast } from "sonner";
|
||||
import { setOutlines } from "@/store/slices/presentationGeneration";
|
||||
import { jsonrepair } from "jsonrepair";
|
||||
import { StreamState } from "../types/index";
|
||||
import { RootState } from "@/store/store";
|
||||
|
||||
const DEFAULT_STREAM_STATE: StreamState = {
|
||||
isStreaming: false,
|
||||
isLoading: true,
|
||||
};
|
||||
|
||||
|
||||
export const useOutlineStreaming = (presentationId: string | null) => {
|
||||
const dispatch = useDispatch();
|
||||
const { outlines } = useSelector((state: RootState) => state.presentationGeneration);
|
||||
const [streamState, setStreamState] = useState<StreamState>(DEFAULT_STREAM_STATE);
|
||||
const [isStreaming, setIsStreaming] = useState(true);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (!presentationId || outlines.length > 0) return;
|
||||
|
|
@ -23,8 +20,8 @@ export const useOutlineStreaming = (presentationId: string | null) => {
|
|||
let accumulatedChunks = "";
|
||||
|
||||
const initializeStream = async () => {
|
||||
setStreamState({ isStreaming: true, isLoading: true });
|
||||
|
||||
setIsStreaming(true)
|
||||
setIsLoading(true)
|
||||
try {
|
||||
eventSource = new EventSource(
|
||||
`/api/v1/ppt/outlines/stream?presentation_id=${presentationId}`
|
||||
|
|
@ -34,13 +31,16 @@ export const useOutlineStreaming = (presentationId: string | null) => {
|
|||
const data = JSON.parse(event.data);
|
||||
switch (data.type) {
|
||||
case "chunk":
|
||||
// console.log('data', data)
|
||||
accumulatedChunks += data.chunk;
|
||||
// console.log('accumulatedChunks', accumulatedChunks)
|
||||
try {
|
||||
const repairedJson = jsonrepair(accumulatedChunks);
|
||||
const partialData = JSON.parse(repairedJson);
|
||||
console.log('partialData', partialData)
|
||||
if (partialData.slides) {
|
||||
dispatch(setOutlines(partialData.slides));
|
||||
setStreamState(prev => ({ ...prev, isLoading: false }));
|
||||
setIsLoading(false)
|
||||
}
|
||||
} catch (error) {
|
||||
// JSON isn't complete yet, continue accumulating
|
||||
|
|
@ -48,11 +48,13 @@ export const useOutlineStreaming = (presentationId: string | null) => {
|
|||
break;
|
||||
|
||||
case "complete":
|
||||
console.log('complete', data)
|
||||
try {
|
||||
const outlinesData: string[] = data.presentation.outlines.slides;
|
||||
const outlinesData: { content: string }[] = data.presentation.outlines.slides;
|
||||
dispatch(setOutlines(outlinesData));
|
||||
setStreamState({ isStreaming: false, isLoading: false });
|
||||
eventSource.close();
|
||||
setIsStreaming(false)
|
||||
setIsLoading(false)
|
||||
eventSource.close();
|
||||
} catch (error) {
|
||||
console.error("Error parsing accumulated chunks:", error);
|
||||
toast.error("Failed to parse presentation data");
|
||||
|
|
@ -62,11 +64,15 @@ export const useOutlineStreaming = (presentationId: string | null) => {
|
|||
break;
|
||||
|
||||
case "closing":
|
||||
setStreamState({ isStreaming: false, isLoading: false });
|
||||
console.log('closing', data)
|
||||
setIsStreaming(false)
|
||||
setIsLoading(false)
|
||||
eventSource.close();
|
||||
break;
|
||||
case "error":
|
||||
setStreamState({ isStreaming: false, isLoading: false });
|
||||
console.log('error', data)
|
||||
setIsStreaming(false)
|
||||
setIsLoading(false)
|
||||
eventSource.close();
|
||||
toast.error('Error in outline streaming',
|
||||
{
|
||||
|
|
@ -78,18 +84,20 @@ export const useOutlineStreaming = (presentationId: string | null) => {
|
|||
});
|
||||
|
||||
eventSource.onerror = () => {
|
||||
setStreamState({ isStreaming: false, isLoading: false });
|
||||
console.log('onerror')
|
||||
setIsStreaming(false)
|
||||
setIsLoading(false)
|
||||
eventSource.close();
|
||||
toast.error("Failed to connect to the server. Please try again.");
|
||||
};
|
||||
} catch (error) {
|
||||
setStreamState({ isStreaming: false, isLoading: false });
|
||||
console.log('error', error)
|
||||
setIsStreaming(false)
|
||||
setIsLoading(false)
|
||||
toast.error("Failed to initialize connection");
|
||||
}finally{
|
||||
setStreamState({ isStreaming: false, isLoading: false });
|
||||
}
|
||||
};
|
||||
initializeStream();
|
||||
initializeStream();
|
||||
return () => {
|
||||
if (eventSource) {
|
||||
eventSource.close();
|
||||
|
|
@ -97,5 +105,5 @@ export const useOutlineStreaming = (presentationId: string | null) => {
|
|||
};
|
||||
}, [presentationId, dispatch]);
|
||||
|
||||
return streamState;
|
||||
return { isStreaming, isLoading };
|
||||
};
|
||||
|
|
@ -15,7 +15,7 @@ const DEFAULT_LOADING_STATE: LoadingState = {
|
|||
|
||||
export const usePresentationGeneration = (
|
||||
presentationId: string | null,
|
||||
outlines: string[] | null,
|
||||
outlines: { content: string }[] | null,
|
||||
selectedLayoutGroup: LayoutGroup | null,
|
||||
setActiveTab: (tab: string) => void
|
||||
) => {
|
||||
|
|
|
|||
|
|
@ -14,10 +14,7 @@ export interface LoadingState {
|
|||
duration: number;
|
||||
}
|
||||
|
||||
export interface StreamState {
|
||||
isStreaming: boolean;
|
||||
isLoading: boolean;
|
||||
}
|
||||
|
||||
|
||||
export const TABS = {
|
||||
OUTLINE: 'outline',
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ const UploadPage = () => {
|
|||
config,
|
||||
files: responses,
|
||||
}));
|
||||
dispatch(clearOutlines());
|
||||
dispatch(clearOutlines())
|
||||
router.push("/documents-preview");
|
||||
};
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ const UploadPage = () => {
|
|||
});
|
||||
|
||||
dispatch(setPresentationId(createResponse.id));
|
||||
dispatch(clearOutlines());
|
||||
dispatch(clearOutlines())
|
||||
router.push("/outline");
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ export async function POST(request: Request) {
|
|||
userConfig.EXTENDED_REASONING === undefined
|
||||
? existingConfig.EXTENDED_REASONING
|
||||
: userConfig.EXTENDED_REASONING,
|
||||
WEB_GROUNDING:
|
||||
userConfig.WEB_GROUNDING === undefined
|
||||
? existingConfig.WEB_GROUNDING
|
||||
: userConfig.WEB_GROUNDING,
|
||||
USE_CUSTOM_URL:
|
||||
userConfig.USE_CUSTOM_URL === undefined
|
||||
? existingConfig.USE_CUSTOM_URL
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ interface AnthropicConfigProps {
|
|||
anthropicApiKey: string;
|
||||
anthropicModel: string;
|
||||
extendedReasoning: boolean;
|
||||
webGrounding?: boolean;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
|
|
@ -27,6 +28,7 @@ export default function AnthropicConfig({
|
|||
anthropicApiKey,
|
||||
anthropicModel,
|
||||
extendedReasoning,
|
||||
webGrounding,
|
||||
onInputChange,
|
||||
}: AnthropicConfigProps) {
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
|
|
@ -65,7 +67,7 @@ export default function AnthropicConfig({
|
|||
const data = await response.json();
|
||||
setAvailableModels(data);
|
||||
setModelsChecked(true);
|
||||
onInputChange("claude-3-5-sonnet-20241022", "anthropic_model");
|
||||
onInputChange("claude-sonnet-4-20250514", "anthropic_model");
|
||||
} else {
|
||||
console.error('Failed to fetch models');
|
||||
setAvailableModels([]);
|
||||
|
|
@ -226,6 +228,23 @@ export default function AnthropicConfig({
|
|||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{/* Web Grounding Toggle - at the end, below models dropdown */}
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Enable Web Grounding
|
||||
</label>
|
||||
<Switch
|
||||
checked={!!webGrounding}
|
||||
onCheckedChange={(checked) => onInputChange(checked, "web_grounding")}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
If enabled, the model can use web search grounding when available.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -13,16 +13,19 @@ import {
|
|||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
import { Switch } from "./ui/switch";
|
||||
|
||||
interface GoogleConfigProps {
|
||||
googleApiKey: string;
|
||||
googleModel: string;
|
||||
onInputChange: (value: string, field: string) => void;
|
||||
webGrounding?: boolean;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
export default function GoogleConfig({
|
||||
googleApiKey,
|
||||
googleModel,
|
||||
webGrounding,
|
||||
onInputChange
|
||||
}: GoogleConfigProps) {
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
|
|
@ -61,7 +64,7 @@ export default function GoogleConfig({
|
|||
const data = await response.json();
|
||||
setAvailableModels(data);
|
||||
setModelsChecked(true);
|
||||
onInputChange("models/gemini-2.0-flash", "google_model");
|
||||
onInputChange("models/gemini-2.5-flash", "google_model");
|
||||
} else {
|
||||
console.error('Failed to fetch models');
|
||||
setAvailableModels([]);
|
||||
|
|
@ -205,6 +208,23 @@ export default function GoogleConfig({
|
|||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{/* Web Grounding Toggle - at the end, below models dropdown */}
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Enable Web Grounding
|
||||
</label>
|
||||
<Switch
|
||||
checked={!!webGrounding}
|
||||
onCheckedChange={(checked) => onInputChange(checked, "web_grounding")}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
If enabled, the model can use web search grounding when available.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -149,6 +149,7 @@ export default function LLMProviderSelection({
|
|||
<OpenAIConfig
|
||||
openaiApiKey={llmConfig.OPENAI_API_KEY || ""}
|
||||
openaiModel={llmConfig.OPENAI_MODEL || ""}
|
||||
webGrounding={llmConfig.WEB_GROUNDING || false}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
|
@ -158,6 +159,7 @@ export default function LLMProviderSelection({
|
|||
<GoogleConfig
|
||||
googleApiKey={llmConfig.GOOGLE_API_KEY || ""}
|
||||
googleModel={llmConfig.GOOGLE_MODEL || ""}
|
||||
webGrounding={llmConfig.WEB_GROUNDING || false}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
|
@ -168,6 +170,7 @@ export default function LLMProviderSelection({
|
|||
anthropicApiKey={llmConfig.ANTHROPIC_API_KEY || ""}
|
||||
anthropicModel={llmConfig.ANTHROPIC_MODEL || ""}
|
||||
extendedReasoning={llmConfig.EXTENDED_REASONING || false}
|
||||
webGrounding={llmConfig.WEB_GROUNDING || false}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ interface OllamaModel {
|
|||
label: string;
|
||||
value: string;
|
||||
size: string;
|
||||
icon: string;
|
||||
}
|
||||
|
||||
interface OllamaConfigProps {
|
||||
|
|
@ -128,19 +127,6 @@ export default function OllamaConfig({
|
|||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
{ollamaModel && (
|
||||
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={
|
||||
ollamaModels?.find(
|
||||
(m) => m.value === ollamaModel
|
||||
)?.icon
|
||||
}
|
||||
alt={`${ollamaModel} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{ollamaModel
|
||||
? ollamaModels?.find(
|
||||
|
|
@ -189,13 +175,6 @@ export default function OllamaConfig({
|
|||
)}
|
||||
/>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={model.icon}
|
||||
alt={`${model.label} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900 capitalize">
|
||||
|
|
|
|||
|
|
@ -13,16 +13,19 @@ import {
|
|||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
import { Switch } from "./ui/switch";
|
||||
|
||||
interface OpenAIConfigProps {
|
||||
openaiApiKey: string;
|
||||
openaiModel: string;
|
||||
onInputChange: (value: string, field: string) => void;
|
||||
webGrounding?: boolean;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
export default function OpenAIConfig({
|
||||
openaiApiKey,
|
||||
openaiModel,
|
||||
webGrounding,
|
||||
onInputChange
|
||||
}: OpenAIConfigProps) {
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
|
|
@ -210,6 +213,23 @@ export default function OpenAIConfig({
|
|||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{/* Web Grounding Toggle - show at the end, below models dropdown */}
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Enable Web Grounding
|
||||
</label>
|
||||
<Switch
|
||||
checked={!!webGrounding}
|
||||
onCheckedChange={(checked) => onInputChange(checked, "web_grounding")}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
If enabled, the model can use web search grounding when available.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -18,7 +18,7 @@ interface PresentationGenerationState {
|
|||
presentation_id: string | null;
|
||||
isLoading: boolean;
|
||||
isStreaming: boolean | null;
|
||||
outlines: string[];
|
||||
outlines: { content: string }[];
|
||||
error: string | null;
|
||||
presentationData: PresentationData | null;
|
||||
isSlidesRendered: boolean;
|
||||
|
|
@ -72,7 +72,7 @@ const presentationGenerationSlice = createSlice({
|
|||
state.outlines = [];
|
||||
},
|
||||
// Set outlines
|
||||
setOutlines: (state, action: PayloadAction<string[]>) => {
|
||||
setOutlines: (state, action: PayloadAction<{ content: string }[]>) => {
|
||||
state.outlines = action.payload;
|
||||
},
|
||||
// Set presentation data
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ export interface LLMConfig {
|
|||
TOOL_CALLS?: boolean;
|
||||
DISABLE_THINKING?: boolean;
|
||||
EXTENDED_REASONING?: boolean;
|
||||
WEB_GROUNDING?: boolean;
|
||||
|
||||
// Only used in UI settings
|
||||
USE_CUSTOM_URL?: boolean;
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@ import { LLMConfig } from "@/types/llm_config";
|
|||
export interface OllamaModel {
|
||||
label: string;
|
||||
value: string;
|
||||
description: string;
|
||||
size: string;
|
||||
icon: string;
|
||||
}
|
||||
|
||||
export interface DownloadingModel {
|
||||
|
|
@ -48,6 +46,7 @@ export const updateLLMConfig = (
|
|||
tool_calls: "TOOL_CALLS",
|
||||
disable_thinking: "DISABLE_THINKING",
|
||||
extended_reasoning: "EXTENDED_REASONING",
|
||||
web_grounding: "WEB_GROUNDING",
|
||||
};
|
||||
|
||||
const configKey = fieldMappings[field];
|
||||
|
|
|
|||
1
start.js
1
start.js
|
|
@ -81,6 +81,7 @@ const setupUserConfigFromEnv = () => {
|
|||
TOOL_CALLS: process.env.TOOL_CALLS || existingConfig.TOOL_CALLS,
|
||||
DISABLE_THINKING: process.env.DISABLE_THINKING || existingConfig.DISABLE_THINKING,
|
||||
EXTENDED_REASONING: process.env.EXTENDED_REASONING || existingConfig.EXTENDED_REASONING,
|
||||
WEB_GROUNDING: process.env.WEB_GROUNDING || existingConfig.WEB_GROUNDING,
|
||||
USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL,
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue