feat(fastapi): change all llm calls to use openai package
This commit is contained in:
parent
0239c794bd
commit
8ee5a4f53a
10 changed files with 369 additions and 177 deletions
|
|
@ -80,3 +80,4 @@ class SelectedLLMProvider(Enum):
|
|||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
CUSTOM = "custom"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import List
|
||||
import uuid, aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
|
@ -32,8 +33,6 @@ from langchain_core.output_parsers import JsonOutputParser
|
|||
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
|
||||
|
||||
|
||||
class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
|
|
@ -80,19 +79,17 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
|||
|
||||
print("-" * 40)
|
||||
print("Generating Presentation")
|
||||
presentation_text = (
|
||||
await generate_presentation(
|
||||
PresentationMarkdownModel(
|
||||
title=presentation_content.title,
|
||||
slides=presentation_content.slides,
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
presentation_text = await generate_presentation(
|
||||
PresentationMarkdownModel(
|
||||
title=presentation_content.title,
|
||||
slides=presentation_content.slides,
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
).content
|
||||
)
|
||||
|
||||
print("-" * 40)
|
||||
print("Parsing Presentation")
|
||||
presentation_json = output_parser.parse(presentation_text)
|
||||
presentation_json = json.loads(presentation_text)
|
||||
|
||||
slide_models: List[SlideModel] = []
|
||||
for i, slide in enumerate(presentation_json["slides"]):
|
||||
|
|
|
|||
|
|
@ -35,8 +35,6 @@ from langchain_core.output_parsers import JsonOutputParser
|
|||
|
||||
from ppt_generator.slide_generator import get_slide_content_from_type_and_outline
|
||||
|
||||
output_parser = JsonOutputParser(pydantic_object=LLMPresentationModel)
|
||||
|
||||
|
||||
class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
|
|
@ -151,16 +149,14 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin
|
|||
notes=self.presentation.notes,
|
||||
)
|
||||
):
|
||||
print(event)
|
||||
print("-" * 100)
|
||||
return
|
||||
# presentation_text += event
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
chunk = event.choices[0].delta.content
|
||||
presentation_text += chunk
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = output_parser.parse(presentation_text)
|
||||
self.presentation_json = json.loads(presentation_text)
|
||||
|
||||
async def generate_presentation_ollama(self):
|
||||
presentation_structure = PresentationStructureModel(
|
||||
|
|
|
|||
|
|
@ -16,12 +16,16 @@ def get_selected_llm_provider() -> SelectedLLMProvider:
|
|||
def get_model_base_url():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
|
||||
if selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return "http://localhost:11434/v1"
|
||||
elif selected_llm == SelectedLLMProvider.OPENAI:
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
else:
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("LLM_PROVIDER_URL", "http://localhost:11434/v1")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("LLM_PROVIDER_URL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM provider: {selected_llm}")
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
|
|
@ -30,8 +34,12 @@ def get_llm_api_key():
|
|||
return os.getenv("OPENAI_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("LLM_API_KEY", "ollama")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("LLM_API_KEY")
|
||||
else:
|
||||
return "ollama"
|
||||
raise ValueError(f"Invalid LLM provider: {selected_llm}")
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
|
|
|
|||
|
|
@ -1,8 +1,15 @@
|
|||
from typing import AsyncIterator
|
||||
|
||||
from api.utils.model_utils import get_large_model, get_llm_client
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.utils.model_utils import (
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
get_selected_llm_provider,
|
||||
)
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_generator.models.llm_models import LLMPresentationModel
|
||||
from ppt_generator.models.llm_models_with_validations import (
|
||||
LLMPresentationModelWithValidation,
|
||||
)
|
||||
|
||||
|
||||
CREATE_PRESENTATION_PROMPT = """
|
||||
|
|
@ -63,56 +70,87 @@ CREATE_PRESENTATION_PROMPT = """
|
|||
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
|
||||
"""
|
||||
|
||||
# schema = LLMPresentationModel.model_json_schema()
|
||||
system_prompt_with_schema = f"""
|
||||
{CREATE_PRESENTATION_PROMPT}
|
||||
|
||||
# system_prompt = f"""
|
||||
# {CREATE_PRESENTATION_PROMPT}
|
||||
Follow this schema while giving out response: {LLMPresentationModelWithValidation.model_json_schema()}.
|
||||
|
||||
# Follow this schema while giving out response: {schema}.
|
||||
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
|
||||
"""
|
||||
|
||||
# Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
|
||||
# """
|
||||
|
||||
# ollama_system_prompt = f"""
|
||||
# {CREATE_PRESENTATION_PROMPT}
|
||||
def get_system_prompt():
|
||||
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
|
||||
return (
|
||||
system_prompt_with_schema if is_google_selected else CREATE_PRESENTATION_PROMPT
|
||||
)
|
||||
|
||||
# Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
|
||||
# """
|
||||
|
||||
def get_response_format():
|
||||
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
|
||||
return (
|
||||
{
|
||||
"type": "json_object",
|
||||
}
|
||||
if is_google_selected
|
||||
else {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "LLMPresentationModel",
|
||||
"schema": LLMPresentationModelWithValidation.model_json_schema(),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def generate_presentation_stream(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
):
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
response_format = get_response_format()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": CREATE_PRESENTATION_PROMPT,
|
||||
"content": get_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": presentation_outline.to_string(),
|
||||
},
|
||||
],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "LLMPresentationModel",
|
||||
"schema": LLMPresentationModel.model_json_schema(),
|
||||
},
|
||||
},
|
||||
response_format=response_format,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_presentation(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
):
|
||||
# model, system_prompt, user_message = get_model_and_messages(presentation_outline)
|
||||
# return await model.ainvoke([system_prompt, user_message])
|
||||
pass
|
||||
) -> str:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
response_format = get_response_format()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": get_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": presentation_outline.to_string(),
|
||||
},
|
||||
],
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -57,9 +57,6 @@ class HeadingModel(BaseModel):
|
|||
class SlideContentModel(BaseModel):
|
||||
title: str
|
||||
|
||||
def to_llm_content(self):
|
||||
raise NotImplementedError("to_llm_content method not implemented")
|
||||
|
||||
|
||||
class Type1Content(SlideContentModel):
|
||||
body: str
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from typing import List, Literal, Mapping, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ppt_generator.models.content_type_models import (
|
||||
HeadingModel,
|
||||
SlideContentModel,
|
||||
TableDataModel,
|
||||
TableModel,
|
||||
Type1Content,
|
||||
|
|
@ -30,23 +29,19 @@ from ppt_generator.models.other_models import (
|
|||
|
||||
|
||||
class LLMTableDataModel(TableDataModel):
|
||||
x_labels: List[str] = Field(description="X labels of the table")
|
||||
y_labels: List[str] = Field(description="Y labels of the table")
|
||||
data: List[List[float]] = Field(description="Data of the table")
|
||||
x_labels: List[str]
|
||||
y_labels: List[str]
|
||||
data: List[List[float]]
|
||||
|
||||
|
||||
class LLMTableModel(TableModel):
|
||||
name: str = Field(description="Name of the table")
|
||||
name: str
|
||||
data: LLMTableDataModel
|
||||
|
||||
|
||||
class LLMHeadingModel(BaseModel):
|
||||
heading: str = Field(
|
||||
description="Item heading in less than 6 words",
|
||||
)
|
||||
description: str = Field(
|
||||
description="Item description in less than 15 words.",
|
||||
)
|
||||
heading: str
|
||||
description: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
|
|
@ -55,14 +50,8 @@ class LLMHeadingModel(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class LLMHeadingModelNew(LLMHeadingModel):
|
||||
pass
|
||||
|
||||
|
||||
class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
|
||||
image_prompt: str = Field(
|
||||
description="Item image prompt in less than 5 words",
|
||||
)
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
|
|
@ -72,9 +61,7 @@ class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
|
|||
|
||||
|
||||
class LLMHeadingModelWithIconQuery(LLMHeadingModel):
|
||||
icon_query: str = Field(
|
||||
description="Item icon query in less than 5 words",
|
||||
)
|
||||
icon_query: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
|
|
@ -84,22 +71,12 @@ class LLMHeadingModelWithIconQuery(LLMHeadingModel):
|
|||
|
||||
|
||||
class LLMSlideContentModel(BaseModel):
|
||||
# title: str = Field(
|
||||
# description="Slide title in less than 8 words",
|
||||
# )
|
||||
|
||||
def to_content(self) -> SlideContentModel:
|
||||
raise NotImplementedError("to_content method not implemented")
|
||||
title: str
|
||||
|
||||
|
||||
class LLMType1Content(LLMSlideContentModel):
|
||||
content_type: Literal["1"] = "1"
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Slide image prompt in less than 5 words",
|
||||
)
|
||||
body: str
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> Type1Content:
|
||||
return Type1Content(
|
||||
|
|
@ -110,15 +87,7 @@ class LLMType1Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType2Content(LLMSlideContentModel):
|
||||
content_type: Literal["2"] = Field(
|
||||
"2",
|
||||
description="Content type",
|
||||
)
|
||||
body: List[LLMHeadingModelNew] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
body: List[LLMHeadingModel]
|
||||
|
||||
def to_content(self) -> Type2Content:
|
||||
return Type2Content(
|
||||
|
|
@ -128,15 +97,8 @@ class LLMType2Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType3Content(LLMSlideContentModel):
|
||||
content_type: Literal["3"] = "3"
|
||||
body: List[LLMHeadingModel] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Slide image prompt in less than 5 words",
|
||||
)
|
||||
body: List[LLMHeadingModel]
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> Type3Content:
|
||||
return Type3Content(
|
||||
|
|
@ -147,12 +109,7 @@ class LLMType3Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType4Content(LLMSlideContentModel):
|
||||
content_type: Literal["4"] = "4"
|
||||
body: List[LLMHeadingModelWithImagePrompt] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePrompt]
|
||||
|
||||
def to_content(self) -> Type4Content:
|
||||
return Type4Content(
|
||||
|
|
@ -163,11 +120,8 @@ class LLMType4Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType5Content(LLMSlideContentModel):
|
||||
content_type: Literal["5"] = "5"
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
)
|
||||
table: LLMTableModel = Field(description="Table to show in slide")
|
||||
body: str
|
||||
table: LLMTableModel
|
||||
|
||||
def to_content(self) -> Type5Content:
|
||||
return Type5Content(
|
||||
|
|
@ -178,18 +132,8 @@ class LLMType5Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType6Content(LLMSlideContentModel):
|
||||
content_type: Literal["6"] = Field(
|
||||
"6",
|
||||
description="Content type",
|
||||
)
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelNew] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
description: str
|
||||
body: List[LLMHeadingModel]
|
||||
|
||||
def to_content(self) -> Type6Content:
|
||||
return Type6Content(
|
||||
|
|
@ -200,12 +144,7 @@ class LLMType6Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType7Content(LLMSlideContentModel):
|
||||
content_type: Literal["7"] = "7"
|
||||
body: List[LLMHeadingModelWithIconQuery] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
body: List[LLMHeadingModelWithIconQuery]
|
||||
|
||||
def to_content(self) -> Type7Content:
|
||||
return Type7Content(
|
||||
|
|
@ -216,15 +155,8 @@ class LLMType7Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType8Content(LLMSlideContentModel):
|
||||
content_type: Literal["8"] = "8"
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePrompt] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
description: str
|
||||
body: List[LLMHeadingModelWithImagePrompt]
|
||||
|
||||
def to_content(self) -> Type8Content:
|
||||
return Type8Content(
|
||||
|
|
@ -236,13 +168,8 @@ class LLMType8Content(LLMSlideContentModel):
|
|||
|
||||
|
||||
class LLMType9Content(LLMSlideContentModel):
|
||||
content_type: Literal["9"] = "9"
|
||||
body: List[LLMHeadingModel] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
table: LLMTableModel = Field(description="Table to show in slide")
|
||||
body: List[LLMHeadingModel]
|
||||
table: LLMTableModel
|
||||
|
||||
def to_content(self) -> Type9Content:
|
||||
return Type9Content(
|
||||
|
|
@ -252,7 +179,19 @@ class LLMType9Content(LLMSlideContentModel):
|
|||
)
|
||||
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
|
||||
LLMContentUnion = Union[
|
||||
LLMType1Content,
|
||||
LLMType2Content,
|
||||
LLMType3Content,
|
||||
LLMType4Content,
|
||||
LLMType5Content,
|
||||
LLMType6Content,
|
||||
LLMType7Content,
|
||||
LLMType8Content,
|
||||
LLMType9Content,
|
||||
]
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMContentUnion] = {
|
||||
TYPE1: LLMType1Content,
|
||||
TYPE2: LLMType2Content,
|
||||
TYPE3: LLMType3Content,
|
||||
|
|
@ -264,25 +203,10 @@ LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = {
|
|||
TYPE9: LLMType9Content,
|
||||
}
|
||||
|
||||
LLMContentUnion = Union[
|
||||
# LLMType1Content,
|
||||
LLMType2Content,
|
||||
# LLMType3Content,
|
||||
# LLMType4Content,
|
||||
# LLMType5Content,
|
||||
LLMType6Content,
|
||||
# LLMType7Content,
|
||||
# LLMType8Content,
|
||||
# LLMType9Content,
|
||||
]
|
||||
|
||||
|
||||
class LLMSlideModel(BaseModel):
|
||||
type: int
|
||||
content: LLMContentUnion = Field(
|
||||
description="Content of the slide",
|
||||
discriminator="content_type",
|
||||
)
|
||||
content: LLMContentUnion
|
||||
|
||||
|
||||
class LLMPresentationModel(BaseModel):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,227 @@
|
|||
from typing import List, Mapping, Union
|
||||
from pydantic import Field
|
||||
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
TYPE3,
|
||||
TYPE4,
|
||||
TYPE5,
|
||||
TYPE6,
|
||||
TYPE7,
|
||||
TYPE8,
|
||||
TYPE9,
|
||||
)
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLMTableDataModel,
|
||||
LLMTableModel,
|
||||
LLMHeadingModel,
|
||||
LLMHeadingModelWithImagePrompt,
|
||||
LLMHeadingModelWithIconQuery,
|
||||
LLMSlideContentModel,
|
||||
LLMType1Content,
|
||||
LLMType2Content,
|
||||
LLMType3Content,
|
||||
LLMType4Content,
|
||||
LLMType5Content,
|
||||
LLMType6Content,
|
||||
LLMType7Content,
|
||||
LLMType8Content,
|
||||
LLMType9Content,
|
||||
LLMSlideModel,
|
||||
LLMPresentationModel,
|
||||
)
|
||||
|
||||
|
||||
class LLMTableDataModelWithValidation(LLMTableDataModel):
|
||||
x_labels: List[str] = Field(
|
||||
description="X labels of the table",
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
)
|
||||
y_labels: List[str] = Field(
|
||||
description="Y labels of the table",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
data: List[List[float]] = Field(
|
||||
description="Data of the table",
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
)
|
||||
|
||||
|
||||
class LLMTableModelWithValidation(LLMTableModel):
|
||||
name: str = Field(
|
||||
description="Name of the table in less than 8 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
data: LLMTableDataModelWithValidation
|
||||
|
||||
|
||||
class LLMHeadingModelWithValidation(LLMHeadingModel):
|
||||
heading: str = Field(
|
||||
description="Item heading in less than 6 words",
|
||||
min_length=10,
|
||||
max_length=40,
|
||||
)
|
||||
description: str = Field(
|
||||
description="Item description in less than 15 words.",
|
||||
min_length=50,
|
||||
max_length=150,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
|
||||
image_prompt: str = Field(
|
||||
description="Item image prompt in less than 10 words",
|
||||
min_length=10,
|
||||
max_length=100,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
|
||||
icon_query: str = Field(
|
||||
description="Item icon query in less than 4 words",
|
||||
min_length=10,
|
||||
max_length=40,
|
||||
)
|
||||
|
||||
|
||||
class LLMSlideContentModelWithValidation(LLMSlideContentModel):
|
||||
title: str = Field(
|
||||
description="Slide title in less than 8 words",
|
||||
min_length=10,
|
||||
max_length=80,
|
||||
)
|
||||
|
||||
|
||||
class LLMType1ContentWithValidation(LLMType1Content):
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Slide image prompt in less than 5 words",
|
||||
min_length=10,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
|
||||
class LLMType2ContentWithValidation(LLMType2Content):
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
|
||||
class LLMType3ContentWithValidation(LLMType3Content):
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Slide image prompt in less than 5 words",
|
||||
min_length=10,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
|
||||
class LLMType4ContentWithValidation(LLMType4Content):
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
|
||||
class LLMType5ContentWithValidation(LLMType5Content):
|
||||
body: str = Field(
|
||||
description="Slide content summary in less than 30 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
|
||||
|
||||
class LLMType6ContentWithValidation(LLMType6Content):
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
|
||||
class LLMType7ContentWithValidation(LLMType7Content):
|
||||
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
|
||||
class LLMType8ContentWithValidation(LLMType8Content):
|
||||
description: str = Field(
|
||||
description="Slide content summary in less than 20 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
|
||||
class LLMType9ContentWithValidation(LLMType9Content):
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
|
||||
|
||||
LLMContentUnionWithValidation = Union[
|
||||
LLMType1ContentWithValidation,
|
||||
LLMType2ContentWithValidation,
|
||||
LLMType3ContentWithValidation,
|
||||
LLMType4ContentWithValidation,
|
||||
LLMType5ContentWithValidation,
|
||||
LLMType6ContentWithValidation,
|
||||
LLMType7ContentWithValidation,
|
||||
LLMType8ContentWithValidation,
|
||||
LLMType9ContentWithValidation,
|
||||
]
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION: Mapping[
|
||||
int, LLMContentUnionWithValidation
|
||||
] = {
|
||||
TYPE1: LLMType1ContentWithValidation,
|
||||
TYPE2: LLMType2ContentWithValidation,
|
||||
TYPE3: LLMType3ContentWithValidation,
|
||||
TYPE4: LLMType4ContentWithValidation,
|
||||
TYPE5: LLMType5ContentWithValidation,
|
||||
TYPE6: LLMType6ContentWithValidation,
|
||||
TYPE7: LLMType7ContentWithValidation,
|
||||
TYPE8: LLMType8ContentWithValidation,
|
||||
TYPE9: LLMType9ContentWithValidation,
|
||||
}
|
||||
|
||||
|
||||
class LLMSlideModelWithValidation(LLMSlideModel):
|
||||
type: int
|
||||
content: LLMContentUnionWithValidation
|
||||
|
||||
|
||||
class LLMPresentationModelWithValidation(LLMPresentationModel):
|
||||
slides: List[LLMSlideModelWithValidation]
|
||||
|
|
@ -9,6 +9,9 @@ from ppt_generator.models.llm_models import (
|
|||
LLM_CONTENT_TYPE_MAPPING,
|
||||
LLMContentUnion,
|
||||
)
|
||||
from ppt_generator.models.llm_models_with_validations import (
|
||||
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION,
|
||||
)
|
||||
from ppt_generator.models.other_models import SlideTypeModel
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
|
|
@ -117,20 +120,21 @@ def get_prompt_to_select_slide_type(prompt: str, slide_data: dict, slide_type: i
|
|||
async def get_slide_content_from_type_and_outline(
|
||||
slide_type: int, outline: SlideMarkdownModel
|
||||
) -> LLMContentUnion:
|
||||
response_model = LLM_CONTENT_TYPE_MAPPING[slide_type]
|
||||
response_model = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
temperature=0.5,
|
||||
messages=get_prompt_to_generate_slide_content(
|
||||
outline.title,
|
||||
outline.body,
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
|
||||
return response.choices[0].message.parsed
|
||||
|
||||
|
||||
|
|
@ -144,7 +148,7 @@ async def get_edited_slide_content_model(
|
|||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
content_type_model_type = LLM_CONTENT_TYPE_MAPPING[slide_type]
|
||||
content_type_model_type = LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION[slide_type]
|
||||
slide_data = slide.content.to_llm_content().model_dump_json()
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ SQLAlchemy==2.0.41
|
|||
sqlmodel==0.0.24
|
||||
starlette==0.46.2
|
||||
sympy==1.14.0
|
||||
tenacity==9.1.2
|
||||
tenacity==8.5.0
|
||||
tiktoken==0.9.0
|
||||
tokenizers==0.21.1
|
||||
tqdm==4.67.1
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue