145 lines
No EOL
5.5 KiB
Python
Executable file
145 lines
No EOL
5.5 KiB
Python
Executable file
"""
|
|
Tests for the LLM Service
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
import json
|
|
from app.services.llm_service import LLMService, LLMServiceError
|
|
|
|
class TestLLMService(unittest.TestCase):
|
|
"""Test cases for the LLM Service"""
|
|
|
|
@patch('app.services.llm_service.genai.GenerativeModel')
|
|
def test_get_model(self, mock_generative_model):
|
|
"""Test getting a Gemini model"""
|
|
# Setup mock
|
|
mock_model = MagicMock()
|
|
mock_generative_model.return_value = mock_model
|
|
|
|
# Test with default model
|
|
model = LLMService.get_model()
|
|
mock_generative_model.assert_called_once()
|
|
self.assertEqual(model, mock_model)
|
|
|
|
# Reset mock
|
|
mock_generative_model.reset_mock()
|
|
|
|
# Test with custom model
|
|
custom_model = "custom-model-name"
|
|
model = LLMService.get_model(custom_model)
|
|
mock_generative_model.assert_called_once_with(custom_model)
|
|
self.assertEqual(model, mock_model)
|
|
|
|
@patch('app.services.llm_service.LLMService.get_model')
|
|
def test_generate_content(self, mock_get_model):
|
|
"""Test generating content with the LLM"""
|
|
# Setup mock
|
|
mock_model = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.text = "Generated text response"
|
|
mock_model.generate_content.return_value = mock_response
|
|
mock_get_model.return_value = mock_model
|
|
|
|
# Test with default parameters
|
|
prompt = "Test prompt"
|
|
response = LLMService.generate_content(prompt)
|
|
|
|
mock_get_model.assert_called_once()
|
|
mock_model.generate_content.assert_called_once()
|
|
self.assertEqual(response, "Generated text response")
|
|
|
|
# Test with custom parameters
|
|
mock_get_model.reset_mock()
|
|
mock_model.generate_content.reset_mock()
|
|
|
|
response = LLMService.generate_content(
|
|
prompt="Custom prompt",
|
|
temperature=0.5,
|
|
max_tokens=100,
|
|
model_name="custom-model"
|
|
)
|
|
|
|
mock_get_model.assert_called_once_with("custom-model")
|
|
mock_model.generate_content.assert_called_once()
|
|
self.assertEqual(response, "Generated text response")
|
|
|
|
@patch('app.services.llm_service.LLMService.get_model')
|
|
def test_generate_content_error(self, mock_get_model):
|
|
"""Test error handling in generate_content"""
|
|
# Setup mock to raise an exception
|
|
mock_model = MagicMock()
|
|
mock_model.generate_content.side_effect = Exception("Model error")
|
|
mock_get_model.return_value = mock_model
|
|
|
|
# Test error handling
|
|
with self.assertRaises(LLMServiceError) as context:
|
|
LLMService.generate_content("Test prompt")
|
|
|
|
self.assertIn("Error generating content", str(context.exception))
|
|
|
|
def test_parse_json_response_valid(self):
|
|
"""Test parsing valid JSON responses"""
|
|
# Test with clean JSON
|
|
clean_json = '{"key": "value", "number": 42}'
|
|
result = LLMService.parse_json_response(clean_json)
|
|
expected = {"key": "value", "number": 42}
|
|
self.assertEqual(result, expected)
|
|
|
|
# Test with JSON in markdown code block
|
|
markdown_json = '```json\n{"key": "value", "number": 42}\n```'
|
|
result = LLMService.parse_json_response(markdown_json)
|
|
self.assertEqual(result, expected)
|
|
|
|
# Test with JSON in generic code block
|
|
generic_code_block = '```\n{"key": "value", "number": 42}\n```'
|
|
result = LLMService.parse_json_response(generic_code_block)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_parse_json_response_invalid(self):
|
|
"""Test parsing invalid JSON responses"""
|
|
invalid_json = 'This is not JSON'
|
|
|
|
with self.assertRaises(LLMServiceError) as context:
|
|
LLMService.parse_json_response(invalid_json)
|
|
|
|
self.assertIn("Failed to parse JSON response", str(context.exception))
|
|
|
|
@patch('app.services.llm_service.LLMService.generate_content')
|
|
@patch('app.services.llm_service.LLMService.parse_json_response')
|
|
def test_generate_structured_response(self, mock_parse_json, mock_generate_content):
|
|
"""Test generating a structured JSON response"""
|
|
# Setup mocks
|
|
mock_generate_content.return_value = '{"result": "success"}'
|
|
mock_parse_json.return_value = {"result": "success"}
|
|
|
|
# Test
|
|
result = LLMService.generate_structured_response(
|
|
prompt="Generate JSON",
|
|
temperature=0.5
|
|
)
|
|
|
|
mock_generate_content.assert_called_once_with(
|
|
prompt="Generate JSON",
|
|
temperature=0.5,
|
|
max_tokens=None,
|
|
model_name=None,
|
|
system_prompt=None
|
|
)
|
|
mock_parse_json.assert_called_once_with('{"result": "success"}')
|
|
self.assertEqual(result, {"result": "success"})
|
|
|
|
@patch('app.services.llm_service.LLMService.generate_content')
|
|
def test_generate_structured_response_error(self, mock_generate_content):
|
|
"""Test error handling in generate_structured_response"""
|
|
# Setup mock to raise an exception
|
|
mock_generate_content.side_effect = LLMServiceError("Generation failed")
|
|
|
|
# Test error propagation
|
|
with self.assertRaises(LLMServiceError) as context:
|
|
LLMService.generate_structured_response("Generate JSON")
|
|
|
|
self.assertEqual(str(context.exception), "Generation failed")
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |