diff --git a/docker-compose.yml b/docker-compose.yml index da81c479..fd92fa8b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,11 +12,13 @@ services: environment: - CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS} - LLM=${LLM} - - LLM_PROVIDER_URL=${LLM_PROVIDER_URL} - - LLM_API_KEY=${LLM_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} - GOOGLE_API_KEY=${GOOGLE_API_KEY} + - OLLAMA_URL=${OLLAMA_URL} - OLLAMA_MODEL=${OLLAMA_MODEL} + - CUSTOM_LLM_URL=${CUSTOM_LLM_URL} + - CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY} + - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} production-gpu: @@ -39,11 +41,13 @@ services: environment: - CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS} - LLM=${LLM} - - LLM_PROVIDER_URL=${LLM_PROVIDER_URL} - - LLM_API_KEY=${LLM_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} - GOOGLE_API_KEY=${GOOGLE_API_KEY} + - OLLAMA_URL=${OLLAMA_URL} - OLLAMA_MODEL=${OLLAMA_MODEL} + - CUSTOM_LLM_URL=${CUSTOM_LLM_URL} + - CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY} + - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} development: @@ -60,11 +64,13 @@ services: - NODE_ENV=development - CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS} - LLM=${LLM} - - LLM_PROVIDER_URL=${LLM_PROVIDER_URL} - - LLM_API_KEY=${LLM_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} - GOOGLE_API_KEY=${GOOGLE_API_KEY} + - OLLAMA_URL=${OLLAMA_URL} - OLLAMA_MODEL=${OLLAMA_MODEL} + - CUSTOM_LLM_URL=${CUSTOM_LLM_URL} + - CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY} + - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} development-gpu: @@ -88,9 +94,11 @@ services: - NODE_ENV=development - CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS} - LLM=${LLM} - - LLM_PROVIDER_URL=${LLM_PROVIDER_URL} - - LLM_API_KEY=${LLM_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY} - GOOGLE_API_KEY=${GOOGLE_API_KEY} + - OLLAMA_URL=${OLLAMA_URL} - OLLAMA_MODEL=${OLLAMA_MODEL} + - CUSTOM_LLM_URL=${CUSTOM_LLM_URL} + - CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY} + - CUSTOM_MODEL=${CUSTOM_MODEL} - PEXELS_API_KEY=${PEXELS_API_KEY} \ No newline at end of file diff --git a/servers/fastapi/api/main.py b/servers/fastapi/api/main.py index 4908a063..a95af05a 100644 --- a/servers/fastapi/api/main.py +++ b/servers/fastapi/api/main.py @@ -1,40 +1,79 @@ +import asyncio import os from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -import ollama from sqlmodel import SQLModel from contextlib import asynccontextmanager +from api.models import SelectedLLMProvider from api.routers.presentation.router import presentation_router from api.services.database import sql_engine from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS from api.utils.utils import update_env_with_user_config -from api.utils.model_utils import is_ollama_selected +from api.utils.model_utils import ( + get_selected_llm_provider, + is_custom_llm_selected, + is_ollama_selected, + list_available_custom_models, + pull_ollama_model, +) can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false" -# Ollama model download -if not can_change_keys and is_ollama_selected(): - ollama_model = os.getenv("MODEL") - pexels_api_key = os.getenv("PEXELS_API_KEY") - if not (ollama_model or pexels_api_key): - raise Exception("MODEL and PEXELS_API_KEY must be provided") - if ollama_model not in SUPPORTED_OLLAMA_MODELS: - raise Exception(f"Model {ollama_model} is not supported") +async def check_llm_model_availability(): + if not can_change_keys: + if get_selected_llm_provider() == SelectedLLMProvider.OPENAI: + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + raise Exception("OPENAI_API_KEY must be provided") - print("-" * 50) - print("Pulling model: ", ollama_model) - for event in ollama.pull(ollama_model, stream=True): - print(event) - print("Pulled model: ", ollama_model) - print("-" * 50) + elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE: + google_api_key = os.getenv("GOOGLE_API_KEY") + if not google_api_key: + raise Exception("GOOGLE_API_KEY must be provided") + + elif is_ollama_selected(): + ollama_model = os.getenv("OLLAMA_MODEL") + if not ollama_model: + raise Exception("OLLAMA_MODEL must be provided") + + if ollama_model not in SUPPORTED_OLLAMA_MODELS: + raise Exception(f"Model {ollama_model} is not supported") + + print("-" * 50) + print("Pulling model: ", ollama_model) + async for event in pull_ollama_model(ollama_model): + print(event) + print("Pulled model: ", ollama_model) + print("-" * 50) + + elif is_custom_llm_selected(): + custom_model = os.getenv("CUSTOM_MODEL") + custom_llm_url = os.getenv("CUSTOM_LLM_URL") + custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY") + if not custom_model: + raise Exception("CUSTOM_MODEL must be provided") + if not custom_llm_url: + raise Exception("CUSTOM_LLM_URL must be provided") + if not custom_llm_api_key: + raise Exception("CUSTOM_LLM_API_KEY must be provided") + print("-" * 50) + print("Selecting model: ", custom_model) + models = await list_available_custom_models( + custom_llm_url, custom_llm_api_key + ) + print("Available models: ", models) + print("-" * 50) + if custom_model not in models: + raise Exception(f"Model {custom_model} is not available") @asynccontextmanager async def lifespan(_: FastAPI): os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True) SQLModel.metadata.create_all(sql_engine) + await check_llm_model_availability() yield diff --git a/servers/fastapi/api/models.py b/servers/fastapi/api/models.py index 96d7b8d0..910db4f3 100644 --- a/servers/fastapi/api/models.py +++ b/servers/fastapi/api/models.py @@ -63,9 +63,11 @@ class UserConfig(BaseModel): LLM: Optional[str] = None OPENAI_API_KEY: Optional[str] = None GOOGLE_API_KEY: Optional[str] = None - MODEL: Optional[str] = None - LLM_PROVIDER_URL: Optional[str] = None - LLM_API_KEY: Optional[str] = None + OLLAMA_URL: Optional[str] = None + OLLAMA_MODEL: Optional[str] = None + CUSTOM_LLM_URL: Optional[str] = None + CUSTOM_LLM_API_KEY: Optional[str] = None + CUSTOM_MODEL: Optional[str] = None PEXELS_API_KEY: Optional[str] = None diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_data.py b/servers/fastapi/api/routers/presentation/handlers/generate_data.py index 8375d7eb..477cddba 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_data.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_data.py @@ -11,7 +11,7 @@ from api.routers.presentation.models import PresentationGenerateRequest from api.services.logging import LoggingService from api.sql_models import KeyValueSqlModel, PresentationSqlModel from api.services.database import get_sql_session -from api.utils.model_utils import is_ollama_selected +from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel from ppt_config_generator.structure_generator import generate_presentation_structure @@ -39,7 +39,7 @@ class PresentationGenerateDataHandler: value=self.data.model_dump(mode="json"), ) - if is_ollama_selected(): + if is_ollama_selected() or is_custom_llm_selected(): with get_sql_session() as sql_session: presentation = sql_session.get( PresentationSqlModel, self.data.presentation_id @@ -53,9 +53,10 @@ class PresentationGenerateDataHandler: } ) ) - supports_graph = True - model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")] - supports_graph = model.supports_graph + supports_graph = not is_custom_llm_selected() + if is_ollama_selected(): + model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")] + supports_graph = model.supports_graph for each in presentation_structure.slides: if each.type > 9: diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_stream.py b/servers/fastapi/api/routers/presentation/handlers/generate_stream.py index 6a4bb663..fd2d2762 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_stream.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_stream.py @@ -18,7 +18,7 @@ from api.services.database import get_sql_session from api.services.logging import LoggingService from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel from api.utils.utils import get_presentation_dir -from api.utils.model_utils import is_ollama_selected +from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected from ppt_config_generator.models import ( PresentationMarkdownModel, PresentationStructureModel, @@ -99,8 +99,8 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin self.presentation_json = None # self.presentation_json will be mutated by the generator - if is_ollama_selected(): - async for result in self.generate_presentation_ollama(): + if is_ollama_selected() or is_custom_llm_selected(): + async for result in self.generate_presentation_ollama_custom(): yield result else: async for result in self.generate_presentation_openai_google(): @@ -157,7 +157,7 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin self.presentation_json = json.loads(presentation_text) - async def generate_presentation_ollama(self): + async def generate_presentation_ollama_custom(self): presentation_structure = PresentationStructureModel( **self.presentation.structure ) diff --git a/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py b/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py index 56820ff0..e55d2d87 100644 --- a/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py +++ b/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py @@ -1,9 +1,6 @@ -import aiohttp -from fastapi import HTTPException from api.models import LogMetadata -from api.routers.presentation.models import OllamaModelStatusResponse from api.services.logging import LoggingService -from api.utils.model_utils import get_llm_provider_url_or +from api.utils.model_utils import list_pulled_ollama_models class ListPulledOllamaModelsHandler: @@ -13,35 +10,10 @@ class ListPulledOllamaModelsHandler: logging_service.message("Listing Ollama models"), extra=log_metadata.model_dump(), ) - async with aiohttp.ClientSession() as session: - async with session.get( - f"{get_llm_provider_url_or()}/api/tags", - ) as response: - if response.status == 200: - response_data = await response.json() - elif response.status == 403: - raise HTTPException( - status_code=403, - detail="Forbidden: Please check your Ollama Configuration", - ) - else: - raise HTTPException( - status_code=response.status, - detail=f"Failed to list Ollama models: {response.status}", - ) + pulled_models = await list_pulled_ollama_models() logging_service.logger.info( - logging_service.message(response_data), + logging_service.message(pulled_models), extra=log_metadata.model_dump(), ) - - return [ - OllamaModelStatusResponse( - name=model["model"], - size=model["size"], - status="pulled", - downloaded=model["size"], - done=True, - ) - for model in response_data["models"] - ] + return pulled_models diff --git a/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py b/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py index 87178bc2..dd12a7b0 100644 --- a/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py +++ b/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py @@ -9,7 +9,11 @@ from api.routers.presentation.handlers.list_supported_ollama_models import ( from api.routers.presentation.models import OllamaModelStatusResponse from api.services.instances import REDIS_SERVICE from api.services.logging import LoggingService -from api.utils.model_utils import get_llm_provider_url_or +from api.utils.model_utils import ( + get_llm_provider_url_or, + list_pulled_ollama_models, + pull_ollama_model, +) class PullOllamaModelHandler: @@ -34,40 +38,13 @@ class PullOllamaModelHandler: detail=f"Model {self.name} is not supported", ) - # Check if model is already pulled using LLM_PROVIDER_URL/api/tags try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"{get_llm_provider_url_or()}/api/tags", - ) as response: - if response.status == 200: - pulled_models = await response.json() - filtered_models = [ - model - for model in pulled_models["models"] - if model["model"] == self.name - ] - - # If the model is already pulled, return the model - if filtered_models: - return OllamaModelStatusResponse( - name=self.name, - size=filtered_models[0]["size"], - status="pulled", - downloaded=filtered_models[0]["size"], - done=True, - ) - elif response.status == 403: - print(response) - raise HTTPException( - status_code=403, - detail="Forbidden: Please check your Ollama Configuration", - ) - else: - raise HTTPException( - status_code=response.status, - detail=f"Failed to list Ollama models: {response.status}", - ) + pulled_models = await list_pulled_ollama_models() + filtered_models = [ + model for model in pulled_models if model.name == self.name + ] + if filtered_models: + return filtered_models[0] except HTTPException as e: logging_service.logger.warning( logging_service.message(e.detail), @@ -122,43 +99,24 @@ class PullOllamaModelHandler: log_event_count = 0 try: - async with aiohttp.ClientSession() as session: - async with session.post( - f"{get_llm_provider_url_or()}/api/pull", - json={"model": self.name}, - ) as response: - if response.status != 200: - raise HTTPException( - status_code=response.status, - detail=f"Failed to pull model: {await response.text()}", - ) + async for event in pull_ollama_model(self.name): + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue - async for line in response.content: - if not line.strip(): - continue + if "completed" in event: + saved_model_status.downloaded = event["completed"] - try: - event = json.loads(line.decode("utf-8")) - except json.JSONDecodeError: - continue + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] - log_event_count += 1 - if log_event_count != 1 and log_event_count % 20 != 0: - continue + if "status" in event: + saved_model_status.status = event["status"] - if "completed" in event: - saved_model_status.downloaded = event["completed"] - - if not saved_model_status.size and "total" in event: - saved_model_status.size = event["total"] - - if "status" in event: - saved_model_status.status = event["status"] - - REDIS_SERVICE.set( - f"ollama_models/{self.name}", - json.dumps(saved_model_status.model_dump(mode="json")), - ) + REDIS_SERVICE.set( + f"ollama_models/{self.name}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) except Exception as e: saved_model_status.status = "error" diff --git a/servers/fastapi/api/routers/presentation/router.py b/servers/fastapi/api/routers/presentation/router.py index be48b529..3b3a0894 100644 --- a/servers/fastapi/api/routers/presentation/router.py +++ b/servers/fastapi/api/routers/presentation/router.py @@ -1,6 +1,7 @@ from typing import Annotated, List, Optional import uuid from fastapi import APIRouter, BackgroundTasks, Body, File, Form, UploadFile +import openai from api.models import SessionModel from api.request_utils import RequestUtils @@ -81,6 +82,7 @@ from api.routers.presentation.models import ( PresentationUpdateRequest, ) from api.sql_models import PresentationSqlModel +from api.utils.model_utils import get_llm_client, list_available_custom_models from api.utils.utils import handle_errors from image_processor.images_finder import ( generate_image_google, @@ -389,3 +391,11 @@ async def pull_ollama_model(name: str, background_tasks: BackgroundTasks): log_metadata, background_tasks=background_tasks, ) + + +@presentation_router.post("/models/list/custom", response_model=List[str]) +async def list_custom_models( + url: Annotated[Optional[str], Body()] = None, + api_key: Annotated[Optional[str], Body()] = None, +): + return await list_available_custom_models(url, api_key) diff --git a/servers/fastapi/api/utils/model_utils.py b/servers/fastapi/api/utils/model_utils.py index e26dc4d2..2673adeb 100644 --- a/servers/fastapi/api/utils/model_utils.py +++ b/servers/fastapi/api/utils/model_utils.py @@ -1,16 +1,29 @@ +import json import os +from typing import AsyncGenerator, Optional +import aiohttp +from fastapi import HTTPException from openai import AsyncOpenAI +import openai from api.models import SelectedLLMProvider +from api.routers.presentation.models import OllamaModelStatusResponse def is_ollama_selected() -> bool: return get_selected_llm_provider() == SelectedLLMProvider.OLLAMA +def is_custom_llm_selected() -> bool: + return get_selected_llm_provider() == SelectedLLMProvider.CUSTOM + + def get_llm_provider_url_or(): - llm_provider_url = os.getenv("LLM_PROVIDER_URL") or "http://localhost:11434" + llm_provider_url = ( + os.getenv("OLLAMA_URL") if is_ollama_selected() else os.getenv("CUSTOM_LLM_URL") + ) + llm_provider_url = llm_provider_url or "http://localhost:11434" if llm_provider_url.endswith("/"): return llm_provider_url[:-1] return llm_provider_url @@ -20,6 +33,19 @@ def get_selected_llm_provider() -> SelectedLLMProvider: return SelectedLLMProvider(os.getenv("LLM")) +async def list_available_custom_models( + url: Optional[str] = None, api_key: Optional[str] = None +) -> list[str]: + if not url or not api_key: + client = get_llm_client() + else: + client = openai.AsyncOpenAI(api_key=api_key, base_url=url) + models = [] + async for model in client.models.list(): + models.append(model.id) + return models + + def get_model_base_url(): selected_llm = get_selected_llm_provider() @@ -29,6 +55,8 @@ def get_model_base_url(): return "https://generativelanguage.googleapis.com/v1beta/openai" elif selected_llm == SelectedLLMProvider.OLLAMA: return os.path.join(get_llm_provider_url_or(), "v1") + elif selected_llm == SelectedLLMProvider.CUSTOM: + return get_llm_provider_url_or() else: raise ValueError(f"Invalid LLM provider") @@ -41,6 +69,8 @@ def get_llm_api_key(): return os.getenv("GOOGLE_API_KEY") elif selected_llm == SelectedLLMProvider.OLLAMA: return "ollama" + elif selected_llm == SelectedLLMProvider.CUSTOM: + return os.getenv("CUSTOM_LLM_API_KEY") else: raise ValueError(f"Invalid LLM API key") @@ -59,8 +89,12 @@ def get_large_model(): return "gpt-4.1" elif selected_llm == SelectedLLMProvider.GOOGLE: return "gemini-2.0-flash" + elif selected_llm == SelectedLLMProvider.OLLAMA: + return os.getenv("OLLAMA_MODEL") + elif selected_llm == SelectedLLMProvider.CUSTOM: + return os.getenv("CUSTOM_MODEL") else: - return os.getenv("MODEL") + raise ValueError(f"Invalid LLM model") def get_small_model(): @@ -69,8 +103,12 @@ def get_small_model(): return "gpt-4.1-mini" elif selected_llm == SelectedLLMProvider.GOOGLE: return "gemini-2.0-flash" + elif selected_llm == SelectedLLMProvider.OLLAMA: + return os.getenv("OLLAMA_MODEL") + elif selected_llm == SelectedLLMProvider.CUSTOM: + return os.getenv("CUSTOM_MODEL") else: - return os.getenv("MODEL") + raise ValueError(f"Invalid LLM model") def get_nano_model(): @@ -79,5 +117,62 @@ def get_nano_model(): return "gpt-4.1-nano" elif selected_llm == SelectedLLMProvider.GOOGLE: return "gemini-2.0-flash" + elif selected_llm == SelectedLLMProvider.OLLAMA: + return os.getenv("OLLAMA_MODEL") + elif selected_llm == SelectedLLMProvider.CUSTOM: + return os.getenv("CUSTOM_MODEL") else: - return os.getenv("MODEL") + raise ValueError(f"Invalid LLM model") + + +async def list_pulled_ollama_models() -> list[OllamaModelStatusResponse]: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{get_llm_provider_url_or()}/api/tags", + ) as response: + if response.status == 200: + pulled_models = await response.json() + return [ + OllamaModelStatusResponse( + name=m["model"], + size=m["size"], + status="pulled", + downloaded=m["size"], + done=True, + ) + for m in pulled_models["models"] + ] + elif response.status == 403: + raise HTTPException( + status_code=403, + detail="Forbidden: Please check your Ollama Configuration", + ) + else: + raise HTTPException( + status_code=response.status, + detail=f"Failed to list Ollama models: {response.status}", + ) + + +async def pull_ollama_model(model: str) -> AsyncGenerator[dict, None]: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{get_llm_provider_url_or()}/api/pull", + json={"model": model}, + ) as response: + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"Failed to pull model: {await response.text()}", + ) + + async for line in response.content: + if not line.strip(): + continue + + try: + event = json.loads(line.decode("utf-8")) + except json.JSONDecodeError: + continue + + yield event diff --git a/servers/fastapi/api/utils/utils.py b/servers/fastapi/api/utils/utils.py index 96174bde..09ba0cec 100644 --- a/servers/fastapi/api/utils/utils.py +++ b/servers/fastapi/api/utils/utils.py @@ -42,12 +42,14 @@ def get_user_config(): return UserConfig( LLM=existing_config.LLM or os.getenv("LLM"), - LLM_PROVIDER_URL=existing_config.LLM_PROVIDER_URL - or os.getenv("LLM_PROVIDER_URL"), - LLM_API_KEY=existing_config.LLM_API_KEY or os.getenv("LLM_API_KEY"), OPENAI_API_KEY=existing_config.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY"), GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or os.getenv("GOOGLE_API_KEY"), - MODEL=existing_config.MODEL or os.getenv("MODEL"), + OLLAMA_URL=existing_config.OLLAMA_URL or os.getenv("OLLAMA_URL"), + OLLAMA_MODEL=existing_config.OLLAMA_MODEL or os.getenv("OLLAMA_MODEL"), + CUSTOM_LLM_URL=existing_config.CUSTOM_LLM_URL or os.getenv("CUSTOM_LLM_URL"), + CUSTOM_LLM_API_KEY=existing_config.CUSTOM_LLM_API_KEY + or os.getenv("CUSTOM_LLM_API_KEY"), + CUSTOM_MODEL=existing_config.CUSTOM_MODEL or os.getenv("CUSTOM_MODEL"), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or os.getenv("PEXELS_API_KEY"), ) @@ -56,16 +58,20 @@ def update_env_with_user_config(): user_config = get_user_config() if user_config.LLM: os.environ["LLM"] = user_config.LLM - if user_config.LLM_PROVIDER_URL: - os.environ["LLM_PROVIDER_URL"] = user_config.LLM_PROVIDER_URL - if user_config.LLM_API_KEY: - os.environ["LLM_API_KEY"] = user_config.LLM_API_KEY if user_config.OPENAI_API_KEY: os.environ["OPENAI_API_KEY"] = user_config.OPENAI_API_KEY if user_config.GOOGLE_API_KEY: os.environ["GOOGLE_API_KEY"] = user_config.GOOGLE_API_KEY - if user_config.MODEL: - os.environ["MODEL"] = user_config.MODEL + if user_config.OLLAMA_URL: + os.environ["OLLAMA_URL"] = user_config.OLLAMA_URL + if user_config.OLLAMA_MODEL: + os.environ["OLLAMA_MODEL"] = user_config.OLLAMA_MODEL + if user_config.CUSTOM_LLM_URL: + os.environ["CUSTOM_LLM_URL"] = user_config.CUSTOM_LLM_URL + if user_config.CUSTOM_LLM_API_KEY: + os.environ["CUSTOM_LLM_API_KEY"] = user_config.CUSTOM_LLM_API_KEY + if user_config.CUSTOM_MODEL: + os.environ["CUSTOM_MODEL"] = user_config.CUSTOM_MODEL if user_config.PEXELS_API_KEY: os.environ["PEXELS_API_KEY"] = user_config.PEXELS_API_KEY diff --git a/servers/fastapi/image_processor/images_finder.py b/servers/fastapi/image_processor/images_finder.py index d915f914..867442fc 100644 --- a/servers/fastapi/image_processor/images_finder.py +++ b/servers/fastapi/image_processor/images_finder.py @@ -1,5 +1,4 @@ import asyncio -import base64 import os import uuid import aiohttp @@ -10,7 +9,11 @@ from ppt_generator.models.query_and_prompt_models import ( ImagePromptWithThemeAndAspectRatio, ) from api.utils.utils import download_file, get_resource -from api.utils.model_utils import get_llm_client, is_ollama_selected +from api.utils.model_utils import ( + get_llm_client, + is_custom_llm_selected, + is_ollama_selected, +) async def generate_image( @@ -18,10 +21,11 @@ async def generate_image( output_directory: str, ) -> str: is_ollama = is_ollama_selected() + is_custom_llm = is_custom_llm_selected() image_prompt = ( input.image_prompt - if is_ollama + if is_ollama or is_custom_llm else f"{input.image_prompt}, {input.theme_prompt}" ) print(f"Request - Generating Image for {image_prompt}") @@ -29,7 +33,7 @@ async def generate_image( try: image_gen_func = ( get_image_from_pexels - if is_ollama + if is_ollama or is_custom_llm else ( generate_image_openai if os.getenv("LLM") == "openai" diff --git a/servers/nextjs/app/api/user-config/route.ts b/servers/nextjs/app/api/user-config/route.ts index 87549b91..6fe3bccf 100644 --- a/servers/nextjs/app/api/user-config/route.ts +++ b/servers/nextjs/app/api/user-config/route.ts @@ -34,11 +34,13 @@ export async function POST(request: Request) { } const mergedConfig: LLMConfig = { LLM: userConfig.LLM || existingConfig.LLM, - LLM_PROVIDER_URL: userConfig.LLM_PROVIDER_URL || existingConfig.LLM_PROVIDER_URL, - LLM_API_KEY: userConfig.LLM_API_KEY, OPENAI_API_KEY: userConfig.OPENAI_API_KEY || existingConfig.OPENAI_API_KEY, GOOGLE_API_KEY: userConfig.GOOGLE_API_KEY || existingConfig.GOOGLE_API_KEY, - MODEL: userConfig.MODEL || existingConfig.MODEL, + OLLAMA_URL: userConfig.OLLAMA_URL || existingConfig.OLLAMA_URL, + OLLAMA_MODEL: userConfig.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL, + CUSTOM_LLM_URL: userConfig.CUSTOM_LLM_URL || existingConfig.CUSTOM_LLM_URL, + CUSTOM_LLM_API_KEY: userConfig.CUSTOM_LLM_API_KEY || existingConfig.CUSTOM_LLM_API_KEY, + CUSTOM_MODEL: userConfig.CUSTOM_MODEL || existingConfig.CUSTOM_MODEL, PEXELS_API_KEY: userConfig.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY, USE_CUSTOM_URL: userConfig.USE_CUSTOM_URL === undefined ? existingConfig.USE_CUSTOM_URL : userConfig.USE_CUSTOM_URL, } diff --git a/servers/nextjs/app/settings/SettingPage.tsx b/servers/nextjs/app/settings/SettingPage.tsx index 512c929e..a93dff81 100644 --- a/servers/nextjs/app/settings/SettingPage.tsx +++ b/servers/nextjs/app/settings/SettingPage.tsx @@ -41,6 +41,11 @@ const PROVIDER_CONFIGS: Record = { title: "Ollama API Key", description: "Required for using Ollama services", placeholder: "Choose a model", + }, + custom: { + title: "Custom Model Configuration", + description: "Configure your own OpenAI-compatible model", + placeholder: "Enter your custom model details", } }; @@ -63,6 +68,7 @@ const SettingsPage = () => { size: string; icon: string; }[]>([]); + const [customModels, setCustomModels] = useState([]); const [downloadingModel, setDownloadingModel] = useState({ name: '', size: null, @@ -73,14 +79,24 @@ const SettingsPage = () => { const [isLoading, setIsLoading] = useState(false); const [openModelSelect, setOpenModelSelect] = useState(false); const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState(userConfigState.llm_config.USE_CUSTOM_URL || false); + const [customModelsLoading, setCustomModelsLoading] = useState(false); + const [customModelsChecked, setCustomModelsChecked] = useState(false); const input_field_changed = (new_value: string, field: string) => { if (field === 'openai_api_key') { setLlmConfig({ ...llmConfig, OPENAI_API_KEY: new_value }); } else if (field === 'google_api_key') { setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: new_value }); - } else if (field === 'llm_provider_url') { - setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: new_value }); + } else if (field === 'ollama_url') { + setLlmConfig({ ...llmConfig, OLLAMA_URL: new_value }); + } else if (field === 'ollama_model') { + setLlmConfig({ ...llmConfig, OLLAMA_MODEL: new_value }); + } else if (field === 'custom_llm_url') { + setLlmConfig({ ...llmConfig, CUSTOM_LLM_URL: new_value }); + } else if (field === 'custom_llm_api_key') { + setLlmConfig({ ...llmConfig, CUSTOM_LLM_API_KEY: new_value }); + } else if (field === 'custom_model') { + setLlmConfig({ ...llmConfig, CUSTOM_MODEL: new_value }); } else if (field === 'pexels_api_key') { setLlmConfig({ ...llmConfig, PEXELS_API_KEY: new_value }); } @@ -110,10 +126,30 @@ const SettingsPage = () => { } }; + const fetchOllamaModelsWithConfig = async (config: any) => { + try { + const response = await fetch('/api/v1/ppt/ollama/list-supported-models'); + const data = await response.json(); + setOllamaModels(data.models); + + // Check if currently selected model is still available + if (config.OLLAMA_MODEL && data.models.length > 0) { + const isModelAvailable = data.models.some((model: any) => model.value === config.OLLAMA_MODEL); + if (!isModelAvailable) { + setLlmConfig({ ...config, OLLAMA_MODEL: '' }); + } + } + } catch (error) { + console.error('Error fetching ollama models:', error); + } + } + const changeProvider = (provider: string) => { - setLlmConfig({ ...llmConfig, LLM: provider }); + const newConfig = { ...llmConfig, LLM: provider }; + setLlmConfig(newConfig); if (provider === 'ollama') { - fetchOllamaModels(); + // Use the new config to avoid stale state issues + fetchOllamaModelsWithConfig(newConfig); } } @@ -131,7 +167,7 @@ const SettingsPage = () => { return new Promise((resolve, reject) => { const interval = setInterval(async () => { try { - const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`); + const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`); if (response.status === 200) { const data = await response.json(); if (data.done && data.status !== 'error') { @@ -163,23 +199,64 @@ const SettingsPage = () => { } const fetchOllamaModels = async () => { + await fetchOllamaModelsWithConfig(llmConfig); + } + + const fetchCustomModels = async () => { try { - const response = await fetch('/api/v1/ppt/ollama/list-supported-models'); + setCustomModelsLoading(true); + const response = await fetch('/api/v1/ppt/models/list/custom', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + url: llmConfig.CUSTOM_LLM_URL || '', + api_key: llmConfig.CUSTOM_LLM_API_KEY || '' + }) + }); const data = await response.json(); - setOllamaModels(data.models); + setCustomModels(data); + setCustomModelsChecked(true); + + // Check if currently selected model is still available + if (llmConfig.CUSTOM_MODEL && data.length > 0) { + const isModelAvailable = data.includes(llmConfig.CUSTOM_MODEL); + if (!isModelAvailable) { + setLlmConfig({ ...llmConfig, CUSTOM_MODEL: '' }); + toast({ + title: 'Model Unavailable', + description: `The selected model "${llmConfig.CUSTOM_MODEL}" is no longer available. Please select a different model.`, + variant: 'destructive', + }); + } + } } catch (error) { - console.error('Error fetching ollama models:', error); + console.error('Error fetching custom models:', error); + toast({ + title: 'Error', + description: 'Failed to fetch available models. Please check your URL and API key.', + variant: 'destructive', + }); + } finally { + setCustomModelsLoading(false); } } const setOllamaConfig = () => { if (!useCustomOllamaUrl) { - setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', USE_CUSTOM_URL: false }); + setLlmConfig({ ...llmConfig, OLLAMA_URL: 'http://localhost:11434', USE_CUSTOM_URL: false }); } else { setLlmConfig({ ...llmConfig, USE_CUSTOM_URL: true }); } } + const onCustomModelInfoChange = (value: string, field: string) => { + setCustomModels([]); + setCustomModelsChecked(false); + setLlmConfig({ ...llmConfig, CUSTOM_MODEL: '', CUSTOM_LLM_URL: field === 'custom_llm_url' ? value : llmConfig.CUSTOM_LLM_URL, CUSTOM_LLM_API_KEY: field === 'custom_llm_api_key' ? value : llmConfig.CUSTOM_LLM_API_KEY }); + } + useEffect(() => { if (!canChangeKeys) { @@ -187,6 +264,11 @@ const SettingsPage = () => { } if (userConfigState.llm_config.LLM === 'ollama') { fetchOllamaModels(); + } else if (userConfigState.llm_config.LLM === 'custom' && + userConfigState.llm_config.CUSTOM_MODEL && + userConfigState.llm_config.CUSTOM_LLM_URL && + userConfigState.llm_config.CUSTOM_LLM_API_KEY) { + fetchCustomModels(); } }, [userConfigState.llm_config.LLM]); @@ -194,6 +276,7 @@ const SettingsPage = () => { setOllamaConfig(); }, [useCustomOllamaUrl]); + if (!canChangeKeys) { return null; } @@ -247,37 +330,20 @@ const SettingsPage = () => { {/* API Key Input */} - {llmConfig.LLM !== 'ollama' && ( + {llmConfig.LLM !== 'ollama' && llmConfig.LLM !== 'custom' && (
-
+
input_field_changed(e.target.value, llmConfig.LLM === 'openai' ? 'openai_api_key' : 'google_api_key')} - className="flex-1 px-4 py-2.5 border border-gray-300 outline-none rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors" + className="w-full px-4 py-2.5 border border-gray-300 outline-none rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors" placeholder={PROVIDER_CONFIGS[llmConfig.LLM!].placeholder} /> -

{PROVIDER_CONFIGS[llmConfig.LLM!].description}

@@ -302,25 +368,25 @@ const SettingsPage = () => { className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between" >
- {llmConfig.MODEL && ( + {llmConfig.OLLAMA_MODEL && (
m.value === llmConfig.MODEL)?.icon} - alt={`${llmConfig.MODEL} icon`} + src={ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.icon} + alt={`${llmConfig.OLLAMA_MODEL} icon`} className="rounded-sm" />
)} - {llmConfig.MODEL ? ( - ollamaModels.find(m => m.value === llmConfig.MODEL)?.label || llmConfig.MODEL + {llmConfig.OLLAMA_MODEL ? ( + ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.label || llmConfig.OLLAMA_MODEL ) : ( 'Select a model' )} - {llmConfig.MODEL && ( + {llmConfig.OLLAMA_MODEL && ( - {ollamaModels.find(m => m.value === llmConfig.MODEL)?.size} + {ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.size} )}
@@ -338,14 +404,14 @@ const SettingsPage = () => { key={index} value={model.value} onSelect={(value) => { - setLlmConfig({ ...llmConfig, MODEL: value }); + setLlmConfig({ ...llmConfig, OLLAMA_MODEL: value }); setOpenModelSelect(false); }} >
@@ -419,8 +485,8 @@ const SettingsPage = () => { required placeholder="Enter your Ollama URL" className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors" - value={llmConfig.LLM_PROVIDER_URL || ''} - onChange={(e) => input_field_changed(e.target.value, 'llm_provider_url')} + value={llmConfig.OLLAMA_URL || ''} + onChange={(e) => input_field_changed(e.target.value, 'ollama_url')} />

@@ -435,39 +501,19 @@ const SettingsPage = () => {

-
+
input_field_changed(e.target.value, 'pexels_api_key')} /> -
-

Required for using Ollama services with image generation

+

Provide a Pexels API key to generate presentation images

{downloadingModel.status && downloadingModel.status !== 'pulled' && (
@@ -476,6 +522,202 @@ const SettingsPage = () => { )}
)} + + {/* Custom Model Configuration */} + {llmConfig.LLM === 'custom' && ( +
+
+ +
+ onCustomModelInfoChange(e.target.value, 'custom_llm_url')} + /> +
+
+ +
+ +
+ onCustomModelInfoChange(e.target.value, 'custom_llm_api_key')} + /> +
+
+ + {/* Model selection dropdown - show if models are available or if there's a selected model */} + {((customModelsChecked && customModels.length > 0) || llmConfig.CUSTOM_MODEL) && ( +
+ +
+ + + + + + + + + No model found. + + {customModels.map((model, index) => ( + { + setLlmConfig({ ...llmConfig, CUSTOM_MODEL: value }); + setOpenModelSelect(false); + }} + > + + + {model} + + + ))} + + + + + +
+
+ )} + + {/* Check for available models button - show when no models checked or no models found, and no model is selected */} + {(!customModelsChecked || (customModelsChecked && customModels.length === 0)) && !llmConfig.CUSTOM_MODEL && ( +
+ +
+ )} + + {/* Show message if no models found */} + {customModelsChecked && customModels.length === 0 && ( +
+

+ No models found. Please check your URL and API key, or try again. +

+
+ )} + + {/* Refresh models button - show when there's a selected model but we want to refresh */} + {llmConfig.CUSTOM_MODEL && customModelsChecked && ( +
+ +
+ )} + +
+ +
+ input_field_changed(e.target.value, 'pexels_api_key')} + /> +
+

Provide a Pexels API key to generate presentation images

+
+
+ )} + + {/* Save Button */} + + + { + llmConfig.LLM === 'ollama' && downloadingModel.status && downloadingModel.status !== 'pulled' && ( +
+ {downloadingModel.status} +
+ ) + }
diff --git a/servers/nextjs/app/storeInitializer.tsx b/servers/nextjs/app/storeInitializer.tsx index 8e0e1436..95baf330 100644 --- a/servers/nextjs/app/storeInitializer.tsx +++ b/servers/nextjs/app/storeInitializer.tsx @@ -45,13 +45,21 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) { if (isValid) { // Check if the selected Ollama model is pulled if (llmConfig.LLM === 'ollama') { - const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.MODEL); + const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.OLLAMA_MODEL); if (!isPulled) { router.push('/'); setLoadingToFalseAfterNavigatingTo('/'); return; } } + if (llmConfig.LLM === 'custom') { + const isAvailable = await checkIfSelectedCustomModelIsAvailable(llmConfig.CUSTOM_MODEL); + if (!isAvailable) { + router.push('/'); + setLoadingToFalseAfterNavigatingTo('/'); + return; + } + } if (route === '/') { router.push('/upload'); setLoadingToFalseAfterNavigatingTo('/upload'); @@ -86,6 +94,22 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) { } } + const checkIfSelectedCustomModelIsAvailable = async (customModel: string) => { + try { + const response = await fetch('/api/v1/ppt/models/list/custom', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + }); + const data = await response.json(); + return data.includes(customModel); + } catch (error) { + console.error('Error fetching custom models:', error); + return false; + } + } + if (isLoading) { return ( diff --git a/servers/nextjs/components/Home.tsx b/servers/nextjs/components/Home.tsx index a88dbbb0..93cddbf0 100644 --- a/servers/nextjs/components/Home.tsx +++ b/servers/nextjs/components/Home.tsx @@ -30,6 +30,7 @@ import { } from "./ui/popover"; import { cn } from "@/lib/utils"; import { Switch } from "./ui/switch"; +import { setLLMConfig } from "@/store/slices/userConfig"; interface ModelOption { value: string; @@ -137,28 +138,28 @@ const PROVIDER_CONFIGS: Record = { docsUrl: "https://www.pexels.com/api/documentation/", }, }, - // custom: { - // textModels: [], - // imageModels: [ - // { - // value: "pexels", - // label: "Pexels", - // description: "Pexels is a free stock photo and video platform that allows you to download high-quality images and videos for free.", - // icon: "/icons/pexels.png", - // size: "8GB", - // }, - // ], - // apiGuide: { - // title: "How to get your Pexels API Key", - // steps: [ - // "Visit pexels.com", - // 'Click on "Get API key" in the top navigation', - // "Copy your API key - you're ready to go!", - // ], - // videoUrl: "https://www.youtube.com/watch?v=o8iyrtQyrZM&t=66s", - // docsUrl: "https://www.pexels.com/api/documentation/", - // }, - // }, + custom: { + textModels: [], + imageModels: [ + { + value: "pexels", + label: "Pexels", + description: "Pexels is a free stock photo and video platform that allows you to download high-quality images and videos for free.", + icon: "/icons/pexels.png", + size: "8GB", + }, + ], + apiGuide: { + title: "How to get your Pexels API Key", + steps: [ + "Visit pexels.com", + 'Click on "Get API key" in the top navigation', + "Copy your API key - you're ready to go!", + ], + videoUrl: "https://www.youtube.com/watch?v=o8iyrtQyrZM&t=66s", + docsUrl: "https://www.pexels.com/api/documentation/", + }, + }, }; export default function Home() { @@ -172,6 +173,7 @@ export default function Home() { size: string; icon: string; }[]>([]); + const [customModels, setCustomModels] = useState([]); const [downloadingModel, setDownloadingModel] = useState({ name: '', size: null, @@ -182,6 +184,8 @@ export default function Home() { const [isLoading, setIsLoading] = useState(false); const [openModelSelect, setOpenModelSelect] = useState(false); const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState(llmConfig.USE_CUSTOM_URL || false); + const [customModelsLoading, setCustomModelsLoading] = useState(false); + const [customModelsChecked, setCustomModelsChecked] = useState(false); const canChangeKeys = config.can_change_keys; @@ -190,8 +194,16 @@ export default function Home() { setLlmConfig({ ...llmConfig, OPENAI_API_KEY: new_value }); } else if (field === 'google_api_key') { setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: new_value }); - } else if (field === 'llm_provider_url') { - setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: new_value }); + } else if (field === 'ollama_url') { + setLlmConfig({ ...llmConfig, OLLAMA_URL: new_value }); + } else if (field === 'ollama_model') { + setLlmConfig({ ...llmConfig, OLLAMA_MODEL: new_value }); + } else if (field === 'custom_llm_url') { + setLlmConfig({ ...llmConfig, CUSTOM_LLM_URL: new_value }); + } else if (field === 'custom_llm_api_key') { + setLlmConfig({ ...llmConfig, CUSTOM_LLM_API_KEY: new_value }); + } else if (field === 'custom_model') { + setLlmConfig({ ...llmConfig, CUSTOM_MODEL: new_value }); } else if (field === 'pexels_api_key') { setLlmConfig({ ...llmConfig, PEXELS_API_KEY: new_value }); } @@ -221,10 +233,30 @@ export default function Home() { } }; + const fetchOllamaModelsWithConfig = async (config: any) => { + try { + const response = await fetch('/api/v1/ppt/ollama/list-supported-models'); + const data = await response.json(); + setOllamaModels(data.models); + + // Check if currently selected model is still available + if (config.OLLAMA_MODEL && data.models.length > 0) { + const isModelAvailable = data.models.some((model: any) => model.value === config.OLLAMA_MODEL); + if (!isModelAvailable) { + setLlmConfig({ ...config, OLLAMA_MODEL: '' }); + } + } + } catch (error) { + console.error('Error fetching ollama models:', error); + } + } + const changeProvider = (provider: string) => { - setLlmConfig({ ...llmConfig, LLM: provider }); + const newConfig = { ...llmConfig, LLM: provider }; + setLlmConfig(newConfig); if (provider === 'ollama') { - fetchOllamaModels(); + // Use the new config to avoid stale state issues + fetchOllamaModelsWithConfig(newConfig); } } @@ -242,7 +274,7 @@ export default function Home() { return new Promise((resolve, reject) => { const interval = setInterval(async () => { try { - const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`); + const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`); if (response.status === 200) { const data = await response.json(); if (data.done && data.status !== 'error') { @@ -274,18 +306,40 @@ export default function Home() { } const fetchOllamaModels = async () => { + await fetchOllamaModelsWithConfig(llmConfig); + } + + const fetchCustomModels = async () => { try { - const response = await fetch('/api/v1/ppt/ollama/list-supported-models'); + setCustomModelsLoading(true); + const response = await fetch('/api/v1/ppt/models/list/custom', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + url: llmConfig.CUSTOM_LLM_URL || '', + api_key: llmConfig.CUSTOM_LLM_API_KEY || '' + }) + }); const data = await response.json(); - setOllamaModels(data.models); + setCustomModels(data); + setCustomModelsChecked(true); } catch (error) { - console.error('Error fetching ollama models:', error); + console.error('Error fetching custom models:', error); + toast({ + title: 'Error', + description: 'Failed to fetch available models. Please check your URL and API key.', + variant: 'destructive', + }); + } finally { + setCustomModelsLoading(false); } } const setOllamaConfig = () => { if (!useCustomOllamaUrl) { - setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', USE_CUSTOM_URL: false }); + setLlmConfig({ ...llmConfig, OLLAMA_URL: 'http://localhost:11434', USE_CUSTOM_URL: false }); } else { setLlmConfig({ ...llmConfig, USE_CUSTOM_URL: true }); } @@ -304,6 +358,14 @@ export default function Home() { setOllamaConfig(); }, [useCustomOllamaUrl]); + // Reset custom models when URL or API key changes + useEffect(() => { + if (llmConfig.LLM === 'custom') { + setCustomModels([]); + setCustomModelsChecked(false); + setLlmConfig({ ...llmConfig, CUSTOM_MODEL: '' }); + } + }, [llmConfig.CUSTOM_LLM_URL, llmConfig.CUSTOM_LLM_API_KEY]); if (!canChangeKeys) { return null; @@ -355,7 +417,7 @@ export default function Home() {
{/* API Key Input */} - {llmConfig.LLM !== 'ollama' &&
+ {llmConfig.LLM !== 'ollama' && llmConfig.LLM !== 'custom' &&

- Required for generating presentation images + Provide a Pexels API key to generate presentation images

) } + { + llmConfig.LLM === 'custom' && ( + <> +
+ +
+ input_field_changed(e.target.value, 'custom_llm_url')} + /> +
+
+
+ +
+ input_field_changed(e.target.value, 'custom_llm_api_key')} + /> +
+
+ + {/* Model selection dropdown - only show if models are available */} + {customModelsChecked && customModels.length > 0 && ( +
+ +
+ + + + + + + + + No model found. + + {customModels.map((model, index) => ( + { + input_field_changed(value, 'custom_model'); + setOpenModelSelect(false); + }} + > + + + {model} + + + ))} + + + + + +
+
+ )} + + {/* Check for available models button - show when no models checked or no models found */} + {(!customModelsChecked || (customModelsChecked && customModels.length === 0)) && ( +
+ +
+ )} + + {/* Show message if no models found */} + {customModelsChecked && customModels.length === 0 && ( +
+

+ No models found. Please check your URL and API key, or try again. +

+
+ )} + +
+ +
+ input_field_changed(e.target.value, 'pexels_api_key')} + /> +
+

+ + Provide a Pexels API key to generate presentation images +

+
+ + ) + } {/* Model Information */}
@@ -550,7 +756,7 @@ export default function Home() { Selected Models

- Using {llmConfig.LLM === 'ollama' ? llmConfig.MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text + Using {llmConfig.LLM === 'ollama' ? llmConfig.OLLAMA_MODEL ?? '_____' : llmConfig.LLM === 'custom' ? llmConfig.CUSTOM_MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text generation and {PROVIDER_CONFIGS[llmConfig.LLM!].imageModels[0].label} for images

@@ -611,8 +817,8 @@ export default function Home() { {/* Save Button */}