Changes: better assets handling in fastapi
This commit is contained in:
parent
e6d706d366
commit
ea81417bcb
9 changed files with 231 additions and 299 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -37,6 +38,26 @@ class SSEResponse(BaseModel):
|
|||
return f"event: {self.event}\ndata: {self.data}\n\n"
|
||||
|
||||
|
||||
class SSEStatusResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "status", "status": self.status})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSECompleteResponse(BaseModel):
|
||||
key: str
|
||||
value: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "complete", self.key: self.value}),
|
||||
).to_string()
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
LLM: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
|
|
@ -10,9 +10,15 @@ from api.routers.presentation.models import (
|
|||
)
|
||||
from api.services.instances import temp_file_service
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils import get_presentation_dir
|
||||
from api.utils import get_presentation_dir, get_presentation_images_dir
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
from ppt_generator.models.other_models import SlideType
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
IconQueryCollectionWithData,
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from ppt_generator.slide_generator import (
|
||||
get_edited_slide_content_model,
|
||||
|
|
@ -21,7 +27,6 @@ from ppt_generator.slide_generator import (
|
|||
from ppt_generator.slide_model_utils import SlideModelUtils
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from ppt_generator.models.other_models import SlideType
|
||||
|
||||
|
||||
class PresentationEditHandler:
|
||||
|
|
@ -49,193 +54,138 @@ class PresentationEditHandler:
|
|||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
slide_to_edit_sql = sql_session.exec(
|
||||
select(SlideSqlModel).where(SlideSqlModel.index == self.slide_index)
|
||||
select(SlideSqlModel).where(
|
||||
SlideSqlModel.index == self.slide_index,
|
||||
SlideSqlModel.presentation == self.presentation_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
slide_to_edit = SlideModel.from_dict(
|
||||
slide_to_edit_sql.model_dump(mode="json")
|
||||
)
|
||||
slide_to_edit = SlideModel.from_dict(slide_to_edit_sql.model_dump(mode="json"))
|
||||
new_slide_type = SlideType(
|
||||
(await get_slide_type_from_prompt(self.prompt, slide_to_edit)).slide_type
|
||||
)
|
||||
|
||||
new_slide_type = await get_slide_type_from_prompt(
|
||||
self.prompt, slide_to_edit
|
||||
)
|
||||
edited_content = await get_edited_slide_content_model(
|
||||
self.prompt,
|
||||
new_slide_type,
|
||||
slide_to_edit,
|
||||
presentation.theme,
|
||||
presentation.language,
|
||||
)
|
||||
|
||||
edited_content = await get_edited_slide_content_model(
|
||||
self.prompt,
|
||||
SlideType(new_slide_type.slide_type),
|
||||
slide_to_edit,
|
||||
presentation.theme,
|
||||
presentation.language,
|
||||
)
|
||||
new_slide_model = SlideModel(
|
||||
id=slide_to_edit.id,
|
||||
index=slide_to_edit.index,
|
||||
type=new_slide_type,
|
||||
design_index=slide_to_edit.design_index,
|
||||
images=None,
|
||||
icons=None,
|
||||
presentation=slide_to_edit.presentation,
|
||||
properties=slide_to_edit.properties,
|
||||
content=edited_content,
|
||||
)
|
||||
|
||||
new_slide_model = SlideModel(
|
||||
id=slide_to_edit.id,
|
||||
index=slide_to_edit.index,
|
||||
type=SlideType(new_slide_type.slide_type),
|
||||
design_index=slide_to_edit.design_index,
|
||||
images=None,
|
||||
icons=None,
|
||||
presentation=slide_to_edit.presentation,
|
||||
content=edited_content,
|
||||
)
|
||||
new_slide_images_count = new_slide_model.images_count
|
||||
new_slide_icons_count = new_slide_model.icons_count
|
||||
|
||||
images_to_delete, images_to_generate, icons_to_delete, icons_to_generate = (
|
||||
self.get_all_assets_to_generate_and_delete(
|
||||
slide_to_edit,
|
||||
new_slide_model,
|
||||
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
|
||||
|
||||
new_slide_images: dict[int, str | ImagePromptWithThemeAndAspectRatio] = {}
|
||||
new_slide_icons: dict[int, str | IconQueryCollectionWithData] = {}
|
||||
|
||||
# ? Checks if image prompts have changed
|
||||
# ? If they have, it will search if it is same as the old one but used at a different index
|
||||
# ? If it is, it will use the old image
|
||||
# ? If it is not, it will generate a new image
|
||||
if new_slide_images_count:
|
||||
new_image_prompts = slide_model_utils.get_image_prompts()
|
||||
old_image_prompts = (
|
||||
slide_to_edit.content.image_prompts
|
||||
if slide_to_edit.images_count
|
||||
else []
|
||||
)
|
||||
for index in range(new_slide_images_count):
|
||||
new_prompt = new_slide_model.content.image_prompts[index]
|
||||
for old_prompt in old_image_prompts:
|
||||
if old_prompt != new_prompt:
|
||||
continue
|
||||
if index < len(slide_to_edit.images or []):
|
||||
new_slide_images[index] = slide_to_edit.images[index]
|
||||
break
|
||||
if not new_slide_images.get(index):
|
||||
new_slide_images[index] = new_image_prompts[index]
|
||||
|
||||
# ? Checks if icon queries have changed
|
||||
# ? If they have, it will search if it is same as the old one but used at a different index
|
||||
# ? If it is, it will use the old icon
|
||||
# ? If it is not, it will generate a new icon
|
||||
if new_slide_icons_count:
|
||||
new_icon_queries = slide_model_utils.get_icon_queries()
|
||||
old_icon_queries = (
|
||||
slide_to_edit.content.icon_queries if slide_to_edit.icons_count else []
|
||||
)
|
||||
for index in range(new_slide_icons_count):
|
||||
new_query = new_slide_model.content.icon_queries[index]
|
||||
for old_query in old_icon_queries:
|
||||
if old_query != new_query:
|
||||
continue
|
||||
if index < len(slide_to_edit.icons or []):
|
||||
new_slide_icons[index] = slide_to_edit.icons[index]
|
||||
break
|
||||
if not new_slide_icons.get(index):
|
||||
new_slide_icons[index] = new_icon_queries[index]
|
||||
|
||||
images_to_generate = []
|
||||
for each in new_slide_images.values():
|
||||
if isinstance(each, ImagePromptWithThemeAndAspectRatio):
|
||||
images_to_generate.append(each)
|
||||
|
||||
icons_to_generate = []
|
||||
for each in new_slide_icons.values():
|
||||
if isinstance(each, IconQueryCollectionWithData):
|
||||
icons_to_generate.append(each)
|
||||
|
||||
images_directory = get_presentation_images_dir(self.presentation_id)
|
||||
if icons_to_generate:
|
||||
icons_vectorstore = get_icons_vectorstore()
|
||||
|
||||
coroutines = [
|
||||
generate_image(each_prompt, images_directory)
|
||||
for each_prompt in images_to_generate
|
||||
] + [
|
||||
get_icon(icons_vectorstore, each_query) for each_query in icons_to_generate
|
||||
]
|
||||
generated_assets = await asyncio.gather(*coroutines)
|
||||
generated_image_count = len(images_to_generate)
|
||||
generate_images = generated_assets[:generated_image_count]
|
||||
generate_icons = generated_assets[generated_image_count:]
|
||||
|
||||
for each in new_slide_images:
|
||||
if isinstance(new_slide_images[each], ImagePromptWithThemeAndAspectRatio):
|
||||
new_slide_images[each] = generate_images.pop(0)
|
||||
|
||||
for each in new_slide_icons:
|
||||
if isinstance(new_slide_icons[each], IconQueryCollectionWithData):
|
||||
new_slide_icons[each] = generate_icons.pop(0)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.exec(
|
||||
update(SlideSqlModel)
|
||||
.where(SlideSqlModel.id == slide_to_edit.id)
|
||||
.values(
|
||||
type=new_slide_type.value,
|
||||
images=list(new_slide_images.values()),
|
||||
icons=list(new_slide_icons.values()),
|
||||
content=new_slide_model.content.model_dump(mode="json"),
|
||||
)
|
||||
)
|
||||
new_image_paths = slide_to_edit.images or []
|
||||
new_icon_paths = slide_to_edit.icons or []
|
||||
images_count = len(new_image_paths)
|
||||
icons_count = len(new_icon_paths)
|
||||
for index in images_to_generate:
|
||||
file_key = f"{self.presentation_dir}/images/{str(uuid.uuid4())}.jpg"
|
||||
if index < images_count:
|
||||
new_image_paths.pop(index)
|
||||
new_image_paths.insert(index, file_key)
|
||||
else:
|
||||
new_image_paths.append(file_key)
|
||||
for index in icons_to_generate:
|
||||
file_key = f"{self.presentation_dir}/icons/{str(uuid.uuid4())}.png"
|
||||
if index < icons_count:
|
||||
new_icon_paths.pop(index)
|
||||
new_icon_paths.insert(index, file_key)
|
||||
else:
|
||||
new_icon_paths.append(file_key)
|
||||
|
||||
if new_image_paths:
|
||||
new_slide_model.images = new_image_paths
|
||||
if new_icon_paths:
|
||||
new_slide_model.icons = new_icon_paths
|
||||
|
||||
# ? Images and Icons are related to this presentation will be deleted while deleting presentation.
|
||||
# objects_to_delete = [*images_to_delete, *icons_to_delete]
|
||||
# if objects_to_delete:
|
||||
# for each in objects_to_delete:
|
||||
# os.remove(each)
|
||||
|
||||
new_image_prompts = {}
|
||||
new_icon_queries = {}
|
||||
if images_to_generate:
|
||||
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
|
||||
image_prompts = slide_model_utils.get_image_prompts()
|
||||
for image_index in images_to_generate:
|
||||
new_image_prompts[new_slide_model.images[image_index]] = (
|
||||
image_prompts[image_index]
|
||||
)
|
||||
|
||||
if icons_to_generate:
|
||||
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
|
||||
icon_queries = slide_model_utils.get_icon_queries()
|
||||
for icon_index in icons_to_generate:
|
||||
new_icon_queries[new_slide_model.icons[icon_index]] = icon_queries[
|
||||
icon_index
|
||||
]
|
||||
|
||||
coroutines = [
|
||||
generate_image(value, key) for key, value in new_image_prompts.items()
|
||||
] + [get_icon(value, key) for key, value in new_icon_queries.items()]
|
||||
|
||||
await asyncio.gather(*coroutines)
|
||||
|
||||
slide_to_edit.images = new_slide_model.images
|
||||
slide_to_edit.icons = new_slide_model.icons
|
||||
slide_to_edit.content = new_slide_model.content
|
||||
slide_to_edit.type = SlideType(new_slide_type.slide_type)
|
||||
|
||||
slide_to_edit_sql.index = slide_to_edit.index
|
||||
slide_to_edit_sql.type = slide_to_edit.type.value
|
||||
slide_to_edit_sql.design_index = slide_to_edit.design_index
|
||||
slide_to_edit_sql.images = slide_to_edit.images
|
||||
slide_to_edit_sql.icons = slide_to_edit.icons
|
||||
slide_to_edit_sql.content = slide_to_edit.content.model_dump(mode="json")
|
||||
slide_to_edit_sql.properties = slide_to_edit.properties
|
||||
slide_to_edit_sql.presentation = slide_to_edit.presentation
|
||||
sql_session.commit()
|
||||
sql_session.refresh(slide_to_edit_sql)
|
||||
slide_to_edit_sql = sql_session.exec(
|
||||
select(SlideSqlModel).where(SlideSqlModel.id == slide_to_edit.id)
|
||||
).first()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(slide_to_edit.model_dump(mode="json")),
|
||||
logging_service.message(slide_to_edit_sql.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return slide_to_edit
|
||||
|
||||
def get_all_assets_to_generate_and_delete(
|
||||
self,
|
||||
old_slide_model: SlideModel,
|
||||
new_slide_model: SlideModel,
|
||||
) -> Tuple[List[str], List[str], List[str], List[str]]:
|
||||
|
||||
images_to_delete, images_to_generate = self.get_assets_to_generate_and_delete(
|
||||
old_slide_model,
|
||||
new_slide_model,
|
||||
"image_prompts",
|
||||
"images",
|
||||
)
|
||||
|
||||
icons_to_delete, icons_to_generate = self.get_assets_to_generate_and_delete(
|
||||
old_slide_model,
|
||||
new_slide_model,
|
||||
"icon_queries",
|
||||
"icons",
|
||||
)
|
||||
|
||||
return images_to_delete, images_to_generate, icons_to_delete, icons_to_generate
|
||||
|
||||
def get_assets_to_generate_and_delete(
|
||||
self,
|
||||
old_slide_model: SlideModel,
|
||||
new_slide_model: SlideModel,
|
||||
content_attr: str,
|
||||
slide_model_attr: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
|
||||
items_to_delete = []
|
||||
items_to_generate = []
|
||||
|
||||
existing_paths = getattr(old_slide_model, slide_model_attr, [])
|
||||
new_content_items = getattr(new_slide_model.content, content_attr, [])
|
||||
old_content_items = getattr(old_slide_model.content, content_attr, [])
|
||||
|
||||
# Case 1: No new items but slide has existing items - delete all
|
||||
if not new_content_items and existing_paths:
|
||||
items_to_delete.extend(existing_paths)
|
||||
return items_to_delete, items_to_generate
|
||||
|
||||
# Case 2: New items but slide has no existing items - generate all
|
||||
if new_content_items and not existing_paths:
|
||||
items_to_generate = [idx for idx in range(len(new_content_items))]
|
||||
return items_to_delete, items_to_generate
|
||||
|
||||
# Case 3: Both new and existing items - compare and update
|
||||
if new_content_items and existing_paths:
|
||||
new_count = len(new_content_items)
|
||||
old_count = len(existing_paths)
|
||||
|
||||
generate_idx = []
|
||||
for idx in range(max(new_count, old_count)):
|
||||
# Generate additional new items
|
||||
if idx >= old_count:
|
||||
generate_idx.append(idx)
|
||||
# Delete excess old items
|
||||
elif idx >= new_count:
|
||||
items_to_delete.append(existing_paths[idx])
|
||||
# Compare and update changed items
|
||||
else:
|
||||
old_value = old_content_items[idx]
|
||||
new_value = new_content_items[idx]
|
||||
if old_value != new_value:
|
||||
items_to_delete.append(existing_paths[idx])
|
||||
generate_idx.append(idx)
|
||||
|
||||
if generate_idx:
|
||||
items_to_generate = generate_idx
|
||||
|
||||
filtered_items_to_delete = []
|
||||
for each in items_to_delete:
|
||||
if not each:
|
||||
continue
|
||||
filtered_items_to_delete.append(each)
|
||||
|
||||
return filtered_items_to_delete, items_to_generate
|
||||
return slide_to_edit_sql
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from api.routers.presentation.models import (
|
|||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import temp_file_service
|
||||
from api.utils import get_presentation_dir
|
||||
from api.utils import get_presentation_dir, get_presentation_images_dir
|
||||
from image_processor.images_finder import generate_image
|
||||
|
||||
|
||||
|
|
@ -28,10 +28,8 @@ class GenerateImageHandler:
|
|||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
self.presentation_dir, "generated_images", str(uuid.uuid4()) + ".jpg"
|
||||
)
|
||||
await generate_image(self.data.prompt, image_path)
|
||||
images_directory = get_presentation_images_dir(self.data.presentation_id)
|
||||
image_path = await generate_image(self.data.prompt, images_directory)
|
||||
|
||||
response = PresentationAndPaths(
|
||||
presentation_id=self.data.presentation_id, paths=[image_path]
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlmodel import delete, select
|
||||
from sqlmodel import delete
|
||||
|
||||
from api.models import LogMetadata, SSEResponse
|
||||
from api.models import LogMetadata, SSECompleteResponse, SSEResponse, SSEStatusResponse
|
||||
|
||||
from api.routers.presentation.models import (
|
||||
PresentationAndSlides,
|
||||
|
|
@ -17,7 +15,7 @@ from api.routers.presentation.models import (
|
|||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
|
||||
from api.utils import get_presentation_dir
|
||||
from api.utils import get_presentation_dir, get_presentation_images_dir
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
|
|
@ -119,40 +117,15 @@ class PresentationGenerateStreamHandler:
|
|||
content["index"] = i
|
||||
content["presentation"] = presentation.id
|
||||
slide_model = SlideModel(**content)
|
||||
slide_content = slide_model.content
|
||||
has_images = hasattr(slide_content, "image_prompts")
|
||||
has_icons = hasattr(slide_content, "icon_queries")
|
||||
|
||||
if has_images:
|
||||
slide_model.images = [
|
||||
os.path.join(
|
||||
self.presentation_dir,
|
||||
"images",
|
||||
f"{str(uuid.uuid4())}.jpg",
|
||||
)
|
||||
for _ in range(len(slide_content.image_prompts))
|
||||
]
|
||||
|
||||
if has_icons:
|
||||
slide_model.icons = [
|
||||
os.path.join(
|
||||
self.presentation_dir,
|
||||
"icons",
|
||||
f"{str(uuid.uuid4())}.png",
|
||||
)
|
||||
for _ in range(len(slide_content.icon_queries))
|
||||
]
|
||||
|
||||
slide_models.append(slide_model)
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "status", "status": "Fetching slide assets"}),
|
||||
).to_string()
|
||||
|
||||
async for result in self.fetch_slide_assets(slide_models):
|
||||
yield result
|
||||
|
||||
print("-" * 40)
|
||||
print(slide_models)
|
||||
print("-" * 40)
|
||||
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models
|
||||
]
|
||||
|
|
@ -163,73 +136,55 @@ class PresentationGenerateStreamHandler:
|
|||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "status", "status": "Packing slide data"}),
|
||||
).to_string()
|
||||
yield SSEStatusResponse(status="Packing slide data").to_string()
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=presentation, slides=slide_sql_models
|
||||
).to_response_dict()
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "complete", "presentation": response}),
|
||||
).to_string()
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "closing", "content": "First Warning"}),
|
||||
).to_string()
|
||||
await asyncio.sleep(3)
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "closing", "content": "Final Warning"}),
|
||||
).to_string()
|
||||
yield SSECompleteResponse(key="presentation", value=response).to_string()
|
||||
|
||||
async def fetch_slide_assets(self, slide_models: List[SlideModel]):
|
||||
image_prompts = []
|
||||
icon_queries = []
|
||||
|
||||
image_paths = []
|
||||
icon_paths = []
|
||||
|
||||
for each_slide_model in slide_models:
|
||||
slide_model_utils = SlideModelUtils(self.theme, each_slide_model)
|
||||
image_prompts.extend(slide_model_utils.get_image_prompts())
|
||||
icon_queries.extend(slide_model_utils.get_icon_queries())
|
||||
|
||||
if each_slide_model.images:
|
||||
prompts = slide_model_utils.get_image_prompts()
|
||||
image_prompts.extend(prompts)
|
||||
image_paths.extend(each_slide_model.images)
|
||||
if each_slide_model.icons:
|
||||
icon_queries.extend(slide_model_utils.get_icon_queries())
|
||||
icon_paths.extend(each_slide_model.icons)
|
||||
|
||||
if icon_paths:
|
||||
if icon_queries:
|
||||
icon_vector_store = get_icons_vectorstore()
|
||||
|
||||
images_directory = get_presentation_images_dir(self.presentation_id)
|
||||
|
||||
coroutines = [
|
||||
generate_image(
|
||||
each,
|
||||
image_path,
|
||||
images_directory,
|
||||
)
|
||||
for each, image_path in zip(image_prompts, image_paths)
|
||||
] + [
|
||||
get_icon(icon_vector_store, each, icon_path)
|
||||
for each, icon_path in zip(icon_queries, icon_paths)
|
||||
]
|
||||
for each in image_prompts
|
||||
] + [get_icon(icon_vector_store, each) for each in icon_queries]
|
||||
|
||||
assets_future = asyncio.gather(*coroutines)
|
||||
|
||||
while not assets_future.done():
|
||||
status = SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"status": "Fetching slide assets..."}),
|
||||
).to_string()
|
||||
status = SSEStatusResponse(status="Fetching slide assets").to_string()
|
||||
yield status
|
||||
await asyncio.sleep(5)
|
||||
|
||||
await assets_future
|
||||
assets = await assets_future
|
||||
|
||||
yield SSEResponse(
|
||||
event="response", data=json.dumps({"status": "Slide assets fetched"})
|
||||
).to_string()
|
||||
image_prompts_len = len(image_prompts)
|
||||
|
||||
images = assets[:image_prompts_len]
|
||||
icons = assets[image_prompts_len:]
|
||||
|
||||
for each_slide_model in slide_models:
|
||||
each_slide_model.images = images[: each_slide_model.images_count]
|
||||
images = images[each_slide_model.images_count :]
|
||||
|
||||
each_slide_model.icons = icons[: each_slide_model.icons_count]
|
||||
icons = icons[each_slide_model.icons_count :]
|
||||
|
||||
yield SSEStatusResponse(status="Slide assets fetched").to_string()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,14 @@ def get_presentation_dir(presentation_id: str) -> str:
|
|||
return presentation_dir
|
||||
|
||||
|
||||
def get_presentation_images_dir(presentation_id: str) -> str:
|
||||
presentation_images_dir = os.path.join(
|
||||
get_presentation_dir(presentation_id), "images"
|
||||
)
|
||||
os.makedirs(presentation_images_dir, exist_ok=True)
|
||||
return presentation_images_dir
|
||||
|
||||
|
||||
def get_user_config():
|
||||
user_config_path = os.getenv("USER_CONFIG_PATH")
|
||||
|
||||
|
|
@ -144,14 +152,14 @@ async def handle_errors(
|
|||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
name, ext = os.path.splitext(filename)
|
||||
sanitized = re.sub(r'[\\/:*?"<>|]', '_', name)
|
||||
sanitized = re.sub(r'[\s_]+', '_', sanitized)
|
||||
sanitized = sanitized.strip(' .')
|
||||
sanitized = re.sub(r'[\\/:*?"<>|]', "_", name)
|
||||
sanitized = re.sub(r"[\s_]+", "_", sanitized)
|
||||
sanitized = sanitized.strip(" .")
|
||||
|
||||
if not sanitized:
|
||||
sanitized = 'untitled'
|
||||
|
||||
sanitized = "untitled"
|
||||
|
||||
if len(sanitized) > 200:
|
||||
sanitized = sanitized[:200]
|
||||
|
||||
|
||||
return sanitized + ext
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from api.utils import get_resource
|
||||
|
|
@ -12,22 +11,15 @@ from langchain_core.vectorstores import InMemoryVectorStore
|
|||
async def get_icon(
|
||||
vector_store: InMemoryVectorStore,
|
||||
input: IconQueryCollectionWithData,
|
||||
output_path: str,
|
||||
) -> str:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
query = input.icon_query.queries[0]
|
||||
results = vector_store.similarity_search(query=query, k=1)
|
||||
icon_name = results[0].page_content
|
||||
|
||||
with open(output_path, "wb") as f_a:
|
||||
try:
|
||||
with open(get_resource(f"assets/icons/bold/{icon_name}.png"), "rb") as f_b:
|
||||
f_a.write(f_b.read())
|
||||
except Exception as e:
|
||||
print("Error finding icon: ", e)
|
||||
with open(get_resource("assets/icons/placeholder.png"), "rb") as f_b:
|
||||
f_a.write(f_b.read())
|
||||
try:
|
||||
query = input.icon_query.queries[0]
|
||||
results = vector_store.similarity_search(query=query, k=1)
|
||||
icon_name = results[0].page_content
|
||||
return get_resource(f"assets/icons/bold/{icon_name}.png")
|
||||
except Exception as e:
|
||||
print("Error finding icon: ", e)
|
||||
return get_resource("assets/icons/placeholder.png")
|
||||
|
||||
|
||||
async def get_icons(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from openai import OpenAI
|
||||
|
|
@ -13,10 +14,8 @@ from api.utils import get_resource
|
|||
|
||||
async def generate_image(
|
||||
input: ImagePromptWithThemeAndAspectRatio,
|
||||
output_path: str,
|
||||
output_directory: str,
|
||||
) -> str:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
image_prompt = f"{input.image_prompt}, {input.theme_prompt}"
|
||||
print(f"Request - Generating Image for {image_prompt}")
|
||||
|
||||
|
|
@ -26,15 +25,17 @@ async def generate_image(
|
|||
if os.getenv("LLM") == "openai"
|
||||
else generate_image_google
|
||||
)
|
||||
await image_gen_func(image_prompt, output_path)
|
||||
image_path = await image_gen_func(image_prompt, output_directory)
|
||||
if image_path and os.path.exists(image_path):
|
||||
return image_path
|
||||
raise Exception(f"Image not found at {image_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating image: {e}")
|
||||
with open(get_resource("assets/images/placeholder.jpg"), "rb") as f_a:
|
||||
with open(output_path, "wb") as f_b:
|
||||
f_b.write(f_a.read())
|
||||
return get_resource("assets/images/placeholder.jpg")
|
||||
|
||||
|
||||
async def generate_image_openai(prompt: str, output_path: str):
|
||||
async def generate_image_openai(prompt: str, output_directory: str) -> str:
|
||||
client = OpenAI()
|
||||
result = await asyncio.to_thread(
|
||||
client.images.generate,
|
||||
|
|
@ -48,11 +49,13 @@ async def generate_image_openai(prompt: str, output_path: str):
|
|||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
image_bytes = await response.read()
|
||||
with open(output_path, "wb") as f:
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return image_path
|
||||
|
||||
|
||||
async def generate_image_google(prompt: str, output_path: str):
|
||||
async def generate_image_google(prompt: str, output_directory: str) -> str:
|
||||
response = await ChatGoogleGenerativeAI(
|
||||
model="gemini-2.0-flash-preview-image-generation"
|
||||
).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]})
|
||||
|
|
@ -64,5 +67,8 @@ async def generate_image_google(prompt: str, output_path: str):
|
|||
)
|
||||
|
||||
base64_image = image_block["image_url"].get("url").split(",")[-1]
|
||||
with open(output_path, "wb") as f:
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(base64.b64decode(base64_image))
|
||||
|
||||
return image_path
|
||||
|
|
|
|||
|
|
@ -28,4 +28,6 @@ class SlideType(Enum):
|
|||
|
||||
|
||||
class SlideTypeModel(BaseModel):
|
||||
slide_type: int = Field(default=1, gte=1, lte=9, description="Slide type from 1 to 9")
|
||||
slide_type: int = Field(
|
||||
default=1, gte=1, lte=9, description="Slide type from 1 to 9"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -58,10 +58,10 @@ class SlideModel(BaseModel):
|
|||
def images_count(self):
|
||||
if not hasattr(self.content, "image_prompts"):
|
||||
return 0
|
||||
return len(self.content.image_prompts)
|
||||
return len(self.content.image_prompts or [])
|
||||
|
||||
@property
|
||||
def icons_count(self):
|
||||
if not hasattr(self.content, "icon_queries"):
|
||||
return 0
|
||||
return len(self.content.icon_queries)
|
||||
return len(self.content.icon_queries or [])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue