diff --git a/servers/fastapi/ppt_generator/models/content_type_models.py b/servers/fastapi/ppt_generator/models/content_type_models.py index 4d6b1513..01145157 100644 --- a/servers/fastapi/ppt_generator/models/content_type_models.py +++ b/servers/fastapi/ppt_generator/models/content_type_models.py @@ -1,4 +1,4 @@ -from typing import List, Mapping +from typing import List, Mapping, Union from pydantic import BaseModel from ppt_generator.models.other_models import ( @@ -196,7 +196,19 @@ class Type9Content(SlideContentModel): ) -CONTENT_TYPE_MAPPING: Mapping[int, SlideContentModel] = { +ContentUnion = Union[ + Type1Content, + Type2Content, + Type3Content, + Type4Content, + Type5Content, + Type6Content, + Type7Content, + Type8Content, + Type9Content, +] + +CONTENT_TYPE_MAPPING: Mapping[int, ContentUnion] = { TYPE1: Type1Content, TYPE2: Type2Content, TYPE3: Type3Content, diff --git a/servers/fastapi/ppt_generator/models/llm_models.py b/servers/fastapi/ppt_generator/models/llm_models.py index a08292e8..a8a08456 100644 --- a/servers/fastapi/ppt_generator/models/llm_models.py +++ b/servers/fastapi/ppt_generator/models/llm_models.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Mapping +from typing import List, Literal, Mapping, Union from pydantic import BaseModel, Field from ppt_generator.models.content_type_models import ( @@ -55,6 +55,10 @@ class LLMHeadingModel(BaseModel): ) +class LLMHeadingModelNew(LLMHeadingModel): + pass + + class LLMHeadingModelWithImagePrompt(LLMHeadingModel): image_prompt: str = Field( description="Item image prompt in less than 5 words", @@ -89,7 +93,7 @@ class LLMSlideContentModel(BaseModel): class LLMType1Content(LLMSlideContentModel): - type: Literal["1"] = "1" + content_type: Literal["1"] = "1" body: str = Field( description="Slide content summary in less than 30 words.", ) @@ -106,8 +110,11 @@ class LLMType1Content(LLMSlideContentModel): class LLMType2Content(LLMSlideContentModel): - type: Literal["2"] = "2" - body: List[LLMHeadingModel] = Field( + content_type: Literal["2"] = Field( + "2", + description="Content type", + ) + body: List[LLMHeadingModelNew] = Field( description="Items to show in slide", min_length=1, max_length=4, @@ -121,7 +128,7 @@ class LLMType2Content(LLMSlideContentModel): class LLMType3Content(LLMSlideContentModel): - type: Literal["3"] = "3" + content_type: Literal["3"] = "3" body: List[LLMHeadingModel] = Field( description="Items to show in slide", min_length=3, @@ -140,7 +147,7 @@ class LLMType3Content(LLMSlideContentModel): class LLMType4Content(LLMSlideContentModel): - type: Literal["4"] = "4" + content_type: Literal["4"] = "4" body: List[LLMHeadingModelWithImagePrompt] = Field( description="Items to show in slide", min_length=1, @@ -156,7 +163,7 @@ class LLMType4Content(LLMSlideContentModel): class LLMType5Content(LLMSlideContentModel): - type: Literal["5"] = "5" + content_type: Literal["5"] = "5" body: str = Field( description="Slide content summary in less than 30 words.", ) @@ -171,11 +178,14 @@ class LLMType5Content(LLMSlideContentModel): class LLMType6Content(LLMSlideContentModel): - type: Literal["6"] = "6" + content_type: Literal["6"] = Field( + "6", + description="Content type", + ) description: str = Field( description="Slide content summary in less than 20 words.", ) - body: List[LLMHeadingModel] = Field( + body: List[LLMHeadingModelNew] = Field( description="Items to show in slide", min_length=1, max_length=3, @@ -190,7 +200,7 @@ class LLMType6Content(LLMSlideContentModel): class LLMType7Content(LLMSlideContentModel): - type: Literal["7"] = "7" + content_type: Literal["7"] = "7" body: List[LLMHeadingModelWithIconQuery] = Field( description="Items to show in slide", min_length=1, @@ -206,7 +216,7 @@ class LLMType7Content(LLMSlideContentModel): class LLMType8Content(LLMSlideContentModel): - type: Literal["8"] = "8" + content_type: Literal["8"] = "8" description: str = Field( description="Slide content summary in less than 20 words.", ) @@ -226,7 +236,7 @@ class LLMType8Content(LLMSlideContentModel): class LLMType9Content(LLMSlideContentModel): - type: Literal["9"] = "9" + content_type: Literal["9"] = "9" body: List[LLMHeadingModel] = Field( description="Items to show in slide", min_length=1, @@ -254,20 +264,26 @@ LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMSlideContentModel] = { TYPE9: LLMType9Content, } +LLMContentUnion = Union[ + # LLMType1Content, + LLMType2Content, + # LLMType3Content, + # LLMType4Content, + # LLMType5Content, + LLMType6Content, + # LLMType7Content, + # LLMType8Content, + # LLMType9Content, +] + class LLMSlideModel(BaseModel): type: int - content: ( - LLMType1Content - | LLMType2Content - | LLMType4Content - | LLMType5Content - | LLMType6Content - # | LLMType7Content - # | LLMType8Content - # | LLMType9Content + content: LLMContentUnion = Field( + description="Content of the slide", + discriminator="content_type", ) class LLMPresentationModel(BaseModel): - slides: list[LLMSlideModel] + slides: List[LLMSlideModel] diff --git a/servers/fastapi/ppt_generator/slide_generator.py b/servers/fastapi/ppt_generator/slide_generator.py index c3ca4b87..31a1bbf9 100644 --- a/servers/fastapi/ppt_generator/slide_generator.py +++ b/servers/fastapi/ppt_generator/slide_generator.py @@ -7,6 +7,7 @@ from ppt_config_generator.models import SlideMarkdownModel from ppt_generator.models.llm_models import ( LLM_CONTENT_TYPE_MAPPING, + LLMContentUnion, ) from ppt_generator.models.other_models import SlideTypeModel from ppt_generator.models.slide_model import SlideModel @@ -115,7 +116,7 @@ def get_prompt_to_select_slide_type(prompt: str, slide_data: dict, slide_type: i async def get_slide_content_from_type_and_outline( slide_type: int, outline: SlideMarkdownModel -) -> BaseModel: +) -> LLMContentUnion: response_model = LLM_CONTENT_TYPE_MAPPING[slide_type] client = get_llm_client() @@ -139,7 +140,7 @@ async def get_edited_slide_content_model( slide: SlideModel, theme: Optional[dict] = None, language: Optional[str] = None, -): +) -> LLMContentUnion: client = get_llm_client() model = get_large_model()