Merge pull request #52 from presenton/removes_langchain

feat(ollama): adds support for custom ollama url, refactor: removes langchain
This commit is contained in:
Saurav Niraula 2025-06-30 00:03:50 +05:45 committed by GitHub
commit 2719ea4e3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 584731 additions and 592377 deletions

View file

@ -7,4 +7,5 @@ build
.git
.gitignore
tmp
debug
debug
.fastembed_cache

3
.gitignore vendored
View file

@ -8,4 +8,5 @@ node_modules
out
user_data
tmp
debug
debug
.fastembed_cache

View file

@ -8,16 +8,17 @@ from contextlib import asynccontextmanager
from api.routers.presentation.router import presentation_router
from api.services.database import sql_engine
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from api.utils.utils import is_ollama_selected, update_env_with_user_config
from api.utils.utils import update_env_with_user_config
from api.utils.model_utils import is_ollama_selected
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
# Ollama model download
if not can_change_keys and is_ollama_selected():
ollama_model = os.getenv("OLLAMA_MODEL")
ollama_model = os.getenv("MODEL")
pexels_api_key = os.getenv("PEXELS_API_KEY")
if not (ollama_model or pexels_api_key):
raise Exception("OLLAMA_MODEL and PEXELS_API_KEY must be provided")
raise Exception("MODEL and PEXELS_API_KEY must be provided")
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
raise Exception(f"Model {ollama_model} is not supported")

View file

@ -1,3 +1,4 @@
from enum import Enum
import json
from typing import Optional
from pydantic import BaseModel
@ -62,7 +63,9 @@ class UserConfig(BaseModel):
LLM: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
OLLAMA_MODEL: Optional[str] = None
MODEL: Optional[str] = None
LLM_PROVIDER_URL: Optional[str] = None
LLM_API_KEY: Optional[str] = None
PEXELS_API_KEY: Optional[str] = None
@ -73,3 +76,10 @@ class OllamaModelMetadata(BaseModel):
icon: str
size: str
supports_graph: bool
class SelectedLLMProvider(Enum):
OLLAMA = "ollama"
OPENAI = "openai"
GOOGLE = "google"
CUSTOM = "custom"

View file

@ -1,5 +1,3 @@
import asyncio
from typing import List
import uuid
from api.models import LogMetadata
from api.routers.presentation.models import (
@ -37,7 +35,7 @@ class DecomposeDocumentsHandler:
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
f"{str(uuid.uuid4())}.txt", self.temp_dir
)
parsed_doc = parsed_doc.page_content.replace("<br>", "\n")
parsed_doc = parsed_doc.replace("<br>", "\n")
with open(file_path, "w") as text_file:
text_file.write(parsed_doc)
document_paths.append(file_path)

View file

@ -14,8 +14,8 @@ from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from api.utils.utils import (
get_presentation_dir,
get_presentation_images_dir,
is_ollama_selected,
)
from api.utils.model_utils import is_ollama_selected
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
from image_processor.images_finder import generate_image
from image_processor.icons_finder import get_icon
@ -69,7 +69,7 @@ class PresentationEditHandler:
new_slide_type = new_slide_type.slide_type
if is_ollama_selected():
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")]
if not model.supports_graph:
if new_slide_type == 5:
new_slide_type = 1

View file

@ -11,7 +11,7 @@ from api.routers.presentation.models import PresentationGenerateRequest
from api.services.logging import LoggingService
from api.sql_models import KeyValueSqlModel, PresentationSqlModel
from api.services.database import get_sql_session
from api.utils.utils import is_ollama_selected
from api.utils.model_utils import is_ollama_selected
from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel
from ppt_config_generator.structure_generator import generate_presentation_structure
@ -54,7 +54,7 @@ class PresentationGenerateDataHandler:
)
)
supports_graph = True
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")]
supports_graph = model.supports_graph
for each in presentation_structure.slides:

View file

@ -1,5 +1,4 @@
import uuid
import re
from api.models import LogMetadata
from api.routers.presentation.models import GenerateOutlinesRequest

View file

@ -1,3 +1,4 @@
import json
from typing import List
import uuid, aiohttp
from fastapi import HTTPException
@ -17,7 +18,8 @@ from api.services.database import get_sql_session
from api.services.instances import TEMP_FILE_SERVICE
from api.services.logging import LoggingService
from api.sql_models import PresentationSqlModel, SlideSqlModel
from api.utils.utils import get_presentation_dir, is_ollama_selected
from api.utils.utils import get_presentation_dir
from api.utils.model_utils import is_ollama_selected
from document_processor.loader import DocumentsLoader
from ppt_config_generator.document_summary_generator import generate_document_summary
from ppt_config_generator.models import PresentationMarkdownModel
@ -25,14 +27,9 @@ from ppt_config_generator.ppt_outlines_generator import generate_ppt_content
from ppt_generator.generator import generate_presentation
from ppt_generator.models.llm_models import (
LLM_CONTENT_TYPE_MAPPING,
LLMPresentationModel,
)
from langchain_core.output_parsers import JsonOutputParser
from ppt_generator.models.slide_model import SlideModel
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
@ -79,19 +76,17 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
print("-" * 40)
print("Generating Presentation")
presentation_text = (
await generate_presentation(
PresentationMarkdownModel(
title=presentation_content.title,
slides=presentation_content.slides,
notes=presentation_content.notes,
)
presentation_text = await generate_presentation(
PresentationMarkdownModel(
title=presentation_content.title,
slides=presentation_content.slides,
notes=presentation_content.notes,
)
).content
)
print("-" * 40)
print("Parsing Presentation")
presentation_json = output_parser.parse(presentation_text)
presentation_json = json.loads(presentation_text)
slide_models: List[SlideModel] = []
for i, slide in enumerate(presentation_json["slides"]):

View file

@ -17,13 +17,13 @@ from api.routers.presentation.models import (
from api.services.database import get_sql_session
from api.services.logging import LoggingService
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
from api.utils.utils import get_presentation_dir, is_ollama_selected
from api.utils.utils import get_presentation_dir
from api.utils.model_utils import is_ollama_selected
from ppt_config_generator.models import (
PresentationMarkdownModel,
PresentationStructureModel,
)
from ppt_generator.generator import generate_presentation_stream
from ppt_generator.models.content_type_models import CONTENT_TYPE_MAPPING
from ppt_generator.models.llm_models import (
LLM_CONTENT_TYPE_MAPPING,
LLMPresentationModel,
@ -31,12 +31,9 @@ from ppt_generator.models.llm_models import (
)
from ppt_generator.models.slide_model import SlideModel
from api.services.instances import TEMP_FILE_SERVICE
from langchain_core.output_parsers import JsonOutputParser
from ppt_generator.slide_generator import get_slide_content_from_type_and_outline
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin):
@ -144,20 +141,21 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin
async def generate_presentation_openai_google(self):
presentation_text = ""
async for chunk in generate_presentation_stream(
async for event in await generate_presentation_stream(
PresentationMarkdownModel(
title=self.title,
slides=self.outlines,
notes=self.presentation.notes,
)
):
presentation_text += chunk.content
chunk = event.choices[0].delta.content
presentation_text += chunk
yield SSEResponse(
event="response",
data=json.dumps({"type": "chunk", "chunk": chunk.content}),
data=json.dumps({"type": "chunk", "chunk": chunk}),
).to_string()
self.presentation_json = output_parser.parse(presentation_text)
self.presentation_json = json.loads(presentation_text)
async def generate_presentation_ollama(self):
presentation_structure = PresentationStructureModel(

View file

@ -1,7 +1,9 @@
import ollama
import os
import aiohttp
from api.models import LogMetadata
from api.routers.presentation.models import OllamaModelStatusResponse
from api.services.logging import LoggingService
from api.utils.model_utils import get_llm_api_key_or, get_llm_provider_url_or
class ListPulledOllamaModelsHandler:
@ -12,20 +14,25 @@ class ListPulledOllamaModelsHandler:
extra=log_metadata.model_dump(),
)
response = ollama.list()
async with aiohttp.ClientSession() as session:
async with session.get(
f"{get_llm_provider_url_or()}/api/tags",
headers={"Authorization": f"Bearer {get_llm_api_key_or()}"},
) as response:
response_data = await response.json()
logging_service.logger.info(
logging_service.message(response.model_dump(mode="json")),
logging_service.message(response_data),
extra=log_metadata.model_dump(),
)
return [
OllamaModelStatusResponse(
name=model.model,
size=model.size,
name=model["model"],
size=model["size"],
status="pulled",
downloaded=model.size,
downloaded=model["size"],
done=True,
)
for model in response.models
for model in response_data["models"]
]

View file

@ -1,5 +1,5 @@
import asyncio
import json
import aiohttp
from fastapi import BackgroundTasks, HTTPException
from api.models import LogMetadata
from api.routers.presentation.handlers.list_supported_ollama_models import (
@ -8,7 +8,7 @@ from api.routers.presentation.handlers.list_supported_ollama_models import (
from api.routers.presentation.models import OllamaModelStatusResponse
from api.services.instances import REDIS_SERVICE
from api.services.logging import LoggingService
import ollama
from api.utils.model_utils import get_llm_api_key_or, get_llm_provider_url_or
class PullOllamaModelHandler:
@ -33,19 +33,34 @@ class PullOllamaModelHandler:
detail=f"Model {self.name} is not supported",
)
pulled_models = ollama.list().models
filtered_models = list(
filter(lambda model: model.model == self.name, pulled_models)
)
# Check if model is already pulled using LLM_PROVIDER_URL/api/tags
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{get_llm_provider_url_or()}/api/tags",
headers={"Authorization": f"Bearer {get_llm_api_key_or()}"},
) as response:
if response.status == 200:
pulled_models = await response.json()
filtered_models = [
model
for model in pulled_models["models"]
if model["model"] == self.name
]
# If the model is already pulled, return the model
if filtered_models:
return OllamaModelStatusResponse(
name=self.name,
size=filtered_models[0].size,
status="pulled",
downloaded=filtered_models[0].size,
done=True,
# If the model is already pulled, return the model
if filtered_models:
return OllamaModelStatusResponse(
name=self.name,
size=filtered_models[0]["size"],
status="pulled",
downloaded=filtered_models[0]["size"],
done=True,
)
except Exception as e:
logging_service.logger.warning(
f"Failed to check pulled models: {e}",
extra=log_metadata.model_dump(),
)
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{self.name}")
@ -64,33 +79,64 @@ class PullOllamaModelHandler:
)
async def pull_model_in_background(self):
await asyncio.to_thread(self.pull_model)
await self.pull_model()
def pull_model(self):
async def pull_model(self):
saved_model_status = OllamaModelStatusResponse(
name=self.name,
status="pulling",
done=False,
)
log_event_count = 0
for event in ollama.pull(self.name, stream=True):
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if event.completed:
saved_model_status.downloaded = event.completed
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{get_llm_provider_url_or()}/api/pull",
json={"model": self.name},
headers={"Authorization": f"Bearer {get_llm_api_key_or()}"},
) as response:
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Failed to pull model: {await response.text()}",
)
if not saved_model_status.size and event.total:
saved_model_status.size = event.total
async for line in response.content:
if not line.strip():
continue
if event.status:
saved_model_status.status = event.status
try:
event = json.loads(line.decode("utf-8"))
except json.JSONDecodeError:
continue
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if "status" in event:
saved_model_status.status = event["status"]
REDIS_SERVICE.set(
f"ollama_models/{self.name}",
json.dumps(saved_model_status.model_dump(mode="json")),
)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
REDIS_SERVICE.set(
f"ollama_models/{self.name}",
json.dumps(saved_model_status.model_dump(mode="json")),
)
raise e
saved_model_status.done = True
saved_model_status.status = "pulled"

View file

@ -82,6 +82,9 @@ from api.routers.presentation.models import (
)
from api.sql_models import PresentationSqlModel
from api.utils.utils import handle_errors
from image_processor.images_finder import (
generate_image_google,
)
from ppt_generator.models.slide_model import SlideModel
route_prefix = "/api/v1/ppt"

View file

@ -0,0 +1,84 @@
import os
from openai import AsyncOpenAI
from api.models import SelectedLLMProvider
def is_ollama_selected() -> bool:
return get_selected_llm_provider() == SelectedLLMProvider.OLLAMA
def get_llm_provider_url_or():
return os.getenv("LLM_PROVIDER_URL") or "http://localhost:11434"
def get_llm_api_key_or():
return os.getenv("LLM_API_KEY") or "ollama"
def get_selected_llm_provider() -> SelectedLLMProvider:
return SelectedLLMProvider(os.getenv("LLM"))
def get_model_base_url():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return "https://api.openai.com/v1"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "https://generativelanguage.googleapis.com/v1beta/openai"
elif selected_llm == SelectedLLMProvider.OLLAMA:
return os.path.join(get_llm_provider_url_or(), "v1")
else:
raise ValueError(f"Invalid LLM provider")
def get_llm_api_key():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif selected_llm == SelectedLLMProvider.GOOGLE:
return os.getenv("GOOGLE_API_KEY")
elif selected_llm == SelectedLLMProvider.OLLAMA:
return get_llm_api_key_or()
else:
raise ValueError(f"Invalid LLM API key")
def get_llm_client():
client = AsyncOpenAI(
base_url=get_model_base_url(),
api_key=get_llm_api_key(),
)
return client
def get_large_model():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return "gpt-4.1"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "gemini-2.0-flash"
else:
return os.getenv("MODEL")
def get_small_model():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return "gpt-4.1-mini"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "gemini-2.0-flash"
else:
return os.getenv("MODEL")
def get_nano_model():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return "gpt-4.1-nano"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "gemini-2.0-flash"
else:
return os.getenv("MODEL")

View file

@ -9,48 +9,11 @@ from typing import List, Optional
import aiohttp
from fastapi import HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from api.models import LogMetadata, UserConfig
from api.services.logging import LoggingService
def is_ollama_selected() -> bool:
return os.getenv("LLM") == "ollama"
def get_large_model():
selected_llm = os.getenv("LLM")
if selected_llm == "openai":
return ChatOpenAI(model="gpt-4.1")
elif selected_llm == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
else:
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
def get_small_model():
selected_llm = os.getenv("LLM")
if selected_llm == "openai":
return ChatOpenAI(model="gpt-4.1-mini")
elif selected_llm == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
else:
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
def get_nano_model():
selected_llm = os.getenv("LLM")
if selected_llm == "openai":
return ChatOpenAI(model="gpt-4.1-nano")
elif selected_llm == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
else:
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
def get_presentation_dir(presentation_id: str) -> str:
presentation_dir = os.path.join(os.getenv("APP_DATA_DIRECTORY"), presentation_id)
os.makedirs(presentation_dir, exist_ok=True)
@ -81,8 +44,11 @@ def get_user_config():
LLM=existing_config.LLM or os.getenv("LLM"),
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY"),
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or os.getenv("GOOGLE_API_KEY"),
OLLAMA_MODEL=existing_config.OLLAMA_MODEL or os.getenv("OLLAMA_MODEL"),
MODEL=existing_config.MODEL or os.getenv("MODEL"),
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or os.getenv("PEXELS_API_KEY"),
LLM_PROVIDER_URL=existing_config.LLM_PROVIDER_URL
or os.getenv("LLM_PROVIDER_URL"),
LLM_API_KEY=existing_config.LLM_API_KEY or os.getenv("LLM_API_KEY"),
)
@ -94,10 +60,14 @@ def update_env_with_user_config():
os.environ["OPENAI_API_KEY"] = user_config.OPENAI_API_KEY
if user_config.GOOGLE_API_KEY:
os.environ["GOOGLE_API_KEY"] = user_config.GOOGLE_API_KEY
if user_config.OLLAMA_MODEL:
os.environ["OLLAMA_MODEL"] = user_config.OLLAMA_MODEL
if user_config.MODEL:
os.environ["MODEL"] = user_config.MODEL
if user_config.PEXELS_API_KEY:
os.environ["PEXELS_API_KEY"] = user_config.PEXELS_API_KEY
if user_config.LLM_PROVIDER_URL:
os.environ["LLM_PROVIDER_URL"] = user_config.LLM_PROVIDER_URL
if user_config.LLM_API_KEY:
os.environ["LLM_API_KEY"] = user_config.LLM_API_KEY
def get_resource(relative_path):

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,10 @@
import asyncio
import mimetypes
import os
from typing import List, Tuple
from fastapi import HTTPException
from langchain_community.document_loaders import TextLoader, PDFPlumberLoader
from langchain_core.documents import Document
from langchain_text_splitters import CharacterTextSplitter, MarkdownTextSplitter
from pptx import Presentation
import pdfplumber
from docx import Document as DocxDocument
from image_processor.utils import get_page_images_from_pdf_async
@ -30,23 +29,13 @@ class DocumentsLoader:
def __init__(self, documents: List[str]):
self._document_paths = documents
self._documents: List[Document] = []
self._splitted_documents: List[Document] = []
self._documents: List[str] = []
self._images: List[List[str]] = []
self._markdown_splitter = MarkdownTextSplitter(chunk_size=500, chunk_overlap=50)
self._text_splitter = CharacterTextSplitter(
separator="/n", chunk_size=500, chunk_overlap=50
)
@property
def documents(self):
return self._documents
@property
def splitted_documents(self):
return self._splitted_documents
@property
def images(self):
return self._images
@ -54,90 +43,69 @@ class DocumentsLoader:
async def load_documents(
self,
temp_dir: str,
split_documents: bool = False,
load_markdown: bool = True,
load_text: bool = True,
load_images: bool = False,
):
documents: List[Document] = []
documents: List[str] = []
images: List[str] = []
splitted_documents: List[Document] = []
for file_path in self._document_paths:
if not os.path.exists(file_path):
raise HTTPException(
status_code=404, detail=f"File {file_path} not found"
)
docs = []
document = ""
imgs = []
mime_type = mimetypes.guess_type(file_path)[0]
if mime_type in PDF_MIME_TYPES:
docs, imgs = await self.load_pdf(
file_path, load_markdown, load_images, temp_dir
document, imgs = await self.load_pdf(
file_path, load_text, load_images, temp_dir
)
elif mime_type in TEXT_MIME_TYPES:
docs = self.load_text(file_path)
document = await self.load_text(file_path)
elif mime_type in POWERPOINT_TYPES:
docs = self.load_powerpoint(file_path)
document = self.load_powerpoint(file_path)
elif mime_type in WORD_TYPES:
docs = self.load_msword(file_path)
document = self.load_msword(file_path)
documents.extend(docs)
documents.append(document)
images.append(imgs)
if split_documents:
splitted_documents.extend(self.split_documents(docs, mime_type))
self._documents = documents
self._splitted_documents = splitted_documents
self._images = images
def split_documents(self, documents: List[Document], mime_type):
return self._text_splitter.split_documents(documents)
def clip_longer_documents(self, documents: List[Document], clip_after: int = 1200):
for document in documents:
document.page_content = document.page_content[:clip_after]
return documents
async def load_pdf(
self,
file_path: str,
load_markdown: bool,
load_text: bool,
load_images: bool,
temp_dir: str,
) -> Tuple[List[Document], List[str]]:
) -> Tuple[str, List[str]]:
image_paths = []
documents: List[Document] = []
document: str = ""
if load_markdown:
loader = PDFPlumberLoader(file_path)
docs = loader.load()
pdf_document = Document(page_content="")
pdf_document.metadata = docs[0].metadata
for doc in docs:
pdf_document.page_content += doc.page_content
documents.append(pdf_document)
if load_text:
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
document += await asyncio.to_thread(page.extract_text)
if load_images:
image_paths = await get_page_images_from_pdf_async(file_path, temp_dir)
return documents, image_paths
return document, image_paths
async def decompose_pdf_to_markdown(self, document_path: str) -> str:
raise Exception("Not Implemented")
async def load_text(self, file_path: str) -> str:
with open(file_path, "r") as file:
return await asyncio.to_thread(file.read)
def load_text(self, file_path: str) -> List[Document]:
loader = TextLoader(file_path)
return loader.load()
def load_msword(self, file_path: str) -> List[Document]:
def load_msword(self, file_path: str) -> str:
document = DocxDocument(file_path)
text = "\n".join([paragraph.text for paragraph in document.paragraphs])
return [Document(page_content=text)]
return text
def load_powerpoint(self, file_path: str) -> List[Document]:
def load_powerpoint(self, file_path: str) -> str:
presentation = Presentation(file_path)
extracted_text = ""
@ -149,4 +117,4 @@ class DocumentsLoader:
extracted_text += f"{paragraph.text}\n"
extracted_text += "\n"
extracted_text += "\n\n"
return [Document(page_content=extracted_text)]
return extracted_text

View file

@ -1,132 +0,0 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
from graph_processor.utils import clip_text
class PointModel(BaseModel):
x: float
y: float
def to_list(self) -> List[float]:
return [self.x, self.y]
class PointWithRadius(PointModel):
radius: Optional[float] = None
class BarSeriesModel(BaseModel):
name: str
data: List[float] = Field(
description="Only numbers should be given out in data. Don't include text/string in data."
)
def get_name(self) -> str:
return clip_text(self.name)
class ScatterSeriesModel(BaseModel):
name: str
points: List[PointModel]
def get_name(self) -> str:
return clip_text(self.name)
class BubbleSeriesModel(BaseModel):
name: str
points: List[PointWithRadius]
def get_name(self) -> str:
return clip_text(self.name)
class LineSeriesModel(BaseModel):
name: str
data: List[float] = Field(
description="Only numbers should be given out in data. Don't include text/string in data."
)
def get_name(self) -> str:
return clip_text(self.name)
class PieChartSeriesModel(BaseModel):
data: List[float]
class BarGraphDataModel(BaseModel):
categories: List[str]
series: List[BarSeriesModel] = Field(
description="There should be no more than 3 series"
)
def get_categories(self) -> List[str]:
return [clip_text(category) for category in self.categories]
class ScatterChartDataModel(BaseModel):
series: List[ScatterSeriesModel]
class BubbleChartDataModel(BaseModel):
series: List[BubbleSeriesModel]
class LineChartDataModel(BaseModel):
categories: List[str]
series: List[LineSeriesModel] = Field(
description="There should be no more than 3 series"
)
def get_categories(self) -> List[str]:
return [clip_text(category) for category in self.categories]
class PieChartDataModel(BaseModel):
categories: List[str]
series: List[PieChartSeriesModel] = Field(
description="One series model with list of data",
min_length=1,
)
@model_validator(mode="after")
def limit_series(self):
self.series = self.series[:1]
return self
def get_categories(self) -> List[str]:
return [clip_text(category) for category in self.categories]
# class TableDataModel(BaseModel):
# categories: List[str]
# series: List[BarSeriesModel]
# def get_categories(self) -> List[str]:
# return [clip_text(category) for category in self.categories]
class GraphTypeEnum(Enum):
pie = "pie"
bar = "bar"
line = "line"
class GraphModel(BaseModel):
style: Optional[dict] = {}
name: str
type: GraphTypeEnum
unit: Optional[str] = Field(
description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc."
)
data: PieChartDataModel | LineChartDataModel | BarGraphDataModel
GRAPH_TYPE_MAPPING = {
GraphTypeEnum.pie: PieChartDataModel,
GraphTypeEnum.bar: BarGraphDataModel,
GraphTypeEnum.line: LineChartDataModel,
}

View file

@ -1,3 +0,0 @@
def clip_text(text: str, max_length: int = 6) -> str:
# return text[:max_length] + ".." if len(text) > max_length else text
return text

View file

@ -5,17 +5,17 @@ from ppt_generator.models.query_and_prompt_models import (
IconCategoryEnum,
IconQueryCollectionWithData,
)
from langchain_core.vectorstores import InMemoryVectorStore
from fastembed_vectorstore import FastembedVectorstore
async def get_icon(
vector_store: InMemoryVectorStore,
vector_store: FastembedVectorstore,
input: IconQueryCollectionWithData,
) -> str:
try:
query = input.icon_query
results = vector_store.similarity_search(query=query, k=1)
icon_name = results[0].page_content
results = vector_store.search(query, 1)
icon_name = results[0][0].split("||")[0]
return get_resource(f"assets/icons/bold/{icon_name}.png")
except Exception as e:
print("Error finding icon: ", e)
@ -23,7 +23,7 @@ async def get_icon(
async def get_icons(
vector_store: InMemoryVectorStore,
vector_store: FastembedVectorstore,
query: str,
page: int,
limit: int,
@ -31,7 +31,7 @@ async def get_icons(
temp_dir: str,
) -> List[str]:
results = await vector_store.asimilarity_search(query=query, k=limit)
icon_names = [result.page_content for result in results]
results = vector_store.search(query, limit)
icon_names = [result[0].split("||")[0] for result in results]
return [get_resource(f"assets/icons/bold/{each}.png") for each in icon_names]

View file

@ -1,37 +1,26 @@
import json
import os
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from api.utils.utils import get_resource
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
# Pyinstaller
import fastembed
from fastembed_vectorstore import FastembedVectorstore, FastembedEmbeddingModel
def get_icons_vectorstore():
vector_store_path = get_resource("assets/icons_vectorstore.json")
embeddings = FastEmbedEmbeddings()
embedding_model = FastembedEmbeddingModel.BGESmallENV15
if os.path.exists(vector_store_path):
vector_store = InMemoryVectorStore.load(vector_store_path, embeddings)
return vector_store
vector_store = InMemoryVectorStore(embeddings)
vector_store.dump(vector_store_path)
return FastembedVectorstore.load(embedding_model, vector_store_path)
vector_store = FastembedVectorstore(embedding_model)
with open(get_resource("assets/icons.json"), "r") as f:
icons = json.load(f)
icon_names = [icon["name"] for icon in icons["icons"]]
documents = []
for each in icon_names:
if each.split("-")[-1] == "bold":
documents.append(Document(id=each, page_content=each))
for each in icons["icons"]:
if each["name"].split("-")[-1] == "bold":
documents.append(f"{each['name']}||{each['tags']}")
vector_store.embed_documents(documents)
vector_store.save(vector_store_path)
vector_store.add_documents(documents)
vector_store.dump(vector_store_path)
return vector_store

View file

@ -3,13 +3,14 @@ import base64
import os
import uuid
import aiohttp
from langchain_google_genai import ChatGoogleGenerativeAI
from openai import OpenAI
from google import genai
from google.genai.types import GenerateContentConfig
from ppt_generator.models.query_and_prompt_models import (
ImagePromptWithThemeAndAspectRatio,
)
from api.utils.utils import download_file, get_resource, is_ollama_selected
from api.utils.utils import download_file, get_resource
from api.utils.model_utils import get_llm_client, is_ollama_selected
async def generate_image(
@ -46,7 +47,7 @@ async def generate_image(
async def generate_image_openai(prompt: str, output_directory: str) -> str:
client = OpenAI()
client = get_llm_client()
result = await asyncio.to_thread(
client.images.generate,
model="dall-e-3",
@ -66,20 +67,20 @@ async def generate_image_openai(prompt: str, output_directory: str) -> str:
async def generate_image_google(prompt: str, output_directory: str) -> str:
response = await ChatGoogleGenerativeAI(
model="gemini-2.0-flash-preview-image-generation"
).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]})
image_block = next(
block
for block in response.content
if isinstance(block, dict) and block.get("image_url")
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.0-flash-preview-image-generation",
contents=[prompt],
config=GenerateContentConfig(response_modalities=["TEXT", "IMAGE"]),
)
base64_image = image_block["image_url"].get("url").split(",")[-1]
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
with open(image_path, "wb") as f:
f.write(base64.b64decode(base64_image))
for part in response.candidates[0].content.parts:
if part.text is not None:
print(part.text)
elif part.inline_data is not None:
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
with open(image_path, "wb") as f:
f.write(part.inline_data.data)
return image_path

View file

@ -1,14 +1,8 @@
import asyncio
import os
from typing import List
from langchain_core.documents import Document
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import BaseMessage
from langchain_text_splitters import CharacterTextSplitter
from openai.types.chat.chat_completion import ChatCompletion
from api.utils.utils import get_nano_model
from api.utils.model_utils import get_llm_client, get_nano_model
sysmte_prompt = """
Generate a blog-style summary of the provided document in **more than 2000 words**.
@ -26,26 +20,25 @@ Maintain as much information as possible.
- If **slides structure is mentioned** in document, structure the summary in the same way.
"""
prompt_template = ChatPromptTemplate.from_messages(
[
("system", sysmte_prompt),
("user", "{text}"),
]
)
async def generate_document_summary(documents: List[Document]):
async def generate_document_summary(documents: List[str]):
client = get_llm_client()
model = get_nano_model()
text_splitter = CharacterTextSplitter(chunk_size=200000, chunk_overlap=0)
chain = prompt_template | model
coroutines = []
for document in documents:
text = document.page_content
truncated_text = text_splitter.split_text(text)[0]
coroutine = chain.ainvoke({"text": truncated_text})
truncated_text = document[:200000]
coroutine = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": sysmte_prompt},
{"role": "user", "content": truncated_text},
],
)
coroutines.append(coroutine)
completions: List[BaseMessage] = await asyncio.gather(*coroutines)
combined = "\n\n\n\n".join([completion.content for completion in completions])
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
combined = "\n\n\n\n".join(
[completion.choices[0].message.content for completion in completions]
)
return combined

View file

@ -1,32 +1,17 @@
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from api.utils.utils import get_large_model
from api.utils.model_utils import get_large_model, get_llm_client
from api.utils.variable_length_models import (
get_presentation_markdown_model_with_n_slides,
)
from ppt_config_generator.models import PresentationMarkdownModel
from ppt_generator.fix_validation_errors import get_validated_response
user_prompt_text = {
"type": "text",
"text": """
**Input:**
- Prompt: {prompt}
- Output Language: {language}
- Number of Slides: {n_slides}
- Additional Information: {content}
""",
}
def get_prompt_template():
return ChatPromptTemplate.from_messages(
[
(
"system",
"""
def get_prompt_template(prompt: str, n_slides: int, language: str, content: str):
return [
{
"role": "system",
"content": """
Create a presentation based on the provided prompt, number of slides, output language, and additional informational details.
Format the output in the specified JSON schema with structured markdown content.
@ -49,13 +34,18 @@ def get_prompt_template():
- Slide **title** should not be in markdown format.
- There must be exact **Number of Slides** as specified.
""",
),
(
"user",
[user_prompt_text],
),
],
)
},
{
"role": "user",
"content": f"""
**Input:**
- Prompt: {prompt}
- Output Language: {language}
- Number of Slides: {n_slides}
- Additional Information: {content}
""",
},
]
async def generate_ppt_content(
@ -64,21 +54,14 @@ async def generate_ppt_content(
language: Optional[str] = None,
content: Optional[str] = None,
) -> PresentationMarkdownModel:
client = get_llm_client()
model = get_large_model()
response_model = get_presentation_markdown_model_with_n_slides(n_slides)
chain = get_prompt_template() | model.with_structured_output(
response_model.model_json_schema()
)
return await get_validated_response(
chain,
{
"prompt": prompt,
"n_slides": n_slides,
"language": language or "English",
"content": content,
},
response_model,
PresentationMarkdownModel,
response = await client.beta.chat.completions.parse(
model=model,
temperature=0.2,
messages=get_prompt_template(prompt, n_slides, language, content),
response_format=response_model,
)
return response.choices[0].message.parsed

View file

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate
from api.utils.utils import get_small_model
from api.utils.model_utils import get_llm_client, get_small_model
from api.utils.variable_length_models import (
get_presentation_structure_model_with_n_slides,
)
@ -8,13 +6,13 @@ from ppt_config_generator.models import (
PresentationStructureModel,
PresentationMarkdownModel,
)
from ppt_generator.fix_validation_errors import get_validated_response
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
def get_prompt(n_slides: int, data: str):
return [
{
"role": "system",
"content": f"""
You're a professional presentation designer with years of experience in designing clear and engaging presentations.
# Slide Types
@ -44,33 +42,32 @@ prompt = ChatPromptTemplate.from_messages(
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
""",
),
(
"human",
"""
{data}
},
{
"role": "user",
"content": f"""
{data}
""",
),
},
]
)
async def generate_presentation_structure(
presentation_outline: PresentationMarkdownModel,
) -> PresentationStructureModel:
client = get_llm_client()
model = get_small_model()
response_model = get_presentation_structure_model_with_n_slides(
len(presentation_outline.slides)
)
chain = prompt | model.with_structured_output(response_model.model_json_schema())
return await get_validated_response(
chain,
{
"n_slides": len(presentation_outline.slides),
"data": presentation_outline.to_string(),
},
response_model,
PresentationStructureModel,
response = await client.beta.chat.completions.parse(
model=model,
temperature=0.2,
messages=get_prompt(
len(presentation_outline.slides), presentation_outline.to_string()
),
response_format=response_model,
)
return response.choices[0].message.parsed

View file

@ -1,12 +1,11 @@
import os
from typing import Optional
from fastapi import HTTPException
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, ValidationError
from api.utils.utils import get_large_model
from api.utils.model_utils import get_large_model
def get_prompt_template():
@ -41,7 +40,7 @@ def get_prompt_template():
async def fix_validation_errors(response_model: BaseModel, response, errors):
model = get_large_model()
model = ChatOllama(model=get_large_model(), temperature=0.8)
chain = get_prompt_template() | model.with_structured_output(
response_model.model_json_schema()

View file

@ -1,11 +1,11 @@
from typing import AsyncIterator
from langchain_core.messages import (
HumanMessage,
AIMessageChunk,
AIMessage,
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from api.models import SelectedLLMProvider
from api.utils.model_utils import (
get_large_model,
get_llm_client,
get_selected_llm_provider,
)
from api.utils.utils import get_large_model
from ppt_config_generator.models import PresentationMarkdownModel
from ppt_generator.models.llm_models_with_validations import (
LLMPresentationModelWithValidation,
@ -70,42 +70,87 @@ CREATE_PRESENTATION_PROMPT = """
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
"""
schema = LLMPresentationModelWithValidation.model_json_schema()
system_prompt = f"""
system_prompt_with_schema = f"""
{CREATE_PRESENTATION_PROMPT}
Follow this schema while giving out response: {schema}.
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
"""
ollama_system_prompt = f"""
{CREATE_PRESENTATION_PROMPT}
Follow this schema while giving out response: {LLMPresentationModelWithValidation.model_json_schema()}.
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
"""
def get_model_and_messages(
def get_system_prompt():
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
return (
system_prompt_with_schema if is_google_selected else CREATE_PRESENTATION_PROMPT
)
def get_response_format():
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
return (
{
"type": "json_object",
}
if is_google_selected
else {
"type": "json_schema",
"json_schema": {
"name": "LLMPresentationModel",
"schema": LLMPresentationModelWithValidation.model_json_schema(),
},
}
)
async def generate_presentation_stream(
presentation_outline: PresentationMarkdownModel,
):
user_message = HumanMessage(presentation_outline.to_string())
) -> AsyncStream[ChatCompletionChunk]:
client = get_llm_client()
model = get_large_model()
return model, system_prompt, user_message
response_format = get_response_format()
response = await client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": get_system_prompt(),
},
{
"role": "user",
"content": presentation_outline.to_string(),
},
],
response_format=response_format,
stream=True,
)
def generate_presentation_stream(
presentation_outline: PresentationMarkdownModel,
) -> AsyncIterator[AIMessageChunk]:
model, system_prompt, user_message = get_model_and_messages(presentation_outline)
return model.astream([system_prompt, user_message])
return response
async def generate_presentation(
presentation_outline: PresentationMarkdownModel,
) -> AIMessage:
model, system_prompt, user_message = get_model_and_messages(presentation_outline)
return await model.ainvoke([system_prompt, user_message])
) -> str:
client = get_llm_client()
model = get_large_model()
response_format = get_response_format()
response = await client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": get_system_prompt(),
},
{
"role": "user",
"content": presentation_outline.to_string(),
},
],
response_format=response_format,
)
return response.choices[0].message.content

View file

@ -1,4 +1,4 @@
from typing import List, Mapping
from typing import List, Mapping, Union
from pydantic import BaseModel
from ppt_generator.models.other_models import (
@ -12,7 +12,17 @@ from ppt_generator.models.other_models import (
TYPE8,
TYPE9,
)
from graph_processor.models import GraphModel
class TableDataModel(BaseModel):
x_labels: List[str]
y_labels: List[str]
data: List[List[float]]
class TableModel(BaseModel):
name: str
data: TableDataModel
class HeadingModel(BaseModel):
@ -47,9 +57,6 @@ class HeadingModel(BaseModel):
class SlideContentModel(BaseModel):
title: str
def to_llm_content(self):
raise NotImplementedError("to_llm_content method not implemented")
class Type1Content(SlideContentModel):
body: str
@ -110,7 +117,7 @@ class Type4Content(SlideContentModel):
class Type5Content(SlideContentModel):
body: str
graph: GraphModel
table: TableModel
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType5Content
@ -118,7 +125,7 @@ class Type5Content(SlideContentModel):
return LLMType5Content(
title=self.title,
body=self.body,
graph=self.graph,
table=self.table,
)
@ -174,7 +181,7 @@ class Type8Content(SlideContentModel):
class Type9Content(SlideContentModel):
body: List[HeadingModel]
graph: GraphModel
table: TableModel
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType9Content
@ -182,11 +189,23 @@ class Type9Content(SlideContentModel):
return LLMType9Content(
title=self.title,
body=[item.to_llm_content() for item in self.body],
graph=self.graph,
table=self.table,
)
CONTENT_TYPE_MAPPING: Mapping[int, SlideContentModel] = {
ContentUnion = Union[
Type1Content,
Type2Content,
Type3Content,
Type4Content,
Type5Content,
Type6Content,
Type7Content,
Type8Content,
Type9Content,
]
CONTENT_TYPE_MAPPING: Mapping[int, ContentUnion] = {
TYPE1: Type1Content,
TYPE2: Type2Content,
TYPE3: Type3Content,

View file

@ -1,10 +1,10 @@
from typing import List, Mapping
from typing import List, Mapping, Union
from pydantic import BaseModel
from graph_processor.models import GraphModel
from ppt_generator.models.content_type_models import (
HeadingModel,
SlideContentModel,
TableDataModel,
TableModel,
Type1Content,
Type2Content,
Type3Content,
@ -28,6 +28,17 @@ from ppt_generator.models.other_models import (
)
class LLMTableDataModel(TableDataModel):
x_labels: List[str]
y_labels: List[str]
data: List[List[float]]
class LLMTableModel(TableModel):
name: str
data: LLMTableDataModel
class LLMHeadingModel(BaseModel):
heading: str
description: str
@ -42,17 +53,26 @@ class LLMHeadingModel(BaseModel):
class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
image_prompt: str
def to_content(self) -> HeadingModel:
return HeadingModel(
heading=self.heading,
description=self.description,
)
class LLMHeadingModelWithIconQuery(LLMHeadingModel):
icon_query: str
def to_content(self) -> HeadingModel:
return HeadingModel(
heading=self.heading,
description=self.description,
)
class LLMSlideContentModel(BaseModel):
title: str
def to_content(self) -> SlideContentModel:
raise NotImplementedError("to_content method not implemented")
class LLMType1Content(LLMSlideContentModel):
body: str
@ -101,13 +121,13 @@ class LLMType4Content(LLMSlideContentModel):
class LLMType5Content(LLMSlideContentModel):
body: str
graph: GraphModel
table: LLMTableModel
def to_content(self) -> Type5Content:
return Type5Content(
title=self.title,
body=self.body,
graph=self.graph,
table=self.table,
)
@ -149,17 +169,29 @@ class LLMType8Content(LLMSlideContentModel):
class LLMType9Content(LLMSlideContentModel):
body: List[LLMHeadingModel]
graph: GraphModel
table: LLMTableModel
def to_content(self) -> Type9Content:
return Type9Content(
title=self.title,
body=[each.to_content() for each in self.body],
graph=self.graph,
table=self.table,
)
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
LLMContentUnion = Union[
LLMType1Content,
LLMType2Content,
LLMType3Content,
LLMType4Content,
LLMType5Content,
LLMType6Content,
LLMType7Content,
LLMType8Content,
LLMType9Content,
]
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMContentUnion] = {
TYPE1: LLMType1Content,
TYPE2: LLMType2Content,
TYPE3: LLMType3Content,
@ -174,17 +206,8 @@ LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
class LLMSlideModel(BaseModel):
type: int
content: (
LLMType1Content
| LLMType2Content
| LLMType4Content
| LLMType5Content
| LLMType6Content
| LLMType7Content
| LLMType8Content
| LLMType9Content
)
content: LLMContentUnion
class LLMPresentationModel(BaseModel):
slides: list[LLMSlideModel]
slides: List[LLMSlideModel]

View file

@ -1,7 +1,6 @@
from typing import List, Mapping
from typing import List, Mapping, Union
from pydantic import Field
from graph_processor.models import GraphModel
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
@ -14,6 +13,8 @@ from ppt_generator.models.other_models import (
TYPE9,
)
from ppt_generator.models.llm_models import (
LLMTableDataModel,
LLMTableModel,
LLMHeadingModel,
LLMHeadingModelWithImagePrompt,
LLMHeadingModelWithIconQuery,
@ -32,239 +33,179 @@ from ppt_generator.models.llm_models import (
)
class LLMTableDataModelWithValidation(LLMTableDataModel):
x_labels: List[str] = Field(
description="X labels of the table",
min_length=1,
max_length=5,
)
y_labels: List[str] = Field(
description="Y labels of the table",
min_length=1,
max_length=3,
)
data: List[List[float]] = Field(
description="Data of the table",
min_length=1,
max_length=5,
)
class LLMTableModelWithValidation(LLMTableModel):
name: str = Field(
description="Name of the table in about 8 words",
min_length=10,
max_length=50,
)
data: LLMTableDataModelWithValidation
class LLMHeadingModelWithValidation(LLMHeadingModel):
heading: str = Field(
description="List item heading to show in slide body",
description="Item heading in about 6 words",
min_length=10,
max_length=30,
max_length=40,
)
description: str = Field(
description="Description of list item in less than 20 words.",
min_length=80,
max_length=150,
description="Item description in about 12 words.",
min_length=50,
max_length=120,
)
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
image_prompt: str = Field(
description="Prompt used to generate image for this item",
description="Item image prompt in about 10 words",
min_length=10,
max_length=50,
max_length=100,
)
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
icon_query: str = Field(
description="Icon query to generate icon for this item",
description="Item icon query in about 4 words",
min_length=10,
max_length=50,
max_length=40,
)
class LLMSlideContentModelWithValidation(LLMSlideContentModel):
title: str = Field(
description="Slide title in about 8 words",
min_length=10,
max_length=80,
)
class LLMType1ContentWithValidation(LLMType1Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: str = Field(
description="Slide content summary in less than 30 words.",
min_length=100,
max_length=200,
description="Slide content summary in about 30 words.",
min_length=50,
max_length=300,
)
image_prompt: str = Field(
description="Prompt used to generate image for this slide.",
description="Slide image prompt in about 5 words",
min_length=10,
max_length=50,
max_length=30,
)
@classmethod
def get_notes(cls):
return ""
class LLMType2ContentWithValidation(LLMType2Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=1,
max_length=4,
)
@classmethod
def get_notes(cls):
return """
- The **Body** should include **1 to 4 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
class LLMType3ContentWithValidation(LLMType3Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=3,
max_length=3,
)
image_prompt: str = Field(
description="Prompt used to generate image for this slide",
description="Slide image prompt in about 5 words",
min_length=10,
max_length=50,
max_length=30,
)
@classmethod
def get_notes(cls):
return """
- The **Body** should include **3 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
class LLMType4ContentWithValidation(LLMType4Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=1,
max_length=3,
)
@classmethod
def get_notes(cls):
return """
- The **Body** should include **1 to 3 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
class LLMType5ContentWithValidation(LLMType5Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: str = Field(
description="Slide content summary in less than 30 words.",
min_length=100,
max_length=250,
description="Slide content summary in about 30 words.",
min_length=50,
max_length=300,
)
graph: GraphModel = Field(description="Graph to show in slide")
@classmethod
def get_notes(self):
return ""
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
class LLMType6ContentWithValidation(LLMType6Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
description: str = Field(
description="Slide content summary in less than 20 words.",
min_length=80,
max_length=150,
description="Slide content summary in about 20 words.",
min_length=50,
max_length=300,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=1,
max_length=3,
)
@classmethod
def get_notes(cls):
return """
- The **Body** should include **1 to 3 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
class LLMType7ContentWithValidation(LLMType7Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=1,
max_length=4,
)
@classmethod
def get_notes(cls):
return """
- The **Body** should include **1 to 4 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
class LLMType8ContentWithValidation(LLMType8Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
description: str = Field(
description="Slide content summary in less than 20 words.",
min_length=80,
max_length=150,
description="Slide content summary in about 20 words.",
min_length=50,
max_length=300,
)
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=1,
max_length=3,
)
@classmethod
def get_notes(cls):
return """
- The **Body** should include **1 to 3 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
class LLMType9ContentWithValidation(LLMType9Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
description="Items to show in slide",
min_length=1,
max_length=3,
)
graph: GraphModel = Field(description="Graph to show in slide")
@classmethod
def get_notes(cls):
return """
- The **Body** should include **1 to 3 HeadingModels**.
- Each **Heading** must consist of **1 to 3 words**.
- Each item **Description** can be upto 10 words.
"""
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING: Mapping[int, LLMSlideContentModel] = {
LLMContentUnionWithValidation = Union[
LLMType1ContentWithValidation,
LLMType2ContentWithValidation,
LLMType3ContentWithValidation,
LLMType4ContentWithValidation,
LLMType5ContentWithValidation,
LLMType6ContentWithValidation,
LLMType7ContentWithValidation,
LLMType8ContentWithValidation,
LLMType9ContentWithValidation,
]
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION: Mapping[
int, LLMContentUnionWithValidation
] = {
TYPE1: LLMType1ContentWithValidation,
TYPE2: LLMType2ContentWithValidation,
TYPE3: LLMType3ContentWithValidation,
@ -279,17 +220,8 @@ LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING: Mapping[int, LLMSlideContentModel] = {
class LLMSlideModelWithValidation(LLMSlideModel):
type: int
content: (
LLMType1ContentWithValidation
| LLMType2ContentWithValidation
| LLMType4ContentWithValidation
| LLMType5ContentWithValidation
| LLMType6ContentWithValidation
| LLMType7ContentWithValidation
| LLMType8ContentWithValidation
| LLMType9ContentWithValidation
)
content: LLMContentUnionWithValidation
class LLMPresentationModelWithValidation(LLMPresentationModel):
slides: list[LLMSlideModelWithValidation]
slides: List[LLMSlideModelWithValidation]

View file

@ -6,8 +6,6 @@ from pptx.util import Pt
from pptx.enum.text import PP_ALIGN
from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE, MSO_CONNECTOR_TYPE
from graph_processor.models import GraphModel
class PptxBoxShapeEnum(Enum):
RECTANGLE = "rectangle"
@ -138,14 +136,6 @@ class PptxPictureBoxModel(PptxShapeModel):
picture: PptxPictureModel
class PptxGraphBoxModel(PptxShapeModel):
position: PptxPositionModel
category_font: Optional[PptxFontModel] = None
value_font: Optional[PptxFontModel] = None
legend_font: Optional[PptxFontModel] = None
graph: GraphModel
class PptxConnectorModel(PptxShapeModel):
type: MSO_CONNECTOR_TYPE = MSO_CONNECTOR_TYPE.STRAIGHT
position: PptxPositionModel
@ -159,13 +149,10 @@ class PptxSlideModel(BaseModel):
| PptxAutoShapeBoxModel
| PptxConnectorModel
| PptxPictureBoxModel
| PptxGraphBoxModel
]
class PptxPresentationModel(BaseModel):
# theme: PresentationTheme
# watermark: bool
background_color: str
shapes: Optional[List[PptxShapeModel]] = None
slides: List[PptxSlideModel]

View file

@ -6,26 +6,12 @@ from lxml import etree
from pptx import Presentation
from pptx.shapes.autoshape import Shape
from pptx.slide import Slide
from pptx.chart.data import ChartData, BubbleChartData
from pptx.chart.chart import Chart
from pptx.text.text import _Paragraph, TextFrame, Font, _Run
from pptx.enum.chart import (
XL_CHART_TYPE,
XL_LEGEND_POSITION,
XL_LABEL_POSITION,
)
from pptx.opc.constants import RELATIONSHIP_TYPE as RT
from lxml.etree import fromstring, tostring
from PIL import Image
from pptx.util import Pt
from graph_processor.models import (
BarGraphDataModel,
BubbleChartDataModel,
GraphTypeEnum,
LineChartDataModel,
PieChartDataModel,
)
from pptx.dml.color import RGBColor
from ppt_generator.models.pptx_models import (
PptxAutoShapeBoxModel,
@ -33,7 +19,6 @@ from ppt_generator.models.pptx_models import (
PptxConnectorModel,
PptxFillModel,
PptxFontModel,
PptxGraphBoxModel,
PptxParagraphModel,
PptxPictureBoxModel,
PptxPositionModel,
@ -63,8 +48,6 @@ class PptxPresentationCreator:
self._ppt_model = ppt_model
self._slide_models = ppt_model.slides
# self._theme = ppt_model.theme
# self._watermark = ppt_model.watermark
self._ppt = Presentation()
self._ppt.slide_width = Pt(1280)
@ -73,7 +56,6 @@ class PptxPresentationCreator:
self._slide_fill = PptxFillModel(color=ppt_model.background_color)
def create_ppt(self):
# self.set_presentation_theme()
for slide_model in self._slide_models:
# Adding global shapes to slide
@ -120,16 +102,9 @@ class PptxPresentationCreator:
elif model_type is PptxTextBoxModel:
self.add_textbox(slide, shape_model)
elif model_type is PptxGraphBoxModel:
self.add_graph(slide, shape_model)
elif model_type is PptxConnectorModel:
self.add_connector(slide, shape_model)
# if self._watermark:
# Adding watermark
# self.add_picture(slide, self.get_watermark_box_model())
def add_connector(self, slide: Slide, connector_model: PptxConnectorModel):
if connector_model.thickness == 0:
return
@ -139,126 +114,6 @@ class PptxPresentationCreator:
connector_shape.line.width = Pt(connector_model.thickness)
connector_shape.line.color.rgb = RGBColor.from_string(connector_model.color)
def add_graph(self, slide: Slide, graph_box_model: PptxGraphBoxModel):
chart_data = None
chart_type = None
graph = graph_box_model.graph
match (graph.type):
case GraphTypeEnum.bar:
chart_data = self.get_bar_graph(graph.data)
chart_type = XL_CHART_TYPE.COLUMN_CLUSTERED
case GraphTypeEnum.scatter:
chart_data = self.get_scatter_graph(graph.data)
chart_type = XL_CHART_TYPE.XY_SCATTER
case GraphTypeEnum.bubble:
chart_data = self.get_bubble_graph(graph.data)
chart_type = XL_CHART_TYPE.BUBBLE
case GraphTypeEnum.line:
chart_data = self.get_line_graph(graph.data)
chart_type = XL_CHART_TYPE.LINE
case GraphTypeEnum.pie:
chart_data = self.get_pie_graph(graph.data)
chart_type = XL_CHART_TYPE.PIE
if chart_data:
chart: Chart = slide.shapes.add_chart(
chart_type, *graph_box_model.position.to_pt_list(), chart_data
).chart
self.apply_graph_styles(chart, graph_box_model)
def apply_graph_styles(self, chart, graph_box_model: PptxGraphBoxModel):
graph = graph_box_model.graph
if graph.type in [GraphTypeEnum.pie, GraphTypeEnum.scatter]:
chart.has_legend = True
chart.legend.position = XL_LEGEND_POSITION.RIGHT
else:
chart.has_legend = False
if graph_box_model.legend_font:
self.apply_font(chart.font, graph_box_model.legend_font)
try:
category_axis = chart.category_axis
if graph_box_model.category_font:
font = category_axis.tick_labels.font
self.apply_font(font, graph_box_model.category_font)
except:
print("-" * 20)
print("Could not apply category labels style")
try:
value_axis = chart.value_axis
tick_labels = value_axis.tick_labels
if graph.postfix:
tick_labels.number_format = f'0"{graph.postfix}"'
if graph_box_model.value_font:
self.apply_font(tick_labels.font, graph_box_model.value_font)
except:
print("-" * 20)
print("Could not apply tick labels style")
if graph_box_model.graph.type is GraphTypeEnum.pie:
for plot in chart.plots:
try:
plot.has_data_labels = True
plot.data_labels.position = (
XL_LABEL_POSITION.OUTSIDE_END
if graph_box_model.graph.type is GraphTypeEnum.bar
else XL_LABEL_POSITION.CENTER
)
if graph.postfix:
plot.data_labels.number_format = f'0"{graph.postfix}"'
if graph_box_model.value_font:
self.apply_font(
plot.data_labels.font,
(
graph_box_model.value_font
if graph_box_model.graph.type is GraphTypeEnum.bar
else PptxFontModel(
# size=self._theme.fonts.p2,
size=16,
bold=True,
color="ffffff",
)
),
)
except:
print("-" * 20)
print("Could not apply data labels style")
def get_bar_graph(self, graph: BarGraphDataModel):
chart_data = ChartData()
chart_data.categories = graph.get_categories()
for series in graph.series:
chart_data.add_series(series.get_name(), series.data)
return chart_data
def get_bubble_graph(self, graph: BubbleChartDataModel):
chart_data = BubbleChartData()
for each in graph.series:
series = chart_data.add_series(each.get_name())
for point in each.points:
series.add_data_point(*point.to_list())
return chart_data
def get_line_graph(self, graph: LineChartDataModel):
chart_data = ChartData()
chart_data.categories = graph.get_categories()
for series in graph.series:
chart_data.add_series(series.get_name(), series.data)
return chart_data
def get_pie_graph(self, graph: PieChartDataModel):
chart_data = ChartData()
chart_data.categories = graph.get_categories()
chart_data.add_series("", graph.series[0].data)
return chart_data
def add_picture(self, slide: Slide, picture_model: PptxPictureBoxModel):
image_path = picture_model.picture.path
if (
@ -562,17 +417,5 @@ class PptxPresentationCreator:
font.italic = font_model.italic
font.size = Pt(font_model.size)
# def get_watermark_box_model(self):
# watermark_asset_path = f"assets/images/{'watermark_dark.png' if self._theme == PresentationTheme.dark else 'watermark.png'}"
# return PptxPictureBoxModel(
# position=PptxPositionModel(left=1120, top=685, width=140),
# clip=False,
# picture=PptxPictureModel(
# is_network=False,
# path=watermark_asset_path,
# ),
# )
def save(self, path: str):
self._ppt.save(path)

View file

@ -1,149 +1,142 @@
from typing import Optional
from api.utils.utils import get_large_model, get_small_model
from ppt_config_generator.models import SlideMarkdownModel
from ppt_generator.fix_validation_errors import get_validated_response
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
from api.utils.model_utils import get_large_model, get_llm_client, get_small_model
from ppt_config_generator.models import SlideMarkdownModel
from ppt_generator.models.llm_models import (
LLM_CONTENT_TYPE_MAPPING,
LLMSlideContentModel,
LLMContentUnion,
)
from ppt_generator.models.llm_models_with_validations import (
LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING,
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION,
)
from ppt_generator.models.other_models import SlideTypeModel
from ppt_generator.models.slide_model import SlideModel
prompt_template_to_generate_slide_content = ChatPromptTemplate.from_messages(
[
(
"system",
"""
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
def get_prompt_to_generate_slide_content(title: str, outline: str):
return [
{
"role": "system",
"content": f"""
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
# Steps
1. Analyze the outline and title.
2. Generate structured slide based on the outline and title.
3. Generate image prompts and icon queries if mentioned in schema.
4. Generate graph if mentioned in schema.
# Steps
1. Analyze the outline and title.
2. Generate structured slide based on the outline and title.
3. Generate image prompts and icon queries if mentioned in schema.
4. Generate graph if mentioned in schema.
# Notes
- Slide body should not use words like "This slide", "This presentation".
- Rephrase the slide body to make it flow naturally.
- Do not use markdown formatting in slide body.
- **Icon query** must be a generic single word noun.
- **Image prompt** should be a 2-3 words phrase.
- Try to make paragraphs as short as possible.
{notes}
# Notes
- Slide body should not use words like "This slide", "This presentation".
- Rephrase the slide body to make it flow naturally.
- Do not use markdown formatting in slide body.
""",
),
(
"user",
"""
## Slide Title
{title}
},
{
"role": "user",
"content": f"""
## Slide Title
{title}
## Slide Outline
{outline}
""",
),
## Slide Outline
{outline}
""",
},
]
)
prompt_template_to_edit_slide_content = ChatPromptTemplate.from_messages(
[
(
"system",
"""
def get_prompt_to_edit_slide_content(
prompt: str,
slide_data: dict,
theme: Optional[dict] = None,
language: Optional[str] = None,
):
return [
{
"role": "system",
"content": """
Edit Slide data based on provided prompt, follow mentioned steps and notes and provide structured output.
# Notes
- Provide output in language mentioned in **Input**.
- The goal is to change Slide data based on the provided prompt.
- Do not change **Image prompts** and **Icon queries** if not asked for in prompt.
- Generate **Image prompts** and **Icon queries** if asked to generate or change image or icons in prompt.
- Ensure there are no line breaks in the JSON.
- Do not use special characters for highlighting.
{notes}
- Generate **Image prompts** and **Icon queries** if asked to generate or change in prompt.
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
),
(
"user",
"""
- Prompt: {prompt}
- Output Language: {language}
- Image Prompts and Icon Queries Language: English
- Theme: {theme}
- Slide data: {slide_data}
""",
),
},
{
"role": "user",
"content": f"""
- Prompt: {prompt}
- Output Language: {language}
- Image Prompts and Icon Queries Language: English
- Theme: {theme}
- Slide data: {slide_data}
""",
},
]
)
prompt_template_to_select_slide_type = ChatPromptTemplate.from_messages(
[
(
"system",
"""
Select a Slide Type based on provided user prompt and current slide data.
def get_prompt_to_select_slide_type(prompt: str, slide_data: dict, slide_type: int):
return [
{
"role": "system",
"content": """
Select a Slide Type based on provided user prompt and current slide data.
Select slide based on following slide description and make sure it matches user requirement:
# Slide Types (Slide Type : Slide Description)
- **1**: contains title, description and image.
- **2**: contains title and list of items.
- **4**: contains title and list of items with images.
- **5**: contains title, description and a graph.
- **6**: contains title, description and list of items.
- **7**: contains title and list of items with icons.
- **8**: contains title, description and list of items with icons.
- **9**: contains title, list of items and a graph.
Select slide based on following slide description and make sure it matches user requirement:
# Slide Types (Slide Type : Slide Description)
- **1**: contains title, description and image.
- **2**: contains title and list of items.
- **4**: contains title and list of items with images.
- **5**: contains title, description and a graph.
- **6**: contains title, description and list of items.
- **7**: contains title and list of items with icons.
- **8**: contains title, description and list of items with icons.
- **9**: contains title, list of items and a graph.
# Notes
- Do not select different slide type than current unless absolutely necessary as per user prompt.
# Notes
- Do not select different slide type than current unless absolutely necessary as per user prompt.
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
),
(
"user",
"""
- User Prompt: {prompt}
- Current Slide Data: {slide_data}
- Current Slide Type: {slide_type}
""",
),
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
},
{
"role": "user",
"content": f"""
- User Prompt: {prompt}
- Current Slide Data: {slide_data}
- Current Slide Type: {slide_type}
""",
},
]
)
async def get_slide_content_from_type_and_outline(
slide_type: int, outline: SlideMarkdownModel
) -> LLMSlideContentModel:
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
model = get_small_model().with_structured_output(
content_type_model_type.model_json_schema()
)
chain = prompt_template_to_generate_slide_content | model
) -> LLMContentUnion:
response_model = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
return await get_validated_response(
chain,
{
"title": outline.title,
"outline": outline.body,
"notes": content_type_model_type.get_notes(),
},
content_type_model_type,
validation_model,
client = get_llm_client()
model = get_small_model()
response = await client.beta.chat.completions.parse(
model=model,
temperature=0.5,
messages=get_prompt_to_generate_slide_content(
outline.title,
outline.body,
),
response_format=response_model,
)
return response.choices[0].message.parsed
async def get_edited_slide_content_model(
prompt: str,
@ -151,29 +144,24 @@ async def get_edited_slide_content_model(
slide: SlideModel,
theme: Optional[dict] = None,
language: Optional[str] = None,
):
) -> LLMContentUnion:
client = get_llm_client()
model = get_large_model()
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
chain = prompt_template_to_edit_slide_content | model.with_structured_output(
content_type_model_type.model_json_schema()
)
content_type_model_type = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
slide_data = slide.content.to_llm_content().model_dump_json()
edited_content = await get_validated_response(
chain,
{
"prompt": prompt,
"language": language or "English",
"theme": theme,
"slide_data": slide_data,
"notes": "",
},
content_type_model_type,
validation_model,
response = await client.beta.chat.completions.parse(
model=model,
temperature=0.2,
messages=get_prompt_to_edit_slide_content(
prompt,
slide_data,
theme,
language,
),
response_format=content_type_model_type,
)
return edited_content.to_content()
return response.choices[0].message.parsed
async def get_slide_type_from_prompt(
@ -181,18 +169,15 @@ async def get_slide_type_from_prompt(
slide: SlideModel,
) -> SlideTypeModel:
client = get_llm_client()
model = get_small_model()
chain = prompt_template_to_select_slide_type | model.with_structured_output(
SlideTypeModel.model_json_schema()
)
slide_data = slide.content.to_llm_content().model_dump_json()
return await get_validated_response(
chain,
{
"prompt": prompt,
"slide_data": slide_data,
"slide_type": slide.type,
},
SlideTypeModel,
response = await client.beta.chat.completions.parse(
model=model,
temperature=0.2,
messages=get_prompt_to_select_slide_type(
prompt, slide.content.to_llm_content().model_dump_json(), slide.type
),
response_format=SlideTypeModel,
)
return response.choices[0].message.parsed

View file

@ -19,7 +19,7 @@ dnspython==2.7.0
email_validator==2.2.0
fastapi==0.115.12
fastapi-cli==0.0.7
fastembed==0.7.0
fastembed_vectorstore==0.1.5
filelock==3.18.0
filetype==1.2.0
flatbuffers==25.2.10
@ -28,6 +28,7 @@ fsspec==2025.3.2
google-ai-generativelanguage==0.6.18
google-api-core==2.24.2
google-auth==2.40.1
google-genai==1.23.0
googleapis-common-protos==1.70.0
greenlet==3.2.2
grpcio==1.72.0rc1
@ -44,14 +45,6 @@ Jinja2==3.1.6
jiter==0.9.0
jsonpatch==1.33
jsonpointer==3.0.0
langchain==0.3.25
langchain-community==0.3.24
langchain-core==0.3.65
langchain-google-genai==2.1.4
langchain-ollama==0.3.3
langchain-openai==0.3.16
langchain-text-splitters==0.3.8
langsmith==0.3.45
loguru==0.7.3
lxml==5.4.0
markdown-it-py==3.0.0
@ -65,7 +58,7 @@ mypy_extensions==1.1.0
numpy==2.2.5
ollama==0.5.1
onnxruntime==1.22.0
openai==1.78.1
openai==1.91.0
orjson==3.10.18
packaging==24.2
pdfminer.six==20250327
@ -102,7 +95,7 @@ SQLAlchemy==2.0.41
sqlmodel==0.0.24
starlette==0.46.2
sympy==1.14.0
tenacity==9.1.2
tenacity==8.5.0
tiktoken==0.9.0
tokenizers==0.21.1
tqdm==4.67.1

View file

@ -1,58 +1,35 @@
import os
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
# search_tool = DuckDuckGoSearchRun(
# api_wrapper=DuckDuckGoSearchAPIWrapper(max_results=50)
# )
prompt_template = ChatPromptTemplate.from_messages(
[
(
"system",
"""
Use provided prompt and search results to create an elaborate and up-to-date research report in mentioned language.
def get_prompt_template():
return [
{
"role": "system",
"content": """
Use provided prompt and search results to create an elaborate and up-to-date research report in mentioned language.
# Steps
1. Analyze the prompt and search results.
2. Extract topic of the report.
3. Generate a report in markdown format.
# Steps
1. Analyze the prompt and search results.
2. Extract topic of the report.
3. Generate a report in markdown format.
# Notes
- If language is not mentioned, use language from prompt.
- Format of report should be like *Research Report*.
- Ignore formatting if mentioned in prompt.
# Notes
- If language is not mentioned, use language from prompt.
- Format of report should be like *Research Report*.
- Ignore formatting if mentioned in prompt.
""",
),
(
"human",
"""
- Prompt: {prompt}
- Language: {language}
- Search Results: {search_results}
},
{
"role": "human",
"content": """
- Prompt: {prompt}
- Language: {language}
- Search Results: {search_results}
""",
),
},
]
)
async def get_report(query: str, language: Optional[str]):
model = (
ChatOpenAI(model="gpt-4.1-nano")
if os.getenv("LLM") == "openai"
else ChatGoogleGenerativeAI(model="gemini-2.0-flash")
)
chain = prompt_template | model
# search_results = await search_tool.ainvoke(query)
# response = await chain.ainvoke(
# {
# "prompt": query,
# "language": language,
# "search_results": search_results,
# }
# )
# return response.content
return "Research Report coming soon"

View file

@ -1,3 +1,4 @@
import os
import uvicorn
import argparse
@ -8,6 +9,8 @@ from api.main import app
app
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
parser = argparse.ArgumentParser(description="Run the FastAPI server")
parser.add_argument(
"--port", type=int, required=True, help="Port number to run the server on"

View file

@ -1,8 +1,11 @@
import os
import uvicorn
import argparse
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
parser = argparse.ArgumentParser(description="Run the FastAPI server")
parser.add_argument(
"--port", type=int, required=True, help="Port number to run the server on"

View file

@ -1,9 +1,12 @@
import os
import uvicorn
from dotenv import load_dotenv
load_dotenv()
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
uvicorn.run(
"api.main:app", host="0.0.0.0", port=8000, log_level="info", reload=True
)

View file

@ -215,7 +215,7 @@ const Header = ({
if (response.ok) {
const { path: pdfPath } = await response.json();
const staticFileUrl = getStaticFileUrl(pdfPath);
window.open(staticFileUrl, '_self');
window.open(staticFileUrl, '_blank');
} else {
throw new Error("Failed to export PDF");
}

View file

@ -36,7 +36,9 @@ export async function POST(request: Request) {
LLM: userConfig.LLM || existingConfig.LLM,
OPENAI_API_KEY: userConfig.OPENAI_API_KEY || existingConfig.OPENAI_API_KEY,
GOOGLE_API_KEY: userConfig.GOOGLE_API_KEY || existingConfig.GOOGLE_API_KEY,
OLLAMA_MODEL: userConfig.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL,
MODEL: userConfig.MODEL || existingConfig.MODEL,
LLM_PROVIDER_URL: userConfig.LLM_PROVIDER_URL || existingConfig.LLM_PROVIDER_URL,
LLM_API_KEY: userConfig.LLM_API_KEY || existingConfig.LLM_API_KEY,
PEXELS_API_KEY: userConfig.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY,
}
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig))

View file

@ -2,13 +2,29 @@
import React, { useState, useEffect } from "react";
import Header from "../dashboard/components/Header";
import Wrapper from "@/components/Wrapper";
import { Settings, Key, Loader2 } from 'lucide-react';
import { Settings, Key, Loader2, Check, ChevronsUpDown } from 'lucide-react';
import { toast } from '@/hooks/use-toast';
import { RootState } from "@/store/store";
import { useSelector } from "react-redux";
import { handleSaveLLMConfig } from "@/utils/storeHelpers";
import { useRouter } from "next/navigation";
import { Select, SelectContent, SelectItem, SelectTrigger } from "@/components/ui/select";
import { Button } from "@/components/ui/button";
import {
Command,
CommandEmpty,
CommandGroup,
CommandInput,
CommandItem,
CommandList,
} from "@/components/ui/command";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { cn } from "@/lib/utils";
import { Switch } from "@/components/ui/switch";
const PROVIDER_CONFIGS: Record<string, ProviderConfig> = {
openai: {
@ -55,14 +71,22 @@ const SettingsPage = () => {
done: false,
});
const [isLoading, setIsLoading] = useState<boolean>(false);
const [openModelSelect, setOpenModelSelect] = useState(false);
const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState<boolean>(false);
const api_key_changed = (apiKey: string) => {
const api_key_changed = (apiKey: string, field?: string) => {
if (llmConfig.LLM === 'openai') {
setLlmConfig({ ...llmConfig, OPENAI_API_KEY: apiKey });
} else if (llmConfig.LLM === 'google') {
setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: apiKey });
} else if (llmConfig.LLM === 'ollama') {
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: apiKey });
if (field === 'pexels') {
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: apiKey });
} else if (field === 'ollama_url') {
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: apiKey });
} else if (field === 'ollama_api_key') {
setLlmConfig({ ...llmConfig, LLM_API_KEY: apiKey });
}
}
}
@ -87,7 +111,7 @@ const SettingsPage = () => {
}
}
try {
await handleSaveLLMConfig(llmConfig);
await handleSaveLLMConfig(llmConfig, useCustomOllamaUrl);
toast({
title: 'Success',
description: 'Configuration saved successfully',
@ -116,7 +140,7 @@ const SettingsPage = () => {
return new Promise((resolve, reject) => {
const interval = setInterval(async () => {
try {
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`);
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`);
if (response.status === 200) {
const data = await response.json();
@ -163,6 +187,14 @@ const SettingsPage = () => {
}
}, [userConfigState.llm_config.LLM]);
useEffect(() => {
if (!useCustomOllamaUrl) {
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: undefined, LLM_API_KEY: undefined });
} else {
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', LLM_API_KEY: '' });
}
}, [useCustomOllamaUrl]);
if (!canChangeKeys) {
return null;
}
@ -262,70 +294,90 @@ const SettingsPage = () => {
</label>
<div className="w-full">
{ollamaModels.length > 0 ? (
<Select value={llmConfig.OLLAMA_MODEL} onValueChange={(value) => setLlmConfig({ ...llmConfig, OLLAMA_MODEL: value })}>
<SelectTrigger className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400">
<div className="flex items-center justify-between w-full">
<Popover open={openModelSelect} onOpenChange={setOpenModelSelect}>
<PopoverTrigger asChild>
<Button
variant="outline"
role="combobox"
aria-expanded={openModelSelect}
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
>
<div className="flex gap-3 items-center">
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.icon}
alt={`${llmConfig.OLLAMA_MODEL} icon`}
className="rounded-sm"
/>
</div>
{llmConfig.MODEL && (
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={ollamaModels.find(m => m.value === llmConfig.MODEL)?.icon}
alt={`${llmConfig.MODEL} icon`}
className="rounded-sm"
/>
</div>
)}
<span className="text-sm font-medium text-gray-900">
{llmConfig.OLLAMA_MODEL ? (
ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.label || llmConfig.OLLAMA_MODEL
{llmConfig.MODEL ? (
ollamaModels.find(m => m.value === llmConfig.MODEL)?.label || llmConfig.MODEL
) : (
'Select a model'
)}
</span>
{llmConfig.OLLAMA_MODEL && (
{llmConfig.MODEL && (
<span className="text-xs text-gray-500 bg-gray-100 rounded-full px-2 py-1">
{ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.size}
{ollamaModels.find(m => m.value === llmConfig.MODEL)?.size}
</span>
)}
</div>
</div>
</SelectTrigger>
<SelectContent className="max-h-80">
<div className="p-2">
<div className="text-xs font-semibold text-gray-500 uppercase tracking-wide mb-3 pt-3 px-2">
Available Models
</div>
{ollamaModels.map((model, index) => (
<SelectItem
key={index}
value={model.value}
className="relative cursor-pointer rounded-md py-3 hover:bg-gray-50 focus:bg-gray-50 focus:outline-none transition-colors"
>
<div className="flex gap-3 items-center">
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={model.icon}
alt={`${model.label} icon`}
className=" rounded-sm"
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
</Button>
</PopoverTrigger>
<PopoverContent className="p-0" align="start" style={{ width: 'var(--radix-popover-trigger-width)' }}>
<Command>
<CommandInput placeholder="Search model..." />
<CommandList>
<CommandEmpty>No model found.</CommandEmpty>
<CommandGroup>
{ollamaModels.map((model, index) => (
<CommandItem
key={index}
value={model.value}
onSelect={(value) => {
setLlmConfig({ ...llmConfig, MODEL: value });
setOpenModelSelect(false);
}}
>
<Check
className={cn(
"mr-2 h-4 w-4",
llmConfig.MODEL === model.value ? "opacity-100" : "opacity-0"
)}
/>
</div>
<div className="flex flex-col space-y-1 flex-1">
<div className="flex items-center justify-between gap-2">
<span className="text-sm font-medium text-gray-900 capitalize">
{model.label}
</span>
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
{model.size}
</span>
<div className="flex gap-3 items-center">
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={model.icon}
alt={`${model.label} icon`}
className="rounded-sm"
/>
</div>
<div className="flex flex-col space-y-1 flex-1">
<div className="flex items-center justify-between gap-2">
<span className="text-sm font-medium text-gray-900 capitalize">
{model.label}
</span>
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
{model.size}
</span>
</div>
<span className="text-xs text-gray-600 leading-relaxed">
{model.description}
</span>
</div>
</div>
<span className="text-xs text-gray-600 leading-relaxed">
{model.description}
</span>
</div>
</div>
</SelectItem>
))}
</div>
</SelectContent>
</Select>
</CommandItem>
))}
</CommandGroup>
</CommandList>
</Command>
</PopoverContent>
</Popover>
) : (
<div className="w-full border border-gray-300 rounded-lg p-4">
<div className="flex items-center space-x-3">
@ -344,6 +396,62 @@ const SettingsPage = () => {
</p>
)}
</div>
{/* Custom Ollama URL Configuration */}
<div>
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
<label className="text-sm font-medium text-gray-700">
Use custom Ollama URL
</label>
<Switch
checked={useCustomOllamaUrl}
onCheckedChange={setUseCustomOllamaUrl}
/>
</div>
{useCustomOllamaUrl && (
<>
<div className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-2">
Ollama URL
</label>
<div className="relative">
<input
type="text"
required
placeholder="Enter your Ollama URL"
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
value={llmConfig.LLM_PROVIDER_URL || ''}
onChange={(e) => api_key_changed(e.target.value, 'ollama_url')}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
Change this if you are using a custom Ollama instance
</p>
</div>
<div className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-2">
Ollama API Key
</label>
<div className="relative">
<input
type="text"
required
placeholder="Enter your Ollama API key"
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
value={llmConfig.LLM_API_KEY || ''}
onChange={(e) => api_key_changed(e.target.value, 'ollama_api_key')}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
Provide this if you are using a custom Ollama instance
</p>
</div>
</>
)}
</div>
<div>
<label className="block text-sm font-medium text-gray-700 mb-2">
Pexels API Key (required for images)
@ -355,12 +463,12 @@ const SettingsPage = () => {
placeholder="Enter your Pexels API key"
className="flex-1 px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
value={llmConfig.PEXELS_API_KEY || ''}
onChange={(e) => api_key_changed(e.target.value)}
onChange={(e) => api_key_changed(e.target.value, 'pexels')}
/>
<button
onClick={handleSaveConfig}
disabled={isLoading || !llmConfig.OLLAMA_MODEL}
className={`px-4 py-2 rounded-lg transition-colors ${isLoading || !llmConfig.OLLAMA_MODEL
disabled={isLoading || !llmConfig.MODEL}
className={`px-4 py-2 rounded-lg transition-colors ${isLoading || !llmConfig.MODEL
? 'bg-gray-400 cursor-not-allowed'
: 'bg-blue-600 hover:bg-blue-700'
} text-white`}
@ -374,7 +482,7 @@ const SettingsPage = () => {
}
</div>
) : (
!llmConfig.OLLAMA_MODEL ? 'Select Model' : 'Save'
!llmConfig.MODEL ? 'Select Model' : 'Save'
)}
</button>
</div>

View file

@ -41,11 +41,11 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) {
llmConfig.LLM = 'openai';
}
dispatch(setLLMConfig(llmConfig));
const isValid = hasValidLLMConfig(llmConfig);
const isValid = hasValidLLMConfig(llmConfig, false);
if (isValid) {
// Check if the selected Ollama model is pulled
if (llmConfig.LLM === 'ollama') {
const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.OLLAMA_MODEL);
const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.MODEL);
if (!isPulled) {
router.push('/');
setLoadingToFalseAfterNavigatingTo('/');

View file

@ -2,7 +2,7 @@
import { useState, useEffect } from "react";
import { useRouter } from "next/navigation";
import { toast } from "@/hooks/use-toast";
import { Info, ExternalLink, PlayCircle, Loader2 } from "lucide-react";
import { Info, ExternalLink, PlayCircle, Loader2, Check, ChevronsUpDown } from "lucide-react";
import Link from "next/link";
import {
Accordion,
@ -14,6 +14,22 @@ import { useSelector } from "react-redux";
import { RootState } from "@/store/store";
import { handleSaveLLMConfig } from "@/utils/storeHelpers";
import { Select, SelectContent, SelectItem, SelectTrigger } from "./ui/select";
import { Button } from "./ui/button";
import {
Command,
CommandEmpty,
CommandGroup,
CommandInput,
CommandItem,
CommandList,
} from "./ui/command";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "./ui/popover";
import { cn } from "@/lib/utils";
import { Switch } from "./ui/switch";
interface ModelOption {
value: string;
@ -100,36 +116,7 @@ const PROVIDER_CONFIGS: Record<string, ProviderConfig> = {
},
},
ollama: {
textModels: [
{
value: "llama3.1:8b",
label: "Llama3.1:8b",
description: "Balanced model for most tasks",
icon: "/icons/ollama.png",
size: "8GB",
},
{
value: "llama3.1:70b",
label: "Llama3.1:70b",
description: "Large model for complex tasks",
icon: "/icons/ollama.png",
size: "70GB",
},
{
value: "llama3.1:14b",
label: "Llama3.1:14b",
description: "Large model for complex tasks",
icon: "/icons/ollama.png",
size: "14GB",
},
{
value: "llama3.1:11b",
label: "Llama3.1:11b",
description: "Large model for complex tasks",
icon: "/icons/ollama.png",
size: "11GB",
},
],
textModels: [],
imageModels: [
{
value: "pexels",
@ -171,16 +158,24 @@ export default function Home() {
done: false,
});
const [isLoading, setIsLoading] = useState<boolean>(false);
const [openModelSelect, setOpenModelSelect] = useState(false);
const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState<boolean>(false);
const canChangeKeys = config.can_change_keys;
const api_key_changed = (newApiKey: string) => {
const api_key_changed = (newApiKey: string, field?: string) => {
if (llmConfig.LLM === 'openai') {
setLlmConfig({ ...llmConfig, OPENAI_API_KEY: newApiKey });
} else if (llmConfig.LLM === 'google') {
setLlmConfig({ ...llmConfig, GOOGLE_API_KEY: newApiKey });
} else if (llmConfig.LLM === 'ollama') {
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: newApiKey });
if (field === 'pexels') {
setLlmConfig({ ...llmConfig, PEXELS_API_KEY: newApiKey });
} else if (field === 'ollama_url') {
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: newApiKey });
} else if (field === 'ollama_api_key') {
setLlmConfig({ ...llmConfig, LLM_API_KEY: newApiKey });
}
}
}
@ -205,7 +200,7 @@ export default function Home() {
}
}
try {
await handleSaveLLMConfig(llmConfig);
await handleSaveLLMConfig(llmConfig, useCustomOllamaUrl);
toast({
title: 'Success',
description: 'Configuration saved successfully',
@ -234,7 +229,7 @@ export default function Home() {
return new Promise((resolve, reject) => {
const interval = setInterval(async () => {
try {
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.OLLAMA_MODEL}`);
const response = await fetch(`/api/v1/ppt/ollama/pull-model?name=${llmConfig.MODEL}`);
if (response.status === 200) {
const data = await response.json();
@ -277,6 +272,15 @@ export default function Home() {
}
}, []);
useEffect(() => {
if (!useCustomOllamaUrl) {
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: undefined, LLM_API_KEY: undefined });
} else {
setLlmConfig({ ...llmConfig, LLM_PROVIDER_URL: 'http://localhost:11434', LLM_API_KEY: '' });
}
}, [useCustomOllamaUrl]);
if (!canChangeKeys) {
return null;
}
@ -355,70 +359,90 @@ export default function Home() {
</label>
<div className="w-full">
{ollamaModels.length > 0 ? (
<Select value={llmConfig.OLLAMA_MODEL} onValueChange={(value) => setLlmConfig({ ...llmConfig, OLLAMA_MODEL: value })}>
<SelectTrigger className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400">
<div className="flex items-center justify-between w-full">
<Popover open={openModelSelect} onOpenChange={setOpenModelSelect}>
<PopoverTrigger asChild>
<Button
variant="outline"
role="combobox"
aria-expanded={openModelSelect}
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
>
<div className="flex gap-3 items-center">
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.icon}
alt={`${llmConfig.OLLAMA_MODEL} icon`}
className="rounded-sm"
/>
</div>
{llmConfig.MODEL && (
<div className="w-6 h-6 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={ollamaModels.find(m => m.value === llmConfig.MODEL)?.icon}
alt={`${llmConfig.MODEL} icon`}
className="rounded-sm"
/>
</div>
)}
<span className="text-sm font-medium text-gray-900">
{llmConfig.OLLAMA_MODEL ? (
ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.label || llmConfig.OLLAMA_MODEL
{llmConfig.MODEL ? (
ollamaModels.find(m => m.value === llmConfig.MODEL)?.label || llmConfig.MODEL
) : (
'Select a model'
)}
</span>
{llmConfig.OLLAMA_MODEL && (
{llmConfig.MODEL && (
<span className="text-xs text-gray-500 bg-gray-100 rounded-full px-2 py-1">
{ollamaModels.find(m => m.value === llmConfig.OLLAMA_MODEL)?.size}
{ollamaModels.find(m => m.value === llmConfig.MODEL)?.size}
</span>
)}
</div>
</div>
</SelectTrigger>
<SelectContent className="max-h-80">
<div className="p-2">
<div className="text-xs font-semibold text-gray-500 uppercase tracking-wide mb-3 pt-3 px-2">
Available Models
</div>
{ollamaModels.map((model, index) => (
<SelectItem
key={index}
value={model.value}
className="relative cursor-pointer rounded-md py-3 hover:bg-gray-50 focus:bg-gray-50 focus:outline-none transition-colors"
>
<div className="flex gap-3 items-center">
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={model.icon}
alt={`${model.label} icon`}
className=" rounded-sm"
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
</Button>
</PopoverTrigger>
<PopoverContent className="p-0" align="start" style={{ width: 'var(--radix-popover-trigger-width)' }}>
<Command>
<CommandInput placeholder="Search model..." />
<CommandList>
<CommandEmpty>No model found.</CommandEmpty>
<CommandGroup>
{ollamaModels.map((model, index) => (
<CommandItem
key={index}
value={model.value}
onSelect={(value) => {
setLlmConfig({ ...llmConfig, MODEL: value });
setOpenModelSelect(false);
}}
>
<Check
className={cn(
"mr-2 h-4 w-4",
llmConfig.MODEL === model.value ? "opacity-100" : "opacity-0"
)}
/>
</div>
<div className="flex flex-col space-y-1 flex-1">
<div className="flex items-center justify-between gap-2">
<span className="text-sm font-medium text-gray-900 capitalize">
{model.label}
</span>
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
{model.size}
</span>
<div className="flex gap-3 items-center">
<div className="w-8 h-8 rounded-lg flex items-center justify-center flex-shrink-0">
<img
src={model.icon}
alt={`${model.label} icon`}
className="rounded-sm"
/>
</div>
<div className="flex flex-col space-y-1 flex-1">
<div className="flex items-center justify-between gap-2">
<span className="text-sm font-medium text-gray-900 capitalize">
{model.label}
</span>
<span className="text-xs text-gray-500 bg-gray-100 px-2 py-1 rounded-full">
{model.size}
</span>
</div>
<span className="text-xs text-gray-600 leading-relaxed">
{model.description}
</span>
</div>
</div>
<span className="text-xs text-gray-600 leading-relaxed">
{model.description}
</span>
</div>
</div>
</SelectItem>
))}
</div>
</SelectContent>
</Select>
</CommandItem>
))}
</CommandGroup>
</CommandList>
</Command>
</PopoverContent>
</Popover>
) : (
<div className="w-full border border-gray-300 rounded-lg p-4">
<div className="flex items-center space-x-3">
@ -437,6 +461,59 @@ export default function Home() {
</p>
)}
</div>
<div className="mb-8">
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
<label className="text-sm font-medium text-gray-700">
Use custom Ollama URL
</label>
<Switch
checked={useCustomOllamaUrl}
onCheckedChange={setUseCustomOllamaUrl}
/>
</div>
{useCustomOllamaUrl && (
<>
<div className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-2">
Ollama URL
</label>
<div className="relative">
<input
type="text"
required
placeholder="Enter your Ollama URL"
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
value={llmConfig.LLM_PROVIDER_URL || ''}
onChange={(e) => api_key_changed(e.target.value, 'ollama_url')}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
Change this if you are using a custom Ollama instance
</p>
</div>
<div className="mb-4">
<label className="block text-sm font-medium text-gray-700 mb-2">
Ollama API Key
</label>
<div className="relative">
<input
type="text"
required
placeholder="Enter your Ollama API key"
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
value={llmConfig.LLM_API_KEY || ''}
onChange={(e) => api_key_changed(e.target.value, 'ollama_api_key')}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
Provide this if you are using a custom Ollama instance
</p>
</div>
</>
)}
</div>
<div className="mb-8">
<label className="block text-sm font-medium text-gray-700 mb-2">
Pexels API Key (required for images)
@ -448,7 +525,7 @@ export default function Home() {
placeholder="Enter your Pexels API key"
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
value={llmConfig.PEXELS_API_KEY || ''}
onChange={(e) => api_key_changed(e.target.value)}
onChange={(e) => api_key_changed(e.target.value, 'pexels')}
/>
</div>
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
@ -468,7 +545,7 @@ export default function Home() {
Selected Models
</h3>
<p className="text-sm text-blue-700">
Using {llmConfig.LLM === 'ollama' ? llmConfig.OLLAMA_MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text
Using {llmConfig.LLM === 'ollama' ? llmConfig.MODEL ?? '_____' : PROVIDER_CONFIGS[llmConfig.LLM!].textModels[0].label} for text
generation and {PROVIDER_CONFIGS[llmConfig.LLM!].imageModels[0].label} for
images
</p>
@ -544,7 +621,7 @@ export default function Home() {
}
</div>
) : (
llmConfig.LLM === 'ollama' && !llmConfig.OLLAMA_MODEL
llmConfig.LLM === 'ollama' && !llmConfig.MODEL
? 'Please Select a Model'
: 'Save Configuration'
)}

View file

@ -18,5 +18,7 @@ interface LLMConfig {
OPENAI_API_KEY?: string;
GOOGLE_API_KEY?: string;
PEXELS_API_KEY?: string;
OLLAMA_MODEL?: string;
LLM_PROVIDER_URL?: string;
LLM_API_KEY?: string;
MODEL?: string;
}

View file

@ -1,8 +1,8 @@
import { setLLMConfig } from "@/store/slices/userConfig";
import { store } from "@/store/store";
export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => {
if (!hasValidLLMConfig(llmConfig)) {
export const handleSaveLLMConfig = async (llmConfig: LLMConfig, useCustomOllamaUrl: boolean) => {
if (!hasValidLLMConfig(llmConfig, useCustomOllamaUrl)) {
throw new Error('API key cannot be empty');
}
@ -14,17 +14,21 @@ export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => {
store.dispatch(setLLMConfig(llmConfig));
}
export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
export const hasValidLLMConfig = (llmConfig: LLMConfig, useCustomOllamaUrl: boolean) => {
if (!llmConfig.LLM) return false;
const OPENAI_API_KEY = llmConfig.OPENAI_API_KEY;
const GOOGLE_API_KEY = llmConfig.GOOGLE_API_KEY;
const OLLAMA_MODEL = llmConfig.OLLAMA_MODEL;
const MODEL = llmConfig.MODEL;
const PEXELS_API_KEY = llmConfig.PEXELS_API_KEY;
const isOllamaBaseConfigValid = PEXELS_API_KEY !== '' && PEXELS_API_KEY !== null && PEXELS_API_KEY !== undefined && MODEL !== '' && MODEL !== null && MODEL !== undefined;
return llmConfig.LLM === 'openai' ?
OPENAI_API_KEY !== '' && OPENAI_API_KEY !== null && OPENAI_API_KEY !== undefined :
llmConfig.LLM === 'google' ?
GOOGLE_API_KEY !== '' && GOOGLE_API_KEY !== null && GOOGLE_API_KEY !== undefined :
llmConfig.LLM === 'ollama' ?
PEXELS_API_KEY !== '' && PEXELS_API_KEY !== null && PEXELS_API_KEY !== undefined && OLLAMA_MODEL !== '' && OLLAMA_MODEL !== null && OLLAMA_MODEL !== undefined :
false;
useCustomOllamaUrl ?
isOllamaBaseConfigValid && llmConfig.LLM_PROVIDER_URL !== '' && llmConfig.LLM_PROVIDER_URL !== null && llmConfig.LLM_PROVIDER_URL !== undefined && llmConfig.LLM_API_KEY !== '' && llmConfig.LLM_API_KEY !== null && llmConfig.LLM_API_KEY !== undefined :
isOllamaBaseConfigValid : false;
}