From b27efd8cd4da41c1bee9c766cb5ecbd408f87449 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Mon, 21 Jul 2025 02:14:02 +0545 Subject: [PATCH] fix(fastapi): checks if generated image is ImageAsset or str and saves it to database accordingly, feat(fastapi): adds list generated images endpoint --- .../fastapi/api/v1/ppt/endpoints/images.py | 14 +++ .../api/v1/ppt/endpoints/presentation.py | 3 + servers/fastapi/api/v1/ppt/endpoints/slide.py | 99 +-------------- servers/fastapi/models/sql/image_asset.py | 4 +- servers/fastapi/utils/process_slides.py | 116 ++++++++++++++++++ 5 files changed, 140 insertions(+), 96 deletions(-) diff --git a/servers/fastapi/api/v1/ppt/endpoints/images.py b/servers/fastapi/api/v1/ppt/endpoints/images.py index 5af10e88..826f5436 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/images.py +++ b/servers/fastapi/api/v1/ppt/endpoints/images.py @@ -1,4 +1,6 @@ +from typing import List from fastapi import APIRouter +from sqlmodel import select from models.image_prompt import ImagePrompt from models.sql.image_asset import ImageAsset @@ -24,3 +26,15 @@ async def generate_image(prompt: str): sql_session.commit() return image.path + + +@IMAGES_ROUTER.get("/generated", response_model=List[ImageAsset]) +async def get_generated_images(): + try: + with get_sql_session() as sql_session: + images = sql_session.exec( + select(ImageAsset).order_by(ImageAsset.created_at.desc()) + ).all() + return images + except Exception as e: + return {"error": f"Failed to retrieve generated images: {str(e)}"} diff --git a/servers/fastapi/api/v1/ppt/endpoints/presentation.py b/servers/fastapi/api/v1/ppt/endpoints/presentation.py index 5594a8c7..0b70d6f0 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/presentation.py +++ b/servers/fastapi/api/v1/ppt/endpoints/presentation.py @@ -214,7 +214,10 @@ async def stream_presentation(presentation_id: str): content=slide_content, ) slides.append(slide) + + # This will mutate slide async_assets_generation_tasks.append(process_slide_and_fetch_assets(slide)) + yield SSEResponse( event="response", data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}), diff --git a/servers/fastapi/api/v1/ppt/endpoints/slide.py b/servers/fastapi/api/v1/ppt/endpoints/slide.py index d70d0306..3a5b4230 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/slide.py +++ b/servers/fastapi/api/v1/ppt/endpoints/slide.py @@ -1,18 +1,12 @@ -import asyncio from typing import Annotated from fastapi import APIRouter, Body, HTTPException -from models.image_prompt import ImagePrompt -from models.presentation_outline_model import SlideOutlineModel from models.sql.presentation import PresentationModel from models.sql.slide import SlideModel from services.database import get_sql_session -from services.icon_finder_service import IconFinderService -from services.image_generation_service import ImageGenerationService -from utils.asset_directory_utils import get_images_directory -from utils.dict_utils import get_dict_at_path, get_dict_paths_with_key, set_dict_at_path from utils.llm_calls.edit_slide import get_edited_slide_content from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt +from utils.process_slides import process_old_and_new_slides_and_fetch_assets SLIDE_ROUTER = APIRouter(prefix="/slide", tags=["Slide"]) @@ -38,99 +32,16 @@ async def edit_slide(id: Annotated[str, Body()], prompt: Annotated[str, Body()]) prompt, slide_layout, slide, presentation.language ) - # Get old image and icon dicts - old_image_dict_paths = get_dict_paths_with_key(slide.content, "__image_prompt__") - old_image_dicts = [ - get_dict_at_path(slide.content, path) for path in old_image_dict_paths - ] - old_image_prompts = [ - old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts - ] - old_icon_dict_paths = get_dict_paths_with_key(slide.content, "__icon_query__") - old_icon_dicts = [ - get_dict_at_path(slide.content, path) for path in old_icon_dict_paths - ] - old_icon_queries = [ - old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts - ] - - # Get new image and icon dicts - new_image_dict_paths = get_dict_paths_with_key( - edited_slide_content, "__image_prompt__" + # This will mutate edited_slide_content + new_assets = await process_old_and_new_slides_and_fetch_assets( + slide.content, edited_slide_content ) - new_image_dicts = [ - get_dict_at_path(edited_slide_content, path) for path in new_image_dict_paths - ] - new_icon_dict_paths = get_dict_paths_with_key( - edited_slide_content, "__icon_query__" - ) - new_icon_dicts = [ - get_dict_at_path(edited_slide_content, path) for path in new_icon_dict_paths - ] - - image_generation_service = ImageGenerationService(get_images_directory()) - icon_finder_service = IconFinderService() - - async_image_fetch_tasks = [] - new_images_fetch_status = [] - - async_icon_fetch_tasks = [] - new_icons_fetch_status = [] - - for new_image in new_image_dicts: - # Use old image url if prompt is same - if new_image["__image_prompt__"] in old_image_prompts: - old_image_url = old_image_dicts[ - old_image_prompts.index(new_image["__image_prompt__"]) - ]["__image_url__"] - new_image["__image_url__"] = old_image_url - new_images_fetch_status.append(False) - continue - - async_image_fetch_tasks.append( - image_generation_service.generate_image( - ImagePrompt( - prompt=new_image["__image_prompt__"], - ) - ) - ) - new_images_fetch_status.append(True) - - for new_icon in new_icon_dicts: - if new_icon["__icon_query__"] in old_icon_queries: - old_icon_url = old_icon_dicts[ - old_icon_queries.index(new_icon["__icon_query__"]) - ]["__icon_url__"] - new_icon["__icon_url__"] = old_icon_url - new_icons_fetch_status.append(False) - continue - - async_icon_fetch_tasks.append( - icon_finder_service.search_icons(new_icon["__icon_query__"]) - ) - new_icons_fetch_status.append(True) - - new_images = await asyncio.gather(*async_image_fetch_tasks) - new_icons = await asyncio.gather(*async_icon_fetch_tasks) - - for i, new_image in enumerate(new_images): - if new_images_fetch_status[i]: - new_image_dicts[i]["__image_url__"] = new_images[i] - - for i, new_icon in enumerate(new_icons): - if new_icons_fetch_status[i]: - new_icon_dicts[i]["__icon_url__"] = new_icons[i] - - for i, new_image_dict in enumerate(new_image_dicts): - set_dict_at_path(edited_slide_content, new_image_dict_paths[i], new_image_dict) - - for i, new_icon_dict in enumerate(new_icon_dicts): - set_dict_at_path(edited_slide_content, new_icon_dict_paths[i], new_icon_dict) with get_sql_session() as sql_session: sql_session.add(slide) slide.content = edited_slide_content slide.layout = slide_layout.id + sql_session.add_all(new_assets) sql_session.commit() sql_session.refresh(slide) diff --git a/servers/fastapi/models/sql/image_asset.py b/servers/fastapi/models/sql/image_asset.py index 9195a495..2c7b4053 100644 --- a/servers/fastapi/models/sql/image_asset.py +++ b/servers/fastapi/models/sql/image_asset.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional -from sqlalchemy import JSON, Column +from sqlalchemy import JSON, Column, DateTime from sqlmodel import SQLModel, Field from utils.randomizers import get_random_uuid @@ -9,6 +9,6 @@ from utils.randomizers import get_random_uuid class ImageAsset(SQLModel, table=True): id: str = Field(default_factory=get_random_uuid, primary_key=True) - created_at: datetime = Field(default=datetime.now()) + created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now)) path: str extras: Optional[dict] = Field(sa_column=Column(JSON), default=None) diff --git a/servers/fastapi/utils/process_slides.py b/servers/fastapi/utils/process_slides.py index ec043fb4..b4380529 100644 --- a/servers/fastapi/utils/process_slides.py +++ b/servers/fastapi/utils/process_slides.py @@ -58,3 +58,119 @@ async def process_slide_and_fetch_assets( set_dict_at_path(slide.content, icon_path, icon_dict) return return_assets + + +async def process_old_and_new_slides_and_fetch_assets( + old_slide_content: dict, + new_slide_content: dict, +) -> List[ImageAsset]: + # Finds all old images + old_image_dict_paths = get_dict_paths_with_key( + old_slide_content, "__image_prompt__" + ) + old_image_dicts = [ + get_dict_at_path(old_slide_content, path) for path in old_image_dict_paths + ] + old_image_prompts = [ + old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts + ] + + # Finds all old icons + old_icon_dict_paths = get_dict_paths_with_key(old_slide_content, "__icon_query__") + old_icon_dicts = [ + get_dict_at_path(old_slide_content, path) for path in old_icon_dict_paths + ] + old_icon_queries = [ + old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts + ] + + # Finds all new images + new_image_dict_paths = get_dict_paths_with_key( + new_slide_content, "__image_prompt__" + ) + new_image_dicts = [ + get_dict_at_path(new_slide_content, path) for path in new_image_dict_paths + ] + + # Finds all new icons + new_icon_dict_paths = get_dict_paths_with_key(new_slide_content, "__icon_query__") + new_icon_dicts = [ + get_dict_at_path(new_slide_content, path) for path in new_icon_dict_paths + ] + + # Creates services + image_generation_service = ImageGenerationService(get_images_directory()) + icon_finder_service = IconFinderService() + + # Creates async tasks for fetching new images + async_image_fetch_tasks = [] + new_images_fetch_status = [] + + # Creates async tasks for fetching new icons + async_icon_fetch_tasks = [] + new_icons_fetch_status = [] + + # Creates async tasks for fetching new images + # Use old image url if prompt is same + for new_image in new_image_dicts: + if new_image["__image_prompt__"] in old_image_prompts: + old_image_url = old_image_dicts[ + old_image_prompts.index(new_image["__image_prompt__"]) + ]["__image_url__"] + new_image["__image_url__"] = old_image_url + new_images_fetch_status.append(False) + continue + + async_image_fetch_tasks.append( + image_generation_service.generate_image( + ImagePrompt( + prompt=new_image["__image_prompt__"], + ) + ) + ) + new_images_fetch_status.append(True) + + # Creates async tasks for fetching new icons + # Use old icon url if query is same + for new_icon in new_icon_dicts: + if new_icon["__icon_query__"] in old_icon_queries: + old_icon_url = old_icon_dicts[ + old_icon_queries.index(new_icon["__icon_query__"]) + ]["__icon_url__"] + new_icon["__icon_url__"] = old_icon_url + new_icons_fetch_status.append(False) + continue + + async_icon_fetch_tasks.append( + icon_finder_service.search_icons(new_icon["__icon_query__"]) + ) + new_icons_fetch_status.append(True) + + new_images = await asyncio.gather(*async_image_fetch_tasks) + new_icons = await asyncio.gather(*async_icon_fetch_tasks) + + # list of new assets + new_assets = [] + + # Sets new image and icon urls for assets that were fetched + for i, new_image in enumerate(new_images): + if new_images_fetch_status[i]: + fetched_image = new_images[i] + if isinstance(fetched_image, ImageAsset): + new_assets.append(fetched_image) + image_url = fetched_image.path + else: + image_url = fetched_image + new_image_dicts[i]["__image_url__"] = image_url + + for i, new_icon in enumerate(new_icons): + if new_icons_fetch_status[i]: + new_icon_dicts[i]["__icon_url__"] = new_icons[i] + + for i, new_image_dict in enumerate(new_image_dicts): + set_dict_at_path(new_slide_content, new_image_dict_paths[i], new_image_dict) + + for i, new_icon_dict in enumerate(new_icon_dicts): + set_dict_at_path(new_slide_content, new_icon_dict_paths[i], new_icon_dict) + + return new_assets