Merge pull request #518 from presenton/feat/overflow-mitigation-loop
feat: implements overflow mitigation loop on structured generation; fix: changes slide content generation system prompt to avoid clipping
This commit is contained in:
commit
36815fe7a3
10 changed files with 555 additions and 62 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Coroutine, Optional
|
||||
from typing import Any, Callable, Coroutine
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ dependencies = [
|
|||
"google-genai>=1.28.0",
|
||||
# Platform-specific: greenlet for macOS only (critical for SQLAlchemy async)
|
||||
"greenlet>=3.0.0; sys_platform == 'darwin'",
|
||||
"jsonschema>=4.25.0",
|
||||
"nltk>=3.9.1",
|
||||
"openai>=1.98.0",
|
||||
"pathvalidate>=3.3.1",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import dirtyjson
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from openai import APIStatusError, AsyncOpenAI, OpenAIError
|
||||
|
|
@ -69,11 +70,15 @@ from utils.schema_utils import (
|
|||
ensure_array_schemas_have_items,
|
||||
ensure_strict_json_schema,
|
||||
flatten_json_schema,
|
||||
get_schema_validation_errors,
|
||||
remove_titles_from_schema,
|
||||
)
|
||||
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self):
|
||||
self.llm_provider = get_llm_provider()
|
||||
|
|
@ -1067,6 +1072,101 @@ class LLMClient:
|
|||
depth=depth,
|
||||
)
|
||||
|
||||
async def _generate_structured_once(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
tools: Optional[List[dict]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict | None:
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return await self._generate_openai_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
return await self._generate_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return await self._generate_google_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return await self._generate_anthropic_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
return await self._generate_ollama_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
return await self._generate_custom_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def _get_structured_validation_feedback_message(
|
||||
self,
|
||||
content: dict,
|
||||
validation_errors: List[str],
|
||||
) -> LLMUserMessage:
|
||||
max_error_count = 10
|
||||
max_json_chars = 6000
|
||||
|
||||
formatted_errors = validation_errors[:max_error_count]
|
||||
if len(validation_errors) > max_error_count:
|
||||
formatted_errors.append(
|
||||
f"...and {len(validation_errors) - max_error_count} more validation errors."
|
||||
)
|
||||
|
||||
previous_response = json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
if len(previous_response) > max_json_chars:
|
||||
previous_response = previous_response[:max_json_chars] + "\n... (truncated)"
|
||||
|
||||
return LLMUserMessage(
|
||||
content=(
|
||||
"The previous JSON response did not match the required response schema.\n\n"
|
||||
"Validation errors:\n"
|
||||
+ "\n".join(f"- {error}" for error in formatted_errors)
|
||||
+ "\n\nPrevious invalid JSON:\n"
|
||||
+ f"```json\n{previous_response}\n```\n\n"
|
||||
+ "Return corrected JSON only. Make sure it fully matches the required schema."
|
||||
)
|
||||
)
|
||||
|
||||
async def generate_structured(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1075,68 +1175,69 @@ class LLMClient:
|
|||
strict: bool = False,
|
||||
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
validate_schema: bool = False,
|
||||
validate_schema_max_loop_count: int = 5,
|
||||
) -> dict:
|
||||
parsed_tools = self.tool_calls_handler.parse_tools(tools)
|
||||
max_validation_loops = max(1, validate_schema_max_loop_count)
|
||||
working_messages = [*messages]
|
||||
|
||||
for attempt in range(3):
|
||||
for validation_attempt in range(max_validation_loops):
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
content = await self._generate_openai_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CODEX:
|
||||
content = await self._generate_codex_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
content = await self._generate_anthropic_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
content = await self._generate_custom_structured(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
for attempt in range(3):
|
||||
content = await self._generate_structured_once(
|
||||
model=model,
|
||||
messages=working_messages,
|
||||
response_format=response_format,
|
||||
strict=strict,
|
||||
tools=parsed_tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
if content is not None:
|
||||
if content is not None:
|
||||
break
|
||||
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="LLM did not return any content",
|
||||
)
|
||||
|
||||
if not validate_schema:
|
||||
return content
|
||||
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
validation_errors = get_schema_validation_errors(
|
||||
response_format,
|
||||
content,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
if not validation_errors:
|
||||
return content
|
||||
|
||||
formatted_validation_errors = " | ".join(validation_errors)
|
||||
if validation_attempt == max_validation_loops - 1:
|
||||
LOGGER.warning(
|
||||
"Validation error after max fixes, returning last response: %s",
|
||||
formatted_validation_errors,
|
||||
)
|
||||
return content
|
||||
|
||||
LOGGER.warning(
|
||||
"Validation error, attempting fix %s/%s: %s",
|
||||
validation_attempt + 1,
|
||||
max_validation_loops - 1,
|
||||
formatted_validation_errors,
|
||||
)
|
||||
working_messages.append(
|
||||
self._get_structured_validation_feedback_message(
|
||||
content,
|
||||
validation_errors,
|
||||
)
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -1754,8 +1855,6 @@ class LLMClient:
|
|||
):
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
async def _stream_codex_structured(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,338 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import LLMUserMessage
|
||||
from models.presentation_outline_model import PresentationOutlineModel, SlideOutlineModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
from templates.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from utils.llm_calls.edit_slide import get_edited_slide_content
|
||||
from utils.llm_calls.generate_presentation_structure import (
|
||||
generate_presentation_structure,
|
||||
)
|
||||
from utils.llm_calls.generate_slide_content import get_slide_content_from_type_and_outline
|
||||
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
|
||||
|
||||
|
||||
def _build_client() -> LLMClient:
|
||||
client = object.__new__(LLMClient)
|
||||
client.llm_provider = LLMProvider.OPENAI
|
||||
client.tool_calls_handler = SimpleNamespace(parse_tools=lambda tools: None)
|
||||
return client
|
||||
|
||||
|
||||
def _build_layout() -> PresentationLayoutModel:
|
||||
return PresentationLayoutModel(
|
||||
name="Test Layout",
|
||||
slides=[
|
||||
SlideLayoutModel(
|
||||
id="layout-1",
|
||||
name="Title Slide",
|
||||
description="Single title layout",
|
||||
json_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
},
|
||||
"required": ["title"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_slide() -> SlideModel:
|
||||
return SlideModel(
|
||||
presentation=uuid.uuid4(),
|
||||
layout_group="default",
|
||||
layout="layout-1",
|
||||
index=0,
|
||||
content={"title": "Current title"},
|
||||
)
|
||||
|
||||
|
||||
def test_generate_structured_skips_validation_when_disabled():
|
||||
client = _build_client()
|
||||
call_messages = []
|
||||
|
||||
async def fake_generate(**kwargs):
|
||||
call_messages.append(kwargs["messages"])
|
||||
return {"title": 123}
|
||||
|
||||
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
|
||||
|
||||
response = asyncio.run(
|
||||
client.generate_structured(
|
||||
model="test-model",
|
||||
messages=[LLMUserMessage(content="Generate JSON")],
|
||||
response_format={
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
"required": ["title"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
validate_schema=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert response == {"title": 123}
|
||||
assert len(call_messages) == 1
|
||||
assert len(call_messages[0]) == 1
|
||||
|
||||
|
||||
def test_generate_structured_retries_with_validation_feedback():
|
||||
client = _build_client()
|
||||
call_messages = []
|
||||
responses = [
|
||||
{"title": 123},
|
||||
{"title": "Valid title"},
|
||||
]
|
||||
|
||||
async def fake_generate(**kwargs):
|
||||
call_messages.append(kwargs["messages"])
|
||||
return responses[len(call_messages) - 1]
|
||||
|
||||
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
|
||||
|
||||
with patch("services.llm_client.LOGGER.warning") as mock_warning:
|
||||
response = asyncio.run(
|
||||
client.generate_structured(
|
||||
model="test-model",
|
||||
messages=[LLMUserMessage(content="Generate JSON")],
|
||||
response_format={
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
"required": ["title"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
validate_schema=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert response == {"title": "Valid title"}
|
||||
assert len(call_messages) == 2
|
||||
feedback_message = call_messages[1][-1]
|
||||
assert isinstance(feedback_message, LLMUserMessage)
|
||||
assert "Validation errors:" in feedback_message.content
|
||||
assert "$.title" in feedback_message.content
|
||||
assert '"title": 123' in feedback_message.content
|
||||
mock_warning.assert_called_once()
|
||||
assert "$.title" in mock_warning.call_args.args[3]
|
||||
|
||||
|
||||
def test_generate_structured_returns_last_invalid_response_at_max_loop_count():
|
||||
client = _build_client()
|
||||
call_messages = []
|
||||
responses = [
|
||||
{"title": 123},
|
||||
{"title": False},
|
||||
{"title": "should not be used"},
|
||||
]
|
||||
|
||||
async def fake_generate(**kwargs):
|
||||
call_messages.append(kwargs["messages"])
|
||||
return responses[len(call_messages) - 1]
|
||||
|
||||
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
|
||||
|
||||
response = asyncio.run(
|
||||
client.generate_structured(
|
||||
model="test-model",
|
||||
messages=[LLMUserMessage(content="Generate JSON")],
|
||||
response_format={
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
"required": ["title"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
validate_schema=True,
|
||||
validate_schema_max_loop_count=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert response == {"title": False}
|
||||
assert len(call_messages) == 2
|
||||
|
||||
|
||||
def test_generate_structured_uses_strict_schema_for_validation():
|
||||
client = _build_client()
|
||||
call_messages = []
|
||||
responses = [
|
||||
{"title": "Only title"},
|
||||
{"title": "Valid title", "subtitle": "Valid subtitle"},
|
||||
]
|
||||
|
||||
async def fake_generate(**kwargs):
|
||||
call_messages.append(kwargs["messages"])
|
||||
return responses[len(call_messages) - 1]
|
||||
|
||||
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
|
||||
|
||||
response = asyncio.run(
|
||||
client.generate_structured(
|
||||
model="test-model",
|
||||
messages=[LLMUserMessage(content="Generate JSON")],
|
||||
response_format={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"subtitle": {"type": "string"},
|
||||
},
|
||||
},
|
||||
strict=True,
|
||||
validate_schema=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert response == {"title": "Valid title", "subtitle": "Valid subtitle"}
|
||||
assert len(call_messages) == 2
|
||||
feedback_message = call_messages[1][-1]
|
||||
assert "required property" in feedback_message.content
|
||||
assert "subtitle" in feedback_message.content
|
||||
|
||||
|
||||
def test_generate_structured_preserves_no_content_retries():
|
||||
client = _build_client()
|
||||
client._generate_structured_once = AsyncMock(
|
||||
side_effect=[None, None, {"title": "Valid title"}]
|
||||
)
|
||||
|
||||
response = asyncio.run(
|
||||
client.generate_structured(
|
||||
model="test-model",
|
||||
messages=[LLMUserMessage(content="Generate JSON")],
|
||||
response_format={
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
"required": ["title"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response == {"title": "Valid title"}
|
||||
assert client._generate_structured_once.await_count == 3
|
||||
|
||||
|
||||
def test_edit_slide_enables_schema_validation():
|
||||
mock_client = SimpleNamespace(
|
||||
generate_structured=AsyncMock(
|
||||
return_value={
|
||||
"title": "Edited title",
|
||||
"__speaker_note__": "x" * 120,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
with patch("utils.llm_calls.edit_slide.LLMClient", return_value=mock_client), patch(
|
||||
"utils.llm_calls.edit_slide.get_model",
|
||||
return_value="test-model",
|
||||
):
|
||||
response = asyncio.run(
|
||||
get_edited_slide_content(
|
||||
prompt="Update the title",
|
||||
slide=_build_slide(),
|
||||
language="English",
|
||||
slide_layout=_build_layout().slides[0],
|
||||
)
|
||||
)
|
||||
|
||||
assert response["title"] == "Edited title"
|
||||
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
|
||||
|
||||
|
||||
def test_generate_presentation_structure_enables_schema_validation():
|
||||
mock_client = SimpleNamespace(
|
||||
generate_structured=AsyncMock(return_value={"slides": [0]})
|
||||
)
|
||||
mock_response_model = SimpleNamespace(
|
||||
model_json_schema=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slides": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
}
|
||||
},
|
||||
"required": ["slides"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"utils.llm_calls.generate_presentation_structure.LLMClient",
|
||||
return_value=mock_client,
|
||||
), patch(
|
||||
"utils.llm_calls.generate_presentation_structure.get_model",
|
||||
return_value="test-model",
|
||||
), patch(
|
||||
"utils.llm_calls.generate_presentation_structure.get_presentation_structure_model_with_n_slides",
|
||||
return_value=mock_response_model,
|
||||
):
|
||||
response = asyncio.run(
|
||||
generate_presentation_structure(
|
||||
presentation_outline=PresentationOutlineModel(
|
||||
slides=[SlideOutlineModel(content="Outline content")]
|
||||
),
|
||||
presentation_layout=_build_layout(),
|
||||
)
|
||||
)
|
||||
|
||||
assert response.slides == [0]
|
||||
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
|
||||
|
||||
|
||||
def test_generate_slide_content_enables_schema_validation():
|
||||
mock_client = SimpleNamespace(
|
||||
generate_structured=AsyncMock(
|
||||
return_value={
|
||||
"title": "Slide title",
|
||||
"__speaker_note__": "x" * 120,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"utils.llm_calls.generate_slide_content.LLMClient",
|
||||
return_value=mock_client,
|
||||
), patch(
|
||||
"utils.llm_calls.generate_slide_content.get_model",
|
||||
return_value="test-model",
|
||||
):
|
||||
response = asyncio.run(
|
||||
get_slide_content_from_type_and_outline(
|
||||
slide_layout=_build_layout().slides[0],
|
||||
outline=SlideOutlineModel(content="Slide outline"),
|
||||
language="English",
|
||||
)
|
||||
)
|
||||
|
||||
assert response["title"] == "Slide title"
|
||||
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
|
||||
|
||||
|
||||
def test_select_slide_type_on_edit_enables_schema_validation():
|
||||
mock_client = SimpleNamespace(generate_structured=AsyncMock(return_value={"index": 0}))
|
||||
layout = _build_layout()
|
||||
|
||||
with patch(
|
||||
"utils.llm_calls.select_slide_type_on_edit.LLMClient",
|
||||
return_value=mock_client,
|
||||
), patch(
|
||||
"utils.llm_calls.select_slide_type_on_edit.get_model",
|
||||
return_value="test-model",
|
||||
):
|
||||
response = asyncio.run(
|
||||
get_slide_layout_from_prompt(
|
||||
prompt="Use the first layout",
|
||||
layout=layout,
|
||||
slide=_build_slide(),
|
||||
)
|
||||
)
|
||||
|
||||
assert response.id == "layout-1"
|
||||
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
|
||||
|
|
@ -108,7 +108,7 @@ async def get_edited_slide_content(
|
|||
"__speaker_note__": {
|
||||
"type": "string",
|
||||
"minLength": 100,
|
||||
"maxLength": 250,
|
||||
"maxLength": 500,
|
||||
"description": "Speaker note for the slide",
|
||||
}
|
||||
},
|
||||
|
|
@ -124,6 +124,7 @@ async def get_edited_slide_content(
|
|||
),
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
validate_schema=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ async def generate_presentation_structure(
|
|||
),
|
||||
response_format=response_model.model_json_schema(),
|
||||
strict=True,
|
||||
validate_schema=True,
|
||||
)
|
||||
return PresentationStructureModel(**response)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ You need to generate structured content json based on the schema.
|
|||
# General Rules
|
||||
- Make sure to follow language guidelines.
|
||||
- Speaker note should be normal text, not markdown.
|
||||
- Never ever go over the max character limit.
|
||||
- Never ever go over the max character limit but don't clip the sentence to satisfy character limit instead rephrase it.
|
||||
- Do not add emoji in the content.
|
||||
- Don't provide $schema field in content json.
|
||||
{markdown_emphasis_rules}
|
||||
|
|
@ -167,7 +167,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
"__speaker_note__": {
|
||||
"type": "string",
|
||||
"minLength": 100,
|
||||
"maxLength": 250,
|
||||
"maxLength": 500,
|
||||
"description": "Speaker note for the slide",
|
||||
}
|
||||
},
|
||||
|
|
@ -187,6 +187,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
),
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
validate_schema=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ async def get_slide_layout_from_prompt(
|
|||
),
|
||||
response_format=SlideLayoutIndex.model_json_schema(),
|
||||
strict=True,
|
||||
validate_schema=True,
|
||||
)
|
||||
index = SlideLayoutIndex(**response).index
|
||||
return layout.slides[index]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any, List
|
||||
|
||||
from jsonschema.validators import validator_for
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from utils.dict_utils import (
|
||||
|
|
@ -323,6 +324,53 @@ def ensure_array_schemas_have_items(schema: dict) -> dict[str, Any]:
|
|||
return _ensure(result)
|
||||
|
||||
|
||||
def prepare_schema_for_validation(
|
||||
schema: dict,
|
||||
strict: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
prepared_schema = deepcopy(schema)
|
||||
if strict:
|
||||
prepared_schema = ensure_strict_json_schema(
|
||||
prepared_schema,
|
||||
path=(),
|
||||
root=prepared_schema,
|
||||
)
|
||||
return ensure_array_schemas_have_items(prepared_schema)
|
||||
|
||||
|
||||
def format_json_path(path: List[Any]) -> str:
|
||||
if not path:
|
||||
return "$"
|
||||
|
||||
formatted = "$"
|
||||
for part in path:
|
||||
if isinstance(part, int):
|
||||
formatted += f"[{part}]"
|
||||
else:
|
||||
formatted += f".{part}"
|
||||
return formatted
|
||||
|
||||
|
||||
def get_schema_validation_errors(
|
||||
schema: dict,
|
||||
instance: Any,
|
||||
strict: bool = False,
|
||||
) -> List[str]:
|
||||
prepared_schema = prepare_schema_for_validation(schema, strict=strict)
|
||||
validator_cls = validator_for(prepared_schema)
|
||||
validator_cls.check_schema(prepared_schema)
|
||||
validator = validator_cls(prepared_schema)
|
||||
|
||||
errors = sorted(
|
||||
validator.iter_errors(instance),
|
||||
key=lambda error: (format_json_path(list(error.path)), error.message),
|
||||
)
|
||||
|
||||
return [
|
||||
f"{format_json_path(list(error.path))}: {error.message}" for error in errors
|
||||
]
|
||||
|
||||
|
||||
def remove_titles_from_schema(schema: dict) -> dict[str, Any]:
|
||||
|
||||
def _strip_titles(node: Any) -> Any:
|
||||
|
|
|
|||
5
electron/servers/fastapi/uv.lock
generated
5
electron/servers/fastapi/uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = "==3.11.*"
|
||||
resolution-markers = [
|
||||
"platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
|
|
@ -624,6 +624,7 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/fc/2e/d4fcb2978f826358b673f779f78fa8a32ee37df11920dc2bb5589cbeecef/greenlet-3.2.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:784ae58bba89fa1fa5733d170d42486580cab9decda3484779f4759345b29822", size = 270219, upload-time = "2025-06-05T16:10:10.414Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/24/929f853e0202130e4fe163bc1d05a671ce8dcd604f790e14896adac43a52/greenlet-3.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0921ac4ea42a5315d3446120ad48f90c3a6b9bb93dd9b3cf4e4d84a66e42de83", size = 630383, upload-time = "2025-06-05T16:38:51.785Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/b2/0320715eb61ae70c25ceca2f1d5ae620477d246692d9cc284c13242ec31c/greenlet-3.2.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d2971d93bb99e05f8c2c0c2f4aa9484a18d98c4c3bd3c62b65b7e6ae33dfcfaf", size = 642422, upload-time = "2025-06-05T16:41:35.259Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/49/445fd1a210f4747fedf77615d941444349c6a3a4a1135bba9701337cd966/greenlet-3.2.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c667c0bf9d406b77a15c924ef3285e1e05250948001220368e039b6aa5b5034b", size = 638375, upload-time = "2025-06-05T16:48:18.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/c8/ca19760cf6eae75fa8dc32b487e963d863b3ee04a7637da77b616703bc37/greenlet-3.2.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:592c12fb1165be74592f5de0d70f82bc5ba552ac44800d632214b76089945147", size = 637627, upload-time = "2025-06-05T16:13:02.858Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/65/89/77acf9e3da38e9bcfca881e43b02ed467c1dedc387021fc4d9bd9928afb8/greenlet-3.2.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29e184536ba333003540790ba29829ac14bb645514fbd7e32af331e8202a62a5", size = 585502, upload-time = "2025-06-05T16:12:49.642Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/97/c6/ae244d7c95b23b7130136e07a9cc5aadd60d59b5951180dc7dc7e8edaba7/greenlet-3.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93c0bb79844a367782ec4f429d07589417052e621aa39a5ac1fb99c5aa308edc", size = 1114498, upload-time = "2025-06-05T16:36:46.598Z" },
|
||||
|
|
@ -1302,6 +1303,7 @@ dependencies = [
|
|||
{ name = "fastmcp" },
|
||||
{ name = "google-genai" },
|
||||
{ name = "greenlet", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "nltk" },
|
||||
{ name = "openai" },
|
||||
{ name = "pathvalidate" },
|
||||
|
|
@ -1329,6 +1331,7 @@ requires-dist = [
|
|||
{ name = "fastmcp", specifier = ">=2.11.0" },
|
||||
{ name = "google-genai", specifier = ">=1.28.0" },
|
||||
{ name = "greenlet", marker = "sys_platform == 'darwin'", specifier = ">=3.0.0" },
|
||||
{ name = "jsonschema", specifier = ">=4.25.0" },
|
||||
{ name = "nltk", specifier = ">=3.9.1" },
|
||||
{ name = "openai", specifier = ">=1.98.0" },
|
||||
{ name = "pathvalidate", specifier = ">=3.3.1" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue