feat(fastapi): adds slides assets genetation logic
This commit is contained in:
parent
98cc85a931
commit
1e1637f286
5 changed files with 86 additions and 27 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import Annotated, List, Optional
|
||||
|
|
@ -23,6 +24,7 @@ from utils.llm_calls.generate_presentation_structure import (
|
|||
from utils.llm_calls.generate_slide_content import (
|
||||
get_slide_content_from_type_and_outline,
|
||||
)
|
||||
from utils.process_slides import process_slide_and_fetch_assets
|
||||
|
||||
PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"])
|
||||
|
||||
|
|
@ -185,6 +187,8 @@ async def stream_presentation(presentation_id: str):
|
|||
layout = presentation.get_layout()
|
||||
outline = presentation.get_presentation_outline()
|
||||
|
||||
asyncio_tasks = []
|
||||
|
||||
slides: List[SlideModel] = []
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
|
|
@ -202,6 +206,7 @@ async def stream_presentation(presentation_id: str):
|
|||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
asyncio_tasks.append(process_slide_and_fetch_assets(slide))
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
||||
|
|
@ -212,6 +217,8 @@ async def stream_presentation(presentation_id: str):
|
|||
data=json.dumps({"type": "chunk", "chunk": " ] }"}),
|
||||
).to_string()
|
||||
|
||||
await asyncio.gather(*asyncio_tasks)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slides)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class ContactInfoModel(BaseModel):
|
|||
|
||||
class ImageModel(BaseModel):
|
||||
url: str = Field(description="Image URL")
|
||||
image_type_: Literal["image"] = "image"
|
||||
__image_type__: Literal["image"] = "image"
|
||||
prompt: str = Field(description="Image prompt")
|
||||
|
||||
|
||||
|
|
@ -415,13 +415,16 @@ presentation_layout = PresentationLayoutModel(
|
|||
],
|
||||
)
|
||||
|
||||
print(json.dumps(FirstSlideModel.model_json_schema()))
|
||||
# print(json.dumps(FirstSlideModel.model_json_schema()))
|
||||
|
||||
# slide_schema = FirstSlideModel.model_json_schema()
|
||||
|
||||
# schema_processor = SchemaProcessor()
|
||||
# print(schema_processor.flatten_schema(slide_schema))
|
||||
# print(schema_processor.find_dict_paths_in_object(slide_schema, "_image_type"))
|
||||
# print(
|
||||
# json.dumps(
|
||||
# schema_processor.remove_image_url_fields(FirstSlideModel.model_json_schema())
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# print(PresentationOutlineModel.model_json_schema())
|
||||
|
|
|
|||
|
|
@ -27,15 +27,15 @@ class SchemaProcessor:
|
|||
return self.resolve_refs(schema, defs)
|
||||
|
||||
def find_dict_with_key(
|
||||
self, schema: dict, target_key: str, current_path: Optional[List[str]] = None
|
||||
self, data: dict, target_key: str, current_path: Optional[List[str]] = None
|
||||
) -> List[List[str]]:
|
||||
if current_path is None:
|
||||
current_path = []
|
||||
paths = []
|
||||
if target_key in schema:
|
||||
if target_key in data:
|
||||
paths.append(current_path.copy())
|
||||
|
||||
for key, value in schema.items():
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
new_path = current_path + [key]
|
||||
paths.extend(self.find_dict_with_key(value, target_key, new_path))
|
||||
|
|
@ -48,8 +48,8 @@ class SchemaProcessor:
|
|||
)
|
||||
return paths
|
||||
|
||||
def get_dict_at_path(self, schema: dict, path: List[str]) -> dict:
|
||||
current = schema
|
||||
def get_dict_at_path(self, data: dict, path: List[str]) -> dict:
|
||||
current = data
|
||||
|
||||
for part in path:
|
||||
if part.isdigit():
|
||||
|
|
@ -59,19 +59,47 @@ class SchemaProcessor:
|
|||
|
||||
return current
|
||||
|
||||
def remove_image_url_fields(self, schema: dict) -> dict:
|
||||
copied_schema = schema.copy()
|
||||
def set_dict_at_path(self, data: dict, path: List[str], value) -> None:
|
||||
if not path:
|
||||
raise ValueError("Path cannot be empty")
|
||||
|
||||
image_type_paths = self.find_dict_with_key(copied_schema, "_image_type")
|
||||
current = data
|
||||
|
||||
# Navigate to the parent of the target location
|
||||
for part in path[:-1]:
|
||||
if part.isdigit():
|
||||
index = int(part)
|
||||
if index >= len(current):
|
||||
# Extend list if needed
|
||||
current.extend([{}] * (index - len(current) + 1))
|
||||
current = current[index]
|
||||
else:
|
||||
if part not in current:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
|
||||
# Set the value at the final path component
|
||||
final_part = path[-1]
|
||||
if final_part.isdigit():
|
||||
index = int(final_part)
|
||||
if index >= len(current):
|
||||
# Extend list if needed
|
||||
current.extend([None] * (index - len(current) + 1))
|
||||
current[index] = value
|
||||
else:
|
||||
current[final_part] = value
|
||||
|
||||
def remove_image_url_fields(self, data: dict) -> dict:
|
||||
copied_data = data.copy()
|
||||
|
||||
image_type_paths = self.find_dict_with_key(copied_data, "__image_type__")
|
||||
|
||||
for path in image_type_paths:
|
||||
dict_at_path = self.get_dict_at_path(copied_schema, path)
|
||||
dict_at_path = self.get_dict_at_path(copied_data, path)
|
||||
if "properties" in dict_at_path:
|
||||
del dict_at_path["properties"]["url"]
|
||||
dict_at_parent_path = self.get_dict_at_path(copied_data, path[:-1])
|
||||
if "required" in dict_at_parent_path:
|
||||
dict_at_parent_path["required"].remove("url")
|
||||
|
||||
if dict_at_path.get("_image_type") == "image":
|
||||
if "properties" in dict_at_path and "url" in dict_at_path["properties"]:
|
||||
del dict_at_path["properties"]["url"]
|
||||
|
||||
if "required" in dict_at_path and "url" in dict_at_path["required"]:
|
||||
dict_at_path["required"].remove("url")
|
||||
|
||||
return copied_schema
|
||||
return copied_data
|
||||
|
|
|
|||
|
|
@ -66,9 +66,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "SlideContent",
|
||||
"schema": SCHEMA_PROCESSOR.remove_image_url_fields(
|
||||
slide_layout.json_schema
|
||||
),
|
||||
"schema": slide_layout.json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,33 @@
|
|||
import os
|
||||
from typing import List, Tuple
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.asset import ImageAsset
|
||||
from models.sql.slide import SlideModel
|
||||
from services import SCHEMA_PROCESSOR
|
||||
from services.icon_finder_service import IconFinderService
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
|
||||
|
||||
def process_slide_and_fetch_assets(
|
||||
async def process_slide_and_fetch_assets(
|
||||
slide: SlideModel, layout: SlideLayoutModel
|
||||
) -> Tuple[SlideModel, List[ImageAsset]]:
|
||||
pass
|
||||
) -> SlideModel:
|
||||
image_directory = os.path.join(get_app_data_directory_env(), "images")
|
||||
|
||||
image_generation_service = ImageGenerationService(image_directory)
|
||||
icon_finder_service = IconFinderService()
|
||||
|
||||
image_type_paths = SCHEMA_PROCESSOR.find_dict_with_key(
|
||||
slide.content, "__image_type__"
|
||||
)
|
||||
for path in image_type_paths:
|
||||
image_dict = SCHEMA_PROCESSOR.get_dict_at_path(slide.content, path)
|
||||
image_prompt = image_dict["prompt"]
|
||||
if image_dict["__image_type__"] == "image":
|
||||
image_path = await image_generation_service.generate_image(image_prompt)
|
||||
image_dict["url"] = image_path
|
||||
else:
|
||||
icon_path = await icon_finder_service.search_icons(image_prompt)
|
||||
image_dict["url"] = icon_path[0]
|
||||
|
||||
SCHEMA_PROCESSOR.set_dict_at_path(slide.content, path, image_dict)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue