feat: Implement stock image search functionality with Pexels and Pixabay integration
This commit is contained in:
parent
ca186c6d20
commit
cab99c6bd2
10 changed files with 416 additions and 37 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from typing import List
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, HTTPException
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, HTTPException, Query, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
|
|
@ -8,6 +8,9 @@ from models.sql.image_asset import ImageAsset
|
|||
from services.database import get_async_session
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
from utils.get_env import get_pexels_api_key_env, get_pixabay_api_key_env
|
||||
from utils.image_provider import get_selected_image_provider
|
||||
from enums.image_provider import ImageProvider
|
||||
import os
|
||||
import uuid
|
||||
from utils.file_utils import get_file_name_with_random_uuid
|
||||
|
|
@ -15,6 +18,75 @@ from utils.file_utils import get_file_name_with_random_uuid
|
|||
IMAGES_ROUTER = APIRouter(prefix="/images", tags=["Images"])
|
||||
|
||||
|
||||
def _normalize_stock_provider(provider: str | None) -> str:
|
||||
normalized_provider = (provider or "").strip().lower()
|
||||
if normalized_provider in {"pixels", "pixel", "pexel"}:
|
||||
normalized_provider = "pexels"
|
||||
|
||||
if normalized_provider:
|
||||
if normalized_provider in {"pexels", "pixabay"}:
|
||||
return normalized_provider
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="provider must be either 'pexels' or 'pixabay'",
|
||||
)
|
||||
|
||||
selected_provider = get_selected_image_provider()
|
||||
if selected_provider == ImageProvider.PIXABAY:
|
||||
return "pixabay"
|
||||
return "pexels"
|
||||
|
||||
|
||||
@IMAGES_ROUTER.get("/search", response_model=List[str])
|
||||
async def search_stock_images(
|
||||
query: str,
|
||||
limit: int = Query(default=12, ge=1, le=30),
|
||||
provider: str | None = Query(default=None),
|
||||
strict_api_key: bool = Query(default=False),
|
||||
x_provider_api_key: str | None = Header(default=None, alias="X-Provider-Api-Key"),
|
||||
):
|
||||
normalized_provider = _normalize_stock_provider(provider)
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
|
||||
if normalized_provider == "pexels":
|
||||
api_key = (x_provider_api_key or get_pexels_api_key_env() or "").strip()
|
||||
if strict_api_key and not api_key:
|
||||
raise HTTPException(status_code=401, detail="Pexels API key is required")
|
||||
|
||||
# Pexels can return cached public responses for common queries.
|
||||
# Use a nonce query in strict mode to force a real auth check.
|
||||
if strict_api_key:
|
||||
validation_query = f"__presenton_auth_check_{uuid.uuid4().hex}"
|
||||
await image_generation_service.get_image_from_pexels(
|
||||
validation_query,
|
||||
api_key=api_key,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
images = await image_generation_service.get_image_from_pexels(
|
||||
query,
|
||||
api_key=api_key,
|
||||
limit=limit,
|
||||
)
|
||||
if isinstance(images, str):
|
||||
return [images] if images else []
|
||||
return images
|
||||
|
||||
api_key = (x_provider_api_key or get_pixabay_api_key_env() or "").strip()
|
||||
if strict_api_key and not api_key:
|
||||
raise HTTPException(status_code=401, detail="Pixabay API key is required")
|
||||
|
||||
images = await image_generation_service.get_image_from_pixabay(
|
||||
query,
|
||||
api_key=api_key,
|
||||
limit=limit,
|
||||
)
|
||||
if isinstance(images, str):
|
||||
return [images] if images else []
|
||||
return images
|
||||
|
||||
|
||||
@IMAGES_ROUTER.get("/generate")
|
||||
async def generate_image(
|
||||
prompt: str, sql_session: AsyncSession = Depends(get_async_session)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class DocumentsLoader:
|
||||
DECOMPOSE_TIMEOUT_SECONDS = 600
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -39,7 +40,9 @@ class DocumentsLoader:
|
|||
):
|
||||
self._file_paths = file_paths
|
||||
self._ocr_language = presentation_language_to_ocr_code(presentation_language)
|
||||
self.liteparse_service = LiteParseService()
|
||||
self.liteparse_service = LiteParseService(
|
||||
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS
|
||||
)
|
||||
self.document_conversion_service = DocumentConversionService()
|
||||
self.document_service: Any = (
|
||||
DocumentServiceCls() if DocumentServiceCls is not None else None
|
||||
|
|
@ -142,6 +145,7 @@ class DocumentsLoader:
|
|||
converted_path = self.document_conversion_service.convert_office_to_pdf(
|
||||
file_path,
|
||||
temp_dir,
|
||||
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
|
||||
)
|
||||
return self._parse_with_liteparse(converted_path)
|
||||
|
||||
|
|
@ -149,6 +153,7 @@ class DocumentsLoader:
|
|||
converted_path = self.document_conversion_service.convert_office_to_pdf(
|
||||
file_path,
|
||||
conversion_dir,
|
||||
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
|
||||
)
|
||||
return self._parse_with_liteparse(converted_path)
|
||||
|
||||
|
|
@ -157,6 +162,7 @@ class DocumentsLoader:
|
|||
converted_path = self.document_conversion_service.convert_image_to_png(
|
||||
file_path,
|
||||
temp_dir,
|
||||
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
|
||||
)
|
||||
return self._parse_with_liteparse(converted_path)
|
||||
|
||||
|
|
@ -164,6 +170,7 @@ class DocumentsLoader:
|
|||
converted_path = self.document_conversion_service.convert_image_to_png(
|
||||
file_path,
|
||||
conversion_dir,
|
||||
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
|
||||
)
|
||||
return self._parse_with_liteparse(converted_path)
|
||||
|
||||
|
|
|
|||
|
|
@ -221,24 +221,92 @@ class ImageGenerationService:
|
|||
prompt, output_directory, "gemini-3-pro-image-preview"
|
||||
)
|
||||
|
||||
async def get_image_from_pexels(self, prompt: str) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
response = await session.get(
|
||||
f"https://api.pexels.com/v1/search?query={prompt}&per_page=1",
|
||||
headers={"Authorization": f"{get_pexels_api_key_env()}"},
|
||||
)
|
||||
data = await response.json()
|
||||
image_url = data["photos"][0]["src"]["large"]
|
||||
return image_url
|
||||
async def get_image_from_pexels(
|
||||
self, prompt: str, api_key: str | None = None, limit: int = 1
|
||||
) -> str | list[str]:
|
||||
per_page = max(1, min(limit, 80))
|
||||
resolved_api_key = (api_key or get_pexels_api_key_env() or "").strip()
|
||||
|
||||
async def get_image_from_pixabay(self, prompt: str) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
response = await session.get(
|
||||
f"https://pixabay.com/api/?key={get_pixabay_api_key_env()}&q={prompt}&image_type=photo&per_page=3"
|
||||
"https://api.pexels.com/v1/search",
|
||||
params={"query": prompt, "per_page": per_page},
|
||||
headers={"Authorization": resolved_api_key} if resolved_api_key else {},
|
||||
timeout=aiohttp.ClientTimeout(total=20),
|
||||
)
|
||||
|
||||
if response.status in {401, 403}:
|
||||
raise HTTPException(status_code=401, detail="Invalid Pexels API key")
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Pexels request failed: {error_text}",
|
||||
)
|
||||
|
||||
data = await response.json()
|
||||
image_url = data["hits"][0]["largeImageURL"]
|
||||
return image_url
|
||||
photos = data.get("photos", [])
|
||||
image_urls = [
|
||||
photo.get("src", {}).get("large")
|
||||
for photo in photos
|
||||
if photo.get("src", {}).get("large")
|
||||
]
|
||||
|
||||
if limit <= 1:
|
||||
return image_urls[0] if image_urls else ""
|
||||
return image_urls[:limit]
|
||||
|
||||
async def get_image_from_pixabay(
|
||||
self, prompt: str, api_key: str | None = None, limit: int = 1
|
||||
) -> str | list[str]:
|
||||
per_page = max(3, min(limit, 200))
|
||||
resolved_api_key = (api_key or get_pixabay_api_key_env() or "").strip()
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
response = await session.get(
|
||||
"https://pixabay.com/api/",
|
||||
params={
|
||||
"key": resolved_api_key,
|
||||
"q": prompt[:99],
|
||||
"image_type": "photo",
|
||||
"per_page": per_page,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=20),
|
||||
)
|
||||
|
||||
if response.status in {401, 403}:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid Pixabay API key: {error_text}",
|
||||
)
|
||||
if response.status == 400:
|
||||
error_text = await response.text()
|
||||
if "api key" in error_text.lower():
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid Pixabay API key: {error_text}",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Pixabay request invalid: {error_text}",
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Pixabay request failed: {error_text}",
|
||||
)
|
||||
|
||||
data = await response.json()
|
||||
hits = data.get("hits", [])
|
||||
image_urls = [
|
||||
hit.get("largeImageURL") for hit in hits if hit.get("largeImageURL")
|
||||
]
|
||||
|
||||
if limit <= 1:
|
||||
return image_urls[0] if image_urls else ""
|
||||
return image_urls[:limit]
|
||||
|
||||
async def generate_image_comfyui(self, prompt: str, output_directory: str) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -398,3 +398,38 @@ class TestImageGenerationEndpoint:
|
|||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_search_stock_images_defaults_to_selected_pixabay(self, client, mock_images_directory):
|
||||
"""
|
||||
Test stock image search defaults to IMAGE_PROVIDER when provider query param is omitted
|
||||
- Sets IMAGE_PROVIDER to pixabay
|
||||
- Ensures /images/search uses Pixabay instead of returning provider validation error
|
||||
"""
|
||||
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
|
||||
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
|
||||
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
|
||||
mock_service_instance = Mock()
|
||||
mock_service_instance.get_image_from_pixabay = AsyncMock(
|
||||
return_value=["https://example.com/pixabay_image.jpg"]
|
||||
)
|
||||
mock_service_instance.get_image_from_pexels = AsyncMock(
|
||||
return_value=["https://example.com/pexels_image.jpg"]
|
||||
)
|
||||
mock_service_class.return_value = mock_service_instance
|
||||
|
||||
response = client.get("/images/search?query=business&limit=1")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["https://example.com/pixabay_image.jpg"]
|
||||
mock_service_instance.get_image_from_pixabay.assert_awaited_once()
|
||||
mock_service_instance.get_image_from_pexels.assert_not_called()
|
||||
|
||||
def test_search_stock_images_invalid_provider_returns_400(self, client):
|
||||
"""
|
||||
Test stock image search validates invalid provider values
|
||||
- Ensures unsupported providers return HTTP 400 with clear guidance
|
||||
"""
|
||||
response = client.get("/images/search?query=business&provider=invalid-provider")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "provider must be either 'pexels' or 'pixabay'"
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ import TextProvider from "./TextProvider";
|
|||
import ImageProvider from "./ImageProvider";
|
||||
import PrivacySettings from "./PrivacySettings";
|
||||
import { IMAGE_PROVIDERS, LLM_PROVIDERS } from "@/utils/providerConstants";
|
||||
import { ImagesApi } from "@/app/(presentation-generator)/services/api/images";
|
||||
|
||||
const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]);
|
||||
|
||||
// Button state interface
|
||||
interface ButtonState {
|
||||
|
|
@ -72,6 +75,36 @@ const SettingsPage = () => {
|
|||
return 0;
|
||||
}, [downloadingModel?.downloaded, downloadingModel?.size]);
|
||||
|
||||
const ensureSelectedStockProviderReady = async (): Promise<boolean> => {
|
||||
if (llmConfig.DISABLE_IMAGE_GENERATION) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const provider = (llmConfig.IMAGE_PROVIDER || "").toLowerCase();
|
||||
if (!STOCK_IMAGE_PROVIDERS.has(provider)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const providerApiKey =
|
||||
provider === "pexels" ? llmConfig.PEXELS_API_KEY : llmConfig.PIXABAY_API_KEY;
|
||||
|
||||
try {
|
||||
await ImagesApi.searchStockImages("business", 1, {
|
||||
provider,
|
||||
apiKey: providerApiKey,
|
||||
strictApiKey: true,
|
||||
});
|
||||
return true;
|
||||
} catch (error: any) {
|
||||
notify.error(
|
||||
"Cannot save settings",
|
||||
error?.message ||
|
||||
`Unable to reach ${provider} with the provided API key. Please verify your settings and try again.`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveConfig = async () => {
|
||||
trackEvent(MixpanelEvent.Settings_SaveConfiguration_Button_Clicked, { pathname });
|
||||
const validationError = getLLMConfigValidationError(llmConfig);
|
||||
|
|
@ -79,6 +112,12 @@ const SettingsPage = () => {
|
|||
notify.error("Cannot save settings", validationError);
|
||||
return;
|
||||
}
|
||||
|
||||
const providerReady = await ensureSelectedStockProviderReady();
|
||||
if (!providerReady) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"use client";
|
||||
import React, { useEffect, useState, useRef } from "react";
|
||||
import { useSelector } from "react-redux";
|
||||
import {
|
||||
Sheet,
|
||||
SheetContent,
|
||||
|
|
@ -9,7 +10,7 @@ import {
|
|||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Wand2, Upload, Loader2, Delete, Trash } from "lucide-react";
|
||||
import { Wand2, Upload, Loader2, Trash } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { PresentationGenerationApi } from "../services/api/presentation-generation";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
|
@ -18,6 +19,10 @@ import { PreviousGeneratedImagesResponse } from "../services/api/params";
|
|||
import { trackEvent, MixpanelEvent } from "@/utils/mixpanel";
|
||||
import { ImagesApi } from "../services/api/images";
|
||||
import { ImageAssetResponse } from "../services/api/types";
|
||||
import { RootState } from "@/store/store";
|
||||
|
||||
const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]);
|
||||
|
||||
interface ImageEditorProps {
|
||||
initialImage: string | null;
|
||||
imageIdx?: number;
|
||||
|
|
@ -39,6 +44,10 @@ const ImageEditor = ({
|
|||
onFocusPointClick,
|
||||
onImageChange,
|
||||
}: ImageEditorProps) => {
|
||||
const llmConfig = useSelector((state: RootState) => state.userConfig.llm_config);
|
||||
const selectedImageProvider = (llmConfig?.IMAGE_PROVIDER || "").toLowerCase();
|
||||
const isStockImageProvider = STOCK_IMAGE_PROVIDERS.has(selectedImageProvider);
|
||||
|
||||
// State management
|
||||
const [previewImages, setPreviewImages] = useState(initialImage);
|
||||
const [previousGeneratedImages, setPreviousGeneratedImages] = useState<
|
||||
|
|
@ -53,6 +62,7 @@ const ImageEditor = ({
|
|||
const [isOpen, setIsOpen] = useState(true);
|
||||
const [uploadedImages, setUploadedImages] = useState<ImageAssetResponse[]>([]);
|
||||
const [uploadedImagesLoading, setUploadedImagesLoading] = useState(false);
|
||||
const [stockSearchResults, setStockSearchResults] = useState<string[]>([]);
|
||||
// Focus point and object fit for image editing
|
||||
const [isFocusPointMode, setIsFocusPointMode] = useState(false);
|
||||
const [focusPoint, setFocusPoint] = useState(
|
||||
|
|
@ -190,18 +200,40 @@ const ImageEditor = ({
|
|||
setError("Please enter a prompt");
|
||||
return;
|
||||
}
|
||||
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (!trimmedPrompt) {
|
||||
setError("Please enter a prompt");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsGenerating(true);
|
||||
setError(null);
|
||||
trackEvent(MixpanelEvent.ImageEditor_GenerateImage_API_Call);
|
||||
const response = await PresentationGenerationApi.generateImage({
|
||||
prompt: prompt,
|
||||
});
|
||||
|
||||
setPreviewImages(response);
|
||||
if (isStockImageProvider) {
|
||||
const providerApiKey =
|
||||
selectedImageProvider === "pexels"
|
||||
? llmConfig?.PEXELS_API_KEY
|
||||
: llmConfig?.PIXABAY_API_KEY;
|
||||
const results = await ImagesApi.searchStockImages(trimmedPrompt, 12, {
|
||||
provider: selectedImageProvider,
|
||||
apiKey: providerApiKey,
|
||||
});
|
||||
setStockSearchResults(results);
|
||||
if (results.length > 0) {
|
||||
setPreviewImages(results[0]);
|
||||
}
|
||||
} else {
|
||||
const response = await PresentationGenerationApi.generateImage({
|
||||
prompt: trimmedPrompt,
|
||||
});
|
||||
setPreviewImages(response);
|
||||
}
|
||||
} catch (err: any) {
|
||||
console.error("Error in image generation", err);
|
||||
setError(err.message || "Failed to generate image. Please try again.");
|
||||
setError(err.message || "Failed to fetch images. Please try again.");
|
||||
} finally {
|
||||
setIsGenerating(false);
|
||||
}
|
||||
|
|
@ -286,7 +318,7 @@ const ImageEditor = ({
|
|||
<Tabs defaultValue="generate" className="w-full" onValueChange={handleTabChange}>
|
||||
<TabsList className="grid bg-blue-100 border border-blue-300 w-full grid-cols-3 mx-auto">
|
||||
<TabsTrigger className="font-medium" value="generate">
|
||||
AI Generate
|
||||
{isStockImageProvider ? "Stock Search" : "AI Generate"}
|
||||
</TabsTrigger>
|
||||
<TabsTrigger className="font-medium" value="upload">
|
||||
Upload
|
||||
|
|
@ -305,10 +337,14 @@ const ImageEditor = ({
|
|||
|
||||
<div>
|
||||
<h3 className="text-base font-medium mb-2">
|
||||
Image Description
|
||||
{isStockImageProvider ? "Image Keyword" : "Image Description"}
|
||||
</h3>
|
||||
<Textarea
|
||||
placeholder="Describe the image you want to generate..."
|
||||
placeholder={
|
||||
isStockImageProvider
|
||||
? "Enter a keyword to search stock images..."
|
||||
: "Describe the image you want to generate..."
|
||||
}
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
className="min-h-[100px]"
|
||||
|
|
@ -321,35 +357,62 @@ const ImageEditor = ({
|
|||
disabled={!prompt || isGenerating}
|
||||
>
|
||||
<Wand2 className="w-4 h-4 mr-2" />
|
||||
{isGenerating ? "Generating..." : "Generate Image"}
|
||||
{isGenerating
|
||||
? (isStockImageProvider ? "Searching..." : "Generating...")
|
||||
: (isStockImageProvider ? "Search Images" : "Generate Image")}
|
||||
</Button>
|
||||
|
||||
{error && <p className="text-red-500 text-sm">{error}</p>}
|
||||
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
{isGenerating || !previewImages ? (
|
||||
{isGenerating ? (
|
||||
Array.from({ length: 4 }).map((_, index) => (
|
||||
<Skeleton
|
||||
key={index}
|
||||
className="aspect-[4/3] w-full rounded-lg"
|
||||
/>
|
||||
))
|
||||
) : isStockImageProvider ? (
|
||||
stockSearchResults.length > 0 ? (
|
||||
stockSearchResults.map((imageUrl, index) => (
|
||||
<div
|
||||
key={`${imageUrl}-${index}`}
|
||||
onClick={() => handleImageChange(imageUrl)}
|
||||
className="aspect-[4/3] w-full overflow-hidden rounded-lg border cursor-pointer hover:border-blue-500 transition-colors"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt={`Stock result ${index + 1}`}
|
||||
className="w-full h-full object-cover"
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<p className="text-sm text-gray-500 col-span-2">Search with a keyword to view stock images.</p>
|
||||
)
|
||||
) : (
|
||||
<div
|
||||
onClick={() => handleImageChange(previewImages)}
|
||||
className="aspect-[4/3] w-full overflow-hidden rounded-lg border cursor-pointer hover:border-blue-500 transition-colors"
|
||||
>
|
||||
{previewImages && (
|
||||
!previewImages ? (
|
||||
Array.from({ length: 4 }).map((_, index) => (
|
||||
<Skeleton
|
||||
key={index}
|
||||
className="aspect-[4/3] w-full rounded-lg"
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<div
|
||||
onClick={() => handleImageChange(previewImages)}
|
||||
className="aspect-[4/3] w-full overflow-hidden rounded-lg border cursor-pointer hover:border-blue-500 transition-colors"
|
||||
>
|
||||
<img
|
||||
src={previewImages}
|
||||
alt={`Preview`}
|
||||
className="w-full h-full object-cover"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
{previousGeneratedImages.length > 0 && (
|
||||
{!isStockImageProvider && previousGeneratedImages.length > 0 && (
|
||||
<div className="mt-4">
|
||||
<h3 className="text-sm font-medium mb-2">
|
||||
Previous Generated Images
|
||||
|
|
|
|||
|
|
@ -3,6 +3,12 @@ import { ApiResponseHandler } from "./api-error-handler";
|
|||
import { ImageAssetResponse } from "./types";
|
||||
import { getApiUrl } from "@/utils/api";
|
||||
|
||||
interface StockSearchOptions {
|
||||
provider?: string;
|
||||
apiKey?: string;
|
||||
strictApiKey?: boolean;
|
||||
}
|
||||
|
||||
|
||||
export class ImagesApi {
|
||||
|
||||
|
|
@ -43,6 +49,41 @@ export class ImagesApi {
|
|||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
static async searchStockImages(
|
||||
query: string,
|
||||
limit: number = 12,
|
||||
options: StockSearchOptions = {}
|
||||
): Promise<string[]> {
|
||||
try {
|
||||
const params = new URLSearchParams({
|
||||
query,
|
||||
limit: String(limit),
|
||||
});
|
||||
const normalizedProvider = (options.provider || "").trim().toLowerCase();
|
||||
if (normalizedProvider) {
|
||||
params.set("provider", normalizedProvider);
|
||||
}
|
||||
if (options.strictApiKey) {
|
||||
params.set("strict_api_key", "true");
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
const trimmedApiKey = (options.apiKey || "").trim();
|
||||
if (trimmedApiKey) {
|
||||
headers["X-Provider-Api-Key"] = trimmedApiKey;
|
||||
}
|
||||
|
||||
const response = await fetch(getApiUrl(`/api/v1/ppt/images/search?${params.toString()}`), {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
return await ApiResponseHandler.handleResponse(response, "Failed to search stock images") as string[];
|
||||
} catch (error:any) {
|
||||
console.log("Stock image search error:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
"use client";
|
||||
import React, { useState } from "react";
|
||||
import { useRouter, usePathname } from "next/navigation";
|
||||
import { useDispatch } from "react-redux";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { clearOutlines, setPresentationId } from "@/store/slices/presentationGeneration";
|
||||
import { PromptInput } from "./PromptInput";
|
||||
import { LanguageType, PresentationConfig, ToneType, VerbosityType } from "../type";
|
||||
|
|
@ -26,6 +26,10 @@ import Wrapper from "@/components/Wrapper";
|
|||
import { setPptGenUploadState } from "@/store/slices/presentationGenUpload";
|
||||
import { trackEvent, MixpanelEvent } from "@/utils/mixpanel";
|
||||
import { ConfigurationSelects } from "./ConfigurationSelects";
|
||||
import { RootState } from "@/store/store";
|
||||
import { ImagesApi } from "../../services/api/images";
|
||||
|
||||
const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]);
|
||||
|
||||
// Types for loading state
|
||||
interface LoadingState {
|
||||
|
|
@ -40,11 +44,12 @@ const UploadPage = () => {
|
|||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const dispatch = useDispatch();
|
||||
const llmConfig = useSelector((state: RootState) => state.userConfig.llm_config);
|
||||
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
const [config, setConfig] = useState<PresentationConfig>({
|
||||
slides: null,
|
||||
language: LanguageType.English,
|
||||
language: LanguageType.Auto,
|
||||
prompt: "",
|
||||
tone: ToneType.Default,
|
||||
verbosity: VerbosityType.Standard,
|
||||
|
|
@ -66,6 +71,36 @@ const UploadPage = () => {
|
|||
setConfig((prev) => ({ ...prev, [key]: value } as PresentationConfig));
|
||||
};
|
||||
|
||||
const ensureStockImageProviderReady = async (): Promise<boolean> => {
|
||||
if (llmConfig?.DISABLE_IMAGE_GENERATION) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const selectedProvider = (llmConfig?.IMAGE_PROVIDER || "").toLowerCase();
|
||||
if (!STOCK_IMAGE_PROVIDERS.has(selectedProvider)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
try {
|
||||
const providerApiKey =
|
||||
selectedProvider === "pexels"
|
||||
? llmConfig?.PEXELS_API_KEY
|
||||
: llmConfig?.PIXABAY_API_KEY;
|
||||
await ImagesApi.searchStockImages("business", 1, {
|
||||
provider: selectedProvider,
|
||||
apiKey: providerApiKey,
|
||||
strictApiKey: true,
|
||||
});
|
||||
return true;
|
||||
} catch (error: any) {
|
||||
toast.error(
|
||||
error?.message ||
|
||||
`Unable to reach ${selectedProvider} right now. Please check your API key/settings and try again.`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Validates the current configuration and files
|
||||
* @returns boolean indicating if the configuration is valid
|
||||
|
|
@ -76,6 +111,11 @@ const UploadPage = () => {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (files.length > 0 && config.language === LanguageType.Auto) {
|
||||
toast.error("Please choose a language before processing uploaded documents");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!config.prompt.trim() && files.length === 0) {
|
||||
toast.error("No Prompt or Document Provided");
|
||||
return false;
|
||||
|
|
@ -89,6 +129,9 @@ const UploadPage = () => {
|
|||
const handleGeneratePresentation = async () => {
|
||||
if (!validateConfiguration()) return;
|
||||
|
||||
const isStockProviderReady = await ensureStockImageProviderReady();
|
||||
if (!isStockProviderReady) return;
|
||||
|
||||
try {
|
||||
const hasUploadedAssets = files.length > 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ export enum ThemeType {
|
|||
|
||||
export enum LanguageType {
|
||||
// Major World Languages
|
||||
// Auto = "Auto",
|
||||
Auto = "Auto",
|
||||
English = "English",
|
||||
Spanish = "Spanish (Español)",
|
||||
French = "French (Français)",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import { IconSearch, ImageGenerate, ImageSearch, PreviousGeneratedImagesResponse
|
|||
import { ApiResponseHandler } from "./api-error-handler";
|
||||
|
||||
export class PresentationGenerationApi {
|
||||
private static readonly DECOMPOSE_TIMEOUT_MS = 10 * 60 * 1000;
|
||||
|
||||
static async uploadDoc(documents: File[]) {
|
||||
const formData = new FormData();
|
||||
|
||||
|
|
@ -29,6 +31,9 @@ export class PresentationGenerationApi {
|
|||
}
|
||||
|
||||
static async decomposeDocuments(documentKeys: string[]) {
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), this.DECOMPOSE_TIMEOUT_MS);
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/v1/ppt/files/decompose`,
|
||||
|
|
@ -39,13 +44,19 @@ export class PresentationGenerationApi {
|
|||
file_paths: documentKeys,
|
||||
}),
|
||||
cache: "no-cache",
|
||||
signal: controller.signal,
|
||||
}
|
||||
);
|
||||
|
||||
return await ApiResponseHandler.handleResponse(response, "Failed to decompose documents");
|
||||
} catch (error) {
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
throw new Error("File decomposition timed out after 10 minutes");
|
||||
}
|
||||
console.error("Error in Decompose Files", error);
|
||||
throw error;
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue