Merge pull request #52 from presenton/removes_langchain
feat(ollama): adds support for custom ollama url, refactor: removes langchain
This commit is contained in:
commit
2719ea4e3f
45 changed files with 584731 additions and 592377 deletions
|
|
@ -7,4 +7,5 @@ build
|
|||
.git
|
||||
.gitignore
|
||||
tmp
|
||||
debug
|
||||
debug
|
||||
.fastembed_cache
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -8,4 +8,5 @@ node_modules
|
|||
out
|
||||
user_data
|
||||
tmp
|
||||
debug
|
||||
debug
|
||||
.fastembed_cache
|
||||
|
|
@ -8,16 +8,17 @@ from contextlib import asynccontextmanager
|
|||
from api.routers.presentation.router import presentation_router
|
||||
from api.services.database import sql_engine
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from api.utils.utils import is_ollama_selected, update_env_with_user_config
|
||||
from api.utils.utils import update_env_with_user_config
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
|
||||
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
|
||||
|
||||
# Ollama model download
|
||||
if not can_change_keys and is_ollama_selected():
|
||||
ollama_model = os.getenv("OLLAMA_MODEL")
|
||||
ollama_model = os.getenv("MODEL")
|
||||
pexels_api_key = os.getenv("PEXELS_API_KEY")
|
||||
if not (ollama_model or pexels_api_key):
|
||||
raise Exception("OLLAMA_MODEL and PEXELS_API_KEY must be provided")
|
||||
raise Exception("MODEL and PEXELS_API_KEY must be provided")
|
||||
|
||||
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise Exception(f"Model {ollama_model} is not supported")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
import json
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -62,7 +63,9 @@ class UserConfig(BaseModel):
|
|||
LLM: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
OLLAMA_MODEL: Optional[str] = None
|
||||
MODEL: Optional[str] = None
|
||||
LLM_PROVIDER_URL: Optional[str] = None
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -73,3 +76,10 @@ class OllamaModelMetadata(BaseModel):
|
|||
icon: str
|
||||
size: str
|
||||
supports_graph: bool
|
||||
|
||||
|
||||
class SelectedLLMProvider(Enum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
CUSTOM = "custom"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
|
|
@ -37,7 +35,7 @@ class DecomposeDocumentsHandler:
|
|||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{str(uuid.uuid4())}.txt", self.temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.page_content.replace("<br>", "\n")
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
document_paths.append(file_path)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|||
from api.utils.utils import (
|
||||
get_presentation_dir,
|
||||
get_presentation_images_dir,
|
||||
is_ollama_selected,
|
||||
)
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
|
|
@ -69,7 +69,7 @@ class PresentationEditHandler:
|
|||
new_slide_type = new_slide_type.slide_type
|
||||
|
||||
if is_ollama_selected():
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")]
|
||||
if not model.supports_graph:
|
||||
if new_slide_type == 5:
|
||||
new_slide_type = 1
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from api.routers.presentation.models import PresentationGenerateRequest
|
|||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.utils import is_ollama_selected
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel
|
||||
from ppt_config_generator.structure_generator import generate_presentation_structure
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ class PresentationGenerateDataHandler:
|
|||
)
|
||||
)
|
||||
supports_graph = True
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")]
|
||||
supports_graph = model.supports_graph
|
||||
|
||||
for each in presentation_structure.slides:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import uuid
|
||||
import re
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GenerateOutlinesRequest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import List
|
||||
import uuid, aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -17,7 +18,8 @@ from api.services.database import get_sql_session
|
|||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir, is_ollama_selected
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
|
|
@ -25,14 +27,9 @@ from ppt_config_generator.ppt_outlines_generator import generate_ppt_content
|
|||
from ppt_generator.generator import generate_presentation
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
LLMPresentationModel,
|
||||
)
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
|
||||
|
||||
|
||||
class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
|
|
@ -79,19 +76,17 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
|||
|
||||
print("-" * 40)
|
||||
print("Generating Presentation")
|
||||
presentation_text = (
|
||||
await generate_presentation(
|
||||
PresentationMarkdownModel(
|
||||
title=presentation_content.title,
|
||||
slides=presentation_content.slides,
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
presentation_text = await generate_presentation(
|
||||
PresentationMarkdownModel(
|
||||
title=presentation_content.title,
|
||||
slides=presentation_content.slides,
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
).content
|
||||
)
|
||||
|
||||
print("-" * 40)
|
||||
print("Parsing Presentation")
|
||||
presentation_json = output_parser.parse(presentation_text)
|
||||
presentation_json = json.loads(presentation_text)
|
||||
|
||||
slide_models: List[SlideModel] = []
|
||||
for i, slide in enumerate(presentation_json["slides"]):
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@ from api.routers.presentation.models import (
|
|||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir, is_ollama_selected
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from ppt_config_generator.models import (
|
||||
PresentationMarkdownModel,
|
||||
PresentationStructureModel,
|
||||
)
|
||||
from ppt_generator.generator import generate_presentation_stream
|
||||
from ppt_generator.models.content_type_models import CONTENT_TYPE_MAPPING
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
LLMPresentationModel,
|
||||
|
|
@ -31,12 +31,9 @@ from ppt_generator.models.llm_models import (
|
|||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
|
||||
from ppt_generator.slide_generator import get_slide_content_from_type_and_outline
|
||||
|
||||
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
|
||||
|
||||
|
||||
class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
|
|
@ -144,20 +141,21 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin
|
|||
|
||||
async def generate_presentation_openai_google(self):
|
||||
presentation_text = ""
|
||||
async for chunk in generate_presentation_stream(
|
||||
async for event in await generate_presentation_stream(
|
||||
PresentationMarkdownModel(
|
||||
title=self.title,
|
||||
slides=self.outlines,
|
||||
notes=self.presentation.notes,
|
||||
)
|
||||
):
|
||||
presentation_text += chunk.content
|
||||
chunk = event.choices[0].delta.content
|
||||
presentation_text += chunk
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk.content}),
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = output_parser.parse(presentation_text)
|
||||
self.presentation_json = json.loads(presentation_text)
|
||||
|
||||
async def generate_presentation_ollama(self):
|
||||
presentation_structure = PresentationStructureModel(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import ollama
|
||||
import os
|
||||
import aiohttp
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import OllamaModelStatusResponse
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.model_utils import get_llm_api_key_or, get_llm_provider_url_or
|
||||
|
||||
|
||||
class ListPulledOllamaModelsHandler:
|
||||
|
|
@ -12,20 +14,25 @@ class ListPulledOllamaModelsHandler:
|
|||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
response = ollama.list()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{get_llm_provider_url_or()}/api/tags",
|
||||
headers={"Authorization": f"Bearer {get_llm_api_key_or()}"},
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
logging_service.message(response_data),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return [
|
||||
OllamaModelStatusResponse(
|
||||
name=model.model,
|
||||
size=model.size,
|
||||
name=model["model"],
|
||||
size=model["size"],
|
||||
status="pulled",
|
||||
downloaded=model.size,
|
||||
downloaded=model["size"],
|
||||
done=True,
|
||||
)
|
||||
for model in response.models
|
||||
for model in response_data["models"]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import asyncio
|
||||
import json
|
||||
import aiohttp
|
||||
from fastapi import BackgroundTasks, HTTPException
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
|
|
@ -8,7 +8,7 @@ from api.routers.presentation.handlers.list_supported_ollama_models import (
|
|||
from api.routers.presentation.models import OllamaModelStatusResponse
|
||||
from api.services.instances import REDIS_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
import ollama
|
||||
from api.utils.model_utils import get_llm_api_key_or, get_llm_provider_url_or
|
||||
|
||||
|
||||
class PullOllamaModelHandler:
|
||||
|
|
@ -33,19 +33,34 @@ class PullOllamaModelHandler:
|
|||
detail=f"Model {self.name} is not supported",
|
||||
)
|
||||
|
||||
pulled_models = ollama.list().models
|
||||
filtered_models = list(
|
||||
filter(lambda model: model.model == self.name, pulled_models)
|
||||
)
|
||||
# Check if model is already pulled using LLM_PROVIDER_URL/api/tags
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{get_llm_provider_url_or()}/api/tags",
|
||||
headers={"Authorization": f"Bearer {get_llm_api_key_or()}"},
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
pulled_models = await response.json()
|
||||
filtered_models = [
|
||||
model
|
||||
for model in pulled_models["models"]
|
||||
if model["model"] == self.name
|
||||
]
|
||||
|
||||
# If the model is already pulled, return the model
|
||||
if filtered_models:
|
||||
return OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
size=filtered_models[0].size,
|
||||
status="pulled",
|
||||
downloaded=filtered_models[0].size,
|
||||
done=True,
|
||||
# If the model is already pulled, return the model
|
||||
if filtered_models:
|
||||
return OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
size=filtered_models[0]["size"],
|
||||
status="pulled",
|
||||
downloaded=filtered_models[0]["size"],
|
||||
done=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logging_service.logger.warning(
|
||||
f"Failed to check pulled models: {e}",
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{self.name}")
|
||||
|
|
@ -64,33 +79,64 @@ class PullOllamaModelHandler:
|
|||
)
|
||||
|
||||
async def pull_model_in_background(self):
|
||||
await asyncio.to_thread(self.pull_model)
|
||||
await self.pull_model()
|
||||
|
||||
def pull_model(self):
|
||||
async def pull_model(self):
|
||||
saved_model_status = OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
log_event_count = 0
|
||||
for event in ollama.pull(self.name, stream=True):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if event.completed:
|
||||
saved_model_status.downloaded = event.completed
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{get_llm_provider_url_or()}/api/pull",
|
||||
json={"model": self.name},
|
||||
headers={"Authorization": f"Bearer {get_llm_api_key_or()}"},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to pull model: {await response.text()}",
|
||||
)
|
||||
|
||||
if not saved_model_status.size and event.total:
|
||||
saved_model_status.size = event.total
|
||||
async for line in response.content:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
if event.status:
|
||||
saved_model_status.status = event.status
|
||||
try:
|
||||
event = json.loads(line.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
raise e
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ from api.routers.presentation.models import (
|
|||
)
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.utils.utils import handle_errors
|
||||
from image_processor.images_finder import (
|
||||
generate_image_google,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
route_prefix = "/api/v1/ppt"
|
||||
|
|
|
|||
84
servers/fastapi/api/utils/model_utils.py
Normal file
84
servers/fastapi/api/utils/model_utils.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from api.models import SelectedLLMProvider
|
||||
|
||||
|
||||
def is_ollama_selected() -> bool:
|
||||
return get_selected_llm_provider() == SelectedLLMProvider.OLLAMA
|
||||
|
||||
|
||||
def get_llm_provider_url_or():
|
||||
return os.getenv("LLM_PROVIDER_URL") or "http://localhost:11434"
|
||||
|
||||
|
||||
def get_llm_api_key_or():
|
||||
return os.getenv("LLM_API_KEY") or "ollama"
|
||||
|
||||
|
||||
def get_selected_llm_provider() -> SelectedLLMProvider:
|
||||
return SelectedLLMProvider(os.getenv("LLM"))
|
||||
|
||||
|
||||
def get_model_base_url():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.path.join(get_llm_provider_url_or(), "v1")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM provider")
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return get_llm_api_key_or()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM API key")
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
client = AsyncOpenAI(
|
||||
base_url=get_model_base_url(),
|
||||
api_key=get_llm_api_key(),
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
else:
|
||||
return os.getenv("MODEL")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
else:
|
||||
return os.getenv("MODEL")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
else:
|
||||
return os.getenv("MODEL")
|
||||
|
|
@ -9,48 +9,11 @@ from typing import List, Optional
|
|||
import aiohttp
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from api.models import LogMetadata, UserConfig
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
def is_ollama_selected() -> bool:
|
||||
return os.getenv("LLM") == "ollama"
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = os.getenv("LLM")
|
||||
if selected_llm == "openai":
|
||||
return ChatOpenAI(model="gpt-4.1")
|
||||
elif selected_llm == "google":
|
||||
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
else:
|
||||
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = os.getenv("LLM")
|
||||
if selected_llm == "openai":
|
||||
return ChatOpenAI(model="gpt-4.1-mini")
|
||||
elif selected_llm == "google":
|
||||
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
else:
|
||||
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = os.getenv("LLM")
|
||||
if selected_llm == "openai":
|
||||
return ChatOpenAI(model="gpt-4.1-nano")
|
||||
elif selected_llm == "google":
|
||||
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
else:
|
||||
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
|
||||
|
||||
|
||||
def get_presentation_dir(presentation_id: str) -> str:
|
||||
presentation_dir = os.path.join(os.getenv("APP_DATA_DIRECTORY"), presentation_id)
|
||||
os.makedirs(presentation_dir, exist_ok=True)
|
||||
|
|
@ -81,8 +44,11 @@ def get_user_config():
|
|||
LLM=existing_config.LLM or os.getenv("LLM"),
|
||||
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY"),
|
||||
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or os.getenv("GOOGLE_API_KEY"),
|
||||
OLLAMA_MODEL=existing_config.OLLAMA_MODEL or os.getenv("OLLAMA_MODEL"),
|
||||
MODEL=existing_config.MODEL or os.getenv("MODEL"),
|
||||
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or os.getenv("PEXELS_API_KEY"),
|
||||
LLM_PROVIDER_URL=existing_config.LLM_PROVIDER_URL
|
||||
or os.getenv("LLM_PROVIDER_URL"),
|
||||
LLM_API_KEY=existing_config.LLM_API_KEY or os.getenv("LLM_API_KEY"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -94,10 +60,14 @@ def update_env_with_user_config():
|
|||
os.environ["OPENAI_API_KEY"] = user_config.OPENAI_API_KEY
|
||||
if user_config.GOOGLE_API_KEY:
|
||||
os.environ["GOOGLE_API_KEY"] = user_config.GOOGLE_API_KEY
|
||||
if user_config.OLLAMA_MODEL:
|
||||
os.environ["OLLAMA_MODEL"] = user_config.OLLAMA_MODEL
|
||||
if user_config.MODEL:
|
||||
os.environ["MODEL"] = user_config.MODEL
|
||||
if user_config.PEXELS_API_KEY:
|
||||
os.environ["PEXELS_API_KEY"] = user_config.PEXELS_API_KEY
|
||||
if user_config.LLM_PROVIDER_URL:
|
||||
os.environ["LLM_PROVIDER_URL"] = user_config.LLM_PROVIDER_URL
|
||||
if user_config.LLM_API_KEY:
|
||||
os.environ["LLM_API_KEY"] = user_config.LLM_API_KEY
|
||||
|
||||
|
||||
def get_resource(relative_path):
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,11 +1,10 @@
|
|||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from fastapi import HTTPException
|
||||
from langchain_community.document_loaders import TextLoader, PDFPlumberLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import CharacterTextSplitter, MarkdownTextSplitter
|
||||
from pptx import Presentation
|
||||
import pdfplumber
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from image_processor.utils import get_page_images_from_pdf_async
|
||||
|
|
@ -30,23 +29,13 @@ class DocumentsLoader:
|
|||
def __init__(self, documents: List[str]):
|
||||
self._document_paths = documents
|
||||
|
||||
self._documents: List[Document] = []
|
||||
self._splitted_documents: List[Document] = []
|
||||
self._documents: List[str] = []
|
||||
self._images: List[List[str]] = []
|
||||
|
||||
self._markdown_splitter = MarkdownTextSplitter(chunk_size=500, chunk_overlap=50)
|
||||
self._text_splitter = CharacterTextSplitter(
|
||||
separator="/n", chunk_size=500, chunk_overlap=50
|
||||
)
|
||||
|
||||
@property
|
||||
def documents(self):
|
||||
return self._documents
|
||||
|
||||
@property
|
||||
def splitted_documents(self):
|
||||
return self._splitted_documents
|
||||
|
||||
@property
|
||||
def images(self):
|
||||
return self._images
|
||||
|
|
@ -54,90 +43,69 @@ class DocumentsLoader:
|
|||
async def load_documents(
|
||||
self,
|
||||
temp_dir: str,
|
||||
split_documents: bool = False,
|
||||
load_markdown: bool = True,
|
||||
load_text: bool = True,
|
||||
load_images: bool = False,
|
||||
):
|
||||
documents: List[Document] = []
|
||||
documents: List[str] = []
|
||||
images: List[str] = []
|
||||
|
||||
splitted_documents: List[Document] = []
|
||||
for file_path in self._document_paths:
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"File {file_path} not found"
|
||||
)
|
||||
|
||||
docs = []
|
||||
document = ""
|
||||
imgs = []
|
||||
|
||||
mime_type = mimetypes.guess_type(file_path)[0]
|
||||
if mime_type in PDF_MIME_TYPES:
|
||||
docs, imgs = await self.load_pdf(
|
||||
file_path, load_markdown, load_images, temp_dir
|
||||
document, imgs = await self.load_pdf(
|
||||
file_path, load_text, load_images, temp_dir
|
||||
)
|
||||
elif mime_type in TEXT_MIME_TYPES:
|
||||
docs = self.load_text(file_path)
|
||||
document = await self.load_text(file_path)
|
||||
elif mime_type in POWERPOINT_TYPES:
|
||||
docs = self.load_powerpoint(file_path)
|
||||
document = self.load_powerpoint(file_path)
|
||||
elif mime_type in WORD_TYPES:
|
||||
docs = self.load_msword(file_path)
|
||||
document = self.load_msword(file_path)
|
||||
|
||||
documents.extend(docs)
|
||||
documents.append(document)
|
||||
images.append(imgs)
|
||||
|
||||
if split_documents:
|
||||
splitted_documents.extend(self.split_documents(docs, mime_type))
|
||||
|
||||
self._documents = documents
|
||||
self._splitted_documents = splitted_documents
|
||||
self._images = images
|
||||
|
||||
def split_documents(self, documents: List[Document], mime_type):
|
||||
return self._text_splitter.split_documents(documents)
|
||||
|
||||
def clip_longer_documents(self, documents: List[Document], clip_after: int = 1200):
|
||||
for document in documents:
|
||||
document.page_content = document.page_content[:clip_after]
|
||||
return documents
|
||||
|
||||
async def load_pdf(
|
||||
self,
|
||||
file_path: str,
|
||||
load_markdown: bool,
|
||||
load_text: bool,
|
||||
load_images: bool,
|
||||
temp_dir: str,
|
||||
) -> Tuple[List[Document], List[str]]:
|
||||
) -> Tuple[str, List[str]]:
|
||||
image_paths = []
|
||||
documents: List[Document] = []
|
||||
document: str = ""
|
||||
|
||||
if load_markdown:
|
||||
loader = PDFPlumberLoader(file_path)
|
||||
docs = loader.load()
|
||||
pdf_document = Document(page_content="")
|
||||
pdf_document.metadata = docs[0].metadata
|
||||
for doc in docs:
|
||||
pdf_document.page_content += doc.page_content
|
||||
documents.append(pdf_document)
|
||||
if load_text:
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for page in pdf.pages:
|
||||
document += await asyncio.to_thread(page.extract_text)
|
||||
|
||||
if load_images:
|
||||
image_paths = await get_page_images_from_pdf_async(file_path, temp_dir)
|
||||
|
||||
return documents, image_paths
|
||||
return document, image_paths
|
||||
|
||||
async def decompose_pdf_to_markdown(self, document_path: str) -> str:
|
||||
raise Exception("Not Implemented")
|
||||
async def load_text(self, file_path: str) -> str:
|
||||
with open(file_path, "r") as file:
|
||||
return await asyncio.to_thread(file.read)
|
||||
|
||||
def load_text(self, file_path: str) -> List[Document]:
|
||||
loader = TextLoader(file_path)
|
||||
return loader.load()
|
||||
|
||||
def load_msword(self, file_path: str) -> List[Document]:
|
||||
def load_msword(self, file_path: str) -> str:
|
||||
document = DocxDocument(file_path)
|
||||
text = "\n".join([paragraph.text for paragraph in document.paragraphs])
|
||||
return [Document(page_content=text)]
|
||||
return text
|
||||
|
||||
def load_powerpoint(self, file_path: str) -> List[Document]:
|
||||
def load_powerpoint(self, file_path: str) -> str:
|
||||
presentation = Presentation(file_path)
|
||||
|
||||
extracted_text = ""
|
||||
|
|
@ -149,4 +117,4 @@ class DocumentsLoader:
|
|||
extracted_text += f"{paragraph.text}\n"
|
||||
extracted_text += "\n"
|
||||
extracted_text += "\n\n"
|
||||
return [Document(page_content=extracted_text)]
|
||||
return extracted_text
|
||||
|
|
|
|||
|
|
@ -1,132 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graph_processor.utils import clip_text
|
||||
|
||||
|
||||
class PointModel(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
def to_list(self) -> List[float]:
|
||||
return [self.x, self.y]
|
||||
|
||||
|
||||
class PointWithRadius(PointModel):
|
||||
radius: Optional[float] = None
|
||||
|
||||
|
||||
class BarSeriesModel(BaseModel):
|
||||
name: str
|
||||
data: List[float] = Field(
|
||||
description="Only numbers should be given out in data. Don't include text/string in data."
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return clip_text(self.name)
|
||||
|
||||
|
||||
class ScatterSeriesModel(BaseModel):
|
||||
name: str
|
||||
points: List[PointModel]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return clip_text(self.name)
|
||||
|
||||
|
||||
class BubbleSeriesModel(BaseModel):
|
||||
name: str
|
||||
points: List[PointWithRadius]
|
||||
|
||||
def get_name(self) -> str:
|
||||
return clip_text(self.name)
|
||||
|
||||
|
||||
class LineSeriesModel(BaseModel):
|
||||
name: str
|
||||
data: List[float] = Field(
|
||||
description="Only numbers should be given out in data. Don't include text/string in data."
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
return clip_text(self.name)
|
||||
|
||||
|
||||
class PieChartSeriesModel(BaseModel):
|
||||
data: List[float]
|
||||
|
||||
|
||||
class BarGraphDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[BarSeriesModel] = Field(
|
||||
description="There should be no more than 3 series"
|
||||
)
|
||||
|
||||
def get_categories(self) -> List[str]:
|
||||
return [clip_text(category) for category in self.categories]
|
||||
|
||||
|
||||
class ScatterChartDataModel(BaseModel):
|
||||
series: List[ScatterSeriesModel]
|
||||
|
||||
|
||||
class BubbleChartDataModel(BaseModel):
|
||||
series: List[BubbleSeriesModel]
|
||||
|
||||
|
||||
class LineChartDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[LineSeriesModel] = Field(
|
||||
description="There should be no more than 3 series"
|
||||
)
|
||||
|
||||
def get_categories(self) -> List[str]:
|
||||
return [clip_text(category) for category in self.categories]
|
||||
|
||||
|
||||
class PieChartDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[PieChartSeriesModel] = Field(
|
||||
description="One series model with list of data",
|
||||
min_length=1,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def limit_series(self):
|
||||
self.series = self.series[:1]
|
||||
return self
|
||||
|
||||
def get_categories(self) -> List[str]:
|
||||
return [clip_text(category) for category in self.categories]
|
||||
|
||||
|
||||
# class TableDataModel(BaseModel):
|
||||
# categories: List[str]
|
||||
# series: List[BarSeriesModel]
|
||||
|
||||
# def get_categories(self) -> List[str]:
|
||||
# return [clip_text(category) for category in self.categories]
|
||||
|
||||
|
||||
class GraphTypeEnum(Enum):
|
||||
pie = "pie"
|
||||
bar = "bar"
|
||||
line = "line"
|
||||
|
||||
|
||||
class GraphModel(BaseModel):
|
||||
style: Optional[dict] = {}
|
||||
name: str
|
||||
type: GraphTypeEnum
|
||||
unit: Optional[str] = Field(
|
||||
description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc."
|
||||
)
|
||||
data: PieChartDataModel | LineChartDataModel | BarGraphDataModel
|
||||
|
||||
|
||||
GRAPH_TYPE_MAPPING = {
|
||||
GraphTypeEnum.pie: PieChartDataModel,
|
||||
GraphTypeEnum.bar: BarGraphDataModel,
|
||||
GraphTypeEnum.line: LineChartDataModel,
|
||||
}
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
def clip_text(text: str, max_length: int = 6) -> str:
|
||||
# return text[:max_length] + ".." if len(text) > max_length else text
|
||||
return text
|
||||
|
|
@ -5,17 +5,17 @@ from ppt_generator.models.query_and_prompt_models import (
|
|||
IconCategoryEnum,
|
||||
IconQueryCollectionWithData,
|
||||
)
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from fastembed_vectorstore import FastembedVectorstore
|
||||
|
||||
|
||||
async def get_icon(
|
||||
vector_store: InMemoryVectorStore,
|
||||
vector_store: FastembedVectorstore,
|
||||
input: IconQueryCollectionWithData,
|
||||
) -> str:
|
||||
try:
|
||||
query = input.icon_query
|
||||
results = vector_store.similarity_search(query=query, k=1)
|
||||
icon_name = results[0].page_content
|
||||
results = vector_store.search(query, 1)
|
||||
icon_name = results[0][0].split("||")[0]
|
||||
return get_resource(f"assets/icons/bold/{icon_name}.png")
|
||||
except Exception as e:
|
||||
print("Error finding icon: ", e)
|
||||
|
|
@ -23,7 +23,7 @@ async def get_icon(
|
|||
|
||||
|
||||
async def get_icons(
|
||||
vector_store: InMemoryVectorStore,
|
||||
vector_store: FastembedVectorstore,
|
||||
query: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
|
|
@ -31,7 +31,7 @@ async def get_icons(
|
|||
temp_dir: str,
|
||||
) -> List[str]:
|
||||
|
||||
results = await vector_store.asimilarity_search(query=query, k=limit)
|
||||
icon_names = [result.page_content for result in results]
|
||||
results = vector_store.search(query, limit)
|
||||
icon_names = [result[0].split("||")[0] for result in results]
|
||||
|
||||
return [get_resource(f"assets/icons/bold/{each}.png") for each in icon_names]
|
||||
|
|
|
|||
|
|
@ -1,37 +1,26 @@
|
|||
import json
|
||||
import os
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from api.utils.utils import get_resource
|
||||
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
|
||||
|
||||
# Pyinstaller
|
||||
import fastembed
|
||||
from fastembed_vectorstore import FastembedVectorstore, FastembedEmbeddingModel
|
||||
|
||||
|
||||
def get_icons_vectorstore():
|
||||
vector_store_path = get_resource("assets/icons_vectorstore.json")
|
||||
|
||||
embeddings = FastEmbedEmbeddings()
|
||||
embedding_model = FastembedEmbeddingModel.BGESmallENV15
|
||||
|
||||
if os.path.exists(vector_store_path):
|
||||
vector_store = InMemoryVectorStore.load(vector_store_path, embeddings)
|
||||
return vector_store
|
||||
|
||||
vector_store = InMemoryVectorStore(embeddings)
|
||||
|
||||
vector_store.dump(vector_store_path)
|
||||
return FastembedVectorstore.load(embedding_model, vector_store_path)
|
||||
|
||||
vector_store = FastembedVectorstore(embedding_model)
|
||||
with open(get_resource("assets/icons.json"), "r") as f:
|
||||
icons = json.load(f)
|
||||
|
||||
icon_names = [icon["name"] for icon in icons["icons"]]
|
||||
documents = []
|
||||
for each in icon_names:
|
||||
if each.split("-")[-1] == "bold":
|
||||
documents.append(Document(id=each, page_content=each))
|
||||
for each in icons["icons"]:
|
||||
if each["name"].split("-")[-1] == "bold":
|
||||
documents.append(f"{each['name']}||{each['tags']}")
|
||||
|
||||
vector_store.embed_documents(documents)
|
||||
vector_store.save(vector_store_path)
|
||||
|
||||
vector_store.add_documents(documents)
|
||||
vector_store.dump(vector_store_path)
|
||||
return vector_store
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@ import base64
|
|||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from openai import OpenAI
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from api.utils.utils import download_file, get_resource, is_ollama_selected
|
||||
from api.utils.utils import download_file, get_resource
|
||||
from api.utils.model_utils import get_llm_client, is_ollama_selected
|
||||
|
||||
|
||||
async def generate_image(
|
||||
|
|
@ -46,7 +47,7 @@ async def generate_image(
|
|||
|
||||
|
||||
async def generate_image_openai(prompt: str, output_directory: str) -> str:
|
||||
client = OpenAI()
|
||||
client = get_llm_client()
|
||||
result = await asyncio.to_thread(
|
||||
client.images.generate,
|
||||
model="dall-e-3",
|
||||
|
|
@ -66,20 +67,20 @@ async def generate_image_openai(prompt: str, output_directory: str) -> str:
|
|||
|
||||
|
||||
async def generate_image_google(prompt: str, output_directory: str) -> str:
|
||||
response = await ChatGoogleGenerativeAI(
|
||||
model="gemini-2.0-flash-preview-image-generation"
|
||||
).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]})
|
||||
|
||||
image_block = next(
|
||||
block
|
||||
for block in response.content
|
||||
if isinstance(block, dict) and block.get("image_url")
|
||||
client = genai.Client()
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-preview-image-generation",
|
||||
contents=[prompt],
|
||||
config=GenerateContentConfig(response_modalities=["TEXT", "IMAGE"]),
|
||||
)
|
||||
|
||||
base64_image = image_block["image_url"].get("url").split(",")[-1]
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(base64.b64decode(base64_image))
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text is not None:
|
||||
print(part.text)
|
||||
elif part.inline_data is not None:
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(part.inline_data.data)
|
||||
|
||||
return image_path
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,8 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from api.utils.utils import get_nano_model
|
||||
from api.utils.model_utils import get_llm_client, get_nano_model
|
||||
|
||||
sysmte_prompt = """
|
||||
Generate a blog-style summary of the provided document in **more than 2000 words**.
|
||||
|
|
@ -26,26 +20,25 @@ Maintain as much information as possible.
|
|||
- If **slides structure is mentioned** in document, structure the summary in the same way.
|
||||
"""
|
||||
|
||||
prompt_template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", sysmte_prompt),
|
||||
("user", "{text}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def generate_document_summary(documents: List[Document]):
|
||||
async def generate_document_summary(documents: List[str]):
|
||||
client = get_llm_client()
|
||||
model = get_nano_model()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=200000, chunk_overlap=0)
|
||||
chain = prompt_template | model
|
||||
|
||||
coroutines = []
|
||||
for document in documents:
|
||||
text = document.page_content
|
||||
truncated_text = text_splitter.split_text(text)[0]
|
||||
coroutine = chain.ainvoke({"text": truncated_text})
|
||||
truncated_text = document[:200000]
|
||||
coroutine = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": sysmte_prompt},
|
||||
{"role": "user", "content": truncated_text},
|
||||
],
|
||||
)
|
||||
coroutines.append(coroutine)
|
||||
|
||||
completions: List[BaseMessage] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join([completion.content for completion in completions])
|
||||
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join(
|
||||
[completion.choices[0].message.content for completion in completions]
|
||||
)
|
||||
return combined
|
||||
|
|
|
|||
|
|
@ -1,32 +1,17 @@
|
|||
from typing import Optional
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from api.utils.utils import get_large_model
|
||||
from api.utils.model_utils import get_large_model, get_llm_client
|
||||
from api.utils.variable_length_models import (
|
||||
get_presentation_markdown_model_with_n_slides,
|
||||
)
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_generator.fix_validation_errors import get_validated_response
|
||||
|
||||
|
||||
user_prompt_text = {
|
||||
"type": "text",
|
||||
"text": """
|
||||
**Input:**
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Number of Slides: {n_slides}
|
||||
- Additional Information: {content}
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
return ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
def get_prompt_template(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
Create a presentation based on the provided prompt, number of slides, output language, and additional informational details.
|
||||
Format the output in the specified JSON schema with structured markdown content.
|
||||
|
||||
|
|
@ -49,13 +34,18 @@ def get_prompt_template():
|
|||
- Slide **title** should not be in markdown format.
|
||||
- There must be exact **Number of Slides** as specified.
|
||||
""",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
[user_prompt_text],
|
||||
),
|
||||
],
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
**Input:**
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Number of Slides: {n_slides}
|
||||
- Additional Information: {content}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def generate_ppt_content(
|
||||
|
|
@ -64,21 +54,14 @@ async def generate_ppt_content(
|
|||
language: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> PresentationMarkdownModel:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
response_model = get_presentation_markdown_model_with_n_slides(n_slides)
|
||||
|
||||
chain = get_prompt_template() | model.with_structured_output(
|
||||
response_model.model_json_schema()
|
||||
)
|
||||
|
||||
return await get_validated_response(
|
||||
chain,
|
||||
{
|
||||
"prompt": prompt,
|
||||
"n_slides": n_slides,
|
||||
"language": language or "English",
|
||||
"content": content,
|
||||
},
|
||||
response_model,
|
||||
PresentationMarkdownModel,
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_template(prompt, n_slides, language, content),
|
||||
response_format=response_model,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from api.utils.utils import get_small_model
|
||||
from api.utils.model_utils import get_llm_client, get_small_model
|
||||
from api.utils.variable_length_models import (
|
||||
get_presentation_structure_model_with_n_slides,
|
||||
)
|
||||
|
|
@ -8,13 +6,13 @@ from ppt_config_generator.models import (
|
|||
PresentationStructureModel,
|
||||
PresentationMarkdownModel,
|
||||
)
|
||||
from ppt_generator.fix_validation_errors import get_validated_response
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
|
||||
def get_prompt(n_slides: int, data: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
You're a professional presentation designer with years of experience in designing clear and engaging presentations.
|
||||
|
||||
# Slide Types
|
||||
|
|
@ -44,33 +42,32 @@ prompt = ChatPromptTemplate.from_messages(
|
|||
|
||||
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
|
||||
""",
|
||||
),
|
||||
(
|
||||
"human",
|
||||
"""
|
||||
{data}
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
{data}
|
||||
""",
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def generate_presentation_structure(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> PresentationStructureModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
chain = prompt | model.with_structured_output(response_model.model_json_schema())
|
||||
|
||||
return await get_validated_response(
|
||||
chain,
|
||||
{
|
||||
"n_slides": len(presentation_outline.slides),
|
||||
"data": presentation_outline.to_string(),
|
||||
},
|
||||
response_model,
|
||||
PresentationStructureModel,
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt(
|
||||
len(presentation_outline.slides), presentation_outline.to_string()
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from api.utils.utils import get_large_model
|
||||
from api.utils.model_utils import get_large_model
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
|
|
@ -41,7 +40,7 @@ def get_prompt_template():
|
|||
|
||||
|
||||
async def fix_validation_errors(response_model: BaseModel, response, errors):
|
||||
model = get_large_model()
|
||||
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
||||
|
||||
chain = get_prompt_template() | model.with_structured_output(
|
||||
response_model.model_json_schema()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import AsyncIterator
|
||||
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
AIMessageChunk,
|
||||
AIMessage,
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.utils.model_utils import (
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
get_selected_llm_provider,
|
||||
)
|
||||
from api.utils.utils import get_large_model
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_generator.models.llm_models_with_validations import (
|
||||
LLMPresentationModelWithValidation,
|
||||
|
|
@ -70,42 +70,87 @@ CREATE_PRESENTATION_PROMPT = """
|
|||
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
|
||||
"""
|
||||
|
||||
schema = LLMPresentationModelWithValidation.model_json_schema()
|
||||
|
||||
system_prompt = f"""
|
||||
system_prompt_with_schema = f"""
|
||||
{CREATE_PRESENTATION_PROMPT}
|
||||
|
||||
Follow this schema while giving out response: {schema}.
|
||||
|
||||
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
|
||||
"""
|
||||
|
||||
ollama_system_prompt = f"""
|
||||
{CREATE_PRESENTATION_PROMPT}
|
||||
Follow this schema while giving out response: {LLMPresentationModelWithValidation.model_json_schema()}.
|
||||
|
||||
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
|
||||
"""
|
||||
|
||||
|
||||
def get_model_and_messages(
|
||||
def get_system_prompt():
|
||||
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
|
||||
return (
|
||||
system_prompt_with_schema if is_google_selected else CREATE_PRESENTATION_PROMPT
|
||||
)
|
||||
|
||||
|
||||
def get_response_format():
|
||||
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
|
||||
return (
|
||||
{
|
||||
"type": "json_object",
|
||||
}
|
||||
if is_google_selected
|
||||
else {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "LLMPresentationModel",
|
||||
"schema": LLMPresentationModelWithValidation.model_json_schema(),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def generate_presentation_stream(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
):
|
||||
user_message = HumanMessage(presentation_outline.to_string())
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
return model, system_prompt, user_message
|
||||
response_format = get_response_format()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": get_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": presentation_outline.to_string(),
|
||||
},
|
||||
],
|
||||
response_format=response_format,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def generate_presentation_stream(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> AsyncIterator[AIMessageChunk]:
|
||||
model, system_prompt, user_message = get_model_and_messages(presentation_outline)
|
||||
|
||||
return model.astream([system_prompt, user_message])
|
||||
return response
|
||||
|
||||
|
||||
async def generate_presentation(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> AIMessage:
|
||||
model, system_prompt, user_message = get_model_and_messages(presentation_outline)
|
||||
return await model.ainvoke([system_prompt, user_message])
|
||||
) -> str:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
response_format = get_response_format()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": get_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": presentation_outline.to_string(),
|
||||
},
|
||||
],
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Mapping
|
||||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ppt_generator.models.other_models import (
|
||||
|
|
@ -12,7 +12,17 @@ from ppt_generator.models.other_models import (
|
|||
TYPE8,
|
||||
TYPE9,
|
||||
)
|
||||
from graph_processor.models import GraphModel
|
||||
|
||||
|
||||
class TableDataModel(BaseModel):
|
||||
x_labels: List[str]
|
||||
y_labels: List[str]
|
||||
data: List[List[float]]
|
||||
|
||||
|
||||
class TableModel(BaseModel):
|
||||
name: str
|
||||
data: TableDataModel
|
||||
|
||||
|
||||
class HeadingModel(BaseModel):
|
||||
|
|
@ -47,9 +57,6 @@ class HeadingModel(BaseModel):
|
|||
class SlideContentModel(BaseModel):
|
||||
title: str
|
||||
|
||||
def to_llm_content(self):
|
||||
raise NotImplementedError("to_llm_content method not implemented")
|
||||
|
||||
|
||||
class Type1Content(SlideContentModel):
|
||||
body: str
|
||||
|
|
@ -110,7 +117,7 @@ class Type4Content(SlideContentModel):
|
|||
|
||||
class Type5Content(SlideContentModel):
|
||||
body: str
|
||||
graph: GraphModel
|
||||
table: TableModel
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType5Content
|
||||
|
|
@ -118,7 +125,7 @@ class Type5Content(SlideContentModel):
|
|||
return LLMType5Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
graph=self.graph,
|
||||
table=self.table,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -174,7 +181,7 @@ class Type8Content(SlideContentModel):
|
|||
|
||||
class Type9Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
graph: GraphModel
|
||||
table: TableModel
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType9Content
|
||||
|
|
@ -182,11 +189,23 @@ class Type9Content(SlideContentModel):
|
|||
return LLMType9Content(
|
||||
title=self.title,
|
||||
body=[item.to_llm_content() for item in self.body],
|
||||
graph=self.graph,
|
||||
table=self.table,
|
||||
)
|
||||
|
||||
|
||||
CONTENT_TYPE_MAPPING: Mapping[int, SlideContentModel] = {
|
||||
ContentUnion = Union[
|
||||
Type1Content,
|
||||
Type2Content,
|
||||
Type3Content,
|
||||
Type4Content,
|
||||
Type5Content,
|
||||
Type6Content,
|
||||
Type7Content,
|
||||
Type8Content,
|
||||
Type9Content,
|
||||
]
|
||||
|
||||
CONTENT_TYPE_MAPPING: Mapping[int, ContentUnion] = {
|
||||
TYPE1: Type1Content,
|
||||
TYPE2: Type2Content,
|
||||
TYPE3: Type3Content,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import List, Mapping
|
||||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graph_processor.models import GraphModel
|
||||
from ppt_generator.models.content_type_models import (
|
||||
HeadingModel,
|
||||
SlideContentModel,
|
||||
TableDataModel,
|
||||
TableModel,
|
||||
Type1Content,
|
||||
Type2Content,
|
||||
Type3Content,
|
||||
|
|
@ -28,6 +28,17 @@ from ppt_generator.models.other_models import (
|
|||
)
|
||||
|
||||
|
||||
class LLMTableDataModel(TableDataModel):
|
||||
x_labels: List[str]
|
||||
y_labels: List[str]
|
||||
data: List[List[float]]
|
||||
|
||||
|
||||
class LLMTableModel(TableModel):
|
||||
name: str
|
||||
data: LLMTableDataModel
|
||||
|
||||
|
||||
class LLMHeadingModel(BaseModel):
|
||||
heading: str
|
||||
description: str
|
||||
|
|
@ -42,17 +53,26 @@ class LLMHeadingModel(BaseModel):
|
|||
class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithIconQuery(LLMHeadingModel):
|
||||
icon_query: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class LLMSlideContentModel(BaseModel):
|
||||
title: str
|
||||
|
||||
def to_content(self) -> SlideContentModel:
|
||||
raise NotImplementedError("to_content method not implemented")
|
||||
|
||||
|
||||
class LLMType1Content(LLMSlideContentModel):
|
||||
body: str
|
||||
|
|
@ -101,13 +121,13 @@ class LLMType4Content(LLMSlideContentModel):
|
|||
|
||||
class LLMType5Content(LLMSlideContentModel):
|
||||
body: str
|
||||
graph: GraphModel
|
||||
table: LLMTableModel
|
||||
|
||||
def to_content(self) -> Type5Content:
|
||||
return Type5Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
graph=self.graph,
|
||||
table=self.table,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -149,17 +169,29 @@ class LLMType8Content(LLMSlideContentModel):
|
|||
|
||||
class LLMType9Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModel]
|
||||
graph: GraphModel
|
||||
table: LLMTableModel
|
||||
|
||||
def to_content(self) -> Type9Content:
|
||||
return Type9Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
graph=self.graph,
|
||||
table=self.table,
|
||||
)
|
||||
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
|
||||
LLMContentUnion = Union[
|
||||
LLMType1Content,
|
||||
LLMType2Content,
|
||||
LLMType3Content,
|
||||
LLMType4Content,
|
||||
LLMType5Content,
|
||||
LLMType6Content,
|
||||
LLMType7Content,
|
||||
LLMType8Content,
|
||||
LLMType9Content,
|
||||
]
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMContentUnion] = {
|
||||
TYPE1: LLMType1Content,
|
||||
TYPE2: LLMType2Content,
|
||||
TYPE3: LLMType3Content,
|
||||
|
|
@ -174,17 +206,8 @@ LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
|
|||
|
||||
class LLMSlideModel(BaseModel):
|
||||
type: int
|
||||
content: (
|
||||
LLMType1Content
|
||||
| LLMType2Content
|
||||
| LLMType4Content
|
||||
| LLMType5Content
|
||||
| LLMType6Content
|
||||
| LLMType7Content
|
||||
| LLMType8Content
|
||||
| LLMType9Content
|
||||
)
|
||||
content: LLMContentUnion
|
||||
|
||||
|
||||
class LLMPresentationModel(BaseModel):
|
||||
slides: list[LLMSlideModel]
|
||||
slides: List[LLMSlideModel]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import List, Mapping
|
||||
from typing import List, Mapping, Union
|
||||
from pydantic import Field
|
||||
|
||||
from graph_processor.models import GraphModel
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
|
|
@ -14,6 +13,8 @@ from ppt_generator.models.other_models import (
|
|||
TYPE9,
|
||||
)
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLMTableDataModel,
|
||||
LLMTableModel,
|
||||
LLMHeadingModel,
|
||||
LLMHeadingModelWithImagePrompt,
|
||||
LLMHeadingModelWithIconQuery,
|
||||
|
|
@ -32,239 +33,179 @@ from ppt_generator.models.llm_models import (
|
|||
)
|
||||
|
||||
|
||||
class LLMTableDataModelWithValidation(LLMTableDataModel):
|
||||
x_labels: List[str] = Field(
|
||||
description="X labels of the table",
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
)
|
||||
y_labels: List[str] = Field(
|
||||
description="Y labels of the table",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
data: List[List[float]] = Field(
|
||||
description="Data of the table",
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
)
|
||||
|
||||
|
||||
class LLMTableModelWithValidation(LLMTableModel):
|
||||
name: str = Field(
|
||||
description="Name of the table in about 8 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
data: LLMTableDataModelWithValidation
|
||||
|
||||
|
||||
class LLMHeadingModelWithValidation(LLMHeadingModel):
|
||||
heading: str = Field(
|
||||
description="List item heading to show in slide body",
|
||||
description="Item heading in about 6 words",
|
||||
min_length=10,
|
||||
max_length=30,
|
||||
max_length=40,
|
||||
)
|
||||
description: str = Field(
|
||||
description="Description of list item in less than 20 words.",
|
||||
min_length=80,
|
||||
max_length=150,
|
||||
description="Item description in about 12 words.",
|
||||
min_length=50,
|
||||
max_length=120,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
|
||||
image_prompt: str = Field(
|
||||
description="Prompt used to generate image for this item",
|
||||
description="Item image prompt in about 10 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
max_length=100,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
|
||||
icon_query: str = Field(
|
||||
description="Icon query to generate icon for this item",
|
||||
description="Item icon query in about 4 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
max_length=40,
|
||||
)
|
||||
|
||||
|
||||
class LLMSlideContentModelWithValidation(LLMSlideContentModel):
|
||||
title: str = Field(
|
||||
description="Slide title in about 8 words",
|
||||
min_length=10,
|
||||
max_length=80,
|
||||
)
|
||||
|
||||
|
||||
class LLMType1ContentWithValidation(LLMType1Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
min_length=100,
|
||||
max_length=200,
|
||||
description="Slide content summary in about 30 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Prompt used to generate image for this slide.",
|
||||
description="Slide image prompt in about 5 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return ""
|
||||
|
||||
|
||||
class LLMType2ContentWithValidation(LLMType2Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **1 to 4 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
|
||||
|
||||
class LLMType3ContentWithValidation(LLMType3Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Prompt used to generate image for this slide",
|
||||
description="Slide image prompt in about 5 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **3 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
|
||||
|
||||
class LLMType4ContentWithValidation(LLMType4Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **1 to 3 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
|
||||
|
||||
class LLMType5ContentWithValidation(LLMType5Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
min_length=100,
|
||||
max_length=250,
|
||||
description="Slide content summary in about 30 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
graph: GraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
@classmethod
|
||||
def get_notes(self):
|
||||
return ""
|
||||
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
|
||||
|
||||
class LLMType6ContentWithValidation(LLMType6Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
min_length=80,
|
||||
max_length=150,
|
||||
description="Slide content summary in about 20 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **1 to 3 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
|
||||
|
||||
class LLMType7ContentWithValidation(LLMType7Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **1 to 4 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
|
||||
|
||||
class LLMType8ContentWithValidation(LLMType8Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
min_length=80,
|
||||
max_length=150,
|
||||
description="Slide content summary in about 20 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **1 to 3 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
|
||||
|
||||
class LLMType9ContentWithValidation(LLMType9Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
graph: GraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
@classmethod
|
||||
def get_notes(cls):
|
||||
return """
|
||||
- The **Body** should include **1 to 3 HeadingModels**.
|
||||
- Each **Heading** must consist of **1 to 3 words**.
|
||||
- Each item **Description** can be upto 10 words.
|
||||
"""
|
||||
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
|
||||
|
||||
LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING: Mapping[int, LLMSlideContentModel] = {
|
||||
LLMContentUnionWithValidation = Union[
|
||||
LLMType1ContentWithValidation,
|
||||
LLMType2ContentWithValidation,
|
||||
LLMType3ContentWithValidation,
|
||||
LLMType4ContentWithValidation,
|
||||
LLMType5ContentWithValidation,
|
||||
LLMType6ContentWithValidation,
|
||||
LLMType7ContentWithValidation,
|
||||
LLMType8ContentWithValidation,
|
||||
LLMType9ContentWithValidation,
|
||||
]
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION: Mapping[
|
||||
int, LLMContentUnionWithValidation
|
||||
] = {
|
||||
TYPE1: LLMType1ContentWithValidation,
|
||||
TYPE2: LLMType2ContentWithValidation,
|
||||
TYPE3: LLMType3ContentWithValidation,
|
||||
|
|
@ -279,17 +220,8 @@ LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING: Mapping[int, LLMSlideContentModel] = {
|
|||
|
||||
class LLMSlideModelWithValidation(LLMSlideModel):
|
||||
type: int
|
||||
content: (
|
||||
LLMType1ContentWithValidation
|
||||
| LLMType2ContentWithValidation
|
||||
| LLMType4ContentWithValidation
|
||||
| LLMType5ContentWithValidation
|
||||
| LLMType6ContentWithValidation
|
||||
| LLMType7ContentWithValidation
|
||||
| LLMType8ContentWithValidation
|
||||
| LLMType9ContentWithValidation
|
||||
)
|
||||
content: LLMContentUnionWithValidation
|
||||
|
||||
|
||||
class LLMPresentationModelWithValidation(LLMPresentationModel):
|
||||
slides: list[LLMSlideModelWithValidation]
|
||||
slides: List[LLMSlideModelWithValidation]
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from pptx.util import Pt
|
|||
from pptx.enum.text import PP_ALIGN
|
||||
from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE, MSO_CONNECTOR_TYPE
|
||||
|
||||
from graph_processor.models import GraphModel
|
||||
|
||||
|
||||
class PptxBoxShapeEnum(Enum):
|
||||
RECTANGLE = "rectangle"
|
||||
|
|
@ -138,14 +136,6 @@ class PptxPictureBoxModel(PptxShapeModel):
|
|||
picture: PptxPictureModel
|
||||
|
||||
|
||||
class PptxGraphBoxModel(PptxShapeModel):
|
||||
position: PptxPositionModel
|
||||
category_font: Optional[PptxFontModel] = None
|
||||
value_font: Optional[PptxFontModel] = None
|
||||
legend_font: Optional[PptxFontModel] = None
|
||||
graph: GraphModel
|
||||
|
||||
|
||||
class PptxConnectorModel(PptxShapeModel):
|
||||
type: MSO_CONNECTOR_TYPE = MSO_CONNECTOR_TYPE.STRAIGHT
|
||||
position: PptxPositionModel
|
||||
|
|
@ -159,13 +149,10 @@ class PptxSlideModel(BaseModel):
|
|||
| PptxAutoShapeBoxModel
|
||||
| PptxConnectorModel
|
||||
| PptxPictureBoxModel
|
||||
| PptxGraphBoxModel
|
||||
]
|
||||
|
||||
|
||||
class PptxPresentationModel(BaseModel):
|
||||
# theme: PresentationTheme
|
||||
# watermark: bool
|
||||
background_color: str
|
||||
shapes: Optional[List[PptxShapeModel]] = None
|
||||
slides: List[PptxSlideModel]
|
||||
|
|
|
|||
|
|
@ -6,26 +6,12 @@ from lxml import etree
|
|||
from pptx import Presentation
|
||||
from pptx.shapes.autoshape import Shape
|
||||
from pptx.slide import Slide
|
||||
from pptx.chart.data import ChartData, BubbleChartData
|
||||
from pptx.chart.chart import Chart
|
||||
from pptx.text.text import _Paragraph, TextFrame, Font, _Run
|
||||
from pptx.enum.chart import (
|
||||
XL_CHART_TYPE,
|
||||
XL_LEGEND_POSITION,
|
||||
XL_LABEL_POSITION,
|
||||
)
|
||||
from pptx.opc.constants import RELATIONSHIP_TYPE as RT
|
||||
from lxml.etree import fromstring, tostring
|
||||
from PIL import Image
|
||||
|
||||
from pptx.util import Pt
|
||||
from graph_processor.models import (
|
||||
BarGraphDataModel,
|
||||
BubbleChartDataModel,
|
||||
GraphTypeEnum,
|
||||
LineChartDataModel,
|
||||
PieChartDataModel,
|
||||
)
|
||||
from pptx.dml.color import RGBColor
|
||||
from ppt_generator.models.pptx_models import (
|
||||
PptxAutoShapeBoxModel,
|
||||
|
|
@ -33,7 +19,6 @@ from ppt_generator.models.pptx_models import (
|
|||
PptxConnectorModel,
|
||||
PptxFillModel,
|
||||
PptxFontModel,
|
||||
PptxGraphBoxModel,
|
||||
PptxParagraphModel,
|
||||
PptxPictureBoxModel,
|
||||
PptxPositionModel,
|
||||
|
|
@ -63,8 +48,6 @@ class PptxPresentationCreator:
|
|||
|
||||
self._ppt_model = ppt_model
|
||||
self._slide_models = ppt_model.slides
|
||||
# self._theme = ppt_model.theme
|
||||
# self._watermark = ppt_model.watermark
|
||||
|
||||
self._ppt = Presentation()
|
||||
self._ppt.slide_width = Pt(1280)
|
||||
|
|
@ -73,7 +56,6 @@ class PptxPresentationCreator:
|
|||
self._slide_fill = PptxFillModel(color=ppt_model.background_color)
|
||||
|
||||
def create_ppt(self):
|
||||
# self.set_presentation_theme()
|
||||
|
||||
for slide_model in self._slide_models:
|
||||
# Adding global shapes to slide
|
||||
|
|
@ -120,16 +102,9 @@ class PptxPresentationCreator:
|
|||
elif model_type is PptxTextBoxModel:
|
||||
self.add_textbox(slide, shape_model)
|
||||
|
||||
elif model_type is PptxGraphBoxModel:
|
||||
self.add_graph(slide, shape_model)
|
||||
|
||||
elif model_type is PptxConnectorModel:
|
||||
self.add_connector(slide, shape_model)
|
||||
|
||||
# if self._watermark:
|
||||
# Adding watermark
|
||||
# self.add_picture(slide, self.get_watermark_box_model())
|
||||
|
||||
def add_connector(self, slide: Slide, connector_model: PptxConnectorModel):
|
||||
if connector_model.thickness == 0:
|
||||
return
|
||||
|
|
@ -139,126 +114,6 @@ class PptxPresentationCreator:
|
|||
connector_shape.line.width = Pt(connector_model.thickness)
|
||||
connector_shape.line.color.rgb = RGBColor.from_string(connector_model.color)
|
||||
|
||||
def add_graph(self, slide: Slide, graph_box_model: PptxGraphBoxModel):
|
||||
chart_data = None
|
||||
chart_type = None
|
||||
graph = graph_box_model.graph
|
||||
match (graph.type):
|
||||
case GraphTypeEnum.bar:
|
||||
chart_data = self.get_bar_graph(graph.data)
|
||||
chart_type = XL_CHART_TYPE.COLUMN_CLUSTERED
|
||||
|
||||
case GraphTypeEnum.scatter:
|
||||
chart_data = self.get_scatter_graph(graph.data)
|
||||
chart_type = XL_CHART_TYPE.XY_SCATTER
|
||||
|
||||
case GraphTypeEnum.bubble:
|
||||
chart_data = self.get_bubble_graph(graph.data)
|
||||
chart_type = XL_CHART_TYPE.BUBBLE
|
||||
|
||||
case GraphTypeEnum.line:
|
||||
chart_data = self.get_line_graph(graph.data)
|
||||
chart_type = XL_CHART_TYPE.LINE
|
||||
|
||||
case GraphTypeEnum.pie:
|
||||
chart_data = self.get_pie_graph(graph.data)
|
||||
chart_type = XL_CHART_TYPE.PIE
|
||||
|
||||
if chart_data:
|
||||
chart: Chart = slide.shapes.add_chart(
|
||||
chart_type, *graph_box_model.position.to_pt_list(), chart_data
|
||||
).chart
|
||||
self.apply_graph_styles(chart, graph_box_model)
|
||||
|
||||
def apply_graph_styles(self, chart, graph_box_model: PptxGraphBoxModel):
|
||||
graph = graph_box_model.graph
|
||||
|
||||
if graph.type in [GraphTypeEnum.pie, GraphTypeEnum.scatter]:
|
||||
chart.has_legend = True
|
||||
chart.legend.position = XL_LEGEND_POSITION.RIGHT
|
||||
else:
|
||||
chart.has_legend = False
|
||||
|
||||
if graph_box_model.legend_font:
|
||||
self.apply_font(chart.font, graph_box_model.legend_font)
|
||||
|
||||
try:
|
||||
category_axis = chart.category_axis
|
||||
if graph_box_model.category_font:
|
||||
font = category_axis.tick_labels.font
|
||||
self.apply_font(font, graph_box_model.category_font)
|
||||
except:
|
||||
print("-" * 20)
|
||||
print("Could not apply category labels style")
|
||||
|
||||
try:
|
||||
value_axis = chart.value_axis
|
||||
tick_labels = value_axis.tick_labels
|
||||
if graph.postfix:
|
||||
tick_labels.number_format = f'0"{graph.postfix}"'
|
||||
if graph_box_model.value_font:
|
||||
self.apply_font(tick_labels.font, graph_box_model.value_font)
|
||||
except:
|
||||
print("-" * 20)
|
||||
print("Could not apply tick labels style")
|
||||
|
||||
if graph_box_model.graph.type is GraphTypeEnum.pie:
|
||||
for plot in chart.plots:
|
||||
try:
|
||||
plot.has_data_labels = True
|
||||
plot.data_labels.position = (
|
||||
XL_LABEL_POSITION.OUTSIDE_END
|
||||
if graph_box_model.graph.type is GraphTypeEnum.bar
|
||||
else XL_LABEL_POSITION.CENTER
|
||||
)
|
||||
if graph.postfix:
|
||||
plot.data_labels.number_format = f'0"{graph.postfix}"'
|
||||
if graph_box_model.value_font:
|
||||
self.apply_font(
|
||||
plot.data_labels.font,
|
||||
(
|
||||
graph_box_model.value_font
|
||||
if graph_box_model.graph.type is GraphTypeEnum.bar
|
||||
else PptxFontModel(
|
||||
# size=self._theme.fonts.p2,
|
||||
size=16,
|
||||
bold=True,
|
||||
color="ffffff",
|
||||
)
|
||||
),
|
||||
)
|
||||
except:
|
||||
print("-" * 20)
|
||||
print("Could not apply data labels style")
|
||||
|
||||
def get_bar_graph(self, graph: BarGraphDataModel):
|
||||
chart_data = ChartData()
|
||||
chart_data.categories = graph.get_categories()
|
||||
for series in graph.series:
|
||||
chart_data.add_series(series.get_name(), series.data)
|
||||
return chart_data
|
||||
|
||||
def get_bubble_graph(self, graph: BubbleChartDataModel):
|
||||
chart_data = BubbleChartData()
|
||||
for each in graph.series:
|
||||
series = chart_data.add_series(each.get_name())
|
||||
for point in each.points:
|
||||
series.add_data_point(*point.to_list())
|
||||
return chart_data
|
||||
|
||||
def get_line_graph(self, graph: LineChartDataModel):
|
||||
chart_data = ChartData()
|
||||
chart_data.categories = graph.get_categories()
|
||||
for series in graph.series:
|
||||
chart_data.add_series(series.get_name(), series.data)
|
||||
return chart_data
|
||||
|
||||
def get_pie_graph(self, graph: PieChartDataModel):
|
||||
chart_data = ChartData()
|
||||
chart_data.categories = graph.get_categories()
|
||||
chart_data.add_series("", graph.series[0].data)
|
||||
return chart_data
|
||||
|
||||
def add_picture(self, slide: Slide, picture_model: PptxPictureBoxModel):
|
||||
image_path = picture_model.picture.path
|
||||
if (
|
||||
|
|
@ -562,17 +417,5 @@ class PptxPresentationCreator:
|
|||
font.italic = font_model.italic
|
||||
font.size = Pt(font_model.size)
|
||||
|
||||
# def get_watermark_box_model(self):
|
||||
# watermark_asset_path = f"assets/images/{'watermark_dark.png' if self._theme == PresentationTheme.dark else 'watermark.png'}"
|
||||
|
||||
# return PptxPictureBoxModel(
|
||||
# position=PptxPositionModel(left=1120, top=685, width=140),
|
||||
# clip=False,
|
||||
# picture=PptxPictureModel(
|
||||
# is_network=False,
|
||||
# path=watermark_asset_path,
|
||||
# ),
|
||||
# )
|
||||
|
||||
def save(self, path: str):
|
||||
self._ppt.save(path)
|
||||
|
|
|
|||
|
|
@ -1,149 +1,142 @@
|
|||
from typing import Optional
|
||||
from api.utils.utils import get_large_model, get_small_model
|
||||
from ppt_config_generator.models import SlideMarkdownModel
|
||||
from ppt_generator.fix_validation_errors import get_validated_response
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.utils.model_utils import get_large_model, get_llm_client, get_small_model
|
||||
from ppt_config_generator.models import SlideMarkdownModel
|
||||
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
LLMSlideContentModel,
|
||||
LLMContentUnion,
|
||||
)
|
||||
from ppt_generator.models.llm_models_with_validations import (
|
||||
LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING,
|
||||
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION,
|
||||
)
|
||||
from ppt_generator.models.other_models import SlideTypeModel
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
|
||||
prompt_template_to_generate_slide_content = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
|
||||
def get_prompt_to_generate_slide_content(title: str, outline: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
|
||||
|
||||
# Steps
|
||||
1. Analyze the outline and title.
|
||||
2. Generate structured slide based on the outline and title.
|
||||
3. Generate image prompts and icon queries if mentioned in schema.
|
||||
4. Generate graph if mentioned in schema.
|
||||
|
||||
# Steps
|
||||
1. Analyze the outline and title.
|
||||
2. Generate structured slide based on the outline and title.
|
||||
3. Generate image prompts and icon queries if mentioned in schema.
|
||||
4. Generate graph if mentioned in schema.
|
||||
|
||||
# Notes
|
||||
- Slide body should not use words like "This slide", "This presentation".
|
||||
- Rephrase the slide body to make it flow naturally.
|
||||
- Do not use markdown formatting in slide body.
|
||||
- **Icon query** must be a generic single word noun.
|
||||
- **Image prompt** should be a 2-3 words phrase.
|
||||
- Try to make paragraphs as short as possible.
|
||||
{notes}
|
||||
# Notes
|
||||
- Slide body should not use words like "This slide", "This presentation".
|
||||
- Rephrase the slide body to make it flow naturally.
|
||||
- Do not use markdown formatting in slide body.
|
||||
""",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
"""
|
||||
## Slide Title
|
||||
{title}
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
## Slide Title
|
||||
{title}
|
||||
|
||||
## Slide Outline
|
||||
{outline}
|
||||
""",
|
||||
),
|
||||
## Slide Outline
|
||||
{outline}
|
||||
""",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
prompt_template_to_edit_slide_content = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
def get_prompt_to_edit_slide_content(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
theme: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
Edit Slide data based on provided prompt, follow mentioned steps and notes and provide structured output.
|
||||
|
||||
# Notes
|
||||
- Provide output in language mentioned in **Input**.
|
||||
- The goal is to change Slide data based on the provided prompt.
|
||||
- Do not change **Image prompts** and **Icon queries** if not asked for in prompt.
|
||||
- Generate **Image prompts** and **Icon queries** if asked to generate or change image or icons in prompt.
|
||||
- Ensure there are no line breaks in the JSON.
|
||||
- Do not use special characters for highlighting.
|
||||
{notes}
|
||||
- Generate **Image prompts** and **Icon queries** if asked to generate or change in prompt.
|
||||
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
"""
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Image Prompts and Icon Queries Language: English
|
||||
- Theme: {theme}
|
||||
- Slide data: {slide_data}
|
||||
""",
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Image Prompts and Icon Queries Language: English
|
||||
- Theme: {theme}
|
||||
- Slide data: {slide_data}
|
||||
""",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
prompt_template_to_select_slide_type = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
Select a Slide Type based on provided user prompt and current slide data.
|
||||
def get_prompt_to_select_slide_type(prompt: str, slide_data: dict, slide_type: int):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
Select a Slide Type based on provided user prompt and current slide data.
|
||||
|
||||
Select slide based on following slide description and make sure it matches user requirement:
|
||||
# Slide Types (Slide Type : Slide Description)
|
||||
- **1**: contains title, description and image.
|
||||
- **2**: contains title and list of items.
|
||||
- **4**: contains title and list of items with images.
|
||||
- **5**: contains title, description and a graph.
|
||||
- **6**: contains title, description and list of items.
|
||||
- **7**: contains title and list of items with icons.
|
||||
- **8**: contains title, description and list of items with icons.
|
||||
- **9**: contains title, list of items and a graph.
|
||||
Select slide based on following slide description and make sure it matches user requirement:
|
||||
# Slide Types (Slide Type : Slide Description)
|
||||
- **1**: contains title, description and image.
|
||||
- **2**: contains title and list of items.
|
||||
- **4**: contains title and list of items with images.
|
||||
- **5**: contains title, description and a graph.
|
||||
- **6**: contains title, description and list of items.
|
||||
- **7**: contains title and list of items with icons.
|
||||
- **8**: contains title, description and list of items with icons.
|
||||
- **9**: contains title, list of items and a graph.
|
||||
|
||||
# Notes
|
||||
- Do not select different slide type than current unless absolutely necessary as per user prompt.
|
||||
# Notes
|
||||
- Do not select different slide type than current unless absolutely necessary as per user prompt.
|
||||
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
- Current Slide Type: {slide_type}
|
||||
""",
|
||||
),
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
- Current Slide Type: {slide_type}
|
||||
""",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
slide_type: int, outline: SlideMarkdownModel
|
||||
) -> LLMSlideContentModel:
|
||||
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
|
||||
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
|
||||
model = get_small_model().with_structured_output(
|
||||
content_type_model_type.model_json_schema()
|
||||
)
|
||||
chain = prompt_template_to_generate_slide_content | model
|
||||
) -> LLMContentUnion:
|
||||
response_model = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
|
||||
|
||||
return await get_validated_response(
|
||||
chain,
|
||||
{
|
||||
"title": outline.title,
|
||||
"outline": outline.body,
|
||||
"notes": content_type_model_type.get_notes(),
|
||||
},
|
||||
content_type_model_type,
|
||||
validation_model,
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.5,
|
||||
messages=get_prompt_to_generate_slide_content(
|
||||
outline.title,
|
||||
outline.body,
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
|
||||
return response.choices[0].message.parsed
|
||||
|
||||
|
||||
async def get_edited_slide_content_model(
|
||||
prompt: str,
|
||||
|
|
@ -151,29 +144,24 @@ async def get_edited_slide_content_model(
|
|||
slide: SlideModel,
|
||||
theme: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
) -> LLMContentUnion:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
|
||||
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
|
||||
chain = prompt_template_to_edit_slide_content | model.with_structured_output(
|
||||
content_type_model_type.model_json_schema()
|
||||
)
|
||||
content_type_model_type = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
|
||||
slide_data = slide.content.to_llm_content().model_dump_json()
|
||||
edited_content = await get_validated_response(
|
||||
chain,
|
||||
{
|
||||
"prompt": prompt,
|
||||
"language": language or "English",
|
||||
"theme": theme,
|
||||
"slide_data": slide_data,
|
||||
"notes": "",
|
||||
},
|
||||
content_type_model_type,
|
||||
validation_model,
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_to_edit_slide_content(
|
||||
prompt,
|
||||
slide_data,
|
||||
theme,
|
||||
language,
|
||||
),
|
||||
response_format=content_type_model_type,
|
||||
)
|
||||
|
||||
return edited_content.to_content()
|
||||
return response.choices[0].message.parsed
|
||||
|
||||
|
||||
async def get_slide_type_from_prompt(
|
||||
|
|
@ -181,18 +169,15 @@ async def get_slide_type_from_prompt(
|
|||
slide: SlideModel,
|
||||
) -> SlideTypeModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
|
||||
chain = prompt_template_to_select_slide_type | model.with_structured_output(
|
||||
SlideTypeModel.model_json_schema()
|
||||
)
|
||||
slide_data = slide.content.to_llm_content().model_dump_json()
|
||||
return await get_validated_response(
|
||||
chain,
|
||||
{
|
||||
"prompt": prompt,
|
||||
"slide_data": slide_data,
|
||||
"slide_type": slide.type,
|
||||
},
|
||||
SlideTypeModel,
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_to_select_slide_type(
|
||||
prompt, slide.content.to_llm_content().model_dump_json(), slide.type
|
||||
),
|
||||
response_format=SlideTypeModel,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ dnspython==2.7.0
|
|||
email_validator==2.2.0
|
||||
fastapi==0.115.12
|
||||
fastapi-cli==0.0.7
|
||||
fastembed==0.7.0
|
||||
fastembed_vectorstore==0.1.5
|
||||
filelock==3.18.0
|
||||
filetype==1.2.0
|
||||
flatbuffers==25.2.10
|
||||
|
|
@ -28,6 +28,7 @@ fsspec==2025.3.2
|
|||
google-ai-generativelanguage==0.6.18
|
||||
google-api-core==2.24.2
|
||||
google-auth==2.40.1
|
||||
google-genai==1.23.0
|
||||
googleapis-common-protos==1.70.0
|
||||
greenlet==3.2.2
|
||||
grpcio==1.72.0rc1
|
||||
|
|
@ -44,14 +45,6 @@ Jinja2==3.1.6
|
|||
jiter==0.9.0
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
langchain==0.3.25
|
||||
langchain-community==0.3.24
|
||||
langchain-core==0.3.65
|
||||
langchain-google-genai==2.1.4
|
||||
langchain-ollama==0.3.3
|
||||
langchain-openai==0.3.16
|
||||
langchain-text-splitters==0.3.8
|
||||
langsmith==0.3.45
|
||||
loguru==0.7.3
|
||||
lxml==5.4.0
|
||||
markdown-it-py==3.0.0
|
||||
|
|
@ -65,7 +58,7 @@ mypy_extensions==1.1.0
|
|||
numpy==2.2.5
|
||||
ollama==0.5.1
|
||||
onnxruntime==1.22.0
|
||||
openai==1.78.1
|
||||
openai==1.91.0
|
||||
orjson==3.10.18
|
||||
packaging==24.2
|
||||
pdfminer.six==20250327
|
||||
|
|
@ -102,7 +95,7 @@ SQLAlchemy==2.0.41
|
|||
sqlmodel==0.0.24
|
||||
starlette==0.46.2
|
||||
sympy==1.14.0
|
||||
tenacity==9.1.2
|
||||
tenacity==8.5.0
|
||||
tiktoken==0.9.0
|
||||
tokenizers==0.21.1
|
||||
tqdm==4.67.1
|
||||
|
|
|
|||
|
|
@ -1,58 +1,35 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# search_tool = DuckDuckGoSearchRun(
|
||||
# api_wrapper=DuckDuckGoSearchAPIWrapper(max_results=50)
|
||||
# )
|
||||
|
||||
prompt_template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
Use provided prompt and search results to create an elaborate and up-to-date research report in mentioned language.
|
||||
def get_prompt_template():
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
Use provided prompt and search results to create an elaborate and up-to-date research report in mentioned language.
|
||||
|
||||
# Steps
|
||||
1. Analyze the prompt and search results.
|
||||
2. Extract topic of the report.
|
||||
3. Generate a report in markdown format.
|
||||
# Steps
|
||||
1. Analyze the prompt and search results.
|
||||
2. Extract topic of the report.
|
||||
3. Generate a report in markdown format.
|
||||
|
||||
# Notes
|
||||
- If language is not mentioned, use language from prompt.
|
||||
- Format of report should be like *Research Report*.
|
||||
- Ignore formatting if mentioned in prompt.
|
||||
# Notes
|
||||
- If language is not mentioned, use language from prompt.
|
||||
- Format of report should be like *Research Report*.
|
||||
- Ignore formatting if mentioned in prompt.
|
||||
""",
|
||||
),
|
||||
(
|
||||
"human",
|
||||
"""
|
||||
- Prompt: {prompt}
|
||||
- Language: {language}
|
||||
- Search Results: {search_results}
|
||||
},
|
||||
{
|
||||
"role": "human",
|
||||
"content": """
|
||||
- Prompt: {prompt}
|
||||
- Language: {language}
|
||||
- Search Results: {search_results}
|
||||
""",
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def get_report(query: str, language: Optional[str]):
|
||||
model = (
|
||||
ChatOpenAI(model="gpt-4.1-nano")
|
||||
if os.getenv("LLM") == "openai"
|
||||
else ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
)
|
||||
chain = prompt_template | model
|
||||
|
||||
# search_results = await search_tool.ainvoke(query)
|
||||
# response = await chain.ainvoke(
|
||||
# {
|
||||
# "prompt": query,
|
||||
# "language": language,
|
||||
# "search_results": search_results,
|
||||
# }
|
||||
# )
|
||||
# return response.content
|
||||
return "Research Report coming soon"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import uvicorn
|
||||
import argparse
|
||||
|
||||
|
|
@ -8,6 +9,8 @@ from api.main import app
|
|||
app
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the FastAPI server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, required=True, help="Port number to run the server on"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
import uvicorn
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the FastAPI server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, required=True, help="Port number to run the server on"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
|
||||
uvicorn.run(
|
||||
"api.main:app", host="0.0.0.0", port=8000, log_level="info", reload=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ const Header = ({
|
|||
if (response.ok) {
|
||||
const { path: pdfPath } = await response.json();
|
||||
const staticFileUrl = getStaticFileUrl(pdfPath);
|
||||
window.open(staticFileUrl, '_self');
|
||||
window.open(staticFileUrl, '_blank');
|
||||
} else {
|
||||
throw new Error("Failed to export PDF");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,9 @@ export async function POST(request: Request) {
|
|||
LLM: userConfig.LLM || existingConfig.LLM,
|
||||
OPENAI_API_KEY: userConfig.OPENAI_API_KEY || existingConfig.OPENAI_API_KEY,
|
||||
GOOGLE_API_KEY: userConfig.GOOGLE_API_KEY || existingConfig.GOOGLE_API_KEY,
|
||||
OLLAMA_MODEL: userConfig.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL,
|
||||
MODEL: userConfig.MODEL || existingConfig.MODEL,
|
||||
LLM_PROVIDER_URL: userConfig.LLM_PROVIDER_URL || existingConfig.LLM_PROVIDER_URL,
|
||||
LLM_API_KEY: userConfig.LLM_API_KEY || existingConfig.LLM_API_KEY,
|
||||
PEXELS_API_KEY: userConfig.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY,
|
||||
}
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig))
|
||||
|
|
|
|||
|
|
@ -2,13 +2,29 @@
|
|||
import React, { useState, useEffect } from "react";
|
||||
import Header from "../dashboard/components/Header";
|
||||
import Wrapper from "@/components/Wrapper";
|
||||
import { Settings, Key, Loader2 } from 'lucide-react';
|
||||
import { Settings, Key, Loader2, Check, ChevronsUpDown } from 'lucide-react';
|
||||
import { toast } from '@/hooks/use-toast';
|
||||
import { RootState } from "@/store/store";
|
||||
import { useSelector } from "react-redux";
|
||||
import { handleSaveLLMConfig } from "@/utils/storeHelpers";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger } from "@/components/ui/select";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "@/components/ui/command";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
|
||||
const PROVIDER_CONFIGS: Record<string, ProviderConfig> = {
|
||||
openai: {
|
||||
|
|
@ -55,14 +71,22 @@ const SettingsPage = () => {
|
|||
done: false,
|
||||
});
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState<boolean>(false);
|
||||
|
||||
const api_key_changed = (apiKey: string) => {
|
||||
const api_key_changed = (apiKey: string, field?: string) => {
|
||||
if (llmConfig.LLM === 'openai') {
|
||||
setLlmConfig({ ...llmConfig, OPENAI_API_KEY: apiKey });
|
||||
} else if (llmConfig.LLM === 'google') {
|
||||
setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: apiKey });
|
||||
} else if (llmConfig.LLM === 'ollama') {
|
||||
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: apiKey });
|
||||
if (field === 'pexels') {
|
||||
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: apiKey });
|
||||
} else if (field === 'ollama_url') {
|
||||
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: apiKey });
|
||||
} else if (field === 'ollama_api_key') {
|
||||
setLlmConfig({ ...llmConfig, LLM_API_KEY: apiKey });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -87,7 +111,7 @@ const SettingsPage = () => {
|
|||
}
|
||||
}
|
||||
try {
|
||||
await handleSaveLLMConfig(llmConfig);
|
||||
await handleSaveLLMConfig(llmConfig, useCustomOllamaUrl);
|
||||
toast({
|
||||
title: 'Success',
|
||||
description: 'Configuration saved successfully',
|
||||
|
|
@ -116,7 +140,7 @@ const SettingsPage = () => {
|
|||
return new Promise((resolve, reject) => {
|
||||
const interval = setInterval(async () => {
|
||||
try {
|
||||
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`);
|
||||
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`);
|
||||
if (response.status === 200) {
|
||||
|
||||
const data = await response.json();
|
||||
|
|
@ -163,6 +187,14 @@ const SettingsPage = () => {
|
|||
}
|
||||
}, [userConfigState.llm_config.LLM]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!useCustomOllamaUrl) {
|
||||
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: undefined, LLM_API_KEY: undefined });
|
||||
} else {
|
||||
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', LLM_API_KEY: '' });
|
||||
}
|
||||
}, [useCustomOllamaUrl]);
|
||||
|
||||
if (!canChangeKeys) {
|
||||
return null;
|
||||
}
|
||||
|
|
@ -262,70 +294,90 @@ const SettingsPage = () => {
|
|||
</label>
|
||||
<div className="w-full">
|
||||
{ollamaModels.length > 0 ? (
|
||||
<Select value={llmConfig.OLLAMA_MODEL} onValueChange={(value) => setLlmConfig({ ...llmConfig, OLLAMA_MODEL: value })}>
|
||||
<SelectTrigger className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400">
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<Popover open={openModelSelect} onOpenChange={setOpenModelSelect}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.icon}
|
||||
alt={`${llmConfig.OLLAMA_MODEL} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
{llmConfig.MODEL && (
|
||||
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={ollamaModels.find(m => m.value === llmConfig.MODEL)?.icon}
|
||||
alt={`${llmConfig.MODEL} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{llmConfig.OLLAMA_MODEL ? (
|
||||
ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.label || llmConfig.OLLAMA_MODEL
|
||||
{llmConfig.MODEL ? (
|
||||
ollamaModels.find(m => m.value === llmConfig.MODEL)?.label || llmConfig.MODEL
|
||||
) : (
|
||||
'Select a model'
|
||||
)}
|
||||
</span>
|
||||
{llmConfig.OLLAMA_MODEL && (
|
||||
{llmConfig.MODEL && (
|
||||
<span className="text-xs text-gray-500 bg-gray-100 rounded-full px-2 py-1">
|
||||
{ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.size}
|
||||
{ollamaModels.find(m => m.value === llmConfig.MODEL)?.size}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-h-80">
|
||||
<div className="p-2">
|
||||
<div className="text-xs font-semibold text-gray-500 uppercase tracking-wide mb-3 pt-3 px-2">
|
||||
Available Models
|
||||
</div>
|
||||
{ollamaModels.map((model, index) => (
|
||||
<SelectItem
|
||||
key={index}
|
||||
value={model.value}
|
||||
className="relative cursor-pointer rounded-md py-3 hover:bg-gray-50 focus:bg-gray-50 focus:outline-none transition-colors"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={model.icon}
|
||||
alt={`${model.label} icon`}
|
||||
className=" rounded-sm"
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="p-0" align="start" style={{ width: 'var(--radix-popover-trigger-width)' }}>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search model..." />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{ollamaModels.map((model, index) => (
|
||||
<CommandItem
|
||||
key={index}
|
||||
value={model.value}
|
||||
onSelect={(value) => {
|
||||
setLlmConfig({ ...llmConfig, MODEL: value });
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
llmConfig.MODEL === model.value ? "opacity-100" : "opacity-0"
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900 capitalize">
|
||||
{model.label}
|
||||
</span>
|
||||
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
|
||||
{model.size}
|
||||
</span>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={model.icon}
|
||||
alt={`${model.label} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900 capitalize">
|
||||
{model.label}
|
||||
</span>
|
||||
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
|
||||
{model.size}
|
||||
</span>
|
||||
</div>
|
||||
<span className="text-xs text-gray-600 leading-relaxed">
|
||||
{model.description}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-xs text-gray-600 leading-relaxed">
|
||||
{model.description}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</div>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : (
|
||||
<div className="w-full border border-gray-300 rounded-lg p-4">
|
||||
<div className="flex items-center space-x-3">
|
||||
|
|
@ -344,6 +396,62 @@ const SettingsPage = () => {
|
|||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Custom Ollama URL Configuration */}
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Use custom Ollama URL
|
||||
</label>
|
||||
<Switch
|
||||
checked={useCustomOllamaUrl}
|
||||
onCheckedChange={setUseCustomOllamaUrl}
|
||||
/>
|
||||
</div>
|
||||
{useCustomOllamaUrl && (
|
||||
<>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Ollama URL
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
required
|
||||
placeholder="Enter your Ollama URL"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={llmConfig.LLM_PROVIDER_URL || ''}
|
||||
onChange={(e) => api_key_changed(e.target.value, 'ollama_url')}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Change this if you are using a custom Ollama instance
|
||||
</p>
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Ollama API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
required
|
||||
placeholder="Enter your Ollama API key"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={llmConfig.LLM_API_KEY || ''}
|
||||
onChange={(e) => api_key_changed(e.target.value, 'ollama_api_key')}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Provide this if you are using a custom Ollama instance
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Pexels API Key (required for images)
|
||||
|
|
@ -355,12 +463,12 @@ const SettingsPage = () => {
|
|||
placeholder="Enter your Pexels API key"
|
||||
className="flex-1 px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={llmConfig.PEXELS_API_KEY || ''}
|
||||
onChange={(e) => api_key_changed(e.target.value)}
|
||||
onChange={(e) => api_key_changed(e.target.value, 'pexels')}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSaveConfig}
|
||||
disabled={isLoading || !llmConfig.OLLAMA_MODEL}
|
||||
className={`px-4 py-2 rounded-lg transition-colors ${isLoading || !llmConfig.OLLAMA_MODEL
|
||||
disabled={isLoading || !llmConfig.MODEL}
|
||||
className={`px-4 py-2 rounded-lg transition-colors ${isLoading || !llmConfig.MODEL
|
||||
? 'bg-gray-400 cursor-not-allowed'
|
||||
: 'bg-blue-600 hover:bg-blue-700'
|
||||
} text-white`}
|
||||
|
|
@ -374,7 +482,7 @@ const SettingsPage = () => {
|
|||
}
|
||||
</div>
|
||||
) : (
|
||||
!llmConfig.OLLAMA_MODEL ? 'Select Model' : 'Save'
|
||||
!llmConfig.MODEL ? 'Select Model' : 'Save'
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -41,11 +41,11 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) {
|
|||
llmConfig.LLM = 'openai';
|
||||
}
|
||||
dispatch(setLLMConfig(llmConfig));
|
||||
const isValid = hasValidLLMConfig(llmConfig);
|
||||
const isValid = hasValidLLMConfig(llmConfig, false);
|
||||
if (isValid) {
|
||||
// Check if the selected Ollama model is pulled
|
||||
if (llmConfig.LLM === 'ollama') {
|
||||
const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.OLLAMA_MODEL);
|
||||
const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.MODEL);
|
||||
if (!isPulled) {
|
||||
router.push('/');
|
||||
setLoadingToFalseAfterNavigatingTo('/');
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { toast } from "@/hooks/use-toast";
|
||||
import { Info, ExternalLink, PlayCircle, Loader2 } from "lucide-react";
|
||||
import { Info, ExternalLink, PlayCircle, Loader2, Check, ChevronsUpDown } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import {
|
||||
Accordion,
|
||||
|
|
@ -14,6 +14,22 @@ import { useSelector } from "react-redux";
|
|||
import { RootState } from "@/store/store";
|
||||
import { handleSaveLLMConfig } from "@/utils/storeHelpers";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger } from "./ui/select";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "./ui/command";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Switch } from "./ui/switch";
|
||||
|
||||
interface ModelOption {
|
||||
value: string;
|
||||
|
|
@ -100,36 +116,7 @@ const PROVIDER_CONFIGS: Record<string, ProviderConfig> = {
|
|||
},
|
||||
},
|
||||
ollama: {
|
||||
textModels: [
|
||||
{
|
||||
value: "llama3.1:8b",
|
||||
label: "Llama3.1:8b",
|
||||
description: "Balanced model for most tasks",
|
||||
icon: "/icons/ollama.png",
|
||||
size: "8GB",
|
||||
},
|
||||
{
|
||||
value: "llama3.1:70b",
|
||||
label: "Llama3.1:70b",
|
||||
description: "Large model for complex tasks",
|
||||
icon: "/icons/ollama.png",
|
||||
size: "70GB",
|
||||
},
|
||||
{
|
||||
value: "llama3.1:14b",
|
||||
label: "Llama3.1:14b",
|
||||
description: "Large model for complex tasks",
|
||||
icon: "/icons/ollama.png",
|
||||
size: "14GB",
|
||||
},
|
||||
{
|
||||
value: "llama3.1:11b",
|
||||
label: "Llama3.1:11b",
|
||||
description: "Large model for complex tasks",
|
||||
icon: "/icons/ollama.png",
|
||||
size: "11GB",
|
||||
},
|
||||
],
|
||||
textModels: [],
|
||||
imageModels: [
|
||||
{
|
||||
value: "pexels",
|
||||
|
|
@ -171,16 +158,24 @@ export default function Home() {
|
|||
done: false,
|
||||
});
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState<boolean>(false);
|
||||
|
||||
const canChangeKeys = config.can_change_keys;
|
||||
|
||||
const api_key_changed = (newApiKey: string) => {
|
||||
const api_key_changed = (newApiKey: string, field?: string) => {
|
||||
if (llmConfig.LLM === 'openai') {
|
||||
setLlmConfig({ ...llmConfig, OPENAI_API_KEY: newApiKey });
|
||||
} else if (llmConfig.LLM === 'google') {
|
||||
setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: newApiKey });
|
||||
} else if (llmConfig.LLM === 'ollama') {
|
||||
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: newApiKey });
|
||||
if (field === 'pexels') {
|
||||
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: newApiKey });
|
||||
} else if (field === 'ollama_url') {
|
||||
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: newApiKey });
|
||||
} else if (field === 'ollama_api_key') {
|
||||
setLlmConfig({ ...llmConfig, LLM_API_KEY: newApiKey });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,7 +200,7 @@ export default function Home() {
|
|||
}
|
||||
}
|
||||
try {
|
||||
await handleSaveLLMConfig(llmConfig);
|
||||
await handleSaveLLMConfig(llmConfig, useCustomOllamaUrl);
|
||||
toast({
|
||||
title: 'Success',
|
||||
description: 'Configuration saved successfully',
|
||||
|
|
@ -234,7 +229,7 @@ export default function Home() {
|
|||
return new Promise((resolve, reject) => {
|
||||
const interval = setInterval(async () => {
|
||||
try {
|
||||
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`);
|
||||
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`);
|
||||
if (response.status === 200) {
|
||||
const data = await response.json();
|
||||
|
||||
|
|
@ -277,6 +272,15 @@ export default function Home() {
|
|||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!useCustomOllamaUrl) {
|
||||
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: undefined, LLM_API_KEY: undefined });
|
||||
} else {
|
||||
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', LLM_API_KEY: '' });
|
||||
}
|
||||
}, [useCustomOllamaUrl]);
|
||||
|
||||
|
||||
if (!canChangeKeys) {
|
||||
return null;
|
||||
}
|
||||
|
|
@ -355,70 +359,90 @@ export default function Home() {
|
|||
</label>
|
||||
<div className="w-full">
|
||||
{ollamaModels.length > 0 ? (
|
||||
<Select value={llmConfig.OLLAMA_MODEL} onValueChange={(value) => setLlmConfig({ ...llmConfig, OLLAMA_MODEL: value })}>
|
||||
<SelectTrigger className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400">
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<Popover open={openModelSelect} onOpenChange={setOpenModelSelect}>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.icon}
|
||||
alt={`${llmConfig.OLLAMA_MODEL} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
{llmConfig.MODEL && (
|
||||
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={ollamaModels.find(m => m.value === llmConfig.MODEL)?.icon}
|
||||
alt={`${llmConfig.MODEL} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{llmConfig.OLLAMA_MODEL ? (
|
||||
ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.label || llmConfig.OLLAMA_MODEL
|
||||
{llmConfig.MODEL ? (
|
||||
ollamaModels.find(m => m.value === llmConfig.MODEL)?.label || llmConfig.MODEL
|
||||
) : (
|
||||
'Select a model'
|
||||
)}
|
||||
</span>
|
||||
{llmConfig.OLLAMA_MODEL && (
|
||||
{llmConfig.MODEL && (
|
||||
<span className="text-xs text-gray-500 bg-gray-100 rounded-full px-2 py-1">
|
||||
{ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.size}
|
||||
{ollamaModels.find(m => m.value === llmConfig.MODEL)?.size}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="max-h-80">
|
||||
<div className="p-2">
|
||||
<div className="text-xs font-semibold text-gray-500 uppercase tracking-wide mb-3 pt-3 px-2">
|
||||
Available Models
|
||||
</div>
|
||||
{ollamaModels.map((model, index) => (
|
||||
<SelectItem
|
||||
key={index}
|
||||
value={model.value}
|
||||
className="relative cursor-pointer rounded-md py-3 hover:bg-gray-50 focus:bg-gray-50 focus:outline-none transition-colors"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={model.icon}
|
||||
alt={`${model.label} icon`}
|
||||
className=" rounded-sm"
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="p-0" align="start" style={{ width: 'var(--radix-popover-trigger-width)' }}>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search model..." />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{ollamaModels.map((model, index) => (
|
||||
<CommandItem
|
||||
key={index}
|
||||
value={model.value}
|
||||
onSelect={(value) => {
|
||||
setLlmConfig({ ...llmConfig, MODEL: value });
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
llmConfig.MODEL === model.value ? "opacity-100" : "opacity-0"
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900 capitalize">
|
||||
{model.label}
|
||||
</span>
|
||||
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
|
||||
{model.size}
|
||||
</span>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
|
||||
<img
|
||||
src={model.icon}
|
||||
alt={`${model.label} icon`}
|
||||
className="rounded-sm"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900 capitalize">
|
||||
{model.label}
|
||||
</span>
|
||||
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
|
||||
{model.size}
|
||||
</span>
|
||||
</div>
|
||||
<span className="text-xs text-gray-600 leading-relaxed">
|
||||
{model.description}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-xs text-gray-600 leading-relaxed">
|
||||
{model.description}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</div>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : (
|
||||
<div className="w-full border border-gray-300 rounded-lg p-4">
|
||||
<div className="flex items-center space-x-3">
|
||||
|
|
@ -437,6 +461,59 @@ export default function Home() {
|
|||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Use custom Ollama URL
|
||||
</label>
|
||||
<Switch
|
||||
checked={useCustomOllamaUrl}
|
||||
onCheckedChange={setUseCustomOllamaUrl}
|
||||
/>
|
||||
</div>
|
||||
{useCustomOllamaUrl && (
|
||||
<>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Ollama URL
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
required
|
||||
placeholder="Enter your Ollama URL"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={llmConfig.LLM_PROVIDER_URL || ''}
|
||||
onChange={(e) => api_key_changed(e.target.value, 'ollama_url')}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Change this if you are using a custom Ollama instance
|
||||
</p>
|
||||
</div>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Ollama API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
required
|
||||
placeholder="Enter your Ollama API key"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={llmConfig.LLM_API_KEY || ''}
|
||||
onChange={(e) => api_key_changed(e.target.value, 'ollama_api_key')}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Provide this if you are using a custom Ollama instance
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Pexels API Key (required for images)
|
||||
|
|
@ -448,7 +525,7 @@ export default function Home() {
|
|||
placeholder="Enter your Pexels API key"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={llmConfig.PEXELS_API_KEY || ''}
|
||||
onChange={(e) => api_key_changed(e.target.value)}
|
||||
onChange={(e) => api_key_changed(e.target.value, 'pexels')}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
|
|
@ -468,7 +545,7 @@ export default function Home() {
|
|||
Selected Models
|
||||
</h3>
|
||||
<p className="text-sm text-blue-700">
|
||||
Using {llmConfig.LLM === 'ollama' ? llmConfig.OLLAMA_MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text
|
||||
Using {llmConfig.LLM === 'ollama' ? llmConfig.MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text
|
||||
generation and {PROVIDER_CONFIGS[llmConfig.LLM!].imageModels[0].label} for
|
||||
images
|
||||
</p>
|
||||
|
|
@ -544,7 +621,7 @@ export default function Home() {
|
|||
}
|
||||
</div>
|
||||
) : (
|
||||
llmConfig.LLM === 'ollama' && !llmConfig.OLLAMA_MODEL
|
||||
llmConfig.LLM === 'ollama' && !llmConfig.MODEL
|
||||
? 'Please Select a Model'
|
||||
: 'Save Configuration'
|
||||
)}
|
||||
|
|
|
|||
4
servers/nextjs/types/global.d.ts
vendored
4
servers/nextjs/types/global.d.ts
vendored
|
|
@ -18,5 +18,7 @@ interface LLMConfig {
|
|||
OPENAI_API_KEY?: string;
|
||||
GOOGLE_API_KEY?: string;
|
||||
PEXELS_API_KEY?: string;
|
||||
OLLAMA_MODEL?: string;
|
||||
LLM_PROVIDER_URL?: string;
|
||||
LLM_API_KEY?: string;
|
||||
MODEL?: string;
|
||||
}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import { setLLMConfig } from "@/store/slices/userConfig";
|
||||
import { store } from "@/store/store";
|
||||
|
||||
export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => {
|
||||
if (!hasValidLLMConfig(llmConfig)) {
|
||||
export const handleSaveLLMConfig = async (llmConfig: LLMConfig, useCustomOllamaUrl: boolean) => {
|
||||
if (!hasValidLLMConfig(llmConfig, useCustomOllamaUrl)) {
|
||||
throw new Error('API key cannot be empty');
|
||||
}
|
||||
|
||||
|
|
@ -14,17 +14,21 @@ export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => {
|
|||
store.dispatch(setLLMConfig(llmConfig));
|
||||
}
|
||||
|
||||
export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
||||
export const hasValidLLMConfig = (llmConfig: LLMConfig, useCustomOllamaUrl: boolean) => {
|
||||
if (!llmConfig.LLM) return false;
|
||||
const OPENAI_API_KEY = llmConfig.OPENAI_API_KEY;
|
||||
const GOOGLE_API_KEY = llmConfig.GOOGLE_API_KEY;
|
||||
const OLLAMA_MODEL = llmConfig.OLLAMA_MODEL;
|
||||
const MODEL = llmConfig.MODEL;
|
||||
const PEXELS_API_KEY = llmConfig.PEXELS_API_KEY;
|
||||
|
||||
const isOllamaBaseConfigValid = PEXELS_API_KEY !== '' && PEXELS_API_KEY !== null && PEXELS_API_KEY !== undefined && MODEL !== '' && MODEL !== null && MODEL !== undefined;
|
||||
|
||||
return llmConfig.LLM === 'openai' ?
|
||||
OPENAI_API_KEY !== '' && OPENAI_API_KEY !== null && OPENAI_API_KEY !== undefined :
|
||||
llmConfig.LLM === 'google' ?
|
||||
GOOGLE_API_KEY !== '' && GOOGLE_API_KEY !== null && GOOGLE_API_KEY !== undefined :
|
||||
llmConfig.LLM === 'ollama' ?
|
||||
PEXELS_API_KEY !== '' && PEXELS_API_KEY !== null && PEXELS_API_KEY !== undefined && OLLAMA_MODEL !== '' && OLLAMA_MODEL !== null && OLLAMA_MODEL !== undefined :
|
||||
false;
|
||||
useCustomOllamaUrl ?
|
||||
isOllamaBaseConfigValid && llmConfig.LLM_PROVIDER_URL !== '' && llmConfig.LLM_PROVIDER_URL !== null && llmConfig.LLM_PROVIDER_URL !== undefined && llmConfig.LLM_API_KEY !== '' && llmConfig.LLM_API_KEY !== null && llmConfig.LLM_API_KEY !== undefined :
|
||||
isOllamaBaseConfigValid : false;
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue