fix(fastapi): revert tables to old graph implementations

This commit is contained in:
sauravniraula 2025-07-10 00:59:19 +05:45
parent 0c0f980678
commit 4aabf52e8c
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
7 changed files with 166 additions and 20 deletions

View file

@ -15,7 +15,7 @@ from api.utils.utils import (
get_presentation_dir,
get_presentation_images_dir,
)
from api.utils.model_utils import is_ollama_selected
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
from image_processor.images_finder import generate_image
from image_processor.icons_finder import get_icon
@ -68,13 +68,16 @@ class PresentationEditHandler:
new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit)
new_slide_type = new_slide_type.slide_type
supports_graph = not is_custom_llm_selected()
if is_ollama_selected():
model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")]
if not model.supports_graph:
if new_slide_type == 5:
new_slide_type = 1
elif new_slide_type == 9:
new_slide_type = 6
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
supports_graph = model.supports_graph
if not supports_graph:
if new_slide_type == 5:
new_slide_type = 1
elif new_slide_type == 9:
new_slide_type = 6
edited_content = await get_edited_slide_content_model(
self.prompt,

View file

@ -19,7 +19,7 @@ from api.services.instances import TEMP_FILE_SERVICE
from api.services.logging import LoggingService
from api.sql_models import PresentationSqlModel, SlideSqlModel
from api.utils.utils import get_presentation_dir
from api.utils.model_utils import is_ollama_selected
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
from document_processor.loader import DocumentsLoader
from ppt_config_generator.document_summary_generator import generate_document_summary
from ppt_config_generator.models import PresentationMarkdownModel
@ -45,7 +45,7 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
if is_ollama_selected():
if is_ollama_selected() or is_custom_llm_selected():
raise HTTPException(
status_code=400,
detail="Ollama is not currently supported for this endpoint",

View file

@ -0,0 +1,116 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
class PointModel(BaseModel):
x: float
y: float
def to_list(self) -> List[float]:
return [self.x, self.y]
class PointWithRadius(PointModel):
radius: Optional[float] = None
class BarSeriesModel(BaseModel):
name: str
data: List[float] = Field(
description="Only numbers should be given out in data. Don't include text/string in data."
)
class ScatterSeriesModel(BaseModel):
name: str
points: List[PointModel]
class BubbleSeriesModel(BaseModel):
name: str
points: List[PointWithRadius]
class LineSeriesModel(BaseModel):
name: str
data: List[float] = Field(
description="Only numbers should be given out in data. Don't include text/string in data."
)
class PieChartSeriesModel(BaseModel):
data: List[float]
class BarGraphDataModel(BaseModel):
categories: List[str]
series: List[BarSeriesModel] = Field(
description="There should be no more than 3 series"
)
class ScatterChartDataModel(BaseModel):
series: List[ScatterSeriesModel]
class BubbleChartDataModel(BaseModel):
series: List[BubbleSeriesModel]
class LineChartDataModel(BaseModel):
categories: List[str]
series: List[LineSeriesModel] = Field(
description="There should be no more than 3 series"
)
class PieChartDataModel(BaseModel):
categories: List[str]
series: List[PieChartSeriesModel] = Field(
description="One series model with list of data",
min_length=1,
)
@model_validator(mode="after")
def limit_series(self):
self.series = self.series[:1]
return self
class GraphTypeEnum(Enum):
pie = "pie"
bar = "bar"
line = "line"
class LLMGraphModel(BaseModel):
name: str
type: GraphTypeEnum
unit: Optional[str] = Field(
description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc."
)
data: PieChartDataModel | LineChartDataModel | BarGraphDataModel
class GraphModel(LLMGraphModel):
style: Optional[dict] = {}
@classmethod
def from_llm_graph_model(
cls, llm_graph_model: LLMGraphModel, style: Optional[dict] = {}
):
return cls(
name=llm_graph_model.name,
type=llm_graph_model.type,
unit=llm_graph_model.unit,
data=llm_graph_model.data,
style=style,
)
GRAPH_TYPE_MAPPING = {
GraphTypeEnum.pie: PieChartDataModel,
GraphTypeEnum.bar: BarGraphDataModel,
GraphTypeEnum.line: LineChartDataModel,
}

View file

@ -1,6 +1,8 @@
from enum import Enum
from typing import List, Mapping, Union
from pydantic import BaseModel
from graph_processor.models import GraphModel, LLMGraphModel
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
@ -14,6 +16,13 @@ from ppt_generator.models.other_models import (
)
class TableType(Enum):
TABLE = "table"
BAR = "bar"
LINE = "line"
PIE = "pie"
class TableDataModel(BaseModel):
x_labels: List[str]
y_labels: List[str]
@ -22,6 +31,7 @@ class TableDataModel(BaseModel):
class TableModel(BaseModel):
name: str
type: TableType
data: TableDataModel
@ -117,7 +127,8 @@ class Type4Content(SlideContentModel):
class Type5Content(SlideContentModel):
body: str
table: TableModel
# table: TableModel
graph: GraphModel
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType5Content
@ -125,7 +136,8 @@ class Type5Content(SlideContentModel):
return LLMType5Content(
title=self.title,
body=self.body,
table=self.table,
# table=self.table,
graph=self.graph,
)
@ -181,7 +193,8 @@ class Type8Content(SlideContentModel):
class Type9Content(SlideContentModel):
body: List[HeadingModel]
table: TableModel
# table: TableModel
graph: GraphModel
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType9Content
@ -189,7 +202,8 @@ class Type9Content(SlideContentModel):
return LLMType9Content(
title=self.title,
body=[item.to_llm_content() for item in self.body],
table=self.table,
# table=self.table,
graph=self.graph,
)

View file

@ -1,10 +1,12 @@
from typing import List, Mapping, Union
from pydantic import BaseModel
from graph_processor.models import GraphModel, LLMGraphModel
from ppt_generator.models.content_type_models import (
HeadingModel,
TableDataModel,
TableModel,
TableType,
Type1Content,
Type2Content,
Type3Content,
@ -36,6 +38,7 @@ class LLMTableDataModel(TableDataModel):
class LLMTableModel(TableModel):
name: str
type: TableType
data: LLMTableDataModel
@ -121,13 +124,15 @@ class LLMType4Content(LLMSlideContentModel):
class LLMType5Content(LLMSlideContentModel):
body: str
table: LLMTableModel
# table: LLMTableModel
graph: LLMGraphModel
def to_content(self) -> Type5Content:
return Type5Content(
title=self.title,
body=self.body,
table=self.table,
# table=self.table,
graph=GraphModel.from_llm_graph_model(self.graph),
)
@ -169,13 +174,15 @@ class LLMType8Content(LLMSlideContentModel):
class LLMType9Content(LLMSlideContentModel):
body: List[LLMHeadingModel]
table: LLMTableModel
# table: LLMTableModel
graph: LLMGraphModel
def to_content(self) -> Type9Content:
return Type9Content(
title=self.title,
body=[each.to_content() for each in self.body],
table=self.table,
# table=self.table,
graph=GraphModel.from_llm_graph_model(self.graph),
)

View file

@ -1,6 +1,8 @@
from typing import List, Mapping, Union
from pydantic import Field
from graph_processor.models import LLMGraphModel
from ppt_generator.models.content_type_models import TableType
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
@ -57,6 +59,7 @@ class LLMTableModelWithValidation(LLMTableModel):
min_length=10,
max_length=50,
)
type: TableType = Field(description="Type of the table")
data: LLMTableDataModelWithValidation
@ -145,7 +148,8 @@ class LLMType5ContentWithValidation(LLMType5Content):
min_length=50,
max_length=300,
)
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
graph: LLMGraphModel = Field(description="Graph to show in slide")
class LLMType6ContentWithValidation(LLMType6Content):
@ -188,7 +192,8 @@ class LLMType9ContentWithValidation(LLMType9Content):
min_length=1,
max_length=3,
)
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
graph: LLMGraphModel = Field(description="Graph to show in slide")
LLMContentUnionWithValidation = Union[

View file

@ -161,7 +161,8 @@ async def get_edited_slide_content_model(
),
response_format=content_type_model_type,
)
return response.choices[0].message.parsed
response_data = response.choices[0].message.parsed
return response_data
async def get_slide_type_from_prompt(