From ea81417bcbfc3b0cbc5f35e5fbb763554dbce167 Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Fri, 16 May 2025 16:38:14 +0545 Subject: [PATCH] Changes: better assets handling in fastapi --- servers/fastapi/api/models.py | 21 ++ .../api/routers/presentation/handlers/edit.py | 312 ++++++++---------- .../presentation/handlers/generate_image.py | 8 +- .../presentation/handlers/generate_stream.py | 109 ++---- servers/fastapi/api/utils.py | 20 +- .../fastapi/image_processor/icons_finder.py | 24 +- .../fastapi/image_processor/images_finder.py | 28 +- .../ppt_generator/models/other_models.py | 4 +- .../ppt_generator/models/slide_model.py | 4 +- 9 files changed, 231 insertions(+), 299 deletions(-) diff --git a/servers/fastapi/api/models.py b/servers/fastapi/api/models.py index 06a0dd23..ff63c0bb 100644 --- a/servers/fastapi/api/models.py +++ b/servers/fastapi/api/models.py @@ -1,3 +1,4 @@ +import json from typing import Optional from pydantic import BaseModel @@ -37,6 +38,26 @@ class SSEResponse(BaseModel): return f"event: {self.event}\ndata: {self.data}\n\n" +class SSEStatusResponse(BaseModel): + status: str + + def to_string(self): + return SSEResponse( + event="response", data=json.dumps({"type": "status", "status": self.status}) + ).to_string() + + +class SSECompleteResponse(BaseModel): + key: str + value: object + + def to_string(self): + return SSEResponse( + event="response", + data=json.dumps({"type": "complete", self.key: self.value}), + ).to_string() + + class UserConfig(BaseModel): LLM: Optional[str] = None OPENAI_API_KEY: Optional[str] = None diff --git a/servers/fastapi/api/routers/presentation/handlers/edit.py b/servers/fastapi/api/routers/presentation/handlers/edit.py index 5db2f415..13261df1 100644 --- a/servers/fastapi/api/routers/presentation/handlers/edit.py +++ b/servers/fastapi/api/routers/presentation/handlers/edit.py @@ -1,8 +1,8 @@ import asyncio -import os -from typing import List, Tuple +from typing import Literal import uuid +from sqlalchemy import update from sqlmodel import select from api.models import LogMetadata from api.routers.presentation.models import ( @@ -10,9 +10,15 @@ from api.routers.presentation.models import ( ) from api.services.instances import temp_file_service from api.services.logging import LoggingService -from api.utils import get_presentation_dir +from api.utils import get_presentation_dir, get_presentation_images_dir +from image_processor.icons_vectorstore_utils import get_icons_vectorstore from image_processor.images_finder import generate_image from image_processor.icons_finder import get_icon +from ppt_generator.models.other_models import SlideType +from ppt_generator.models.query_and_prompt_models import ( + IconQueryCollectionWithData, + ImagePromptWithThemeAndAspectRatio, +) from ppt_generator.models.slide_model import SlideModel from ppt_generator.slide_generator import ( get_edited_slide_content_model, @@ -21,7 +27,6 @@ from ppt_generator.slide_generator import ( from ppt_generator.slide_model_utils import SlideModelUtils from api.sql_models import PresentationSqlModel, SlideSqlModel from api.services.database import get_sql_session -from ppt_generator.models.other_models import SlideType class PresentationEditHandler: @@ -49,193 +54,138 @@ class PresentationEditHandler: with get_sql_session() as sql_session: presentation = sql_session.get(PresentationSqlModel, self.presentation_id) slide_to_edit_sql = sql_session.exec( - select(SlideSqlModel).where(SlideSqlModel.index == self.slide_index) + select(SlideSqlModel).where( + SlideSqlModel.index == self.slide_index, + SlideSqlModel.presentation == self.presentation_id, + ) ).first() - slide_to_edit = SlideModel.from_dict( - slide_to_edit_sql.model_dump(mode="json") - ) + slide_to_edit = SlideModel.from_dict(slide_to_edit_sql.model_dump(mode="json")) + new_slide_type = SlideType( + (await get_slide_type_from_prompt(self.prompt, slide_to_edit)).slide_type + ) - new_slide_type = await get_slide_type_from_prompt( - self.prompt, slide_to_edit - ) + edited_content = await get_edited_slide_content_model( + self.prompt, + new_slide_type, + slide_to_edit, + presentation.theme, + presentation.language, + ) - edited_content = await get_edited_slide_content_model( - self.prompt, - SlideType(new_slide_type.slide_type), - slide_to_edit, - presentation.theme, - presentation.language, - ) + new_slide_model = SlideModel( + id=slide_to_edit.id, + index=slide_to_edit.index, + type=new_slide_type, + design_index=slide_to_edit.design_index, + images=None, + icons=None, + presentation=slide_to_edit.presentation, + properties=slide_to_edit.properties, + content=edited_content, + ) - new_slide_model = SlideModel( - id=slide_to_edit.id, - index=slide_to_edit.index, - type=SlideType(new_slide_type.slide_type), - design_index=slide_to_edit.design_index, - images=None, - icons=None, - presentation=slide_to_edit.presentation, - content=edited_content, - ) + new_slide_images_count = new_slide_model.images_count + new_slide_icons_count = new_slide_model.icons_count - images_to_delete, images_to_generate, icons_to_delete, icons_to_generate = ( - self.get_all_assets_to_generate_and_delete( - slide_to_edit, - new_slide_model, + slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model) + + new_slide_images: dict[int, str | ImagePromptWithThemeAndAspectRatio] = {} + new_slide_icons: dict[int, str | IconQueryCollectionWithData] = {} + + # ? Checks if image prompts have changed + # ? If they have, it will search if it is same as the old one but used at a different index + # ? If it is, it will use the old image + # ? If it is not, it will generate a new image + if new_slide_images_count: + new_image_prompts = slide_model_utils.get_image_prompts() + old_image_prompts = ( + slide_to_edit.content.image_prompts + if slide_to_edit.images_count + else [] + ) + for index in range(new_slide_images_count): + new_prompt = new_slide_model.content.image_prompts[index] + for old_prompt in old_image_prompts: + if old_prompt != new_prompt: + continue + if index < len(slide_to_edit.images or []): + new_slide_images[index] = slide_to_edit.images[index] + break + if not new_slide_images.get(index): + new_slide_images[index] = new_image_prompts[index] + + # ? Checks if icon queries have changed + # ? If they have, it will search if it is same as the old one but used at a different index + # ? If it is, it will use the old icon + # ? If it is not, it will generate a new icon + if new_slide_icons_count: + new_icon_queries = slide_model_utils.get_icon_queries() + old_icon_queries = ( + slide_to_edit.content.icon_queries if slide_to_edit.icons_count else [] + ) + for index in range(new_slide_icons_count): + new_query = new_slide_model.content.icon_queries[index] + for old_query in old_icon_queries: + if old_query != new_query: + continue + if index < len(slide_to_edit.icons or []): + new_slide_icons[index] = slide_to_edit.icons[index] + break + if not new_slide_icons.get(index): + new_slide_icons[index] = new_icon_queries[index] + + images_to_generate = [] + for each in new_slide_images.values(): + if isinstance(each, ImagePromptWithThemeAndAspectRatio): + images_to_generate.append(each) + + icons_to_generate = [] + for each in new_slide_icons.values(): + if isinstance(each, IconQueryCollectionWithData): + icons_to_generate.append(each) + + images_directory = get_presentation_images_dir(self.presentation_id) + if icons_to_generate: + icons_vectorstore = get_icons_vectorstore() + + coroutines = [ + generate_image(each_prompt, images_directory) + for each_prompt in images_to_generate + ] + [ + get_icon(icons_vectorstore, each_query) for each_query in icons_to_generate + ] + generated_assets = await asyncio.gather(*coroutines) + generated_image_count = len(images_to_generate) + generate_images = generated_assets[:generated_image_count] + generate_icons = generated_assets[generated_image_count:] + + for each in new_slide_images: + if isinstance(new_slide_images[each], ImagePromptWithThemeAndAspectRatio): + new_slide_images[each] = generate_images.pop(0) + + for each in new_slide_icons: + if isinstance(new_slide_icons[each], IconQueryCollectionWithData): + new_slide_icons[each] = generate_icons.pop(0) + + with get_sql_session() as sql_session: + sql_session.exec( + update(SlideSqlModel) + .where(SlideSqlModel.id == slide_to_edit.id) + .values( + type=new_slide_type.value, + images=list(new_slide_images.values()), + icons=list(new_slide_icons.values()), + content=new_slide_model.content.model_dump(mode="json"), ) ) - new_image_paths = slide_to_edit.images or [] - new_icon_paths = slide_to_edit.icons or [] - images_count = len(new_image_paths) - icons_count = len(new_icon_paths) - for index in images_to_generate: - file_key = f"{self.presentation_dir}/images/{str(uuid.uuid4())}.jpg" - if index < images_count: - new_image_paths.pop(index) - new_image_paths.insert(index, file_key) - else: - new_image_paths.append(file_key) - for index in icons_to_generate: - file_key = f"{self.presentation_dir}/icons/{str(uuid.uuid4())}.png" - if index < icons_count: - new_icon_paths.pop(index) - new_icon_paths.insert(index, file_key) - else: - new_icon_paths.append(file_key) - - if new_image_paths: - new_slide_model.images = new_image_paths - if new_icon_paths: - new_slide_model.icons = new_icon_paths - - # ? Images and Icons are related to this presentation will be deleted while deleting presentation. - # objects_to_delete = [*images_to_delete, *icons_to_delete] - # if objects_to_delete: - # for each in objects_to_delete: - # os.remove(each) - - new_image_prompts = {} - new_icon_queries = {} - if images_to_generate: - slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model) - image_prompts = slide_model_utils.get_image_prompts() - for image_index in images_to_generate: - new_image_prompts[new_slide_model.images[image_index]] = ( - image_prompts[image_index] - ) - - if icons_to_generate: - slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model) - icon_queries = slide_model_utils.get_icon_queries() - for icon_index in icons_to_generate: - new_icon_queries[new_slide_model.icons[icon_index]] = icon_queries[ - icon_index - ] - - coroutines = [ - generate_image(value, key) for key, value in new_image_prompts.items() - ] + [get_icon(value, key) for key, value in new_icon_queries.items()] - - await asyncio.gather(*coroutines) - - slide_to_edit.images = new_slide_model.images - slide_to_edit.icons = new_slide_model.icons - slide_to_edit.content = new_slide_model.content - slide_to_edit.type = SlideType(new_slide_type.slide_type) - - slide_to_edit_sql.index = slide_to_edit.index - slide_to_edit_sql.type = slide_to_edit.type.value - slide_to_edit_sql.design_index = slide_to_edit.design_index - slide_to_edit_sql.images = slide_to_edit.images - slide_to_edit_sql.icons = slide_to_edit.icons - slide_to_edit_sql.content = slide_to_edit.content.model_dump(mode="json") - slide_to_edit_sql.properties = slide_to_edit.properties - slide_to_edit_sql.presentation = slide_to_edit.presentation sql_session.commit() - sql_session.refresh(slide_to_edit_sql) + slide_to_edit_sql = sql_session.exec( + select(SlideSqlModel).where(SlideSqlModel.id == slide_to_edit.id) + ).first() logging_service.logger.info( - logging_service.message(slide_to_edit.model_dump(mode="json")), + logging_service.message(slide_to_edit_sql.model_dump(mode="json")), extra=log_metadata.model_dump(), ) - return slide_to_edit - - def get_all_assets_to_generate_and_delete( - self, - old_slide_model: SlideModel, - new_slide_model: SlideModel, - ) -> Tuple[List[str], List[str], List[str], List[str]]: - - images_to_delete, images_to_generate = self.get_assets_to_generate_and_delete( - old_slide_model, - new_slide_model, - "image_prompts", - "images", - ) - - icons_to_delete, icons_to_generate = self.get_assets_to_generate_and_delete( - old_slide_model, - new_slide_model, - "icon_queries", - "icons", - ) - - return images_to_delete, images_to_generate, icons_to_delete, icons_to_generate - - def get_assets_to_generate_and_delete( - self, - old_slide_model: SlideModel, - new_slide_model: SlideModel, - content_attr: str, - slide_model_attr: str, - ) -> Tuple[List[str], List[str]]: - - items_to_delete = [] - items_to_generate = [] - - existing_paths = getattr(old_slide_model, slide_model_attr, []) - new_content_items = getattr(new_slide_model.content, content_attr, []) - old_content_items = getattr(old_slide_model.content, content_attr, []) - - # Case 1: No new items but slide has existing items - delete all - if not new_content_items and existing_paths: - items_to_delete.extend(existing_paths) - return items_to_delete, items_to_generate - - # Case 2: New items but slide has no existing items - generate all - if new_content_items and not existing_paths: - items_to_generate = [idx for idx in range(len(new_content_items))] - return items_to_delete, items_to_generate - - # Case 3: Both new and existing items - compare and update - if new_content_items and existing_paths: - new_count = len(new_content_items) - old_count = len(existing_paths) - - generate_idx = [] - for idx in range(max(new_count, old_count)): - # Generate additional new items - if idx >= old_count: - generate_idx.append(idx) - # Delete excess old items - elif idx >= new_count: - items_to_delete.append(existing_paths[idx]) - # Compare and update changed items - else: - old_value = old_content_items[idx] - new_value = new_content_items[idx] - if old_value != new_value: - items_to_delete.append(existing_paths[idx]) - generate_idx.append(idx) - - if generate_idx: - items_to_generate = generate_idx - - filtered_items_to_delete = [] - for each in items_to_delete: - if not each: - continue - filtered_items_to_delete.append(each) - - return filtered_items_to_delete, items_to_generate + return slide_to_edit_sql diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_image.py b/servers/fastapi/api/routers/presentation/handlers/generate_image.py index 513116b5..06f037fc 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_image.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_image.py @@ -7,7 +7,7 @@ from api.routers.presentation.models import ( ) from api.services.logging import LoggingService from api.services.instances import temp_file_service -from api.utils import get_presentation_dir +from api.utils import get_presentation_dir, get_presentation_images_dir from image_processor.images_finder import generate_image @@ -28,10 +28,8 @@ class GenerateImageHandler: extra=log_metadata.model_dump(), ) - image_path = os.path.join( - self.presentation_dir, "generated_images", str(uuid.uuid4()) + ".jpg" - ) - await generate_image(self.data.prompt, image_path) + images_directory = get_presentation_images_dir(self.data.presentation_id) + image_path = await generate_image(self.data.prompt, images_directory) response = PresentationAndPaths( presentation_id=self.data.presentation_id, paths=[image_path] diff --git a/servers/fastapi/api/routers/presentation/handlers/generate_stream.py b/servers/fastapi/api/routers/presentation/handlers/generate_stream.py index 90d6e469..6bf61aa7 100644 --- a/servers/fastapi/api/routers/presentation/handlers/generate_stream.py +++ b/servers/fastapi/api/routers/presentation/handlers/generate_stream.py @@ -1,14 +1,12 @@ import asyncio import json -import os from typing import List -import uuid from fastapi import HTTPException from fastapi.responses import StreamingResponse -from sqlmodel import delete, select +from sqlmodel import delete -from api.models import LogMetadata, SSEResponse +from api.models import LogMetadata, SSECompleteResponse, SSEResponse, SSEStatusResponse from api.routers.presentation.models import ( PresentationAndSlides, @@ -17,7 +15,7 @@ from api.routers.presentation.models import ( from api.services.database import get_sql_session from api.services.logging import LoggingService from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel -from api.utils import get_presentation_dir +from api.utils import get_presentation_dir, get_presentation_images_dir from image_processor.icons_vectorstore_utils import get_icons_vectorstore from image_processor.images_finder import generate_image from image_processor.icons_finder import get_icon @@ -119,40 +117,15 @@ class PresentationGenerateStreamHandler: content["index"] = i content["presentation"] = presentation.id slide_model = SlideModel(**content) - slide_content = slide_model.content - has_images = hasattr(slide_content, "image_prompts") - has_icons = hasattr(slide_content, "icon_queries") - - if has_images: - slide_model.images = [ - os.path.join( - self.presentation_dir, - "images", - f"{str(uuid.uuid4())}.jpg", - ) - for _ in range(len(slide_content.image_prompts)) - ] - - if has_icons: - slide_model.icons = [ - os.path.join( - self.presentation_dir, - "icons", - f"{str(uuid.uuid4())}.png", - ) - for _ in range(len(slide_content.icon_queries)) - ] - slide_models.append(slide_model) - yield SSEResponse( - event="response", - data=json.dumps({"type": "status", "status": "Fetching slide assets"}), - ).to_string() - async for result in self.fetch_slide_assets(slide_models): yield result + print("-" * 40) + print(slide_models) + print("-" * 40) + slide_sql_models = [ SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models ] @@ -163,73 +136,55 @@ class PresentationGenerateStreamHandler: for each in slide_sql_models: sql_session.refresh(each) - yield SSEResponse( - event="response", - data=json.dumps({"type": "status", "status": "Packing slide data"}), - ).to_string() + yield SSEStatusResponse(status="Packing slide data").to_string() response = PresentationAndSlides( presentation=presentation, slides=slide_sql_models ).to_response_dict() - yield SSEResponse( - event="response", - data=json.dumps({"type": "complete", "presentation": response}), - ).to_string() - yield SSEResponse( - event="response", - data=json.dumps({"type": "closing", "content": "First Warning"}), - ).to_string() - await asyncio.sleep(3) - yield SSEResponse( - event="response", - data=json.dumps({"type": "closing", "content": "Final Warning"}), - ).to_string() + yield SSECompleteResponse(key="presentation", value=response).to_string() async def fetch_slide_assets(self, slide_models: List[SlideModel]): image_prompts = [] icon_queries = [] - image_paths = [] - icon_paths = [] - for each_slide_model in slide_models: slide_model_utils = SlideModelUtils(self.theme, each_slide_model) + image_prompts.extend(slide_model_utils.get_image_prompts()) + icon_queries.extend(slide_model_utils.get_icon_queries()) - if each_slide_model.images: - prompts = slide_model_utils.get_image_prompts() - image_prompts.extend(prompts) - image_paths.extend(each_slide_model.images) - if each_slide_model.icons: - icon_queries.extend(slide_model_utils.get_icon_queries()) - icon_paths.extend(each_slide_model.icons) - - if icon_paths: + if icon_queries: icon_vector_store = get_icons_vectorstore() + images_directory = get_presentation_images_dir(self.presentation_id) + coroutines = [ generate_image( each, - image_path, + images_directory, ) - for each, image_path in zip(image_prompts, image_paths) - ] + [ - get_icon(icon_vector_store, each, icon_path) - for each, icon_path in zip(icon_queries, icon_paths) - ] + for each in image_prompts + ] + [get_icon(icon_vector_store, each) for each in icon_queries] assets_future = asyncio.gather(*coroutines) while not assets_future.done(): - status = SSEResponse( - event="response", - data=json.dumps({"status": "Fetching slide assets..."}), - ).to_string() + status = SSEStatusResponse(status="Fetching slide assets").to_string() yield status await asyncio.sleep(5) - await assets_future + assets = await assets_future - yield SSEResponse( - event="response", data=json.dumps({"status": "Slide assets fetched"}) - ).to_string() + image_prompts_len = len(image_prompts) + + images = assets[:image_prompts_len] + icons = assets[image_prompts_len:] + + for each_slide_model in slide_models: + each_slide_model.images = images[: each_slide_model.images_count] + images = images[each_slide_model.images_count :] + + each_slide_model.icons = icons[: each_slide_model.icons_count] + icons = icons[each_slide_model.icons_count :] + + yield SSEStatusResponse(status="Slide assets fetched").to_string() diff --git a/servers/fastapi/api/utils.py b/servers/fastapi/api/utils.py index 5ec5f630..aef1876c 100644 --- a/servers/fastapi/api/utils.py +++ b/servers/fastapi/api/utils.py @@ -20,6 +20,14 @@ def get_presentation_dir(presentation_id: str) -> str: return presentation_dir +def get_presentation_images_dir(presentation_id: str) -> str: + presentation_images_dir = os.path.join( + get_presentation_dir(presentation_id), "images" + ) + os.makedirs(presentation_images_dir, exist_ok=True) + return presentation_images_dir + + def get_user_config(): user_config_path = os.getenv("USER_CONFIG_PATH") @@ -144,14 +152,14 @@ async def handle_errors( def sanitize_filename(filename: str) -> str: name, ext = os.path.splitext(filename) - sanitized = re.sub(r'[\\/:*?"<>|]', '_', name) - sanitized = re.sub(r'[\s_]+', '_', sanitized) - sanitized = sanitized.strip(' .') + sanitized = re.sub(r'[\\/:*?"<>|]', "_", name) + sanitized = re.sub(r"[\s_]+", "_", sanitized) + sanitized = sanitized.strip(" .") if not sanitized: - sanitized = 'untitled' - + sanitized = "untitled" + if len(sanitized) > 200: sanitized = sanitized[:200] - + return sanitized + ext diff --git a/servers/fastapi/image_processor/icons_finder.py b/servers/fastapi/image_processor/icons_finder.py index d8b3108b..0f64e069 100644 --- a/servers/fastapi/image_processor/icons_finder.py +++ b/servers/fastapi/image_processor/icons_finder.py @@ -1,4 +1,3 @@ -import os from typing import List, Optional from api.utils import get_resource @@ -12,22 +11,15 @@ from langchain_core.vectorstores import InMemoryVectorStore async def get_icon( vector_store: InMemoryVectorStore, input: IconQueryCollectionWithData, - output_path: str, ) -> str: - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - query = input.icon_query.queries[0] - results = vector_store.similarity_search(query=query, k=1) - icon_name = results[0].page_content - - with open(output_path, "wb") as f_a: - try: - with open(get_resource(f"assets/icons/bold/{icon_name}.png"), "rb") as f_b: - f_a.write(f_b.read()) - except Exception as e: - print("Error finding icon: ", e) - with open(get_resource("assets/icons/placeholder.png"), "rb") as f_b: - f_a.write(f_b.read()) + try: + query = input.icon_query.queries[0] + results = vector_store.similarity_search(query=query, k=1) + icon_name = results[0].page_content + return get_resource(f"assets/icons/bold/{icon_name}.png") + except Exception as e: + print("Error finding icon: ", e) + return get_resource("assets/icons/placeholder.png") async def get_icons( diff --git a/servers/fastapi/image_processor/images_finder.py b/servers/fastapi/image_processor/images_finder.py index fe7dd332..93c0f924 100644 --- a/servers/fastapi/image_processor/images_finder.py +++ b/servers/fastapi/image_processor/images_finder.py @@ -1,6 +1,7 @@ import asyncio import base64 import os +import uuid import aiohttp from langchain_google_genai import ChatGoogleGenerativeAI from openai import OpenAI @@ -13,10 +14,8 @@ from api.utils import get_resource async def generate_image( input: ImagePromptWithThemeAndAspectRatio, - output_path: str, + output_directory: str, ) -> str: - os.makedirs(os.path.dirname(output_path), exist_ok=True) - image_prompt = f"{input.image_prompt}, {input.theme_prompt}" print(f"Request - Generating Image for {image_prompt}") @@ -26,15 +25,17 @@ async def generate_image( if os.getenv("LLM") == "openai" else generate_image_google ) - await image_gen_func(image_prompt, output_path) + image_path = await image_gen_func(image_prompt, output_directory) + if image_path and os.path.exists(image_path): + return image_path + raise Exception(f"Image not found at {image_path}") + except Exception as e: print(f"Error generating image: {e}") - with open(get_resource("assets/images/placeholder.jpg"), "rb") as f_a: - with open(output_path, "wb") as f_b: - f_b.write(f_a.read()) + return get_resource("assets/images/placeholder.jpg") -async def generate_image_openai(prompt: str, output_path: str): +async def generate_image_openai(prompt: str, output_directory: str) -> str: client = OpenAI() result = await asyncio.to_thread( client.images.generate, @@ -48,11 +49,13 @@ async def generate_image_openai(prompt: str, output_path: str): async with aiohttp.ClientSession() as session: async with session.get(image_url) as response: image_bytes = await response.read() - with open(output_path, "wb") as f: + image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg") + with open(image_path, "wb") as f: f.write(image_bytes) + return image_path -async def generate_image_google(prompt: str, output_path: str): +async def generate_image_google(prompt: str, output_directory: str) -> str: response = await ChatGoogleGenerativeAI( model="gemini-2.0-flash-preview-image-generation" ).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]}) @@ -64,5 +67,8 @@ async def generate_image_google(prompt: str, output_path: str): ) base64_image = image_block["image_url"].get("url").split(",")[-1] - with open(output_path, "wb") as f: + image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg") + with open(image_path, "wb") as f: f.write(base64.b64decode(base64_image)) + + return image_path diff --git a/servers/fastapi/ppt_generator/models/other_models.py b/servers/fastapi/ppt_generator/models/other_models.py index 4faffac5..4cd625ef 100644 --- a/servers/fastapi/ppt_generator/models/other_models.py +++ b/servers/fastapi/ppt_generator/models/other_models.py @@ -28,4 +28,6 @@ class SlideType(Enum): class SlideTypeModel(BaseModel): - slide_type: int = Field(default=1, gte=1, lte=9, description="Slide type from 1 to 9") + slide_type: int = Field( + default=1, gte=1, lte=9, description="Slide type from 1 to 9" + ) diff --git a/servers/fastapi/ppt_generator/models/slide_model.py b/servers/fastapi/ppt_generator/models/slide_model.py index 92d5a0e8..204700a5 100644 --- a/servers/fastapi/ppt_generator/models/slide_model.py +++ b/servers/fastapi/ppt_generator/models/slide_model.py @@ -58,10 +58,10 @@ class SlideModel(BaseModel): def images_count(self): if not hasattr(self.content, "image_prompts"): return 0 - return len(self.content.image_prompts) + return len(self.content.image_prompts or []) @property def icons_count(self): if not hasattr(self.content, "icon_queries"): return 0 - return len(self.content.icon_queries) + return len(self.content.icon_queries or [])