diff --git a/servers/fastapi/api/main.py b/servers/fastapi/api/main.py index 960ce692..4908a063 100644 --- a/servers/fastapi/api/main.py +++ b/servers/fastapi/api/main.py @@ -15,10 +15,10 @@ can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false" # Ollama model download if not can_change_keys and is_ollama_selected(): - ollama_model = os.getenv("OLLAMA_MODEL") + ollama_model = os.getenv("MODEL") pexels_api_key = os.getenv("PEXELS_API_KEY") if not (ollama_model or pexels_api_key): - raise Exception("OLLAMA_MODEL and PEXELS_API_KEY must be provided") + raise Exception("MODEL and PEXELS_API_KEY must be provided") if ollama_model not in SUPPORTED_OLLAMA_MODELS: raise Exception(f"Model {ollama_model} is not supported") diff --git a/servers/fastapi/api/models.py b/servers/fastapi/api/models.py index da38aa83..96d7b8d0 100644 --- a/servers/fastapi/api/models.py +++ b/servers/fastapi/api/models.py @@ -63,7 +63,9 @@ class UserConfig(BaseModel): LLM: Optional[str] = None OPENAI_API_KEY: Optional[str] = None GOOGLE_API_KEY: Optional[str] = None - OLLAMA_MODEL: Optional[str] = None + MODEL: Optional[str] = None + LLM_PROVIDER_URL: Optional[str] = None + LLM_API_KEY: Optional[str] = None PEXELS_API_KEY: Optional[str] = None diff --git a/servers/fastapi/api/routers/presentation/handlers/edit.py b/servers/fastapi/api/routers/presentation/handlers/edit.py index 2706b53f..6cc1d514 100644 --- a/servers/fastapi/api/routers/presentation/handlers/edit.py +++ b/servers/fastapi/api/routers/presentation/handlers/edit.py @@ -69,7 +69,7 @@ class PresentationEditHandler: new_slide_type = new_slide_type.slide_type if is_ollama_selected(): - model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")] + model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")] if not model.supports_graph: if new_slide_type == 5: new_slide_type = 1 diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_data.py b/servers/fastapi/api/routers/presentation/handlers/generate_data.py index f3c6edd7..8375d7eb 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_data.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_data.py @@ -54,7 +54,7 @@ class PresentationGenerateDataHandler: ) ) supports_graph = True - model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")] + model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")] supports_graph = model.supports_graph for each in presentation_structure.slides: diff --git a/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py b/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py index 65fab26a..c96db431 100644 --- a/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py +++ b/servers/fastapi/api/routers/presentation/handlers/list_ollama_pulled_models.py @@ -1,7 +1,9 @@ -import ollama +import os +import aiohttp from api.models import LogMetadata from api.routers.presentation.models import OllamaModelStatusResponse from api.services.logging import LoggingService +from api.utils.model_utils import get_llm_api_key_or, get_llm_provider_url_or class ListPulledOllamaModelsHandler: @@ -12,20 +14,25 @@ class ListPulledOllamaModelsHandler: extra=log_metadata.model_dump(), ) - response = ollama.list() + async with aiohttp.ClientSession() as session: + async with session.get( + f"{get_llm_provider_url_or()}/api/tags", + headers={"Authorization": f"Bearer {get_llm_api_key_or()}"}, + ) as response: + response_data = await response.json() logging_service.logger.info( - logging_service.message(response.model_dump(mode="json")), + logging_service.message(response_data), extra=log_metadata.model_dump(), ) return [ OllamaModelStatusResponse( - name=model.model, - size=model.size, + name=model["model"], + size=model["size"], status="pulled", - downloaded=model.size, + downloaded=model["size"], done=True, ) - for model in response.models + for model in response_data["models"] ] diff --git a/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py b/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py index b942ddb8..f14dd239 100644 --- a/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py +++ b/servers/fastapi/api/routers/presentation/handlers/pull_ollama_model.py @@ -1,5 +1,5 @@ -import asyncio import json +import aiohttp from fastapi import BackgroundTasks, HTTPException from api.models import LogMetadata from api.routers.presentation.handlers.list_supported_ollama_models import ( @@ -8,7 +8,7 @@ from api.routers.presentation.handlers.list_supported_ollama_models import ( from api.routers.presentation.models import OllamaModelStatusResponse from api.services.instances import REDIS_SERVICE from api.services.logging import LoggingService -import ollama +from api.utils.model_utils import get_llm_api_key_or, get_llm_provider_url_or class PullOllamaModelHandler: @@ -33,19 +33,34 @@ class PullOllamaModelHandler: detail=f"Model {self.name} is not supported", ) - pulled_models = ollama.list().models - filtered_models = list( - filter(lambda model: model.model == self.name, pulled_models) - ) + # Check if model is already pulled using LLM_PROVIDER_URL/api/tags + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{get_llm_provider_url_or()}/api/tags", + headers={"Authorization": f"Bearer {get_llm_api_key_or()}"}, + ) as response: + if response.status == 200: + pulled_models = await response.json() + filtered_models = [ + model + for model in pulled_models["models"] + if model["model"] == self.name + ] - # If the model is already pulled, return the model - if filtered_models: - return OllamaModelStatusResponse( - name=self.name, - size=filtered_models[0].size, - status="pulled", - downloaded=filtered_models[0].size, - done=True, + # If the model is already pulled, return the model + if filtered_models: + return OllamaModelStatusResponse( + name=self.name, + size=filtered_models[0]["size"], + status="pulled", + downloaded=filtered_models[0]["size"], + done=True, + ) + except Exception as e: + logging_service.logger.warning( + f"Failed to check pulled models: {e}", + extra=log_metadata.model_dump(), ) saved_model_status = REDIS_SERVICE.get(f"ollama_models/{self.name}") @@ -64,33 +79,64 @@ class PullOllamaModelHandler: ) async def pull_model_in_background(self): - await asyncio.to_thread(self.pull_model) + await self.pull_model() - def pull_model(self): + async def pull_model(self): saved_model_status = OllamaModelStatusResponse( name=self.name, status="pulling", done=False, ) log_event_count = 0 - for event in ollama.pull(self.name, stream=True): - log_event_count += 1 - if log_event_count != 1 and log_event_count % 20 != 0: - continue - if event.completed: - saved_model_status.downloaded = event.completed + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{get_llm_provider_url_or()}/api/pull", + json={"model": self.name}, + headers={"Authorization": f"Bearer {get_llm_api_key_or()}"}, + ) as response: + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"Failed to pull model: {await response.text()}", + ) - if not saved_model_status.size and event.total: - saved_model_status.size = event.total + async for line in response.content: + if not line.strip(): + continue - if event.status: - saved_model_status.status = event.status + try: + event = json.loads(line.decode("utf-8")) + except json.JSONDecodeError: + continue + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue + + if "completed" in event: + saved_model_status.downloaded = event["completed"] + + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] + + if "status" in event: + saved_model_status.status = event["status"] + + REDIS_SERVICE.set( + f"ollama_models/{self.name}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + + except Exception as e: + saved_model_status.status = "error" + saved_model_status.done = True REDIS_SERVICE.set( f"ollama_models/{self.name}", json.dumps(saved_model_status.model_dump(mode="json")), ) + raise e saved_model_status.done = True saved_model_status.status = "pulled" diff --git a/servers/fastapi/api/routers/presentation/router.py b/servers/fastapi/api/routers/presentation/router.py index eb7d4806..be48b529 100644 --- a/servers/fastapi/api/routers/presentation/router.py +++ b/servers/fastapi/api/routers/presentation/router.py @@ -82,6 +82,9 @@ from api.routers.presentation.models import ( ) from api.sql_models import PresentationSqlModel from api.utils.utils import handle_errors +from image_processor.images_finder import ( + generate_image_google, +) from ppt_generator.models.slide_model import SlideModel route_prefix = "/api/v1/ppt" diff --git a/servers/fastapi/api/utils/model_utils.py b/servers/fastapi/api/utils/model_utils.py index e7a740b7..b2ee037a 100644 --- a/servers/fastapi/api/utils/model_utils.py +++ b/servers/fastapi/api/utils/model_utils.py @@ -9,6 +9,14 @@ def is_ollama_selected() -> bool: return get_selected_llm_provider() == SelectedLLMProvider.OLLAMA +def get_llm_provider_url_or(): + return os.getenv("LLM_PROVIDER_URL") or "http://localhost:11434" + + +def get_llm_api_key_or(): + return os.getenv("LLM_API_KEY") or "ollama" + + def get_selected_llm_provider() -> SelectedLLMProvider: return SelectedLLMProvider(os.getenv("LLM")) @@ -21,11 +29,9 @@ def get_model_base_url(): elif selected_llm == SelectedLLMProvider.GOOGLE: return "https://generativelanguage.googleapis.com/v1beta/openai" elif selected_llm == SelectedLLMProvider.OLLAMA: - return os.getenv("LLM_PROVIDER_URL", "http://localhost:11434/v1") - elif selected_llm == SelectedLLMProvider.CUSTOM: - return os.getenv("LLM_PROVIDER_URL") + return os.path.join(get_llm_provider_url_or(), "v1") else: - raise ValueError(f"Invalid LLM provider: {selected_llm}") + raise ValueError(f"Invalid LLM provider") def get_llm_api_key(): @@ -35,11 +41,9 @@ def get_llm_api_key(): elif selected_llm == SelectedLLMProvider.GOOGLE: return os.getenv("GOOGLE_API_KEY") elif selected_llm == SelectedLLMProvider.OLLAMA: - return os.getenv("LLM_API_KEY", "ollama") - elif selected_llm == SelectedLLMProvider.CUSTOM: - return os.getenv("LLM_API_KEY") + return get_llm_api_key_or() else: - raise ValueError(f"Invalid LLM provider: {selected_llm}") + raise ValueError(f"Invalid LLM API key") def get_llm_client(): @@ -57,7 +61,7 @@ def get_large_model(): elif selected_llm == SelectedLLMProvider.GOOGLE: return "gemini-2.0-flash" else: - return os.getenv("OLLAMA_MODEL") + return os.getenv("MODEL") def get_small_model(): @@ -67,7 +71,7 @@ def get_small_model(): elif selected_llm == SelectedLLMProvider.GOOGLE: return "gemini-2.0-flash" else: - return os.getenv("OLLAMA_MODEL") + return os.getenv("MODEL") def get_nano_model(): @@ -77,4 +81,4 @@ def get_nano_model(): elif selected_llm == SelectedLLMProvider.GOOGLE: return "gemini-2.0-flash" else: - return os.getenv("OLLAMA_MODEL") + return os.getenv("MODEL") diff --git a/servers/fastapi/api/utils/utils.py b/servers/fastapi/api/utils/utils.py index 975b751d..e297cc96 100644 --- a/servers/fastapi/api/utils/utils.py +++ b/servers/fastapi/api/utils/utils.py @@ -44,8 +44,11 @@ def get_user_config(): LLM=existing_config.LLM or os.getenv("LLM"), OPENAI_API_KEY=existing_config.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY"), GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or os.getenv("GOOGLE_API_KEY"), - OLLAMA_MODEL=existing_config.OLLAMA_MODEL or os.getenv("OLLAMA_MODEL"), + MODEL=existing_config.MODEL or os.getenv("MODEL"), PEXELS_API_KEY=existing_config.PEXELS_API_KEY or os.getenv("PEXELS_API_KEY"), + LLM_PROVIDER_URL=existing_config.LLM_PROVIDER_URL + or os.getenv("LLM_PROVIDER_URL"), + LLM_API_KEY=existing_config.LLM_API_KEY or os.getenv("LLM_API_KEY"), ) @@ -57,10 +60,14 @@ def update_env_with_user_config(): os.environ["OPENAI_API_KEY"] = user_config.OPENAI_API_KEY if user_config.GOOGLE_API_KEY: os.environ["GOOGLE_API_KEY"] = user_config.GOOGLE_API_KEY - if user_config.OLLAMA_MODEL: - os.environ["OLLAMA_MODEL"] = user_config.OLLAMA_MODEL + if user_config.MODEL: + os.environ["MODEL"] = user_config.MODEL if user_config.PEXELS_API_KEY: os.environ["PEXELS_API_KEY"] = user_config.PEXELS_API_KEY + if user_config.LLM_PROVIDER_URL: + os.environ["LLM_PROVIDER_URL"] = user_config.LLM_PROVIDER_URL + if user_config.LLM_API_KEY: + os.environ["LLM_API_KEY"] = user_config.LLM_API_KEY def get_resource(relative_path): @@ -129,35 +136,35 @@ async def download_files(urls: List[str], save_paths: List[str]): async def handle_errors( func, logging_service: LoggingService, log_metadata: LogMetadata, **kwargs ): - # try: - logging_service.logger.info(f"START", extra=log_metadata.model_dump()) - response = await func( - logging_service=logging_service, log_metadata=log_metadata, **kwargs - ) - is_stream = isinstance(response, StreamingResponse) - logging_service.logger.info( - "STREAMING" if is_stream else "END", extra=log_metadata.model_dump() - ) - return response + try: + logging_service.logger.info(f"START", extra=log_metadata.model_dump()) + response = await func( + logging_service=logging_service, log_metadata=log_metadata, **kwargs + ) + is_stream = isinstance(response, StreamingResponse) + logging_service.logger.info( + "STREAMING" if is_stream else "END", extra=log_metadata.model_dump() + ) + return response - # except HTTPException as e: - # log_metadata.status_code = e.status_code - # logging_service.logger.error( - # f"Raised HTTPException - {e.detail}", extra=log_metadata.model_dump() - # ) - # raise e - # except Exception as e: - # print(traceback.print_stack()) - # print(traceback.print_exc()) + except HTTPException as e: + log_metadata.status_code = e.status_code + logging_service.logger.error( + f"Raised HTTPException - {e.detail}", extra=log_metadata.model_dump() + ) + raise e + except Exception as e: + print(traceback.print_stack()) + print(traceback.print_exc()) - # log_metadata.status_code = 400 - # logging_service.logger.critical( - # "Unhandled Exception", - # exc_info=True, - # stack_info=True, - # extra=log_metadata.model_dump(), - # ) - # raise HTTPException(400, "Something went wrong while processing your request.") + log_metadata.status_code = 400 + logging_service.logger.critical( + "Unhandled Exception", + exc_info=True, + stack_info=True, + extra=log_metadata.model_dump(), + ) + raise HTTPException(400, "Something went wrong while processing your request.") def sanitize_filename(filename: str) -> str: diff --git a/servers/fastapi/image_processor/images_finder.py b/servers/fastapi/image_processor/images_finder.py index 31223b19..d915f914 100644 --- a/servers/fastapi/image_processor/images_finder.py +++ b/servers/fastapi/image_processor/images_finder.py @@ -3,13 +3,14 @@ import base64 import os import uuid import aiohttp -from openai import OpenAI +from google import genai +from google.genai.types import GenerateContentConfig from ppt_generator.models.query_and_prompt_models import ( ImagePromptWithThemeAndAspectRatio, ) from api.utils.utils import download_file, get_resource -from api.utils.model_utils import is_ollama_selected +from api.utils.model_utils import get_llm_client, is_ollama_selected async def generate_image( @@ -46,7 +47,7 @@ async def generate_image( async def generate_image_openai(prompt: str, output_directory: str) -> str: - client = OpenAI() + client = get_llm_client() result = await asyncio.to_thread( client.images.generate, model="dall-e-3", @@ -66,23 +67,22 @@ async def generate_image_openai(prompt: str, output_directory: str) -> str: async def generate_image_google(prompt: str, output_directory: str) -> str: - # response = await ChatGoogleGenerativeAI( - # model="gemini-2.0-flash-preview-image-generation" - # ).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]}) + client = genai.Client() + response = client.models.generate_content( + model="gemini-2.0-flash-preview-image-generation", + contents=[prompt], + config=GenerateContentConfig(response_modalities=["TEXT", "IMAGE"]), + ) - # image_block = next( - # block - # for block in response.content - # if isinstance(block, dict) and block.get("image_url") - # ) + for part in response.candidates[0].content.parts: + if part.text is not None: + print(part.text) + elif part.inline_data is not None: + image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg") + with open(image_path, "wb") as f: + f.write(part.inline_data.data) - # base64_image = image_block["image_url"].get("url").split(",")[-1] - # image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg") - # with open(image_path, "wb") as f: - # f.write(base64.b64decode(base64_image)) - - # return image_path - return "" + return image_path async def get_image_from_pexels(prompt: str, output_directory: str) -> str: diff --git a/servers/fastapi/ppt_generator/models/llm_models_with_validations.py b/servers/fastapi/ppt_generator/models/llm_models_with_validations.py index bce7fc0f..00074152 100644 --- a/servers/fastapi/ppt_generator/models/llm_models_with_validations.py +++ b/servers/fastapi/ppt_generator/models/llm_models_with_validations.py @@ -53,7 +53,7 @@ class LLMTableDataModelWithValidation(LLMTableDataModel): class LLMTableModelWithValidation(LLMTableModel): name: str = Field( - description="Name of the table in less than 8 words", + description="Name of the table in about 8 words", min_length=10, max_length=50, ) @@ -62,20 +62,20 @@ class LLMTableModelWithValidation(LLMTableModel): class LLMHeadingModelWithValidation(LLMHeadingModel): heading: str = Field( - description="Item heading in less than 6 words", + description="Item heading in about 6 words", min_length=10, max_length=40, ) description: str = Field( - description="Item description in less than 15 words.", + description="Item description in about 12 words.", min_length=50, - max_length=150, + max_length=120, ) class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt): image_prompt: str = Field( - description="Item image prompt in less than 10 words", + description="Item image prompt in about 10 words", min_length=10, max_length=100, ) @@ -83,7 +83,7 @@ class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePromp class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery): icon_query: str = Field( - description="Item icon query in less than 4 words", + description="Item icon query in about 4 words", min_length=10, max_length=40, ) @@ -91,7 +91,7 @@ class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery): class LLMSlideContentModelWithValidation(LLMSlideContentModel): title: str = Field( - description="Slide title in less than 8 words", + description="Slide title in about 8 words", min_length=10, max_length=80, ) @@ -99,12 +99,12 @@ class LLMSlideContentModelWithValidation(LLMSlideContentModel): class LLMType1ContentWithValidation(LLMType1Content): body: str = Field( - description="Slide content summary in less than 30 words.", + description="Slide content summary in about 30 words.", min_length=50, max_length=300, ) image_prompt: str = Field( - description="Slide image prompt in less than 5 words", + description="Slide image prompt in about 5 words", min_length=10, max_length=30, ) @@ -125,7 +125,7 @@ class LLMType3ContentWithValidation(LLMType3Content): max_length=3, ) image_prompt: str = Field( - description="Slide image prompt in less than 5 words", + description="Slide image prompt in about 5 words", min_length=10, max_length=30, ) @@ -141,7 +141,7 @@ class LLMType4ContentWithValidation(LLMType4Content): class LLMType5ContentWithValidation(LLMType5Content): body: str = Field( - description="Slide content summary in less than 30 words.", + description="Slide content summary in about 30 words.", min_length=50, max_length=300, ) @@ -150,7 +150,7 @@ class LLMType5ContentWithValidation(LLMType5Content): class LLMType6ContentWithValidation(LLMType6Content): description: str = Field( - description="Slide content summary in less than 20 words.", + description="Slide content summary in about 20 words.", min_length=50, max_length=300, ) @@ -171,7 +171,7 @@ class LLMType7ContentWithValidation(LLMType7Content): class LLMType8ContentWithValidation(LLMType8Content): description: str = Field( - description="Slide content summary in less than 20 words.", + description="Slide content summary in about 20 words.", min_length=50, max_length=300, ) diff --git a/servers/fastapi/requirements.txt b/servers/fastapi/requirements.txt index 928ecf65..81c4d998 100644 --- a/servers/fastapi/requirements.txt +++ b/servers/fastapi/requirements.txt @@ -28,6 +28,7 @@ fsspec==2025.3.2 google-ai-generativelanguage==0.6.18 google-api-core==2.24.2 google-auth==2.40.1 +google-genai==1.23.0 googleapis-common-protos==1.70.0 greenlet==3.2.2 grpcio==1.72.0rc1 diff --git a/servers/nextjs/app/(presentation-generator)/presentation/components/Header.tsx b/servers/nextjs/app/(presentation-generator)/presentation/components/Header.tsx index 3aac3035..dd33bf83 100644 --- a/servers/nextjs/app/(presentation-generator)/presentation/components/Header.tsx +++ b/servers/nextjs/app/(presentation-generator)/presentation/components/Header.tsx @@ -215,7 +215,7 @@ const Header = ({ if (response.ok) { const { path: pdfPath } = await response.json(); const staticFileUrl = getStaticFileUrl(pdfPath); - window.open(staticFileUrl, '_self'); + window.open(staticFileUrl, '_blank'); } else { throw new Error("Failed to export PDF"); } diff --git a/servers/nextjs/app/api/user-config/route.ts b/servers/nextjs/app/api/user-config/route.ts index f8fc32f5..1fc257b4 100644 --- a/servers/nextjs/app/api/user-config/route.ts +++ b/servers/nextjs/app/api/user-config/route.ts @@ -36,7 +36,9 @@ export async function POST(request: Request) { LLM: userConfig.LLM || existingConfig.LLM, OPENAI_API_KEY: userConfig.OPENAI_API_KEY || existingConfig.OPENAI_API_KEY, GOOGLE_API_KEY: userConfig.GOOGLE_API_KEY || existingConfig.GOOGLE_API_KEY, - OLLAMA_MODEL: userConfig.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL, + MODEL: userConfig.MODEL || existingConfig.MODEL, + LLM_PROVIDER_URL: userConfig.LLM_PROVIDER_URL || existingConfig.LLM_PROVIDER_URL, + LLM_API_KEY: userConfig.LLM_API_KEY || existingConfig.LLM_API_KEY, PEXELS_API_KEY: userConfig.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY, } fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig)) diff --git a/servers/nextjs/app/settings/SettingPage.tsx b/servers/nextjs/app/settings/SettingPage.tsx index af8ad73d..703d4ac4 100644 --- a/servers/nextjs/app/settings/SettingPage.tsx +++ b/servers/nextjs/app/settings/SettingPage.tsx @@ -2,13 +2,29 @@ import React, { useState, useEffect } from "react"; import Header from "../dashboard/components/Header"; import Wrapper from "@/components/Wrapper"; -import { Settings, Key, Loader2 } from 'lucide-react'; +import { Settings, Key, Loader2, Check, ChevronsUpDown } from 'lucide-react'; import { toast } from '@/hooks/use-toast'; import { RootState } from "@/store/store"; import { useSelector } from "react-redux"; import { handleSaveLLMConfig } from "@/utils/storeHelpers"; import { useRouter } from "next/navigation"; import { Select, SelectContent, SelectItem, SelectTrigger } from "@/components/ui/select"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { cn } from "@/lib/utils"; +import { Switch } from "@/components/ui/switch"; const PROVIDER_CONFIGS: Record = { openai: { @@ -55,14 +71,22 @@ const SettingsPage = () => { done: false, }); const [isLoading, setIsLoading] = useState(false); + const [openModelSelect, setOpenModelSelect] = useState(false); + const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState(false); - const api_key_changed = (apiKey: string) => { + const api_key_changed = (apiKey: string, field?: string) => { if (llmConfig.LLM === 'openai') { setLlmConfig({ ...llmConfig, OPENAI_API_KEY: apiKey }); } else if (llmConfig.LLM === 'google') { setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: apiKey }); } else if (llmConfig.LLM === 'ollama') { - setLlmConfig({ ...llmConfig, PEXELS_API_KEY: apiKey }); + if (field === 'pexels') { + setLlmConfig({ ...llmConfig, PEXELS_API_KEY: apiKey }); + } else if (field === 'ollama_url') { + setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: apiKey }); + } else if (field === 'ollama_api_key') { + setLlmConfig({ ...llmConfig, LLM_API_KEY: apiKey }); + } } } @@ -87,7 +111,7 @@ const SettingsPage = () => { } } try { - await handleSaveLLMConfig(llmConfig); + await handleSaveLLMConfig(llmConfig, useCustomOllamaUrl); toast({ title: 'Success', description: 'Configuration saved successfully', @@ -116,7 +140,7 @@ const SettingsPage = () => { return new Promise((resolve, reject) => { const interval = setInterval(async () => { try { - const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`); + const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`); if (response.status === 200) { const data = await response.json(); @@ -163,6 +187,14 @@ const SettingsPage = () => { } }, [userConfigState.llm_config.LLM]); + useEffect(() => { + if (!useCustomOllamaUrl) { + setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: undefined, LLM_API_KEY: undefined }); + } else { + setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', LLM_API_KEY: '' }); + } + }, [useCustomOllamaUrl]); + if (!canChangeKeys) { return null; } @@ -262,70 +294,90 @@ const SettingsPage = () => {
{ollamaModels.length > 0 ? ( - + + ))} + + + + + ) : (
@@ -344,6 +396,62 @@ const SettingsPage = () => {

)}
+ + {/* Custom Ollama URL Configuration */} +
+
+ + +
+ {useCustomOllamaUrl && ( + <> +
+ +
+ api_key_changed(e.target.value, 'ollama_url')} + /> +
+

+ + Change this if you are using a custom Ollama instance +

+
+
+ +
+ api_key_changed(e.target.value, 'ollama_api_key')} + /> +
+

+ + Provide this if you are using a custom Ollama instance +

+
+ + )} +
+
) : ( - !llmConfig.OLLAMA_MODEL ? 'Select Model' : 'Save' + !llmConfig.MODEL ? 'Select Model' : 'Save' )}
diff --git a/servers/nextjs/app/storeInitializer.tsx b/servers/nextjs/app/storeInitializer.tsx index 279be8c8..edbb856d 100644 --- a/servers/nextjs/app/storeInitializer.tsx +++ b/servers/nextjs/app/storeInitializer.tsx @@ -41,11 +41,11 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) { llmConfig.LLM = 'openai'; } dispatch(setLLMConfig(llmConfig)); - const isValid = hasValidLLMConfig(llmConfig); + const isValid = hasValidLLMConfig(llmConfig, false); if (isValid) { // Check if the selected Ollama model is pulled if (llmConfig.LLM === 'ollama') { - const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.OLLAMA_MODEL); + const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.MODEL); if (!isPulled) { router.push('/'); setLoadingToFalseAfterNavigatingTo('/'); diff --git a/servers/nextjs/components/Home.tsx b/servers/nextjs/components/Home.tsx index e246c772..3e357f5f 100644 --- a/servers/nextjs/components/Home.tsx +++ b/servers/nextjs/components/Home.tsx @@ -2,7 +2,7 @@ import { useState, useEffect } from "react"; import { useRouter } from "next/navigation"; import { toast } from "@/hooks/use-toast"; -import { Info, ExternalLink, PlayCircle, Loader2 } from "lucide-react"; +import { Info, ExternalLink, PlayCircle, Loader2, Check, ChevronsUpDown } from "lucide-react"; import Link from "next/link"; import { Accordion, @@ -14,6 +14,22 @@ import { useSelector } from "react-redux"; import { RootState } from "@/store/store"; import { handleSaveLLMConfig } from "@/utils/storeHelpers"; import { Select, SelectContent, SelectItem, SelectTrigger } from "./ui/select"; +import { Button } from "./ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "./ui/command"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "./ui/popover"; +import { cn } from "@/lib/utils"; +import { Switch } from "./ui/switch"; interface ModelOption { value: string; @@ -171,16 +187,24 @@ export default function Home() { done: false, }); const [isLoading, setIsLoading] = useState(false); + const [openModelSelect, setOpenModelSelect] = useState(false); + const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState(false); const canChangeKeys = config.can_change_keys; - const api_key_changed = (newApiKey: string) => { + const api_key_changed = (newApiKey: string, field?: string) => { if (llmConfig.LLM === 'openai') { setLlmConfig({ ...llmConfig, OPENAI_API_KEY: newApiKey }); } else if (llmConfig.LLM === 'google') { setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: newApiKey }); } else if (llmConfig.LLM === 'ollama') { - setLlmConfig({ ...llmConfig, PEXELS_API_KEY: newApiKey }); + if (field === 'pexels') { + setLlmConfig({ ...llmConfig, PEXELS_API_KEY: newApiKey }); + } else if (field === 'ollama_url') { + setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: newApiKey }); + } else if (field === 'ollama_api_key') { + setLlmConfig({ ...llmConfig, LLM_API_KEY: newApiKey }); + } } } @@ -205,7 +229,7 @@ export default function Home() { } } try { - await handleSaveLLMConfig(llmConfig); + await handleSaveLLMConfig(llmConfig, useCustomOllamaUrl); toast({ title: 'Success', description: 'Configuration saved successfully', @@ -234,7 +258,7 @@ export default function Home() { return new Promise((resolve, reject) => { const interval = setInterval(async () => { try { - const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`); + const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`); if (response.status === 200) { const data = await response.json(); @@ -277,6 +301,15 @@ export default function Home() { } }, []); + useEffect(() => { + if (!useCustomOllamaUrl) { + setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: undefined, LLM_API_KEY: undefined }); + } else { + setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', LLM_API_KEY: '' }); + } + }, [useCustomOllamaUrl]); + + if (!canChangeKeys) { return null; } @@ -355,70 +388,90 @@ export default function Home() {
{ollamaModels.length > 0 ? ( - + + ))} + + + + + ) : (
@@ -437,6 +490,59 @@ export default function Home() {

)}
+
+
+ + +
+ {useCustomOllamaUrl && ( + <> +
+ +
+ api_key_changed(e.target.value, 'ollama_url')} + /> +
+

+ + Change this if you are using a custom Ollama instance +

+
+
+ +
+ api_key_changed(e.target.value, 'ollama_api_key')} + /> +
+

+ + Provide this if you are using a custom Ollama instance +

+
+ + )} +

@@ -468,7 +574,7 @@ export default function Home() { Selected Models

- Using {llmConfig.LLM === 'ollama' ? llmConfig.OLLAMA_MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text + Using {llmConfig.LLM === 'ollama' ? llmConfig.MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text generation and {PROVIDER_CONFIGS[llmConfig.LLM!].imageModels[0].label} for images

@@ -544,7 +650,7 @@ export default function Home() { }
) : ( - llmConfig.LLM === 'ollama' && !llmConfig.OLLAMA_MODEL + llmConfig.LLM === 'ollama' && !llmConfig.MODEL ? 'Please Select a Model' : 'Save Configuration' )} diff --git a/servers/nextjs/types/global.d.ts b/servers/nextjs/types/global.d.ts index 1d52650a..d3269c53 100644 --- a/servers/nextjs/types/global.d.ts +++ b/servers/nextjs/types/global.d.ts @@ -18,5 +18,7 @@ interface LLMConfig { OPENAI_API_KEY?: string; GOOGLE_API_KEY?: string; PEXELS_API_KEY?: string; - OLLAMA_MODEL?: string; + LLM_PROVIDER_URL?: string; + LLM_API_KEY?: string; + MODEL?: string; } \ No newline at end of file diff --git a/servers/nextjs/utils/storeHelpers.ts b/servers/nextjs/utils/storeHelpers.ts index ccd37a51..dbb13274 100644 --- a/servers/nextjs/utils/storeHelpers.ts +++ b/servers/nextjs/utils/storeHelpers.ts @@ -1,8 +1,8 @@ import { setLLMConfig } from "@/store/slices/userConfig"; import { store } from "@/store/store"; -export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => { - if (!hasValidLLMConfig(llmConfig)) { +export const handleSaveLLMConfig = async (llmConfig: LLMConfig, useCustomOllamaUrl: boolean) => { + if (!hasValidLLMConfig(llmConfig, useCustomOllamaUrl)) { throw new Error('API key cannot be empty'); } @@ -14,17 +14,21 @@ export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => { store.dispatch(setLLMConfig(llmConfig)); } -export const hasValidLLMConfig = (llmConfig: LLMConfig) => { +export const hasValidLLMConfig = (llmConfig: LLMConfig, useCustomOllamaUrl: boolean) => { if (!llmConfig.LLM) return false; const OPENAI_API_KEY = llmConfig.OPENAI_API_KEY; const GOOGLE_API_KEY = llmConfig.GOOGLE_API_KEY; - const OLLAMA_MODEL = llmConfig.OLLAMA_MODEL; + const MODEL = llmConfig.MODEL; const PEXELS_API_KEY = llmConfig.PEXELS_API_KEY; + + const isOllamaBaseConfigValid = PEXELS_API_KEY !== '' && PEXELS_API_KEY !== null && PEXELS_API_KEY !== undefined && MODEL !== '' && MODEL !== null && MODEL !== undefined; + return llmConfig.LLM === 'openai' ? OPENAI_API_KEY !== '' && OPENAI_API_KEY !== null && OPENAI_API_KEY !== undefined : llmConfig.LLM === 'google' ? GOOGLE_API_KEY !== '' && GOOGLE_API_KEY !== null && GOOGLE_API_KEY !== undefined : llmConfig.LLM === 'ollama' ? - PEXELS_API_KEY !== '' && PEXELS_API_KEY !== null && PEXELS_API_KEY !== undefined && OLLAMA_MODEL !== '' && OLLAMA_MODEL !== null && OLLAMA_MODEL !== undefined : - false; + useCustomOllamaUrl ? + isOllamaBaseConfigValid && llmConfig.LLM_PROVIDER_URL !== '' && llmConfig.LLM_PROVIDER_URL !== null && llmConfig.LLM_PROVIDER_URL !== undefined && llmConfig.LLM_API_KEY !== '' && llmConfig.LLM_API_KEY !== null && llmConfig.LLM_API_KEY !== undefined : + isOllamaBaseConfigValid : false; } \ No newline at end of file