Merge pull request #518 from presenton/feat/overflow-mitigation-loop

feat: implements overflow mitigation loop on structured generation; fix: changes slide content generation system prompt to avoid clipping
This commit is contained in:
Saurav Niraula 2026-04-15 18:34:27 +05:45 committed by GitHub
commit 36815fe7a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 555 additions and 62 deletions

View file

@ -1,4 +1,4 @@
from typing import Any, Callable, Coroutine, Optional
from typing import Any, Callable, Coroutine
from pydantic import BaseModel, Field

View file

@ -17,6 +17,7 @@ dependencies = [
"google-genai>=1.28.0",
# Platform-specific: greenlet for macOS only (critical for SQLAlchemy async)
"greenlet>=3.0.0; sys_platform == 'darwin'",
"jsonschema>=4.25.0",
"nltk>=3.9.1",
"openai>=1.98.0",
"pathvalidate>=3.3.1",

View file

@ -1,6 +1,7 @@
import asyncio
import dirtyjson
import json
import logging
from typing import AsyncGenerator, List, Optional, Dict, Any
from fastapi import HTTPException
from openai import APIStatusError, AsyncOpenAI, OpenAIError
@ -69,11 +70,15 @@ from utils.schema_utils import (
ensure_array_schemas_have_items,
ensure_strict_json_schema,
flatten_json_schema,
get_schema_validation_errors,
remove_titles_from_schema,
)
LOGGER = logging.getLogger(__name__)
class LLMClient:
def __init__(self):
self.llm_provider = get_llm_provider()
@ -1067,6 +1072,101 @@ class LLMClient:
depth=depth,
)
async def _generate_structured_once(
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
tools: Optional[List[dict]] = None,
max_tokens: Optional[int] = None,
) -> dict | None:
match self.llm_provider:
case LLMProvider.OPENAI:
return await self._generate_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=tools,
max_tokens=max_tokens,
)
case LLMProvider.CODEX:
return await self._generate_codex_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=tools,
max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
return await self._generate_google_structured(
model=model,
messages=messages,
response_format=response_format,
tools=tools,
max_tokens=max_tokens,
)
case LLMProvider.ANTHROPIC:
return await self._generate_anthropic_structured(
model=model,
messages=messages,
response_format=response_format,
tools=tools,
max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
return await self._generate_ollama_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
case LLMProvider.CUSTOM:
return await self._generate_custom_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
def _get_structured_validation_feedback_message(
self,
content: dict,
validation_errors: List[str],
) -> LLMUserMessage:
max_error_count = 10
max_json_chars = 6000
formatted_errors = validation_errors[:max_error_count]
if len(validation_errors) > max_error_count:
formatted_errors.append(
f"...and {len(validation_errors) - max_error_count} more validation errors."
)
previous_response = json.dumps(
content,
ensure_ascii=False,
indent=2,
default=str,
)
if len(previous_response) > max_json_chars:
previous_response = previous_response[:max_json_chars] + "\n... (truncated)"
return LLMUserMessage(
content=(
"The previous JSON response did not match the required response schema.\n\n"
"Validation errors:\n"
+ "\n".join(f"- {error}" for error in formatted_errors)
+ "\n\nPrevious invalid JSON:\n"
+ f"```json\n{previous_response}\n```\n\n"
+ "Return corrected JSON only. Make sure it fully matches the required schema."
)
)
async def generate_structured(
self,
model: str,
@ -1075,68 +1175,69 @@ class LLMClient:
strict: bool = False,
tools: Optional[List[type[LLMTool] | LLMDynamicTool]] = None,
max_tokens: Optional[int] = None,
validate_schema: bool = False,
validate_schema_max_loop_count: int = 5,
) -> dict:
parsed_tools = self.tool_calls_handler.parse_tools(tools)
max_validation_loops = max(1, validate_schema_max_loop_count)
working_messages = [*messages]
for attempt in range(3):
for validation_attempt in range(max_validation_loops):
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.CODEX:
content = await self._generate_codex_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.GOOGLE:
content = await self._generate_google_structured(
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.ANTHROPIC:
content = await self._generate_anthropic_structured(
model=model,
messages=messages,
response_format=response_format,
tools=parsed_tools,
max_tokens=max_tokens,
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
case LLMProvider.CUSTOM:
content = await self._generate_custom_structured(
model=model,
messages=messages,
response_format=response_format,
strict=strict,
max_tokens=max_tokens,
)
for attempt in range(3):
content = await self._generate_structured_once(
model=model,
messages=working_messages,
response_format=response_format,
strict=strict,
tools=parsed_tools,
max_tokens=max_tokens,
)
if content is not None:
if content is not None:
break
if attempt < 2:
await asyncio.sleep(0.5 * (attempt + 1))
if content is None:
raise HTTPException(
status_code=400,
detail="LLM did not return any content",
)
if not validate_schema:
return content
if attempt < 2:
await asyncio.sleep(0.5 * (attempt + 1))
validation_errors = get_schema_validation_errors(
response_format,
content,
strict=strict,
)
if not validation_errors:
return content
formatted_validation_errors = " | ".join(validation_errors)
if validation_attempt == max_validation_loops - 1:
LOGGER.warning(
"Validation error after max fixes, returning last response: %s",
formatted_validation_errors,
)
return content
LOGGER.warning(
"Validation error, attempting fix %s/%s: %s",
validation_attempt + 1,
max_validation_loops - 1,
formatted_validation_errors,
)
working_messages.append(
self._get_structured_validation_feedback_message(
content,
validation_errors,
)
)
raise HTTPException(
status_code=400,
@ -1754,8 +1855,6 @@ class LLMClient:
):
yield event
async def _stream_codex_structured(
self,
model: str,

View file

@ -0,0 +1,338 @@
import asyncio
import uuid
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from enums.llm_provider import LLMProvider
from models.llm_message import LLMUserMessage
from models.presentation_outline_model import PresentationOutlineModel, SlideOutlineModel
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
from templates.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from utils.llm_calls.edit_slide import get_edited_slide_content
from utils.llm_calls.generate_presentation_structure import (
generate_presentation_structure,
)
from utils.llm_calls.generate_slide_content import get_slide_content_from_type_and_outline
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
def _build_client() -> LLMClient:
client = object.__new__(LLMClient)
client.llm_provider = LLMProvider.OPENAI
client.tool_calls_handler = SimpleNamespace(parse_tools=lambda tools: None)
return client
def _build_layout() -> PresentationLayoutModel:
return PresentationLayoutModel(
name="Test Layout",
slides=[
SlideLayoutModel(
id="layout-1",
name="Title Slide",
description="Single title layout",
json_schema={
"type": "object",
"properties": {
"title": {"type": "string"},
},
"required": ["title"],
"additionalProperties": False,
},
)
],
)
def _build_slide() -> SlideModel:
return SlideModel(
presentation=uuid.uuid4(),
layout_group="default",
layout="layout-1",
index=0,
content={"title": "Current title"},
)
def test_generate_structured_skips_validation_when_disabled():
client = _build_client()
call_messages = []
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return {"title": 123}
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
validate_schema=False,
)
)
assert response == {"title": 123}
assert len(call_messages) == 1
assert len(call_messages[0]) == 1
def test_generate_structured_retries_with_validation_feedback():
client = _build_client()
call_messages = []
responses = [
{"title": 123},
{"title": "Valid title"},
]
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return responses[len(call_messages) - 1]
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
with patch("services.llm_client.LOGGER.warning") as mock_warning:
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
validate_schema=True,
)
)
assert response == {"title": "Valid title"}
assert len(call_messages) == 2
feedback_message = call_messages[1][-1]
assert isinstance(feedback_message, LLMUserMessage)
assert "Validation errors:" in feedback_message.content
assert "$.title" in feedback_message.content
assert '"title": 123' in feedback_message.content
mock_warning.assert_called_once()
assert "$.title" in mock_warning.call_args.args[3]
def test_generate_structured_returns_last_invalid_response_at_max_loop_count():
client = _build_client()
call_messages = []
responses = [
{"title": 123},
{"title": False},
{"title": "should not be used"},
]
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return responses[len(call_messages) - 1]
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
validate_schema=True,
validate_schema_max_loop_count=2,
)
)
assert response == {"title": False}
assert len(call_messages) == 2
def test_generate_structured_uses_strict_schema_for_validation():
client = _build_client()
call_messages = []
responses = [
{"title": "Only title"},
{"title": "Valid title", "subtitle": "Valid subtitle"},
]
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return responses[len(call_messages) - 1]
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {
"title": {"type": "string"},
"subtitle": {"type": "string"},
},
},
strict=True,
validate_schema=True,
)
)
assert response == {"title": "Valid title", "subtitle": "Valid subtitle"}
assert len(call_messages) == 2
feedback_message = call_messages[1][-1]
assert "required property" in feedback_message.content
assert "subtitle" in feedback_message.content
def test_generate_structured_preserves_no_content_retries():
client = _build_client()
client._generate_structured_once = AsyncMock(
side_effect=[None, None, {"title": "Valid title"}]
)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
)
)
assert response == {"title": "Valid title"}
assert client._generate_structured_once.await_count == 3
def test_edit_slide_enables_schema_validation():
mock_client = SimpleNamespace(
generate_structured=AsyncMock(
return_value={
"title": "Edited title",
"__speaker_note__": "x" * 120,
}
)
)
with patch("utils.llm_calls.edit_slide.LLMClient", return_value=mock_client), patch(
"utils.llm_calls.edit_slide.get_model",
return_value="test-model",
):
response = asyncio.run(
get_edited_slide_content(
prompt="Update the title",
slide=_build_slide(),
language="English",
slide_layout=_build_layout().slides[0],
)
)
assert response["title"] == "Edited title"
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
def test_generate_presentation_structure_enables_schema_validation():
mock_client = SimpleNamespace(
generate_structured=AsyncMock(return_value={"slides": [0]})
)
mock_response_model = SimpleNamespace(
model_json_schema=lambda: {
"type": "object",
"properties": {
"slides": {
"type": "array",
"items": {"type": "integer"},
}
},
"required": ["slides"],
"additionalProperties": False,
}
)
with patch(
"utils.llm_calls.generate_presentation_structure.LLMClient",
return_value=mock_client,
), patch(
"utils.llm_calls.generate_presentation_structure.get_model",
return_value="test-model",
), patch(
"utils.llm_calls.generate_presentation_structure.get_presentation_structure_model_with_n_slides",
return_value=mock_response_model,
):
response = asyncio.run(
generate_presentation_structure(
presentation_outline=PresentationOutlineModel(
slides=[SlideOutlineModel(content="Outline content")]
),
presentation_layout=_build_layout(),
)
)
assert response.slides == [0]
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
def test_generate_slide_content_enables_schema_validation():
mock_client = SimpleNamespace(
generate_structured=AsyncMock(
return_value={
"title": "Slide title",
"__speaker_note__": "x" * 120,
}
)
)
with patch(
"utils.llm_calls.generate_slide_content.LLMClient",
return_value=mock_client,
), patch(
"utils.llm_calls.generate_slide_content.get_model",
return_value="test-model",
):
response = asyncio.run(
get_slide_content_from_type_and_outline(
slide_layout=_build_layout().slides[0],
outline=SlideOutlineModel(content="Slide outline"),
language="English",
)
)
assert response["title"] == "Slide title"
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
def test_select_slide_type_on_edit_enables_schema_validation():
mock_client = SimpleNamespace(generate_structured=AsyncMock(return_value={"index": 0}))
layout = _build_layout()
with patch(
"utils.llm_calls.select_slide_type_on_edit.LLMClient",
return_value=mock_client,
), patch(
"utils.llm_calls.select_slide_type_on_edit.get_model",
return_value="test-model",
):
response = asyncio.run(
get_slide_layout_from_prompt(
prompt="Use the first layout",
layout=layout,
slide=_build_slide(),
)
)
assert response.id == "layout-1"
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True

View file

@ -108,7 +108,7 @@ async def get_edited_slide_content(
"__speaker_note__": {
"type": "string",
"minLength": 100,
"maxLength": 250,
"maxLength": 500,
"description": "Speaker note for the slide",
}
},
@ -124,6 +124,7 @@ async def get_edited_slide_content(
),
response_format=response_schema,
strict=False,
validate_schema=True,
)
return response

View file

@ -167,6 +167,7 @@ async def generate_presentation_structure(
),
response_format=response_model.model_json_schema(),
strict=True,
validate_schema=True,
)
return PresentationStructureModel(**response)
except Exception as e:

View file

@ -24,7 +24,7 @@ You need to generate structured content json based on the schema.
# General Rules
- Make sure to follow language guidelines.
- Speaker note should be normal text, not markdown.
- Never ever go over the max character limit.
- Never ever go over the max character limit but don't clip the sentence to satisfy character limit instead rephrase it.
- Do not add emoji in the content.
- Don't provide $schema field in content json.
{markdown_emphasis_rules}
@ -167,7 +167,7 @@ async def get_slide_content_from_type_and_outline(
"__speaker_note__": {
"type": "string",
"minLength": 100,
"maxLength": 250,
"maxLength": 500,
"description": "Speaker note for the slide",
}
},
@ -187,6 +187,7 @@ async def get_slide_content_from_type_and_outline(
),
response_format=response_schema,
strict=False,
validate_schema=True,
)
return response

View file

@ -58,6 +58,7 @@ async def get_slide_layout_from_prompt(
),
response_format=SlideLayoutIndex.model_json_schema(),
strict=True,
validate_schema=True,
)
index = SlideLayoutIndex(**response).index
return layout.slides[index]

View file

@ -1,6 +1,7 @@
from copy import deepcopy
from typing import Any, List
from jsonschema.validators import validator_for
from openai import NOT_GIVEN
from utils.dict_utils import (
@ -323,6 +324,53 @@ def ensure_array_schemas_have_items(schema: dict) -> dict[str, Any]:
return _ensure(result)
def prepare_schema_for_validation(
schema: dict,
strict: bool = False,
) -> dict[str, Any]:
prepared_schema = deepcopy(schema)
if strict:
prepared_schema = ensure_strict_json_schema(
prepared_schema,
path=(),
root=prepared_schema,
)
return ensure_array_schemas_have_items(prepared_schema)
def format_json_path(path: List[Any]) -> str:
if not path:
return "$"
formatted = "$"
for part in path:
if isinstance(part, int):
formatted += f"[{part}]"
else:
formatted += f".{part}"
return formatted
def get_schema_validation_errors(
schema: dict,
instance: Any,
strict: bool = False,
) -> List[str]:
prepared_schema = prepare_schema_for_validation(schema, strict=strict)
validator_cls = validator_for(prepared_schema)
validator_cls.check_schema(prepared_schema)
validator = validator_cls(prepared_schema)
errors = sorted(
validator.iter_errors(instance),
key=lambda error: (format_json_path(list(error.path)), error.message),
)
return [
f"{format_json_path(list(error.path))}: {error.message}" for error in errors
]
def remove_titles_from_schema(schema: dict) -> dict[str, Any]:
def _strip_titles(node: Any) -> Any:

View file

@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = "==3.11.*"
resolution-markers = [
"platform_machine == 'aarch64' and sys_platform == 'linux'",
@ -624,6 +624,7 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fc/2e/d4fcb2978f826358b673f779f78fa8a32ee37df11920dc2bb5589cbeecef/greenlet-3.2.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:784ae58bba89fa1fa5733d170d42486580cab9decda3484779f4759345b29822", size = 270219, upload-time = "2025-06-05T16:10:10.414Z" },
{ url = "https://files.pythonhosted.org/packages/16/24/929f853e0202130e4fe163bc1d05a671ce8dcd604f790e14896adac43a52/greenlet-3.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0921ac4ea42a5315d3446120ad48f90c3a6b9bb93dd9b3cf4e4d84a66e42de83", size = 630383, upload-time = "2025-06-05T16:38:51.785Z" },
{ url = "https://files.pythonhosted.org/packages/d1/b2/0320715eb61ae70c25ceca2f1d5ae620477d246692d9cc284c13242ec31c/greenlet-3.2.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d2971d93bb99e05f8c2c0c2f4aa9484a18d98c4c3bd3c62b65b7e6ae33dfcfaf", size = 642422, upload-time = "2025-06-05T16:41:35.259Z" },
{ url = "https://files.pythonhosted.org/packages/bd/49/445fd1a210f4747fedf77615d941444349c6a3a4a1135bba9701337cd966/greenlet-3.2.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c667c0bf9d406b77a15c924ef3285e1e05250948001220368e039b6aa5b5034b", size = 638375, upload-time = "2025-06-05T16:48:18.235Z" },
{ url = "https://files.pythonhosted.org/packages/7e/c8/ca19760cf6eae75fa8dc32b487e963d863b3ee04a7637da77b616703bc37/greenlet-3.2.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:592c12fb1165be74592f5de0d70f82bc5ba552ac44800d632214b76089945147", size = 637627, upload-time = "2025-06-05T16:13:02.858Z" },
{ url = "https://files.pythonhosted.org/packages/65/89/77acf9e3da38e9bcfca881e43b02ed467c1dedc387021fc4d9bd9928afb8/greenlet-3.2.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29e184536ba333003540790ba29829ac14bb645514fbd7e32af331e8202a62a5", size = 585502, upload-time = "2025-06-05T16:12:49.642Z" },
{ url = "https://files.pythonhosted.org/packages/97/c6/ae244d7c95b23b7130136e07a9cc5aadd60d59b5951180dc7dc7e8edaba7/greenlet-3.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93c0bb79844a367782ec4f429d07589417052e621aa39a5ac1fb99c5aa308edc", size = 1114498, upload-time = "2025-06-05T16:36:46.598Z" },
@ -1302,6 +1303,7 @@ dependencies = [
{ name = "fastmcp" },
{ name = "google-genai" },
{ name = "greenlet", marker = "sys_platform == 'darwin'" },
{ name = "jsonschema" },
{ name = "nltk" },
{ name = "openai" },
{ name = "pathvalidate" },
@ -1329,6 +1331,7 @@ requires-dist = [
{ name = "fastmcp", specifier = ">=2.11.0" },
{ name = "google-genai", specifier = ">=1.28.0" },
{ name = "greenlet", marker = "sys_platform == 'darwin'", specifier = ">=3.0.0" },
{ name = "jsonschema", specifier = ">=4.25.0" },
{ name = "nltk", specifier = ">=3.9.1" },
{ name = "openai", specifier = ">=1.98.0" },
{ name = "pathvalidate", specifier = ">=3.3.1" },