From cab99c6bd28c30849fb4935d6f8d830510fc488e Mon Sep 17 00:00:00 2001 From: sudipnext Date: Tue, 7 Apr 2026 18:00:16 +0545 Subject: [PATCH] feat: Implement stock image search functionality with Pexels and Pixabay integration --- .../fastapi/api/v1/ppt/endpoints/images.py | 74 ++++++++++++- .../fastapi/services/documents_loader.py | 9 +- .../services/image_generation_service.py | 94 +++++++++++++--- .../fastapi/tests/test_image_generation.py | 35 ++++++ .../(dashboard)/settings/SettingPage.tsx | 39 +++++++ .../components/ImageEditor.tsx | 101 ++++++++++++++---- .../services/api/images.ts | 41 +++++++ .../upload/components/UploadPage.tsx | 47 +++++++- .../(presentation-generator)/upload/type.ts | 2 +- .../services/api/presentation-generation.ts | 11 ++ 10 files changed, 416 insertions(+), 37 deletions(-) diff --git a/electron/servers/fastapi/api/v1/ppt/endpoints/images.py b/electron/servers/fastapi/api/v1/ppt/endpoints/images.py index 82192c3c..525b5f36 100644 --- a/electron/servers/fastapi/api/v1/ppt/endpoints/images.py +++ b/electron/servers/fastapi/api/v1/ppt/endpoints/images.py @@ -1,5 +1,5 @@ from typing import List -from fastapi import APIRouter, Depends, File, UploadFile, HTTPException +from fastapi import APIRouter, Depends, File, UploadFile, HTTPException, Query, Header from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import select @@ -8,6 +8,9 @@ from models.sql.image_asset import ImageAsset from services.database import get_async_session from services.image_generation_service import ImageGenerationService from utils.asset_directory_utils import get_images_directory +from utils.get_env import get_pexels_api_key_env, get_pixabay_api_key_env +from utils.image_provider import get_selected_image_provider +from enums.image_provider import ImageProvider import os import uuid from utils.file_utils import get_file_name_with_random_uuid @@ -15,6 +18,75 @@ from utils.file_utils import get_file_name_with_random_uuid IMAGES_ROUTER = APIRouter(prefix="/images", tags=["Images"]) +def _normalize_stock_provider(provider: str | None) -> str: + normalized_provider = (provider or "").strip().lower() + if normalized_provider in {"pixels", "pixel", "pexel"}: + normalized_provider = "pexels" + + if normalized_provider: + if normalized_provider in {"pexels", "pixabay"}: + return normalized_provider + raise HTTPException( + status_code=400, + detail="provider must be either 'pexels' or 'pixabay'", + ) + + selected_provider = get_selected_image_provider() + if selected_provider == ImageProvider.PIXABAY: + return "pixabay" + return "pexels" + + +@IMAGES_ROUTER.get("/search", response_model=List[str]) +async def search_stock_images( + query: str, + limit: int = Query(default=12, ge=1, le=30), + provider: str | None = Query(default=None), + strict_api_key: bool = Query(default=False), + x_provider_api_key: str | None = Header(default=None, alias="X-Provider-Api-Key"), +): + normalized_provider = _normalize_stock_provider(provider) + + image_generation_service = ImageGenerationService(get_images_directory()) + + if normalized_provider == "pexels": + api_key = (x_provider_api_key or get_pexels_api_key_env() or "").strip() + if strict_api_key and not api_key: + raise HTTPException(status_code=401, detail="Pexels API key is required") + + # Pexels can return cached public responses for common queries. + # Use a nonce query in strict mode to force a real auth check. + if strict_api_key: + validation_query = f"__presenton_auth_check_{uuid.uuid4().hex}" + await image_generation_service.get_image_from_pexels( + validation_query, + api_key=api_key, + limit=1, + ) + + images = await image_generation_service.get_image_from_pexels( + query, + api_key=api_key, + limit=limit, + ) + if isinstance(images, str): + return [images] if images else [] + return images + + api_key = (x_provider_api_key or get_pixabay_api_key_env() or "").strip() + if strict_api_key and not api_key: + raise HTTPException(status_code=401, detail="Pixabay API key is required") + + images = await image_generation_service.get_image_from_pixabay( + query, + api_key=api_key, + limit=limit, + ) + if isinstance(images, str): + return [images] if images else [] + return images + + @IMAGES_ROUTER.get("/generate") async def generate_image( prompt: str, sql_session: AsyncSession = Depends(get_async_session) diff --git a/electron/servers/fastapi/services/documents_loader.py b/electron/servers/fastapi/services/documents_loader.py index e65a659a..63f84a0b 100644 --- a/electron/servers/fastapi/services/documents_loader.py +++ b/electron/servers/fastapi/services/documents_loader.py @@ -31,6 +31,7 @@ LOGGER = logging.getLogger(__name__) class DocumentsLoader: + DECOMPOSE_TIMEOUT_SECONDS = 600 def __init__( self, @@ -39,7 +40,9 @@ class DocumentsLoader: ): self._file_paths = file_paths self._ocr_language = presentation_language_to_ocr_code(presentation_language) - self.liteparse_service = LiteParseService() + self.liteparse_service = LiteParseService( + timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS + ) self.document_conversion_service = DocumentConversionService() self.document_service: Any = ( DocumentServiceCls() if DocumentServiceCls is not None else None @@ -142,6 +145,7 @@ class DocumentsLoader: converted_path = self.document_conversion_service.convert_office_to_pdf( file_path, temp_dir, + timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS, ) return self._parse_with_liteparse(converted_path) @@ -149,6 +153,7 @@ class DocumentsLoader: converted_path = self.document_conversion_service.convert_office_to_pdf( file_path, conversion_dir, + timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS, ) return self._parse_with_liteparse(converted_path) @@ -157,6 +162,7 @@ class DocumentsLoader: converted_path = self.document_conversion_service.convert_image_to_png( file_path, temp_dir, + timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS, ) return self._parse_with_liteparse(converted_path) @@ -164,6 +170,7 @@ class DocumentsLoader: converted_path = self.document_conversion_service.convert_image_to_png( file_path, conversion_dir, + timeout_seconds=self.DECOMPOSE_TIMEOUT_SECONDS, ) return self._parse_with_liteparse(converted_path) diff --git a/electron/servers/fastapi/services/image_generation_service.py b/electron/servers/fastapi/services/image_generation_service.py index 8e14b8be..be951eba 100644 --- a/electron/servers/fastapi/services/image_generation_service.py +++ b/electron/servers/fastapi/services/image_generation_service.py @@ -221,24 +221,92 @@ class ImageGenerationService: prompt, output_directory, "gemini-3-pro-image-preview" ) - async def get_image_from_pexels(self, prompt: str) -> str: - async with aiohttp.ClientSession(trust_env=True) as session: - response = await session.get( - f"https://api.pexels.com/v1/search?query={prompt}&per_page=1", - headers={"Authorization": f"{get_pexels_api_key_env()}"}, - ) - data = await response.json() - image_url = data["photos"][0]["src"]["large"] - return image_url + async def get_image_from_pexels( + self, prompt: str, api_key: str | None = None, limit: int = 1 + ) -> str | list[str]: + per_page = max(1, min(limit, 80)) + resolved_api_key = (api_key or get_pexels_api_key_env() or "").strip() - async def get_image_from_pixabay(self, prompt: str) -> str: async with aiohttp.ClientSession(trust_env=True) as session: response = await session.get( - f"https://pixabay.com/api/?key={get_pixabay_api_key_env()}&q={prompt}&image_type=photo&per_page=3" + "https://api.pexels.com/v1/search", + params={"query": prompt, "per_page": per_page}, + headers={"Authorization": resolved_api_key} if resolved_api_key else {}, + timeout=aiohttp.ClientTimeout(total=20), ) + + if response.status in {401, 403}: + raise HTTPException(status_code=401, detail="Invalid Pexels API key") + if response.status != 200: + error_text = await response.text() + raise HTTPException( + status_code=502, + detail=f"Pexels request failed: {error_text}", + ) + data = await response.json() - image_url = data["hits"][0]["largeImageURL"] - return image_url + photos = data.get("photos", []) + image_urls = [ + photo.get("src", {}).get("large") + for photo in photos + if photo.get("src", {}).get("large") + ] + + if limit <= 1: + return image_urls[0] if image_urls else "" + return image_urls[:limit] + + async def get_image_from_pixabay( + self, prompt: str, api_key: str | None = None, limit: int = 1 + ) -> str | list[str]: + per_page = max(3, min(limit, 200)) + resolved_api_key = (api_key or get_pixabay_api_key_env() or "").strip() + + async with aiohttp.ClientSession(trust_env=True) as session: + response = await session.get( + "https://pixabay.com/api/", + params={ + "key": resolved_api_key, + "q": prompt[:99], + "image_type": "photo", + "per_page": per_page, + }, + timeout=aiohttp.ClientTimeout(total=20), + ) + + if response.status in {401, 403}: + error_text = await response.text() + raise HTTPException( + status_code=401, + detail=f"Invalid Pixabay API key: {error_text}", + ) + if response.status == 400: + error_text = await response.text() + if "api key" in error_text.lower(): + raise HTTPException( + status_code=401, + detail=f"Invalid Pixabay API key: {error_text}", + ) + raise HTTPException( + status_code=400, + detail=f"Pixabay request invalid: {error_text}", + ) + if response.status != 200: + error_text = await response.text() + raise HTTPException( + status_code=502, + detail=f"Pixabay request failed: {error_text}", + ) + + data = await response.json() + hits = data.get("hits", []) + image_urls = [ + hit.get("largeImageURL") for hit in hits if hit.get("largeImageURL") + ] + + if limit <= 1: + return image_urls[0] if image_urls else "" + return image_urls[:limit] async def generate_image_comfyui(self, prompt: str, output_directory: str) -> str: """ diff --git a/electron/servers/fastapi/tests/test_image_generation.py b/electron/servers/fastapi/tests/test_image_generation.py index bf0db108..56a602b4 100644 --- a/electron/servers/fastapi/tests/test_image_generation.py +++ b/electron/servers/fastapi/tests/test_image_generation.py @@ -398,3 +398,38 @@ class TestImageGenerationEndpoint: asyncio.run(run_test()) + def test_search_stock_images_defaults_to_selected_pixabay(self, client, mock_images_directory): + """ + Test stock image search defaults to IMAGE_PROVIDER when provider query param is omitted + - Sets IMAGE_PROVIDER to pixabay + - Ensures /images/search uses Pixabay instead of returning provider validation error + """ + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}): + with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory): + with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class: + mock_service_instance = Mock() + mock_service_instance.get_image_from_pixabay = AsyncMock( + return_value=["https://example.com/pixabay_image.jpg"] + ) + mock_service_instance.get_image_from_pexels = AsyncMock( + return_value=["https://example.com/pexels_image.jpg"] + ) + mock_service_class.return_value = mock_service_instance + + response = client.get("/images/search?query=business&limit=1") + + assert response.status_code == 200 + assert response.json() == ["https://example.com/pixabay_image.jpg"] + mock_service_instance.get_image_from_pixabay.assert_awaited_once() + mock_service_instance.get_image_from_pexels.assert_not_called() + + def test_search_stock_images_invalid_provider_returns_400(self, client): + """ + Test stock image search validates invalid provider values + - Ensures unsupported providers return HTTP 400 with clear guidance + """ + response = client.get("/images/search?query=business&provider=invalid-provider") + + assert response.status_code == 400 + assert response.json()["detail"] == "provider must be either 'pexels' or 'pixabay'" + diff --git a/electron/servers/nextjs/app/(presentation-generator)/(dashboard)/settings/SettingPage.tsx b/electron/servers/nextjs/app/(presentation-generator)/(dashboard)/settings/SettingPage.tsx index 9f28a065..bf02088a 100644 --- a/electron/servers/nextjs/app/(presentation-generator)/(dashboard)/settings/SettingPage.tsx +++ b/electron/servers/nextjs/app/(presentation-generator)/(dashboard)/settings/SettingPage.tsx @@ -21,6 +21,9 @@ import TextProvider from "./TextProvider"; import ImageProvider from "./ImageProvider"; import PrivacySettings from "./PrivacySettings"; import { IMAGE_PROVIDERS, LLM_PROVIDERS } from "@/utils/providerConstants"; +import { ImagesApi } from "@/app/(presentation-generator)/services/api/images"; + +const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]); // Button state interface interface ButtonState { @@ -72,6 +75,36 @@ const SettingsPage = () => { return 0; }, [downloadingModel?.downloaded, downloadingModel?.size]); + const ensureSelectedStockProviderReady = async (): Promise => { + if (llmConfig.DISABLE_IMAGE_GENERATION) { + return true; + } + + const provider = (llmConfig.IMAGE_PROVIDER || "").toLowerCase(); + if (!STOCK_IMAGE_PROVIDERS.has(provider)) { + return true; + } + + const providerApiKey = + provider === "pexels" ? llmConfig.PEXELS_API_KEY : llmConfig.PIXABAY_API_KEY; + + try { + await ImagesApi.searchStockImages("business", 1, { + provider, + apiKey: providerApiKey, + strictApiKey: true, + }); + return true; + } catch (error: any) { + notify.error( + "Cannot save settings", + error?.message || + `Unable to reach ${provider} with the provided API key. Please verify your settings and try again.` + ); + return false; + } + }; + const handleSaveConfig = async () => { trackEvent(MixpanelEvent.Settings_SaveConfiguration_Button_Clicked, { pathname }); const validationError = getLLMConfigValidationError(llmConfig); @@ -79,6 +112,12 @@ const SettingsPage = () => { notify.error("Cannot save settings", validationError); return; } + + const providerReady = await ensureSelectedStockProviderReady(); + if (!providerReady) { + return; + } + try { setButtonState(prev => ({ ...prev, diff --git a/electron/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx b/electron/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx index 65eae752..29aee4e6 100644 --- a/electron/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx +++ b/electron/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx @@ -1,5 +1,6 @@ "use client"; import React, { useEffect, useState, useRef } from "react"; +import { useSelector } from "react-redux"; import { Sheet, SheetContent, @@ -9,7 +10,7 @@ import { import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Button } from "@/components/ui/button"; import { Textarea } from "@/components/ui/textarea"; -import { Wand2, Upload, Loader2, Delete, Trash } from "lucide-react"; +import { Wand2, Upload, Loader2, Trash } from "lucide-react"; import { cn } from "@/lib/utils"; import { PresentationGenerationApi } from "../services/api/presentation-generation"; import { Skeleton } from "@/components/ui/skeleton"; @@ -18,6 +19,10 @@ import { PreviousGeneratedImagesResponse } from "../services/api/params"; import { trackEvent, MixpanelEvent } from "@/utils/mixpanel"; import { ImagesApi } from "../services/api/images"; import { ImageAssetResponse } from "../services/api/types"; +import { RootState } from "@/store/store"; + +const STOCK_IMAGE_PROVIDERS = new Set(["pexels", "pixabay"]); + interface ImageEditorProps { initialImage: string | null; imageIdx?: number; @@ -39,6 +44,10 @@ const ImageEditor = ({ onFocusPointClick, onImageChange, }: ImageEditorProps) => { + const llmConfig = useSelector((state: RootState) => state.userConfig.llm_config); + const selectedImageProvider = (llmConfig?.IMAGE_PROVIDER || "").toLowerCase(); + const isStockImageProvider = STOCK_IMAGE_PROVIDERS.has(selectedImageProvider); + // State management const [previewImages, setPreviewImages] = useState(initialImage); const [previousGeneratedImages, setPreviousGeneratedImages] = useState< @@ -53,6 +62,7 @@ const ImageEditor = ({ const [isOpen, setIsOpen] = useState(true); const [uploadedImages, setUploadedImages] = useState([]); const [uploadedImagesLoading, setUploadedImagesLoading] = useState(false); + const [stockSearchResults, setStockSearchResults] = useState([]); // Focus point and object fit for image editing const [isFocusPointMode, setIsFocusPointMode] = useState(false); const [focusPoint, setFocusPoint] = useState( @@ -190,18 +200,40 @@ const ImageEditor = ({ setError("Please enter a prompt"); return; } + + const trimmedPrompt = prompt.trim(); + if (!trimmedPrompt) { + setError("Please enter a prompt"); + return; + } + try { setIsGenerating(true); setError(null); trackEvent(MixpanelEvent.ImageEditor_GenerateImage_API_Call); - const response = await PresentationGenerationApi.generateImage({ - prompt: prompt, - }); - setPreviewImages(response); + if (isStockImageProvider) { + const providerApiKey = + selectedImageProvider === "pexels" + ? llmConfig?.PEXELS_API_KEY + : llmConfig?.PIXABAY_API_KEY; + const results = await ImagesApi.searchStockImages(trimmedPrompt, 12, { + provider: selectedImageProvider, + apiKey: providerApiKey, + }); + setStockSearchResults(results); + if (results.length > 0) { + setPreviewImages(results[0]); + } + } else { + const response = await PresentationGenerationApi.generateImage({ + prompt: trimmedPrompt, + }); + setPreviewImages(response); + } } catch (err: any) { console.error("Error in image generation", err); - setError(err.message || "Failed to generate image. Please try again."); + setError(err.message || "Failed to fetch images. Please try again."); } finally { setIsGenerating(false); } @@ -286,7 +318,7 @@ const ImageEditor = ({ - AI Generate + {isStockImageProvider ? "Stock Search" : "AI Generate"} Upload @@ -305,10 +337,14 @@ const ImageEditor = ({

- Image Description + {isStockImageProvider ? "Image Keyword" : "Image Description"}