diff --git a/servers/fastapi/api/routers/presentation/handlers/edit.py b/servers/fastapi/api/routers/presentation/handlers/edit.py index 938f19b7..c4174927 100644 --- a/servers/fastapi/api/routers/presentation/handlers/edit.py +++ b/servers/fastapi/api/routers/presentation/handlers/edit.py @@ -15,7 +15,7 @@ from api.utils.utils import ( get_presentation_dir, get_presentation_images_dir, ) -from api.utils.model_utils import is_ollama_selected +from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected from image_processor.icons_vectorstore_utils import get_icons_vectorstore from image_processor.images_finder import generate_image from image_processor.icons_finder import get_icon @@ -68,13 +68,16 @@ class PresentationEditHandler: new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit) new_slide_type = new_slide_type.slide_type + supports_graph = not is_custom_llm_selected() if is_ollama_selected(): - model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")] - if not model.supports_graph: - if new_slide_type == 5: - new_slide_type = 1 - elif new_slide_type == 9: - new_slide_type = 6 + model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")] + supports_graph = model.supports_graph + + if not supports_graph: + if new_slide_type == 5: + new_slide_type = 1 + elif new_slide_type == 9: + new_slide_type = 6 edited_content = await get_edited_slide_content_model( self.prompt, diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_presentation.py b/servers/fastapi/api/routers/presentation/handlers/generate_presentation.py index 3123452e..55c4dfc3 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_presentation.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_presentation.py @@ -19,7 +19,7 @@ from api.services.instances import TEMP_FILE_SERVICE from api.services.logging import LoggingService from api.sql_models import PresentationSqlModel, SlideSqlModel from api.utils.utils import get_presentation_dir -from api.utils.model_utils import is_ollama_selected +from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected from document_processor.loader import DocumentsLoader from ppt_config_generator.document_summary_generator import generate_document_summary from ppt_config_generator.models import PresentationMarkdownModel @@ -45,7 +45,7 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin): TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir) async def post(self, logging_service: LoggingService, log_metadata: LogMetadata): - if is_ollama_selected(): + if is_ollama_selected() or is_custom_llm_selected(): raise HTTPException( status_code=400, detail="Ollama is not currently supported for this endpoint", diff --git a/servers/fastapi/graph_processor/models.py b/servers/fastapi/graph_processor/models.py new file mode 100644 index 00000000..9bd59b2f --- /dev/null +++ b/servers/fastapi/graph_processor/models.py @@ -0,0 +1,116 @@ +from enum import Enum +from typing import List, Optional +from pydantic import BaseModel, Field, model_validator + + +class PointModel(BaseModel): + x: float + y: float + + def to_list(self) -> List[float]: + return [self.x, self.y] + + +class PointWithRadius(PointModel): + radius: Optional[float] = None + + +class BarSeriesModel(BaseModel): + name: str + data: List[float] = Field( + description="Only numbers should be given out in data. Don't include text/string in data." + ) + + +class ScatterSeriesModel(BaseModel): + name: str + points: List[PointModel] + + +class BubbleSeriesModel(BaseModel): + name: str + points: List[PointWithRadius] + + +class LineSeriesModel(BaseModel): + name: str + data: List[float] = Field( + description="Only numbers should be given out in data. Don't include text/string in data." + ) + + +class PieChartSeriesModel(BaseModel): + data: List[float] + + +class BarGraphDataModel(BaseModel): + categories: List[str] + series: List[BarSeriesModel] = Field( + description="There should be no more than 3 series" + ) + + +class ScatterChartDataModel(BaseModel): + series: List[ScatterSeriesModel] + + +class BubbleChartDataModel(BaseModel): + series: List[BubbleSeriesModel] + + +class LineChartDataModel(BaseModel): + categories: List[str] + series: List[LineSeriesModel] = Field( + description="There should be no more than 3 series" + ) + + +class PieChartDataModel(BaseModel): + categories: List[str] + series: List[PieChartSeriesModel] = Field( + description="One series model with list of data", + min_length=1, + ) + + @model_validator(mode="after") + def limit_series(self): + self.series = self.series[:1] + return self + + +class GraphTypeEnum(Enum): + pie = "pie" + bar = "bar" + line = "line" + + +class LLMGraphModel(BaseModel): + name: str + type: GraphTypeEnum + unit: Optional[str] = Field( + description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc." + ) + data: PieChartDataModel | LineChartDataModel | BarGraphDataModel + + +class GraphModel(LLMGraphModel): + style: Optional[dict] = {} + + @classmethod + def from_llm_graph_model( + cls, llm_graph_model: LLMGraphModel, style: Optional[dict] = {} + ): + return cls( + name=llm_graph_model.name, + type=llm_graph_model.type, + unit=llm_graph_model.unit, + data=llm_graph_model.data, + style=style, + ) + + +GRAPH_TYPE_MAPPING = { + GraphTypeEnum.pie: PieChartDataModel, + GraphTypeEnum.bar: BarGraphDataModel, + GraphTypeEnum.line: LineChartDataModel, +} diff --git a/servers/fastapi/ppt_generator/models/content_type_models.py b/servers/fastapi/ppt_generator/models/content_type_models.py index 0bebd343..4e50ba8b 100644 --- a/servers/fastapi/ppt_generator/models/content_type_models.py +++ b/servers/fastapi/ppt_generator/models/content_type_models.py @@ -1,6 +1,8 @@ +from enum import Enum from typing import List, Mapping, Union from pydantic import BaseModel +from graph_processor.models import GraphModel, LLMGraphModel from ppt_generator.models.other_models import ( TYPE1, TYPE2, @@ -14,6 +16,13 @@ from ppt_generator.models.other_models import ( ) +class TableType(Enum): + TABLE = "table" + BAR = "bar" + LINE = "line" + PIE = "pie" + + class TableDataModel(BaseModel): x_labels: List[str] y_labels: List[str] @@ -22,6 +31,7 @@ class TableDataModel(BaseModel): class TableModel(BaseModel): name: str + type: TableType data: TableDataModel @@ -117,7 +127,8 @@ class Type4Content(SlideContentModel): class Type5Content(SlideContentModel): body: str - table: TableModel + # table: TableModel + graph: GraphModel def to_llm_content(self): from ppt_generator.models.llm_models import LLMType5Content @@ -125,7 +136,8 @@ class Type5Content(SlideContentModel): return LLMType5Content( title=self.title, body=self.body, - table=self.table, + # table=self.table, + graph=self.graph, ) @@ -181,7 +193,8 @@ class Type8Content(SlideContentModel): class Type9Content(SlideContentModel): body: List[HeadingModel] - table: TableModel + # table: TableModel + graph: GraphModel def to_llm_content(self): from ppt_generator.models.llm_models import LLMType9Content @@ -189,7 +202,8 @@ class Type9Content(SlideContentModel): return LLMType9Content( title=self.title, body=[item.to_llm_content() for item in self.body], - table=self.table, + # table=self.table, + graph=self.graph, ) diff --git a/servers/fastapi/ppt_generator/models/llm_models.py b/servers/fastapi/ppt_generator/models/llm_models.py index 430af6c3..9820bba6 100644 --- a/servers/fastapi/ppt_generator/models/llm_models.py +++ b/servers/fastapi/ppt_generator/models/llm_models.py @@ -1,10 +1,12 @@ from typing import List, Mapping, Union from pydantic import BaseModel +from graph_processor.models import GraphModel, LLMGraphModel from ppt_generator.models.content_type_models import ( HeadingModel, TableDataModel, TableModel, + TableType, Type1Content, Type2Content, Type3Content, @@ -36,6 +38,7 @@ class LLMTableDataModel(TableDataModel): class LLMTableModel(TableModel): name: str + type: TableType data: LLMTableDataModel @@ -121,13 +124,15 @@ class LLMType4Content(LLMSlideContentModel): class LLMType5Content(LLMSlideContentModel): body: str - table: LLMTableModel + # table: LLMTableModel + graph: LLMGraphModel def to_content(self) -> Type5Content: return Type5Content( title=self.title, body=self.body, - table=self.table, + # table=self.table, + graph=GraphModel.from_llm_graph_model(self.graph), ) @@ -169,13 +174,15 @@ class LLMType8Content(LLMSlideContentModel): class LLMType9Content(LLMSlideContentModel): body: List[LLMHeadingModel] - table: LLMTableModel + # table: LLMTableModel + graph: LLMGraphModel def to_content(self) -> Type9Content: return Type9Content( title=self.title, body=[each.to_content() for each in self.body], - table=self.table, + # table=self.table, + graph=GraphModel.from_llm_graph_model(self.graph), ) diff --git a/servers/fastapi/ppt_generator/models/llm_models_with_validations.py b/servers/fastapi/ppt_generator/models/llm_models_with_validations.py index 00074152..06e92dd4 100644 --- a/servers/fastapi/ppt_generator/models/llm_models_with_validations.py +++ b/servers/fastapi/ppt_generator/models/llm_models_with_validations.py @@ -1,6 +1,8 @@ from typing import List, Mapping, Union from pydantic import Field +from graph_processor.models import LLMGraphModel +from ppt_generator.models.content_type_models import TableType from ppt_generator.models.other_models import ( TYPE1, TYPE2, @@ -57,6 +59,7 @@ class LLMTableModelWithValidation(LLMTableModel): min_length=10, max_length=50, ) + type: TableType = Field(description="Type of the table") data: LLMTableDataModelWithValidation @@ -145,7 +148,8 @@ class LLMType5ContentWithValidation(LLMType5Content): min_length=50, max_length=300, ) - table: LLMTableModelWithValidation = Field(description="Table to show in slide") + # table: LLMTableModelWithValidation = Field(description="Table to show in slide") + graph: LLMGraphModel = Field(description="Graph to show in slide") class LLMType6ContentWithValidation(LLMType6Content): @@ -188,7 +192,8 @@ class LLMType9ContentWithValidation(LLMType9Content): min_length=1, max_length=3, ) - table: LLMTableModelWithValidation = Field(description="Table to show in slide") + # table: LLMTableModelWithValidation = Field(description="Table to show in slide") + graph: LLMGraphModel = Field(description="Graph to show in slide") LLMContentUnionWithValidation = Union[ diff --git a/servers/fastapi/ppt_generator/slide_generator.py b/servers/fastapi/ppt_generator/slide_generator.py index c37e6b44..715fbd6a 100644 --- a/servers/fastapi/ppt_generator/slide_generator.py +++ b/servers/fastapi/ppt_generator/slide_generator.py @@ -161,7 +161,8 @@ async def get_edited_slide_content_model( ), response_format=content_type_model_type, ) - return response.choices[0].message.parsed + response_data = response.choices[0].message.parsed + return response_data async def get_slide_type_from_prompt(