fix(fastapi): uses SlideOutlineModel to wrap slide outline, to fix json repair issue
This commit is contained in:
parent
c840b8fce9
commit
09f648dc5b
6 changed files with 27 additions and 10 deletions
|
|
@ -87,7 +87,8 @@ async def stream_outlines(
|
|||
|
||||
presentation.outlines = presentation_outlines.model_dump()
|
||||
presentation.title = (
|
||||
presentation_outlines.slides[0][:50]
|
||||
presentation_outlines.slides[0]
|
||||
.content[:50]
|
||||
.replace("#", "")
|
||||
.replace("/", "")
|
||||
.replace("\\", "")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
heading: str
|
||||
|
|
@ -7,5 +9,5 @@ class DocumentChunk(BaseModel):
|
|||
heading_index: int
|
||||
score: float
|
||||
|
||||
def to_slide_outline(self) -> str:
|
||||
return f"{self.heading}\n{self.content}"
|
||||
def to_slide_outline(self) -> SlideOutlineModel:
|
||||
return SlideOutlineModel(content=f"{self.heading}\n{self.content}")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@ from typing import List
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlideOutlineModel(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class PresentationOutlineModel(BaseModel):
|
||||
slides: List[str]
|
||||
slides: List[SlideOutlineModel]
|
||||
|
||||
def to_string(self):
|
||||
message = ""
|
||||
|
|
|
|||
|
|
@ -1,13 +1,23 @@
|
|||
from typing import List
|
||||
from pydantic import Field
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
def get_presentation_outline_model_with_n_slides(n_slides: int):
|
||||
class SlideOutlineModelWithNSlides(SlideOutlineModel):
|
||||
content: str = Field(
|
||||
description="Markdown content for each slide",
|
||||
min_length=100,
|
||||
max_length=300,
|
||||
)
|
||||
|
||||
class PresentationOutlineModelWithNSlides(PresentationOutlineModel):
|
||||
slides: List[str] = Field(
|
||||
description="Markdown content for each slide in about 100 to 200 words",
|
||||
slides: List[SlideOutlineModelWithNSlides] = Field(
|
||||
description="List of slide outlines",
|
||||
min_items=n_slides,
|
||||
max_items=n_slides,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from models.llm_message import LLMMessage
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
|
@ -50,7 +51,7 @@ def get_messages(outline: str, language: str):
|
|||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
slide_layout: SlideLayoutModel, outline: str, language: str
|
||||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
|
||||
):
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
|
@ -62,7 +63,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(
|
||||
outline,
|
||||
outline.content,
|
||||
language,
|
||||
),
|
||||
response_format=response_schema,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue