diff --git a/Dockerfile b/Dockerfile index ecf0dbbb..5fe801a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,4 +50,4 @@ COPY nginx.conf /etc/nginx/nginx.conf EXPOSE 80 # Start the servers -CMD ["/bin/bash", "-c", "ollama serve & service nginx start && service redis-server start && node /app/start.js"] \ No newline at end of file +CMD ["/bin/bash", "/app/docker-start.sh"] \ No newline at end of file diff --git a/Dockerfile.dev b/Dockerfile.dev index 1bf2c4db..664eb16a 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -32,6 +32,8 @@ RUN npm install # Install chrome for puppeteer RUN npx puppeteer browsers install chrome --install-deps +RUN chmod -R 777 /node_dependencies + # Copy nginx configuration COPY nginx.conf /etc/nginx/nginx.conf @@ -39,10 +41,4 @@ COPY nginx.conf /etc/nginx/nginx.conf EXPOSE 80 3000 8000 # Start the servers -CMD ["/bin/bash", "-c", "\ - rm -rf /app/servers/nextjs/node_modules && \ - ln -s /node_dependencies/node_modules /app/servers/nextjs/node_modules && \ - ollama serve & \ - service nginx start & \ - service redis-server start && \ - node /app/start.js"] \ No newline at end of file +CMD ["/bin/bash", "/app/docker-dev-start.sh"] \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 52262bf0..49cf6c49 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -48,7 +48,4 @@ services: - OPENAI_API_KEY=${OPENAI_API_KEY} - GOOGLE_API_KEY=${GOOGLE_API_KEY} - OLLAMA_MODEL=${OLLAMA_MODEL} - - PEXELS_API_KEY=${PEXELS_API_KEY} - - LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2} - - LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY} - - LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT} \ No newline at end of file + - PEXELS_API_KEY=${PEXELS_API_KEY} \ No newline at end of file diff --git a/docker-dev-start.sh b/docker-dev-start.sh new file mode 100644 index 00000000..bd6ae2ad --- /dev/null +++ b/docker-dev-start.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +echo "Starting development server..." + +if [ -d "/node_dependencies/node_modules" ]; then + rm -rf /app/servers/nextjs/node_modules + mv /node_dependencies/node_modules /app/servers/nextjs +fi + +ollama serve & +service nginx start +service redis-server start +node /app/start.js diff --git a/docker-start.sh b/docker-start.sh new file mode 100644 index 00000000..22c80b91 --- /dev/null +++ b/docker-start.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +echo "Starting production server..." + +ollama serve & +service nginx start +service redis-server start +node /app/start.js diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_stream.py b/servers/fastapi/api/routers/presentation/handlers/generate_stream.py index 78d44aa5..d0125477 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_stream.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_stream.py @@ -24,7 +24,6 @@ from ppt_config_generator.models import ( PresentationStructureModel, ) from ppt_generator.generator import generate_presentation_stream -from ppt_generator.models.content_type_models import CONTENT_TYPE_MAPPING from ppt_generator.models.llm_models import ( LLM_CONTENT_TYPE_MAPPING, LLMPresentationModel, @@ -145,18 +144,21 @@ class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin async def generate_presentation_openai_google(self): presentation_text = "" - async for chunk in generate_presentation_stream( + async for event in await generate_presentation_stream( PresentationMarkdownModel( title=self.title, slides=self.outlines, notes=self.presentation.notes, ) ): - presentation_text += chunk.content - yield SSEResponse( - event="response", - data=json.dumps({"type": "chunk", "chunk": chunk.content}), - ).to_string() + print(event) + print("-" * 100) + return + # presentation_text += event + yield SSEResponse( + event="response", + data=json.dumps({"type": "chunk", "chunk": chunk}), + ).to_string() self.presentation_json = output_parser.parse(presentation_text) diff --git a/servers/fastapi/api/utils/utils.py b/servers/fastapi/api/utils/utils.py index 16b434d0..975b751d 100644 --- a/servers/fastapi/api/utils/utils.py +++ b/servers/fastapi/api/utils/utils.py @@ -129,35 +129,35 @@ async def download_files(urls: List[str], save_paths: List[str]): async def handle_errors( func, logging_service: LoggingService, log_metadata: LogMetadata, **kwargs ): - try: - logging_service.logger.info(f"START", extra=log_metadata.model_dump()) - response = await func( - logging_service=logging_service, log_metadata=log_metadata, **kwargs - ) - is_stream = isinstance(response, StreamingResponse) - logging_service.logger.info( - "STREAMING" if is_stream else "END", extra=log_metadata.model_dump() - ) - return response + # try: + logging_service.logger.info(f"START", extra=log_metadata.model_dump()) + response = await func( + logging_service=logging_service, log_metadata=log_metadata, **kwargs + ) + is_stream = isinstance(response, StreamingResponse) + logging_service.logger.info( + "STREAMING" if is_stream else "END", extra=log_metadata.model_dump() + ) + return response - except HTTPException as e: - log_metadata.status_code = e.status_code - logging_service.logger.error( - f"Raised HTTPException - {e.detail}", extra=log_metadata.model_dump() - ) - raise e - except Exception as e: - print(traceback.print_stack()) - print(traceback.print_exc()) + # except HTTPException as e: + # log_metadata.status_code = e.status_code + # logging_service.logger.error( + # f"Raised HTTPException - {e.detail}", extra=log_metadata.model_dump() + # ) + # raise e + # except Exception as e: + # print(traceback.print_stack()) + # print(traceback.print_exc()) - log_metadata.status_code = 400 - logging_service.logger.critical( - "Unhandled Exception", - exc_info=True, - stack_info=True, - extra=log_metadata.model_dump(), - ) - raise HTTPException(400, "Something went wrong while processing your request.") + # log_metadata.status_code = 400 + # logging_service.logger.critical( + # "Unhandled Exception", + # exc_info=True, + # stack_info=True, + # extra=log_metadata.model_dump(), + # ) + # raise HTTPException(400, "Something went wrong while processing your request.") def sanitize_filename(filename: str) -> str: diff --git a/servers/fastapi/graph_processor/models.py b/servers/fastapi/graph_processor/models.py deleted file mode 100644 index 6ced32b1..00000000 --- a/servers/fastapi/graph_processor/models.py +++ /dev/null @@ -1,132 +0,0 @@ -from enum import Enum -from typing import List, Optional -from pydantic import BaseModel, Field, model_validator - -from graph_processor.utils import clip_text - - -class PointModel(BaseModel): - x: float - y: float - - def to_list(self) -> List[float]: - return [self.x, self.y] - - -class PointWithRadius(PointModel): - radius: Optional[float] = None - - -class BarSeriesModel(BaseModel): - name: str - data: List[float] = Field( - description="Only numbers should be given out in data. Don't include text/string in data." - ) - - def get_name(self) -> str: - return clip_text(self.name) - - -class ScatterSeriesModel(BaseModel): - name: str - points: List[PointModel] - - def get_name(self) -> str: - return clip_text(self.name) - - -class BubbleSeriesModel(BaseModel): - name: str - points: List[PointWithRadius] - - def get_name(self) -> str: - return clip_text(self.name) - - -class LineSeriesModel(BaseModel): - name: str - data: List[float] = Field( - description="Only numbers should be given out in data. Don't include text/string in data." - ) - - def get_name(self) -> str: - return clip_text(self.name) - - -class PieChartSeriesModel(BaseModel): - data: List[float] - - -class BarGraphDataModel(BaseModel): - categories: List[str] - series: List[BarSeriesModel] = Field( - description="There should be no more than 3 series" - ) - - def get_categories(self) -> List[str]: - return [clip_text(category) for category in self.categories] - - -class ScatterChartDataModel(BaseModel): - series: List[ScatterSeriesModel] - - -class BubbleChartDataModel(BaseModel): - series: List[BubbleSeriesModel] - - -class LineChartDataModel(BaseModel): - categories: List[str] - series: List[LineSeriesModel] = Field( - description="There should be no more than 3 series" - ) - - def get_categories(self) -> List[str]: - return [clip_text(category) for category in self.categories] - - -class PieChartDataModel(BaseModel): - categories: List[str] - series: List[PieChartSeriesModel] = Field( - description="One series model with list of data", - min_length=1, - ) - - @model_validator(mode="after") - def limit_series(self): - self.series = self.series[:1] - return self - - def get_categories(self) -> List[str]: - return [clip_text(category) for category in self.categories] - - -# class TableDataModel(BaseModel): -# categories: List[str] -# series: List[BarSeriesModel] - -# def get_categories(self) -> List[str]: -# return [clip_text(category) for category in self.categories] - - -class GraphTypeEnum(Enum): - pie = "pie" - bar = "bar" - line = "line" - - -class GraphModel(BaseModel): - style: Optional[dict] = {} - name: str - type: GraphTypeEnum - unit: Optional[str] = Field( - description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc." - ) - data: PieChartDataModel | LineChartDataModel | BarGraphDataModel - - -GRAPH_TYPE_MAPPING = { - GraphTypeEnum.pie: PieChartDataModel, - GraphTypeEnum.bar: BarGraphDataModel, - GraphTypeEnum.line: LineChartDataModel, -} diff --git a/servers/fastapi/graph_processor/utils.py b/servers/fastapi/graph_processor/utils.py deleted file mode 100644 index 502145f0..00000000 --- a/servers/fastapi/graph_processor/utils.py +++ /dev/null @@ -1,3 +0,0 @@ -def clip_text(text: str, max_length: int = 6) -> str: - # return text[:max_length] + ".." if len(text) > max_length else text - return text diff --git a/servers/fastapi/ppt_config_generator/ppt_outlines_generator.py b/servers/fastapi/ppt_config_generator/ppt_outlines_generator.py index 3bfc17f3..06b5c1ce 100644 --- a/servers/fastapi/ppt_config_generator/ppt_outlines_generator.py +++ b/servers/fastapi/ppt_config_generator/ppt_outlines_generator.py @@ -1,33 +1,17 @@ from typing import Optional -from langchain_core.prompts import ChatPromptTemplate -from langchain_ollama import ChatOllama -from api.utils.model_utils import get_large_model +from api.utils.model_utils import get_large_model, get_llm_client from api.utils.variable_length_models import ( get_presentation_markdown_model_with_n_slides, ) from ppt_config_generator.models import PresentationMarkdownModel -from ppt_generator.fix_validation_errors import get_validated_response -user_prompt_text = { - "type": "text", - "text": """ - **Input:** - - Prompt: {prompt} - - Output Language: {language} - - Number of Slides: {n_slides} - - Additional Information: {content} - """, -} - - -def get_prompt_template(): - return ChatPromptTemplate.from_messages( - [ - ( - "system", - """ +def get_prompt_template(prompt: str, n_slides: int, language: str, content: str): + return [ + { + "role": "system", + "content": """ Create a presentation based on the provided prompt, number of slides, output language, and additional informational details. Format the output in the specified JSON schema with structured markdown content. @@ -50,13 +34,18 @@ def get_prompt_template(): - Slide **title** should not be in markdown format. - There must be exact **Number of Slides** as specified. """, - ), - ( - "user", - [user_prompt_text], - ), - ], - ) + }, + { + "role": "user", + "content": f""" + **Input:** + - Prompt: {prompt} + - Output Language: {language} + - Number of Slides: {n_slides} + - Additional Information: {content} + """, + }, + ] async def generate_ppt_content( @@ -65,21 +54,14 @@ async def generate_ppt_content( language: Optional[str] = None, content: Optional[str] = None, ) -> PresentationMarkdownModel: - model = ChatOllama(model=get_large_model(), temperature=0.8) + client = get_llm_client() + model = get_large_model() response_model = get_presentation_markdown_model_with_n_slides(n_slides) - chain = get_prompt_template() | model.with_structured_output( - response_model.model_json_schema() - ) - - return await get_validated_response( - chain, - { - "prompt": prompt, - "n_slides": n_slides, - "language": language or "English", - "content": content, - }, - response_model, - PresentationMarkdownModel, + response = await client.beta.chat.completions.parse( + model=model, + temperature=0.2, + messages=get_prompt_template(prompt, n_slides, language, content), + response_format=response_model, ) + return response.choices[0].message.parsed diff --git a/servers/fastapi/ppt_config_generator/structure_generator.py b/servers/fastapi/ppt_config_generator/structure_generator.py index 98275a6b..9e0f4972 100644 --- a/servers/fastapi/ppt_config_generator/structure_generator.py +++ b/servers/fastapi/ppt_config_generator/structure_generator.py @@ -1,7 +1,4 @@ -from langchain_core.prompts import ChatPromptTemplate -from langchain_ollama import ChatOllama - -from api.utils.model_utils import get_small_model +from api.utils.model_utils import get_llm_client, get_small_model from api.utils.variable_length_models import ( get_presentation_structure_model_with_n_slides, ) @@ -9,13 +6,13 @@ from ppt_config_generator.models import ( PresentationStructureModel, PresentationMarkdownModel, ) -from ppt_generator.fix_validation_errors import get_validated_response -prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - """ + +def get_prompt(n_slides: int, data: str): + return [ + { + "role": "system", + "content": f""" You're a professional presentation designer with years of experience in designing clear and engaging presentations. # Slide Types @@ -45,33 +42,32 @@ prompt = ChatPromptTemplate.from_messages( **Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.** """, - ), - ( - "human", - """ - {data} + }, + { + "role": "user", + "content": f""" + {data} """, - ), + }, ] -) async def generate_presentation_structure( presentation_outline: PresentationMarkdownModel, ) -> PresentationStructureModel: - model = ChatOllama(model=get_small_model(), temperature=0.8) + client = get_llm_client() + model = get_small_model() response_model = get_presentation_structure_model_with_n_slides( len(presentation_outline.slides) ) - chain = prompt | model.with_structured_output(response_model.model_json_schema()) - return await get_validated_response( - chain, - { - "n_slides": len(presentation_outline.slides), - "data": presentation_outline.to_string(), - }, - response_model, - PresentationStructureModel, + response = await client.beta.chat.completions.parse( + model=model, + temperature=0.2, + messages=get_prompt( + len(presentation_outline.slides), presentation_outline.to_string() + ), + response_format=response_model, ) + return response.choices[0].message.parsed diff --git a/servers/fastapi/ppt_generator/generator.py b/servers/fastapi/ppt_generator/generator.py index 82ba9bf5..fd44ec18 100644 --- a/servers/fastapi/ppt_generator/generator.py +++ b/servers/fastapi/ppt_generator/generator.py @@ -1,16 +1,8 @@ from typing import AsyncIterator -from langchain_core.messages import ( - HumanMessage, - AIMessageChunk, - AIMessage, -) -from langchain_ollama import ChatOllama -from api.utils.model_utils import get_large_model +from api.utils.model_utils import get_large_model, get_llm_client from ppt_config_generator.models import PresentationMarkdownModel -from ppt_generator.models.llm_models_with_validations import ( - LLMPresentationModelWithValidation, -) +from ppt_generator.models.llm_models import LLMPresentationModel CREATE_PRESENTATION_PROMPT = """ @@ -71,42 +63,56 @@ CREATE_PRESENTATION_PROMPT = """ **Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.** """ -schema = LLMPresentationModelWithValidation.model_json_schema() +# schema = LLMPresentationModel.model_json_schema() -system_prompt = f""" -{CREATE_PRESENTATION_PROMPT} +# system_prompt = f""" +# {CREATE_PRESENTATION_PROMPT} -Follow this schema while giving out response: {schema}. +# Follow this schema while giving out response: {schema}. -Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else. -""" +# Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else. +# """ -ollama_system_prompt = f""" -{CREATE_PRESENTATION_PROMPT} +# ollama_system_prompt = f""" +# {CREATE_PRESENTATION_PROMPT} -Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else. -""" +# Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else. +# """ -def get_model_and_messages( +async def generate_presentation_stream( presentation_outline: PresentationMarkdownModel, ): - user_message = HumanMessage(presentation_outline.to_string()) - model = ChatOllama(model=get_large_model(), temperature=0.8) + client = get_llm_client() + model = get_large_model() - return model, system_prompt, user_message - - -def generate_presentation_stream( - presentation_outline: PresentationMarkdownModel, -) -> AsyncIterator[AIMessageChunk]: - model, system_prompt, user_message = get_model_and_messages(presentation_outline) - - return model.astream([system_prompt, user_message]) + response = await client.chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": CREATE_PRESENTATION_PROMPT, + }, + { + "role": "user", + "content": presentation_outline.to_string(), + }, + ], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "LLMPresentationModel", + "schema": LLMPresentationModel.model_json_schema(), + }, + }, + stream=True, + ) + return response async def generate_presentation( presentation_outline: PresentationMarkdownModel, -) -> AIMessage: - model, system_prompt, user_message = get_model_and_messages(presentation_outline) - return await model.ainvoke([system_prompt, user_message]) +): + # model, system_prompt, user_message = get_model_and_messages(presentation_outline) + # return await model.ainvoke([system_prompt, user_message]) + pass diff --git a/servers/fastapi/ppt_generator/models/content_type_models.py b/servers/fastapi/ppt_generator/models/content_type_models.py index f13cb76c..4d6b1513 100644 --- a/servers/fastapi/ppt_generator/models/content_type_models.py +++ b/servers/fastapi/ppt_generator/models/content_type_models.py @@ -12,7 +12,17 @@ from ppt_generator.models.other_models import ( TYPE8, TYPE9, ) -from graph_processor.models import GraphModel + + +class TableDataModel(BaseModel): + x_labels: List[str] + y_labels: List[str] + data: List[List[float]] + + +class TableModel(BaseModel): + name: str + data: TableDataModel class HeadingModel(BaseModel): @@ -110,7 +120,7 @@ class Type4Content(SlideContentModel): class Type5Content(SlideContentModel): body: str - graph: GraphModel + table: TableModel def to_llm_content(self): from ppt_generator.models.llm_models import LLMType5Content @@ -118,7 +128,7 @@ class Type5Content(SlideContentModel): return LLMType5Content( title=self.title, body=self.body, - graph=self.graph, + table=self.table, ) @@ -174,7 +184,7 @@ class Type8Content(SlideContentModel): class Type9Content(SlideContentModel): body: List[HeadingModel] - graph: GraphModel + table: TableModel def to_llm_content(self): from ppt_generator.models.llm_models import LLMType9Content @@ -182,7 +192,7 @@ class Type9Content(SlideContentModel): return LLMType9Content( title=self.title, body=[item.to_llm_content() for item in self.body], - graph=self.graph, + table=self.table, ) diff --git a/servers/fastapi/ppt_generator/models/llm_models.py b/servers/fastapi/ppt_generator/models/llm_models.py index 1ff43feb..a08292e8 100644 --- a/servers/fastapi/ppt_generator/models/llm_models.py +++ b/servers/fastapi/ppt_generator/models/llm_models.py @@ -1,10 +1,11 @@ -from typing import List, Mapping -from pydantic import BaseModel +from typing import List, Literal, Mapping +from pydantic import BaseModel, Field -from graph_processor.models import GraphModel from ppt_generator.models.content_type_models import ( HeadingModel, SlideContentModel, + TableDataModel, + TableModel, Type1Content, Type2Content, Type3Content, @@ -28,9 +29,24 @@ from ppt_generator.models.other_models import ( ) +class LLMTableDataModel(TableDataModel): + x_labels: List[str] = Field(description="X labels of the table") + y_labels: List[str] = Field(description="Y labels of the table") + data: List[List[float]] = Field(description="Data of the table") + + +class LLMTableModel(TableModel): + name: str = Field(description="Name of the table") + data: LLMTableDataModel + + class LLMHeadingModel(BaseModel): - heading: str - description: str + heading: str = Field( + description="Item heading in less than 6 words", + ) + description: str = Field( + description="Item description in less than 15 words.", + ) def to_content(self) -> HeadingModel: return HeadingModel( @@ -40,23 +56,46 @@ class LLMHeadingModel(BaseModel): class LLMHeadingModelWithImagePrompt(LLMHeadingModel): - image_prompt: str + image_prompt: str = Field( + description="Item image prompt in less than 5 words", + ) + + def to_content(self) -> HeadingModel: + return HeadingModel( + heading=self.heading, + description=self.description, + ) class LLMHeadingModelWithIconQuery(LLMHeadingModel): - icon_query: str + icon_query: str = Field( + description="Item icon query in less than 5 words", + ) + + def to_content(self) -> HeadingModel: + return HeadingModel( + heading=self.heading, + description=self.description, + ) class LLMSlideContentModel(BaseModel): - title: str + # title: str = Field( + # description="Slide title in less than 8 words", + # ) def to_content(self) -> SlideContentModel: raise NotImplementedError("to_content method not implemented") class LLMType1Content(LLMSlideContentModel): - body: str - image_prompt: str + type: Literal["1"] = "1" + body: str = Field( + description="Slide content summary in less than 30 words.", + ) + image_prompt: str = Field( + description="Slide image prompt in less than 5 words", + ) def to_content(self) -> Type1Content: return Type1Content( @@ -67,7 +106,12 @@ class LLMType1Content(LLMSlideContentModel): class LLMType2Content(LLMSlideContentModel): - body: List[LLMHeadingModel] + type: Literal["2"] = "2" + body: List[LLMHeadingModel] = Field( + description="Items to show in slide", + min_length=1, + max_length=4, + ) def to_content(self) -> Type2Content: return Type2Content( @@ -77,8 +121,15 @@ class LLMType2Content(LLMSlideContentModel): class LLMType3Content(LLMSlideContentModel): - body: List[LLMHeadingModel] - image_prompt: str + type: Literal["3"] = "3" + body: List[LLMHeadingModel] = Field( + description="Items to show in slide", + min_length=3, + max_length=3, + ) + image_prompt: str = Field( + description="Slide image prompt in less than 5 words", + ) def to_content(self) -> Type3Content: return Type3Content( @@ -89,7 +140,12 @@ class LLMType3Content(LLMSlideContentModel): class LLMType4Content(LLMSlideContentModel): - body: List[LLMHeadingModelWithImagePrompt] + type: Literal["4"] = "4" + body: List[LLMHeadingModelWithImagePrompt] = Field( + description="Items to show in slide", + min_length=1, + max_length=3, + ) def to_content(self) -> Type4Content: return Type4Content( @@ -100,20 +156,30 @@ class LLMType4Content(LLMSlideContentModel): class LLMType5Content(LLMSlideContentModel): - body: str - graph: GraphModel + type: Literal["5"] = "5" + body: str = Field( + description="Slide content summary in less than 30 words.", + ) + table: LLMTableModel = Field(description="Table to show in slide") def to_content(self) -> Type5Content: return Type5Content( title=self.title, body=self.body, - graph=self.graph, + table=self.table, ) class LLMType6Content(LLMSlideContentModel): - description: str - body: List[LLMHeadingModel] + type: Literal["6"] = "6" + description: str = Field( + description="Slide content summary in less than 20 words.", + ) + body: List[LLMHeadingModel] = Field( + description="Items to show in slide", + min_length=1, + max_length=3, + ) def to_content(self) -> Type6Content: return Type6Content( @@ -124,7 +190,12 @@ class LLMType6Content(LLMSlideContentModel): class LLMType7Content(LLMSlideContentModel): - body: List[LLMHeadingModelWithIconQuery] + type: Literal["7"] = "7" + body: List[LLMHeadingModelWithIconQuery] = Field( + description="Items to show in slide", + min_length=1, + max_length=4, + ) def to_content(self) -> Type7Content: return Type7Content( @@ -135,8 +206,15 @@ class LLMType7Content(LLMSlideContentModel): class LLMType8Content(LLMSlideContentModel): - description: str - body: List[LLMHeadingModelWithImagePrompt] + type: Literal["8"] = "8" + description: str = Field( + description="Slide content summary in less than 20 words.", + ) + body: List[LLMHeadingModelWithImagePrompt] = Field( + description="Items to show in slide", + min_length=1, + max_length=3, + ) def to_content(self) -> Type8Content: return Type8Content( @@ -148,14 +226,19 @@ class LLMType8Content(LLMSlideContentModel): class LLMType9Content(LLMSlideContentModel): - body: List[LLMHeadingModel] - graph: GraphModel + type: Literal["9"] = "9" + body: List[LLMHeadingModel] = Field( + description="Items to show in slide", + min_length=1, + max_length=3, + ) + table: LLMTableModel = Field(description="Table to show in slide") def to_content(self) -> Type9Content: return Type9Content( title=self.title, body=[each.to_content() for each in self.body], - graph=self.graph, + table=self.table, ) @@ -180,9 +263,9 @@ class LLMSlideModel(BaseModel): | LLMType4Content | LLMType5Content | LLMType6Content - | LLMType7Content - | LLMType8Content - | LLMType9Content + # | LLMType7Content + # | LLMType8Content + # | LLMType9Content ) diff --git a/servers/fastapi/ppt_generator/models/llm_models_with_validations.py b/servers/fastapi/ppt_generator/models/llm_models_with_validations.py deleted file mode 100644 index 638d853f..00000000 --- a/servers/fastapi/ppt_generator/models/llm_models_with_validations.py +++ /dev/null @@ -1,257 +0,0 @@ -from typing import List, Mapping -from pydantic import Field - -from graph_processor.models import GraphModel -from ppt_generator.models.other_models import ( - TYPE1, - TYPE2, - TYPE3, - TYPE4, - TYPE5, - TYPE6, - TYPE7, - TYPE8, - TYPE9, -) -from ppt_generator.models.llm_models import ( - LLMHeadingModel, - LLMHeadingModelWithImagePrompt, - LLMHeadingModelWithIconQuery, - LLMSlideContentModel, - LLMType1Content, - LLMType2Content, - LLMType3Content, - LLMType4Content, - LLMType5Content, - LLMType6Content, - LLMType7Content, - LLMType8Content, - LLMType9Content, - LLMSlideModel, - LLMPresentationModel, -) - - -class LLMHeadingModelWithValidation(LLMHeadingModel): - heading: str = Field( - description="List item heading to show in slide body in less than 5 words.", - ) - description: str = Field( - description="Description of list item in less than 20 words.", - ) - - -class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt): - image_prompt: str = Field( - description="Prompt used to generate image for this item in less than 6 words.", - ) - - -class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery): - icon_query: str = Field( - description="Icon query to generate icon for this item in less than 4 words.", - ) - - -class LLMType1ContentWithValidation(LLMType1Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: str = Field( - description="Slide content summary in less than 30 words.", - ) - image_prompt: str = Field( - description="Prompt used to generate image for this slide in less than 6 words.", - ) - - @classmethod - def get_notes(cls): - return "" - - -class LLMType2ContentWithValidation(LLMType2Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: List[LLMHeadingModelWithValidation] = Field( - description="List items to show in slide's body", - min_length=1, - max_length=4, - ) - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **1 to 4 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -class LLMType3ContentWithValidation(LLMType3Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: List[LLMHeadingModelWithValidation] = Field( - description="List items to show in slide's body", - min_length=3, - max_length=3, - ) - image_prompt: str = Field( - description="Prompt used to generate image for this slide in less than 6 words.", - ) - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **3 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -class LLMType4ContentWithValidation(LLMType4Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: List[LLMHeadingModelWithImagePromptWithValidation] = Field( - description="List items to show in slide's body", - min_length=1, - max_length=3, - ) - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **1 to 3 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -class LLMType5ContentWithValidation(LLMType5Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: str = Field( - description="Slide content summary in less than 30 words.", - ) - graph: GraphModel = Field(description="Graph to show in slide") - - @classmethod - def get_notes(self): - return "" - - -class LLMType6ContentWithValidation(LLMType6Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - description: str = Field( - description="Slide content summary in less than 20 words.", - ) - body: List[LLMHeadingModelWithValidation] = Field( - description="List items to show in slide's body", - min_length=1, - max_length=3, - ) - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **1 to 3 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -class LLMType7ContentWithValidation(LLMType7Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: List[LLMHeadingModelWithIconQueryWithValidation] = Field( - description="List items to show in slide's body", - min_length=1, - max_length=4, - ) - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **1 to 4 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -class LLMType8ContentWithValidation(LLMType8Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - description: str = Field( - description="Slide content summary in less than 20 words.", - ) - body: List[LLMHeadingModelWithImagePromptWithValidation] = Field( - description="List items to show in slide's body", - min_length=1, - max_length=3, - ) - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **1 to 3 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -class LLMType9ContentWithValidation(LLMType9Content): - title: str = Field( - description="Title of the slide in less than 6 words.", - ) - body: List[LLMHeadingModelWithValidation] = Field( - description="List items to show in slide's body", - min_length=1, - max_length=3, - ) - graph: GraphModel = Field(description="Graph to show in slide") - - @classmethod - def get_notes(cls): - return """ - - The **Body** should include **1 to 3 HeadingModels**. - - Each **Heading** must consist of **1 to 3 words**. - - Each item **Description** can be upto 10 words. - """ - - -LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING: Mapping[int, LLMSlideContentModel] = { - TYPE1: LLMType1ContentWithValidation, - TYPE2: LLMType2ContentWithValidation, - TYPE3: LLMType3ContentWithValidation, - TYPE4: LLMType4ContentWithValidation, - TYPE5: LLMType5ContentWithValidation, - TYPE6: LLMType6ContentWithValidation, - TYPE7: LLMType7ContentWithValidation, - TYPE8: LLMType8ContentWithValidation, - TYPE9: LLMType9ContentWithValidation, -} - - -class LLMSlideModelWithValidation(LLMSlideModel): - type: int - content: ( - LLMType1ContentWithValidation - | LLMType2ContentWithValidation - | LLMType4ContentWithValidation - | LLMType5ContentWithValidation - | LLMType6ContentWithValidation - | LLMType7ContentWithValidation - | LLMType8ContentWithValidation - | LLMType9ContentWithValidation - ) - - -class LLMPresentationModelWithValidation(LLMPresentationModel): - slides: list[LLMSlideModelWithValidation] diff --git a/servers/fastapi/ppt_generator/models/pptx_models.py b/servers/fastapi/ppt_generator/models/pptx_models.py index 327f21dc..6808e6de 100644 --- a/servers/fastapi/ppt_generator/models/pptx_models.py +++ b/servers/fastapi/ppt_generator/models/pptx_models.py @@ -6,8 +6,6 @@ from pptx.util import Pt from pptx.enum.text import PP_ALIGN from pptx.enum.shapes import MSO_AUTO_SHAPE_TYPE, MSO_CONNECTOR_TYPE -from graph_processor.models import GraphModel - class PptxBoxShapeEnum(Enum): RECTANGLE = "rectangle" @@ -138,14 +136,6 @@ class PptxPictureBoxModel(PptxShapeModel): picture: PptxPictureModel -class PptxGraphBoxModel(PptxShapeModel): - position: PptxPositionModel - category_font: Optional[PptxFontModel] = None - value_font: Optional[PptxFontModel] = None - legend_font: Optional[PptxFontModel] = None - graph: GraphModel - - class PptxConnectorModel(PptxShapeModel): type: MSO_CONNECTOR_TYPE = MSO_CONNECTOR_TYPE.STRAIGHT position: PptxPositionModel @@ -159,13 +149,10 @@ class PptxSlideModel(BaseModel): | PptxAutoShapeBoxModel | PptxConnectorModel | PptxPictureBoxModel - | PptxGraphBoxModel ] class PptxPresentationModel(BaseModel): - # theme: PresentationTheme - # watermark: bool background_color: str shapes: Optional[List[PptxShapeModel]] = None slides: List[PptxSlideModel] diff --git a/servers/fastapi/ppt_generator/pptx_presentation_creator.py b/servers/fastapi/ppt_generator/pptx_presentation_creator.py index 9bac33ec..aa00250e 100644 --- a/servers/fastapi/ppt_generator/pptx_presentation_creator.py +++ b/servers/fastapi/ppt_generator/pptx_presentation_creator.py @@ -6,26 +6,12 @@ from lxml import etree from pptx import Presentation from pptx.shapes.autoshape import Shape from pptx.slide import Slide -from pptx.chart.data import ChartData, BubbleChartData -from pptx.chart.chart import Chart from pptx.text.text import _Paragraph, TextFrame, Font, _Run -from pptx.enum.chart import ( - XL_CHART_TYPE, - XL_LEGEND_POSITION, - XL_LABEL_POSITION, -) from pptx.opc.constants import RELATIONSHIP_TYPE as RT from lxml.etree import fromstring, tostring from PIL import Image from pptx.util import Pt -from graph_processor.models import ( - BarGraphDataModel, - BubbleChartDataModel, - GraphTypeEnum, - LineChartDataModel, - PieChartDataModel, -) from pptx.dml.color import RGBColor from ppt_generator.models.pptx_models import ( PptxAutoShapeBoxModel, @@ -33,7 +19,6 @@ from ppt_generator.models.pptx_models import ( PptxConnectorModel, PptxFillModel, PptxFontModel, - PptxGraphBoxModel, PptxParagraphModel, PptxPictureBoxModel, PptxPositionModel, @@ -63,8 +48,6 @@ class PptxPresentationCreator: self._ppt_model = ppt_model self._slide_models = ppt_model.slides - # self._theme = ppt_model.theme - # self._watermark = ppt_model.watermark self._ppt = Presentation() self._ppt.slide_width = Pt(1280) @@ -73,7 +56,6 @@ class PptxPresentationCreator: self._slide_fill = PptxFillModel(color=ppt_model.background_color) def create_ppt(self): - # self.set_presentation_theme() for slide_model in self._slide_models: # Adding global shapes to slide @@ -120,16 +102,9 @@ class PptxPresentationCreator: elif model_type is PptxTextBoxModel: self.add_textbox(slide, shape_model) - elif model_type is PptxGraphBoxModel: - self.add_graph(slide, shape_model) - elif model_type is PptxConnectorModel: self.add_connector(slide, shape_model) - # if self._watermark: - # Adding watermark - # self.add_picture(slide, self.get_watermark_box_model()) - def add_connector(self, slide: Slide, connector_model: PptxConnectorModel): if connector_model.thickness == 0: return @@ -139,126 +114,6 @@ class PptxPresentationCreator: connector_shape.line.width = Pt(connector_model.thickness) connector_shape.line.color.rgb = RGBColor.from_string(connector_model.color) - def add_graph(self, slide: Slide, graph_box_model: PptxGraphBoxModel): - chart_data = None - chart_type = None - graph = graph_box_model.graph - match (graph.type): - case GraphTypeEnum.bar: - chart_data = self.get_bar_graph(graph.data) - chart_type = XL_CHART_TYPE.COLUMN_CLUSTERED - - case GraphTypeEnum.scatter: - chart_data = self.get_scatter_graph(graph.data) - chart_type = XL_CHART_TYPE.XY_SCATTER - - case GraphTypeEnum.bubble: - chart_data = self.get_bubble_graph(graph.data) - chart_type = XL_CHART_TYPE.BUBBLE - - case GraphTypeEnum.line: - chart_data = self.get_line_graph(graph.data) - chart_type = XL_CHART_TYPE.LINE - - case GraphTypeEnum.pie: - chart_data = self.get_pie_graph(graph.data) - chart_type = XL_CHART_TYPE.PIE - - if chart_data: - chart: Chart = slide.shapes.add_chart( - chart_type, *graph_box_model.position.to_pt_list(), chart_data - ).chart - self.apply_graph_styles(chart, graph_box_model) - - def apply_graph_styles(self, chart, graph_box_model: PptxGraphBoxModel): - graph = graph_box_model.graph - - if graph.type in [GraphTypeEnum.pie, GraphTypeEnum.scatter]: - chart.has_legend = True - chart.legend.position = XL_LEGEND_POSITION.RIGHT - else: - chart.has_legend = False - - if graph_box_model.legend_font: - self.apply_font(chart.font, graph_box_model.legend_font) - - try: - category_axis = chart.category_axis - if graph_box_model.category_font: - font = category_axis.tick_labels.font - self.apply_font(font, graph_box_model.category_font) - except: - print("-" * 20) - print("Could not apply category labels style") - - try: - value_axis = chart.value_axis - tick_labels = value_axis.tick_labels - if graph.postfix: - tick_labels.number_format = f'0"{graph.postfix}"' - if graph_box_model.value_font: - self.apply_font(tick_labels.font, graph_box_model.value_font) - except: - print("-" * 20) - print("Could not apply tick labels style") - - if graph_box_model.graph.type is GraphTypeEnum.pie: - for plot in chart.plots: - try: - plot.has_data_labels = True - plot.data_labels.position = ( - XL_LABEL_POSITION.OUTSIDE_END - if graph_box_model.graph.type is GraphTypeEnum.bar - else XL_LABEL_POSITION.CENTER - ) - if graph.postfix: - plot.data_labels.number_format = f'0"{graph.postfix}"' - if graph_box_model.value_font: - self.apply_font( - plot.data_labels.font, - ( - graph_box_model.value_font - if graph_box_model.graph.type is GraphTypeEnum.bar - else PptxFontModel( - # size=self._theme.fonts.p2, - size=16, - bold=True, - color="ffffff", - ) - ), - ) - except: - print("-" * 20) - print("Could not apply data labels style") - - def get_bar_graph(self, graph: BarGraphDataModel): - chart_data = ChartData() - chart_data.categories = graph.get_categories() - for series in graph.series: - chart_data.add_series(series.get_name(), series.data) - return chart_data - - def get_bubble_graph(self, graph: BubbleChartDataModel): - chart_data = BubbleChartData() - for each in graph.series: - series = chart_data.add_series(each.get_name()) - for point in each.points: - series.add_data_point(*point.to_list()) - return chart_data - - def get_line_graph(self, graph: LineChartDataModel): - chart_data = ChartData() - chart_data.categories = graph.get_categories() - for series in graph.series: - chart_data.add_series(series.get_name(), series.data) - return chart_data - - def get_pie_graph(self, graph: PieChartDataModel): - chart_data = ChartData() - chart_data.categories = graph.get_categories() - chart_data.add_series("", graph.series[0].data) - return chart_data - def add_picture(self, slide: Slide, picture_model: PptxPictureBoxModel): image_path = picture_model.picture.path if ( @@ -562,17 +417,5 @@ class PptxPresentationCreator: font.italic = font_model.italic font.size = Pt(font_model.size) - # def get_watermark_box_model(self): - # watermark_asset_path = f"assets/images/{'watermark_dark.png' if self._theme == PresentationTheme.dark else 'watermark.png'}" - - # return PptxPictureBoxModel( - # position=PptxPositionModel(left=1120, top=685, width=140), - # clip=False, - # picture=PptxPictureModel( - # is_network=False, - # path=watermark_asset_path, - # ), - # ) - def save(self, path: str): self._ppt.save(path) diff --git a/servers/fastapi/ppt_generator/slide_generator.py b/servers/fastapi/ppt_generator/slide_generator.py index 45512a55..c3ca4b87 100644 --- a/servers/fastapi/ppt_generator/slide_generator.py +++ b/servers/fastapi/ppt_generator/slide_generator.py @@ -1,34 +1,24 @@ from typing import Optional -from langchain_ollama import ChatOllama -from openai import OpenAI +from pydantic import BaseModel + from api.utils.model_utils import get_large_model, get_llm_client, get_small_model from ppt_config_generator.models import SlideMarkdownModel -from ppt_generator.fix_validation_errors import get_validated_response - -from langchain_core.prompts import ChatPromptTemplate from ppt_generator.models.llm_models import ( LLM_CONTENT_TYPE_MAPPING, - LLMSlideContentModel, -) -from ppt_generator.models.llm_models_with_validations import ( - LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING, ) from ppt_generator.models.other_models import SlideTypeModel from ppt_generator.models.slide_model import SlideModel -def get_prompt_to_generate_slide_content( - title: str, outline: str, notes: Optional[str] = None -): +def get_prompt_to_generate_slide_content(title: str, outline: str): return [ { "role": "system", "content": f""" Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output. - # Steps 1. Analyze the outline and title. 2. Generate structured slide based on the outline and title. @@ -39,10 +29,6 @@ def get_prompt_to_generate_slide_content( - Slide body should not use words like "This slide", "This presentation". - Rephrase the slide body to make it flow naturally. - Do not use markdown formatting in slide body. - - **Icon query** must be a generic single word noun. - - **Image prompt** should be a 2-3 words phrase. - - Try to make paragraphs as short as possible. - {notes} """, }, { @@ -58,99 +44,93 @@ def get_prompt_to_generate_slide_content( ] -prompt_template_to_edit_slide_content = ChatPromptTemplate.from_messages( - [ - ( - "system", - """ +def get_prompt_to_edit_slide_content( + prompt: str, + slide_data: dict, + theme: Optional[dict] = None, + language: Optional[str] = None, +): + return [ + { + "role": "system", + "content": """ Edit Slide data based on provided prompt, follow mentioned steps and notes and provide structured output. # Notes - Provide output in language mentioned in **Input**. - The goal is to change Slide data based on the provided prompt. - Do not change **Image prompts** and **Icon queries** if not asked for in prompt. - - Generate **Image prompts** and **Icon queries** if asked to generate or change image or icons in prompt. - - Ensure there are no line breaks in the JSON. - - Do not use special characters for highlighting. - {notes} + - Generate **Image prompts** and **Icon queries** if asked to generate or change in prompt. **Go through all notes and steps and make sure they are followed, including mentioned constraints** """, - ), - ( - "user", - """ - - Prompt: {prompt} - - Output Language: {language} - - Image Prompts and Icon Queries Language: English - - Theme: {theme} - - Slide data: {slide_data} - """, - ), + }, + { + "role": "user", + "content": f""" + - Prompt: {prompt} + - Output Language: {language} + - Image Prompts and Icon Queries Language: English + - Theme: {theme} + - Slide data: {slide_data} + """, + }, ] -) -prompt_template_to_select_slide_type = ChatPromptTemplate.from_messages( - [ - ( - "system", - """ - Select a Slide Type based on provided user prompt and current slide data. +def get_prompt_to_select_slide_type(prompt: str, slide_data: dict, slide_type: int): + return [ + { + "role": "system", + "content": """ + Select a Slide Type based on provided user prompt and current slide data. - Select slide based on following slide description and make sure it matches user requirement: - # Slide Types (Slide Type : Slide Description) - - **1**: contains title, description and image. - - **2**: contains title and list of items. - - **4**: contains title and list of items with images. - - **5**: contains title, description and a graph. - - **6**: contains title, description and list of items. - - **7**: contains title and list of items with icons. - - **8**: contains title, description and list of items with icons. - - **9**: contains title, list of items and a graph. + Select slide based on following slide description and make sure it matches user requirement: + # Slide Types (Slide Type : Slide Description) + - **1**: contains title, description and image. + - **2**: contains title and list of items. + - **4**: contains title and list of items with images. + - **5**: contains title, description and a graph. + - **6**: contains title, description and list of items. + - **7**: contains title and list of items with icons. + - **8**: contains title, description and list of items with icons. + - **9**: contains title, list of items and a graph. - # Notes - - Do not select different slide type than current unless absolutely necessary as per user prompt. + # Notes + - Do not select different slide type than current unless absolutely necessary as per user prompt. - **Go through all notes and steps and make sure they are followed, including mentioned constraints** - """, - ), - ( - "user", - """ - - User Prompt: {prompt} - - Current Slide Data: {slide_data} - - Current Slide Type: {slide_type} - """, - ), + **Go through all notes and steps and make sure they are followed, including mentioned constraints** + """, + }, + { + "role": "user", + "content": f""" + - User Prompt: {prompt} + - Current Slide Data: {slide_data} + - Current Slide Type: {slide_type} + """, + }, ] -) async def get_slide_content_from_type_and_outline( slide_type: int, outline: SlideMarkdownModel -) -> LLMSlideContentModel: - response_model = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type] +) -> BaseModel: + response_model = LLM_CONTENT_TYPE_MAPPING[slide_type] client = get_llm_client() model = get_small_model() response = await client.beta.chat.completions.parse( model=model, + temperature=0.2, messages=get_prompt_to_generate_slide_content( outline.title, outline.body, - response_model.get_notes(), ), response_format=response_model, ) - - with open("debug/llm_response.json", "w") as f: - f.write(response.choices[0].message.content) - - return LLM_CONTENT_TYPE_MAPPING[slide_type].model_validate_json( - response.choices[0].message.content - ) + return response.choices[0].message.parsed async def get_edited_slide_content_model( @@ -160,28 +140,23 @@ async def get_edited_slide_content_model( theme: Optional[dict] = None, language: Optional[str] = None, ): - model = ChatOllama(model=get_large_model(), temperature=0.8) + client = get_llm_client() + model = get_large_model() - content_type_model_type = LLM_CONTENT_TYPE_WITH_VALIDATION_MAPPING[slide_type] - validation_model = LLM_CONTENT_TYPE_MAPPING[slide_type] - chain = prompt_template_to_edit_slide_content | model.with_structured_output( - content_type_model_type.model_json_schema() - ) + content_type_model_type = LLM_CONTENT_TYPE_MAPPING[slide_type] slide_data = slide.content.to_llm_content().model_dump_json() - edited_content = await get_validated_response( - chain, - { - "prompt": prompt, - "language": language or "English", - "theme": theme, - "slide_data": slide_data, - "notes": "", - }, - content_type_model_type, - validation_model, + response = await client.beta.chat.completions.parse( + model=model, + temperature=0.2, + messages=get_prompt_to_edit_slide_content( + prompt, + slide_data, + theme, + language, + ), + response_format=content_type_model_type, ) - - return edited_content.to_content() + return response.choices[0].message.parsed async def get_slide_type_from_prompt( @@ -189,18 +164,15 @@ async def get_slide_type_from_prompt( slide: SlideModel, ) -> SlideTypeModel: - model = ChatOllama(model=get_small_model(), temperature=0.8) + client = get_llm_client() + model = get_small_model() - chain = prompt_template_to_select_slide_type | model.with_structured_output( - SlideTypeModel.model_json_schema() - ) - slide_data = slide.content.to_llm_content().model_dump_json() - return await get_validated_response( - chain, - { - "prompt": prompt, - "slide_data": slide_data, - "slide_type": slide.type, - }, - SlideTypeModel, + response = await client.beta.chat.completions.parse( + model=model, + temperature=0.2, + messages=get_prompt_to_select_slide_type( + prompt, slide.content.to_llm_content().model_dump_json(), slide.type + ), + response_format=SlideTypeModel, ) + return response.choices[0].message.parsed