presenton/servers/fastapi/tests/test_image_generation.py

435 lines
24 KiB
Python

import pytest
import asyncio
import os
from unittest.mock import Mock, patch, AsyncMock
import httpx
from fastapi.testclient import TestClient
from fastapi import FastAPI
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
from models.image_prompt import ImagePrompt
from services.image_generation_service import ImageGenerationService
from models.sql.image_asset import ImageAsset
class TestImageGenerationService:
"""
Testing the image Generation Service
"""
@pytest.fixture
def mock_images_directory(self, tmp_path):
"""
Creates new images directory for every test case we run
"""
images_dir = tmp_path / "images"
images_dir.mkdir()
return str(images_dir)
@pytest.fixture
def sample_image_prompt(self):
"""
Creates a sample ImagePrompt for testing
"""
return ImagePrompt(prompt="A beautiful sunset over mountains")
def test_image_generation_service_initialization(self, mock_images_directory):
"""
Test initialization of ImageGenerationService with output directory
- Checks if the output directory is set correctly
- Checks if the image generation function is set based on environment variable
"""
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
service = ImageGenerationService(mock_images_directory)
assert service.output_directory == mock_images_directory
assert service.image_gen_func is not None or service.image_gen_func is None
def test_get_image_gen_func_pixabay_selected(self, mock_images_directory):
"""
Testing the function selection when Pixabay is selected
- Checks if the correct function is selected based on environment variable
- Ensures that the function is set to get_image_from_pixabay when Pixabay is selected
"""
with patch('services.image_generation_service.is_pixabay_selected', return_value=True):
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
service = ImageGenerationService(mock_images_directory)
assert service.image_gen_func == service.get_image_from_pixabay
def test_get_image_gen_func_pexels_selected(self, mock_images_directory):
"""
Test function selection when Pexels is selected
- Checks if the correct function is selected based on environment variable
- Ensures that the function is set to get_image_from_pexels when Pexels is selected
"""
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
service = ImageGenerationService(mock_images_directory)
assert service.image_gen_func == service.get_image_from_pexels
def test_get_image_gen_func_dalle3_selected(self, mock_images_directory):
"""
Test function selection when DALL-E 3 is selected
- Checks if the correct function is selected based on environment variable
- Ensures that the function is set to generate_image_openai when DALL-E 3 is selected
"""
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=True):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
service = ImageGenerationService(mock_images_directory)
assert service.image_gen_func == service.generate_image_openai
def test_is_stock_provider_selected(self, mock_images_directory):
"""
Test if stock provider is selected based on environment variable
- Checks if the stock provider is selected correctly based on environment variable
- Ensures that is_stock_provider_selected returns True for Pexels or Pixabay
"""
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
service = ImageGenerationService(mock_images_directory)
assert service.is_stock_provider_selected() is True
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
with patch('services.image_generation_service.is_pixabay_selected', return_value=True):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
service = ImageGenerationService(mock_images_directory)
assert service.is_stock_provider_selected() is True
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
service = ImageGenerationService(mock_images_directory)
assert service.is_stock_provider_selected() is False
def test_generate_image_with_pexels_success(self, mock_images_directory, sample_image_prompt):
"""
Test successful image generation with Pexels provider
- Mocks the Pexels API to return a valid image URL
- Ensures that the image generation function returns the expected URL
- Checks if the image generation function is called with the correct prompt
"""
async def run_test():
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_key"}):
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
service = ImageGenerationService(mock_images_directory)
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={
"photos": [{
"src": {
"large": "https://example.com/image.jpg"
}
}]
})
mock_session = AsyncMock()
mock_session.get = AsyncMock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch('aiohttp.ClientSession', return_value=mock_session):
result = await service.generate_image(sample_image_prompt)
assert result == "https://example.com/image.jpg"
asyncio.run(run_test())
def test_generate_image_with_dalle3_success(self, mock_images_directory, sample_image_prompt):
"""
Test successful image generation with DALL-E 3 provider
- Mocks the OpenAI client to return a valid image URL
- Ensures that the image generation function returns the expected URL
- Checks if the image generation function is called with the correct prompt
"""
async def run_test():
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=True):
service = ImageGenerationService(mock_images_directory)
# Create a real test file
test_image_path = f"{mock_images_directory}/test_image.jpg"
with open(test_image_path, 'w') as f:
f.write("fake image content")
# Mock generate_image_openai to return the test file path
async def mock_openai_generate(prompt, output_dir):
return test_image_path
service.generate_image_openai = mock_openai_generate
result = await service.generate_image(sample_image_prompt)
# Should return ImageAsset for AI providers
assert isinstance(result, ImageAsset)
assert result.path == test_image_path
assert result.extras["prompt"] == sample_image_prompt.prompt
def test_generate_image_no_provider_selected(self, mock_images_directory, sample_image_prompt):
"""
Test generate_image when no provider is selected
- Mocks the environment variable to simulate no provider selected
- Ensures that the function returns a placeholder image path
- Checks if the image generation function is called with the correct prompt
"""
async def run_test():
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
service = ImageGenerationService(mock_images_directory)
result = await service.generate_image(sample_image_prompt)
# Should return placeholder
assert result == "/static/images/placeholder.jpg"
asyncio.run(run_test())
def test_generate_image_provider_error(self, mock_images_directory, sample_image_prompt):
"""
Test generate_image when provider function raises an error
- Mocks the Pexels API to raise an exception
- Ensures that the function returns a placeholder image path
- Checks if the image generation function is called with the correct prompt
"""
async def run_test():
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
service = ImageGenerationService(mock_images_directory)
async def mock_pexels_error(*args, **kwargs):
raise Exception("API Error")
service.get_image_from_pexels = mock_pexels_error
result = await service.generate_image(sample_image_prompt)
assert result == "/static/images/placeholder.jpg"
asyncio.run(run_test())
def test_get_image_from_pexels_real_function(self, mock_images_directory):
"""T
Test REAL Pexels function with mocked HTTP call
- Mocks the Pexels API to return a valid image URL
- Ensures that the function returns the expected URL
- Checks if the HTTP call is made with the correct parameters
"""
async def run_test():
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_pexels_key"}):
service = ImageGenerationService(mock_images_directory)
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={
"photos": [{
"src": {
"large": "https://example.com/pexels_image.jpg"
}
}]
})
mock_session = AsyncMock()
mock_session.get = AsyncMock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch('aiohttp.ClientSession', return_value=mock_session):
result = await service.get_image_from_pexels("sunset")
assert result == "https://example.com/pexels_image.jpg"
mock_session.get.assert_called_once()
asyncio.run(run_test())
def test_get_image_from_pixabay_real_function(self, mock_images_directory):
"""
Test REAL Pixabay function with mocked HTTP call
- Mocks the Pixabay API to return a valid image URL
- Ensures that the function returns the expected URL
- Checks if the HTTP call is made with the correct parameters
"""
async def run_test():
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay", "PIXABAY_API_KEY": "test_pixabay_key"}):
service = ImageGenerationService(mock_images_directory)
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={
"hits": [{
"largeImageURL": "https://example.com/pixabay_image.jpg"
}]
})
mock_session = AsyncMock()
mock_session.get = AsyncMock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch('aiohttp.ClientSession', return_value=mock_session):
result = await service.get_image_from_pixabay("sunset")
assert result == "https://example.com/pixabay_image.jpg"
mock_session.get.assert_called_once()
asyncio.run(run_test())
class TestImageGenerationEndpoint:
"""
Testing the Image Generation API Endpoint
"""
@pytest.fixture
def app(self):
"""Create FastAPI app with the images router"""
app = FastAPI()
app.include_router(IMAGES_ROUTER)
return app
@pytest.fixture
def client(self, app):
"""Create test client"""
return TestClient(app)
@pytest.fixture
def mock_images_directory(self, tmp_path):
"""Mock images directory"""
images_dir = tmp_path / "images"
images_dir.mkdir()
return str(images_dir)
def test_generate_image_endpoint_success_stock_provider(self, client, mock_images_directory):
"""
Test successful image generation via API endpoint with stock provider
- Mocks the ImageGenerationService to return a stock image URL
- Ensures that the endpoint returns the expected URL
- Checks if the image generation function is called with the correct prompt
"""
test_prompt = "A beautiful sunset over mountains"
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
mock_service_instance = Mock()
mock_service_instance.generate_image = AsyncMock(return_value="https://example.com/stock_image.jpg")
mock_service_class.return_value = mock_service_instance
response = client.get(f"/images/generate?prompt={test_prompt}")
assert response.status_code == 200
def test_generate_image_endpoint_success_ai_provider(self, client, mock_images_directory):
"""
Test successful image generation via API endpoint with AI provider
- Mocks the ImageGenerationService to return an ImageAsset object
- Ensures that the endpoint returns the expected ImageAsset object
- Checks if the image generation function is called with the correct prompt
"""
test_prompt = "A beautiful sunset over mountains"
test_image_asset = ImageAsset(
path=f"{mock_images_directory}/test_image.jpg",
extras={"prompt": test_prompt, "theme_prompt": "professional"}
)
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
mock_service_instance = Mock()
mock_service_instance.generate_image = AsyncMock(return_value=test_image_asset)
mock_service_class.return_value = mock_service_instance
response = client.get(f"/images/generate?prompt={test_prompt}")
assert response.status_code == 200
def test_generate_image_endpoint_placeholder_response(self, client, mock_images_directory):
"""
Test endpoint returns placeholder image when no provider is selected
- Mocks the ImageGenerationService to return a placeholder image path
- Ensures that the endpoint returns the placeholder image path
- Checks if the image generation function is called with the correct prompt
"""
test_prompt = "Test prompt"
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
mock_service_instance = Mock()
mock_service_instance.generate_image = AsyncMock(return_value="/static/images/placeholder.jpg")
mock_service_class.return_value = mock_service_instance
response = client.get(f"/images/generate?prompt={test_prompt}")
assert response.status_code == 200
def test_generate_image_endpoint_with_async_client(self, mock_images_directory):
"""
Test the image generation endpoint using an async client
- Mocks the ImageGenerationService to return a valid image URL
- Ensures that the endpoint returns the expected URL
- Checks if the image generation function is called with the correct prompt
"""
async def run_test():
app = FastAPI()
app.include_router(IMAGES_ROUTER)
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as ac:
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
mock_service_instance = Mock()
mock_service_instance.generate_image = AsyncMock(return_value="https://example.com/image.jpg")
mock_service_class.return_value = mock_service_instance
response = await ac.get("/images/generate?prompt=test")
assert response.status_code == 200
asyncio.run(run_test())
def test_search_stock_images_defaults_to_selected_pixabay(self, client, mock_images_directory):
"""
Test stock image search defaults to IMAGE_PROVIDER when provider query param is omitted
- Sets IMAGE_PROVIDER to pixabay
- Ensures /images/search uses Pixabay instead of returning provider validation error
"""
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
mock_service_instance = Mock()
mock_service_instance.get_image_from_pixabay = AsyncMock(
return_value=["https://example.com/pixabay_image.jpg"]
)
mock_service_instance.get_image_from_pexels = AsyncMock(
return_value=["https://example.com/pexels_image.jpg"]
)
mock_service_class.return_value = mock_service_instance
response = client.get("/images/search?query=business&limit=1")
assert response.status_code == 200
assert response.json() == ["https://example.com/pixabay_image.jpg"]
mock_service_instance.get_image_from_pixabay.assert_awaited_once()
mock_service_instance.get_image_from_pexels.assert_not_called()
def test_search_stock_images_invalid_provider_returns_400(self, client):
"""
Test stock image search validates invalid provider values
- Ensures unsupported providers return HTTP 400 with clear guidance
"""
response = client.get("/images/search?query=business&provider=invalid-provider")
assert response.status_code == 400
assert response.json()["detail"] == "provider must be either 'pexels' or 'pixabay'"