feat(fastapi): use OpenAI client
This commit is contained in:
parent
3fe88f7725
commit
fff3b96fbe
21 changed files with 165 additions and 155 deletions
|
|
@ -6,4 +6,5 @@ out
|
|||
build
|
||||
.git
|
||||
.gitignore
|
||||
tmp
|
||||
tmp
|
||||
debug
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -7,4 +7,5 @@ __pycache__
|
|||
node_modules
|
||||
out
|
||||
user_data
|
||||
tmp
|
||||
tmp
|
||||
debug
|
||||
|
|
@ -24,9 +24,8 @@ RUN curl -fsSL https://ollama.com/install.sh | sh
|
|||
COPY servers/fastapi/requirements.txt ./
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
|
||||
# Install dependencies for Next.js
|
||||
WORKDIR /app/servers/nextjs
|
||||
WORKDIR /node_dependencies
|
||||
COPY servers/nextjs/package.json servers/nextjs/package-lock.json ./
|
||||
RUN npm install
|
||||
|
||||
|
|
@ -40,4 +39,10 @@ COPY nginx.conf /etc/nginx/nginx.conf
|
|||
EXPOSE 80 3000 8000
|
||||
|
||||
# Start the servers
|
||||
CMD ["/bin/bash", "-c", "ollama serve & service nginx start & service redis-server start && node /app/start.js"]
|
||||
CMD ["/bin/bash", "-c", "\
|
||||
rm -rf /app/servers/nextjs/node_modules && \
|
||||
ln -s /node_dependencies/node_modules /app/servers/nextjs/node_modules && \
|
||||
ollama serve & \
|
||||
service nginx start & \
|
||||
service redis-server start && \
|
||||
node /app/start.js"]
|
||||
|
|
@ -8,7 +8,8 @@ from contextlib import asynccontextmanager
|
|||
from api.routers.presentation.router import presentation_router
|
||||
from api.services.database import sql_engine
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from api.utils.utils import is_ollama_selected, update_env_with_user_config
|
||||
from api.utils.utils import update_env_with_user_config
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
|
||||
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|||
from api.utils.utils import (
|
||||
get_presentation_dir,
|
||||
get_presentation_images_dir,
|
||||
is_ollama_selected,
|
||||
)
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from api.routers.presentation.models import PresentationGenerateRequest
|
|||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.utils import is_ollama_selected
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel
|
||||
from ppt_config_generator.structure_generator import generate_presentation_structure
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from api.services.database import get_sql_session
|
|||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir, is_ollama_selected
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from api.routers.presentation.models import (
|
|||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir, is_ollama_selected
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from ppt_config_generator.models import (
|
||||
PresentationMarkdownModel,
|
||||
PresentationStructureModel,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from api.models import SelectedLLMProvider
|
||||
|
||||
|
||||
|
|
@ -15,38 +17,56 @@ def get_model_base_url():
|
|||
selected_llm = get_selected_llm_provider()
|
||||
|
||||
if selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return "http://localhost:11434"
|
||||
return "http://localhost:11434/v1"
|
||||
elif selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
else:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
else:
|
||||
return "ollama"
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
client = AsyncOpenAI(
|
||||
base_url=get_model_base_url(),
|
||||
api_key=get_llm_api_key(),
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return ChatOpenAI(model="gpt-4.1")
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
return "gemini-2.0-flash"
|
||||
else:
|
||||
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = os.getenv("LLM")
|
||||
if selected_llm == "openai":
|
||||
return ChatOpenAI(model="gpt-4.1-mini")
|
||||
elif selected_llm == "google":
|
||||
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
else:
|
||||
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = os.getenv("LLM")
|
||||
if selected_llm == "openai":
|
||||
return ChatOpenAI(model="gpt-4.1-nano")
|
||||
elif selected_llm == "google":
|
||||
return ChatGoogleGenerativeAI(model="gemini-2.0-flash")
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
else:
|
||||
return ChatOllama(model=os.getenv("OLLAMA_MODEL"), temperature=0.8)
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from openai import OpenAI
|
|||
from ppt_generator.models.query_and_prompt_models import (
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from api.utils.utils import download_file, get_resource, is_ollama_selected
|
||||
from api.utils.utils import download_file, get_resource
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
|
||||
|
||||
async def generate_image(
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from api.utils.utils import get_nano_model
|
||||
from api.utils.model_utils import get_llm_client, get_nano_model
|
||||
|
||||
sysmte_prompt = """
|
||||
Generate a blog-style summary of the provided document in **more than 2000 words**.
|
||||
|
|
@ -26,26 +22,28 @@ Maintain as much information as possible.
|
|||
- If **slides structure is mentioned** in document, structure the summary in the same way.
|
||||
"""
|
||||
|
||||
prompt_template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", sysmte_prompt),
|
||||
("user", "{text}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def generate_document_summary(documents: List[Document]):
|
||||
client = get_llm_client()
|
||||
model = get_nano_model()
|
||||
|
||||
text_splitter = CharacterTextSplitter(chunk_size=200000, chunk_overlap=0)
|
||||
chain = prompt_template | model
|
||||
|
||||
coroutines = []
|
||||
for document in documents:
|
||||
text = document.page_content
|
||||
truncated_text = text_splitter.split_text(text)[0]
|
||||
coroutine = chain.ainvoke({"text": truncated_text})
|
||||
coroutine = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": sysmte_prompt},
|
||||
{"role": "user", "content": truncated_text},
|
||||
],
|
||||
)
|
||||
coroutines.append(coroutine)
|
||||
|
||||
completions: List[BaseMessage] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join([completion.content for completion in completions])
|
||||
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join(
|
||||
[completion.choices[0].message.content for completion in completions]
|
||||
)
|
||||
return combined
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Optional
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
from api.utils.utils import get_large_model
|
||||
from api.utils.model_utils import get_large_model
|
||||
from api.utils.variable_length_models import (
|
||||
get_presentation_markdown_model_with_n_slides,
|
||||
)
|
||||
|
|
@ -64,7 +65,7 @@ async def generate_ppt_content(
|
|||
language: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> PresentationMarkdownModel:
|
||||
model = get_large_model()
|
||||
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
||||
response_model = get_presentation_markdown_model_with_n_slides(n_slides)
|
||||
|
||||
chain = get_prompt_template() | model.with_structured_output(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
from api.utils.utils import get_small_model
|
||||
from api.utils.model_utils import get_small_model
|
||||
from api.utils.variable_length_models import (
|
||||
get_presentation_structure_model_with_n_slides,
|
||||
)
|
||||
|
|
@ -59,7 +60,7 @@ async def generate_presentation_structure(
|
|||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> PresentationStructureModel:
|
||||
|
||||
model = get_small_model()
|
||||
model = ChatOllama(model=get_small_model(), temperature=0.8)
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from api.utils.utils import get_large_model
|
||||
from api.utils.model_utils import get_large_model
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
|
|
@ -41,7 +40,7 @@ def get_prompt_template():
|
|||
|
||||
|
||||
async def fix_validation_errors(response_model: BaseModel, response, errors):
|
||||
model = get_large_model()
|
||||
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
||||
|
||||
chain = get_prompt_template() | model.with_structured_output(
|
||||
response_model.model_json_schema()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from langchain_core.messages import (
|
|||
AIMessageChunk,
|
||||
AIMessage,
|
||||
)
|
||||
from api.utils.utils import get_large_model
|
||||
from langchain_ollama import ChatOllama
|
||||
from api.utils.model_utils import get_large_model
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_generator.models.llm_models_with_validations import (
|
||||
LLMPresentationModelWithValidation,
|
||||
|
|
@ -91,7 +92,7 @@ def get_model_and_messages(
|
|||
presentation_outline: PresentationMarkdownModel,
|
||||
):
|
||||
user_message = HumanMessage(presentation_outline.to_string())
|
||||
model = get_large_model()
|
||||
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
||||
|
||||
return model, system_prompt, user_message
|
||||
|
||||
|
|
|
|||
|
|
@ -34,48 +34,34 @@ from ppt_generator.models.llm_models import (
|
|||
|
||||
class LLMHeadingModelWithValidation(LLMHeadingModel):
|
||||
heading: str = Field(
|
||||
description="List item heading to show in slide body",
|
||||
min_length=10,
|
||||
max_length=30,
|
||||
description="List item heading to show in slide body in less than 5 words.",
|
||||
)
|
||||
description: str = Field(
|
||||
description="Description of list item in less than 20 words.",
|
||||
min_length=80,
|
||||
max_length=150,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
|
||||
image_prompt: str = Field(
|
||||
description="Prompt used to generate image for this item",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Prompt used to generate image for this item in less than 6 words.",
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
|
||||
icon_query: str = Field(
|
||||
description="Icon query to generate icon for this item",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Icon query to generate icon for this item in less than 4 words.",
|
||||
)
|
||||
|
||||
|
||||
class LLMType1ContentWithValidation(LLMType1Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
min_length=100,
|
||||
max_length=200,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Prompt used to generate image for this slide.",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Prompt used to generate image for this slide in less than 6 words.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -85,9 +71,7 @@ class LLMType1ContentWithValidation(LLMType1Content):
|
|||
|
||||
class LLMType2ContentWithValidation(LLMType2Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
@ -106,9 +90,7 @@ class LLMType2ContentWithValidation(LLMType2Content):
|
|||
|
||||
class LLMType3ContentWithValidation(LLMType3Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
@ -116,9 +98,7 @@ class LLMType3ContentWithValidation(LLMType3Content):
|
|||
max_length=3,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Prompt used to generate image for this slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Prompt used to generate image for this slide in less than 6 words.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -132,9 +112,7 @@ class LLMType3ContentWithValidation(LLMType3Content):
|
|||
|
||||
class LLMType4ContentWithValidation(LLMType4Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
@ -153,14 +131,10 @@ class LLMType4ContentWithValidation(LLMType4Content):
|
|||
|
||||
class LLMType5ContentWithValidation(LLMType5Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
min_length=100,
|
||||
max_length=250,
|
||||
)
|
||||
graph: GraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
|
|
@ -171,14 +145,10 @@ class LLMType5ContentWithValidation(LLMType5Content):
|
|||
|
||||
class LLMType6ContentWithValidation(LLMType6Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
min_length=80,
|
||||
max_length=150,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
@ -197,9 +167,7 @@ class LLMType6ContentWithValidation(LLMType6Content):
|
|||
|
||||
class LLMType7ContentWithValidation(LLMType7Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
@ -218,14 +186,10 @@ class LLMType7ContentWithValidation(LLMType7Content):
|
|||
|
||||
class LLMType8ContentWithValidation(LLMType8Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
min_length=80,
|
||||
max_length=150,
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
@ -244,9 +208,7 @@ class LLMType8ContentWithValidation(LLMType8Content):
|
|||
|
||||
class LLMType9ContentWithValidation(LLMType9Content):
|
||||
title: str = Field(
|
||||
description="Title of the slide",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
description="Title of the slide in less than 6 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="List items to show in slide's body",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from typing import Optional
|
||||
from api.utils.utils import get_large_model, get_small_model
|
||||
|
||||
from langchain_ollama import ChatOllama
|
||||
from openai import OpenAI
|
||||
from api.utils.model_utils import get_large_model, get_llm_client, get_small_model
|
||||
from ppt_config_generator.models import SlideMarkdownModel
|
||||
from ppt_generator.fix_validation_errors import get_validated_response
|
||||
|
||||
|
|
@ -16,42 +19,43 @@ from ppt_generator.models.other_models import SlideTypeModel
|
|||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
|
||||
prompt_template_to_generate_slide_content = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
|
||||
def get_prompt_to_generate_slide_content(
|
||||
title: str, outline: str, notes: Optional[str] = None
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
|
||||
|
||||
|
||||
# Steps
|
||||
1. Analyze the outline and title.
|
||||
2. Generate structured slide based on the outline and title.
|
||||
3. Generate image prompts and icon queries if mentioned in schema.
|
||||
4. Generate graph if mentioned in schema.
|
||||
# Steps
|
||||
1. Analyze the outline and title.
|
||||
2. Generate structured slide based on the outline and title.
|
||||
3. Generate image prompts and icon queries if mentioned in schema.
|
||||
4. Generate graph if mentioned in schema.
|
||||
|
||||
# Notes
|
||||
- Slide body should not use words like "This slide", "This presentation".
|
||||
- Rephrase the slide body to make it flow naturally.
|
||||
- Do not use markdown formatting in slide body.
|
||||
- **Icon query** must be a generic single word noun.
|
||||
- **Image prompt** should be a 2-3 words phrase.
|
||||
- Try to make paragraphs as short as possible.
|
||||
{notes}
|
||||
# Notes
|
||||
- Slide body should not use words like "This slide", "This presentation".
|
||||
- Rephrase the slide body to make it flow naturally.
|
||||
- Do not use markdown formatting in slide body.
|
||||
- **Icon query** must be a generic single word noun.
|
||||
- **Image prompt** should be a 2-3 words phrase.
|
||||
- Try to make paragraphs as short as possible.
|
||||
{notes}
|
||||
""",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
"""
|
||||
## Slide Title
|
||||
{title}
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
## Slide Title
|
||||
{title}
|
||||
|
||||
## Slide Outline
|
||||
{outline}
|
||||
""",
|
||||
),
|
||||
## Slide Outline
|
||||
{outline}
|
||||
""",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
prompt_template_to_edit_slide_content = ChatPromptTemplate.from_messages(
|
||||
|
|
@ -126,22 +130,26 @@ prompt_template_to_select_slide_type = ChatPromptTemplate.from_messages(
|
|||
async def get_slide_content_from_type_and_outline(
|
||||
slide_type: int, outline: SlideMarkdownModel
|
||||
) -> LLMSlideContentModel:
|
||||
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
|
||||
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
|
||||
model = get_small_model().with_structured_output(
|
||||
content_type_model_type.model_json_schema()
|
||||
)
|
||||
chain = prompt_template_to_generate_slide_content | model
|
||||
response_model = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
|
||||
|
||||
return await get_validated_response(
|
||||
chain,
|
||||
{
|
||||
"title": outline.title,
|
||||
"outline": outline.body,
|
||||
"notes": content_type_model_type.get_notes(),
|
||||
},
|
||||
content_type_model_type,
|
||||
validation_model,
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_generate_slide_content(
|
||||
outline.title,
|
||||
outline.body,
|
||||
response_model.get_notes(),
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
|
||||
with open("debug/llm_response.json", "w") as f:
|
||||
f.write(response.choices[0].message.content)
|
||||
|
||||
return LLM_CONTENT_TYPE_MAPPING[slide_type].model_validate_json(
|
||||
response.choices[0].message.content
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -152,7 +160,7 @@ async def get_edited_slide_content_model(
|
|||
theme: Optional[dict] = None,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
model = get_large_model()
|
||||
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
||||
|
||||
content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type]
|
||||
validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
|
||||
|
|
@ -181,7 +189,7 @@ async def get_slide_type_from_prompt(
|
|||
slide: SlideModel,
|
||||
) -> SlideTypeModel:
|
||||
|
||||
model = get_small_model()
|
||||
model = ChatOllama(model=get_small_model(), temperature=0.8)
|
||||
|
||||
chain = prompt_template_to_select_slide_type | model.with_structured_output(
|
||||
SlideTypeModel.model_json_schema()
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ mypy_extensions==1.1.0
|
|||
numpy==2.2.5
|
||||
ollama==0.5.1
|
||||
onnxruntime==1.22.0
|
||||
openai==1.78.1
|
||||
openai==1.91.0
|
||||
orjson==3.10.18
|
||||
packaging==24.2
|
||||
pdfminer.six==20250327
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import uvicorn
|
||||
import argparse
|
||||
|
||||
|
|
@ -8,6 +9,8 @@ from api.main import app
|
|||
app
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the FastAPI server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, required=True, help="Port number to run the server on"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
import uvicorn
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the FastAPI server")
|
||||
parser.add_argument(
|
||||
"--port", type=int, required=True, help="Port number to run the server on"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
|
||||
uvicorn.run(
|
||||
"api.main:app", host="0.0.0.0", port=8000, log_level="info", reload=True
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue