diff --git a/README.md b/README.md index 72a53771..a7851cc4 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ You may want to directly provide your API KEYS as environment variables and keep - **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom** - **TOOL_CALLS=[Enable/Disable Tool Calls on Custom LLM]**: If **true**, **LLM** will use Tool Call instead of Json Schema for Structured Output. - **DISABLE_THINKING=[Enable/Disable Thinking on Custom LLM]**: If **true**, Thinking will be disabled. +- **WEB_GROUNDING=[Enable/Disable Web Search for OpenAI, Google And Anthropic]**: If **true**, LLM will be able to search web for better results. You can also set the following environment variables to customize the image generation provider and API keys: diff --git a/docker-compose.yml b/docker-compose.yml index 42502b5f..832b6ebe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} production-gpu: @@ -60,6 +63,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development: @@ -87,6 +93,9 @@ services: - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} - DATABASE_URL=${DATABASE_URL} development-gpu: @@ -122,3 +131,7 @@ services: - PEXELS_API_KEY=${PEXELS_API_KEY} - DATABASE_URL=${DATABASE_URL} - EXTENDED_REASONING=${EXTENDED_REASONING} + - TOOL_CALLS=${TOOL_CALLS} + - DISABLE_THINKING=${DISABLE_THINKING} + - WEB_GROUNDING=${WEB_GROUNDING} + - DATABASE_URL=${DATABASE_URL} diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index adde8669..0dafa3e1 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -64,7 +64,7 @@ async def pull_model( # If the model is being pulled, return the model if saved_model_status: # If the model is being pulled, return the model - # ? If the model status is pulled in redis but was not found while listing pulled models, + # ? If the model status is pulled in database but was not found while listing pulled models, # ? it means the model was deleted and we need to pull it again if ( saved_model_status["status"] == "error" diff --git a/servers/fastapi/api/v1/ppt/endpoints/outlines.py b/servers/fastapi/api/v1/ppt/endpoints/outlines.py index b0ec47af..bf5b1489 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/outlines.py +++ b/servers/fastapi/api/v1/ppt/endpoints/outlines.py @@ -72,6 +72,9 @@ async def stream_outlines( presentation_outlines_json = json.loads(presentation_outlines_text) except Exception as e: print(e) + with open("./debug/outlines.txt", "w") as f: + f.write(presentation_outlines_text) + print(presentation_outlines_text) raise HTTPException( status_code=400, detail="Failed to generate presentation outlines. Please try again.", @@ -87,7 +90,8 @@ async def stream_outlines( presentation.outlines = presentation_outlines.model_dump() presentation.title = ( - presentation_outlines.slides[0][:50] + presentation_outlines.slides[0] + .content[:50] .replace("#", "") .replace("/", "") .replace("\\", "") diff --git a/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/servers/fastapi/api/v1/ppt/endpoints/presentation.py index 5b4589a2..2650782d 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/presentation.py +++ b/servers/fastapi/api/v1/ppt/endpoints/presentation.py @@ -11,7 +11,10 @@ from sqlmodel import select from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES from models.presentation_and_path import PresentationPathAndEditPath from models.presentation_from_template import GetPresentationUsingTemplateRequest -from models.presentation_outline_model import PresentationOutlineModel +from models.presentation_outline_model import ( + PresentationOutlineModel, + SlideOutlineModel, +) from models.pptx_models import PptxPresentationModel from models.presentation_layout import PresentationLayoutModel from models.presentation_structure_model import PresentationStructureModel @@ -126,7 +129,7 @@ async def create_presentation( @PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel) async def prepare_presentation( presentation_id: Annotated[str, Body()], - outlines: Annotated[List[str], Body()], + outlines: Annotated[List[SlideOutlineModel], Body()], layout: Annotated[PresentationLayoutModel, Body()], title: Annotated[Optional[str], Body()] = None, sql_session: AsyncSession = Depends(get_async_session), @@ -161,7 +164,9 @@ async def prepare_presentation( presentation_structure.slides[index] = random_slide_index sql_session.add(presentation) - presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump() + presentation.outlines = PresentationOutlineModel(slides=outlines).model_dump( + mode="json" + ) presentation.title = title or presentation.title presentation.set_layout(layout) presentation.set_structure(presentation_structure) diff --git a/servers/fastapi/constants/llm.py b/servers/fastapi/constants/llm.py index ac4bd527..7d374f30 100644 --- a/servers/fastapi/constants/llm.py +++ b/servers/fastapi/constants/llm.py @@ -2,5 +2,5 @@ OPENAI_URL = "https://api.openai.com/v1" # Default models DEFAULT_OPENAI_MODEL = "gpt-4.1" -DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash" -DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620" +DEFAULT_GOOGLE_MODEL = "models/gemini-2.5-flash" +DEFAULT_ANTHROPIC_MODEL = "claude-sonnet-4-20250514" diff --git a/servers/fastapi/constants/supported_ollama_models.py b/servers/fastapi/constants/supported_ollama_models.py index 455a1217..a02a7a18 100644 --- a/servers/fastapi/constants/supported_ollama_models.py +++ b/servers/fastapi/constants/supported_ollama_models.py @@ -6,61 +6,51 @@ SUPPORTED_OLLAMA_MODELS = { label="Llama 3:8b", value="llama3:8b", size="4.7GB", - icon="/static/icons/meta.png", ), "llama3:70b": OllamaModelMetadata( label="Llama 3:70b", value="llama3:70b", size="40GB", - icon="/static/icons/meta.png", ), "llama3.1:8b": OllamaModelMetadata( label="Llama 3.1:8b", value="llama3.1:8b", size="4.9GB", - icon="/static/icons/meta.png", ), "llama3.1:70b": OllamaModelMetadata( label="Llama 3.1:70b", value="llama3.1:70b", size="43GB", - icon="/static/icons/meta.png", ), "llama3.1:405b": OllamaModelMetadata( label="Llama 3.1:405b", value="llama3.1:405b", size="243GB", - icon="/static/icons/meta.png", ), "llama3.2:1b": OllamaModelMetadata( label="Llama 3.2:1b", value="llama3.2:1b", size="1.3GB", - icon="/static/icons/meta.png", ), "llama3.2:3b": OllamaModelMetadata( label="Llama 3.2:3b", value="llama3.2:3b", size="2GB", - icon="/static/icons/meta.png", ), "llama3.3:70b": OllamaModelMetadata( label="Llama 3.3:70b", value="llama3.3:70b", size="43GB", - icon="/static/icons/meta.png", ), "llama4:16x17b": OllamaModelMetadata( label="Llama 4:16x17b", value="llama4:16x17b", size="67GB", - icon="/static/icons/meta.png", ), "llama4:128x17b": OllamaModelMetadata( label="Llama 4:128x17b", value="llama4:128x17b", size="245GB", - icon="/static/icons/meta.png", ), } @@ -69,25 +59,21 @@ SUPPORTED_GEMMA_MODELS = { label="Gemma 3:1b", value="gemma3:1b", size="815MB", - icon="/static/icons/gemma.png", ), "gemma3:4b": OllamaModelMetadata( label="Gemma 3:4b", value="gemma3:4b", size="3.3GB", - icon="/static/icons/gemma.png", ), "gemma3:12b": OllamaModelMetadata( label="Gemma 3:12b", value="gemma3:12b", size="8.1GB", - icon="/static/icons/gemma.png", ), "gemma3:27b": OllamaModelMetadata( label="Gemma 3:27b", value="gemma3:27b", size="17GB", - icon="/static/icons/gemma.png", ), } @@ -96,43 +82,36 @@ SUPPORTED_DEEPSEEK_MODELS = { label="DeepSeek R1:1.5b", value="deepseek-r1:1.5b", size="1.1GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:7b": OllamaModelMetadata( label="DeepSeek R1:7b", value="deepseek-r1:7b", size="4.7GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:8b": OllamaModelMetadata( label="DeepSeek R1:8b", value="deepseek-r1:8b", size="5.2GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:14b": OllamaModelMetadata( label="DeepSeek R1:14b", value="deepseek-r1:14b", size="9GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:32b": OllamaModelMetadata( label="DeepSeek R1:32b", value="deepseek-r1:32b", size="20GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:70b": OllamaModelMetadata( label="DeepSeek R1:70b", value="deepseek-r1:70b", size="43GB", - icon="/static/icons/deepseek.png", ), "deepseek-r1:671b": OllamaModelMetadata( label="DeepSeek R1:671b", value="deepseek-r1:671b", size="404GB", - icon="/static/icons/deepseek.png", ), } @@ -141,49 +120,54 @@ SUPPORTED_QWEN_MODELS = { label="Qwen 3:0.6b", value="qwen3:0.6b", size="523MB", - icon="/static/icons/qwen.png", ), "qwen3:1.7b": OllamaModelMetadata( label="Qwen 3:1.7b", value="qwen3:1.7b", size="1.4GB", - icon="/static/icons/qwen.png", ), "qwen3:4b": OllamaModelMetadata( label="Qwen 3:4b", value="qwen3:4b", size="2.6GB", - icon="/static/icons/qwen.png", ), "qwen3:8b": OllamaModelMetadata( label="Qwen 3:8b", value="qwen3:8b", size="5.2GB", - icon="/static/icons/qwen.png", ), "qwen3:14b": OllamaModelMetadata( label="Qwen 3:14b", value="qwen3:14b", size="9.3GB", - icon="/static/icons/qwen.png", ), "qwen3:30b": OllamaModelMetadata( label="Qwen 3:30b", value="qwen3:30b", size="19GB", - icon="/static/icons/qwen.png", ), "qwen3:32b": OllamaModelMetadata( label="Qwen 3:32b", value="qwen3:32b", size="20GB", - icon="/static/icons/qwen.png", ), "qwen3:235b": OllamaModelMetadata( label="Qwen 3:235b", value="qwen3:235b", size="142GB", - icon="/static/icons/qwen.png", + ), +} + +SUPPORTED_GPT_OSS_MODELS = { + "gpt-oss:20b": OllamaModelMetadata( + label="GPT-OSS 20b", + value="gpt-oss:20b", + size="14GB", + ), + "gpt-oss:120b": OllamaModelMetadata( + label="GPT-OSS 120b", + value="gpt-oss:120b", + size="65GB", ), } @@ -192,4 +176,5 @@ SUPPORTED_OLLAMA_MODELS = { **SUPPORTED_GEMMA_MODELS, **SUPPORTED_DEEPSEEK_MODELS, **SUPPORTED_QWEN_MODELS, + **SUPPORTED_GPT_OSS_MODELS, } diff --git a/servers/fastapi/enums/llm_call_type.py b/servers/fastapi/enums/llm_call_type.py new file mode 100644 index 00000000..e37fe4ae --- /dev/null +++ b/servers/fastapi/enums/llm_call_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LLMCallType(Enum): + UNSTRUCTURED = "unstructured" + UNSTRUCTURED_STREAM = "unstructured_stream" + STRUCTURED = "structured" + STRUCTURED_STREAM = "structured_stream" diff --git a/servers/fastapi/models/document_chunk.py b/servers/fastapi/models/document_chunk.py index 6861e4fb..a7500be9 100644 --- a/servers/fastapi/models/document_chunk.py +++ b/servers/fastapi/models/document_chunk.py @@ -1,5 +1,7 @@ from pydantic import BaseModel +from models.presentation_outline_model import SlideOutlineModel + class DocumentChunk(BaseModel): heading: str @@ -7,5 +9,5 @@ class DocumentChunk(BaseModel): heading_index: int score: float - def to_slide_outline(self) -> str: - return f"{self.heading}\n{self.content}" + def to_slide_outline(self) -> SlideOutlineModel: + return SlideOutlineModel(content=f"{self.heading}\n{self.content}") diff --git a/servers/fastapi/models/llm_message.py b/servers/fastapi/models/llm_message.py index 51284173..db741ca4 100644 --- a/servers/fastapi/models/llm_message.py +++ b/servers/fastapi/models/llm_message.py @@ -1,7 +1,58 @@ -from typing import Literal +from typing import Any, List, Literal, Optional from pydantic import BaseModel +from google.genai.types import Content as GoogleContent + +from models.llm_tool_call import AnthropicToolCall class LLMMessage(BaseModel): - role: Literal["user", "system"] + pass + + +class LLMUserMessage(LLMMessage): + role: Literal["user"] = "user" content: str + + +class LLMSystemMessage(LLMMessage): + role: Literal["system"] = "system" + content: str + + +class OpenAIAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: str | None = None + tool_calls: Optional[List[dict]] = None + + +class GoogleAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: GoogleContent + + +class AnthropicAssistantMessage(LLMMessage): + role: Literal["assistant"] = "assistant" + content: List[AnthropicToolCall] + + +class AnthropicToolCallMessage(LLMMessage): + type: Literal["tool_result"] = "tool_result" + tool_use_id: str + content: str + + +class AnthropicUserMessage(LLMMessage): + role: Literal["user"] = "user" + content: List[AnthropicToolCallMessage] + + +class OpenAIToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + content: str + tool_call_id: str + + +class GoogleToolCallMessage(LLMMessage): + role: Literal["tool"] = "tool" + name: str + response: dict diff --git a/servers/fastapi/models/llm_tool_call.py b/servers/fastapi/models/llm_tool_call.py new file mode 100644 index 00000000..5eb1f008 --- /dev/null +++ b/servers/fastapi/models/llm_tool_call.py @@ -0,0 +1,29 @@ +from typing import Literal, Optional +from pydantic import BaseModel + + +class LLMToolCall(BaseModel): + pass + + +class OpenAIToolCallFunction(BaseModel): + name: str + arguments: str + + +class OpenAIToolCall(LLMToolCall): + id: str + type: Literal["function"] = "function" + function: OpenAIToolCallFunction + + +class GoogleToolCall(LLMToolCall): + name: str + arguments: Optional[dict] = None + + +class AnthropicToolCall(LLMToolCall): + type: Literal["tool_use"] = "tool_use" + id: str + name: str + input: object diff --git a/servers/fastapi/models/llm_tools.py b/servers/fastapi/models/llm_tools.py new file mode 100644 index 00000000..ccf64e67 --- /dev/null +++ b/servers/fastapi/models/llm_tools.py @@ -0,0 +1,29 @@ +from typing import Any, Callable, Coroutine, Optional +from pydantic import BaseModel, Field + + +class LLMTool(BaseModel): + pass + + +class LLMDynamicTool(LLMTool): + name: str + description: str + parameters: dict = {} + handler: Callable[..., Coroutine[Any, Any, str]] + + +class SearchWebTool(LLMTool): + """ + Search the web for information. + """ + + query: str = Field(description="The query to search the web for") + + +class GetCurrentDatetimeTool(LLMTool): + """ + Get the current datetime. + """ + + pass diff --git a/servers/fastapi/models/ollama_model_metadata.py b/servers/fastapi/models/ollama_model_metadata.py index 1f8ed985..88c89cc6 100644 --- a/servers/fastapi/models/ollama_model_metadata.py +++ b/servers/fastapi/models/ollama_model_metadata.py @@ -4,5 +4,4 @@ from pydantic import BaseModel class OllamaModelMetadata(BaseModel): label: str value: str - icon: str size: str diff --git a/servers/fastapi/models/presentation_outline_model.py b/servers/fastapi/models/presentation_outline_model.py index ad55ae4b..01a3b2b7 100644 --- a/servers/fastapi/models/presentation_outline_model.py +++ b/servers/fastapi/models/presentation_outline_model.py @@ -2,8 +2,12 @@ from typing import List from pydantic import BaseModel +class SlideOutlineModel(BaseModel): + content: str + + class PresentationOutlineModel(BaseModel): - slides: List[str] + slides: List[SlideOutlineModel] def to_string(self): message = "" diff --git a/servers/fastapi/models/user_config.py b/servers/fastapi/models/user_config.py index 50783544..c040d22c 100644 --- a/servers/fastapi/models/user_config.py +++ b/servers/fastapi/models/user_config.py @@ -35,3 +35,6 @@ class UserConfig(BaseModel): TOOL_CALLS: Optional[bool] = None DISABLE_THINKING: Optional[bool] = None EXTENDED_REASONING: Optional[bool] = None + + # Web Search + WEB_GROUNDING: Optional[bool] = None diff --git a/servers/fastapi/pyproject.toml b/servers/fastapi/pyproject.toml index f0caf8e3..14240244 100644 --- a/servers/fastapi/pyproject.toml +++ b/servers/fastapi/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "openai>=1.98.0", "pathvalidate>=3.3.1", "pdfplumber>=0.11.7", + "pytest>=8.4.1", "python-pptx>=1.0.2", "redis>=6.2.0", "sqlmodel>=0.0.24", diff --git a/servers/fastapi/services/llm_client.py b/servers/fastapi/services/llm_client.py index f016e763..3e8b35f2 100644 --- a/servers/fastapi/services/llm_client.py +++ b/servers/fastapi/services/llm_client.py @@ -1,16 +1,40 @@ import asyncio import json -from typing import List, Optional +from typing import AsyncGenerator, List, Optional from fastapi import HTTPException from openai import AsyncOpenAI +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) from google import genai -from google.genai.types import GenerateContentConfig +from google.genai.types import Content as GoogleContent, Part as GoogleContentPart +from google.genai.types import GenerateContentConfig, GoogleSearch +from google.genai.types import Tool as GoogleTool from anthropic import AsyncAnthropic from anthropic.types import Message as AnthropicMessage from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent from enums.llm_provider import LLMProvider -from models.llm_message import LLMMessage +from models.llm_message import ( + AnthropicAssistantMessage, + AnthropicUserMessage, + GoogleAssistantMessage, + GoogleToolCallMessage, + OpenAIAssistantMessage, + LLMMessage, + LLMSystemMessage, + LLMUserMessage, +) +from models.llm_tool_call import ( + AnthropicToolCall, + GoogleToolCall, + LLMToolCall, + OpenAIToolCall, + OpenAIToolCallFunction, +) +from models.llm_tools import LLMDynamicTool, LLMTool +from services.llm_tool_calls_handler import LLMToolCallsHandler from utils.async_iterator import iterator_to_async +from utils.dummy_functions import do_nothing_async from utils.get_env import ( get_anthropic_api_key_env, get_custom_llm_api_key_env, @@ -20,8 +44,9 @@ from utils.get_env import ( get_ollama_url_env, get_openai_api_key_env, get_tool_calls_env, + get_web_grounding_env, ) -from utils.llm_provider import get_llm_provider +from utils.llm_provider import get_llm_provider, get_model from utils.parsers import parse_bool_or_none from utils.schema_utils import ensure_strict_json_schema @@ -30,17 +55,25 @@ class LLMClient: def __init__(self): self.llm_provider = get_llm_provider() self._client = self._get_client() + self.tool_calls_handler = LLMToolCallsHandler(self) # ? Use tool calls - def use_tool_calls(self) -> bool: + def use_tool_calls_for_structured_output(self) -> bool: if self.llm_provider != LLMProvider.CUSTOM: return False return parse_bool_or_none(get_tool_calls_env()) or False + # ? Web Grounding + def enable_web_grounding(self) -> bool: + if ( + self.llm_provider == LLMProvider.OLLAMA + or self.llm_provider == LLMProvider.CUSTOM + ): + return False + return parse_bool_or_none(get_web_grounding_env()) or False + # ? Disable thinking def disable_thinking(self) -> bool: - if self.llm_provider != LLMProvider.CUSTOM: - return False return parse_bool_or_none(get_disable_thinking_env()) or False # ? Clients @@ -106,15 +139,39 @@ class LLMClient: # ? Prompts def _get_system_prompt(self, messages: List[LLMMessage]) -> str: for message in messages: - if message.role == "system": + if isinstance(message, LLMSystemMessage): return message.content return "" - def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]: - return [message.content for message in messages if message.role == "user"] + def _get_google_messages(self, messages: List[LLMMessage]) -> List[str]: + contents = [] + for message in messages: + if isinstance(message, LLMUserMessage): + contents.append( + GoogleContent( + role="user", parts=[GoogleContentPart(text=message.content)] + ) + ) + elif isinstance(message, GoogleAssistantMessage): + contents.append(message.content) + elif isinstance(message, GoogleToolCallMessage): + contents.append( + GoogleContent( + role="user", + parts=[ + GoogleContentPart.from_function_response( + name=message.name, response=message.response + ) + ], + ) + ) - def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: - return [message for message in messages if message.role == "user"] + return contents + + def _get_anthropic_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]: + return [ + message for message in messages if not isinstance(message, LLMSystemMessage) + ] # ? Generate Unstructured Content async def _generate_openai( @@ -122,89 +179,250 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> str | None: client: AsyncOpenAI = self._client response = await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, + tools=tools, + extra_body=extra_body, ) + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + assistant_message = OpenAIAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[tool_call.model_dump() for tool_call in parsed_tool_calls], + ) + new_messages = [ + *messages, + assistant_message, + *tool_call_messages, + ] + return await self._generate_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ) + return response.choices[0].message.content async def _generate_google( self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> str | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", max_output_tokens=max_tokens, ), ) - return response.text + + content = response.candidates[0].content + response_parts = content.parts + + if not response_parts: + return None + + text_content = None + tool_calls = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + if each_part.text: + text_content = each_part.text + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_anthropic( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> str | None: client: AsyncAnthropic = self._client + response: AnthropicMessage = await client.messages.create( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], + tools=tools, max_tokens=max_tokens or 4000, ) - text = "" + text_content = None + tool_calls: List[AnthropicToolCall] = [] for content in response.content: if content.type == "text" and isinstance(content.text, str): - text += content.text - if text == "": - return None - return text + text_content = content.text + + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ) + + return text_content async def _generate_ollama( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): - return await self._generate_openai(model, messages, max_tokens) + return await self._generate_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) async def _generate_custom( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + depth: int = 0, ): - return await self._generate_openai(model, messages, max_tokens) + extra_body = {"enable_thinking": not self.disable_thinking()} + return await self._generate_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, + ) async def generate( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: - content = await self._generate_openai(model, messages, max_tokens) + content = await self._generate_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - content = await self._generate_google(model, messages, max_tokens) + content = await self._generate_google( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.ANTHROPIC: - content = await self._generate_anthropic(model, messages, max_tokens) + content = await self._generate_anthropic( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.OLLAMA: - content = await self._generate_ollama(model, messages, max_tokens) + content = await self._generate_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - content = await self._generate_custom(model, messages, max_tokens) + content = await self._generate_custom( + model=model, messages=messages, max_tokens=max_tokens + ) if content is None: raise HTTPException( status_code=400, @@ -220,21 +438,43 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> dict | None: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - response_format={ + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + parameters=response_schema, + handler=do_nothing_async, + ), + strict=strict, + ) + ) + + response = await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + response_format=( + { "type": "json_schema", "json_schema": ( { @@ -243,40 +483,66 @@ class LLMClient: "schema": response_schema, } ), - }, - max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) - content = response.choices[0].message.content - else: - response = await client.chat.completions.create( - model=model, - messages=[message.model_dump() for message in messages], - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", - max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) - tool_calls = response.choices[0].message.tool_calls - if tool_calls: - content = tool_calls[0].function.arguments + } + if not use_tool_calls_for_structured_output + else None + ), + max_completion_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + ) + content = response.choices[0].message.content + + tool_calls = response.choices[0].message.tool_calls + has_response_schema = False + + if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name == "ResponseSchema": + content = tool_call.function.arguments + has_response_schema = True + + if not has_response_schema: + parsed_tool_calls = [ + OpenAIToolCall( + id=tool_call.id, + type=tool_call.type, + function=OpenAIToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + for tool_call in tool_calls + ] + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_openai( + parsed_tool_calls + ) + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=response.choices[0].message.content, + tool_calls=[each.model_dump() for each in parsed_tool_calls], + ), + *tool_call_messages, + ] + content = await self._generate_openai_structured( + model=model, + messages=new_messages, + response_format=response_schema, + strict=strict, + max_tokens=max_tokens, + tools=all_tools, + extra_body=extra_body, + depth=depth + 1, + ) if content: - return json.loads(content) + if depth == 0: + return json.loads(content) + return content return None async def _generate_google_structured( @@ -285,31 +551,96 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> dict | None: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters_json_schema": response_format, + } + ] + ) + ) + response = await asyncio.to_thread( client.models.generate_content, model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ) - content = None - if response.text: - content = json.loads(response.text) - return content + content = response.candidates[0].content + response_parts = content.parts + text_content = None + + if not response_parts: + return None + + tool_calls: List[GoogleToolCall] = [] + for each_part in response_parts: + if each_part.function_call: + tool_calls.append( + GoogleToolCall( + name=each_part.function_call.name, + arguments=each_part.function_call.args, + ) + ) + + if each_part.text: + text_content = each_part.text + + for each in tool_calls: + if each.name == "ResponseSchema": + return each.arguments + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_google( + tool_calls + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=content, + ), + *tool_call_messages, + ] + return await self._generate_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + if text_content: + return json.loads(text_content) + return None async def _generate_anthropic_structured( self, model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client response: AnthropicMessage = await client.messages.create( @@ -317,7 +648,7 @@ class LLMClient: system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -325,19 +656,51 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) - content: dict | None = None - for content_block in response.content: - if content_block.type == "tool_use": - content = content_block.input + tool_calls: List[AnthropicToolCall] = [] + for content in response.content: + if content.type == "tool_use": + tool_calls.append( + AnthropicToolCall( + id=content.id, + type=content.type, + name=content.name, + input=content.input, + ) + ) - return content + for each in tool_calls: + if each.name == "ResponseSchema": + return each.input + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic(tool_calls) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + return await self._generate_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ) + + return None async def _generate_ollama_structured( self, @@ -346,9 +709,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) async def _generate_custom_structured( @@ -358,9 +727,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): + extra_body = {"enable_thinking": not self.disable_thinking()} return await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) async def generate_structured( @@ -369,29 +746,53 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ) -> dict: + parsed_tools = self.tool_calls_handler.parse_tools(tools) + content = None match self.llm_provider: case LLMProvider.OPENAI: content = await self._generate_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: content = await self._generate_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: content = await self._generate_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: content = await self._generate_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: content = await self._generate_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) if content is None: raise HTTPException( @@ -406,90 +807,285 @@ class LLMClient: model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - async with client.chat.completions.stream( + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + async for event in await client.chat.completions.create( model=model, messages=[message.model_dump() for message in messages], max_completion_tokens=max_tokens, - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta + tools=tools, + extra_body=extra_body, + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], + ), + *tool_call_messages, + ] + async for event in self._stream_openai( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + extra_body=extra_body, + depth=depth + 1, + ): + yield event async def _stream_google( self, model: str, messages: List[LLMMessage], + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + + tool_calls = None async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( system_instruction=self._get_system_prompt(messages), response_mime_type="text/plain", + tools=google_tools, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + tools: Optional[List[dict]] = None, + depth: int = 0, ): client: AsyncAnthropic = self._client + async with client.messages.stream( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, + tools=tools, ) as stream: + tool_calls: List[AnthropicToolCall] = [] async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "text" and isinstance(event.text, str): + + if event.type == "text": yield event.text + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic( + model=model, + messages=new_messages, + max_tokens=max_tokens, + tools=tools, + depth=depth + 1, + ): + yield event + def _stream_ollama( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, messages=messages, max_tokens=max_tokens, depth=depth + ) def _stream_custom( self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None, + depth: int = 0, ): - return self._stream_openai(model, messages, max_tokens) + extra_body = {"enable_thinking": not self.disable_thinking()} + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, + ) def stream( - self, model: str, messages: List[LLMMessage], max_tokens: Optional[int] = None + self, + model: str, + messages: List[LLMMessage], + max_tokens: Optional[int] = None, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: - return self._stream_openai(model, messages, max_tokens) + return self._stream_openai( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.GOOGLE: - return self._stream_google(model, messages, max_tokens) + return self._stream_google( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.ANTHROPIC: - return self._stream_anthropic(model, messages, max_tokens) + return self._stream_anthropic( + model=model, + messages=messages, + max_tokens=max_tokens, + tools=parsed_tools, + ) case LLMProvider.OLLAMA: - return self._stream_ollama(model, messages, max_tokens) + return self._stream_ollama( + model=model, messages=messages, max_tokens=max_tokens + ) case LLMProvider.CUSTOM: - return self._stream_custom(model, messages, max_tokens) + return self._stream_custom( + model=model, messages=messages, max_tokens=max_tokens + ) # ? Stream Structured Content async def _stream_openai_structured( @@ -499,62 +1095,145 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + extra_body: Optional[dict] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncOpenAI = self._client - use_tool_calls = self.use_tool_calls() + response_schema = response_format - if strict: + all_tools = [*tools] if tools else None + + use_tool_calls_for_structured_output = ( + self.use_tool_calls_for_structured_output() + ) + if strict and depth == 0: response_schema = ensure_strict_json_schema( response_schema, path=(), root=response_schema, ) - if not use_tool_calls: - async with client.chat.completions.stream( - model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - response_format=( - { - "type": "json_schema", - "json_schema": { + + if use_tool_calls_for_structured_output and depth == 0: + if all_tools is None: + all_tools = [] + all_tools.append( + self.tool_calls_handler.parse_tool( + LLMDynamicTool( + name="ResponseSchema", + description="Provide response to the user", + parameters=response_schema, + handler=do_nothing_async, + ), + strict=strict, + ) + ) + + tool_calls: List[LLMToolCall] = [] + current_index = 0 + current_id = None + current_name = None + current_arguments = None + + has_response_schema_tool_call = False + async for event in await client.chat.completions.create( + model=model, + messages=[message.model_dump() for message in messages], + max_completion_tokens=max_tokens, + tools=all_tools, + response_format=( + { + "type": "json_schema", + "json_schema": ( + { "name": "ResponseSchema", "strict": strict, "schema": response_schema, - }, - } + } + ), + } + if not use_tool_calls_for_structured_output + else None + ), + extra_body=extra_body, + stream=True, + ): + event: OpenAIChatCompletionChunk = event + content_chunk = event.choices[0].delta.content + if content_chunk: + yield content_chunk + + tool_call_chunk = event.choices[0].delta.tool_calls + if tool_call_chunk: + tool_index = tool_call_chunk[0].index + tool_id = tool_call_chunk[0].id + tool_name = tool_call_chunk[0].function.name + tool_arguments = tool_call_chunk[0].function.arguments + + if current_index != tool_index: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + current_index = tool_index + current_id = tool_id + current_name = tool_name + current_arguments = tool_arguments + else: + current_name = tool_name or current_name + current_id = tool_id or current_id + if current_arguments is None: + current_arguments = tool_arguments + else: + current_arguments += tool_arguments + + if current_name == "ResponseSchema": + if tool_arguments: + yield tool_arguments + has_response_schema_tool_call = True + + if current_id is not None: + tool_calls.append( + OpenAIToolCall( + id=current_id, + type="function", + function=OpenAIToolCallFunction( + name=current_name, + arguments=current_arguments, + ), + ) + ) + + if tool_calls and not has_response_schema_tool_call: + tool_call_messages = await self.tool_calls_handler.handle_tool_calls_openai( + tool_calls + ) + new_messages = [ + *messages, + OpenAIAssistantMessage( + role="assistant", + content=None, + tool_calls=[each.model_dump() for each in tool_calls], ), - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) as stream: - async for event in stream: - if event.type == "content.delta": - yield event.delta - else: - async with client.chat.completions.stream( + *tool_call_messages, + ] + async for event in self._stream_openai_structured( model=model, - messages=[message.model_dump() for message in messages], - max_completion_tokens=max_tokens, - tools=[ - { - "type": "function", - "function": { - "name": "ResponseSchema", - "description": "A response to the user's message", - "strict": strict, - "parameters": response_format, - }, - } - ], - tool_choice="required", - extra_body={ - "enable_thinking": not self.disable_thinking(), - }, - ) as stream: - async for event in stream: - if event.type == "tool_calls.function.arguments.delta": - yield event.arguments_delta + messages=new_messages, + max_tokens=max_tokens, + strict=strict, + tools=all_tools, + response_format=response_schema, + extra_body=extra_body, + depth=depth + 1, + ): + yield event async def _stream_google_structured( self, @@ -562,35 +1241,99 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, max_tokens: Optional[int] = None, - ): + tools: Optional[List[dict]] = None, + depth: int = 0, + ) -> AsyncGenerator[str, None]: + client: genai.Client = self._client + + google_tools = None + if tools: + google_tools = [GoogleTool(function_declarations=[tool]) for tool in tools] + google_tools.append( + GoogleTool( + function_declarations=[ + { + "name": "ResponseSchema", + "description": "Provide response to the user", + "parameters_json_schema": response_format, + } + ] + ) + ) + + tool_calls: List[GoogleToolCall] = [] + has_response_schema_tool_call = False async for event in iterator_to_async(client.models.generate_content_stream)( model=model, - contents=self._get_user_prompts(messages), + contents=self._get_google_messages(messages), config=GenerateContentConfig( + tools=google_tools, system_instruction=self._get_system_prompt(messages), - response_mime_type="application/json", - response_json_schema=response_format, + response_mime_type="application/json" if not tools else None, + response_json_schema=response_format if not tools else None, max_output_tokens=max_tokens, ), ): if event.text: yield event.text + if event.function_calls: + tool_calls = [ + GoogleToolCall( + name=each.name, + arguments=each.args, + ) + for each in event.function_calls + ] + + for each in tool_calls: + if each.name == "ResponseSchema": + has_response_schema_tool_call = True + if each.arguments: + yield json.dumps(each.arguments) + + if has_response_schema_tool_call: + continue + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_google(tool_calls) + ) + new_messages = [ + *messages, + GoogleAssistantMessage( + role="assistant", + content=event.candidates[0].content, + ), + *tool_call_messages, + ] + async for event in self._stream_google_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event + async def _stream_anthropic_structured( self, model: str, messages: List[LLMMessage], response_format: dict, + tools: Optional[List[dict]] = None, max_tokens: Optional[int] = None, - ): + depth: int = 0, + ) -> AsyncGenerator[str, None]: client: AsyncAnthropic = self._client async with client.messages.stream( model=model, system=self._get_system_prompt(messages), messages=[ message.model_dump() - for message in self._get_user_llm_messages(messages) + for message in self._get_anthropic_messages(messages) ], max_tokens=max_tokens or 4000, tools=[ @@ -598,17 +1341,72 @@ class LLMClient: "name": "ResponseSchema", "description": "A response to the user's message", "input_schema": response_format, - } + }, + *(tools or []), ], - tool_choice={ - "type": "tool", - "name": "ResponseSchema", - }, ) as stream: + tool_calls: List[AnthropicToolCall] = [] + has_response_schema_tool_call = False + is_response_schema_tool_call_started = False async for event in stream: event: AnthropicMessageStreamEvent = event - if event.type == "input_json" and isinstance(event.partial_json, str): - yield event.partial_json + if ( + event.type == "content_block_start" + and event.content_block.type == "tool_use" + ): + if event.content_block.name == "ResponseSchema": + has_response_schema_tool_call = True + is_response_schema_tool_call_started = True + + if ( + event.type == "content_block_delta" + and event.delta.type == "input_json_delta" + and is_response_schema_tool_call_started + ): + yield event.delta.partial_json + + if has_response_schema_tool_call: + continue + + if ( + event.type == "content_block_stop" + and event.content_block.type == "tool_use" + ): + tool_calls.append( + AnthropicToolCall( + id=event.content_block.id, + type=event.content_block.type, + name=event.content_block.name, + input=event.content_block.input, + ) + ) + + if tool_calls: + tool_call_messages = ( + await self.tool_calls_handler.handle_tool_calls_anthropic( + tool_calls + ) + ) + new_messages = [ + *messages, + AnthropicAssistantMessage( + role="assistant", + content=[each.model_dump() for each in tool_calls], + ), + AnthropicUserMessage( + role="user", + content=[each.model_dump() for each in tool_call_messages], + ), + ] + async for event in self._stream_anthropic_structured( + model=model, + messages=new_messages, + max_tokens=max_tokens, + response_format=response_format, + tools=tools, + depth=depth + 1, + ): + yield event def _stream_ollama_structured( self, @@ -617,9 +1415,15 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + depth=depth, ) def _stream_custom_structured( @@ -629,9 +1433,17 @@ class LLMClient: response_format: dict, strict: bool = False, max_tokens: Optional[int] = None, + depth: int = 0, ): + extra_body = {"enable_thinking": not self.disable_thinking()} return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, + extra_body=extra_body, + depth=depth, ) def stream_structured( @@ -640,26 +1452,93 @@ class LLMClient: messages: List[LLMMessage], response_format: dict, strict: bool = False, + tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None, max_tokens: Optional[int] = None, ): + parsed_tools = self.tool_calls_handler.parse_tools(tools) + match self.llm_provider: case LLMProvider.OPENAI: return self._stream_openai_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.GOOGLE: return self._stream_google_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.ANTHROPIC: return self._stream_anthropic_structured( - model, messages, response_format, max_tokens + model=model, + messages=messages, + response_format=response_format, + tools=parsed_tools, + max_tokens=max_tokens, ) case LLMProvider.OLLAMA: return self._stream_ollama_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) case LLMProvider.CUSTOM: return self._stream_custom_structured( - model, messages, response_format, strict, max_tokens + model=model, + messages=messages, + response_format=response_format, + strict=strict, + max_tokens=max_tokens, ) + + # ? Web search + async def _search_openai(self, query: str) -> str: + client: AsyncOpenAI = self._client + response = await client.responses.create( + model=get_model(), + tools=[ + { + "type": "web_search_preview", + } + ], + input=query, + ) + return response.output_text + + async def _search_google(self, query: str) -> str: + client: genai.Client = self._client + grounding_tool = GoogleTool(google_search=GoogleSearch()) + config = GenerateContentConfig(tools=[grounding_tool]) + + response = await asyncio.to_thread( + client.models.generate_content, + model=get_model(), + contents=query, + config=config, + ) + return response.text + + async def _search_anthropic(self, query: str) -> str: + client: AsyncAnthropic = self._client + + response = await client.messages.create( + model=get_model(), + max_tokens=4000, + messages=[{"role": "user", "content": query}], + tools=[ + {"type": "web_search_20250305", "name": "web_search", "max_uses": 1} + ], + ) + result = "\n".join( + [each.text for each in response.content if each.type == "text"] + ) + return result diff --git a/servers/fastapi/services/llm_tool_calls_handler.py b/servers/fastapi/services/llm_tool_calls_handler.py new file mode 100644 index 00000000..ed0d51ee --- /dev/null +++ b/servers/fastapi/services/llm_tool_calls_handler.py @@ -0,0 +1,201 @@ +import asyncio +from datetime import datetime +import json +from typing import Any, Callable, Coroutine, List, Optional +from fastapi import HTTPException +from enums.llm_provider import LLMProvider +from models.llm_message import ( + AnthropicToolCallMessage, + GoogleToolCallMessage, + OpenAIToolCallMessage, +) +from models.llm_tool_call import AnthropicToolCall, GoogleToolCall, OpenAIToolCall +from models.llm_tools import LLMDynamicTool, LLMTool, SearchWebTool +from utils.schema_utils import ensure_strict_json_schema, flatten_json_schema + + +class LLMToolCallsHandler: + def __init__(self, client): + from services.llm_client import LLMClient + + self.client: LLMClient = client + + self.tools_map: dict[str, Callable[..., Coroutine[Any, Any, str]]] = { + "SearchWebTool": self.search_web_tool_call_handler, + "GetCurrentDatetimeTool": self.get_current_datetime_tool_call_handler, + } + self.dynamic_tools: List[LLMDynamicTool] = [] + + def get_tool_handler( + self, tool_name: str + ) -> Callable[..., Coroutine[Any, Any, str]]: + handler = self.tools_map.get(tool_name) + if handler: + return handler + else: + dynamic_tools = list( + filter(lambda tool: tool.name == tool_name, self.dynamic_tools) + ) + if dynamic_tools: + return dynamic_tools[0].handler + raise HTTPException(status_code=500, detail=f"Tool {tool_name} not found") + + def parse_tools(self, tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None): + if tools is None: + return None + parsed_tools = map(self.parse_tool, tools) + return list(parsed_tools) + + def parse_tool(self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False): + if isinstance(tool, LLMDynamicTool): + self.dynamic_tools.append(tool) + + match self.client.llm_provider: + case LLMProvider.OPENAI | LLMProvider.OLLAMA | LLMProvider.CUSTOM: + return self.parse_tool_openai(tool, strict) + case LLMProvider.ANTHROPIC: + return self.parse_tool_anthropic(tool) + case LLMProvider.GOOGLE: + return self.parse_tool_google(tool) + case _: + raise ValueError( + f"LLM provider must be either openai, anthropic, or google" + ) + + def parse_tool_openai( + self, tool: type[LLMTool] | LLMDynamicTool, strict: bool = False + ): + if isinstance(tool, LLMDynamicTool): + name = tool.name + description = tool.description + parameters = tool.parameters + else: + name = tool.__name__ + description = tool.__doc__ or "" + parameters = tool.model_json_schema() + + if strict: + parameters = ensure_strict_json_schema(parameters, path=(), root=parameters) + + return { + "type": "function", + "function": { + "name": name, + "description": description, + "strict": strict, + "parameters": parameters, + }, + } + + def parse_tool_google(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + # parsed["function"]["parameters"] = flatten_json_schema( + # parsed["function"]["parameters"] + # ) + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "parameters": parsed["function"]["parameters"], + } + + def parse_tool_anthropic(self, tool: type[LLMTool] | LLMDynamicTool): + parsed = self.parse_tool_openai(tool) + input_schema = parsed["function"]["parameters"] + return { + "name": parsed["function"]["name"], + "description": parsed["function"]["description"], + "input_schema": {"type": "object"} if input_schema == {} else input_schema, + } + + async def handle_tool_calls_openai( + self, + tool_calls: List[OpenAIToolCall], + ) -> List[OpenAIToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(tool_call.function.arguments)) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + OpenAIToolCallMessage( + content=result, + tool_call_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_google( + self, + tool_calls: List[GoogleToolCall], + ) -> List[GoogleToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.arguments))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + GoogleToolCallMessage( + name=tool_call.name, + response={"result": result}, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + async def handle_tool_calls_anthropic( + self, + tool_calls: List[AnthropicToolCall], + ) -> List[AnthropicToolCallMessage]: + async_tool_calls_tasks = [] + for tool_call in tool_calls: + tool_name = tool_call.name + tool_handler = self.get_tool_handler(tool_name) + async_tool_calls_tasks.append(tool_handler(json.dumps(tool_call.input))) + + tool_call_results: List[str] = await asyncio.gather(*async_tool_calls_tasks) + tool_call_messages = [ + AnthropicToolCallMessage( + content=result, + tool_use_id=tool_call.id, + ) + for tool_call, result in zip(tool_calls, tool_call_results) + ] + return tool_call_messages + + # ? Tool call handlers + # Search web tool call handler + async def search_web_tool_call_handler(self, arguments: str) -> str: + match self.client.llm_provider: + case LLMProvider.OPENAI: + return await self.search_web_tool_call_handler_openai(arguments) + case LLMProvider.ANTHROPIC: + return await self.search_web_tool_call_handler_anthropic(arguments) + case LLMProvider.GOOGLE: + return await self.search_web_tool_call_handler_google(arguments) + case _: + return ( + "Web search tool call handler not implemented for this LLM provider: " + + self.client.llm_provider.value + ) + + async def search_web_tool_call_handler_openai(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_openai(args.query) + + async def search_web_tool_call_handler_google(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_google(args.query) + + async def search_web_tool_call_handler_anthropic(self, arguments: str) -> str: + args = SearchWebTool.model_validate_json(arguments) + return await self.client._search_anthropic(args.query) + + # Get current datetime tool call handler + async def get_current_datetime_tool_call_handler(self, _) -> str: + current_time = datetime.now() + return f"{current_time.strftime('%A, %B %d, %Y')} at {current_time.strftime('%I:%M:%S %p')}" diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py deleted file mode 100644 index f2e3d8c9..00000000 --- a/servers/fastapi/services/redis_service.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Optional -import redis -from redis.exceptions import RedisError - -from utils.get_env import ( - get_redis_db_env, - get_redis_host_env, - get_redis_password_env, - get_redis_port_env, -) - - -class RedisService: - def __init__(self): - self.redis_host = get_redis_host_env() or "localhost" - self.redis_port = int(get_redis_port_env() or "6379") - self.redis_db = int(get_redis_db_env() or "0") - self.redis_password = get_redis_password_env() or None - self.client = self._create_client() - - def _create_client(self) -> redis.Redis: - return redis.Redis( - host=self.redis_host, - port=self.redis_port, - db=self.redis_db, - password=self.redis_password, - decode_responses=True, - ) - - def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: - try: - return self.client.set(key, value, ex=expire) - except RedisError: - return False - - def get(self, key: str) -> Optional[str]: - try: - return self.client.get(key) - except RedisError: - return None - - def delete(self, key: str) -> bool: - try: - return bool(self.client.delete(key)) - except RedisError: - return False - - def exists(self, key: str) -> bool: - try: - return bool(self.client.exists(key)) - except RedisError: - return False - - def set_hash(self, name: str, mapping: dict) -> bool: - try: - return self.client.hmset(name, mapping) - except RedisError: - return False - - def get_hash(self, name: str) -> Optional[dict]: - try: - return self.client.hgetall(name) - except RedisError: - return None - - def delete_hash(self, name: str, *fields: str) -> int: - try: - return self.client.hdel(name, *fields) - except RedisError: - return 0 - - def set_list(self, name: str, values: list) -> bool: - try: - self.client.delete(name) - if values: - self.client.rpush(name, *values) - return True - except RedisError: - return False - - def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: - try: - return self.client.lrange(name, start, end) - except RedisError: - return None - - def add_to_set(self, name: str, *values: str) -> int: - try: - return self.client.sadd(name, *values) - except RedisError: - return 0 - - def get_set(self, name: str) -> Optional[set]: - try: - return self.client.smembers(name) - except RedisError: - return None - - def remove_from_set(self, name: str, *values: str) -> int: - try: - return self.client.srem(name, *values) - except RedisError: - return 0 - - def clear(self) -> bool: - try: - return self.client.flushdb() - except RedisError: - return False - - def close(self): - try: - self.client.close() - except RedisError: - pass diff --git a/servers/fastapi/static/icons/deepseek.png b/servers/fastapi/static/icons/deepseek.png deleted file mode 100644 index 798b8f18..00000000 Binary files a/servers/fastapi/static/icons/deepseek.png and /dev/null differ diff --git a/servers/fastapi/static/icons/gemma.png b/servers/fastapi/static/icons/gemma.png deleted file mode 100644 index 647d87a2..00000000 Binary files a/servers/fastapi/static/icons/gemma.png and /dev/null differ diff --git a/servers/fastapi/static/icons/meta.png b/servers/fastapi/static/icons/meta.png deleted file mode 100644 index 0a3d82c1..00000000 Binary files a/servers/fastapi/static/icons/meta.png and /dev/null differ diff --git a/servers/fastapi/static/icons/qwen.png b/servers/fastapi/static/icons/qwen.png deleted file mode 100644 index 2cee1c36..00000000 Binary files a/servers/fastapi/static/icons/qwen.png and /dev/null differ diff --git a/servers/fastapi/utils/dummy_functions.py b/servers/fastapi/utils/dummy_functions.py new file mode 100644 index 00000000..461e9695 --- /dev/null +++ b/servers/fastapi/utils/dummy_functions.py @@ -0,0 +1,2 @@ +async def do_nothing_async(_): + return None diff --git a/servers/fastapi/utils/get_dynamic_models.py b/servers/fastapi/utils/get_dynamic_models.py index 744a6a5a..fd4b2bda 100644 --- a/servers/fastapi/utils/get_dynamic_models.py +++ b/servers/fastapi/utils/get_dynamic_models.py @@ -1,13 +1,23 @@ from typing import List from pydantic import Field -from models.presentation_outline_model import PresentationOutlineModel +from models.presentation_outline_model import ( + PresentationOutlineModel, + SlideOutlineModel, +) from models.presentation_structure_model import PresentationStructureModel def get_presentation_outline_model_with_n_slides(n_slides: int): + class SlideOutlineModelWithNSlides(SlideOutlineModel): + content: str = Field( + description="Markdown content for each slide", + min_length=100, + max_length=300, + ) + class PresentationOutlineModelWithNSlides(PresentationOutlineModel): - slides: List[str] = Field( - description="Markdown content for each slide in about 100 to 200 words", + slides: List[SlideOutlineModelWithNSlides] = Field( + description="List of slide outlines", min_items=n_slides, max_items=n_slides, ) diff --git a/servers/fastapi/utils/get_env.py b/servers/fastapi/utils/get_env.py index c2c72efd..fa80b2a2 100644 --- a/servers/fastapi/utils/get_env.py +++ b/servers/fastapi/utils/get_env.py @@ -81,22 +81,6 @@ def get_pixabay_api_key_env(): return os.getenv("PIXABAY_API_KEY") -def get_redis_host_env(): - return os.getenv("REDIS_HOST") - - -def get_redis_port_env(): - return os.getenv("REDIS_PORT") - - -def get_redis_db_env(): - return os.getenv("REDIS_DB") - - -def get_redis_password_env(): - return os.getenv("REDIS_PASSWORD") - - def get_tool_calls_env(): return os.getenv("TOOL_CALLS") @@ -107,3 +91,7 @@ def get_disable_thinking_env(): def get_extended_reasoning_env(): return os.getenv("EXTENDED_REASONING") + + +def get_web_grounding_env(): + return os.getenv("WEB_GROUNDING") diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py index a8df598a..30599d08 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide.py +++ b/servers/fastapi/utils/llm_calls/edit_slide.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel from models.sql.slide import SlideModel from services.llm_client import LLMClient @@ -41,12 +41,10 @@ def get_messages( language: str, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, slide_data, language), ), ] diff --git a/servers/fastapi/utils/llm_calls/edit_slide_html.py b/servers/fastapi/utils/llm_calls/edit_slide_html.py index a5e2dfad..cf58d185 100644 --- a/servers/fastapi/utils/llm_calls/edit_slide_html.py +++ b/servers/fastapi/utils/llm_calls/edit_slide_html.py @@ -1,5 +1,5 @@ from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from services.llm_client import LLMClient from utils.llm_provider import get_model @@ -53,8 +53,8 @@ async def get_edited_slide_html(prompt: str, html: str): response = await client.generate( model=model, messages=[ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=get_user_prompt(prompt, html)), + LLMSystemMessage(content=system_prompt), + LLMUserMessage(content=get_user_prompt(prompt, html)), ], ) return extract_html_from_response(response) or html diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index 507bb6eb..892b9cff 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -1,10 +1,13 @@ -import asyncio from typing import Optional -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage +from models.llm_tools import GetCurrentDatetimeTool, SearchWebTool from services.llm_client import LLMClient from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides +from utils.get_env import get_web_grounding_env from utils.llm_provider import get_model +from utils.parsers import parse_bool_or_none +from utils.user_config import get_user_config system_prompt = """ You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content. @@ -29,12 +32,10 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str): def get_messages(prompt: str, n_slides: int, language: str, content: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(prompt, n_slides, language, content), ), ] @@ -51,10 +52,13 @@ async def generate_ppt_outline( client = LLMClient() + tools = [SearchWebTool, GetCurrentDatetimeTool] + async for chunk in client.stream_structured( model, get_messages(prompt, n_slides, language, content), response_model.model_json_schema(), strict=True, + tools=tools if client.enable_web_grounding() else None, ): yield chunk diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py index 47f47dba..1bfc0cd0 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_structure.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_structure.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from services.llm_client import LLMClient @@ -11,8 +11,7 @@ def get_messages( presentation_layout: PresentationLayoutModel, n_slides: int, data: str ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" You're a professional presentation designer with creative freedom to design engaging presentations. @@ -47,8 +46,7 @@ def get_messages( Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals. """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" {data} """, diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py index ecff518a..be19b168 100644 --- a/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -1,5 +1,6 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import SlideLayoutModel +from models.presentation_outline_model import SlideOutlineModel from services.llm_client import LLMClient from utils.llm_provider import get_model from utils.schema_utils import remove_fields_from_schema @@ -38,19 +39,17 @@ def get_user_prompt(outline: str, language: str): def get_messages(outline: str, language: str): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=system_prompt, ), - LLMMessage( - role="user", + LLMUserMessage( content=get_user_prompt(outline, language), ), ] async def get_slide_content_from_type_and_outline( - slide_layout: SlideLayoutModel, outline: str, language: str + slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str ): client = LLMClient() model = get_model() @@ -62,7 +61,7 @@ async def get_slide_content_from_type_and_outline( response = await client.generate_structured( model=model, messages=get_messages( - outline, + outline.content, language, ), response_format=response_schema, diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py index f3532b48..7235e558 100644 --- a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py +++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -1,4 +1,4 @@ -from models.llm_message import LLMMessage +from models.llm_message import LLMSystemMessage, LLMUserMessage from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel from models.slide_layout_index import SlideLayoutIndex from models.sql.slide import SlideModel @@ -13,8 +13,7 @@ def get_messages( current_slide_layout: int, ): return [ - LLMMessage( - role="system", + LLMSystemMessage( content=f""" Select a Slide Layout index based on provided user prompt and current slide data. {layout.to_string()} @@ -26,8 +25,7 @@ def get_messages( **Go through all notes and steps and make sure they are followed, including mentioned constraints** """, ), - LLMMessage( - role="user", + LLMUserMessage( content=f""" - User Prompt: {prompt} - Current Slide Data: {slide_data} diff --git a/servers/fastapi/utils/schema_utils.py b/servers/fastapi/utils/schema_utils.py index ae65f002..6cb01a0e 100644 --- a/servers/fastapi/utils/schema_utils.py +++ b/servers/fastapi/utils/schema_utils.py @@ -177,6 +177,59 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved +# Flattens a JSON schema by inlining all $ref references and removing $defs/definitions +def flatten_json_schema(schema: dict) -> dict: + root_schema = deepcopy(schema) + + def _flatten(node: Any) -> Any: + if isinstance(node, dict): + # If node is a pure $ref (or combined with extra fields), inline it + if "$ref" in node: + ref_value = node["$ref"] + assert isinstance(ref_value, str), f"Received non-string $ref - {ref_value}" + resolved = resolve_ref(root=root_schema, ref=ref_value) + assert isinstance(resolved, dict), ( + f"Expected `$ref: {ref_value}` to resolve to a dictionary but got {type(resolved)}" + ) + # Merge: referenced first, then overlay current (excluding $ref) + merged: dict[str, Any] = deepcopy(resolved) + for key, value in node.items(): + if key == "$ref": + continue + merged[key] = value + return _flatten(merged) + + flattened: dict[str, Any] = {} + for key, value in node.items(): + # Drop defs/definitions in output + if key in ("$defs", "definitions"): + continue + if key == "properties" and isinstance(value, dict): + flattened[key] = {prop_key: _flatten(prop_val) for prop_key, prop_val in value.items()} + elif key in ("items", "contains", "additionalProperties", "not"): + if isinstance(value, dict): + flattened[key] = _flatten(value) + elif isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = value + elif key in ("allOf", "anyOf", "oneOf", "prefixItems") and isinstance(value, list): + flattened[key] = [_flatten(v) for v in value] + else: + flattened[key] = _flatten(value) if isinstance(value, (dict, list)) else value + return flattened + if isinstance(node, list): + return [_flatten(v) for v in node] + return node + + result = _flatten(schema) + # Ensure top-level cleanup just in case + if isinstance(result, dict): + result.pop("$defs", None) + result.pop("definitions", None) + return result + + # ? Not used def generate_constraint_sentences(schema: dict) -> str: """ diff --git a/servers/fastapi/utils/set_env.py b/servers/fastapi/utils/set_env.py index 7ac0e335..ea3758f3 100644 --- a/servers/fastapi/utils/set_env.py +++ b/servers/fastapi/utils/set_env.py @@ -79,3 +79,7 @@ def set_disable_thinking_env(value): def set_extended_reasoning_env(value): os.environ["EXTENDED_REASONING"] = value + + +def set_web_grounding_env(value): + os.environ["WEB_GROUNDING"] = value \ No newline at end of file diff --git a/servers/fastapi/utils/user_config.py b/servers/fastapi/utils/user_config.py index 06235d5a..49fd1722 100644 --- a/servers/fastapi/utils/user_config.py +++ b/servers/fastapi/utils/user_config.py @@ -22,6 +22,7 @@ from utils.get_env import ( get_image_provider_env, get_pixabay_api_key_env, get_extended_reasoning_env, + get_web_grounding_env, ) from utils.parsers import parse_bool_or_none from utils.set_env import ( @@ -43,6 +44,7 @@ from utils.set_env import ( set_image_provider_env, set_pixabay_api_key_env, set_tool_calls_env, + set_web_grounding_env, ) @@ -76,12 +78,26 @@ def get_user_config(): IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(), PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(), - TOOL_CALLS=existing_config.TOOL_CALLS - or parse_bool_or_none(get_tool_calls_env()), - DISABLE_THINKING=existing_config.DISABLE_THINKING - or parse_bool_or_none(get_disable_thinking_env()), - EXTENDED_REASONING=existing_config.EXTENDED_REASONING - or parse_bool_or_none(get_extended_reasoning_env()), + TOOL_CALLS=( + existing_config.TOOL_CALLS + if existing_config.TOOL_CALLS is not None + else (parse_bool_or_none(get_tool_calls_env()) or False) + ), + DISABLE_THINKING=( + existing_config.DISABLE_THINKING + if existing_config.DISABLE_THINKING is not None + else (parse_bool_or_none(get_disable_thinking_env()) or False) + ), + EXTENDED_REASONING=( + existing_config.EXTENDED_REASONING + if existing_config.EXTENDED_REASONING is not None + else (parse_bool_or_none(get_extended_reasoning_env()) or False) + ), + WEB_GROUNDING=( + existing_config.WEB_GROUNDING + if existing_config.WEB_GROUNDING is not None + else (parse_bool_or_none(get_web_grounding_env()) or False) + ), ) @@ -122,5 +138,6 @@ def update_env_with_user_config(): if user_config.DISABLE_THINKING: set_disable_thinking_env(str(user_config.DISABLE_THINKING)) if user_config.EXTENDED_REASONING: - if user_config.EXTENDED_REASONING: - set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + set_extended_reasoning_env(str(user_config.EXTENDED_REASONING)) + if user_config.WEB_GROUNDING: + set_web_grounding_env(str(user_config.WEB_GROUNDING)) diff --git a/servers/fastapi/uv.lock b/servers/fastapi/uv.lock index b579f42a..4e19c02b 100644 --- a/servers/fastapi/uv.lock +++ b/servers/fastapi/uv.lock @@ -1061,6 +1061,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "isodate" version = "0.7.2" @@ -1907,6 +1916,7 @@ dependencies = [ { name = "openai" }, { name = "pathvalidate" }, { name = "pdfplumber" }, + { name = "pytest" }, { name = "python-pptx" }, { name = "redis" }, { name = "sqlmodel" }, @@ -1928,6 +1938,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.98.0" }, { name = "pathvalidate", specifier = ">=3.3.1" }, { name = "pdfplumber", specifier = ">=0.11.7" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "python-pptx", specifier = ">=1.0.2" }, { name = "redis", specifier = ">=6.2.0" }, { name = "sqlmodel", specifier = ">=0.0.24" }, @@ -2211,6 +2222,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/dc/491b7661614ab97483abf2056be1deee4dc2490ecbf7bff9ab5cdbac86e1/pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6", size = 83178, upload-time = "2024-09-19T02:40:08.598Z" }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + [[package]] name = "python-bidi" version = "0.6.6" diff --git a/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx b/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx index bc5ee297..e5f37757 100644 --- a/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx +++ b/servers/nextjs/app/(presentation-generator)/outline/components/GenerateButton.tsx @@ -1,10 +1,10 @@ import React from "react"; import { Button } from "@/components/ui/button"; -import { LoadingState, StreamState, LayoutGroup } from "../types/index"; +import { LoadingState, LayoutGroup } from "../types/index"; interface GenerateButtonProps { loadingState: LoadingState; - streamState: StreamState; + streamState: { isStreaming: boolean, isLoading: boolean }; selectedLayoutGroup: LayoutGroup | null; onSubmit: () => void; } diff --git a/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx b/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx index a305da91..5764fe6d 100644 --- a/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx +++ b/servers/nextjs/app/(presentation-generator)/outline/components/OutlineContent.tsx @@ -18,7 +18,7 @@ import { Button } from "@/components/ui/button"; import { FileText } from "lucide-react"; interface OutlineContentProps { - outlines: string[] | null; + outlines: { content: string }[] | null; isLoading: boolean; isStreaming: boolean; onDragEnd: (event: any) => void; @@ -32,7 +32,7 @@ const OutlineContent: React.FC = ({ onDragEnd, onAddSlide }) => { - + console.log('isLoading', isLoading) const sensors = useSensors( useSensor(PointerSensor), useSensor(KeyboardSensor, { @@ -83,7 +83,18 @@ const OutlineContent: React.FC = ({ collisionDetection={closestCenter} onDragEnd={onDragEnd} > - ( + + )) + ) : + ({ id: `slide-${index}` })) || []} strategy={verticalListSortingStrategy} > @@ -95,7 +106,7 @@ const OutlineContent: React.FC = ({ isStreaming={isStreaming} /> ))} - + }