202 lines
8 KiB
Python
202 lines
8 KiB
Python
import asyncio
|
|
import os
|
|
import uuid
|
|
|
|
from sqlalchemy import update
|
|
from sqlmodel import select
|
|
from api.models import LogMetadata
|
|
from api.routers.presentation.models import (
|
|
EditPresentationSlideRequest,
|
|
)
|
|
from api.services.instances import TEMP_FILE_SERVICE
|
|
from api.services.logging import LoggingService
|
|
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|
from api.utils.utils import (
|
|
get_presentation_dir,
|
|
get_presentation_images_dir,
|
|
is_ollama_selected,
|
|
)
|
|
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
|
from image_processor.images_finder import generate_image
|
|
from image_processor.icons_finder import get_icon
|
|
from ppt_generator.models.query_and_prompt_models import (
|
|
IconQueryCollectionWithData,
|
|
ImagePromptWithThemeAndAspectRatio,
|
|
)
|
|
from ppt_generator.models.slide_model import SlideModel
|
|
from ppt_generator.slide_generator import (
|
|
get_edited_slide_content_model,
|
|
get_slide_type_from_prompt,
|
|
)
|
|
from ppt_generator.slide_model_utils import SlideModelUtils
|
|
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
|
from api.services.database import get_sql_session
|
|
|
|
|
|
class PresentationEditHandler:
|
|
def __init__(self, data: EditPresentationSlideRequest):
|
|
self.data = data
|
|
self.presentation_id = data.presentation_id
|
|
|
|
self.slide_index = data.index
|
|
self.prompt = data.prompt
|
|
|
|
self.session = str(uuid.uuid4())
|
|
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
|
|
|
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
|
|
|
def __del__(self):
|
|
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
|
|
|
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
|
logging_service.logger.info(
|
|
logging_service.message(self.data.model_dump(mode="json")),
|
|
extra=log_metadata.model_dump(),
|
|
)
|
|
|
|
with get_sql_session() as sql_session:
|
|
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
|
slide_to_edit_sql = sql_session.exec(
|
|
select(SlideSqlModel).where(
|
|
SlideSqlModel.index == self.slide_index,
|
|
SlideSqlModel.presentation == self.presentation_id,
|
|
)
|
|
).first()
|
|
|
|
slide_to_edit = SlideModel.from_dict(slide_to_edit_sql.model_dump(mode="json"))
|
|
new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit)
|
|
new_slide_type = new_slide_type.slide_type
|
|
|
|
if is_ollama_selected():
|
|
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
|
if not model.supports_graph:
|
|
if new_slide_type == 5:
|
|
new_slide_type = 1
|
|
elif new_slide_type == 9:
|
|
new_slide_type = 6
|
|
|
|
edited_content = await get_edited_slide_content_model(
|
|
self.prompt,
|
|
new_slide_type,
|
|
slide_to_edit,
|
|
presentation.theme,
|
|
presentation.language,
|
|
)
|
|
|
|
new_slide_model = SlideModel(
|
|
id=slide_to_edit.id,
|
|
index=slide_to_edit.index,
|
|
type=new_slide_type,
|
|
design_index=slide_to_edit.design_index,
|
|
images=None,
|
|
icons=None,
|
|
presentation=slide_to_edit.presentation,
|
|
properties=slide_to_edit.properties,
|
|
content=edited_content,
|
|
)
|
|
|
|
new_slide_images_count = new_slide_model.images_count
|
|
new_slide_icons_count = new_slide_model.icons_count
|
|
|
|
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
|
|
|
|
new_slide_images: dict[int, str | ImagePromptWithThemeAndAspectRatio] = {}
|
|
new_slide_icons: dict[int, str | IconQueryCollectionWithData] = {}
|
|
|
|
# ? Checks if image prompts have changed
|
|
# ? If they have, it will search if it is same as the old one but used at a different index
|
|
# ? If it is, it will use the old image
|
|
# ? If it is not, it will generate a new image
|
|
if new_slide_images_count:
|
|
new_image_prompts = slide_model_utils.get_image_prompts()
|
|
old_image_prompts = (
|
|
slide_to_edit.content.image_prompts
|
|
if slide_to_edit.images_count
|
|
else []
|
|
)
|
|
for index in range(new_slide_images_count):
|
|
new_prompt = new_slide_model.content.image_prompts[index]
|
|
for old_prompt in old_image_prompts:
|
|
if old_prompt != new_prompt:
|
|
continue
|
|
if index < len(slide_to_edit.images or []):
|
|
new_slide_images[index] = slide_to_edit.images[index]
|
|
break
|
|
if not new_slide_images.get(index):
|
|
new_slide_images[index] = new_image_prompts[index]
|
|
|
|
# ? Checks if icon queries have changed
|
|
# ? If they have, it will search if it is same as the old one but used at a different index
|
|
# ? If it is, it will use the old icon
|
|
# ? If it is not, it will generate a new icon
|
|
if new_slide_icons_count:
|
|
new_icon_queries = slide_model_utils.get_icon_queries()
|
|
old_icon_queries = (
|
|
slide_to_edit.content.icon_queries if slide_to_edit.icons_count else []
|
|
)
|
|
for index in range(new_slide_icons_count):
|
|
new_query = new_slide_model.content.icon_queries[index]
|
|
for old_query in old_icon_queries:
|
|
if old_query != new_query:
|
|
continue
|
|
if index < len(slide_to_edit.icons or []):
|
|
new_slide_icons[index] = slide_to_edit.icons[index]
|
|
break
|
|
if not new_slide_icons.get(index):
|
|
new_slide_icons[index] = new_icon_queries[index]
|
|
|
|
images_to_generate = []
|
|
for each in new_slide_images.values():
|
|
if isinstance(each, ImagePromptWithThemeAndAspectRatio):
|
|
images_to_generate.append(each)
|
|
|
|
icons_to_generate = []
|
|
for each in new_slide_icons.values():
|
|
if isinstance(each, IconQueryCollectionWithData):
|
|
icons_to_generate.append(each)
|
|
|
|
images_directory = get_presentation_images_dir(self.presentation_id)
|
|
if icons_to_generate:
|
|
icons_vectorstore = get_icons_vectorstore()
|
|
|
|
coroutines = [
|
|
generate_image(each_prompt, images_directory)
|
|
for each_prompt in images_to_generate
|
|
] + [
|
|
get_icon(icons_vectorstore, each_query) for each_query in icons_to_generate
|
|
]
|
|
generated_assets = await asyncio.gather(*coroutines)
|
|
generated_image_count = len(images_to_generate)
|
|
generate_images = generated_assets[:generated_image_count]
|
|
generate_icons = generated_assets[generated_image_count:]
|
|
|
|
for each in new_slide_images:
|
|
if isinstance(new_slide_images[each], ImagePromptWithThemeAndAspectRatio):
|
|
new_slide_images[each] = generate_images.pop(0)
|
|
|
|
for each in new_slide_icons:
|
|
if isinstance(new_slide_icons[each], IconQueryCollectionWithData):
|
|
new_slide_icons[each] = generate_icons.pop(0)
|
|
|
|
with get_sql_session() as sql_session:
|
|
sql_session.exec(
|
|
update(SlideSqlModel)
|
|
.where(SlideSqlModel.id == slide_to_edit.id)
|
|
.values(
|
|
type=new_slide_type,
|
|
images=list(new_slide_images.values()),
|
|
icons=list(new_slide_icons.values()),
|
|
content=new_slide_model.content.model_dump(mode="json"),
|
|
)
|
|
)
|
|
sql_session.commit()
|
|
slide_to_edit_sql = sql_session.exec(
|
|
select(SlideSqlModel).where(SlideSqlModel.id == slide_to_edit.id)
|
|
).first()
|
|
|
|
logging_service.logger.info(
|
|
logging_service.message(slide_to_edit_sql.model_dump(mode="json")),
|
|
extra=log_metadata.model_dump(),
|
|
)
|
|
return slide_to_edit_sql
|