Merge pull request #198 from presenton/feat/llm_grounding_web_search

feat/llm grounding web search
This commit is contained in:
Saurav Niraula 2025-08-09 03:08:17 +05:45 committed by GitHub
commit 9526f486fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1608 additions and 345 deletions

View file

@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep
- **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom**
- **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output.
- **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled.
- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results.
You can also set the following environment variables to customize the image generation provider and API keys:

View file

@ -25,6 +25,9 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
- TOOL_CALLS=${TOOL_CALLS}
- DISABLE_THINKING=${DISABLE_THINKING}
- WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}
production-gpu:
@ -60,6 +63,9 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
- TOOL_CALLS=${TOOL_CALLS}
- DISABLE_THINKING=${DISABLE_THINKING}
- WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}
development:
@ -87,6 +93,9 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
- TOOL_CALLS=${TOOL_CALLS}
- DISABLE_THINKING=${DISABLE_THINKING}
- WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}
development-gpu:
@ -121,4 +130,7 @@ services:
- CUSTOM_MODEL=${CUSTOM_MODEL}
- PEXELS_API_KEY=${PEXELS_API_KEY}
- EXTENDED_REASONING=${EXTENDED_REASONING}
- TOOL_CALLS=${TOOL_CALLS}
- DISABLE_THINKING=${DISABLE_THINKING}
- WEB_GROUNDING=${WEB_GROUNDING}
- DATABASE_URL=${DATABASE_URL}

View file

@ -64,7 +64,7 @@ async def pull_model(
# If the model is being pulled, return the model
if saved_model_status:
# If the model is being pulled, return the model
# ? If the model status is pulled in redis but was not found while listing pulled models,
# ? If the model status is pulled in database but was not found while listing pulled models,
# ? it means the model was deleted and we need to pull it again
if (
saved_model_status["status"] == "error"

View file

@ -72,6 +72,9 @@ async def stream_outlines(
presentation_outlines_json = json.loads(presentation_outlines_text)
except Exception as e:
print(e)
with open("./debug/outlines.txt", "w") as f:
f.write(presentation_outlines_text)
print(presentation_outlines_text)
raise HTTPException(
status_code=400,
detail="Failed to generate presentation outlines. Please try again.",

View file

@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1"
# Default models
DEFAULT_OPENAI_MODEL = "gpt-4.1"
DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash"
DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"
DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash"
DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514"

View file

@ -0,0 +1,8 @@
from enum import Enum
class LLMCallType(Enum):
UNSTRUCTURED = "unstructured"
UNSTRUCTURED_STREAM = "unstructured_stream"
STRUCTURED = "structured"
STRUCTURED_STREAM = "structured_stream"

View file

@ -1,7 +1,58 @@
from typing import Literal
from typing import Any, List, Literal, Optional
from pydantic import BaseModel
from google.genai.types import Content as GoogleContent
from models.llm_tool_call import AnthropicToolCall
class LLMMessage(BaseModel):
role: Literal["user", "system"]
pass
class LLMUserMessage(LLMMessage):
role: Literal["user"] = "user"
content: str
class LLMSystemMessage(LLMMessage):
role: Literal["system"] = "system"
content: str
class OpenAIAssistantMessage(LLMMessage):
role: Literal["assistant"] = "assistant"
content: str | None = None
tool_calls: Optional[List[dict]] = None
class GoogleAssistantMessage(LLMMessage):
role: Literal["assistant"] = "assistant"
content: GoogleContent
class AnthropicAssistantMessage(LLMMessage):
role: Literal["assistant"] = "assistant"
content: List[AnthropicToolCall]
class AnthropicToolCallMessage(LLMMessage):
type: Literal["tool_result"] = "tool_result"
tool_use_id: str
content: str
class AnthropicUserMessage(LLMMessage):
role: Literal["user"] = "user"
content: List[AnthropicToolCallMessage]
class OpenAIToolCallMessage(LLMMessage):
role: Literal["tool"] = "tool"
content: str
tool_call_id: str
class GoogleToolCallMessage(LLMMessage):
role: Literal["tool"] = "tool"
name: str
response: dict

View file

@ -0,0 +1,29 @@
from typing import Literal, Optional
from pydantic import BaseModel
class LLMToolCall(BaseModel):
pass
class OpenAIToolCallFunction(BaseModel):
name: str
arguments: str
class OpenAIToolCall(LLMToolCall):
id: str
type: Literal["function"] = "function"
function: OpenAIToolCallFunction
class GoogleToolCall(LLMToolCall):
name: str
arguments: Optional[dict] = None
class AnthropicToolCall(LLMToolCall):
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: object

View file

@ -0,0 +1,29 @@
from typing import Any, Callable, Coroutine, Optional
from pydantic import BaseModel, Field
class LLMTool(BaseModel):
pass
class LLMDynamicTool(LLMTool):
name: str
description: str
parameters: dict = {}
handler: Callable[..., Coroutine[Any, Any, str]]
class SearchWebTool(LLMTool):
"""
Search the web for information.
"""
query: str = Field(description="The query to search the web for")
class GetCurrentDatetimeTool(LLMTool):
"""
Get the current datetime.
"""
pass

View file

@ -35,3 +35,6 @@ class UserConfig(BaseModel):
TOOL_CALLS: Optional[bool] = None
DISABLE_THINKING: Optional[bool] = None
EXTENDED_REASONING: Optional[bool] = None
# Web Search
WEB_GROUNDING: Optional[bool] = None

View file

@ -19,6 +19,7 @@ dependencies = [
"openai>=1.98.0",
"pathvalidate>=3.3.1",
"pdfplumber>=0.11.7",
"pytest>=8.4.1",
"python-pptx>=1.0.2",
"redis>=6.2.0",
"sqlmodel>=0.0.24",

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,201 @@
import asyncio
from datetime import datetime
import json
from typing import Any, Callable, Coroutine, List, Optional
from fastapi import HTTPException
from enums.llm_provider import LLMProvider
from models.llm_message import (
AnthropicToolCallMessage,
GoogleToolCallMessage,
OpenAIToolCallMessage,
)
from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall
from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool
from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema
class LLMToolCallsHandler:
def __init__(self, client):
from services.llm_client import LLMClient
self.client: LLMClient = client
self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = {
"SearchWebTool": self.search_web_tool_call_handler,
"GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler,
}
self.dynamic_tools: List[LLMDynamicTool] = []
def get_tool_handler(
self, tool_name: str
) -> Callable[..., Coroutine[Any, Any, str]]:
handler = self.tools_map.get(tool_name)
if handler:
return handler
else:
dynamic_tools = list(
filter(lambda tool: tool.name == tool_name, self.dynamic_tools)
)
if dynamic_tools:
return dynamic_tools[0].handler
raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found")
def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None):
if tools is None:
return None
parsed_tools = map(self.parse_tool, tools)
return list(parsed_tools)
def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False):
if isinstance(tool, LLMDynamicTool):
self.dynamic_tools.append(tool)
match self.client.llm_provider:
case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM:
return self.parse_tool_openai(tool, strict)
case LLMProvider.ANTHROPIC:
return self.parse_tool_anthropic(tool)
case LLMProvider.GOOGLE:
return self.parse_tool_google(tool)
case _:
raise ValueError(
f"LLM provider must be either openai, anthropic, or google"
)
def parse_tool_openai(
self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False
):
if isinstance(tool, LLMDynamicTool):
name = tool.name
description = tool.description
parameters = tool.parameters
else:
name = tool.__name__
description = tool.__doc__ or ""
parameters = tool.model_json_schema()
if strict:
parameters = ensure_strict_json_schema(parameters, path=(), root=parameters)
return {
"type": "function",
"function": {
"name": name,
"description": description,
"strict": strict,
"parameters": parameters,
},
}
def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool):
parsed = self.parse_tool_openai(tool)
# parsed["function"]["parameters"] = flatten_json_schema(
# parsed["function"]["parameters"]
# )
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],
"parameters": parsed["function"]["parameters"],
}
def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool):
parsed = self.parse_tool_openai(tool)
input_schema = parsed["function"]["parameters"]
return {
"name": parsed["function"]["name"],
"description": parsed["function"]["description"],
"input_schema": {"type": "object"} if input_schema == {} else input_schema,
}
async def handle_tool_calls_openai(
self,
tool_calls: List[OpenAIToolCall],
) -> List[OpenAIToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
OpenAIToolCallMessage(
content=result,
tool_call_id=tool_call.id,
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
return tool_call_messages
async def handle_tool_calls_google(
self,
tool_calls: List[GoogleToolCall],
) -> List[GoogleToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments)))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
GoogleToolCallMessage(
name=tool_call.name,
response={"result": result},
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
return tool_call_messages
async def handle_tool_calls_anthropic(
self,
tool_calls: List[AnthropicToolCall],
) -> List[AnthropicToolCallMessage]:
async_tool_calls_tasks = []
for tool_call in tool_calls:
tool_name = tool_call.name
tool_handler = self.get_tool_handler(tool_name)
async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input)))
tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks)
tool_call_messages = [
AnthropicToolCallMessage(
content=result,
tool_use_id=tool_call.id,
)
for tool_call, result in zip(tool_calls, tool_call_results)
]
return tool_call_messages
# ? Tool call handlers
# Search web tool call handler
async def search_web_tool_call_handler(self, arguments: str) -> str:
match self.client.llm_provider:
case LLMProvider.OPENAI:
return await self.search_web_tool_call_handler_openai(arguments)
case LLMProvider.ANTHROPIC:
return await self.search_web_tool_call_handler_anthropic(arguments)
case LLMProvider.GOOGLE:
return await self.search_web_tool_call_handler_google(arguments)
case _:
return (
"Web search tool call handler not implemented for this LLM provider: "
+ self.client.llm_provider.value
)
async def search_web_tool_call_handler_openai(self, arguments: str) -> str:
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_openai(args.query)
async def search_web_tool_call_handler_google(self, arguments: str) -> str:
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_google(args.query)
async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str:
args = SearchWebTool.model_validate_json(arguments)
return await self.client._search_anthropic(args.query)
# Get current datetime tool call handler
async def get_current_datetime_tool_call_handler(self, _) -> str:
current_time = datetime.now()
return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}"

View file

@ -1,115 +0,0 @@
from typing import Any, Optional
import redis
from redis.exceptions import RedisError
from utils.get_env import (
get_redis_db_env,
get_redis_host_env,
get_redis_password_env,
get_redis_port_env,
)
class RedisService:
def __init__(self):
self.redis_host = get_redis_host_env() or "localhost"
self.redis_port = int(get_redis_port_env() or "6379")
self.redis_db = int(get_redis_db_env() or "0")
self.redis_password = get_redis_password_env() or None
self.client = self._create_client()
def _create_client(self) -> redis.Redis:
return redis.Redis(
host=self.redis_host,
port=self.redis_port,
db=self.redis_db,
password=self.redis_password,
decode_responses=True,
)
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
try:
return self.client.set(key, value, ex=expire)
except RedisError:
return False
def get(self, key: str) -> Optional[str]:
try:
return self.client.get(key)
except RedisError:
return None
def delete(self, key: str) -> bool:
try:
return bool(self.client.delete(key))
except RedisError:
return False
def exists(self, key: str) -> bool:
try:
return bool(self.client.exists(key))
except RedisError:
return False
def set_hash(self, name: str, mapping: dict) -> bool:
try:
return self.client.hmset(name, mapping)
except RedisError:
return False
def get_hash(self, name: str) -> Optional[dict]:
try:
return self.client.hgetall(name)
except RedisError:
return None
def delete_hash(self, name: str, *fields: str) -> int:
try:
return self.client.hdel(name, *fields)
except RedisError:
return 0
def set_list(self, name: str, values: list) -> bool:
try:
self.client.delete(name)
if values:
self.client.rpush(name, *values)
return True
except RedisError:
return False
def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
try:
return self.client.lrange(name, start, end)
except RedisError:
return None
def add_to_set(self, name: str, *values: str) -> int:
try:
return self.client.sadd(name, *values)
except RedisError:
return 0
def get_set(self, name: str) -> Optional[set]:
try:
return self.client.smembers(name)
except RedisError:
return None
def remove_from_set(self, name: str, *values: str) -> int:
try:
return self.client.srem(name, *values)
except RedisError:
return 0
def clear(self) -> bool:
try:
return self.client.flushdb()
except RedisError:
return False
def close(self):
try:
self.client.close()
except RedisError:
pass

View file

@ -0,0 +1,2 @@
async def do_nothing_async(_):
return None

View file

@ -81,22 +81,6 @@ def get_pixabay_api_key_env():
return os.getenv("PIXABAY_API_KEY")
def get_redis_host_env():
return os.getenv("REDIS_HOST")
def get_redis_port_env():
return os.getenv("REDIS_PORT")
def get_redis_db_env():
return os.getenv("REDIS_DB")
def get_redis_password_env():
return os.getenv("REDIS_PASSWORD")
def get_tool_calls_env():
return os.getenv("TOOL_CALLS")
@ -107,3 +91,7 @@ def get_disable_thinking_env():
def get_extended_reasoning_env():
return os.getenv("EXTENDED_REASONING")
def get_web_grounding_env():
return os.getenv("WEB_GROUNDING")

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import SlideLayoutModel
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
@ -41,12 +41,10 @@ def get_messages(
language: str,
):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=system_prompt,
),
LLMMessage(
role="user",
LLMUserMessage(
content=get_user_prompt(prompt, slide_data, language),
),
]

View file

@ -1,5 +1,5 @@
from typing import Optional
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from services.llm_client import LLMClient
from utils.llm_provider import get_model
@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str):
response = await client.generate(
model=model,
messages=[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
LLMSystemMessage(content=system_prompt),
LLMUserMessage(content=get_user_prompt(prompt, html)),
],
)
return extract_html_from_response(response) or html

View file

@ -1,9 +1,13 @@
from typing import Optional
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool
from services.llm_client import LLMClient
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
from utils.get_env import get_web_grounding_env
from utils.llm_provider import get_model
from utils.parsers import parse_bool_or_none
from utils.user_config import get_user_config
system_prompt = """
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
@ -28,12 +32,10 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
def get_messages(prompt: str, n_slides: int, language: str, content: str):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=system_prompt,
),
LLMMessage(
role="user",
LLMUserMessage(
content=get_user_prompt(prompt, n_slides, language, content),
),
]
@ -50,10 +52,13 @@ async def generate_ppt_outline(
client = LLMClient()
tools = [SearchWebTool, GetCurrentDatetimeTool]
async for chunk in client.stream_structured(
model,
get_messages(prompt, n_slides, language, content),
response_model.model_json_schema(),
strict=True,
tools=tools if client.enable_web_grounding() else None,
):
yield chunk

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import PresentationLayoutModel
from models.presentation_outline_model import PresentationOutlineModel
from services.llm_client import LLMClient
@ -11,8 +11,7 @@ def get_messages(
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=f"""
You're a professional presentation designer with creative freedom to design engaging presentations.
@ -47,8 +46,7 @@ def get_messages(
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
""",
),
LLMMessage(
role="user",
LLMUserMessage(
content=f"""
{data}
""",

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import SlideLayoutModel
from models.presentation_outline_model import SlideOutlineModel
from services.llm_client import LLMClient
@ -39,12 +39,10 @@ def get_user_prompt(outline: str, language: str):
def get_messages(outline: str, language: str):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=system_prompt,
),
LLMMessage(
role="user",
LLMUserMessage(
content=get_user_prompt(outline, language),
),
]

View file

@ -1,4 +1,4 @@
from models.llm_message import LLMMessage
from models.llm_message import LLMSystemMessage, LLMUserMessage
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from models.slide_layout_index import SlideLayoutIndex
from models.sql.slide import SlideModel
@ -13,8 +13,7 @@ def get_messages(
current_slide_layout: int,
):
return [
LLMMessage(
role="system",
LLMSystemMessage(
content=f"""
Select a Slide Layout index based on provided user prompt and current slide data.
{layout.to_string()}
@ -26,8 +25,7 @@ def get_messages(
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
),
LLMMessage(
role="user",
LLMUserMessage(
content=f"""
- User Prompt: {prompt}
- Current Slide Data: {slide_data}

View file

@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
return resolved
# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions
def flatten_json_schema(schema: dict) -> dict:
root_schema = deepcopy(schema)
def _flatten(node: Any) -> Any:
if isinstance(node, dict):
# If node is a pure $ref (or combined with extra fields), inline it
if "$ref" in node:
ref_value = node["$ref"]
assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}"
resolved = resolve_ref(root=root_schema, ref=ref_value)
assert isinstance(resolved, dict), (
f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}"
)
# Merge: referenced first, then overlay current (excluding $ref)
merged: dict[str, Any] = deepcopy(resolved)
for key, value in node.items():
if key == "$ref":
continue
merged[key] = value
return _flatten(merged)
flattened: dict[str, Any] = {}
for key, value in node.items():
# Drop defs/definitions in output
if key in ("$defs", "definitions"):
continue
if key == "properties" and isinstance(value, dict):
flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()}
elif key in ("items", "contains", "additionalProperties", "not"):
if isinstance(value, dict):
flattened[key] = _flatten(value)
elif isinstance(value, list):
flattened[key] = [_flatten(v) for v in value]
else:
flattened[key] = value
elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list):
flattened[key] = [_flatten(v) for v in value]
else:
flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value
return flattened
if isinstance(node, list):
return [_flatten(v) for v in node]
return node
result = _flatten(schema)
# Ensure top-level cleanup just in case
if isinstance(result, dict):
result.pop("$defs", None)
result.pop("definitions", None)
return result
# ? Not used
def generate_constraint_sentences(schema: dict) -> str:
"""

View file

@ -79,3 +79,7 @@ def set_disable_thinking_env(value):
def set_extended_reasoning_env(value):
os.environ["EXTENDED_REASONING"] = value
def set_web_grounding_env(value):
os.environ["WEB_GROUNDING"] = value

View file

@ -22,6 +22,7 @@ from utils.get_env import (
get_image_provider_env,
get_pixabay_api_key_env,
get_extended_reasoning_env,
get_web_grounding_env,
)
from utils.parsers import parse_bool_or_none
from utils.set_env import (
@ -43,6 +44,7 @@ from utils.set_env import (
set_image_provider_env,
set_pixabay_api_key_env,
set_tool_calls_env,
set_web_grounding_env,
)
@ -76,12 +78,26 @@ def get_user_config():
IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(),
PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(),
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(),
TOOL_CALLS=existing_config.TOOL_CALLS
or parse_bool_or_none(get_tool_calls_env()),
DISABLE_THINKING=existing_config.DISABLE_THINKING
or parse_bool_or_none(get_disable_thinking_env()),
EXTENDED_REASONING=existing_config.EXTENDED_REASONING
or parse_bool_or_none(get_extended_reasoning_env()),
TOOL_CALLS=(
existing_config.TOOL_CALLS
if existing_config.TOOL_CALLS is not None
else (parse_bool_or_none(get_tool_calls_env()) or False)
),
DISABLE_THINKING=(
existing_config.DISABLE_THINKING
if existing_config.DISABLE_THINKING is not None
else (parse_bool_or_none(get_disable_thinking_env()) or False)
),
EXTENDED_REASONING=(
existing_config.EXTENDED_REASONING
if existing_config.EXTENDED_REASONING is not None
else (parse_bool_or_none(get_extended_reasoning_env()) or False)
),
WEB_GROUNDING=(
existing_config.WEB_GROUNDING
if existing_config.WEB_GROUNDING is not None
else (parse_bool_or_none(get_web_grounding_env()) or False)
),
)
@ -122,5 +138,6 @@ def update_env_with_user_config():
if user_config.DISABLE_THINKING:
set_disable_thinking_env(str(user_config.DISABLE_THINKING))
if user_config.EXTENDED_REASONING:
if user_config.EXTENDED_REASONING:
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
if user_config.WEB_GROUNDING:
set_web_grounding_env(str(user_config.WEB_GROUNDING))

View file

@ -1061,6 +1061,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" },
]
[[package]]
name = "iniconfig"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "isodate"
version = "0.7.2"
@ -1907,6 +1916,7 @@ dependencies = [
{ name = "openai" },
{ name = "pathvalidate" },
{ name = "pdfplumber" },
{ name = "pytest" },
{ name = "python-pptx" },
{ name = "redis" },
{ name = "sqlmodel" },
@ -1928,6 +1938,7 @@ requires-dist = [
{ name = "openai", specifier = ">=1.98.0" },
{ name = "pathvalidate", specifier = ">=3.3.1" },
{ name = "pdfplumber", specifier = ">=0.11.7" },
{ name = "pytest", specifier = ">=8.4.1" },
{ name = "python-pptx", specifier = ">=1.0.2" },
{ name = "redis", specifier = ">=6.2.0" },
{ name = "sqlmodel", specifier = ">=0.0.24" },
@ -2211,6 +2222,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" },
]
[[package]]
name = "pytest"
version = "8.4.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
]
[[package]]
name = "python-bidi"
version = "0.6.6"

View file

@ -57,6 +57,10 @@ export async function POST(request: Request) {
userConfig.EXTENDED_REASONING === undefined
? existingConfig.EXTENDED_REASONING
: userConfig.EXTENDED_REASONING,
WEB_GROUNDING:
userConfig.WEB_GROUNDING === undefined
? existingConfig.WEB_GROUNDING
: userConfig.WEB_GROUNDING,
USE_CUSTOM_URL:
userConfig.USE_CUSTOM_URL === undefined
? existingConfig.USE_CUSTOM_URL

View file

@ -19,6 +19,7 @@ interface AnthropicConfigProps {
anthropicApiKey: string;
anthropicModel: string;
extendedReasoning: boolean;
webGrounding?: boolean;
onInputChange: (value: string | boolean, field: string) => void;
}
@ -27,6 +28,7 @@ export default function AnthropicConfig({
anthropicApiKey,
anthropicModel,
extendedReasoning,
webGrounding,
onInputChange,
}: AnthropicConfigProps) {
const [openModelSelect, setOpenModelSelect] = useState(false);
@ -65,7 +67,7 @@ export default function AnthropicConfig({
const data = await response.json();
setAvailableModels(data);
setModelsChecked(true);
onInputChange("claude-3-5-sonnet-20241022", "anthropic_model");
onInputChange("claude-sonnet-4-20250514", "anthropic_model");
} else {
console.error('Failed to fetch models');
setAvailableModels([]);
@ -226,6 +228,23 @@ export default function AnthropicConfig({
</div>
</div>
) : null}
{/* Web Grounding Toggle - at the end, below models dropdown */}
<div>
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
<label className="text-sm font-medium text-gray-700">
Enable Web Grounding
</label>
<Switch
checked={!!webGrounding}
onCheckedChange={(checked) => onInputChange(checked, "web_grounding")}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
If enabled, the model can use web search grounding when available.
</p>
</div>
</div>
);
}

View file

@ -13,16 +13,19 @@ import {
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
import { cn } from "@/lib/utils";
import { toast } from "sonner";
import { Switch } from "./ui/switch";
interface GoogleConfigProps {
googleApiKey: string;
googleModel: string;
onInputChange: (value: string, field: string) => void;
webGrounding?: boolean;
onInputChange: (value: string | boolean, field: string) => void;
}
export default function GoogleConfig({
googleApiKey,
googleModel,
webGrounding,
onInputChange
}: GoogleConfigProps) {
const [openModelSelect, setOpenModelSelect] = useState(false);
@ -61,7 +64,7 @@ export default function GoogleConfig({
const data = await response.json();
setAvailableModels(data);
setModelsChecked(true);
onInputChange("models/gemini-2.0-flash", "google_model");
onInputChange("models/gemini-2.5-flash", "google_model");
} else {
console.error('Failed to fetch models');
setAvailableModels([]);
@ -205,6 +208,23 @@ export default function GoogleConfig({
</div>
</div>
) : null}
{/* Web Grounding Toggle - at the end, below models dropdown */}
<div>
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
<label className="text-sm font-medium text-gray-700">
Enable Web Grounding
</label>
<Switch
checked={!!webGrounding}
onCheckedChange={(checked) => onInputChange(checked, "web_grounding")}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
If enabled, the model can use web search grounding when available.
</p>
</div>
</div>
);
}

View file

@ -149,6 +149,7 @@ export default function LLMProviderSelection({
<OpenAIConfig
openaiApiKey={llmConfig.OPENAI_API_KEY || ""}
openaiModel={llmConfig.OPENAI_MODEL || ""}
webGrounding={llmConfig.WEB_GROUNDING || false}
onInputChange={input_field_changed}
/>
</TabsContent>
@ -158,6 +159,7 @@ export default function LLMProviderSelection({
<GoogleConfig
googleApiKey={llmConfig.GOOGLE_API_KEY || ""}
googleModel={llmConfig.GOOGLE_MODEL || ""}
webGrounding={llmConfig.WEB_GROUNDING || false}
onInputChange={input_field_changed}
/>
</TabsContent>
@ -168,6 +170,7 @@ export default function LLMProviderSelection({
anthropicApiKey={llmConfig.ANTHROPIC_API_KEY || ""}
anthropicModel={llmConfig.ANTHROPIC_MODEL || ""}
extendedReasoning={llmConfig.EXTENDED_REASONING || false}
webGrounding={llmConfig.WEB_GROUNDING || false}
onInputChange={input_field_changed}
/>
</TabsContent>

View file

@ -13,16 +13,19 @@ import {
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
import { cn } from "@/lib/utils";
import { toast } from "sonner";
import { Switch } from "./ui/switch";
interface OpenAIConfigProps {
openaiApiKey: string;
openaiModel: string;
onInputChange: (value: string, field: string) => void;
webGrounding?: boolean;
onInputChange: (value: string | boolean, field: string) => void;
}
export default function OpenAIConfig({
openaiApiKey,
openaiModel,
webGrounding,
onInputChange
}: OpenAIConfigProps) {
const [openModelSelect, setOpenModelSelect] = useState(false);
@ -210,6 +213,23 @@ export default function OpenAIConfig({
</div>
</div>
) : null}
{/* Web Grounding Toggle - show at the end, below models dropdown */}
<div>
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
<label className="text-sm font-medium text-gray-700">
Enable Web Grounding
</label>
<Switch
checked={!!webGrounding}
onCheckedChange={(checked) => onInputChange(checked, "web_grounding")}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
If enabled, the model can use web search grounding when available.
</p>
</div>
</div>
);
}

View file

@ -31,6 +31,7 @@ export interface LLMConfig {
TOOL_CALLS?: boolean;
DISABLE_THINKING?: boolean;
EXTENDED_REASONING?: boolean;
WEB_GROUNDING?: boolean;
// Only used in UI settings
USE_CUSTOM_URL?: boolean;

View file

@ -48,6 +48,7 @@ export const updateLLMConfig = (
tool_calls: "TOOL_CALLS",
disable_thinking: "DISABLE_THINKING",
extended_reasoning: "EXTENDED_REASONING",
web_grounding: "WEB_GROUNDING",
};
const configKey = fieldMappings[field];

View file

@ -81,6 +81,7 @@ const setupUserConfigFromEnv = () => {
TOOL_CALLS: process.env.TOOL_CALLS || existingConfig.TOOL_CALLS,
DISABLE_THINKING: process.env.DISABLE_THINKING || existingConfig.DISABLE_THINKING,
EXTENDED_REASONING: process.env.EXTENDED_REASONING || existingConfig.EXTENDED_REASONING,
WEB_GROUNDING: process.env.WEB_GROUNDING || existingConfig.WEB_GROUNDING,
USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL,
};