fix(fastapi): checks if generated image is ImageAsset or str and saves it to database accordingly, feat(fastapi): adds list generated images endpoint
This commit is contained in:
parent
eceef4f136
commit
b27efd8cd4
5 changed files with 140 additions and 96 deletions
|
|
@ -1,4 +1,6 @@
|
|||
from typing import List
|
||||
from fastapi import APIRouter
|
||||
from sqlmodel import select
|
||||
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.sql.image_asset import ImageAsset
|
||||
|
|
@ -24,3 +26,15 @@ async def generate_image(prompt: str):
|
|||
sql_session.commit()
|
||||
|
||||
return image.path
|
||||
|
||||
|
||||
@IMAGES_ROUTER.get("/generated", response_model=List[ImageAsset])
|
||||
async def get_generated_images():
|
||||
try:
|
||||
with get_sql_session() as sql_session:
|
||||
images = sql_session.exec(
|
||||
select(ImageAsset).order_by(ImageAsset.created_at.desc())
|
||||
).all()
|
||||
return images
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to retrieve generated images: {str(e)}"}
|
||||
|
|
|
|||
|
|
@ -214,7 +214,10 @@ async def stream_presentation(presentation_id: str):
|
|||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
|
||||
# This will mutate slide
|
||||
async_assets_generation_tasks.append(process_slide_and_fetch_assets(slide))
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
||||
|
|
|
|||
|
|
@ -1,18 +1,12 @@
|
|||
import asyncio
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.database import get_sql_session
|
||||
from services.icon_finder_service import IconFinderService
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
from utils.dict_utils import get_dict_at_path, get_dict_paths_with_key, set_dict_at_path
|
||||
from utils.llm_calls.edit_slide import get_edited_slide_content
|
||||
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
|
||||
from utils.process_slides import process_old_and_new_slides_and_fetch_assets
|
||||
|
||||
|
||||
SLIDE_ROUTER = APIRouter(prefix="/slide", tags=["Slide"])
|
||||
|
|
@ -38,99 +32,16 @@ async def edit_slide(id: Annotated[str, Body()], prompt: Annotated[str, Body()])
|
|||
prompt, slide_layout, slide, presentation.language
|
||||
)
|
||||
|
||||
# Get old image and icon dicts
|
||||
old_image_dict_paths = get_dict_paths_with_key(slide.content, "__image_prompt__")
|
||||
old_image_dicts = [
|
||||
get_dict_at_path(slide.content, path) for path in old_image_dict_paths
|
||||
]
|
||||
old_image_prompts = [
|
||||
old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts
|
||||
]
|
||||
old_icon_dict_paths = get_dict_paths_with_key(slide.content, "__icon_query__")
|
||||
old_icon_dicts = [
|
||||
get_dict_at_path(slide.content, path) for path in old_icon_dict_paths
|
||||
]
|
||||
old_icon_queries = [
|
||||
old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts
|
||||
]
|
||||
|
||||
# Get new image and icon dicts
|
||||
new_image_dict_paths = get_dict_paths_with_key(
|
||||
edited_slide_content, "__image_prompt__"
|
||||
# This will mutate edited_slide_content
|
||||
new_assets = await process_old_and_new_slides_and_fetch_assets(
|
||||
slide.content, edited_slide_content
|
||||
)
|
||||
new_image_dicts = [
|
||||
get_dict_at_path(edited_slide_content, path) for path in new_image_dict_paths
|
||||
]
|
||||
new_icon_dict_paths = get_dict_paths_with_key(
|
||||
edited_slide_content, "__icon_query__"
|
||||
)
|
||||
new_icon_dicts = [
|
||||
get_dict_at_path(edited_slide_content, path) for path in new_icon_dict_paths
|
||||
]
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
icon_finder_service = IconFinderService()
|
||||
|
||||
async_image_fetch_tasks = []
|
||||
new_images_fetch_status = []
|
||||
|
||||
async_icon_fetch_tasks = []
|
||||
new_icons_fetch_status = []
|
||||
|
||||
for new_image in new_image_dicts:
|
||||
# Use old image url if prompt is same
|
||||
if new_image["__image_prompt__"] in old_image_prompts:
|
||||
old_image_url = old_image_dicts[
|
||||
old_image_prompts.index(new_image["__image_prompt__"])
|
||||
]["__image_url__"]
|
||||
new_image["__image_url__"] = old_image_url
|
||||
new_images_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_image_fetch_tasks.append(
|
||||
image_generation_service.generate_image(
|
||||
ImagePrompt(
|
||||
prompt=new_image["__image_prompt__"],
|
||||
)
|
||||
)
|
||||
)
|
||||
new_images_fetch_status.append(True)
|
||||
|
||||
for new_icon in new_icon_dicts:
|
||||
if new_icon["__icon_query__"] in old_icon_queries:
|
||||
old_icon_url = old_icon_dicts[
|
||||
old_icon_queries.index(new_icon["__icon_query__"])
|
||||
]["__icon_url__"]
|
||||
new_icon["__icon_url__"] = old_icon_url
|
||||
new_icons_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_icon_fetch_tasks.append(
|
||||
icon_finder_service.search_icons(new_icon["__icon_query__"])
|
||||
)
|
||||
new_icons_fetch_status.append(True)
|
||||
|
||||
new_images = await asyncio.gather(*async_image_fetch_tasks)
|
||||
new_icons = await asyncio.gather(*async_icon_fetch_tasks)
|
||||
|
||||
for i, new_image in enumerate(new_images):
|
||||
if new_images_fetch_status[i]:
|
||||
new_image_dicts[i]["__image_url__"] = new_images[i]
|
||||
|
||||
for i, new_icon in enumerate(new_icons):
|
||||
if new_icons_fetch_status[i]:
|
||||
new_icon_dicts[i]["__icon_url__"] = new_icons[i]
|
||||
|
||||
for i, new_image_dict in enumerate(new_image_dicts):
|
||||
set_dict_at_path(edited_slide_content, new_image_dict_paths[i], new_image_dict)
|
||||
|
||||
for i, new_icon_dict in enumerate(new_icon_dicts):
|
||||
set_dict_at_path(edited_slide_content, new_icon_dict_paths[i], new_icon_dict)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(slide)
|
||||
slide.content = edited_slide_content
|
||||
slide.layout = slide_layout.id
|
||||
sql_session.add_all(new_assets)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(slide)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
|
@ -9,6 +9,6 @@ from utils.randomizers import get_random_uuid
|
|||
|
||||
class ImageAsset(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
created_at: datetime = Field(default=datetime.now())
|
||||
created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
path: str
|
||||
extras: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
|
|
|
|||
|
|
@ -58,3 +58,119 @@ async def process_slide_and_fetch_assets(
|
|||
set_dict_at_path(slide.content, icon_path, icon_dict)
|
||||
|
||||
return return_assets
|
||||
|
||||
|
||||
async def process_old_and_new_slides_and_fetch_assets(
|
||||
old_slide_content: dict,
|
||||
new_slide_content: dict,
|
||||
) -> List[ImageAsset]:
|
||||
# Finds all old images
|
||||
old_image_dict_paths = get_dict_paths_with_key(
|
||||
old_slide_content, "__image_prompt__"
|
||||
)
|
||||
old_image_dicts = [
|
||||
get_dict_at_path(old_slide_content, path) for path in old_image_dict_paths
|
||||
]
|
||||
old_image_prompts = [
|
||||
old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts
|
||||
]
|
||||
|
||||
# Finds all old icons
|
||||
old_icon_dict_paths = get_dict_paths_with_key(old_slide_content, "__icon_query__")
|
||||
old_icon_dicts = [
|
||||
get_dict_at_path(old_slide_content, path) for path in old_icon_dict_paths
|
||||
]
|
||||
old_icon_queries = [
|
||||
old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts
|
||||
]
|
||||
|
||||
# Finds all new images
|
||||
new_image_dict_paths = get_dict_paths_with_key(
|
||||
new_slide_content, "__image_prompt__"
|
||||
)
|
||||
new_image_dicts = [
|
||||
get_dict_at_path(new_slide_content, path) for path in new_image_dict_paths
|
||||
]
|
||||
|
||||
# Finds all new icons
|
||||
new_icon_dict_paths = get_dict_paths_with_key(new_slide_content, "__icon_query__")
|
||||
new_icon_dicts = [
|
||||
get_dict_at_path(new_slide_content, path) for path in new_icon_dict_paths
|
||||
]
|
||||
|
||||
# Creates services
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
icon_finder_service = IconFinderService()
|
||||
|
||||
# Creates async tasks for fetching new images
|
||||
async_image_fetch_tasks = []
|
||||
new_images_fetch_status = []
|
||||
|
||||
# Creates async tasks for fetching new icons
|
||||
async_icon_fetch_tasks = []
|
||||
new_icons_fetch_status = []
|
||||
|
||||
# Creates async tasks for fetching new images
|
||||
# Use old image url if prompt is same
|
||||
for new_image in new_image_dicts:
|
||||
if new_image["__image_prompt__"] in old_image_prompts:
|
||||
old_image_url = old_image_dicts[
|
||||
old_image_prompts.index(new_image["__image_prompt__"])
|
||||
]["__image_url__"]
|
||||
new_image["__image_url__"] = old_image_url
|
||||
new_images_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_image_fetch_tasks.append(
|
||||
image_generation_service.generate_image(
|
||||
ImagePrompt(
|
||||
prompt=new_image["__image_prompt__"],
|
||||
)
|
||||
)
|
||||
)
|
||||
new_images_fetch_status.append(True)
|
||||
|
||||
# Creates async tasks for fetching new icons
|
||||
# Use old icon url if query is same
|
||||
for new_icon in new_icon_dicts:
|
||||
if new_icon["__icon_query__"] in old_icon_queries:
|
||||
old_icon_url = old_icon_dicts[
|
||||
old_icon_queries.index(new_icon["__icon_query__"])
|
||||
]["__icon_url__"]
|
||||
new_icon["__icon_url__"] = old_icon_url
|
||||
new_icons_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_icon_fetch_tasks.append(
|
||||
icon_finder_service.search_icons(new_icon["__icon_query__"])
|
||||
)
|
||||
new_icons_fetch_status.append(True)
|
||||
|
||||
new_images = await asyncio.gather(*async_image_fetch_tasks)
|
||||
new_icons = await asyncio.gather(*async_icon_fetch_tasks)
|
||||
|
||||
# list of new assets
|
||||
new_assets = []
|
||||
|
||||
# Sets new image and icon urls for assets that were fetched
|
||||
for i, new_image in enumerate(new_images):
|
||||
if new_images_fetch_status[i]:
|
||||
fetched_image = new_images[i]
|
||||
if isinstance(fetched_image, ImageAsset):
|
||||
new_assets.append(fetched_image)
|
||||
image_url = fetched_image.path
|
||||
else:
|
||||
image_url = fetched_image
|
||||
new_image_dicts[i]["__image_url__"] = image_url
|
||||
|
||||
for i, new_icon in enumerate(new_icons):
|
||||
if new_icons_fetch_status[i]:
|
||||
new_icon_dicts[i]["__icon_url__"] = new_icons[i]
|
||||
|
||||
for i, new_image_dict in enumerate(new_image_dicts):
|
||||
set_dict_at_path(new_slide_content, new_image_dict_paths[i], new_image_dict)
|
||||
|
||||
for i, new_icon_dict in enumerate(new_icon_dicts):
|
||||
set_dict_at_path(new_slide_content, new_icon_dict_paths[i], new_icon_dict)
|
||||
|
||||
return new_assets
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue