presenton/electron/servers/fastapi/tests/test_llm_client_structured_validation.py

338 lines
11 KiB
Python

import asyncio
import uuid
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from enums.llm_provider import LLMProvider
from models.llm_message import LLMUserMessage
from models.presentation_outline_model import PresentationOutlineModel, SlideOutlineModel
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
from templates.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from utils.llm_calls.edit_slide import get_edited_slide_content
from utils.llm_calls.generate_presentation_structure import (
generate_presentation_structure,
)
from utils.llm_calls.generate_slide_content import get_slide_content_from_type_and_outline
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
def _build_client() -> LLMClient:
client = object.__new__(LLMClient)
client.llm_provider = LLMProvider.OPENAI
client.tool_calls_handler = SimpleNamespace(parse_tools=lambda tools: None)
return client
def _build_layout() -> PresentationLayoutModel:
return PresentationLayoutModel(
name="Test Layout",
slides=[
SlideLayoutModel(
id="layout-1",
name="Title Slide",
description="Single title layout",
json_schema={
"type": "object",
"properties": {
"title": {"type": "string"},
},
"required": ["title"],
"additionalProperties": False,
},
)
],
)
def _build_slide() -> SlideModel:
return SlideModel(
presentation=uuid.uuid4(),
layout_group="default",
layout="layout-1",
index=0,
content={"title": "Current title"},
)
def test_generate_structured_skips_validation_when_disabled():
client = _build_client()
call_messages = []
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return {"title": 123}
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
validate_schema=False,
)
)
assert response == {"title": 123}
assert len(call_messages) == 1
assert len(call_messages[0]) == 1
def test_generate_structured_retries_with_validation_feedback():
client = _build_client()
call_messages = []
responses = [
{"title": 123},
{"title": "Valid title"},
]
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return responses[len(call_messages) - 1]
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
with patch("services.llm_client.LOGGER.warning") as mock_warning:
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
validate_schema=True,
)
)
assert response == {"title": "Valid title"}
assert len(call_messages) == 2
feedback_message = call_messages[1][-1]
assert isinstance(feedback_message, LLMUserMessage)
assert "Validation errors:" in feedback_message.content
assert "$.title" in feedback_message.content
assert '"title": 123' in feedback_message.content
mock_warning.assert_called_once()
assert "$.title" in mock_warning.call_args.args[3]
def test_generate_structured_returns_last_invalid_response_at_max_loop_count():
client = _build_client()
call_messages = []
responses = [
{"title": 123},
{"title": False},
{"title": "should not be used"},
]
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return responses[len(call_messages) - 1]
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
validate_schema=True,
validate_schema_max_loop_count=2,
)
)
assert response == {"title": False}
assert len(call_messages) == 2
def test_generate_structured_uses_strict_schema_for_validation():
client = _build_client()
call_messages = []
responses = [
{"title": "Only title"},
{"title": "Valid title", "subtitle": "Valid subtitle"},
]
async def fake_generate(**kwargs):
call_messages.append(kwargs["messages"])
return responses[len(call_messages) - 1]
client._generate_structured_once = AsyncMock(side_effect=fake_generate)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {
"title": {"type": "string"},
"subtitle": {"type": "string"},
},
},
strict=True,
validate_schema=True,
)
)
assert response == {"title": "Valid title", "subtitle": "Valid subtitle"}
assert len(call_messages) == 2
feedback_message = call_messages[1][-1]
assert "required property" in feedback_message.content
assert "subtitle" in feedback_message.content
def test_generate_structured_preserves_no_content_retries():
client = _build_client()
client._generate_structured_once = AsyncMock(
side_effect=[None, None, {"title": "Valid title"}]
)
response = asyncio.run(
client.generate_structured(
model="test-model",
messages=[LLMUserMessage(content="Generate JSON")],
response_format={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
)
)
assert response == {"title": "Valid title"}
assert client._generate_structured_once.await_count == 3
def test_edit_slide_enables_schema_validation():
mock_client = SimpleNamespace(
generate_structured=AsyncMock(
return_value={
"title": "Edited title",
"__speaker_note__": "x" * 120,
}
)
)
with patch("utils.llm_calls.edit_slide.LLMClient", return_value=mock_client), patch(
"utils.llm_calls.edit_slide.get_model",
return_value="test-model",
):
response = asyncio.run(
get_edited_slide_content(
prompt="Update the title",
slide=_build_slide(),
language="English",
slide_layout=_build_layout().slides[0],
)
)
assert response["title"] == "Edited title"
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
def test_generate_presentation_structure_enables_schema_validation():
mock_client = SimpleNamespace(
generate_structured=AsyncMock(return_value={"slides": [0]})
)
mock_response_model = SimpleNamespace(
model_json_schema=lambda: {
"type": "object",
"properties": {
"slides": {
"type": "array",
"items": {"type": "integer"},
}
},
"required": ["slides"],
"additionalProperties": False,
}
)
with patch(
"utils.llm_calls.generate_presentation_structure.LLMClient",
return_value=mock_client,
), patch(
"utils.llm_calls.generate_presentation_structure.get_model",
return_value="test-model",
), patch(
"utils.llm_calls.generate_presentation_structure.get_presentation_structure_model_with_n_slides",
return_value=mock_response_model,
):
response = asyncio.run(
generate_presentation_structure(
presentation_outline=PresentationOutlineModel(
slides=[SlideOutlineModel(content="Outline content")]
),
presentation_layout=_build_layout(),
)
)
assert response.slides == [0]
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
def test_generate_slide_content_enables_schema_validation():
mock_client = SimpleNamespace(
generate_structured=AsyncMock(
return_value={
"title": "Slide title",
"__speaker_note__": "x" * 120,
}
)
)
with patch(
"utils.llm_calls.generate_slide_content.LLMClient",
return_value=mock_client,
), patch(
"utils.llm_calls.generate_slide_content.get_model",
return_value="test-model",
):
response = asyncio.run(
get_slide_content_from_type_and_outline(
slide_layout=_build_layout().slides[0],
outline=SlideOutlineModel(content="Slide outline"),
language="English",
)
)
assert response["title"] == "Slide title"
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True
def test_select_slide_type_on_edit_enables_schema_validation():
mock_client = SimpleNamespace(generate_structured=AsyncMock(return_value={"index": 0}))
layout = _build_layout()
with patch(
"utils.llm_calls.select_slide_type_on_edit.LLMClient",
return_value=mock_client,
), patch(
"utils.llm_calls.select_slide_type_on_edit.get_model",
return_value="test-model",
):
response = asyncio.run(
get_slide_layout_from_prompt(
prompt="Use the first layout",
layout=layout,
slide=_build_slide(),
)
)
assert response.id == "layout-1"
assert mock_client.generate_structured.await_args.kwargs["validate_schema"] is True