feat(fastapi): change all llm calls to use openai package

This commit is contained in:
sauravniraula 2025-06-28 00:44:43 +05:45
parent 0239c794bd
commit 8ee5a4f53a
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
10 changed files with 369 additions and 177 deletions

View file

@ -80,3 +80,4 @@ class SelectedLLMProvider(Enum):
OLLAMA = "ollama"
OPENAI = "openai"
GOOGLE = "google"
CUSTOM = "custom"

View file

@ -1,3 +1,4 @@
import json
from typing import List
import uuid, aiohttp
from fastapi import HTTPException
@ -32,8 +33,6 @@ from langchain_core.output_parsers import JsonOutputParser
from ppt_generator.models.slide_model import SlideModel
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
@ -80,19 +79,17 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
print("-" * 40)
print("Generating Presentation")
presentation_text = (
await generate_presentation(
PresentationMarkdownModel(
title=presentation_content.title,
slides=presentation_content.slides,
notes=presentation_content.notes,
)
presentation_text = await generate_presentation(
PresentationMarkdownModel(
title=presentation_content.title,
slides=presentation_content.slides,
notes=presentation_content.notes,
)
).content
)
print("-" * 40)
print("Parsing Presentation")
presentation_json = output_parser.parse(presentation_text)
presentation_json = json.loads(presentation_text)
slide_models: List[SlideModel] = []
for i, slide in enumerate(presentation_json["slides"]):

View file

@ -35,8 +35,6 @@ from langchain_core.output_parsers import JsonOutputParser
from ppt_generator.slide_generator import get_slide_content_from_type_and_outline
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin):
@ -151,16 +149,14 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin
notes=self.presentation.notes,
)
):
print(event)
print("-" * 100)
return
# presentation_text += event
yield SSEResponse(
event="response",
data=json.dumps({"type": "chunk", "chunk": chunk}),
).to_string()
chunk = event.choices[0].delta.content
presentation_text += chunk
yield SSEResponse(
event="response",
data=json.dumps({"type": "chunk", "chunk": chunk}),
).to_string()
self.presentation_json = output_parser.parse(presentation_text)
self.presentation_json = json.loads(presentation_text)
async def generate_presentation_ollama(self):
presentation_structure = PresentationStructureModel(

View file

@ -16,12 +16,16 @@ def get_selected_llm_provider() -> SelectedLLMProvider:
def get_model_base_url():
selected_llm = get_selected_llm_provider()
if selected_llm == SelectedLLMProvider.OLLAMA:
return "http://localhost:11434/v1"
elif selected_llm == SelectedLLMProvider.OPENAI:
if selected_llm == SelectedLLMProvider.OPENAI:
return "https://api.openai.com/v1"
else:
elif selected_llm == SelectedLLMProvider.GOOGLE:
return "https://generativelanguage.googleapis.com/v1beta/openai"
elif selected_llm == SelectedLLMProvider.OLLAMA:
return os.getenv("LLM_PROVIDER_URL", "http://localhost:11434/v1")
elif selected_llm == SelectedLLMProvider.CUSTOM:
return os.getenv("LLM_PROVIDER_URL")
else:
raise ValueError(f"Invalid LLM provider: {selected_llm}")
def get_llm_api_key():
@ -30,8 +34,12 @@ def get_llm_api_key():
return os.getenv("OPENAI_API_KEY")
elif selected_llm == SelectedLLMProvider.GOOGLE:
return os.getenv("GOOGLE_API_KEY")
elif selected_llm == SelectedLLMProvider.OLLAMA:
return os.getenv("LLM_API_KEY", "ollama")
elif selected_llm == SelectedLLMProvider.CUSTOM:
return os.getenv("LLM_API_KEY")
else:
return "ollama"
raise ValueError(f"Invalid LLM provider: {selected_llm}")
def get_llm_client():

View file

@ -1,8 +1,15 @@
from typing import AsyncIterator
from api.utils.model_utils import get_large_model, get_llm_client
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from api.models import SelectedLLMProvider
from api.utils.model_utils import (
get_large_model,
get_llm_client,
get_selected_llm_provider,
)
from ppt_config_generator.models import PresentationMarkdownModel
from ppt_generator.models.llm_models import LLMPresentationModel
from ppt_generator.models.llm_models_with_validations import (
LLMPresentationModelWithValidation,
)
CREATE_PRESENTATION_PROMPT = """
@ -63,56 +70,87 @@ CREATE_PRESENTATION_PROMPT = """
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
"""
# schema = LLMPresentationModel.model_json_schema()
system_prompt_with_schema = f"""
{CREATE_PRESENTATION_PROMPT}
# system_prompt = f"""
# {CREATE_PRESENTATION_PROMPT}
Follow this schema while giving out response: {LLMPresentationModelWithValidation.model_json_schema()}.
# Follow this schema while giving out response: {schema}.
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
"""
# Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
# """
# ollama_system_prompt = f"""
# {CREATE_PRESENTATION_PROMPT}
def get_system_prompt():
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
return (
system_prompt_with_schema if is_google_selected else CREATE_PRESENTATION_PROMPT
)
# Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
# """
def get_response_format():
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
return (
{
"type": "json_object",
}
if is_google_selected
else {
"type": "json_schema",
"json_schema": {
"name": "LLMPresentationModel",
"schema": LLMPresentationModelWithValidation.model_json_schema(),
},
}
)
async def generate_presentation_stream(
presentation_outline: PresentationMarkdownModel,
):
) -> AsyncStream[ChatCompletionChunk]:
client = get_llm_client()
model = get_large_model()
response_format = get_response_format()
response = await client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": CREATE_PRESENTATION_PROMPT,
"content": get_system_prompt(),
},
{
"role": "user",
"content": presentation_outline.to_string(),
},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "LLMPresentationModel",
"schema": LLMPresentationModel.model_json_schema(),
},
},
response_format=response_format,
stream=True,
)
return response
async def generate_presentation(
presentation_outline: PresentationMarkdownModel,
):
# model, system_prompt, user_message = get_model_and_messages(presentation_outline)
# return await model.ainvoke([system_prompt, user_message])
pass
) -> str:
client = get_llm_client()
model = get_large_model()
response_format = get_response_format()
response = await client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": get_system_prompt(),
},
{
"role": "user",
"content": presentation_outline.to_string(),
},
],
response_format=response_format,
)
return response.choices[0].message.content

View file

@ -57,9 +57,6 @@ class HeadingModel(BaseModel):
class SlideContentModel(BaseModel):
title: str
def to_llm_content(self):
raise NotImplementedError("to_llm_content method not implemented")
class Type1Content(SlideContentModel):
body: str

View file

@ -1,9 +1,8 @@
from typing import List, Literal, Mapping, Union
from pydantic import BaseModel, Field
from typing import List, Mapping, Union
from pydantic import BaseModel
from ppt_generator.models.content_type_models import (
HeadingModel,
SlideContentModel,
TableDataModel,
TableModel,
Type1Content,
@ -30,23 +29,19 @@ from ppt_generator.models.other_models import (
class LLMTableDataModel(TableDataModel):
x_labels: List[str] = Field(description="X labels of the table")
y_labels: List[str] = Field(description="Y labels of the table")
data: List[List[float]] = Field(description="Data of the table")
x_labels: List[str]
y_labels: List[str]
data: List[List[float]]
class LLMTableModel(TableModel):
name: str = Field(description="Name of the table")
name: str
data: LLMTableDataModel
class LLMHeadingModel(BaseModel):
heading: str = Field(
description="Item heading in less than 6 words",
)
description: str = Field(
description="Item description in less than 15 words.",
)
heading: str
description: str
def to_content(self) -> HeadingModel:
return HeadingModel(
@ -55,14 +50,8 @@ class LLMHeadingModel(BaseModel):
)
class LLMHeadingModelNew(LLMHeadingModel):
pass
class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
image_prompt: str = Field(
description="Item image prompt in less than 5 words",
)
image_prompt: str
def to_content(self) -> HeadingModel:
return HeadingModel(
@ -72,9 +61,7 @@ class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
class LLMHeadingModelWithIconQuery(LLMHeadingModel):
icon_query: str = Field(
description="Item icon query in less than 5 words",
)
icon_query: str
def to_content(self) -> HeadingModel:
return HeadingModel(
@ -84,22 +71,12 @@ class LLMHeadingModelWithIconQuery(LLMHeadingModel):
class LLMSlideContentModel(BaseModel):
# title: str = Field(
# description="Slide title in less than 8 words",
# )
def to_content(self) -> SlideContentModel:
raise NotImplementedError("to_content method not implemented")
title: str
class LLMType1Content(LLMSlideContentModel):
content_type: Literal["1"] = "1"
body: str = Field(
description="Slide content summary in less than 30 words.",
)
image_prompt: str = Field(
description="Slide image prompt in less than 5 words",
)
body: str
image_prompt: str
def to_content(self) -> Type1Content:
return Type1Content(
@ -110,15 +87,7 @@ class LLMType1Content(LLMSlideContentModel):
class LLMType2Content(LLMSlideContentModel):
content_type: Literal["2"] = Field(
"2",
description="Content type",
)
body: List[LLMHeadingModelNew] = Field(
description="Items to show in slide",
min_length=1,
max_length=4,
)
body: List[LLMHeadingModel]
def to_content(self) -> Type2Content:
return Type2Content(
@ -128,15 +97,8 @@ class LLMType2Content(LLMSlideContentModel):
class LLMType3Content(LLMSlideContentModel):
content_type: Literal["3"] = "3"
body: List[LLMHeadingModel] = Field(
description="Items to show in slide",
min_length=3,
max_length=3,
)
image_prompt: str = Field(
description="Slide image prompt in less than 5 words",
)
body: List[LLMHeadingModel]
image_prompt: str
def to_content(self) -> Type3Content:
return Type3Content(
@ -147,12 +109,7 @@ class LLMType3Content(LLMSlideContentModel):
class LLMType4Content(LLMSlideContentModel):
content_type: Literal["4"] = "4"
body: List[LLMHeadingModelWithImagePrompt] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
body: List[LLMHeadingModelWithImagePrompt]
def to_content(self) -> Type4Content:
return Type4Content(
@ -163,11 +120,8 @@ class LLMType4Content(LLMSlideContentModel):
class LLMType5Content(LLMSlideContentModel):
content_type: Literal["5"] = "5"
body: str = Field(
description="Slide content summary in less than 30 words.",
)
table: LLMTableModel = Field(description="Table to show in slide")
body: str
table: LLMTableModel
def to_content(self) -> Type5Content:
return Type5Content(
@ -178,18 +132,8 @@ class LLMType5Content(LLMSlideContentModel):
class LLMType6Content(LLMSlideContentModel):
content_type: Literal["6"] = Field(
"6",
description="Content type",
)
description: str = Field(
description="Slide content summary in less than 20 words.",
)
body: List[LLMHeadingModelNew] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
description: str
body: List[LLMHeadingModel]
def to_content(self) -> Type6Content:
return Type6Content(
@ -200,12 +144,7 @@ class LLMType6Content(LLMSlideContentModel):
class LLMType7Content(LLMSlideContentModel):
content_type: Literal["7"] = "7"
body: List[LLMHeadingModelWithIconQuery] = Field(
description="Items to show in slide",
min_length=1,
max_length=4,
)
body: List[LLMHeadingModelWithIconQuery]
def to_content(self) -> Type7Content:
return Type7Content(
@ -216,15 +155,8 @@ class LLMType7Content(LLMSlideContentModel):
class LLMType8Content(LLMSlideContentModel):
content_type: Literal["8"] = "8"
description: str = Field(
description="Slide content summary in less than 20 words.",
)
body: List[LLMHeadingModelWithImagePrompt] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
description: str
body: List[LLMHeadingModelWithImagePrompt]
def to_content(self) -> Type8Content:
return Type8Content(
@ -236,13 +168,8 @@ class LLMType8Content(LLMSlideContentModel):
class LLMType9Content(LLMSlideContentModel):
content_type: Literal["9"] = "9"
body: List[LLMHeadingModel] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
table: LLMTableModel = Field(description="Table to show in slide")
body: List[LLMHeadingModel]
table: LLMTableModel
def to_content(self) -> Type9Content:
return Type9Content(
@ -252,7 +179,19 @@ class LLMType9Content(LLMSlideContentModel):
)
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
LLMContentUnion = Union[
LLMType1Content,
LLMType2Content,
LLMType3Content,
LLMType4Content,
LLMType5Content,
LLMType6Content,
LLMType7Content,
LLMType8Content,
LLMType9Content,
]
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMContentUnion] = {
TYPE1: LLMType1Content,
TYPE2: LLMType2Content,
TYPE3: LLMType3Content,
@ -264,25 +203,10 @@ LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
TYPE9: LLMType9Content,
}
LLMContentUnion = Union[
# LLMType1Content,
LLMType2Content,
# LLMType3Content,
# LLMType4Content,
# LLMType5Content,
LLMType6Content,
# LLMType7Content,
# LLMType8Content,
# LLMType9Content,
]
class LLMSlideModel(BaseModel):
type: int
content: LLMContentUnion = Field(
description="Content of the slide",
discriminator="content_type",
)
content: LLMContentUnion
class LLMPresentationModel(BaseModel):

View file

@ -0,0 +1,227 @@
from typing import List, Mapping, Union
from pydantic import Field
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
TYPE3,
TYPE4,
TYPE5,
TYPE6,
TYPE7,
TYPE8,
TYPE9,
)
from ppt_generator.models.llm_models import (
LLMTableDataModel,
LLMTableModel,
LLMHeadingModel,
LLMHeadingModelWithImagePrompt,
LLMHeadingModelWithIconQuery,
LLMSlideContentModel,
LLMType1Content,
LLMType2Content,
LLMType3Content,
LLMType4Content,
LLMType5Content,
LLMType6Content,
LLMType7Content,
LLMType8Content,
LLMType9Content,
LLMSlideModel,
LLMPresentationModel,
)
class LLMTableDataModelWithValidation(LLMTableDataModel):
x_labels: List[str] = Field(
description="X labels of the table",
min_length=1,
max_length=5,
)
y_labels: List[str] = Field(
description="Y labels of the table",
min_length=1,
max_length=3,
)
data: List[List[float]] = Field(
description="Data of the table",
min_length=1,
max_length=5,
)
class LLMTableModelWithValidation(LLMTableModel):
name: str = Field(
description="Name of the table in less than 8 words",
min_length=10,
max_length=50,
)
data: LLMTableDataModelWithValidation
class LLMHeadingModelWithValidation(LLMHeadingModel):
heading: str = Field(
description="Item heading in less than 6 words",
min_length=10,
max_length=40,
)
description: str = Field(
description="Item description in less than 15 words.",
min_length=50,
max_length=150,
)
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
image_prompt: str = Field(
description="Item image prompt in less than 10 words",
min_length=10,
max_length=100,
)
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
icon_query: str = Field(
description="Item icon query in less than 4 words",
min_length=10,
max_length=40,
)
class LLMSlideContentModelWithValidation(LLMSlideContentModel):
title: str = Field(
description="Slide title in less than 8 words",
min_length=10,
max_length=80,
)
class LLMType1ContentWithValidation(LLMType1Content):
body: str = Field(
description="Slide content summary in less than 30 words.",
min_length=50,
max_length=300,
)
image_prompt: str = Field(
description="Slide image prompt in less than 5 words",
min_length=10,
max_length=30,
)
class LLMType2ContentWithValidation(LLMType2Content):
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=4,
)
class LLMType3ContentWithValidation(LLMType3Content):
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=3,
max_length=3,
)
image_prompt: str = Field(
description="Slide image prompt in less than 5 words",
min_length=10,
max_length=30,
)
class LLMType4ContentWithValidation(LLMType4Content):
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
class LLMType5ContentWithValidation(LLMType5Content):
body: str = Field(
description="Slide content summary in less than 30 words.",
min_length=50,
max_length=300,
)
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
class LLMType6ContentWithValidation(LLMType6Content):
description: str = Field(
description="Slide content summary in less than 20 words.",
min_length=50,
max_length=300,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
class LLMType7ContentWithValidation(LLMType7Content):
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=4,
)
class LLMType8ContentWithValidation(LLMType8Content):
description: str = Field(
description="Slide content summary in less than 20 words.",
min_length=50,
max_length=300,
)
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
class LLMType9ContentWithValidation(LLMType9Content):
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
LLMContentUnionWithValidation = Union[
LLMType1ContentWithValidation,
LLMType2ContentWithValidation,
LLMType3ContentWithValidation,
LLMType4ContentWithValidation,
LLMType5ContentWithValidation,
LLMType6ContentWithValidation,
LLMType7ContentWithValidation,
LLMType8ContentWithValidation,
LLMType9ContentWithValidation,
]
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION: Mapping[
int, LLMContentUnionWithValidation
] = {
TYPE1: LLMType1ContentWithValidation,
TYPE2: LLMType2ContentWithValidation,
TYPE3: LLMType3ContentWithValidation,
TYPE4: LLMType4ContentWithValidation,
TYPE5: LLMType5ContentWithValidation,
TYPE6: LLMType6ContentWithValidation,
TYPE7: LLMType7ContentWithValidation,
TYPE8: LLMType8ContentWithValidation,
TYPE9: LLMType9ContentWithValidation,
}
class LLMSlideModelWithValidation(LLMSlideModel):
type: int
content: LLMContentUnionWithValidation
class LLMPresentationModelWithValidation(LLMPresentationModel):
slides: List[LLMSlideModelWithValidation]

View file

@ -9,6 +9,9 @@ from ppt_generator.models.llm_models import (
LLM_CONTENT_TYPE_MAPPING,
LLMContentUnion,
)
from ppt_generator.models.llm_models_with_validations import (
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION,
)
from ppt_generator.models.other_models import SlideTypeModel
from ppt_generator.models.slide_model import SlideModel
@ -117,20 +120,21 @@ def get_prompt_to_select_slide_type(prompt: str, slide_data: dict, slide_type: i
async def get_slide_content_from_type_and_outline(
slide_type: int, outline: SlideMarkdownModel
) -> LLMContentUnion:
response_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
response_model = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
client = get_llm_client()
model = get_small_model()
response = await client.beta.chat.completions.parse(
model=model,
temperature=0.2,
temperature=0.5,
messages=get_prompt_to_generate_slide_content(
outline.title,
outline.body,
),
response_format=response_model,
)
return response.choices[0].message.parsed
@ -144,7 +148,7 @@ async def get_edited_slide_content_model(
client = get_llm_client()
model = get_large_model()
content_type_model_type = LLM_CONTENT_TYPE_MAPPING[slide_type]
content_type_model_type = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
slide_data = slide.content.to_llm_content().model_dump_json()
response = await client.beta.chat.completions.parse(
model=model,

View file

@ -102,7 +102,7 @@ SQLAlchemy==2.0.41
sqlmodel==0.0.24
starlette==0.46.2
sympy==1.14.0
tenacity==9.1.2
tenacity==8.5.0
tiktoken==0.9.0
tokenizers==0.21.1
tqdm==4.67.1