feat(fastapi): use OpenAI client

This commit is contained in:
sauravniraula 2025-06-24 19:53:32 +05:45
parent 3fe88f7725
commit fff3b96fbe
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
21 changed files with 165 additions and 155 deletions

View file

@ -6,4 +6,5 @@ out
build
.git
.gitignore
tmp
tmp
debug

3
.gitignore vendored
View file

@ -7,4 +7,5 @@ __pycache__
node_modules
out
user_data
tmp
tmp
debug

View file

@ -24,9 +24,8 @@ RUN curl -fsSL https://ollama.com/install.sh | sh
COPY servers/fastapi/requirements.txt ./
RUN pip install -r requirements.txt
# Install dependencies for Next.js
WORKDIR /app/servers/nextjs
WORKDIR /node_dependencies
COPY servers/nextjs/package.json servers/nextjs/package-lock.json ./
RUN npm install
@ -40,4 +39,10 @@ COPY nginx.conf /etc/nginx/nginx.conf
EXPOSE 80 3000 8000
# Start the servers
CMD ["/bin/bash", "-c", "ollama serve & service nginx start & service redis-server start && node /app/start.js"]
CMD ["/bin/bash", "-c", "\
rm -rf /app/servers/nextjs/node_modules && \
ln -s /node_dependencies/node_modules /app/servers/nextjs/node_modules && \
ollama serve & \
service nginx start & \
service redis-server start && \
node /app/start.js"]

View file

@ -8,7 +8,8 @@ from contextlib import asynccontextmanager
from api.routers.presentation.router import presentation_router
from api.services.database import sql_engine
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from api.utils.utils import is_ollama_selected, update_env_with_user_config
from api.utils.utils import update_env_with_user_config
from api.utils.model_utils import is_ollama_selected
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"

View file

@ -14,8 +14,8 @@ from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from api.utils.utils import (
get_presentation_dir,
get_presentation_images_dir,
is_ollama_selected,
)
from api.utils.model_utils import is_ollama_selected
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
from image_processor.images_finder import generate_image
from image_processor.icons_finder import get_icon

View file

@ -11,7 +11,7 @@ from api.routers.presentation.models import PresentationGenerateRequest
from api.services.logging import LoggingService
from api.sql_models import KeyValueSqlModel, PresentationSqlModel
from api.services.database import get_sql_session
from api.utils.utils import is_ollama_selected
from api.utils.model_utils import is_ollama_selected
from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel
from ppt_config_generator.structure_generator import generate_presentation_structure

View file

@ -17,7 +17,8 @@ from api.services.database import get_sql_session
from api.services.instances import TEMP_FILE_SERVICE
from api.services.logging import LoggingService
from api.sql_models import PresentationSqlModel, SlideSqlModel
from api.utils.utils import get_presentation_dir, is_ollama_selected
from api.utils.utils import get_presentation_dir
from api.utils.model_utils import is_ollama_selected
from document_processor.loader import DocumentsLoader
from ppt_config_generator.document_summary_generator import generate_document_summary
from ppt_config_generator.models import PresentationMarkdownModel

View file

@ -17,7 +17,8 @@ from api.routers.presentation.models import (
from api.services.database import get_sql_session
from api.services.logging import LoggingService
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
from api.utils.utils import get_presentation_dir, is_ollama_selected
from api.utils.utils import get_presentation_dir
from api.utils.model_utils import is_ollama_selected
from ppt_config_generator.models import (
PresentationMarkdownModel,
PresentationStructureModel,

View file

@ -1,5 +1,7 @@
import os
from openai import AsyncOpenAI
from api.models import SelectedLLMProvider
@ -15,38 +17,56 @@ def get_model_base_url():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OLLAMA:
return "http://localhost:11434"
return "http://localhost:11434/v1"
elif selected_llm == SelectedLLMProvider.OPENAI:
return "https://api.openai.com/v1"
else:
return "https://generativelanguage.googleapis.com/v1beta/openai"
def get_llm_api_key():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return os.getenv("OPENAI_API_KEY")
elif selected_llm == SelectedLLMProvider.GOOGLE:
return os.getenv("GOOGLE_API_KEY")
else:
return "ollama"
def get_llm_client():
client = AsyncOpenAI(
base_url=get_model_base_url(),
api_key=get_llm_api_key(),
)
return client
def get_large_model():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return ChatOpenAI(model="gpt-4.1")
return "gpt-4.1"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
return "gemini-2.0-flash"
else:
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
return os.getenv("OLLAMA_MODEL")
def get_small_model():
selected_llm = os.getenv("LLM")
if selected_llm == "openai":
return ChatOpenAI(model="gpt-4.1-mini")
elif selected_llm == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return "gpt-4.1-mini"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "gemini-2.0-flash"
else:
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
return os.getenv("OLLAMA_MODEL")
def get_nano_model():
selected_llm = os.getenv("LLM")
if selected_llm == "openai":
return ChatOpenAI(model="gpt-4.1-nano")
elif selected_llm == "google":
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OPENAI:
return "gpt-4.1-nano"
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "gemini-2.0-flash"
else:
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
return os.getenv("OLLAMA_MODEL")

View file

@ -9,7 +9,8 @@ from openai import OpenAI
from ppt_generator.models.query_and_prompt_models import (
ImagePromptWithThemeAndAspectRatio,
)
from api.utils.utils import download_file, get_resource, is_ollama_selected
from api.utils.utils import download_file, get_resource
from api.utils.model_utils import is_ollama_selected
async def generate_image(

View file

@ -1,14 +1,10 @@
import asyncio
import os
from typing import List
from langchain_core.documents import Document
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import BaseMessage
from langchain_text_splitters import CharacterTextSplitter
from openai.types.chat.chat_completion import ChatCompletion
from api.utils.utils import get_nano_model
from api.utils.model_utils import get_llm_client, get_nano_model
sysmte_prompt = """
Generate a blog-style summary of the provided document in **more than 2000 words**.
@ -26,26 +22,28 @@ Maintain as much information as possible.
- If **slides structure is mentioned** in document, structure the summary in the same way.
"""
prompt_template = ChatPromptTemplate.from_messages(
[
("system", sysmte_prompt),
("user", "{text}"),
]
)
async def generate_document_summary(documents: List[Document]):
client = get_llm_client()
model = get_nano_model()
text_splitter = CharacterTextSplitter(chunk_size=200000, chunk_overlap=0)
chain = prompt_template | model
coroutines = []
for document in documents:
text = document.page_content
truncated_text = text_splitter.split_text(text)[0]
coroutine = chain.ainvoke({"text": truncated_text})
coroutine = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": sysmte_prompt},
{"role": "user", "content": truncated_text},
],
)
coroutines.append(coroutine)
completions: List[BaseMessage] = await asyncio.gather(*coroutines)
combined = "\n\n\n\n".join([completion.content for completion in completions])
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
combined = "\n\n\n\n".join(
[completion.choices[0].message.content for completion in completions]
)
return combined

View file

@ -1,7 +1,8 @@
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from api.utils.utils import get_large_model
from api.utils.model_utils import get_large_model
from api.utils.variable_length_models import (
get_presentation_markdown_model_with_n_slides,
)
@ -64,7 +65,7 @@ async def generate_ppt_content(
language: Optional[str] = None,
content: Optional[str] = None,
) -> PresentationMarkdownModel:
model = get_large_model()
model = ChatOllama(model=get_large_model(), temperature=0.8)
response_model = get_presentation_markdown_model_with_n_slides(n_slides)
chain = get_prompt_template() | model.with_structured_output(

View file

@ -1,6 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from api.utils.utils import get_small_model
from api.utils.model_utils import get_small_model
from api.utils.variable_length_models import (
get_presentation_structure_model_with_n_slides,
)
@ -59,7 +60,7 @@ async def generate_presentation_structure(
presentation_outline: PresentationMarkdownModel,
) -> PresentationStructureModel:
model = get_small_model()
model = ChatOllama(model=get_small_model(), temperature=0.8)
response_model = get_presentation_structure_model_with_n_slides(
len(presentation_outline.slides)
)

View file

@ -1,12 +1,11 @@
import os
from typing import Optional
from fastapi import HTTPException
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, ValidationError
from api.utils.utils import get_large_model
from api.utils.model_utils import get_large_model
def get_prompt_template():
@ -41,7 +40,7 @@ def get_prompt_template():
async def fix_validation_errors(response_model: BaseModel, response, errors):
model = get_large_model()
model = ChatOllama(model=get_large_model(), temperature=0.8)
chain = get_prompt_template() | model.with_structured_output(
response_model.model_json_schema()

View file

@ -5,7 +5,8 @@ from langchain_core.messages import (
AIMessageChunk,
AIMessage,
)
from api.utils.utils import get_large_model
from langchain_ollama import ChatOllama
from api.utils.model_utils import get_large_model
from ppt_config_generator.models import PresentationMarkdownModel
from ppt_generator.models.llm_models_with_validations import (
LLMPresentationModelWithValidation,
@ -91,7 +92,7 @@ def get_model_and_messages(
presentation_outline: PresentationMarkdownModel,
):
user_message = HumanMessage(presentation_outline.to_string())
model = get_large_model()
model = ChatOllama(model=get_large_model(), temperature=0.8)
return model, system_prompt, user_message

View file

@ -34,48 +34,34 @@ from ppt_generator.models.llm_models import (
class LLMHeadingModelWithValidation(LLMHeadingModel):
heading: str = Field(
description="List item heading to show in slide body",
min_length=10,
max_length=30,
description="List item heading to show in slide body in less than 5 words.",
)
description: str = Field(
description="Description of list item in less than 20 words.",
min_length=80,
max_length=150,
)
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
image_prompt: str = Field(
description="Prompt used to generate image for this item",
min_length=10,
max_length=50,
description="Prompt used to generate image for this item in less than 6 words.",
)
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
icon_query: str = Field(
description="Icon query to generate icon for this item",
min_length=10,
max_length=50,
description="Icon query to generate icon for this item in less than 4 words.",
)
class LLMType1ContentWithValidation(LLMType1Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: str = Field(
description="Slide content summary in less than 30 words.",
min_length=100,
max_length=200,
)
image_prompt: str = Field(
description="Prompt used to generate image for this slide.",
min_length=10,
max_length=50,
description="Prompt used to generate image for this slide in less than 6 words.",
)
@classmethod
@ -85,9 +71,7 @@ class LLMType1ContentWithValidation(LLMType1Content):
class LLMType2ContentWithValidation(LLMType2Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
@ -106,9 +90,7 @@ class LLMType2ContentWithValidation(LLMType2Content):
class LLMType3ContentWithValidation(LLMType3Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
@ -116,9 +98,7 @@ class LLMType3ContentWithValidation(LLMType3Content):
max_length=3,
)
image_prompt: str = Field(
description="Prompt used to generate image for this slide",
min_length=10,
max_length=50,
description="Prompt used to generate image for this slide in less than 6 words.",
)
@classmethod
@ -132,9 +112,7 @@ class LLMType3ContentWithValidation(LLMType3Content):
class LLMType4ContentWithValidation(LLMType4Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="List items to show in slide's body",
@ -153,14 +131,10 @@ class LLMType4ContentWithValidation(LLMType4Content):
class LLMType5ContentWithValidation(LLMType5Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: str = Field(
description="Slide content summary in less than 30 words.",
min_length=100,
max_length=250,
)
graph: GraphModel = Field(description="Graph to show in slide")
@ -171,14 +145,10 @@ class LLMType5ContentWithValidation(LLMType5Content):
class LLMType6ContentWithValidation(LLMType6Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
description: str = Field(
description="Slide content summary in less than 20 words.",
min_length=80,
max_length=150,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",
@ -197,9 +167,7 @@ class LLMType6ContentWithValidation(LLMType6Content):
class LLMType7ContentWithValidation(LLMType7Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
description="List items to show in slide's body",
@ -218,14 +186,10 @@ class LLMType7ContentWithValidation(LLMType7Content):
class LLMType8ContentWithValidation(LLMType8Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
description: str = Field(
description="Slide content summary in less than 20 words.",
min_length=80,
max_length=150,
)
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="List items to show in slide's body",
@ -244,9 +208,7 @@ class LLMType8ContentWithValidation(LLMType8Content):
class LLMType9ContentWithValidation(LLMType9Content):
title: str = Field(
description="Title of the slide",
min_length=10,
max_length=50,
description="Title of the slide in less than 6 words.",
)
body: List[LLMHeadingModelWithValidation] = Field(
description="List items to show in slide's body",

View file

@ -1,5 +1,8 @@
from typing import Optional
from api.utils.utils import get_large_model, get_small_model
from langchain_ollama import ChatOllama
from openai import OpenAI
from api.utils.model_utils import get_large_model, get_llm_client, get_small_model
from ppt_config_generator.models import SlideMarkdownModel
from ppt_generator.fix_validation_errors import get_validated_response
@ -16,42 +19,43 @@ from ppt_generator.models.other_models import SlideTypeModel
from ppt_generator.models.slide_model import SlideModel
prompt_template_to_generate_slide_content = ChatPromptTemplate.from_messages(
[
(
"system",
"""
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
def get_prompt_to_generate_slide_content(
title: str, outline: str, notes: Optional[str] = None
):
return [
{
"role": "system",
"content": f"""
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
# Steps
1. Analyze the outline and title.
2. Generate structured slide based on the outline and title.
3. Generate image prompts and icon queries if mentioned in schema.
4. Generate graph if mentioned in schema.
# Steps
1. Analyze the outline and title.
2. Generate structured slide based on the outline and title.
3. Generate image prompts and icon queries if mentioned in schema.
4. Generate graph if mentioned in schema.
# Notes
- Slide body should not use words like "This slide", "This presentation".
- Rephrase the slide body to make it flow naturally.
- Do not use markdown formatting in slide body.
- **Icon query** must be a generic single word noun.
- **Image prompt** should be a 2-3 words phrase.
- Try to make paragraphs as short as possible.
{notes}
# Notes
- Slide body should not use words like "This slide", "This presentation".
- Rephrase the slide body to make it flow naturally.
- Do not use markdown formatting in slide body.
- **Icon query** must be a generic single word noun.
- **Image prompt** should be a 2-3 words phrase.
- Try to make paragraphs as short as possible.
{notes}
""",
),
(
"user",
"""
## Slide Title
{title}
},
{
"role": "user",
"content": f"""
## Slide Title
{title}
## Slide Outline
{outline}
""",
),
## Slide Outline
{outline}
""",
},
]
)
prompt_template_to_edit_slide_content = ChatPromptTemplate.from_messages(
@ -126,22 +130,26 @@ prompt_template_to_select_slide_type = ChatPromptTemplate.from_messages(
async def get_slide_content_from_type_and_outline(
slide_type: int, outline: SlideMarkdownModel
) -> LLMSlideContentModel:
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
model = get_small_model().with_structured_output(
content_type_model_type.model_json_schema()
)
chain = prompt_template_to_generate_slide_content | model
response_model = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
return await get_validated_response(
chain,
{
"title": outline.title,
"outline": outline.body,
"notes": content_type_model_type.get_notes(),
},
content_type_model_type,
validation_model,
client = get_llm_client()
model = get_small_model()
response = await client.beta.chat.completions.parse(
model=model,
messages=get_prompt_to_generate_slide_content(
outline.title,
outline.body,
response_model.get_notes(),
),
response_format=response_model,
)
with open("debug/llm_response.json", "w") as f:
f.write(response.choices[0].message.content)
return LLM_CONTENT_TYPE_MAPPING[slide_type].model_validate_json(
response.choices[0].message.content
)
@ -152,7 +160,7 @@ async def get_edited_slide_content_model(
theme: Optional[dict] = None,
language: Optional[str] = None,
):
model = get_large_model()
model = ChatOllama(model=get_large_model(), temperature=0.8)
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
@ -181,7 +189,7 @@ async def get_slide_type_from_prompt(
slide: SlideModel,
) -> SlideTypeModel:
model = get_small_model()
model = ChatOllama(model=get_small_model(), temperature=0.8)
chain = prompt_template_to_select_slide_type | model.with_structured_output(
SlideTypeModel.model_json_schema()

View file

@ -65,7 +65,7 @@ mypy_extensions==1.1.0
numpy==2.2.5
ollama==0.5.1
onnxruntime==1.22.0
openai==1.78.1
openai==1.91.0
orjson==3.10.18
packaging==24.2
pdfminer.six==20250327

View file

@ -1,3 +1,4 @@
import os
import uvicorn
import argparse
@ -8,6 +9,8 @@ from api.main import app
app
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
parser = argparse.ArgumentParser(description="Run the FastAPI server")
parser.add_argument(
"--port", type=int, required=True, help="Port number to run the server on"

View file

@ -1,8 +1,11 @@
import os
import uvicorn
import argparse
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
parser = argparse.ArgumentParser(description="Run the FastAPI server")
parser.add_argument(
"--port", type=int, required=True, help="Port number to run the server on"

View file

@ -1,9 +1,12 @@
import os
import uvicorn
from dotenv import load_dotenv
load_dotenv()
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
uvicorn.run(
"api.main:app", host="0.0.0.0", port=8000, log_level="info", reload=True
)