Changes: better assets handling in fastapi

This commit is contained in:
sauravniraula 2025-05-16 16:38:14 +05:45
parent e6d706d366
commit ea81417bcb
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
9 changed files with 231 additions and 299 deletions

View file

@ -1,3 +1,4 @@
import json
from typing import Optional
from pydantic import BaseModel
@ -37,6 +38,26 @@ class SSEResponse(BaseModel):
return f"event: {self.event}\ndata: {self.data}\n\n"
class SSEStatusResponse(BaseModel):
status: str
def to_string(self):
return SSEResponse(
event="response", data=json.dumps({"type": "status", "status": self.status})
).to_string()
class SSECompleteResponse(BaseModel):
key: str
value: object
def to_string(self):
return SSEResponse(
event="response",
data=json.dumps({"type": "complete", self.key: self.value}),
).to_string()
class UserConfig(BaseModel):
LLM: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None

View file

@ -1,8 +1,8 @@
import asyncio
import os
from typing import List, Tuple
from typing import Literal
import uuid
from sqlalchemy import update
from sqlmodel import select
from api.models import LogMetadata
from api.routers.presentation.models import (
@ -10,9 +10,15 @@ from api.routers.presentation.models import (
)
from api.services.instances import temp_file_service
from api.services.logging import LoggingService
from api.utils import get_presentation_dir
from api.utils import get_presentation_dir, get_presentation_images_dir
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
from image_processor.images_finder import generate_image
from image_processor.icons_finder import get_icon
from ppt_generator.models.other_models import SlideType
from ppt_generator.models.query_and_prompt_models import (
IconQueryCollectionWithData,
ImagePromptWithThemeAndAspectRatio,
)
from ppt_generator.models.slide_model import SlideModel
from ppt_generator.slide_generator import (
get_edited_slide_content_model,
@ -21,7 +27,6 @@ from ppt_generator.slide_generator import (
from ppt_generator.slide_model_utils import SlideModelUtils
from api.sql_models import PresentationSqlModel, SlideSqlModel
from api.services.database import get_sql_session
from ppt_generator.models.other_models import SlideType
class PresentationEditHandler:
@ -49,193 +54,138 @@ class PresentationEditHandler:
with get_sql_session() as sql_session:
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
slide_to_edit_sql = sql_session.exec(
select(SlideSqlModel).where(SlideSqlModel.index == self.slide_index)
select(SlideSqlModel).where(
SlideSqlModel.index == self.slide_index,
SlideSqlModel.presentation == self.presentation_id,
)
).first()
slide_to_edit = SlideModel.from_dict(
slide_to_edit_sql.model_dump(mode="json")
)
slide_to_edit = SlideModel.from_dict(slide_to_edit_sql.model_dump(mode="json"))
new_slide_type = SlideType(
(await get_slide_type_from_prompt(self.prompt, slide_to_edit)).slide_type
)
new_slide_type = await get_slide_type_from_prompt(
self.prompt, slide_to_edit
)
edited_content = await get_edited_slide_content_model(
self.prompt,
new_slide_type,
slide_to_edit,
presentation.theme,
presentation.language,
)
edited_content = await get_edited_slide_content_model(
self.prompt,
SlideType(new_slide_type.slide_type),
slide_to_edit,
presentation.theme,
presentation.language,
)
new_slide_model = SlideModel(
id=slide_to_edit.id,
index=slide_to_edit.index,
type=new_slide_type,
design_index=slide_to_edit.design_index,
images=None,
icons=None,
presentation=slide_to_edit.presentation,
properties=slide_to_edit.properties,
content=edited_content,
)
new_slide_model = SlideModel(
id=slide_to_edit.id,
index=slide_to_edit.index,
type=SlideType(new_slide_type.slide_type),
design_index=slide_to_edit.design_index,
images=None,
icons=None,
presentation=slide_to_edit.presentation,
content=edited_content,
)
new_slide_images_count = new_slide_model.images_count
new_slide_icons_count = new_slide_model.icons_count
images_to_delete, images_to_generate, icons_to_delete, icons_to_generate = (
self.get_all_assets_to_generate_and_delete(
slide_to_edit,
new_slide_model,
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
new_slide_images: dict[int, str | ImagePromptWithThemeAndAspectRatio] = {}
new_slide_icons: dict[int, str | IconQueryCollectionWithData] = {}
# ? Checks if image prompts have changed
# ? If they have, it will search if it is same as the old one but used at a different index
# ? If it is, it will use the old image
# ? If it is not, it will generate a new image
if new_slide_images_count:
new_image_prompts = slide_model_utils.get_image_prompts()
old_image_prompts = (
slide_to_edit.content.image_prompts
if slide_to_edit.images_count
else []
)
for index in range(new_slide_images_count):
new_prompt = new_slide_model.content.image_prompts[index]
for old_prompt in old_image_prompts:
if old_prompt != new_prompt:
continue
if index < len(slide_to_edit.images or []):
new_slide_images[index] = slide_to_edit.images[index]
break
if not new_slide_images.get(index):
new_slide_images[index] = new_image_prompts[index]
# ? Checks if icon queries have changed
# ? If they have, it will search if it is same as the old one but used at a different index
# ? If it is, it will use the old icon
# ? If it is not, it will generate a new icon
if new_slide_icons_count:
new_icon_queries = slide_model_utils.get_icon_queries()
old_icon_queries = (
slide_to_edit.content.icon_queries if slide_to_edit.icons_count else []
)
for index in range(new_slide_icons_count):
new_query = new_slide_model.content.icon_queries[index]
for old_query in old_icon_queries:
if old_query != new_query:
continue
if index < len(slide_to_edit.icons or []):
new_slide_icons[index] = slide_to_edit.icons[index]
break
if not new_slide_icons.get(index):
new_slide_icons[index] = new_icon_queries[index]
images_to_generate = []
for each in new_slide_images.values():
if isinstance(each, ImagePromptWithThemeAndAspectRatio):
images_to_generate.append(each)
icons_to_generate = []
for each in new_slide_icons.values():
if isinstance(each, IconQueryCollectionWithData):
icons_to_generate.append(each)
images_directory = get_presentation_images_dir(self.presentation_id)
if icons_to_generate:
icons_vectorstore = get_icons_vectorstore()
coroutines = [
generate_image(each_prompt, images_directory)
for each_prompt in images_to_generate
] + [
get_icon(icons_vectorstore, each_query) for each_query in icons_to_generate
]
generated_assets = await asyncio.gather(*coroutines)
generated_image_count = len(images_to_generate)
generate_images = generated_assets[:generated_image_count]
generate_icons = generated_assets[generated_image_count:]
for each in new_slide_images:
if isinstance(new_slide_images[each], ImagePromptWithThemeAndAspectRatio):
new_slide_images[each] = generate_images.pop(0)
for each in new_slide_icons:
if isinstance(new_slide_icons[each], IconQueryCollectionWithData):
new_slide_icons[each] = generate_icons.pop(0)
with get_sql_session() as sql_session:
sql_session.exec(
update(SlideSqlModel)
.where(SlideSqlModel.id == slide_to_edit.id)
.values(
type=new_slide_type.value,
images=list(new_slide_images.values()),
icons=list(new_slide_icons.values()),
content=new_slide_model.content.model_dump(mode="json"),
)
)
new_image_paths = slide_to_edit.images or []
new_icon_paths = slide_to_edit.icons or []
images_count = len(new_image_paths)
icons_count = len(new_icon_paths)
for index in images_to_generate:
file_key = f"{self.presentation_dir}/images/{str(uuid.uuid4())}.jpg"
if index < images_count:
new_image_paths.pop(index)
new_image_paths.insert(index, file_key)
else:
new_image_paths.append(file_key)
for index in icons_to_generate:
file_key = f"{self.presentation_dir}/icons/{str(uuid.uuid4())}.png"
if index < icons_count:
new_icon_paths.pop(index)
new_icon_paths.insert(index, file_key)
else:
new_icon_paths.append(file_key)
if new_image_paths:
new_slide_model.images = new_image_paths
if new_icon_paths:
new_slide_model.icons = new_icon_paths
# ? Images and Icons are related to this presentation will be deleted while deleting presentation.
# objects_to_delete = [*images_to_delete, *icons_to_delete]
# if objects_to_delete:
# for each in objects_to_delete:
# os.remove(each)
new_image_prompts = {}
new_icon_queries = {}
if images_to_generate:
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
image_prompts = slide_model_utils.get_image_prompts()
for image_index in images_to_generate:
new_image_prompts[new_slide_model.images[image_index]] = (
image_prompts[image_index]
)
if icons_to_generate:
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
icon_queries = slide_model_utils.get_icon_queries()
for icon_index in icons_to_generate:
new_icon_queries[new_slide_model.icons[icon_index]] = icon_queries[
icon_index
]
coroutines = [
generate_image(value, key) for key, value in new_image_prompts.items()
] + [get_icon(value, key) for key, value in new_icon_queries.items()]
await asyncio.gather(*coroutines)
slide_to_edit.images = new_slide_model.images
slide_to_edit.icons = new_slide_model.icons
slide_to_edit.content = new_slide_model.content
slide_to_edit.type = SlideType(new_slide_type.slide_type)
slide_to_edit_sql.index = slide_to_edit.index
slide_to_edit_sql.type = slide_to_edit.type.value
slide_to_edit_sql.design_index = slide_to_edit.design_index
slide_to_edit_sql.images = slide_to_edit.images
slide_to_edit_sql.icons = slide_to_edit.icons
slide_to_edit_sql.content = slide_to_edit.content.model_dump(mode="json")
slide_to_edit_sql.properties = slide_to_edit.properties
slide_to_edit_sql.presentation = slide_to_edit.presentation
sql_session.commit()
sql_session.refresh(slide_to_edit_sql)
slide_to_edit_sql = sql_session.exec(
select(SlideSqlModel).where(SlideSqlModel.id == slide_to_edit.id)
).first()
logging_service.logger.info(
logging_service.message(slide_to_edit.model_dump(mode="json")),
logging_service.message(slide_to_edit_sql.model_dump(mode="json")),
extra=log_metadata.model_dump(),
)
return slide_to_edit
def get_all_assets_to_generate_and_delete(
self,
old_slide_model: SlideModel,
new_slide_model: SlideModel,
) -> Tuple[List[str], List[str], List[str], List[str]]:
images_to_delete, images_to_generate = self.get_assets_to_generate_and_delete(
old_slide_model,
new_slide_model,
"image_prompts",
"images",
)
icons_to_delete, icons_to_generate = self.get_assets_to_generate_and_delete(
old_slide_model,
new_slide_model,
"icon_queries",
"icons",
)
return images_to_delete, images_to_generate, icons_to_delete, icons_to_generate
def get_assets_to_generate_and_delete(
self,
old_slide_model: SlideModel,
new_slide_model: SlideModel,
content_attr: str,
slide_model_attr: str,
) -> Tuple[List[str], List[str]]:
items_to_delete = []
items_to_generate = []
existing_paths = getattr(old_slide_model, slide_model_attr, [])
new_content_items = getattr(new_slide_model.content, content_attr, [])
old_content_items = getattr(old_slide_model.content, content_attr, [])
# Case 1: No new items but slide has existing items - delete all
if not new_content_items and existing_paths:
items_to_delete.extend(existing_paths)
return items_to_delete, items_to_generate
# Case 2: New items but slide has no existing items - generate all
if new_content_items and not existing_paths:
items_to_generate = [idx for idx in range(len(new_content_items))]
return items_to_delete, items_to_generate
# Case 3: Both new and existing items - compare and update
if new_content_items and existing_paths:
new_count = len(new_content_items)
old_count = len(existing_paths)
generate_idx = []
for idx in range(max(new_count, old_count)):
# Generate additional new items
if idx >= old_count:
generate_idx.append(idx)
# Delete excess old items
elif idx >= new_count:
items_to_delete.append(existing_paths[idx])
# Compare and update changed items
else:
old_value = old_content_items[idx]
new_value = new_content_items[idx]
if old_value != new_value:
items_to_delete.append(existing_paths[idx])
generate_idx.append(idx)
if generate_idx:
items_to_generate = generate_idx
filtered_items_to_delete = []
for each in items_to_delete:
if not each:
continue
filtered_items_to_delete.append(each)
return filtered_items_to_delete, items_to_generate
return slide_to_edit_sql

View file

@ -7,7 +7,7 @@ from api.routers.presentation.models import (
)
from api.services.logging import LoggingService
from api.services.instances import temp_file_service
from api.utils import get_presentation_dir
from api.utils import get_presentation_dir, get_presentation_images_dir
from image_processor.images_finder import generate_image
@ -28,10 +28,8 @@ class GenerateImageHandler:
extra=log_metadata.model_dump(),
)
image_path = os.path.join(
self.presentation_dir, "generated_images", str(uuid.uuid4()) + ".jpg"
)
await generate_image(self.data.prompt, image_path)
images_directory = get_presentation_images_dir(self.data.presentation_id)
image_path = await generate_image(self.data.prompt, images_directory)
response = PresentationAndPaths(
presentation_id=self.data.presentation_id, paths=[image_path]

View file

@ -1,14 +1,12 @@
import asyncio
import json
import os
from typing import List
import uuid
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from sqlmodel import delete, select
from sqlmodel import delete
from api.models import LogMetadata, SSEResponse
from api.models import LogMetadata, SSECompleteResponse, SSEResponse, SSEStatusResponse
from api.routers.presentation.models import (
PresentationAndSlides,
@ -17,7 +15,7 @@ from api.routers.presentation.models import (
from api.services.database import get_sql_session
from api.services.logging import LoggingService
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
from api.utils import get_presentation_dir
from api.utils import get_presentation_dir, get_presentation_images_dir
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
from image_processor.images_finder import generate_image
from image_processor.icons_finder import get_icon
@ -119,40 +117,15 @@ class PresentationGenerateStreamHandler:
content["index"] = i
content["presentation"] = presentation.id
slide_model = SlideModel(**content)
slide_content = slide_model.content
has_images = hasattr(slide_content, "image_prompts")
has_icons = hasattr(slide_content, "icon_queries")
if has_images:
slide_model.images = [
os.path.join(
self.presentation_dir,
"images",
f"{str(uuid.uuid4())}.jpg",
)
for _ in range(len(slide_content.image_prompts))
]
if has_icons:
slide_model.icons = [
os.path.join(
self.presentation_dir,
"icons",
f"{str(uuid.uuid4())}.png",
)
for _ in range(len(slide_content.icon_queries))
]
slide_models.append(slide_model)
yield SSEResponse(
event="response",
data=json.dumps({"type": "status", "status": "Fetching slide assets"}),
).to_string()
async for result in self.fetch_slide_assets(slide_models):
yield result
print("-" * 40)
print(slide_models)
print("-" * 40)
slide_sql_models = [
SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models
]
@ -163,73 +136,55 @@ class PresentationGenerateStreamHandler:
for each in slide_sql_models:
sql_session.refresh(each)
yield SSEResponse(
event="response",
data=json.dumps({"type": "status", "status": "Packing slide data"}),
).to_string()
yield SSEStatusResponse(status="Packing slide data").to_string()
response = PresentationAndSlides(
presentation=presentation, slides=slide_sql_models
).to_response_dict()
yield SSEResponse(
event="response",
data=json.dumps({"type": "complete", "presentation": response}),
).to_string()
yield SSEResponse(
event="response",
data=json.dumps({"type": "closing", "content": "First Warning"}),
).to_string()
await asyncio.sleep(3)
yield SSEResponse(
event="response",
data=json.dumps({"type": "closing", "content": "Final Warning"}),
).to_string()
yield SSECompleteResponse(key="presentation", value=response).to_string()
async def fetch_slide_assets(self, slide_models: List[SlideModel]):
image_prompts = []
icon_queries = []
image_paths = []
icon_paths = []
for each_slide_model in slide_models:
slide_model_utils = SlideModelUtils(self.theme, each_slide_model)
image_prompts.extend(slide_model_utils.get_image_prompts())
icon_queries.extend(slide_model_utils.get_icon_queries())
if each_slide_model.images:
prompts = slide_model_utils.get_image_prompts()
image_prompts.extend(prompts)
image_paths.extend(each_slide_model.images)
if each_slide_model.icons:
icon_queries.extend(slide_model_utils.get_icon_queries())
icon_paths.extend(each_slide_model.icons)
if icon_paths:
if icon_queries:
icon_vector_store = get_icons_vectorstore()
images_directory = get_presentation_images_dir(self.presentation_id)
coroutines = [
generate_image(
each,
image_path,
images_directory,
)
for each, image_path in zip(image_prompts, image_paths)
] + [
get_icon(icon_vector_store, each, icon_path)
for each, icon_path in zip(icon_queries, icon_paths)
]
for each in image_prompts
] + [get_icon(icon_vector_store, each) for each in icon_queries]
assets_future = asyncio.gather(*coroutines)
while not assets_future.done():
status = SSEResponse(
event="response",
data=json.dumps({"status": "Fetching slide assets..."}),
).to_string()
status = SSEStatusResponse(status="Fetching slide assets").to_string()
yield status
await asyncio.sleep(5)
await assets_future
assets = await assets_future
yield SSEResponse(
event="response", data=json.dumps({"status": "Slide assets fetched"})
).to_string()
image_prompts_len = len(image_prompts)
images = assets[:image_prompts_len]
icons = assets[image_prompts_len:]
for each_slide_model in slide_models:
each_slide_model.images = images[: each_slide_model.images_count]
images = images[each_slide_model.images_count :]
each_slide_model.icons = icons[: each_slide_model.icons_count]
icons = icons[each_slide_model.icons_count :]
yield SSEStatusResponse(status="Slide assets fetched").to_string()

View file

@ -20,6 +20,14 @@ def get_presentation_dir(presentation_id: str) -> str:
return presentation_dir
def get_presentation_images_dir(presentation_id: str) -> str:
presentation_images_dir = os.path.join(
get_presentation_dir(presentation_id), "images"
)
os.makedirs(presentation_images_dir, exist_ok=True)
return presentation_images_dir
def get_user_config():
user_config_path = os.getenv("USER_CONFIG_PATH")
@ -144,14 +152,14 @@ async def handle_errors(
def sanitize_filename(filename: str) -> str:
name, ext = os.path.splitext(filename)
sanitized = re.sub(r'[\\/:*?"<>|]', '_', name)
sanitized = re.sub(r'[\s_]+', '_', sanitized)
sanitized = sanitized.strip(' .')
sanitized = re.sub(r'[\\/:*?"<>|]', "_", name)
sanitized = re.sub(r"[\s_]+", "_", sanitized)
sanitized = sanitized.strip(" .")
if not sanitized:
sanitized = 'untitled'
sanitized = "untitled"
if len(sanitized) > 200:
sanitized = sanitized[:200]
return sanitized + ext

View file

@ -1,4 +1,3 @@
import os
from typing import List, Optional
from api.utils import get_resource
@ -12,22 +11,15 @@ from langchain_core.vectorstores import InMemoryVectorStore
async def get_icon(
vector_store: InMemoryVectorStore,
input: IconQueryCollectionWithData,
output_path: str,
) -> str:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
query = input.icon_query.queries[0]
results = vector_store.similarity_search(query=query, k=1)
icon_name = results[0].page_content
with open(output_path, "wb") as f_a:
try:
with open(get_resource(f"assets/icons/bold/{icon_name}.png"), "rb") as f_b:
f_a.write(f_b.read())
except Exception as e:
print("Error finding icon: ", e)
with open(get_resource("assets/icons/placeholder.png"), "rb") as f_b:
f_a.write(f_b.read())
try:
query = input.icon_query.queries[0]
results = vector_store.similarity_search(query=query, k=1)
icon_name = results[0].page_content
return get_resource(f"assets/icons/bold/{icon_name}.png")
except Exception as e:
print("Error finding icon: ", e)
return get_resource("assets/icons/placeholder.png")
async def get_icons(

View file

@ -1,6 +1,7 @@
import asyncio
import base64
import os
import uuid
import aiohttp
from langchain_google_genai import ChatGoogleGenerativeAI
from openai import OpenAI
@ -13,10 +14,8 @@ from api.utils import get_resource
async def generate_image(
input: ImagePromptWithThemeAndAspectRatio,
output_path: str,
output_directory: str,
) -> str:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
image_prompt = f"{input.image_prompt}, {input.theme_prompt}"
print(f"Request - Generating Image for {image_prompt}")
@ -26,15 +25,17 @@ async def generate_image(
if os.getenv("LLM") == "openai"
else generate_image_google
)
await image_gen_func(image_prompt, output_path)
image_path = await image_gen_func(image_prompt, output_directory)
if image_path and os.path.exists(image_path):
return image_path
raise Exception(f"Image not found at {image_path}")
except Exception as e:
print(f"Error generating image: {e}")
with open(get_resource("assets/images/placeholder.jpg"), "rb") as f_a:
with open(output_path, "wb") as f_b:
f_b.write(f_a.read())
return get_resource("assets/images/placeholder.jpg")
async def generate_image_openai(prompt: str, output_path: str):
async def generate_image_openai(prompt: str, output_directory: str) -> str:
client = OpenAI()
result = await asyncio.to_thread(
client.images.generate,
@ -48,11 +49,13 @@ async def generate_image_openai(prompt: str, output_path: str):
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
image_bytes = await response.read()
with open(output_path, "wb") as f:
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
with open(image_path, "wb") as f:
f.write(image_bytes)
return image_path
async def generate_image_google(prompt: str, output_path: str):
async def generate_image_google(prompt: str, output_directory: str) -> str:
response = await ChatGoogleGenerativeAI(
model="gemini-2.0-flash-preview-image-generation"
).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]})
@ -64,5 +67,8 @@ async def generate_image_google(prompt: str, output_path: str):
)
base64_image = image_block["image_url"].get("url").split(",")[-1]
with open(output_path, "wb") as f:
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
with open(image_path, "wb") as f:
f.write(base64.b64decode(base64_image))
return image_path

View file

@ -28,4 +28,6 @@ class SlideType(Enum):
class SlideTypeModel(BaseModel):
slide_type: int = Field(default=1, gte=1, lte=9, description="Slide type from 1 to 9")
slide_type: int = Field(
default=1, gte=1, lte=9, description="Slide type from 1 to 9"
)

View file

@ -58,10 +58,10 @@ class SlideModel(BaseModel):
def images_count(self):
if not hasattr(self.content, "image_prompts"):
return 0
return len(self.content.image_prompts)
return len(self.content.image_prompts or [])
@property
def icons_count(self):
if not hasattr(self.content, "icon_queries"):
return 0
return len(self.content.icon_queries)
return len(self.content.icon_queries or [])