feat: replace 'imagen' with 'gemini_flash' across image provider configurations and related services
This commit is contained in:
parent
21dca979ce
commit
2171dba4e5
8 changed files with 27 additions and 27 deletions
|
|
@ -3,5 +3,5 @@ from enum import Enum
|
|||
class ImageProvider(Enum):
|
||||
PEXELS = "pexels"
|
||||
PIXABAY = "pixabay"
|
||||
IMAGEN = "imagen"
|
||||
GEMINI_FLASH = "gemini_flash"
|
||||
DALLE3 = "dall-e-3"
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from utils.llm_provider import (
|
|||
from utils.image_provider import (
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
is_imagen_selected,
|
||||
is_gemini_flash_selected,
|
||||
is_dalle3_selected
|
||||
)
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ class ImageGenerationService:
|
|||
return self.get_image_from_pixabay
|
||||
elif is_pixels_selected():
|
||||
return self.get_image_from_pexels
|
||||
elif is_imagen_selected():
|
||||
elif is_gemini_flash_selected():
|
||||
return self.generate_image_google
|
||||
elif is_dalle3_selected():
|
||||
return self.generate_image_openai
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class TestImageGenerationService:
|
|||
"""
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
|
@ -65,7 +65,7 @@ class TestImageGenerationService:
|
|||
"""
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
|
@ -79,7 +79,7 @@ class TestImageGenerationService:
|
|||
"""
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=True):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
|
@ -120,7 +120,7 @@ class TestImageGenerationService:
|
|||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_key"}):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ class TestImageGenerationService:
|
|||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}):
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=True):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
||||
|
|
@ -187,7 +187,7 @@ class TestImageGenerationService:
|
|||
async def run_test():
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
|
@ -209,7 +209,7 @@ class TestImageGenerationService:
|
|||
async def run_test():
|
||||
with patch('services.image_generation_service.is_pixels_selected', return_value=True):
|
||||
with patch('services.image_generation_service.is_pixabay_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_imagen_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False):
|
||||
with patch('services.image_generation_service.is_dalle3_selected', return_value=False):
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}):
|
||||
service = ImageGenerationService(mock_images_directory)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ def is_pixabay_selected() -> bool:
|
|||
return ImageProvider.PIXABAY == get_selected_image_provider()
|
||||
|
||||
|
||||
def is_imagen_selected() -> bool:
|
||||
return ImageProvider.IMAGEN == get_selected_image_provider()
|
||||
def is_gemini_flash_selected() -> bool:
|
||||
return ImageProvider.GEMINI_FLASH == get_selected_image_provider()
|
||||
|
||||
|
||||
def is_dalle3_selected() -> bool:
|
||||
|
|
@ -33,7 +33,7 @@ def get_image_provider_api_key() -> str:
|
|||
return os.getenv("PEXELS_API_KEY")
|
||||
elif selected_image_provider == ImageProvider.PIXABAY:
|
||||
return os.getenv("PIXABAY_API_KEY")
|
||||
elif selected_image_provider == ImageProvider.IMAGEN:
|
||||
elif selected_image_provider == ImageProvider.GEMINI_FLASH:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
elif selected_image_provider == ImageProvider.DALLE3:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from utils.ollama import pull_ollama_model
|
|||
from utils.image_provider import (
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
is_imagen_selected,
|
||||
is_gemini_flash_selected,
|
||||
is_dalle3_selected,
|
||||
)
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ async def check_llm_and_image_provider_api_or_model_availability():
|
|||
if not pixabay_api_key:
|
||||
raise Exception("PIXABAY_API_KEY must be provided")
|
||||
|
||||
elif is_imagen_selected():
|
||||
elif is_gemini_flash_selected():
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
|
|
|
|||
|
|
@ -50,9 +50,9 @@ const IMAGE_PROVIDERS: Record<string, ImageProviderConfig> = {
|
|||
placeholder: "Enter your OpenAI API key",
|
||||
apiKeyField: "OPENAI_API_KEY",
|
||||
},
|
||||
imagen: {
|
||||
title: "imagen",
|
||||
description: "Required for using Imagen services from Google",
|
||||
gemini_flash: {
|
||||
title: "gemini_flash",
|
||||
description: "Required for using Gemini Flash services from Google",
|
||||
placeholder: "Enter your Google API key",
|
||||
apiKeyField: "GOOGLE_API_KEY",
|
||||
},
|
||||
|
|
@ -943,7 +943,7 @@ const SettingsPage = () => {
|
|||
return <></>;
|
||||
}
|
||||
|
||||
if (provider.title === "imagen" && llmConfig.LLM === "google") {
|
||||
if (provider.title === "gemini_flash" && llmConfig.LLM === "google") {
|
||||
return <> </>;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -87,9 +87,9 @@ const IMAGE_PROVIDERS: Record<string, ImageProviderOption> = {
|
|||
requiresApiKey: true,
|
||||
apiKeyField: "OPENAI_API_KEY",
|
||||
},
|
||||
imagen: {
|
||||
value: "imagen",
|
||||
label: "Imagen",
|
||||
gemini_flash: {
|
||||
value: "gemini_flash",
|
||||
label: "Gemini Flash",
|
||||
description: "Google's primary image generation model",
|
||||
icon: "/icons/google.png",
|
||||
requiresApiKey: true,
|
||||
|
|
@ -142,8 +142,8 @@ const PROVIDER_CONFIGS: Record<string, ProviderConfig> = {
|
|||
],
|
||||
imageModels: [
|
||||
{
|
||||
value: "imagen",
|
||||
label: "Imagen",
|
||||
value: "gemini_flash",
|
||||
label: "Gemini Flash",
|
||||
description: "Google's primary image generation model",
|
||||
icon: "/icons/google.png",
|
||||
size: "8GB",
|
||||
|
|
@ -323,7 +323,7 @@ export default function Home() {
|
|||
if (provider === "openai") {
|
||||
newConfig.IMAGE_PROVIDER = "dall-e-3";
|
||||
} else if (provider === "google") {
|
||||
newConfig.IMAGE_PROVIDER = "imagen";
|
||||
newConfig.IMAGE_PROVIDER = "gemini_flash";
|
||||
} else {
|
||||
newConfig.IMAGE_PROVIDER = "pexels"; // default for ollama and custom
|
||||
}
|
||||
|
|
@ -967,7 +967,7 @@ export default function Home() {
|
|||
return <></>;
|
||||
}
|
||||
|
||||
if (provider.value === "imagen" && llmConfig.LLM === "google") {
|
||||
if (provider.value === "gemini_flash" && llmConfig.LLM === "google") {
|
||||
return <> </>;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
return llmConfig.PIXABAY_API_KEY && llmConfig.PIXABAY_API_KEY !== "";
|
||||
case "dall-e-3":
|
||||
return OPENAI_API_KEY && OPENAI_API_KEY !== "";
|
||||
case "imagen":
|
||||
case "gemini_flash":
|
||||
return GOOGLE_API_KEY && GOOGLE_API_KEY !== "";
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue