feat: Implement stock image search functionality with Pexels and Pixabay integration

This commit is contained in:
sudipnext 2026-04-07 18:00:16 +05:45
parent ca186c6d20
commit cab99c6bd2
10 changed files with 416 additions and 37 deletions

View file

@ -1,5 +1,5 @@
from typing import List
from fastapi import APIRouter, Depends, File, UploadFile, HTTPException
from fastapi import APIRouter, Depends, File, UploadFile, HTTPException, Query, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
@ -8,6 +8,9 @@ from models.sql.image_asset import ImageAsset
from services.database import get_async_session
from services.image_generation_service import ImageGenerationService
from utils.asset_directory_utils import get_images_directory
from utils.get_env import get_pexels_api_key_env, get_pixabay_api_key_env
from utils.image_provider import get_selected_image_provider
from enums.image_provider import ImageProvider
import os
import uuid
from utils.file_utils import get_file_name_with_random_uuid
@ -15,6 +18,75 @@ from utils.file_utils import get_file_name_with_random_uuid
IMAGES_ROUTER = APIRouter(prefix="/images", tags=["Images"])
def _normalize_stock_provider(provider: str | None) -> str:
normalized_provider = (provider or "").strip().lower()
if normalized_provider in {"pixels", "pixel", "pexel"}:
normalized_provider = "pexels"
if normalized_provider:
if normalized_provider in {"pexels", "pixabay"}:
return normalized_provider
raise HTTPException(
status_code=400,
detail="provider must be either 'pexels' or 'pixabay'",
)
selected_provider = get_selected_image_provider()
if selected_provider == ImageProvider.PIXABAY:
return "pixabay"
return "pexels"
@IMAGES_ROUTER.get("/search", response_model=List[str])
async def search_stock_images(
query: str,
limit: int = Query(default=12, ge=1, le=30),
provider: str | None = Query(default=None),
strict_api_key: bool = Query(default=False),
x_provider_api_key: str | None = Header(default=None, alias="X-Provider-Api-Key"),
):
normalized_provider = _normalize_stock_provider(provider)
image_generation_service = ImageGenerationService(get_images_directory())
if normalized_provider == "pexels":
api_key = (x_provider_api_key or get_pexels_api_key_env() or "").strip()
if strict_api_key and not api_key:
raise HTTPException(status_code=401, detail="Pexels API key is required")
# Pexels can return cached public responses for common queries.
# Use a nonce query in strict mode to force a real auth check.
if strict_api_key:
validation_query = f"__presenton_auth_check_{uuid.uuid4().hex}"
await image_generation_service.get_image_from_pexels(
validation_query,
api_key=api_key,
limit=1,
)
images = await image_generation_service.get_image_from_pexels(
query,
api_key=api_key,
limit=limit,
)
if isinstance(images, str):
return [images] if images else []
return images
api_key = (x_provider_api_key or get_pixabay_api_key_env() or "").strip()
if strict_api_key and not api_key:
raise HTTPException(status_code=401, detail="Pixabay API key is required")
images = await image_generation_service.get_image_from_pixabay(
query,
api_key=api_key,
limit=limit,
)
if isinstance(images, str):
return [images] if images else []
return images
@IMAGES_ROUTER.get("/generate")
async def generate_image(
prompt: str, sql_session: AsyncSession = Depends(get_async_session)

View file

@ -31,6 +31,7 @@ LOGGER = logging.getLogger(__name__)
class DocumentsLoader:
DECOMPOSE_TIMEOUT_SECONDS = 600
def __init__(
self,
@ -39,7 +40,9 @@ class DocumentsLoader:
):
self._file_paths = file_paths
self._ocr_language = presentation_language_to_ocr_code(presentation_language)
self.liteparse_service = LiteParseService()
self.liteparse_service = LiteParseService(
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS
)
self.document_conversion_service = DocumentConversionService()
self.document_service: Any = (
DocumentServiceCls() if DocumentServiceCls is not None else None
@ -142,6 +145,7 @@ class DocumentsLoader:
converted_path = self.document_conversion_service.convert_office_to_pdf(
file_path,
temp_dir,
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
)
return self._parse_with_liteparse(converted_path)
@ -149,6 +153,7 @@ class DocumentsLoader:
converted_path = self.document_conversion_service.convert_office_to_pdf(
file_path,
conversion_dir,
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
)
return self._parse_with_liteparse(converted_path)
@ -157,6 +162,7 @@ class DocumentsLoader:
converted_path = self.document_conversion_service.convert_image_to_png(
file_path,
temp_dir,
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
)
return self._parse_with_liteparse(converted_path)
@ -164,6 +170,7 @@ class DocumentsLoader:
converted_path = self.document_conversion_service.convert_image_to_png(
file_path,
conversion_dir,
timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS,
)
return self._parse_with_liteparse(converted_path)

View file

@ -221,24 +221,92 @@ class ImageGenerationService:
prompt, output_directory, "gemini-3-pro-image-preview"
)
async def get_image_from_pexels(self, prompt: str) -> str:
async with aiohttp.ClientSession(trust_env=True) as session:
response = await session.get(
f"https://api.pexels.com/v1/search?query={prompt}&per_page=1",
headers={"Authorization": f"{get_pexels_api_key_env()}"},
)
data = await response.json()
image_url = data["photos"][0]["src"]["large"]
return image_url
async def get_image_from_pexels(
self, prompt: str, api_key: str | None = None, limit: int = 1
) -> str | list[str]:
per_page = max(1, min(limit, 80))
resolved_api_key = (api_key or get_pexels_api_key_env() or "").strip()
async def get_image_from_pixabay(self, prompt: str) -> str:
async with aiohttp.ClientSession(trust_env=True) as session:
response = await session.get(
f"https://pixabay.com/api/?key={get_pixabay_api_key_env()}&q={prompt}&image_type=photo&per_page=3"
"https://api.pexels.com/v1/search",
params={"query": prompt, "per_page": per_page},
headers={"Authorization": resolved_api_key} if resolved_api_key else {},
timeout=aiohttp.ClientTimeout(total=20),
)
if response.status in {401, 403}:
raise HTTPException(status_code=401, detail="Invalid Pexels API key")
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=502,
detail=f"Pexels request failed: {error_text}",
)
data = await response.json()
image_url = data["hits"][0]["largeImageURL"]
return image_url
photos = data.get("photos", [])
image_urls = [
photo.get("src", {}).get("large")
for photo in photos
if photo.get("src", {}).get("large")
]
if limit <= 1:
return image_urls[0] if image_urls else ""
return image_urls[:limit]
async def get_image_from_pixabay(
self, prompt: str, api_key: str | None = None, limit: int = 1
) -> str | list[str]:
per_page = max(3, min(limit, 200))
resolved_api_key = (api_key or get_pixabay_api_key_env() or "").strip()
async with aiohttp.ClientSession(trust_env=True) as session:
response = await session.get(
"https://pixabay.com/api/",
params={
"key": resolved_api_key,
"q": prompt[:99],
"image_type": "photo",
"per_page": per_page,
},
timeout=aiohttp.ClientTimeout(total=20),
)
if response.status in {401, 403}:
error_text = await response.text()
raise HTTPException(
status_code=401,
detail=f"Invalid Pixabay API key: {error_text}",
)
if response.status == 400:
error_text = await response.text()
if "api key" in error_text.lower():
raise HTTPException(
status_code=401,
detail=f"Invalid Pixabay API key: {error_text}",
)
raise HTTPException(
status_code=400,
detail=f"Pixabay request invalid: {error_text}",
)
if response.status != 200:
error_text = await response.text()
raise HTTPException(
status_code=502,
detail=f"Pixabay request failed: {error_text}",
)
data = await response.json()
hits = data.get("hits", [])
image_urls = [
hit.get("largeImageURL") for hit in hits if hit.get("largeImageURL")
]
if limit <= 1:
return image_urls[0] if image_urls else ""
return image_urls[:limit]
async def generate_image_comfyui(self, prompt: str, output_directory: str) -> str:
"""

View file

@ -398,3 +398,38 @@ class TestImageGenerationEndpoint:
asyncio.run(run_test())
def test_search_stock_images_defaults_to_selected_pixabay(self, client, mock_images_directory):
"""
Test stock image search defaults to IMAGE_PROVIDER when provider query param is omitted
- Sets IMAGE_PROVIDER to pixabay
- Ensures /images/search uses Pixabay instead of returning provider validation error
"""
with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}):
with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory):
with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class:
mock_service_instance = Mock()
mock_service_instance.get_image_from_pixabay = AsyncMock(
return_value=["https://example.com/pixabay_image.jpg"]
)
mock_service_instance.get_image_from_pexels = AsyncMock(
return_value=["https://example.com/pexels_image.jpg"]
)
mock_service_class.return_value = mock_service_instance
response = client.get("/images/search?query=business&limit=1")
assert response.status_code == 200
assert response.json() == ["https://example.com/pixabay_image.jpg"]
mock_service_instance.get_image_from_pixabay.assert_awaited_once()
mock_service_instance.get_image_from_pexels.assert_not_called()
def test_search_stock_images_invalid_provider_returns_400(self, client):
"""
Test stock image search validates invalid provider values
- Ensures unsupported providers return HTTP 400 with clear guidance
"""
response = client.get("/images/search?query=business&provider=invalid-provider")
assert response.status_code == 400
assert response.json()["detail"] == "provider must be either 'pexels' or 'pixabay'"

View file

@ -21,6 +21,9 @@ import TextProvider from "./TextProvider";
import ImageProvider from "./ImageProvider";
import PrivacySettings from "./PrivacySettings";
import { IMAGE_PROVIDERS, LLM_PROVIDERS } from "@/utils/providerConstants";
import { ImagesApi } from "@/app/(presentation-generator)/services/api/images";
const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]);
// Button state interface
interface ButtonState {
@ -72,6 +75,36 @@ const SettingsPage = () => {
return 0;
}, [downloadingModel?.downloaded, downloadingModel?.size]);
const ensureSelectedStockProviderReady = async (): Promise<boolean> => {
if (llmConfig.DISABLE_IMAGE_GENERATION) {
return true;
}
const provider = (llmConfig.IMAGE_PROVIDER || "").toLowerCase();
if (!STOCK_IMAGE_PROVIDERS.has(provider)) {
return true;
}
const providerApiKey =
provider === "pexels" ? llmConfig.PEXELS_API_KEY : llmConfig.PIXABAY_API_KEY;
try {
await ImagesApi.searchStockImages("business", 1, {
provider,
apiKey: providerApiKey,
strictApiKey: true,
});
return true;
} catch (error: any) {
notify.error(
"Cannot save settings",
error?.message ||
`Unable to reach ${provider} with the provided API key. Please verify your settings and try again.`
);
return false;
}
};
const handleSaveConfig = async () => {
trackEvent(MixpanelEvent.Settings_SaveConfiguration_Button_Clicked, { pathname });
const validationError = getLLMConfigValidationError(llmConfig);
@ -79,6 +112,12 @@ const SettingsPage = () => {
notify.error("Cannot save settings", validationError);
return;
}
const providerReady = await ensureSelectedStockProviderReady();
if (!providerReady) {
return;
}
try {
setButtonState(prev => ({
...prev,

View file

@ -1,5 +1,6 @@
"use client";
import React, { useEffect, useState, useRef } from "react";
import { useSelector } from "react-redux";
import {
Sheet,
SheetContent,
@ -9,7 +10,7 @@ import {
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Button } from "@/components/ui/button";
import { Textarea } from "@/components/ui/textarea";
import { Wand2, Upload, Loader2, Delete, Trash } from "lucide-react";
import { Wand2, Upload, Loader2, Trash } from "lucide-react";
import { cn } from "@/lib/utils";
import { PresentationGenerationApi } from "../services/api/presentation-generation";
import { Skeleton } from "@/components/ui/skeleton";
@ -18,6 +19,10 @@ import { PreviousGeneratedImagesResponse } from "../services/api/params";
import { trackEvent, MixpanelEvent } from "@/utils/mixpanel";
import { ImagesApi } from "../services/api/images";
import { ImageAssetResponse } from "../services/api/types";
import { RootState } from "@/store/store";
const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]);
interface ImageEditorProps {
initialImage: string | null;
imageIdx?: number;
@ -39,6 +44,10 @@ const ImageEditor = ({
onFocusPointClick,
onImageChange,
}: ImageEditorProps) => {
const llmConfig = useSelector((state: RootState) => state.userConfig.llm_config);
const selectedImageProvider = (llmConfig?.IMAGE_PROVIDER || "").toLowerCase();
const isStockImageProvider = STOCK_IMAGE_PROVIDERS.has(selectedImageProvider);
// State management
const [previewImages, setPreviewImages] = useState(initialImage);
const [previousGeneratedImages, setPreviousGeneratedImages] = useState<
@ -53,6 +62,7 @@ const ImageEditor = ({
const [isOpen, setIsOpen] = useState(true);
const [uploadedImages, setUploadedImages] = useState<ImageAssetResponse[]>([]);
const [uploadedImagesLoading, setUploadedImagesLoading] = useState(false);
const [stockSearchResults, setStockSearchResults] = useState<string[]>([]);
// Focus point and object fit for image editing
const [isFocusPointMode, setIsFocusPointMode] = useState(false);
const [focusPoint, setFocusPoint] = useState(
@ -190,18 +200,40 @@ const ImageEditor = ({
setError("Please enter a prompt");
return;
}
const trimmedPrompt = prompt.trim();
if (!trimmedPrompt) {
setError("Please enter a prompt");
return;
}
try {
setIsGenerating(true);
setError(null);
trackEvent(MixpanelEvent.ImageEditor_GenerateImage_API_Call);
const response = await PresentationGenerationApi.generateImage({
prompt: prompt,
});
setPreviewImages(response);
if (isStockImageProvider) {
const providerApiKey =
selectedImageProvider === "pexels"
? llmConfig?.PEXELS_API_KEY
: llmConfig?.PIXABAY_API_KEY;
const results = await ImagesApi.searchStockImages(trimmedPrompt, 12, {
provider: selectedImageProvider,
apiKey: providerApiKey,
});
setStockSearchResults(results);
if (results.length > 0) {
setPreviewImages(results[0]);
}
} else {
const response = await PresentationGenerationApi.generateImage({
prompt: trimmedPrompt,
});
setPreviewImages(response);
}
} catch (err: any) {
console.error("Error in image generation", err);
setError(err.message || "Failed to generate image. Please try again.");
setError(err.message || "Failed to fetch images. Please try again.");
} finally {
setIsGenerating(false);
}
@ -286,7 +318,7 @@ const ImageEditor = ({
<Tabs defaultValue="generate" className="w-full" onValueChange={handleTabChange}>
<TabsList className="grid bg-blue-100 border border-blue-300 w-full grid-cols-3 mx-auto">
<TabsTrigger className="font-medium" value="generate">
AI Generate
{isStockImageProvider ? "Stock Search" : "AI Generate"}
</TabsTrigger>
<TabsTrigger className="font-medium" value="upload">
Upload
@ -305,10 +337,14 @@ const ImageEditor = ({
<div>
<h3 className="text-base font-medium mb-2">
Image Description
{isStockImageProvider ? "Image Keyword" : "Image Description"}
</h3>
<Textarea
placeholder="Describe the image you want to generate..."
placeholder={
isStockImageProvider
? "Enter a keyword to search stock images..."
: "Describe the image you want to generate..."
}
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="min-h-[100px]"
@ -321,35 +357,62 @@ const ImageEditor = ({
disabled={!prompt || isGenerating}
>
<Wand2 className="w-4 h-4 mr-2" />
{isGenerating ? "Generating..." : "Generate Image"}
{isGenerating
? (isStockImageProvider ? "Searching..." : "Generating...")
: (isStockImageProvider ? "Search Images" : "Generate Image")}
</Button>
{error && <p className="text-red-500 text-sm">{error}</p>}
<div className="grid grid-cols-2 gap-4">
{isGenerating || !previewImages ? (
{isGenerating ? (
Array.from({ length: 4 }).map((_, index) => (
<Skeleton
key={index}
className="aspect-[4/3] w-full rounded-lg"
/>
))
) : isStockImageProvider ? (
stockSearchResults.length > 0 ? (
stockSearchResults.map((imageUrl, index) => (
<div
key={`${imageUrl}-${index}`}
onClick={() => handleImageChange(imageUrl)}
className="aspect-[4/3] w-full overflow-hidden rounded-lg border cursor-pointer hover:border-blue-500 transition-colors"
>
<img
src={imageUrl}
alt={`Stock result ${index + 1}`}
className="w-full h-full object-cover"
/>
</div>
))
) : (
<p className="text-sm text-gray-500 col-span-2">Search with a keyword to view stock images.</p>
)
) : (
<div
onClick={() => handleImageChange(previewImages)}
className="aspect-[4/3] w-full overflow-hidden rounded-lg border cursor-pointer hover:border-blue-500 transition-colors"
>
{previewImages && (
!previewImages ? (
Array.from({ length: 4 }).map((_, index) => (
<Skeleton
key={index}
className="aspect-[4/3] w-full rounded-lg"
/>
))
) : (
<div
onClick={() => handleImageChange(previewImages)}
className="aspect-[4/3] w-full overflow-hidden rounded-lg border cursor-pointer hover:border-blue-500 transition-colors"
>
<img
src={previewImages}
alt={`Preview`}
className="w-full h-full object-cover"
/>
)}
</div>
</div>
)
)}
</div>
{previousGeneratedImages.length > 0 && (
{!isStockImageProvider && previousGeneratedImages.length > 0 && (
<div className="mt-4">
<h3 className="text-sm font-medium mb-2">
Previous Generated Images

View file

@ -3,6 +3,12 @@ import { ApiResponseHandler } from "./api-error-handler";
import { ImageAssetResponse } from "./types";
import { getApiUrl } from "@/utils/api";
interface StockSearchOptions {
provider?: string;
apiKey?: string;
strictApiKey?: boolean;
}
export class ImagesApi {
@ -43,6 +49,41 @@ export class ImagesApi {
throw error;
}
}
static async searchStockImages(
query: string,
limit: number = 12,
options: StockSearchOptions = {}
): Promise<string[]> {
try {
const params = new URLSearchParams({
query,
limit: String(limit),
});
const normalizedProvider = (options.provider || "").trim().toLowerCase();
if (normalizedProvider) {
params.set("provider", normalizedProvider);
}
if (options.strictApiKey) {
params.set("strict_api_key", "true");
}
const headers: Record<string, string> = {};
const trimmedApiKey = (options.apiKey || "").trim();
if (trimmedApiKey) {
headers["X-Provider-Api-Key"] = trimmedApiKey;
}
const response = await fetch(getApiUrl(`/api/v1/ppt/images/search?${params.toString()}`), {
method: "GET",
headers,
});
return await ApiResponseHandler.handleResponse(response, "Failed to search stock images") as string[];
} catch (error:any) {
console.log("Stock image search error:", error);
throw error;
}
}
}

View file

@ -12,7 +12,7 @@
"use client";
import React, { useState } from "react";
import { useRouter, usePathname } from "next/navigation";
import { useDispatch } from "react-redux";
import { useDispatch, useSelector } from "react-redux";
import { clearOutlines, setPresentationId } from "@/store/slices/presentationGeneration";
import { PromptInput } from "./PromptInput";
import { LanguageType, PresentationConfig, ToneType, VerbosityType } from "../type";
@ -26,6 +26,10 @@ import Wrapper from "@/components/Wrapper";
import { setPptGenUploadState } from "@/store/slices/presentationGenUpload";
import { trackEvent, MixpanelEvent } from "@/utils/mixpanel";
import { ConfigurationSelects } from "./ConfigurationSelects";
import { RootState } from "@/store/store";
import { ImagesApi } from "../../services/api/images";
const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]);
// Types for loading state
interface LoadingState {
@ -40,11 +44,12 @@ const UploadPage = () => {
const router = useRouter();
const pathname = usePathname();
const dispatch = useDispatch();
const llmConfig = useSelector((state: RootState) => state.userConfig.llm_config);
const [files, setFiles] = useState<File[]>([]);
const [config, setConfig] = useState<PresentationConfig>({
slides: null,
language: LanguageType.English,
language: LanguageType.Auto,
prompt: "",
tone: ToneType.Default,
verbosity: VerbosityType.Standard,
@ -66,6 +71,36 @@ const UploadPage = () => {
setConfig((prev) => ({ ...prev, [key]: value } as PresentationConfig));
};
const ensureStockImageProviderReady = async (): Promise<boolean> => {
if (llmConfig?.DISABLE_IMAGE_GENERATION) {
return true;
}
const selectedProvider = (llmConfig?.IMAGE_PROVIDER || "").toLowerCase();
if (!STOCK_IMAGE_PROVIDERS.has(selectedProvider)) {
return true;
}
try {
const providerApiKey =
selectedProvider === "pexels"
? llmConfig?.PEXELS_API_KEY
: llmConfig?.PIXABAY_API_KEY;
await ImagesApi.searchStockImages("business", 1, {
provider: selectedProvider,
apiKey: providerApiKey,
strictApiKey: true,
});
return true;
} catch (error: any) {
toast.error(
error?.message ||
`Unable to reach ${selectedProvider} right now. Please check your API key/settings and try again.`
);
return false;
}
};
/**
* Validates the current configuration and files
* @returns boolean indicating if the configuration is valid
@ -76,6 +111,11 @@ const UploadPage = () => {
return false;
}
if (files.length > 0 && config.language === LanguageType.Auto) {
toast.error("Please choose a language before processing uploaded documents");
return false;
}
if (!config.prompt.trim() && files.length === 0) {
toast.error("No Prompt or Document Provided");
return false;
@ -89,6 +129,9 @@ const UploadPage = () => {
const handleGeneratePresentation = async () => {
if (!validateConfiguration()) return;
const isStockProviderReady = await ensureStockImageProviderReady();
if (!isStockProviderReady) return;
try {
const hasUploadedAssets = files.length > 0;

View file

@ -15,7 +15,7 @@ export enum ThemeType {
export enum LanguageType {
// Major World Languages
// Auto = "Auto",
Auto = "Auto",
English = "English",
Spanish = "Spanish (Español)",
French = "French (Français)",

View file

@ -3,6 +3,8 @@ import { IconSearch, ImageGenerate, ImageSearch, PreviousGeneratedImagesResponse
import { ApiResponseHandler } from "./api-error-handler";
export class PresentationGenerationApi {
private static readonly DECOMPOSE_TIMEOUT_MS = 10 * 60 * 1000;
static async uploadDoc(documents: File[]) {
const formData = new FormData();
@ -29,6 +31,9 @@ export class PresentationGenerationApi {
}
static async decomposeDocuments(documentKeys: string[]) {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), this.DECOMPOSE_TIMEOUT_MS);
try {
const response = await fetch(
`/api/v1/ppt/files/decompose`,
@ -39,13 +44,19 @@ export class PresentationGenerationApi {
file_paths: documentKeys,
}),
cache: "no-cache",
signal: controller.signal,
}
);
return await ApiResponseHandler.handleResponse(response, "Failed to decompose documents");
} catch (error) {
if (error instanceof DOMException && error.name === "AbortError") {
throw new Error("File decomposition timed out after 10 minutes");
}
console.error("Error in Decompose Files", error);
throw error;
} finally {
clearTimeout(timeoutId);
}
}