diff --git a/servers/fastapi/api/lifespan.py b/servers/fastapi/api/lifespan.py index 0d92a2d0..184a55eb 100644 --- a/servers/fastapi/api/lifespan.py +++ b/servers/fastapi/api/lifespan.py @@ -5,12 +5,18 @@ from fastapi import FastAPI from sqlmodel import SQLModel from services import SQL_ENGINE -from utils.model_availability import check_llm_model_availability +from utils.get_env import get_app_data_directory_env +from utils.model_availability import check_llm_and_image_provider_api_or_model_availability @asynccontextmanager async def app_lifespan(_: FastAPI): - os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True) + """ + Lifespan context manager for FastAPI application. + Initializes the application data directory and checks LLM model availability. + + """ + os.makedirs(get_app_data_directory_env(), exist_ok=True) SQLModel.metadata.create_all(SQL_ENGINE) - await check_llm_model_availability() + await check_llm_and_image_provider_api_or_model_availability() yield diff --git a/servers/fastapi/api/v1/ppt/background_tasks.py b/servers/fastapi/api/v1/ppt/background_tasks.py new file mode 100644 index 00000000..e9a604f6 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/background_tasks.py @@ -0,0 +1,59 @@ +import json + +from fastapi import HTTPException + +from models.ollama_model_status import OllamaModelStatus +from services import REDIS_SERVICE +from utils.ollama import pull_ollama_model + + +async def pull_ollama_model_background_task(model: str): + saved_model_status = OllamaModelStatus( + name=model, + status="pulling", + done=False, + ) + log_event_count = 0 + + try: + async for event in pull_ollama_model(model): + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue + + if "completed" in event: + saved_model_status.downloaded = event["completed"] + + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] + + if "status" in event: + saved_model_status.status = event["status"] + + REDIS_SERVICE.set( + f"ollama_models/{model}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + + except Exception as e: + saved_model_status.status = "error" + saved_model_status.done = True + REDIS_SERVICE.set( + f"ollama_models/{model}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + raise HTTPException( + status_code=500, + detail=f"Failed to pull model: {e}", + ) + + saved_model_status.done = True + saved_model_status.status = "pulled" + saved_model_status.downloaded = saved_model_status.size + + REDIS_SERVICE.set( + f"ollama_models/{model}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + + return saved_model_status diff --git a/servers/fastapi/api/v1/ppt/endpoints/custom_llm.py b/servers/fastapi/api/v1/ppt/endpoints/custom_llm.py new file mode 100644 index 00000000..8a44cb22 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/custom_llm.py @@ -0,0 +1,14 @@ +from typing import Annotated, List, Optional +from fastapi import APIRouter, Body + +from utils.custom_llm_provider import list_available_custom_models + +CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"]) + + +@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str]) +async def get_available_models( + url: Annotated[Optional[str], Body()] = None, + api_key: Annotated[Optional[str], Body()] = None, +): + return await list_available_custom_models(url, api_key) diff --git a/servers/fastapi/api/v1/ppt/endpoints/files.py b/servers/fastapi/api/v1/ppt/endpoints/files.py index e2f43329..b19e31d0 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/files.py +++ b/servers/fastapi/api/v1/ppt/endpoints/files.py @@ -1,13 +1,13 @@ from http.client import HTTPException import os from typing import Annotated, List, Optional -import uuid from fastapi import APIRouter, Body, File, UploadFile from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES from models.decomposed_file_info import DecomposedFileInfo from services import TEMP_FILE_SERVICE from services.documents_loader import DocumentsLoader +from utils.randomizers import get_random_uuid from utils.validators import validate_files FILES_ROUTER = APIRouter(prefix="/files", tags=["Files"]) @@ -18,7 +18,7 @@ async def upload_files(files: Optional[List[UploadFile]]): if not files: raise HTTPException(400, "Documents are required") - temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4())) + temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid()) validate_files(files, True, True, 50, UPLOAD_ACCEPTED_FILE_TYPES) @@ -39,7 +39,7 @@ async def upload_files(files: Optional[List[UploadFile]]): @FILES_ROUTER.post("/decompose", response_model=List[DecomposedFileInfo]) async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]): - temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4())) + temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid()) txt_files = [] other_files = [] @@ -56,7 +56,7 @@ async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]): response = [] for index, parsed_doc in enumerate(parsed_documents): file_path = TEMP_FILE_SERVICE.create_temp_file_path( - f"{str(uuid.uuid4())}.txt", temp_dir + f"{get_random_uuid()}.txt", temp_dir ) parsed_doc = parsed_doc.replace("
", "\n") with open(file_path, "w") as text_file: diff --git a/servers/fastapi/api/v1/ppt/endpoints/images.py b/servers/fastapi/api/v1/ppt/endpoints/images.py index 64acbde8..5af10e88 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/images.py +++ b/servers/fastapi/api/v1/ppt/endpoints/images.py @@ -1,6 +1,8 @@ from fastapi import APIRouter from models.image_prompt import ImagePrompt +from models.sql.image_asset import ImageAsset +from services.database import get_sql_session from services.image_generation_service import ImageGenerationService from utils.asset_directory_utils import get_images_directory @@ -13,4 +15,12 @@ async def generate_image(prompt: str): image_prompt = ImagePrompt(prompt=prompt) image_generation_service = ImageGenerationService(images_directory) - return await image_generation_service.generate_image(image_prompt) + image = await image_generation_service.generate_image(image_prompt) + if not isinstance(image, ImageAsset): + return image + + with get_sql_session() as sql_session: + sql_session.add(image) + sql_session.commit() + + return image.path diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py new file mode 100644 index 00000000..13e334a5 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -0,0 +1,72 @@ +import json +from typing import List +from fastapi import APIRouter, BackgroundTasks, HTTPException + +from api.v1.ppt.background_tasks import pull_ollama_model_background_task +from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS +from models.ollama_model_metadata import OllamaModelMetadata +from models.ollama_model_status import OllamaModelStatus +from services import REDIS_SERVICE +from utils.ollama import list_pulled_ollama_models + +OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"]) + + +@OLLAMA_ROUTER.get("/models/supported", response_model=List[OllamaModelMetadata]) +def get_supported_models(): + return SUPPORTED_OLLAMA_MODELS.values() + + +@OLLAMA_ROUTER.get("/models/available", response_model=List[OllamaModelStatus]) +async def get_available_models(): + return await list_pulled_ollama_models() + + +@OLLAMA_ROUTER.get("/model/pull", response_model=OllamaModelStatus) +async def pull_model(model: str, background_tasks: BackgroundTasks): + + if model not in SUPPORTED_OLLAMA_MODELS: + raise HTTPException( + status_code=400, + detail=f"Model {model} is not supported", + ) + + try: + pulled_models = await list_pulled_ollama_models() + filtered_models = [ + pulled_model for pulled_model in pulled_models if pulled_model.name == model + ] + if filtered_models: + return filtered_models[0] + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to check pulled models: {e}", + ) + + saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}") + + # If the model is being pulled, return the model + if saved_model_status: + saved_model_status_json = json.loads(saved_model_status) + # If the model is being pulled, return the model + # ? If the model status is pulled in redis but was not found while listing pulled models, + # ? it means the model was deleted and we need to pull it again + if ( + saved_model_status_json["status"] == "error" + or saved_model_status_json["status"] == "pulled" + ): + REDIS_SERVICE.delete(f"ollama_models/{model}") + else: + return saved_model_status_json + + # If the model is not being pulled, pull the model + background_tasks.add_task(pull_ollama_model_background_task, model) + + return OllamaModelStatus( + name=model, + status="pulling", + done=False, + ) diff --git a/servers/fastapi/api/v1/ppt/router.py b/servers/fastapi/api/v1/ppt/router.py index ddff5676..15fb1b52 100644 --- a/servers/fastapi/api/v1/ppt/router.py +++ b/servers/fastapi/api/v1/ppt/router.py @@ -1,8 +1,10 @@ from fastapi import APIRouter +from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER from api.v1.ppt.endpoints.files import FILES_ROUTER from api.v1.ppt.endpoints.icons import ICONS_ROUTER from api.v1.ppt.endpoints.images import IMAGES_ROUTER +from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER @@ -14,3 +16,5 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER) API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER) API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER) API_V1_PPT_ROUTER.include_router(ICONS_ROUTER) +API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER) +API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER) diff --git a/servers/fastapi/constants/supported_ollama_models.py b/servers/fastapi/constants/supported_ollama_models.py index 2589d54e..a46b5774 100644 --- a/servers/fastapi/constants/supported_ollama_models.py +++ b/servers/fastapi/constants/supported_ollama_models.py @@ -1,14 +1,14 @@ from models.ollama_model_metadata import OllamaModelMetadata -SUPPORTED_LLAMA_MODELS = { +SUPPORTED_OLLAMA_MODELS = { "llama3:8b": OllamaModelMetadata( label="Llama 3:8b", value="llama3:8b", description="❌ Graphs not supported.", size="4.7GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3:70b": OllamaModelMetadata( label="Llama 3:70b", @@ -16,7 +16,7 @@ SUPPORTED_LLAMA_MODELS = { description="✅ Graphs supported.", size="40GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3.1:8b": OllamaModelMetadata( label="Llama 3.1:8b", @@ -24,7 +24,7 @@ SUPPORTED_LLAMA_MODELS = { description="❌ Graphs not supported.", size="4.9GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3.1:70b": OllamaModelMetadata( label="Llama 3.1:70b", @@ -32,7 +32,7 @@ SUPPORTED_LLAMA_MODELS = { description="✅ Graphs supported.", size="43GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3.1:405b": OllamaModelMetadata( label="Llama 3.1:405b", @@ -40,7 +40,7 @@ SUPPORTED_LLAMA_MODELS = { description="✅ Graphs supported.", size="243GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3.2:1b": OllamaModelMetadata( label="Llama 3.2:1b", @@ -48,7 +48,7 @@ SUPPORTED_LLAMA_MODELS = { description="❌ Graphs not supported.", size="1.3GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3.2:3b": OllamaModelMetadata( label="Llama 3.2:3b", @@ -56,7 +56,7 @@ SUPPORTED_LLAMA_MODELS = { description="❌ Graphs not supported.", size="2GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama3.3:70b": OllamaModelMetadata( label="Llama 3.3:70b", @@ -64,7 +64,7 @@ SUPPORTED_LLAMA_MODELS = { description="✅ Graphs supported.", size="43GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama4:16x17b": OllamaModelMetadata( label="Llama 4:16x17b", @@ -72,7 +72,7 @@ SUPPORTED_LLAMA_MODELS = { description="✅ Graphs supported.", size="67GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), "llama4:128x17b": OllamaModelMetadata( label="Llama 4:128x17b", @@ -80,7 +80,7 @@ SUPPORTED_LLAMA_MODELS = { description="✅ Graphs supported.", size="245GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/meta.png", + icon="/static/icons/meta.png", ), } @@ -91,7 +91,7 @@ SUPPORTED_GEMMA_MODELS = { description="❌ Graphs not supported.", size="815MB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/gemma.png", + icon="/static/icons/gemma.png", ), "gemma3:4b": OllamaModelMetadata( label="Gemma 3:4b", @@ -99,7 +99,7 @@ SUPPORTED_GEMMA_MODELS = { description="❌ Graphs not supported.", size="3.3GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/gemma.png", + icon="/static/icons/gemma.png", ), "gemma3:12b": OllamaModelMetadata( label="Gemma 3:12b", @@ -107,7 +107,7 @@ SUPPORTED_GEMMA_MODELS = { description="❌ Graphs not supported.", size="8.1GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/gemma.png", + icon="/static/icons/gemma.png", ), "gemma3:27b": OllamaModelMetadata( label="Gemma 3:27b", @@ -115,7 +115,7 @@ SUPPORTED_GEMMA_MODELS = { description="✅ Graphs supported.", size="17GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/gemma.png", + icon="/static/icons/gemma.png", ), } @@ -126,7 +126,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="❌ Graphs not supported.", size="1.1GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), "deepseek-r1:7b": OllamaModelMetadata( label="DeepSeek R1:7b", @@ -134,7 +134,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="❌ Graphs not supported.", size="4.7GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), "deepseek-r1:8b": OllamaModelMetadata( label="DeepSeek R1:8b", @@ -142,7 +142,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="❌ Graphs not supported.", size="5.2GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), "deepseek-r1:14b": OllamaModelMetadata( label="DeepSeek R1:14b", @@ -150,7 +150,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="❌ Graphs not supported.", size="9GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), "deepseek-r1:32b": OllamaModelMetadata( label="DeepSeek R1:32b", @@ -158,7 +158,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="✅ Graphs supported.", size="20GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), "deepseek-r1:70b": OllamaModelMetadata( label="DeepSeek R1:70b", @@ -166,7 +166,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="✅ Graphs supported.", size="43GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), "deepseek-r1:671b": OllamaModelMetadata( label="DeepSeek R1:671b", @@ -174,7 +174,7 @@ SUPPORTED_DEEPSEEK_MODELS = { description="✅ Graphs supported.", size="404GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/deepseek.png", + icon="/static/icons/deepseek.png", ), } @@ -185,7 +185,7 @@ SUPPORTED_QWEN_MODELS = { description="❌ Graphs not supported.", size="523MB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:1.7b": OllamaModelMetadata( label="Qwen 3:1.7b", @@ -193,7 +193,7 @@ SUPPORTED_QWEN_MODELS = { description="❌ Graphs not supported.", size="1.4GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:4b": OllamaModelMetadata( label="Qwen 3:4b", @@ -201,7 +201,7 @@ SUPPORTED_QWEN_MODELS = { description="❌ Graphs not supported.", size="2.6GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:8b": OllamaModelMetadata( label="Qwen 3:8b", @@ -209,7 +209,7 @@ SUPPORTED_QWEN_MODELS = { description="❌ Graphs not supported.", size="5.2GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:14b": OllamaModelMetadata( label="Qwen 3:14b", @@ -217,7 +217,7 @@ SUPPORTED_QWEN_MODELS = { description="❌ Graphs not supported.", size="9.3GB", supports_graph=False, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:30b": OllamaModelMetadata( label="Qwen 3:30b", @@ -225,7 +225,7 @@ SUPPORTED_QWEN_MODELS = { description="✅ Graphs supported.", size="19GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:32b": OllamaModelMetadata( label="Qwen 3:32b", @@ -233,7 +233,7 @@ SUPPORTED_QWEN_MODELS = { description="✅ Graphs supported.", size="20GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), "qwen3:235b": OllamaModelMetadata( label="Qwen 3:235b", @@ -241,12 +241,12 @@ SUPPORTED_QWEN_MODELS = { description="✅ Graphs supported.", size="142GB", supports_graph=True, - icon="/static/servers/fastapi/assets/icons/qwen.png", + icon="/static/icons/qwen.png", ), } SUPPORTED_OLLAMA_MODELS = { - **SUPPORTED_LLAMA_MODELS, + **SUPPORTED_OLLAMA_MODELS, **SUPPORTED_GEMMA_MODELS, **SUPPORTED_DEEPSEEK_MODELS, **SUPPORTED_QWEN_MODELS, diff --git a/servers/fastapi/enums/image_provider.py b/servers/fastapi/enums/image_provider.py new file mode 100644 index 00000000..2c7b3bb2 --- /dev/null +++ b/servers/fastapi/enums/image_provider.py @@ -0,0 +1,7 @@ +from enum import Enum + +class ImageProvider(Enum): + PEXELS = "pexels" + PIXABAY = "pixabay" + GEMINI_FLASH = "gemini_flash" + DALLE3 = "dall-e-3" diff --git a/servers/fastapi/models/pptx_models.py b/servers/fastapi/models/pptx_models.py index ba565ca5..12a8aa91 100644 --- a/servers/fastapi/models/pptx_models.py +++ b/servers/fastapi/models/pptx_models.py @@ -144,6 +144,7 @@ class PptxConnectorModel(PptxShapeModel): class PptxSlideModel(BaseModel): + background: Optional[PptxFillModel] = None shapes: List[ PptxTextBoxModel | PptxAutoShapeBoxModel diff --git a/servers/fastapi/models/user_config.py b/servers/fastapi/models/user_config.py index e04ab5e2..930aa1e5 100644 --- a/servers/fastapi/models/user_config.py +++ b/servers/fastapi/models/user_config.py @@ -12,3 +12,5 @@ class UserConfig(BaseModel): CUSTOM_LLM_API_KEY: Optional[str] = None CUSTOM_MODEL: Optional[str] = None PEXELS_API_KEY: Optional[str] = None + IMAGE_PROVIDER: Optional[str] = None + PIXABAY_API_KEY: Optional[str] = None diff --git a/servers/fastapi/requirements.txt b/servers/fastapi/requirements.txt index 3edb9b66..afeb38a0 100644 --- a/servers/fastapi/requirements.txt +++ b/servers/fastapi/requirements.txt @@ -3,6 +3,7 @@ aiohttp==3.12.14 aiosignal==1.4.0 annotated-types==0.7.0 anyio==4.9.0 +async-timeout==5.0.1 attrs==25.3.0 cachetools==5.5.2 certifi==2025.7.14 @@ -55,6 +56,7 @@ python-dotenv==1.1.1 python-multipart==0.0.20 python-pptx==1.0.2 PyYAML==6.0.2 +redis==6.2.0 requests==2.32.4 rich==14.0.0 rich-toolkit==0.14.8 diff --git a/servers/fastapi/services/__init__.py b/servers/fastapi/services/__init__.py index 89bac591..56843e2b 100644 --- a/servers/fastapi/services/__init__.py +++ b/servers/fastapi/services/__init__.py @@ -1,6 +1,8 @@ +from services.redis_service import RedisService from services.temp_file_service import TempFileService from services.database import sql_engine TEMP_FILE_SERVICE = TempFileService() SQL_ENGINE = sql_engine +REDIS_SERVICE = RedisService() diff --git a/servers/fastapi/services/image_generation_service.py b/servers/fastapi/services/image_generation_service.py index 4acff9f1..4299ced1 100644 --- a/servers/fastapi/services/image_generation_service.py +++ b/servers/fastapi/services/image_generation_service.py @@ -8,10 +8,13 @@ from models.image_prompt import ImagePrompt from models.sql.image_asset import ImageAsset from utils.download_helpers import download_file from utils.get_env import get_pexels_api_key_env -from utils.llm_provider import ( - get_llm_client, - is_google_selected, - is_openai_selected, +from utils.get_env import get_pixabay_api_key_env +from utils.llm_provider import get_llm_client +from utils.image_provider import ( + is_pixels_selected, + is_pixabay_selected, + is_gemini_flash_selected, + is_dalle3_selected, ) @@ -19,32 +22,46 @@ class ImageGenerationService: def __init__(self, output_directory: str): self.output_directory = output_directory - - self.use_pexels = False - if get_pexels_api_key_env(): - self.use_pexels = True - self.image_gen_func = self.get_image_gen_func() def get_image_gen_func(self): - if self.use_pexels: + if is_pixabay_selected(): + return self.get_image_from_pixabay + elif is_pixels_selected(): return self.get_image_from_pexels - elif is_google_selected(): + elif is_gemini_flash_selected(): return self.generate_image_google - elif is_openai_selected(): + elif is_dalle3_selected(): return self.generate_image_openai return None + def is_stock_provider_selected(self): + return is_pixels_selected() or is_pixabay_selected() + async def generate_image(self, prompt: ImagePrompt) -> str | ImageAsset: + """ + Generates an image based on the provided prompt. + - If no image generation function is available, returns a placeholder image. + - If the stock provider is selected, it uses the prompt directly, + otherwise it uses the full image prompt with theme. + - Output Directory is used for saving the generated image not the stock provider. + """ if not self.image_gen_func: print("No image generation function found. Using placeholder image.") return "/static/images/placeholder.jpg" - image_prompt = prompt.get_image_prompt(not self.use_pexels) + image_prompt = prompt.get_image_prompt( + with_theme=not self.is_stock_provider_selected() + ) print(f"Request - Generating Image for {image_prompt}") try: - image_path = await self.image_gen_func(image_prompt, self.output_directory) + if self.is_stock_provider_selected(): + image_path = await self.image_gen_func(image_prompt) + else: + image_path = await self.image_gen_func( + image_prompt, self.output_directory + ) if image_path: if image_path.startswith("http"): return image_path @@ -102,3 +119,12 @@ class ImageGenerationService: data = await response.json() image_url = data["photos"][0]["src"]["large"] return image_url + + async def get_image_from_pixabay(self, prompt: str) -> str: + async with aiohttp.ClientSession() as session: + response = await session.get( + f"https://pixabay.com/api/?key={get_pixabay_api_key_env()}&q={prompt}&image_type=photo&per_page=1" + ) + data = await response.json() + image_url = data["hits"][0]["largeImageURL"] + return image_url diff --git a/servers/fastapi/services/pptx_presentation_creator.py b/servers/fastapi/services/pptx_presentation_creator.py index af21f0fd..875951b5 100644 --- a/servers/fastapi/services/pptx_presentation_creator.py +++ b/servers/fastapi/services/pptx_presentation_creator.py @@ -108,6 +108,9 @@ class PptxPresentationCreator: def add_and_populate_slide(self, slide_model: PptxSlideModel): slide = self._ppt.slides.add_slide(self._ppt.slide_layouts[BLANK_SLIDE_LAYOUT]) + if slide_model.background: + self.apply_fill_to_shape(slide.background, slide_model.background) + for shape_model in slide_model.shapes: model_type = type(shape_model) diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py new file mode 100644 index 00000000..f2e3d8c9 --- /dev/null +++ b/servers/fastapi/services/redis_service.py @@ -0,0 +1,115 @@ +from typing import Any, Optional +import redis +from redis.exceptions import RedisError + +from utils.get_env import ( + get_redis_db_env, + get_redis_host_env, + get_redis_password_env, + get_redis_port_env, +) + + +class RedisService: + def __init__(self): + self.redis_host = get_redis_host_env() or "localhost" + self.redis_port = int(get_redis_port_env() or "6379") + self.redis_db = int(get_redis_db_env() or "0") + self.redis_password = get_redis_password_env() or None + self.client = self._create_client() + + def _create_client(self) -> redis.Redis: + return redis.Redis( + host=self.redis_host, + port=self.redis_port, + db=self.redis_db, + password=self.redis_password, + decode_responses=True, + ) + + def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: + try: + return self.client.set(key, value, ex=expire) + except RedisError: + return False + + def get(self, key: str) -> Optional[str]: + try: + return self.client.get(key) + except RedisError: + return None + + def delete(self, key: str) -> bool: + try: + return bool(self.client.delete(key)) + except RedisError: + return False + + def exists(self, key: str) -> bool: + try: + return bool(self.client.exists(key)) + except RedisError: + return False + + def set_hash(self, name: str, mapping: dict) -> bool: + try: + return self.client.hmset(name, mapping) + except RedisError: + return False + + def get_hash(self, name: str) -> Optional[dict]: + try: + return self.client.hgetall(name) + except RedisError: + return None + + def delete_hash(self, name: str, *fields: str) -> int: + try: + return self.client.hdel(name, *fields) + except RedisError: + return 0 + + def set_list(self, name: str, values: list) -> bool: + try: + self.client.delete(name) + if values: + self.client.rpush(name, *values) + return True + except RedisError: + return False + + def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: + try: + return self.client.lrange(name, start, end) + except RedisError: + return None + + def add_to_set(self, name: str, *values: str) -> int: + try: + return self.client.sadd(name, *values) + except RedisError: + return 0 + + def get_set(self, name: str) -> Optional[set]: + try: + return self.client.smembers(name) + except RedisError: + return None + + def remove_from_set(self, name: str, *values: str) -> int: + try: + return self.client.srem(name, *values) + except RedisError: + return 0 + + def clear(self) -> bool: + try: + return self.client.flushdb() + except RedisError: + return False + + def close(self): + try: + self.client.close() + except RedisError: + pass diff --git a/servers/fastapi/services/temp_file_service.py b/servers/fastapi/services/temp_file_service.py index 31a39035..f4c59cf5 100644 --- a/servers/fastapi/services/temp_file_service.py +++ b/servers/fastapi/services/temp_file_service.py @@ -2,11 +2,13 @@ import os import uuid from typing import Optional, Union +from utils.get_env import get_temp_directory_env + class TempFileService: def __init__(self): - self.base_dir = os.getenv("TEMP_DIRECTORY") + self.base_dir = get_temp_directory_env() # TODO: Uncomment this when we want to cleanup the base dir on startup # self.cleanup_base_dir() os.makedirs(self.base_dir, exist_ok=True) diff --git a/servers/fastapi/tests/test_image_generation.py b/servers/fastapi/tests/test_image_generation.py new file mode 100644 index 00000000..bf0db108 --- /dev/null +++ b/servers/fastapi/tests/test_image_generation.py @@ -0,0 +1,400 @@ +import pytest +import asyncio +import os +from unittest.mock import Mock, patch, AsyncMock +import httpx +from fastapi.testclient import TestClient +from fastapi import FastAPI +from api.v1.ppt.endpoints.images import IMAGES_ROUTER +from models.image_prompt import ImagePrompt +from services.image_generation_service import ImageGenerationService +from models.sql.image_asset import ImageAsset + + +class TestImageGenerationService: + """ + Testing the image Generation Service + """ + + @pytest.fixture + def mock_images_directory(self, tmp_path): + """ + Creates new images directory for every test case we run + """ + images_dir = tmp_path / "images" + images_dir.mkdir() + return str(images_dir) + + @pytest.fixture + def sample_image_prompt(self): + """ + Creates a sample ImagePrompt for testing + """ + return ImagePrompt(prompt="A beautiful sunset over mountains") + + def test_image_generation_service_initialization(self, mock_images_directory): + """ + Test initialization of ImageGenerationService with output directory + - Checks if the output directory is set correctly + - Checks if the image generation function is set based on environment variable + """ + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}): + service = ImageGenerationService(mock_images_directory) + assert service.output_directory == mock_images_directory + assert service.image_gen_func is not None or service.image_gen_func is None + + def test_get_image_gen_func_pixabay_selected(self, mock_images_directory): + """ + Testing the function selection when Pixabay is selected + - Checks if the correct function is selected based on environment variable + - Ensures that the function is set to get_image_from_pixabay when Pixabay is selected + """ + with patch('services.image_generation_service.is_pixabay_selected', return_value=True): + with patch('services.image_generation_service.is_pixels_selected', return_value=False): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=False): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}): + service = ImageGenerationService(mock_images_directory) + assert service.image_gen_func == service.get_image_from_pixabay + + def test_get_image_gen_func_pexels_selected(self, mock_images_directory): + """ + Test function selection when Pexels is selected + - Checks if the correct function is selected based on environment variable + - Ensures that the function is set to get_image_from_pexels when Pexels is selected + """ + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch('services.image_generation_service.is_pixels_selected', return_value=True): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=False): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}): + service = ImageGenerationService(mock_images_directory) + assert service.image_gen_func == service.get_image_from_pexels + + def test_get_image_gen_func_dalle3_selected(self, mock_images_directory): + """ + Test function selection when DALL-E 3 is selected + - Checks if the correct function is selected based on environment variable + - Ensures that the function is set to generate_image_openai when DALL-E 3 is selected + """ + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch('services.image_generation_service.is_pixels_selected', return_value=False): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=True): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}): + service = ImageGenerationService(mock_images_directory) + assert service.image_gen_func == service.generate_image_openai + + def test_is_stock_provider_selected(self, mock_images_directory): + """ + Test if stock provider is selected based on environment variable + - Checks if the stock provider is selected correctly based on environment variable + - Ensures that is_stock_provider_selected returns True for Pexels or Pixabay + """ + with patch('services.image_generation_service.is_pixels_selected', return_value=True): + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}): + service = ImageGenerationService(mock_images_directory) + assert service.is_stock_provider_selected() is True + + with patch('services.image_generation_service.is_pixels_selected', return_value=False): + with patch('services.image_generation_service.is_pixabay_selected', return_value=True): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay"}): + service = ImageGenerationService(mock_images_directory) + assert service.is_stock_provider_selected() is True + + with patch('services.image_generation_service.is_pixels_selected', return_value=False): + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}): + service = ImageGenerationService(mock_images_directory) + assert service.is_stock_provider_selected() is False + + def test_generate_image_with_pexels_success(self, mock_images_directory, sample_image_prompt): + """ + Test successful image generation with Pexels provider + - Mocks the Pexels API to return a valid image URL + - Ensures that the image generation function returns the expected URL + - Checks if the image generation function is called with the correct prompt + """ + async def run_test(): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_key"}): + with patch('services.image_generation_service.is_pixels_selected', return_value=True): + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=False): + service = ImageGenerationService(mock_images_directory) + + mock_response = AsyncMock() + mock_response.json = AsyncMock(return_value={ + "photos": [{ + "src": { + "large": "https://example.com/image.jpg" + } + }] + }) + + mock_session = AsyncMock() + mock_session.get = AsyncMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch('aiohttp.ClientSession', return_value=mock_session): + result = await service.generate_image(sample_image_prompt) + assert result == "https://example.com/image.jpg" + + asyncio.run(run_test()) + + def test_generate_image_with_dalle3_success(self, mock_images_directory, sample_image_prompt): + """ + Test successful image generation with DALL-E 3 provider + - Mocks the OpenAI client to return a valid image URL + - Ensures that the image generation function returns the expected URL + - Checks if the image generation function is called with the correct prompt + """ + async def run_test(): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "dall-e-3"}): + with patch('services.image_generation_service.is_pixels_selected', return_value=False): + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=True): + service = ImageGenerationService(mock_images_directory) + + # Create a real test file + test_image_path = f"{mock_images_directory}/test_image.jpg" + with open(test_image_path, 'w') as f: + f.write("fake image content") + + # Mock generate_image_openai to return the test file path + async def mock_openai_generate(prompt, output_dir): + return test_image_path + + service.generate_image_openai = mock_openai_generate + + result = await service.generate_image(sample_image_prompt) + + # Should return ImageAsset for AI providers + assert isinstance(result, ImageAsset) + assert result.path == test_image_path + assert result.extras["prompt"] == sample_image_prompt.prompt + + def test_generate_image_no_provider_selected(self, mock_images_directory, sample_image_prompt): + """ + Test generate_image when no provider is selected + - Mocks the environment variable to simulate no provider selected + - Ensures that the function returns a placeholder image path + - Checks if the image generation function is called with the correct prompt + """ + async def run_test(): + with patch('services.image_generation_service.is_pixels_selected', return_value=False): + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=False): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}): + service = ImageGenerationService(mock_images_directory) + + result = await service.generate_image(sample_image_prompt) + + # Should return placeholder + assert result == "/static/images/placeholder.jpg" + + asyncio.run(run_test()) + + def test_generate_image_provider_error(self, mock_images_directory, sample_image_prompt): + """ + Test generate_image when provider function raises an error + - Mocks the Pexels API to raise an exception + - Ensures that the function returns a placeholder image path + - Checks if the image generation function is called with the correct prompt + """ + async def run_test(): + with patch('services.image_generation_service.is_pixels_selected', return_value=True): + with patch('services.image_generation_service.is_pixabay_selected', return_value=False): + with patch('services.image_generation_service.is_gemini_flash_selected', return_value=False): + with patch('services.image_generation_service.is_dalle3_selected', return_value=False): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels"}): + service = ImageGenerationService(mock_images_directory) + + async def mock_pexels_error(*args, **kwargs): + raise Exception("API Error") + + service.get_image_from_pexels = mock_pexels_error + + result = await service.generate_image(sample_image_prompt) + + assert result == "/static/images/placeholder.jpg" + + asyncio.run(run_test()) + + def test_get_image_from_pexels_real_function(self, mock_images_directory): + """T + Test REAL Pexels function with mocked HTTP call + - Mocks the Pexels API to return a valid image URL + - Ensures that the function returns the expected URL + - Checks if the HTTP call is made with the correct parameters + """ + async def run_test(): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pexels", "PEXELS_API_KEY": "test_pexels_key"}): + service = ImageGenerationService(mock_images_directory) + + mock_response = AsyncMock() + mock_response.json = AsyncMock(return_value={ + "photos": [{ + "src": { + "large": "https://example.com/pexels_image.jpg" + } + }] + }) + + mock_session = AsyncMock() + mock_session.get = AsyncMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch('aiohttp.ClientSession', return_value=mock_session): + result = await service.get_image_from_pexels("sunset") + + assert result == "https://example.com/pexels_image.jpg" + mock_session.get.assert_called_once() + + asyncio.run(run_test()) + + def test_get_image_from_pixabay_real_function(self, mock_images_directory): + """ + Test REAL Pixabay function with mocked HTTP call + - Mocks the Pixabay API to return a valid image URL + - Ensures that the function returns the expected URL + - Checks if the HTTP call is made with the correct parameters + """ + async def run_test(): + with patch.dict(os.environ, {"IMAGE_PROVIDER": "pixabay", "PIXABAY_API_KEY": "test_pixabay_key"}): + service = ImageGenerationService(mock_images_directory) + + mock_response = AsyncMock() + mock_response.json = AsyncMock(return_value={ + "hits": [{ + "largeImageURL": "https://example.com/pixabay_image.jpg" + }] + }) + + mock_session = AsyncMock() + mock_session.get = AsyncMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch('aiohttp.ClientSession', return_value=mock_session): + result = await service.get_image_from_pixabay("sunset") + + assert result == "https://example.com/pixabay_image.jpg" + mock_session.get.assert_called_once() + + asyncio.run(run_test()) + + +class TestImageGenerationEndpoint: + """ + Testing the Image Generation API Endpoint + """ + + @pytest.fixture + def app(self): + """Create FastAPI app with the images router""" + app = FastAPI() + app.include_router(IMAGES_ROUTER) + return app + + @pytest.fixture + def client(self, app): + """Create test client""" + return TestClient(app) + + @pytest.fixture + def mock_images_directory(self, tmp_path): + """Mock images directory""" + images_dir = tmp_path / "images" + images_dir.mkdir() + return str(images_dir) + + def test_generate_image_endpoint_success_stock_provider(self, client, mock_images_directory): + """ + Test successful image generation via API endpoint with stock provider + - Mocks the ImageGenerationService to return a stock image URL + - Ensures that the endpoint returns the expected URL + - Checks if the image generation function is called with the correct prompt + """ + test_prompt = "A beautiful sunset over mountains" + + with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory): + with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class: + mock_service_instance = Mock() + mock_service_instance.generate_image = AsyncMock(return_value="https://example.com/stock_image.jpg") + mock_service_class.return_value = mock_service_instance + response = client.get(f"/images/generate?prompt={test_prompt}") + assert response.status_code == 200 + + def test_generate_image_endpoint_success_ai_provider(self, client, mock_images_directory): + """ + Test successful image generation via API endpoint with AI provider + - Mocks the ImageGenerationService to return an ImageAsset object + - Ensures that the endpoint returns the expected ImageAsset object + - Checks if the image generation function is called with the correct prompt + """ + test_prompt = "A beautiful sunset over mountains" + + test_image_asset = ImageAsset( + path=f"{mock_images_directory}/test_image.jpg", + extras={"prompt": test_prompt, "theme_prompt": "professional"} + ) + + with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory): + with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class: + mock_service_instance = Mock() + mock_service_instance.generate_image = AsyncMock(return_value=test_image_asset) + mock_service_class.return_value = mock_service_instance + + response = client.get(f"/images/generate?prompt={test_prompt}") + + assert response.status_code == 200 + + def test_generate_image_endpoint_placeholder_response(self, client, mock_images_directory): + """ + Test endpoint returns placeholder image when no provider is selected + - Mocks the ImageGenerationService to return a placeholder image path + - Ensures that the endpoint returns the placeholder image path + - Checks if the image generation function is called with the correct prompt + """ + test_prompt = "Test prompt" + + with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory): + with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class: + mock_service_instance = Mock() + mock_service_instance.generate_image = AsyncMock(return_value="/static/images/placeholder.jpg") + mock_service_class.return_value = mock_service_instance + + response = client.get(f"/images/generate?prompt={test_prompt}") + + assert response.status_code == 200 + + def test_generate_image_endpoint_with_async_client(self, mock_images_directory): + """ + Test the image generation endpoint using an async client + - Mocks the ImageGenerationService to return a valid image URL + - Ensures that the endpoint returns the expected URL + - Checks if the image generation function is called with the correct prompt + """ + async def run_test(): + app = FastAPI() + app.include_router(IMAGES_ROUTER) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as ac: + with patch('api.v1.ppt.endpoints.images.get_images_directory', return_value=mock_images_directory): + with patch('api.v1.ppt.endpoints.images.ImageGenerationService') as mock_service_class: + mock_service_instance = Mock() + mock_service_instance.generate_image = AsyncMock(return_value="https://example.com/image.jpg") + mock_service_class.return_value = mock_service_instance + + response = await ac.get("/images/generate?prompt=test") + assert response.status_code == 200 + + asyncio.run(run_test()) + diff --git a/servers/fastapi/utils/get_env.py b/servers/fastapi/utils/get_env.py index 6544498b..265e3698 100644 --- a/servers/fastapi/utils/get_env.py +++ b/servers/fastapi/utils/get_env.py @@ -55,3 +55,27 @@ def get_custom_model_env(): def get_pexels_api_key_env(): return os.getenv("PEXELS_API_KEY") + + +def get_image_provider_env(): + return os.getenv("IMAGE_PROVIDER") + + +def get_pixabay_api_key_env(): + return os.getenv("PIXABAY_API_KEY") + + +def get_redis_host_env(): + return os.getenv("REDIS_HOST") + + +def get_redis_port_env(): + return os.getenv("REDIS_PORT") + + +def get_redis_db_env(): + return os.getenv("REDIS_DB") + + +def get_redis_password_env(): + return os.getenv("REDIS_PASSWORD") diff --git a/servers/fastapi/utils/image_provider.py b/servers/fastapi/utils/image_provider.py new file mode 100644 index 00000000..8a2f01ff --- /dev/null +++ b/servers/fastapi/utils/image_provider.py @@ -0,0 +1,47 @@ +from enums.image_provider import ImageProvider +from utils.get_env import ( + get_google_api_key_env, + get_image_provider_env, + get_openai_api_key_env, + get_pexels_api_key_env, + get_pixabay_api_key_env, +) + + +def is_pixels_selected() -> bool: + return ImageProvider.PEXELS == get_selected_image_provider() + + +def is_pixabay_selected() -> bool: + return ImageProvider.PIXABAY == get_selected_image_provider() + + +def is_gemini_flash_selected() -> bool: + return ImageProvider.GEMINI_FLASH == get_selected_image_provider() + + +def is_dalle3_selected() -> bool: + return ImageProvider.DALLE3 == get_selected_image_provider() + + +def get_selected_image_provider() -> ImageProvider: + """ + Get the selected image provider from environment variables. + Returns: + ImageProvider: The selected image provider. + """ + return ImageProvider(get_image_provider_env()) + + +def get_image_provider_api_key() -> str: + selected_image_provider = get_selected_image_provider() + if selected_image_provider == ImageProvider.PEXELS: + return get_pexels_api_key_env() + elif selected_image_provider == ImageProvider.PIXABAY: + return get_pixabay_api_key_env() + elif selected_image_provider == ImageProvider.GEMINI_FLASH: + return get_google_api_key_env() + elif selected_image_provider == ImageProvider.DALLE3: + return get_openai_api_key_env() + else: + raise ValueError(f"Invalid image provider: {selected_image_provider}") diff --git a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py index f3fac493..3d0ac08f 100644 --- a/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py +++ b/servers/fastapi/utils/llm_calls/generate_presentation_outlines.py @@ -11,29 +11,7 @@ from utils.llm_provider import ( is_google_selected, ) -# system_prompt = """ -# Create a presentation based on the provided prompt, number of slides, output language, and additional informational details. -# Format the output in the specified JSON schema with structured markdown content. -# # Steps - -# 1. Identify key points from the provided prompt, including the topic, number of slides, output language, and additional content directions. -# 2. Create a concise and descriptive title reflecting the main topic, adhering to the specified language. -# 3. Generate a clear title for each slide. -# 4. Develop comprehensive content using markdown structure: -# * Use bullet points (- or *) for lists. -# * Use **bold** for emphasis, *italic* for secondary emphasis, and `code` for technical terms. -# 5. Provide important points from prompt as notes. - -# # Notes -# - Content must be generated for every slide. -# - Images or Icons information provided in **Input** must be included in the **notes**. -# - Notes should cleary define if it is for specific slide or for the presentation. -# - Slide **body** should not contain slide **title**. -# - Slide **title** should not contain "Slide 1", "Slide 2", etc. -# - Slide **title** should not be in markdown format. -# - There must be exact **Number of Slides** as specified. -# """ system_prompt = """ You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content. @@ -183,13 +161,7 @@ async def generate_ppt_outline( async with client.beta.chat.completions.stream( model=model, messages=get_prompt_template(prompt, n_slides, language, content), - response_format={ - "type": "json_schema", - "json_schema": { - "name": "PresentationOutline", - "schema": response_model.model_json_schema(), - }, - }, + response_format=response_model, ) as stream: async for event in stream: if isinstance(event, ContentDeltaEvent): diff --git a/servers/fastapi/utils/llm_provider.py b/servers/fastapi/utils/llm_provider.py index bb069773..7999ad4a 100644 --- a/servers/fastapi/utils/llm_provider.py +++ b/servers/fastapi/utils/llm_provider.py @@ -7,8 +7,10 @@ from enums.llm_provider import LLMProvider from utils.get_env import ( get_custom_llm_api_key_env, get_custom_llm_url_env, + get_custom_model_env, get_google_api_key_env, get_llm_provider_env, + get_ollama_model_env, get_ollama_url_env, get_openai_api_key_env, ) @@ -93,9 +95,9 @@ def get_large_model(): elif selected_llm == LLMProvider.GOOGLE: return "gemini-2.0-flash" elif selected_llm == LLMProvider.OLLAMA: - return os.getenv("OLLAMA_MODEL") + return get_ollama_model_env() elif selected_llm == LLMProvider.CUSTOM: - return os.getenv("CUSTOM_MODEL") + return get_custom_model_env() else: raise ValueError(f"Invalid LLM model") @@ -107,9 +109,9 @@ def get_small_model(): elif selected_llm == LLMProvider.GOOGLE: return "gemini-2.0-flash" elif selected_llm == LLMProvider.OLLAMA: - return os.getenv("OLLAMA_MODEL") + return get_ollama_model_env() elif selected_llm == LLMProvider.CUSTOM: - return os.getenv("CUSTOM_MODEL") + return get_custom_model_env() else: raise ValueError(f"Invalid LLM model") @@ -121,8 +123,8 @@ def get_nano_model(): elif selected_llm == LLMProvider.GOOGLE: return "gemini-2.0-flash" elif selected_llm == LLMProvider.OLLAMA: - return os.getenv("OLLAMA_MODEL") + return get_ollama_model_env() elif selected_llm == LLMProvider.CUSTOM: - return os.getenv("CUSTOM_MODEL") + return get_custom_model_env() else: raise ValueError(f"Invalid LLM model") diff --git a/servers/fastapi/utils/model_availability.py b/servers/fastapi/utils/model_availability.py index 7d56291a..c1dce981 100644 --- a/servers/fastapi/utils/model_availability.py +++ b/servers/fastapi/utils/model_availability.py @@ -2,30 +2,46 @@ import os from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS from enums.llm_provider import LLMProvider from utils.custom_llm_provider import list_available_custom_models -from utils.get_env import get_can_change_keys_env +from utils.get_env import ( + get_can_change_keys_env, + get_openai_api_key_env, + get_pixabay_api_key_env, + get_pexels_api_key_env, +) +from utils.get_env import get_google_api_key_env +from utils.get_env import get_ollama_model_env +from utils.get_env import get_custom_llm_api_key_env +from utils.get_env import get_custom_llm_url_env +from utils.get_env import get_custom_model_env from utils.llm_provider import ( get_llm_provider, is_custom_llm_selected, is_ollama_selected, ) from utils.ollama import pull_ollama_model +from utils.image_provider import ( + is_pixels_selected, + is_pixabay_selected, + is_gemini_flash_selected, + is_dalle3_selected, +) -async def check_llm_model_availability(): +async def check_llm_and_image_provider_api_or_model_availability(): can_change_keys = get_can_change_keys_env() != "false" if not can_change_keys: if get_llm_provider() == LLMProvider.OPENAI: - openai_api_key = os.getenv("OPENAI_API_KEY") + openai_api_key = get_openai_api_key_env() if not openai_api_key: raise Exception("OPENAI_API_KEY must be provided") elif get_llm_provider() == LLMProvider.GOOGLE: - google_api_key = os.getenv("GOOGLE_API_KEY") + google_api_key = get_google_api_key_env() if not google_api_key: raise Exception("GOOGLE_API_KEY must be provided") elif is_ollama_selected(): - ollama_model = os.getenv("OLLAMA_MODEL") + ollama_model = get_ollama_model_env() if not ollama_model: raise Exception("OLLAMA_MODEL must be provided") @@ -40,9 +56,9 @@ async def check_llm_model_availability(): print("-" * 50) elif is_custom_llm_selected(): - custom_model = os.getenv("CUSTOM_MODEL") - custom_llm_url = os.getenv("CUSTOM_LLM_URL") - custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY") + custom_model = get_custom_model_env() + custom_llm_url = get_custom_llm_url_env() + custom_llm_api_key = get_custom_llm_api_key_env() if not custom_model: raise Exception("CUSTOM_MODEL must be provided") if not custom_llm_url: @@ -58,3 +74,22 @@ async def check_llm_model_availability(): print("-" * 50) if custom_model not in models: raise Exception(f"Model {custom_model} is not available") + elif is_pixels_selected(): + pexels_api_key = get_pexels_api_key_env() + if not pexels_api_key: + raise Exception("PEXELS_API_KEY must be provided") + + elif is_pixabay_selected(): + pixabay_api_key = get_pixabay_api_key_env() + if not pixabay_api_key: + raise Exception("PIXABAY_API_KEY must be provided") + + elif is_gemini_flash_selected(): + google_api_key = get_google_api_key_env() + if not google_api_key: + raise Exception("GOOGLE_API_KEY must be provided") + + elif is_dalle3_selected(): + openai_api_key = get_openai_api_key_env() + if not openai_api_key: + raise Exception("OPENAI_API_KEY must be provided") diff --git a/servers/fastapi/utils/set_env.py b/servers/fastapi/utils/set_env.py index c43d85a9..fbfaf221 100644 --- a/servers/fastapi/utils/set_env.py +++ b/servers/fastapi/utils/set_env.py @@ -43,3 +43,10 @@ def set_custom_model_env(value): def set_pexels_api_key_env(value): os.environ["PEXELS_API_KEY"] = value + +def set_image_provider_env(value): + os.environ["IMAGE_PROVIDER"] = value + + +def set_pixabay_api_key_env(value): + os.environ["PIXABAY_API_KEY"] = value \ No newline at end of file diff --git a/servers/fastapi/utils/user_config.py b/servers/fastapi/utils/user_config.py index af4d6624..b1065a2b 100644 --- a/servers/fastapi/utils/user_config.py +++ b/servers/fastapi/utils/user_config.py @@ -13,6 +13,8 @@ from utils.get_env import ( get_openai_api_key_env, get_pexels_api_key_env, get_user_config_path_env, + get_image_provider_env, + get_pixabay_api_key_env ) from utils.set_env import ( set_custom_llm_api_key_env, @@ -24,6 +26,8 @@ from utils.set_env import ( set_ollama_url_env, set_openai_api_key_env, set_pexels_api_key_env, + set_image_provider_env, + set_pixabay_api_key_env ) @@ -49,6 +53,8 @@ def get_user_config(): CUSTOM_LLM_API_KEY=existing_config.CUSTOM_LLM_API_KEY or get_custom_llm_api_key_env(), CUSTOM_MODEL=existing_config.CUSTOM_MODEL or get_custom_model_env(), + IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(), + PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(), ) @@ -71,5 +77,9 @@ def update_env_with_user_config(): set_custom_llm_api_key_env(user_config.CUSTOM_LLM_API_KEY) if user_config.CUSTOM_MODEL: set_custom_model_env(user_config.CUSTOM_MODEL) + if user_config.IMAGE_PROVIDER: + set_image_provider_env(user_config.IMAGE_PROVIDER) + if user_config.PIXABAY_API_KEY: + set_pixabay_api_key_env(user_config.PIXABAY_API_KEY) if user_config.PEXELS_API_KEY: set_pexels_api_key_env(user_config.PEXELS_API_KEY) diff --git a/servers/nextjs/app/(presentation-generator)/components/ChartEditor.tsx b/servers/nextjs/app/(presentation-generator)/components/ChartEditor.tsx deleted file mode 100644 index f672d01d..00000000 --- a/servers/nextjs/app/(presentation-generator)/components/ChartEditor.tsx +++ /dev/null @@ -1,505 +0,0 @@ -import React, { useState } from 'react'; -import { - Sheet, - SheetContent, - SheetTitle, - SheetHeader, -} from "@/components/ui/sheet"; -import { Button } from '@/components/ui/button'; -import { Plus, ChevronDown, Trash, BarChart3, PieChart as PieChartIcon, LineChart as LineChartIcon } from 'lucide-react'; -import { Input } from '@/components/ui/input'; -import { StoreChartData } from '../utils/chartDataTransforms'; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuTrigger, - DropdownMenuItem, -} from "@/components/ui/dropdown-menu"; -import { renderChart } from './slide_config'; -import { useSelector } from 'react-redux'; -import { RootState } from '@/store/store'; -import { Label } from '@/components/ui/label'; -import { Switch } from '@/components/ui/switch'; -import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; -import { ChartSettings } from '@/store/slices/presentationGeneration'; - -interface ChartEditorProps { - isOpen: boolean; - onClose: () => void; - chartData: StoreChartData; - onChartDataChange: (newData: StoreChartData) => void; - chartSettings: ChartSettings; - setChartSettings: (newSettings: ChartSettings) => void; -} - -const ChartEditor = ({ isOpen, onClose, chartData, onChartDataChange, chartSettings, setChartSettings }: ChartEditorProps) => { - const [selectedCell, setSelectedCell] = useState<{ row: number; col: number } | null>(null); - const { currentColors } = useSelector((state: RootState) => state.theme); - - const handleCategoryChange = (index: number, value: string) => { - const newData = { - ...chartData, - data: { - ...chartData.data, - categories: [ - ...chartData.data.categories.slice(0, index), - value, - ...chartData.data.categories.slice(index + 1) - ] - } - }; - onChartDataChange(newData); - }; - - - const handleValueChange = (categoryIndex: number, seriesIndex: number, value: string) => { - const newData = { - ...chartData, - data: { - ...chartData.data, - series: chartData.data.series.map((series, idx) => { - if (idx === seriesIndex) { - return { - ...series, - data: [...series.data.slice(0, categoryIndex), Number(value), ...series.data.slice(categoryIndex + 1)] - }; - } - return series; - }) - } - }; - onChartDataChange(newData); - }; - - const addCategory = () => { - - const newData = { - ...chartData, - data: { - ...chartData.data, - categories: [...chartData.data.categories, ''], - series: chartData.data.series.map(series => ({ - ...series, - data: [...series.data, 0] - })) - } - }; - onChartDataChange(newData); - }; - - const addSeriesBefore = (index: number) => { - if (chartData.type === 'pie' && chartData.data.series.length >= 1) { - return; - } else { - if (chartData.data.series.length >= 4) { - return; - } - } - const newData = { - ...chartData, - data: { - ...chartData.data, - series: [ - ...chartData.data.series.slice(0, index), - { - name: `Series ${chartData.data.series.length + 1}`, - data: new Array(chartData.data.categories.length).fill(0) - }, - ...chartData.data.series.slice(index) - ] - } - }; - onChartDataChange(newData); - }; - - const addSeriesAfter = (index: number) => { - if (chartData.type === 'pie' && chartData.data.series.length >= 1) { - return; - } else { - if (chartData.data.series.length >= 4) { - return; - } - } - const newData = { - ...chartData, - data: { - ...chartData.data, - series: [ - ...chartData.data.series.slice(0, index + 1), - { - name: `Series ${chartData.data.series.length + 1}`, - data: new Array(chartData.data.categories.length).fill(0) - }, - ...chartData.data.series.slice(index + 1) - ] - } - }; - onChartDataChange(newData); - }; - - const removeCategory = (index: number) => { - const newData = { - ...chartData, - data: { - ...chartData.data, - categories: chartData.data.categories.filter((_, idx) => idx !== index), - series: chartData.data.series.map(series => ({ - ...series, - data: series.data.filter((_, idx) => idx !== index) - })) - } - }; - onChartDataChange(newData); - }; - - const removeSeries = (index: number) => { - const newData = { - ...chartData, - data: { - ...chartData.data, - series: chartData.data.series.filter((_, idx) => idx !== index) - } - }; - onChartDataChange(newData); - }; - - const getColumnLetter = (index: number) => { - return String.fromCharCode(65 + index); - }; - - const isColumnSelected = (colIndex: number) => { - return selectedCell?.col === colIndex; - }; - - const isRowSelected = (rowIndex: number) => { - return selectedCell?.row === rowIndex; - }; - - const isCellSelected = (rowIndex: number, colIndex: number) => { - return selectedCell?.row === rowIndex && selectedCell?.col === colIndex; - }; - const disableAddSeries = (chartType: string) => { - if (chartType === 'pie') { - return chartData.data.series.length >= 1; - } else { - return chartData.data.series.length >= 4; - } - } - - return ( - - e.preventDefault()}> - - Chart Editor - -
-
- {/* Spreadsheet Table */} -
-
- - - - - {/* First column for categories */} - - {/* Data columns for each series */} - {chartData && chartData.data.series && chartData.data.series.map((_, index) => ( - - ))} - - - {/* New row for series names */} - - - - {chartData.data.series.map((series, index) => ( - - ))} - - - - - - {chartData.data.categories.map((category, rowIndex) => ( - - {/* Row Numbers */} - - - {/* Category Cell */} - - - - {/* Series Data Cells */} - {/* series name */} - {chartData.data.series.map((series, seriesIndex) => ( - - ))} - - - - ))} - -
- -
- A -
-
-
- - {getColumnLetter(index + 1)} - - - - - - - addSeriesBefore(index)} disabled={disableAddSeries(chartData.type)}> - - Add Column before - - addSeriesAfter(index)} disabled={disableAddSeries(chartData.type)}> - - Add Column after - - removeSeries(index)}> - - Delete Column - - - -
-
- -
- { - const newSeries = chartData.data.series.map((s, i) => - i === index ? { ...s, name: e.target.value } : s - ); - onChartDataChange({ - ...chartData, - data: { - ...chartData.data, - series: newSeries - } - }); - }} - className="border-0 focus-visible:ring-0 focus:ring-0 h-7 text-[13px] bg-transparent" - /> -
- {rowIndex + 1} - setSelectedCell({ row: rowIndex, col: 0 })} - > - handleCategoryChange(rowIndex, e.target.value)} - className="border-0 focus-visible:ring-0 focus:ring-0 h-7 text-[13px] bg-transparent" - /> - setSelectedCell({ row: rowIndex, col: seriesIndex + 1 })} - > - handleValueChange(rowIndex, seriesIndex, e.target.value)} - className="border-0 focus-visible:ring-0 focus:ring-0 h-7 text-[13px] bg-transparent text-right" - /> - - -
- - {/* Add Row Button */} -
- -
-
-
-
- - {/* Add the chart preview section */} -
-

Preview

-
- {renderChart(chartData, false, currentColors, chartSettings)} -
- - {/* Add chart type selection */} -
-

Chart Type

-
- - - -
-
- {chartData.type !== 'line' && ( -
-
- - setChartSettings({ ...chartSettings, showDataLabel: checked })} - /> -
- - {chartSettings.showDataLabel && ( -
- - - - setChartSettings({ - ...chartSettings, dataLabel: { - ...chartSettings.dataLabel, - dataLabelPosition: 'Inside' - } - })} value="inside">Inside - setChartSettings({ - ...chartSettings, dataLabel: { - ...chartSettings.dataLabel, - dataLabelPosition: 'Outside' - } - })} value="outside">Outside - - {chartData.type === 'bar' && - - - - setChartSettings({ - ...chartSettings, dataLabel: { - ...chartSettings.dataLabel, - dataLabelAlignment: 'Base' - } - })} value="base">Base - setChartSettings({ - ...chartSettings, dataLabel: { - ...chartSettings.dataLabel, - dataLabelAlignment: 'Center' - } - })} value="center">Center - setChartSettings({ - ...chartSettings, dataLabel: { - ...chartSettings.dataLabel, - dataLabelAlignment: 'End' - } - })} value="end">End - - - } - -
- )} -
- )} -
- - setChartSettings({ ...chartSettings, showLegend: checked })} - /> -
- - {chartData.type !== 'pie' &&
- - setChartSettings({ ...chartSettings, showGrid: checked })} - /> -
} - - {chartData.type !== 'pie' &&
- - setChartSettings({ ...chartSettings, showAxisLabel: checked })} - /> -
} -
-
-
-
-
-
- ); -}; - -export default ChartEditor; \ No newline at end of file diff --git a/servers/nextjs/app/(presentation-generator)/components/EditableLayoutWrapper.tsx b/servers/nextjs/app/(presentation-generator)/components/EditableLayoutWrapper.tsx new file mode 100644 index 00000000..9e0117b2 --- /dev/null +++ b/servers/nextjs/app/(presentation-generator)/components/EditableLayoutWrapper.tsx @@ -0,0 +1,321 @@ +"use client"; + +import React, { ReactNode, useRef, useEffect, useState } from 'react'; +import { useDispatch } from 'react-redux'; +import { updateSlideImage, updateSlideIcon } from '@/store/slices/presentationGeneration'; +import ImageEditor from './ImageEditor'; +import IconsEditor from './IconsEditor'; + +interface EditableLayoutWrapperProps { + children: ReactNode; + slideIndex: number; + slideData: any; + isEditMode?: boolean; +} + +interface EditableElement { + id: string; + type: 'image' | 'icon'; + src: string; + dataPath: string; + data: any; + element: HTMLImageElement; +} + +const EditableLayoutWrapper: React.FC = ({ + children, + slideIndex, + slideData, + isEditMode = true, +}) => { + const dispatch = useDispatch(); + const containerRef = useRef(null); + const [editableElements, setEditableElements] = useState([]); + const [activeEditor, setActiveEditor] = useState(null); + + /** + * Recursively searches for image/icon data in the slide data structure + */ + const findDataPath = (targetUrl: string, data: any, path: string = ''): { path: string; type: 'image' | 'icon'; data: any } | null => { + if (!data || typeof data !== 'object') return null; + + // Check current level for __image_url__ or __icon_url__ + if (data.__image_url__ && isMatchingUrl(data.__image_url__, targetUrl)) { + return { path, type: 'image', data }; + } + + if (data.__icon_url__ && isMatchingUrl(data.__icon_url__, targetUrl)) { + return { path, type: 'icon', data }; + } + + // Recursively check nested objects and arrays + for (const [key, value] of Object.entries(data)) { + const newPath = path ? `${path}.${key}` : key; + + if (Array.isArray(value)) { + for (let i = 0; i < value.length; i++) { + const result = findDataPath(targetUrl, value[i], `${newPath}[${i}]`); + if (result) return result; + } + } else if (value && typeof value === 'object') { + const result = findDataPath(targetUrl, value, newPath); + if (result) return result; + } + } + + return null; + }; + + /** + * Checks if two URLs match using various comparison strategies + */ + const isMatchingUrl = (url1: string, url2: string): boolean => { + if (!url1 || !url2) return false; + + // Direct match + if (url1 === url2) return true; + + // Remove protocol and domain differences + const cleanUrl1 = url1.replace(/^https?:\/\/[^\/]+/, '').replace(/^\/+/, ''); + const cleanUrl2 = url2.replace(/^https?:\/\/[^\/]+/, '').replace(/^\/+/, ''); + + if (cleanUrl1 === cleanUrl2) return true; + + // Handle app_data paths and placeholder URLs + if (url1.includes('/app_data/') || url2.includes('/app_data/') || + url1.includes('placeholder') || url2.includes('placeholder')) { + const getFilename = (path: string) => path.split('/').pop() || ''; + const filename1 = getFilename(url1); + const filename2 = getFilename(url2); + if (filename1 === filename2 && filename1 !== '') return true; + } + + // Extract and compare filenames for other URLs + const getFilename = (path: string) => path.split('/').pop() || ''; + const filename1 = getFilename(url1); + const filename2 = getFilename(url2); + + if (filename1 === filename2 && filename1 !== '') { + return true; + } + + // Check if one URL is contained in another (for partial matches) + if (url1.includes(url2) || url2.includes(url1)) { + return true; + } + + return false; + }; + + /** + * Finds and processes images in the DOM, making them editable + */ + const findAndProcessImages = () => { + if (!containerRef.current || !isEditMode) return; + + const imgElements = containerRef.current.querySelectorAll('img:not([data-editable-processed])'); + const newEditableElements: EditableElement[] = []; + + imgElements.forEach((img, index) => { + const htmlImg = img as HTMLImageElement; + const src = htmlImg.src; + + if (src) { + const result = findDataPath(src, slideData); + + if (result) { + const { path: dataPath, type, data } = result; + + // Mark as processed to prevent re-processing + htmlImg.setAttribute('data-editable-processed', 'true'); + + const editableElement: EditableElement = { + id: `${type}-${dataPath}-${index}`, + type, + src, + dataPath, + data, + element: htmlImg + }; + + newEditableElements.push(editableElement); + + // Add click handler directly to the image + const clickHandler = (e: Event) => { + e.preventDefault(); + e.stopPropagation(); + setActiveEditor(editableElement); + }; + + htmlImg.addEventListener('click', clickHandler); + + // Add hover effects without changing layout + htmlImg.style.cursor = 'pointer'; + htmlImg.style.transition = 'filter 0.2s, transform 0.2s'; + + const mouseEnterHandler = () => { + htmlImg.style.filter = 'brightness(0.9)'; + + }; + + const mouseLeaveHandler = () => { + htmlImg.style.filter = 'brightness(1)'; + + }; + + htmlImg.addEventListener('mouseenter', mouseEnterHandler); + htmlImg.addEventListener('mouseleave', mouseLeaveHandler); + + // Store cleanup functions + (htmlImg as any)._editableCleanup = () => { + htmlImg.removeEventListener('click', clickHandler); + htmlImg.removeEventListener('mouseenter', mouseEnterHandler); + htmlImg.removeEventListener('mouseleave', mouseLeaveHandler); + htmlImg.style.cursor = ''; + htmlImg.style.transition = ''; + htmlImg.style.filter = ''; + htmlImg.style.transform = ''; + htmlImg.removeAttribute('data-editable-processed'); + }; + } + } + }); + + setEditableElements(prev => [...prev, ...newEditableElements]); + }; + + /** + * Cleanup function to remove event listeners and reset styles + */ + const cleanupElements = () => { + editableElements.forEach(({ element }) => { + if ((element as any)._editableCleanup) { + (element as any)._editableCleanup(); + } + }); + setEditableElements([]); + }; + + // Wait for LoadableComponent to render and then process images + useEffect(() => { + const timer = setTimeout(() => { + findAndProcessImages(); + }, 300); + + return () => { + clearTimeout(timer); + cleanupElements(); + }; + }, [slideData, children]); + + // Re-run when container content changes + useEffect(() => { + if (!containerRef.current) return; + + const observer = new MutationObserver((mutations) => { + const hasNewImages = mutations.some(mutation => + Array.from(mutation.addedNodes).some(node => + node.nodeType === Node.ELEMENT_NODE && + ( + (node as Element).tagName === 'IMG' || + (node as Element).querySelector('img:not([data-editable-processed])') + ) + ) + ); + + if (hasNewImages) { + setTimeout(findAndProcessImages, 100); + } + }); + + observer.observe(containerRef.current, { + childList: true, + subtree: true + }); + + return () => observer.disconnect(); + }, [slideData]); + + /** + * Handles closing the active editor + */ + const handleEditorClose = () => { + setActiveEditor(null); + }; + + /** + * Handles image change from ImageEditor + */ + const handleImageChange = (newImageUrl: string, prompt?: string) => { + if (activeEditor && activeEditor.element) { + // Update the DOM element immediately for visual feedback + activeEditor.element.src = newImageUrl; + + // Update Redux store + dispatch(updateSlideImage({ + slideIndex, + dataPath: activeEditor.dataPath, + imageUrl: newImageUrl, + prompt: prompt || activeEditor.data?.__image_prompt__ || '' + })); + + setActiveEditor(null); + } + }; + + /** + * Handles icon change from IconsEditor + */ + const handleIconChange = (newIconUrl: string, query?: string) => { + if (activeEditor && activeEditor.element) { + // Update the DOM element immediately for visual feedback + activeEditor.element.src = newIconUrl; + + // Update Redux store + dispatch(updateSlideIcon({ + slideIndex, + dataPath: activeEditor.dataPath, + iconUrl: newIconUrl, + query: query || activeEditor.data?.__icon_query__ || '' + })); + + setActiveEditor(null); + } + }; + + return ( +
+ {children} + + {/* Render ImageEditor when an image is being edited */} + {activeEditor && activeEditor.type === 'image' && ( + +
+ + )} + + {/* Render IconsEditor when an icon is being edited */} + {activeEditor && activeEditor.type === 'icon' && ( + +
+ + )} +
+ ); +}; + +export default EditableLayoutWrapper; \ No newline at end of file diff --git a/servers/nextjs/app/(presentation-generator)/components/IconsEditor.tsx b/servers/nextjs/app/(presentation-generator)/components/IconsEditor.tsx index ca8d6e17..1cfb6568 100644 --- a/servers/nextjs/app/(presentation-generator)/components/IconsEditor.tsx +++ b/servers/nextjs/app/(presentation-generator)/components/IconsEditor.tsx @@ -7,66 +7,52 @@ import { SheetTitle, } from "@/components/ui/sheet"; import { Input } from "@/components/ui/input"; -import { PlusIcon, Search } from "lucide-react"; -import { cn } from "@/lib/utils"; -import { useDispatch, useSelector } from "react-redux"; -import { PresentationGenerationApi } from "../services/api/presentation-generation"; -import { RootState } from "@/store/store"; -import { usePathname, useSearchParams } from "next/navigation"; +import { Search } from "lucide-react"; +import { useSearchParams } from "next/navigation"; import { Skeleton } from "@/components/ui/skeleton"; import { Button } from "@/components/ui/button"; -import { updateSlideIcon } from "@/store/slices/presentationGeneration"; +import { PresentationGenerationApi } from "../services/api/presentation-generation"; import { getStaticFileUrl } from "../utils/others"; interface IconsEditorProps { icon: string; index: number; - backgroundColor: string; - hasBg: boolean; - slideIndex: number; - elementId: string; - isWhite?: boolean; className?: string; icon_prompt?: string[] | null; onClose?: () => void; + onIconChange?: (newIconUrl: string, query?: string) => void; } const IconsEditor = ({ icon: initialIcon, - index, - backgroundColor, - hasBg, - className, - slideIndex, - elementId, icon_prompt, onClose, -}: IconsEditorProps) => { - const dispatch = useDispatch(); + onIconChange, +}: IconsEditorProps) => { + // State management const [icon, setIcon] = useState(initialIcon); const [icons, setIcons] = useState([]); - const [isEditorOpen, setIsEditorOpen] = useState(false); const [searchQuery, setSearchQuery] = useState( icon_prompt?.[0] || "" ); const [loading, setLoading] = useState(true); + const searchParams = useSearchParams(); + // Update local state when initial icon changes useEffect(() => { setIcon(initialIcon); }, [initialIcon]); + // Search for icons when component opens useEffect(() => { - if (isEditorOpen) { - handleIconSearch(); - } - }, [isEditorOpen]); - - const handleIconClick = () => { - setIsEditorOpen(true); - }; + handleIconSearch(); + }, []); + /** + * Searches for icons based on the current query + */ const handleIconSearch = async () => { setLoading(true); const presentation_id = searchParams.get("id"); @@ -88,94 +74,100 @@ const IconsEditor = ({ } }; + /** + * Handles icon selection and calls the parent callback + */ const handleIconChange = (newIcon: string) => { - - setIcon(newIcon); - dispatch( - updateSlideIcon({ index: slideIndex, iconIdx: index, icon: newIcon }) - ); - setIsEditorOpen(false); + + if (onIconChange) { + onIconChange(newIcon, searchQuery || icon_prompt?.[0] || ''); + } }; return ( - onClose?.()}> - e.preventDefault()} - onClick={(e) => e.stopPropagation()} - > - - Choose Icon - -
-
{ - e.preventDefault(); - e.stopPropagation(); - handleIconSearch(); - }} - > -
- +
- setSearchQuery(e.target.value)} - onClick={(e) => e.stopPropagation()} - className="pl-10" - /> -
- - +
+ + setSearchQuery(e.target.value)} + onClick={(e) => e.stopPropagation()} + className="pl-10" + /> +
+ + - {/* Icons grid */} -
- {loading ? ( -
- {Array.from({ length: 40 }).map((_, index) => ( - - ))} -
- ) : icons.length > 0 ? ( -
- {icons.map((iconSrc, idx) => ( -
{ - e.stopPropagation(); - handleIconChange(iconSrc); - }} - className="w-12 h-12 cursor-pointer group relative rounded-lg overflow-hidden hover:bg-gray-100 p-2" - > - {`Icon -
- ))} -
- ) : ( -
- -

No icons found for your search.

-

Try refining your search query.

-
- )} + {/* Icons Grid */} +
+ {loading ? ( +
+ {Array.from({ length: 40 }).map((_, index) => ( + + ))} +
+ ) : icons.length > 0 ? ( +
+ {icons.map((iconSrc, idx) => ( +
{ + e.stopPropagation(); + handleIconChange(iconSrc); + }} + className="w-12 h-12 cursor-pointer group relative rounded-lg overflow-hidden hover:bg-gray-100 p-2 transition-colors" + > + {`Icon +
+ ))} +
+ ) : ( +
+ +

No icons found for your search.

+

Try refining your search query.

+
+ )} +
-
- - + + +
); }; diff --git a/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx b/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx index 14349032..cdf5a3c7 100644 --- a/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx +++ b/servers/nextjs/app/(presentation-generator)/components/ImageEditor.tsx @@ -13,32 +13,25 @@ import { Wand2, Upload, Move, - } from "lucide-react"; import { cn } from "@/lib/utils"; -import { useDispatch, useSelector } from "react-redux"; +import { useSelector } from "react-redux"; import { PresentationGenerationApi } from "../services/api/presentation-generation"; import { RootState } from "@/store/store"; import { useSearchParams } from "next/navigation"; import { Skeleton } from "@/components/ui/skeleton"; -import { - updateSlideImage, - updateSlideProperties, -} from "@/store/slices/presentationGeneration"; -import { getStaticFileUrl, ThemeImagePrompt } from "../utils/others"; - - +import { ThemeImagePrompt } from "../utils/others"; interface ImageEditorProps { initialImage: string | null; imageIdx?: number; - slideIndex: number; - className?: string; promptContent?: string; properties?: null | any; onClose?: () => void; + onImageChange?: (newImageUrl: string, prompt?: string) => void; + } const ImageEditor = ({ @@ -48,12 +41,13 @@ const ImageEditor = ({ promptContent, properties, onClose, + onImageChange, + }: ImageEditorProps) => { - const dispatch = useDispatch(); - const { currentTheme } = useSelector((state: RootState) => state.theme); - const searchParams = useSearchParams(); + + // State management const [image, setImage] = useState(initialImage); const [previewImages, setPreviewImages] = useState([initialImage]); const [prompt, setPrompt] = useState(""); @@ -62,6 +56,8 @@ const ImageEditor = ({ const [isUploading, setIsUploading] = useState(false); const [uploadError, setUploadError] = useState(null); const [uploadedImageUrl, setUploadedImageUrl] = useState(null); + + // Focus point and object fit for image editing const [isFocusPointMode, setIsFocusPointMode] = useState(false); const [focusPoint, setFocusPoint] = useState( (properties && @@ -77,11 +73,14 @@ const ImageEditor = ({ properties[imageIdx].initialObjectFit) || "cover" ); + + // Refs const imageRef = useRef(null); const imageContainerRef = useRef(null); const toolbarRef = useRef(null); const popoverContentRef = useRef(null); + // Update local state when initial image changes useEffect(() => { setImage(initialImage); setPreviewImages([initialImage]); @@ -97,9 +96,7 @@ const ImageEditor = ({ !toolbarRef.current.contains(event.target as Node) && !popoverContentRef.current ) { - if (isFocusPointMode) { - // saveFocusPoint(); // Save focus point before closing saveImageProperties(objectFit, focusPoint); } setIsFocusPointMode(false); @@ -110,21 +107,22 @@ const ImageEditor = ({ return () => { document.removeEventListener("mousedown", handleClickOutside); }; - }, [isFocusPointMode, focusPoint]); - - + }, [isFocusPointMode, focusPoint, objectFit]); + /** + * Handles image selection and calls the parent callback + */ const handleImageChange = (newImage: string) => { setImage(newImage); - dispatch( - updateSlideImage({ - index: slideIndex, - imageIdx: imageIdx, - image: newImage, - }) - ); + + if (onImageChange) { + onImageChange(newImage, promptContent); + } }; + /** + * Handles focus point adjustment when clicking on the image + */ const handleFocusPointClick = (e: React.MouseEvent) => { if (!isFocusPointMode || !imageRef.current) return; @@ -147,14 +145,19 @@ const ImageEditor = ({ } }; + /** + * Toggles focus point adjustment mode + */ const toggleFocusPointMode = () => { if (isFocusPointMode) { - // If turning off focus point mode, save the current focus point - // saveFocusPoint(); + saveImageProperties(objectFit, focusPoint); } setIsFocusPointMode(!isFocusPointMode); }; + /** + * Handles object fit change + */ const handleFitChange = (fit: "cover" | "contain" | "fill") => { setObjectFit(fit); @@ -162,10 +165,12 @@ const ImageEditor = ({ imageRef.current.style.objectFit = fit; } - // Save the fit change to your state saveImageProperties(fit, focusPoint); }; + /** + * Saves image properties (focus point and object fit) + */ const saveImageProperties = ( fit: "cover" | "contain" | "fill", focusPoint: { x: number; y: number } @@ -174,16 +179,12 @@ const ImageEditor = ({ initialObjectFit: fit, initialFocusPoint: focusPoint, }; - - dispatch( - updateSlideProperties({ - index: slideIndex, - itemIdx: imageIdx, - properties: propertiesData, - }) - ); + // TODO: Save to Redux store if needed }; + /** + * Generates new images using AI + */ const handleGenerateImage = async () => { try { setIsGenerating(true); @@ -208,26 +209,24 @@ const ImageEditor = ({ } }; + /** + * Handles file upload + */ const handleFileUpload = async ( event: React.ChangeEvent ) => { - const presentation_id = searchParams.get("id"); const file = event.target.files?.[0]; if (!file) return; - // Check file size (e.g., 5MB limit) + // Validate file size (5MB limit) if (file.size > 5 * 1024 * 1024) { - const error_message = "File size should be less than 5MB"; - - setUploadError(error_message); + setUploadError("File size should be less than 5MB"); return; } - // Check file type + // Validate file type if (!file.type.startsWith("image/")) { - const error_message = "Please upload an image file"; - - setUploadError(error_message); + setUploadError("Please upload an image file"); return; } @@ -249,356 +248,191 @@ const ImageEditor = ({ throw new Error(result.error || 'Upload failed'); } - // Update state with the returned path setUploadedImageUrl(result.filePath); } catch (err) { - const error_message = "Failed to upload image. Please try again."; - - setUploadError(error_message); + setUploadError("Failed to upload image. Please try again."); console.error("Upload error:", err); } finally { setIsUploading(false); } }; - - return ( - onClose?.()}> - e.preventDefault()} - onClick={(e) => e.stopPropagation()} - > - - Update Image - +
-
- - - - Edit - - - AI Generate - - - Upload - - - -
- {/* Current Image Preview */} -
-

Current Image

-
- {image ? ( - Current image { - e.stopPropagation(); - handleFocusPointClick(e); - }} - onError={(e) => { - console.error('Image failed to load:', image); - e.currentTarget.src = '/placeholder-image.png'; - }} - /> - ) : ( -
-
- -

No image selected

-
-
- )} + onClose?.()}> + e.preventDefault()} + onClick={(e) => e.stopPropagation()} + > + + Update Image + - {/* Focus Point Indicator */} - {isFocusPointMode && image && ( -
- )} -
- {/* Debug info */} - {image && ( -
-

Image Path: {image}

-

Resolved URL: {image}

-

Focus Point: {focusPoint.x.toFixed(1)}%, {focusPoint.y.toFixed(1)}%

-

Object Fit: {objectFit}

-
- )} -
- - {/* Editing Controls */} +
+ + + + AI Generate + + + Upload + + + {/* Generate Tab */} +
- {/* Focus Point Controls */} -
-
-

Focus Point

- -
- {isFocusPointMode && ( -

- Click on the image above to set the focus point -

- )} +
+

Current Prompt

+

{promptContent}

- {/* Object Fit Controls */} -
-

Image Fit

-
- - - -
-
-

Cover: Fill container, may crop image

-

Contain: Fit entire image, may show empty space

-

Fill: Stretch to fill container exactly

-
+
+

Image Description

+