feat: add image provider configuration and validation
- Updated LLMConfig interface to include IMAGE_PROVIDER and PIXABAY_API_KEY. - Enhanced handleSaveLLMConfig to log the saving process and validate IMAGE_PROVIDER. - Implemented image provider validation logic in hasValidLLMConfig to check for required API keys based on the selected provider. - Modified start.js to read IMAGE_PROVIDER and PIXABAY_API_KEY from environment variables and include them in the user configuration setup.
This commit is contained in:
parent
957ad959dd
commit
21dca979ce
16 changed files with 2717 additions and 1552 deletions
|
|
@ -5,12 +5,17 @@ from fastapi import FastAPI
|
|||
from sqlmodel import SQLModel
|
||||
|
||||
from services import SQL_ENGINE
|
||||
from utils.model_availability import check_llm_model_availability
|
||||
from utils.model_availability import check_llm_and_image_provider_api_or_model_availability
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(_: FastAPI):
|
||||
"""
|
||||
Lifespan context manager for FastAPI application.
|
||||
Initializes the application data directory and checks LLM model availability.
|
||||
|
||||
"""
|
||||
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
|
||||
SQLModel.metadata.create_all(SQL_ENGINE)
|
||||
await check_llm_model_availability()
|
||||
await check_llm_and_image_provider_api_or_model_availability()
|
||||
yield
|
||||
|
|
|
|||
7
servers/fastapi/enums/image_provider.py
Normal file
7
servers/fastapi/enums/image_provider.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from enum import Enum
|
||||
|
||||
class ImageProvider(Enum):
|
||||
PEXELS = "pexels"
|
||||
PIXABAY = "pixabay"
|
||||
IMAGEN = "imagen"
|
||||
DALLE3 = "dall-e-3"
|
||||
|
|
@ -12,3 +12,5 @@ class UserConfig(BaseModel):
|
|||
CUSTOM_LLM_API_KEY: Optional[str] = None
|
||||
CUSTOM_MODEL: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
IMAGE_PROVIDER: Optional[str] = None
|
||||
PIXABAY_API_KEY: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -8,43 +8,59 @@ from models.image_prompt import ImagePrompt
|
|||
from models.sql.image_asset import ImageAsset
|
||||
from utils.download_helpers import download_file
|
||||
from utils.get_env import get_pexels_api_key_env
|
||||
from utils.get_env import get_pixabay_api_key_env
|
||||
from utils.llm_provider import (
|
||||
get_llm_client,
|
||||
is_google_selected,
|
||||
is_openai_selected,
|
||||
)
|
||||
|
||||
from utils.image_provider import (
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
is_imagen_selected,
|
||||
is_dalle3_selected
|
||||
)
|
||||
|
||||
class ImageGenerationService:
|
||||
|
||||
def __init__(self, output_directory: str):
|
||||
self.output_directory = output_directory
|
||||
|
||||
self.use_pexels = False
|
||||
if get_pexels_api_key_env():
|
||||
self.use_pexels = True
|
||||
|
||||
self.image_gen_func = self.get_image_gen_func()
|
||||
|
||||
def get_image_gen_func(self):
|
||||
if self.use_pexels:
|
||||
if is_pixabay_selected():
|
||||
return self.get_image_from_pixabay
|
||||
elif is_pixels_selected():
|
||||
return self.get_image_from_pexels
|
||||
elif is_google_selected():
|
||||
elif is_imagen_selected():
|
||||
return self.generate_image_google
|
||||
elif is_openai_selected():
|
||||
elif is_dalle3_selected():
|
||||
return self.generate_image_openai
|
||||
return None
|
||||
|
||||
def is_stock_provider_selected(self):
|
||||
return is_pixels_selected() or is_pixabay_selected()
|
||||
|
||||
async def generate_image(self, prompt: ImagePrompt) -> str | ImageAsset:
|
||||
"""
|
||||
Generates an image based on the provided prompt.
|
||||
- If no image generation function is available, returns a placeholder image.
|
||||
- If the stock provider is selected, it uses the prompt directly,
|
||||
otherwise it uses the full image prompt with theme.
|
||||
- Output Directory is used for saving the generated image not the stock provider.
|
||||
"""
|
||||
if not self.image_gen_func:
|
||||
print("No image generation function found. Using placeholder image.")
|
||||
return "/static/images/placeholder.jpg"
|
||||
|
||||
image_prompt = prompt.get_image_prompt(not self.use_pexels)
|
||||
image_prompt = prompt.get_image_prompt(with_theme=not self.is_stock_provider_selected())
|
||||
print(f"Request - Generating Image for {image_prompt}")
|
||||
|
||||
try:
|
||||
image_path = await self.image_gen_func(image_prompt, self.output_directory)
|
||||
if self.is_stock_provider_selected():
|
||||
image_path = await self.image_gen_func(image_prompt)
|
||||
else:
|
||||
image_path = await self.image_gen_func(image_prompt, self.output_directory)
|
||||
if image_path:
|
||||
if image_path.startswith("http"):
|
||||
return image_path
|
||||
|
|
@ -102,3 +118,12 @@ class ImageGenerationService:
|
|||
data = await response.json()
|
||||
image_url = data["photos"][0]["src"]["large"]
|
||||
return image_url
|
||||
|
||||
async def get_image_from_pixabay(self, prompt: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response = await session.get(
|
||||
f"https://pixabay.com/api/?key={os.getenv('PIXABAY_API_KEY')}&q={prompt}&image_type=photo&per_page=3"
|
||||
)
|
||||
data = await response.json()
|
||||
image_url = data["hits"][0]["largeImageURL"]
|
||||
return image_url
|
||||
|
|
|
|||
400
servers/fastapi/tests/test_image_generation.py
Normal file
400
servers/fastapi/tests/test_image_generation.py
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
||||
from models.image_prompt import ImagePrompt
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from models.sql.image_asset import ImageAsset
|
||||
|
||||
|
||||
class TestImageGenerationService:
|
||||
"""
|
||||
Testing the image Generation Service
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_images_directory(self, tmp_path):
|
||||
"""
|
||||
Creates new images directory for every test case we run
|
||||
"""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
return str(images_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_prompt(self):
|
||||
"""
|
||||
Creates a sample ImagePrompt for testing
|
||||
"""
|
||||
return ImagePrompt(prompt="A beautiful sunset over mountains")
|
||||
|
||||
def test_image_generation_service_initialization(self, mock_images_directory):
|
||||
"""
|
||||
Test initialization of ImageGenerationService with output directory
|
||||
- Checks if the output directory is set correctly
|
||||
- Checks if the image generation function is set based on environment variable
|
||||
"""
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.output_directory == mock_images_directory
|
||||
assert service.image_gen_func is not None or service.image_gen_func is None
|
||||
|
||||
def test_get_image_gen_func_pixabay_selected(self, mock_images_directory):
|
||||
"""
|
||||
Testing the function selection when Pixabay is selected
|
||||
- Checks if the correct function is selected based on environment variable
|
||||
- Ensures that the function is set to get_image_from_pixabay when Pixabay is selected
|
||||
"""
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.image_gen_func == service.get_image_from_pixabay
|
||||
|
||||
def test_get_image_gen_func_pexels_selected(self, mock_images_directory):
|
||||
"""
|
||||
Test function selection when Pexels is selected
|
||||
- Checks if the correct function is selected based on environment variable
|
||||
- Ensures that the function is set to get_image_from_pexels when Pexels is selected
|
||||
"""
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.image_gen_func == service.get_image_from_pexels
|
||||
|
||||
def test_get_image_gen_func_dalle3_selected(self, mock_images_directory):
|
||||
"""
|
||||
Test function selection when DALL-E 3 is selected
|
||||
- Checks if the correct function is selected based on environment variable
|
||||
- Ensures that the function is set to generate_image_openai when DALL-E 3 is selected
|
||||
"""
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=True):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.image_gen_func == service.generate_image_openai
|
||||
|
||||
def test_is_stock_provider_selected(self, mock_images_directory):
|
||||
"""
|
||||
Test if stock provider is selected based on environment variable
|
||||
- Checks if the stock provider is selected correctly based on environment variable
|
||||
- Ensures that is_stock_provider_selected returns True for Pexels or Pixabay
|
||||
"""
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.is_stock_provider_selected() is True
|
||||
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=True):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.is_stock_provider_selected() is True
|
||||
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
assert service.is_stock_provider_selected() is False
|
||||
|
||||
def test_generate_image_with_pexels_success(self, mock_images_directory, sample_image_prompt):
|
||||
"""
|
||||
Test successful image generation with Pexels provider
|
||||
- Mocks the Pexels API to return a valid image URL
|
||||
- Ensures that the image generation function returns the expected URL
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
async def run_test():
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_key"}):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
"photos": [{
|
||||
"src": {
|
||||
"large": "https://example.com/image.jpg"
|
||||
}
|
||||
}]
|
||||
})
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.get = AsyncMock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch('aiohttp.ClientSession', return_value=mock_session):
|
||||
result = await service.generate_image(sample_image_prompt)
|
||||
assert result == "https://example.com/image.jpg"
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_generate_image_with_dalle3_success(self, mock_images_directory, sample_image_prompt):
|
||||
"""
|
||||
Test successful image generation with DALL-E 3 provider
|
||||
- Mocks the OpenAI client to return a valid image URL
|
||||
- Ensures that the image generation function returns the expected URL
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
async def run_test():
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=True):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
# Create a real test file
|
||||
test_image_path = f"{mock_images_directory}/test_image.jpg"
|
||||
with open(test_image_path, 'w') as f:
|
||||
f.write("fake image content")
|
||||
|
||||
# Mock generate_image_openai to return the test file path
|
||||
async def mock_openai_generate(prompt, output_dir):
|
||||
return test_image_path
|
||||
|
||||
service.generate_image_openai = mock_openai_generate
|
||||
|
||||
result = await service.generate_image(sample_image_prompt)
|
||||
|
||||
# Should return ImageAsset for AI providers
|
||||
assert isinstance(result, ImageAsset)
|
||||
assert result.path == test_image_path
|
||||
assert result.extras["prompt"] == sample_image_prompt.prompt
|
||||
|
||||
def test_generate_image_no_provider_selected(self, mock_images_directory, sample_image_prompt):
|
||||
"""
|
||||
Test generate_image when no provider is selected
|
||||
- Mocks the environment variable to simulate no provider selected
|
||||
- Ensures that the function returns a placeholder image path
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
async def run_test():
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
result = await service.generate_image(sample_image_prompt)
|
||||
|
||||
# Should return placeholder
|
||||
assert result == "/static/images/placeholder.jpg"
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_generate_image_provider_error(self, mock_images_directory, sample_image_prompt):
|
||||
"""
|
||||
Test generate_image when provider function raises an error
|
||||
- Mocks the Pexels API to raise an exception
|
||||
- Ensures that the function returns a placeholder image path
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
async def run_test():
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
async def mock_pexels_error(*args, **kwargs):
|
||||
raise Exception("API Error")
|
||||
|
||||
service.get_image_from_pexels = mock_pexels_error
|
||||
|
||||
result = await service.generate_image(sample_image_prompt)
|
||||
|
||||
assert result == "/static/images/placeholder.jpg"
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_get_image_from_pexels_real_function(self, mock_images_directory):
|
||||
"""T
|
||||
Test REAL Pexels function with mocked HTTP call
|
||||
- Mocks the Pexels API to return a valid image URL
|
||||
- Ensures that the function returns the expected URL
|
||||
- Checks if the HTTP call is made with the correct parameters
|
||||
"""
|
||||
async def run_test():
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_pexels_key"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
"photos": [{
|
||||
"src": {
|
||||
"large": "https://example.com/pexels_image.jpg"
|
||||
}
|
||||
}]
|
||||
})
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.get = AsyncMock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch('aiohttp.ClientSession', return_value=mock_session):
|
||||
result = await service.get_image_from_pexels("sunset")
|
||||
|
||||
assert result == "https://example.com/pexels_image.jpg"
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_get_image_from_pixabay_real_function(self, mock_images_directory):
|
||||
"""
|
||||
Test REAL Pixabay function with mocked HTTP call
|
||||
- Mocks the Pixabay API to return a valid image URL
|
||||
- Ensures that the function returns the expected URL
|
||||
- Checks if the HTTP call is made with the correct parameters
|
||||
"""
|
||||
async def run_test():
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay", "PIXABAY_API_KEY": "test_pixabay_key"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={
|
||||
"hits": [{
|
||||
"largeImageURL": "https://example.com/pixabay_image.jpg"
|
||||
}]
|
||||
})
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.get = AsyncMock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch('aiohttp.ClientSession', return_value=mock_session):
|
||||
result = await service.get_image_from_pixabay("sunset")
|
||||
|
||||
assert result == "https://example.com/pixabay_image.jpg"
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
class TestImageGenerationEndpoint:
|
||||
"""
|
||||
Testing the Image Generation API Endpoint
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create FastAPI app with the images router"""
|
||||
app = FastAPI()
|
||||
app.include_router(IMAGES_ROUTER)
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
"""Create test client"""
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_images_directory(self, tmp_path):
|
||||
"""Mock images directory"""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
return str(images_dir)
|
||||
|
||||
def test_generate_image_endpoint_success_stock_provider(self, client, mock_images_directory):
|
||||
"""
|
||||
Test successful image generation via API endpoint with stock provider
|
||||
- Mocks the ImageGenerationService to return a stock image URL
|
||||
- Ensures that the endpoint returns the expected URL
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
test_prompt = "A beautiful sunset over mountains"
|
||||
|
||||
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
|
||||
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.generate_image = AsyncMock(return_value="https://example.com/stock_image.jpg")
|
||||
mock_service_class.return_value = mock_service_instance
|
||||
response = client.get(f"/images/generate?prompt={test_prompt}")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_generate_image_endpoint_success_ai_provider(self, client, mock_images_directory):
|
||||
"""
|
||||
Test successful image generation via API endpoint with AI provider
|
||||
- Mocks the ImageGenerationService to return an ImageAsset object
|
||||
- Ensures that the endpoint returns the expected ImageAsset object
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
test_prompt = "A beautiful sunset over mountains"
|
||||
|
||||
test_image_asset = ImageAsset(
|
||||
path=f"{mock_images_directory}/test_image.jpg",
|
||||
extras={"prompt": test_prompt, "theme_prompt": "professional"}
|
||||
)
|
||||
|
||||
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
|
||||
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.generate_image = AsyncMock(return_value=test_image_asset)
|
||||
mock_service_class.return_value = mock_service_instance
|
||||
|
||||
response = client.get(f"/images/generate?prompt={test_prompt}")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_generate_image_endpoint_placeholder_response(self, client, mock_images_directory):
|
||||
"""
|
||||
Test endpoint returns placeholder image when no provider is selected
|
||||
- Mocks the ImageGenerationService to return a placeholder image path
|
||||
- Ensures that the endpoint returns the placeholder image path
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
test_prompt = "Test prompt"
|
||||
|
||||
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
|
||||
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.generate_image = AsyncMock(return_value="/static/images/placeholder.jpg")
|
||||
mock_service_class.return_value = mock_service_instance
|
||||
|
||||
response = client.get(f"/images/generate?prompt={test_prompt}")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_generate_image_endpoint_with_async_client(self, mock_images_directory):
|
||||
"""
|
||||
Test the image generation endpoint using an async client
|
||||
- Mocks the ImageGenerationService to return a valid image URL
|
||||
- Ensures that the endpoint returns the expected URL
|
||||
- Checks if the image generation function is called with the correct prompt
|
||||
"""
|
||||
async def run_test():
|
||||
app = FastAPI()
|
||||
app.include_router(IMAGES_ROUTER)
|
||||
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
|
||||
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.generate_image = AsyncMock(return_value="https://example.com/image.jpg")
|
||||
mock_service_class.return_value = mock_service_instance
|
||||
|
||||
response = await ac.get("/images/generate?prompt=test")
|
||||
assert response.status_code == 200
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
|
@ -55,3 +55,9 @@ def get_custom_model_env():
|
|||
|
||||
def get_pexels_api_key_env():
|
||||
return os.getenv("PEXELS_API_KEY")
|
||||
|
||||
def get_image_provider_env():
|
||||
return os.getenv("IMAGE_PROVIDER")
|
||||
|
||||
def get_pixabay_api_key_env():
|
||||
return os.getenv("PIXABAY_API_KEY")
|
||||
41
servers/fastapi/utils/image_provider.py
Normal file
41
servers/fastapi/utils/image_provider.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
from enums.image_provider import ImageProvider
|
||||
|
||||
|
||||
def is_pixels_selected() -> bool:
|
||||
return ImageProvider.PEXELS == get_selected_image_provider()
|
||||
|
||||
|
||||
def is_pixabay_selected() -> bool:
|
||||
return ImageProvider.PIXABAY == get_selected_image_provider()
|
||||
|
||||
|
||||
def is_imagen_selected() -> bool:
|
||||
return ImageProvider.IMAGEN == get_selected_image_provider()
|
||||
|
||||
|
||||
def is_dalle3_selected() -> bool:
|
||||
return ImageProvider.DALLE3 == get_selected_image_provider()
|
||||
|
||||
|
||||
def get_selected_image_provider() -> ImageProvider:
|
||||
"""
|
||||
Get the selected image provider from environment variables.
|
||||
Returns:
|
||||
ImageProvider: The selected image provider.
|
||||
"""
|
||||
return ImageProvider(os.getenv("IMAGE_PROVIDER"))
|
||||
|
||||
|
||||
def get_image_provider_api_key() -> str:
|
||||
selected_image_provider = get_selected_image_provider()
|
||||
if selected_image_provider == ImageProvider.PEXELS:
|
||||
return os.getenv("PEXELS_API_KEY")
|
||||
elif selected_image_provider == ImageProvider.PIXABAY:
|
||||
return os.getenv("PIXABAY_API_KEY")
|
||||
elif selected_image_provider == ImageProvider.IMAGEN:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
elif selected_image_provider == ImageProvider.DALLE3:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
else:
|
||||
raise ValueError(f"Invalid image provider: {selected_image_provider}")
|
||||
|
|
@ -9,9 +9,14 @@ from utils.llm_provider import (
|
|||
is_ollama_selected,
|
||||
)
|
||||
from utils.ollama import pull_ollama_model
|
||||
from utils.image_provider import (
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
is_imagen_selected,
|
||||
is_dalle3_selected,
|
||||
)
|
||||
|
||||
|
||||
async def check_llm_model_availability():
|
||||
async def check_llm_and_image_provider_api_or_model_availability():
|
||||
can_change_keys = get_can_change_keys_env() != "false"
|
||||
if not can_change_keys:
|
||||
if get_llm_provider() == LLMProvider.OPENAI:
|
||||
|
|
@ -58,3 +63,22 @@ async def check_llm_model_availability():
|
|||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
elif is_pixels_selected():
|
||||
pexels_api_key = os.getenv("PEXELS_API_KEY")
|
||||
if not pexels_api_key:
|
||||
raise Exception("PEXELS_API_KEY must be provided")
|
||||
|
||||
elif is_pixabay_selected():
|
||||
pixabay_api_key = os.getenv("PIXABAY_API_KEY")
|
||||
if not pixabay_api_key:
|
||||
raise Exception("PIXABAY_API_KEY must be provided")
|
||||
|
||||
elif is_imagen_selected():
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
|
||||
elif is_dalle3_selected():
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
|
|
@ -43,3 +43,10 @@ def set_custom_model_env(value):
|
|||
|
||||
def set_pexels_api_key_env(value):
|
||||
os.environ["PEXELS_API_KEY"] = value
|
||||
|
||||
def set_image_provider_env(value):
|
||||
os.environ["IMAGE_PROVIDER"] = value
|
||||
|
||||
|
||||
def set_pixabay_api_key_env(value):
|
||||
os.environ["PIXABAY_API_KEY"] = value
|
||||
|
|
@ -13,6 +13,8 @@ from utils.get_env import (
|
|||
get_openai_api_key_env,
|
||||
get_pexels_api_key_env,
|
||||
get_user_config_path_env,
|
||||
get_image_provider_env,
|
||||
get_pixabay_api_key_env
|
||||
)
|
||||
from utils.set_env import (
|
||||
set_custom_llm_api_key_env,
|
||||
|
|
@ -24,6 +26,8 @@ from utils.set_env import (
|
|||
set_ollama_url_env,
|
||||
set_openai_api_key_env,
|
||||
set_pexels_api_key_env,
|
||||
set_image_provider_env,
|
||||
set_pixabay_api_key_env
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -49,6 +53,8 @@ def get_user_config():
|
|||
CUSTOM_LLM_API_KEY=existing_config.CUSTOM_LLM_API_KEY
|
||||
or get_custom_llm_api_key_env(),
|
||||
CUSTOM_MODEL=existing_config.CUSTOM_MODEL or get_custom_model_env(),
|
||||
IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(),
|
||||
PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(),
|
||||
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(),
|
||||
)
|
||||
|
||||
|
|
@ -71,5 +77,9 @@ def update_env_with_user_config():
|
|||
set_custom_llm_api_key_env(user_config.CUSTOM_LLM_API_KEY)
|
||||
if user_config.CUSTOM_MODEL:
|
||||
set_custom_model_env(user_config.CUSTOM_MODEL)
|
||||
if user_config.IMAGE_PROVIDER:
|
||||
set_image_provider_env(user_config.IMAGE_PROVIDER)
|
||||
if user_config.PIXABAY_API_KEY:
|
||||
set_pixabay_api_key_env(user_config.PIXABAY_API_KEY)
|
||||
if user_config.PEXELS_API_KEY:
|
||||
set_pexels_api_key_env(user_config.PEXELS_API_KEY)
|
||||
|
|
|
|||
|
|
@ -1,36 +1,36 @@
|
|||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { NextResponse } from "next/server";
|
||||
import fs from "fs";
|
||||
|
||||
const userConfigPath = process.env.USER_CONFIG_PATH!;
|
||||
const canChangeKeys = process.env.CAN_CHANGE_KEYS !== 'false';
|
||||
|
||||
const canChangeKeys = process.env.CAN_CHANGE_KEYS !== "false";
|
||||
console.log("UserConfigPath:", userConfigPath);
|
||||
export async function GET() {
|
||||
if (!canChangeKeys) {
|
||||
return NextResponse.json({
|
||||
error: 'You are not allowed to access this resource',
|
||||
})
|
||||
error: "You are not allowed to access this resource",
|
||||
});
|
||||
}
|
||||
|
||||
if (!fs.existsSync(userConfigPath)) {
|
||||
return NextResponse.json({})
|
||||
return NextResponse.json({});
|
||||
}
|
||||
const configData = fs.readFileSync(userConfigPath, 'utf-8')
|
||||
return NextResponse.json(JSON.parse(configData))
|
||||
const configData = fs.readFileSync(userConfigPath, "utf-8");
|
||||
return NextResponse.json(JSON.parse(configData));
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
if (!canChangeKeys) {
|
||||
return NextResponse.json({
|
||||
error: 'You are not allowed to access this resource',
|
||||
})
|
||||
error: "You are not allowed to access this resource",
|
||||
});
|
||||
}
|
||||
|
||||
const userConfig = await request.json()
|
||||
const userConfig = await request.json();
|
||||
|
||||
let existingConfig: LLMConfig = {}
|
||||
let existingConfig: LLMConfig = {};
|
||||
if (fs.existsSync(userConfigPath)) {
|
||||
const configData = fs.readFileSync(userConfigPath, 'utf-8')
|
||||
existingConfig = JSON.parse(configData)
|
||||
const configData = fs.readFileSync(userConfigPath, "utf-8");
|
||||
existingConfig = JSON.parse(configData);
|
||||
}
|
||||
const mergedConfig: LLMConfig = {
|
||||
LLM: userConfig.LLM || existingConfig.LLM,
|
||||
|
|
@ -39,11 +39,18 @@ export async function POST(request: Request) {
|
|||
OLLAMA_URL: userConfig.OLLAMA_URL || existingConfig.OLLAMA_URL,
|
||||
OLLAMA_MODEL: userConfig.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL,
|
||||
CUSTOM_LLM_URL: userConfig.CUSTOM_LLM_URL || existingConfig.CUSTOM_LLM_URL,
|
||||
CUSTOM_LLM_API_KEY: userConfig.CUSTOM_LLM_API_KEY || existingConfig.CUSTOM_LLM_API_KEY,
|
||||
CUSTOM_LLM_API_KEY:
|
||||
userConfig.CUSTOM_LLM_API_KEY || existingConfig.CUSTOM_LLM_API_KEY,
|
||||
CUSTOM_MODEL: userConfig.CUSTOM_MODEL || existingConfig.CUSTOM_MODEL,
|
||||
PIXABAY_API_KEY:
|
||||
userConfig.PIXABAY_API_KEY || existingConfig.PIXABAY_API_KEY,
|
||||
IMAGE_PROVIDER: userConfig.IMAGE_PROVIDER || existingConfig.IMAGE_PROVIDER,
|
||||
PEXELS_API_KEY: userConfig.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY,
|
||||
USE_CUSTOM_URL: userConfig.USE_CUSTOM_URL === undefined ? existingConfig.USE_CUSTOM_URL : userConfig.USE_CUSTOM_URL,
|
||||
}
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig))
|
||||
return NextResponse.json(mergedConfig)
|
||||
}
|
||||
USE_CUSTOM_URL:
|
||||
userConfig.USE_CUSTOM_URL === undefined
|
||||
? existingConfig.USE_CUSTOM_URL
|
||||
: userConfig.USE_CUSTOM_URL,
|
||||
};
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig));
|
||||
return NextResponse.json(mergedConfig);
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
2
servers/nextjs/types/global.d.ts
vendored
2
servers/nextjs/types/global.d.ts
vendored
|
|
@ -22,6 +22,8 @@ interface LLMConfig {
|
|||
CUSTOM_LLM_URL?: string;
|
||||
CUSTOM_LLM_API_KEY?: string;
|
||||
CUSTOM_MODEL?: string;
|
||||
IMAGE_PROVIDER?: string;
|
||||
PIXABAY_API_KEY?: string;
|
||||
PEXELS_API_KEY?: string;
|
||||
|
||||
// Only used in UI settings
|
||||
|
|
|
|||
|
|
@ -3,29 +3,68 @@ import { store } from "@/store/store";
|
|||
|
||||
export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => {
|
||||
if (!hasValidLLMConfig(llmConfig)) {
|
||||
throw new Error('Provided configuration is not valid');
|
||||
throw new Error("Provided configuration is not valid");
|
||||
}
|
||||
|
||||
await fetch('/api/user-config', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(llmConfig)
|
||||
console.log("StoreHelperLLMConfig: Saving LLM config", llmConfig);
|
||||
await fetch("/api/user-config", {
|
||||
method: "POST",
|
||||
body: JSON.stringify(llmConfig),
|
||||
});
|
||||
|
||||
store.dispatch(setLLMConfig(llmConfig));
|
||||
}
|
||||
};
|
||||
|
||||
export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
||||
if (!llmConfig.LLM) return false;
|
||||
if (!llmConfig.IMAGE_PROVIDER) return false;
|
||||
const OPENAI_API_KEY = llmConfig.OPENAI_API_KEY;
|
||||
const GOOGLE_API_KEY = llmConfig.GOOGLE_API_KEY;
|
||||
|
||||
const isOllamaConfigValid = llmConfig.OLLAMA_MODEL !== '' && llmConfig.OLLAMA_MODEL !== null && llmConfig.OLLAMA_MODEL !== undefined && llmConfig.OLLAMA_URL !== '' && llmConfig.OLLAMA_URL !== null && llmConfig.OLLAMA_URL !== undefined;
|
||||
const isCustomConfigValid = llmConfig.CUSTOM_LLM_URL !== '' && llmConfig.CUSTOM_LLM_URL !== null && llmConfig.CUSTOM_LLM_URL !== undefined && llmConfig.CUSTOM_MODEL !== '' && llmConfig.CUSTOM_MODEL !== null && llmConfig.CUSTOM_MODEL !== undefined;
|
||||
const isOllamaConfigValid =
|
||||
llmConfig.OLLAMA_MODEL !== "" &&
|
||||
llmConfig.OLLAMA_MODEL !== null &&
|
||||
llmConfig.OLLAMA_MODEL !== undefined &&
|
||||
llmConfig.OLLAMA_URL !== "" &&
|
||||
llmConfig.OLLAMA_URL !== null &&
|
||||
llmConfig.OLLAMA_URL !== undefined;
|
||||
|
||||
return llmConfig.LLM === 'openai' ?
|
||||
OPENAI_API_KEY !== '' && OPENAI_API_KEY !== null && OPENAI_API_KEY !== undefined :
|
||||
llmConfig.LLM === 'google' ?
|
||||
GOOGLE_API_KEY !== '' && GOOGLE_API_KEY !== null && GOOGLE_API_KEY !== undefined :
|
||||
llmConfig.LLM === 'ollama' ? isOllamaConfigValid :
|
||||
llmConfig.LLM === 'custom' ? isCustomConfigValid : false;
|
||||
}
|
||||
const isCustomConfigValid =
|
||||
llmConfig.CUSTOM_LLM_URL !== "" &&
|
||||
llmConfig.CUSTOM_LLM_URL !== null &&
|
||||
llmConfig.CUSTOM_LLM_URL !== undefined &&
|
||||
llmConfig.CUSTOM_MODEL !== "" &&
|
||||
llmConfig.CUSTOM_MODEL !== null &&
|
||||
llmConfig.CUSTOM_MODEL !== undefined;
|
||||
|
||||
const isImageConfigValid = () => {
|
||||
switch (llmConfig.IMAGE_PROVIDER) {
|
||||
case "pexels":
|
||||
return llmConfig.PEXELS_API_KEY && llmConfig.PEXELS_API_KEY !== "";
|
||||
case "pixabay":
|
||||
return llmConfig.PIXABAY_API_KEY && llmConfig.PIXABAY_API_KEY !== "";
|
||||
case "dall-e-3":
|
||||
return OPENAI_API_KEY && OPENAI_API_KEY !== "";
|
||||
case "imagen":
|
||||
return GOOGLE_API_KEY && GOOGLE_API_KEY !== "";
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const isLLMConfigValid =
|
||||
llmConfig.LLM === "openai"
|
||||
? OPENAI_API_KEY !== "" &&
|
||||
OPENAI_API_KEY !== null &&
|
||||
OPENAI_API_KEY !== undefined
|
||||
: llmConfig.LLM === "google"
|
||||
? GOOGLE_API_KEY !== "" &&
|
||||
GOOGLE_API_KEY !== null &&
|
||||
GOOGLE_API_KEY !== undefined
|
||||
: llmConfig.LLM === "ollama"
|
||||
? isOllamaConfigValid
|
||||
: llmConfig.LLM === "custom"
|
||||
? isCustomConfigValid
|
||||
: false;
|
||||
|
||||
return isLLMConfigValid && isImageConfigValid();
|
||||
};
|
||||
|
|
|
|||
5
start.js
5
start.js
|
|
@ -1,3 +1,5 @@
|
|||
/* This script starts the FastAPI and Next.js servers, setting up user configuration if necessary. It reads environment variables to configure API keys and other settings, ensuring that the user configuration file is created if it doesn't exist. The script also handles the starting of both servers and keeps the Node.js process alive until one of the servers exits. */
|
||||
|
||||
const path = require('path');
|
||||
const { spawn } = require('child_process');
|
||||
const fs = require('fs');
|
||||
|
|
@ -43,12 +45,13 @@ const setupUserConfigFromEnv = () => {
|
|||
CUSTOM_LLM_API_KEY: process.env.CUSTOM_LLM_API_KEY || existingConfig.CUSTOM_LLM_API_KEY,
|
||||
CUSTOM_MODEL: process.env.CUSTOM_MODEL || existingConfig.CUSTOM_MODEL,
|
||||
PEXELS_API_KEY: process.env.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY,
|
||||
PIXABAY_API_KEY: process.env.PIXABAY_API_KEY || existingConfig.PIXABAY_API_KEY,
|
||||
IMAGE_PROVIDER: process.env.IMAGE_PROVIDER || existingConfig.IMAGE_PROVIDER,
|
||||
USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL,
|
||||
};
|
||||
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(userConfig));
|
||||
}
|
||||
|
||||
const startServers = async () => {
|
||||
|
||||
const fastApiProcess = spawn(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue