fix(fastapi): revert tables to old graph implementations
This commit is contained in:
parent
0c0f980678
commit
4aabf52e8c
7 changed files with 166 additions and 20 deletions
|
|
@ -15,7 +15,7 @@ from api.utils.utils import (
|
|||
get_presentation_dir,
|
||||
get_presentation_images_dir,
|
||||
)
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
|
|
@ -68,13 +68,16 @@ class PresentationEditHandler:
|
|||
new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit)
|
||||
new_slide_type = new_slide_type.slide_type
|
||||
|
||||
supports_graph = not is_custom_llm_selected()
|
||||
if is_ollama_selected():
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("MODEL")]
|
||||
if not model.supports_graph:
|
||||
if new_slide_type == 5:
|
||||
new_slide_type = 1
|
||||
elif new_slide_type == 9:
|
||||
new_slide_type = 6
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
supports_graph = model.supports_graph
|
||||
|
||||
if not supports_graph:
|
||||
if new_slide_type == 5:
|
||||
new_slide_type = 1
|
||||
elif new_slide_type == 9:
|
||||
new_slide_type = 6
|
||||
|
||||
edited_content = await get_edited_slide_content_model(
|
||||
self.prompt,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from api.services.instances import TEMP_FILE_SERVICE
|
|||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_ollama_selected
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
|
|
@ -45,7 +45,7 @@ class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
|||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
if is_ollama_selected():
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Ollama is not currently supported for this endpoint",
|
||||
|
|
|
|||
116
servers/fastapi/graph_processor/models.py
Normal file
116
servers/fastapi/graph_processor/models.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class PointModel(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
def to_list(self) -> List[float]:
|
||||
return [self.x, self.y]
|
||||
|
||||
|
||||
class PointWithRadius(PointModel):
|
||||
radius: Optional[float] = None
|
||||
|
||||
|
||||
class BarSeriesModel(BaseModel):
|
||||
name: str
|
||||
data: List[float] = Field(
|
||||
description="Only numbers should be given out in data. Don't include text/string in data."
|
||||
)
|
||||
|
||||
|
||||
class ScatterSeriesModel(BaseModel):
|
||||
name: str
|
||||
points: List[PointModel]
|
||||
|
||||
|
||||
class BubbleSeriesModel(BaseModel):
|
||||
name: str
|
||||
points: List[PointWithRadius]
|
||||
|
||||
|
||||
class LineSeriesModel(BaseModel):
|
||||
name: str
|
||||
data: List[float] = Field(
|
||||
description="Only numbers should be given out in data. Don't include text/string in data."
|
||||
)
|
||||
|
||||
|
||||
class PieChartSeriesModel(BaseModel):
|
||||
data: List[float]
|
||||
|
||||
|
||||
class BarGraphDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[BarSeriesModel] = Field(
|
||||
description="There should be no more than 3 series"
|
||||
)
|
||||
|
||||
|
||||
class ScatterChartDataModel(BaseModel):
|
||||
series: List[ScatterSeriesModel]
|
||||
|
||||
|
||||
class BubbleChartDataModel(BaseModel):
|
||||
series: List[BubbleSeriesModel]
|
||||
|
||||
|
||||
class LineChartDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[LineSeriesModel] = Field(
|
||||
description="There should be no more than 3 series"
|
||||
)
|
||||
|
||||
|
||||
class PieChartDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[PieChartSeriesModel] = Field(
|
||||
description="One series model with list of data",
|
||||
min_length=1,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def limit_series(self):
|
||||
self.series = self.series[:1]
|
||||
return self
|
||||
|
||||
|
||||
class GraphTypeEnum(Enum):
|
||||
pie = "pie"
|
||||
bar = "bar"
|
||||
line = "line"
|
||||
|
||||
|
||||
class LLMGraphModel(BaseModel):
|
||||
name: str
|
||||
type: GraphTypeEnum
|
||||
unit: Optional[str] = Field(
|
||||
description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc."
|
||||
)
|
||||
data: PieChartDataModel | LineChartDataModel | BarGraphDataModel
|
||||
|
||||
|
||||
class GraphModel(LLMGraphModel):
|
||||
style: Optional[dict] = {}
|
||||
|
||||
@classmethod
|
||||
def from_llm_graph_model(
|
||||
cls, llm_graph_model: LLMGraphModel, style: Optional[dict] = {}
|
||||
):
|
||||
return cls(
|
||||
name=llm_graph_model.name,
|
||||
type=llm_graph_model.type,
|
||||
unit=llm_graph_model.unit,
|
||||
data=llm_graph_model.data,
|
||||
style=style,
|
||||
)
|
||||
|
||||
|
||||
GRAPH_TYPE_MAPPING = {
|
||||
GraphTypeEnum.pie: PieChartDataModel,
|
||||
GraphTypeEnum.bar: BarGraphDataModel,
|
||||
GraphTypeEnum.line: LineChartDataModel,
|
||||
}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
from enum import Enum
|
||||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graph_processor.models import GraphModel, LLMGraphModel
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
|
|
@ -14,6 +16,13 @@ from ppt_generator.models.other_models import (
|
|||
)
|
||||
|
||||
|
||||
class TableType(Enum):
|
||||
TABLE = "table"
|
||||
BAR = "bar"
|
||||
LINE = "line"
|
||||
PIE = "pie"
|
||||
|
||||
|
||||
class TableDataModel(BaseModel):
|
||||
x_labels: List[str]
|
||||
y_labels: List[str]
|
||||
|
|
@ -22,6 +31,7 @@ class TableDataModel(BaseModel):
|
|||
|
||||
class TableModel(BaseModel):
|
||||
name: str
|
||||
type: TableType
|
||||
data: TableDataModel
|
||||
|
||||
|
||||
|
|
@ -117,7 +127,8 @@ class Type4Content(SlideContentModel):
|
|||
|
||||
class Type5Content(SlideContentModel):
|
||||
body: str
|
||||
table: TableModel
|
||||
# table: TableModel
|
||||
graph: GraphModel
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType5Content
|
||||
|
|
@ -125,7 +136,8 @@ class Type5Content(SlideContentModel):
|
|||
return LLMType5Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
table=self.table,
|
||||
# table=self.table,
|
||||
graph=self.graph,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -181,7 +193,8 @@ class Type8Content(SlideContentModel):
|
|||
|
||||
class Type9Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
table: TableModel
|
||||
# table: TableModel
|
||||
graph: GraphModel
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType9Content
|
||||
|
|
@ -189,7 +202,8 @@ class Type9Content(SlideContentModel):
|
|||
return LLMType9Content(
|
||||
title=self.title,
|
||||
body=[item.to_llm_content() for item in self.body],
|
||||
table=self.table,
|
||||
# table=self.table,
|
||||
graph=self.graph,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graph_processor.models import GraphModel, LLMGraphModel
|
||||
from ppt_generator.models.content_type_models import (
|
||||
HeadingModel,
|
||||
TableDataModel,
|
||||
TableModel,
|
||||
TableType,
|
||||
Type1Content,
|
||||
Type2Content,
|
||||
Type3Content,
|
||||
|
|
@ -36,6 +38,7 @@ class LLMTableDataModel(TableDataModel):
|
|||
|
||||
class LLMTableModel(TableModel):
|
||||
name: str
|
||||
type: TableType
|
||||
data: LLMTableDataModel
|
||||
|
||||
|
||||
|
|
@ -121,13 +124,15 @@ class LLMType4Content(LLMSlideContentModel):
|
|||
|
||||
class LLMType5Content(LLMSlideContentModel):
|
||||
body: str
|
||||
table: LLMTableModel
|
||||
# table: LLMTableModel
|
||||
graph: LLMGraphModel
|
||||
|
||||
def to_content(self) -> Type5Content:
|
||||
return Type5Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
table=self.table,
|
||||
# table=self.table,
|
||||
graph=GraphModel.from_llm_graph_model(self.graph),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -169,13 +174,15 @@ class LLMType8Content(LLMSlideContentModel):
|
|||
|
||||
class LLMType9Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModel]
|
||||
table: LLMTableModel
|
||||
# table: LLMTableModel
|
||||
graph: LLMGraphModel
|
||||
|
||||
def to_content(self) -> Type9Content:
|
||||
return Type9Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
table=self.table,
|
||||
# table=self.table,
|
||||
graph=GraphModel.from_llm_graph_model(self.graph),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from typing import List, Mapping, Union
|
||||
from pydantic import Field
|
||||
|
||||
from graph_processor.models import LLMGraphModel
|
||||
from ppt_generator.models.content_type_models import TableType
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
|
|
@ -57,6 +59,7 @@ class LLMTableModelWithValidation(LLMTableModel):
|
|||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
type: TableType = Field(description="Type of the table")
|
||||
data: LLMTableDataModelWithValidation
|
||||
|
||||
|
||||
|
|
@ -145,7 +148,8 @@ class LLMType5ContentWithValidation(LLMType5Content):
|
|||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
graph: LLMGraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
|
||||
class LLMType6ContentWithValidation(LLMType6Content):
|
||||
|
|
@ -188,7 +192,8 @@ class LLMType9ContentWithValidation(LLMType9Content):
|
|||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
graph: LLMGraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
|
||||
LLMContentUnionWithValidation = Union[
|
||||
|
|
|
|||
|
|
@ -161,7 +161,8 @@ async def get_edited_slide_content_model(
|
|||
),
|
||||
response_format=content_type_model_type,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
response_data = response.choices[0].message.parsed
|
||||
return response_data
|
||||
|
||||
|
||||
async def get_slide_type_from_prompt(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue