diff --git a/servers/fastapi/api/v1/ppt/endpoints/slide.py b/servers/fastapi/api/v1/ppt/endpoints/slide.py new file mode 100644 index 00000000..a259f58a --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/slide.py @@ -0,0 +1,136 @@ +import asyncio +from typing import Annotated +from fastapi import APIRouter, Body, HTTPException + +from models.image_prompt import ImagePrompt +from models.presentation_outline_model import SlideOutlineModel +from models.sql.presentation import PresentationModel +from models.sql.slide import SlideModel +from services.database import get_sql_session +from services.icon_finder_service import IconFinderService +from services.image_generation_service import ImageGenerationService +from utils.asset_directory_utils import get_images_directory +from utils.dict_utils import get_dict_at_path, get_dict_paths_with_key, set_dict_at_path +from utils.llm_calls.edit_slide import get_edited_slide_content +from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt + + +SLIDE_ROUTER = APIRouter(prefix="/slide", tags=["Slide"]) + + +@SLIDE_ROUTER.post("/edit") +async def edit_slide(id: Annotated[str, Body()], prompt: Annotated[str, Body()]): + + with get_sql_session() as sql_session: + slide = sql_session.get(SlideModel, id) + if not slide: + raise HTTPException(status_code=404, detail="Slide not found") + presentation = sql_session.get(PresentationModel, slide.presentation) + if not presentation: + raise HTTPException(status_code=404, detail="Presentation not found") + + presentation_layout = presentation.get_layout() + + slide_layout = await get_slide_layout_from_prompt( + prompt, presentation_layout, slide + ) + edited_slide_content = await get_edited_slide_content( + prompt, slide_layout, slide, presentation.language + ) + + # Get old image and icon dicts + old_image_dict_paths = get_dict_paths_with_key(slide.content, "__image_prompt__") + old_image_dicts = [ + get_dict_at_path(slide.content, path) for path in old_image_dict_paths + ] + old_image_prompts = [ + old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts + ] + old_icon_dict_paths = get_dict_paths_with_key(slide.content, "__icon_query__") + old_icon_dicts = [ + get_dict_at_path(slide.content, path) for path in old_icon_dict_paths + ] + old_icon_queries = [ + old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts + ] + + # Get new image and icon dicts + new_image_dict_paths = get_dict_paths_with_key( + edited_slide_content, "__image_prompt__" + ) + new_image_dicts = [ + get_dict_at_path(edited_slide_content, path) for path in new_image_dict_paths + ] + new_icon_dict_paths = get_dict_paths_with_key( + edited_slide_content, "__icon_query__" + ) + new_icon_dicts = [ + get_dict_at_path(edited_slide_content, path) for path in new_icon_dict_paths + ] + + image_generation_service = ImageGenerationService(get_images_directory()) + icon_finder_service = IconFinderService() + + async_image_fetch_tasks = [] + new_images_fetch_status = [] + + async_icon_fetch_tasks = [] + new_icons_fetch_status = [] + + for new_image in new_image_dicts: + # Use old image url if prompt is same + if new_image["__image_prompt__"] in old_image_prompts: + old_image_url = old_image_dicts[ + old_image_prompts.index(new_image["__image_prompt__"]) + ]["__image_url__"] + new_image["__image_url__"] = old_image_url + new_images_fetch_status.append(False) + continue + + async_image_fetch_tasks.append( + image_generation_service.generate_image( + ImagePrompt( + prompt=new_image["__image_prompt__"], + ) + ) + ) + new_images_fetch_status.append(True) + + for new_icon in new_icon_dicts: + if new_icon["__icon_query__"] in old_icon_queries: + old_icon_url = old_icon_dicts[ + old_icon_queries.index(new_icon["__icon_query__"]) + ]["__icon_url__"] + new_icon["__icon_url__"] = old_icon_url + new_icons_fetch_status.append(False) + continue + + async_icon_fetch_tasks.append( + icon_finder_service.search_icons(new_icon["__icon_query__"]) + ) + new_icons_fetch_status.append(True) + + new_images = await asyncio.gather(*async_image_fetch_tasks) + new_icons = await asyncio.gather(*async_icon_fetch_tasks) + + for i, new_image in enumerate(new_images): + if new_images_fetch_status[i]: + new_image_dicts[i]["__image_url__"] = new_images[i] + + for i, new_icon in enumerate(new_icons): + if new_icons_fetch_status[i]: + new_icon_dicts[i]["__icon_url__"] = new_icons[i] + + for i, new_image_dict in enumerate(new_image_dicts): + set_dict_at_path(edited_slide_content, new_image_dict_paths[i], new_image_dict) + + for i, new_icon_dict in enumerate(new_icon_dicts): + set_dict_at_path(edited_slide_content, new_icon_dict_paths[i], new_icon_dict) + + with get_sql_session() as sql_session: + sql_session.add(slide) + slide.content = edited_slide_content + sql_session.commit() + sql_session.refresh(slide) + + return slide diff --git a/servers/fastapi/api/v1/ppt/router.py b/servers/fastapi/api/v1/ppt/router.py index 15fb1b52..865afc19 100644 --- a/servers/fastapi/api/v1/ppt/router.py +++ b/servers/fastapi/api/v1/ppt/router.py @@ -7,6 +7,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER +from api.v1.ppt.endpoints.slide import SLIDE_ROUTER API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt") @@ -14,6 +15,7 @@ API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt") API_V1_PPT_ROUTER.include_router(FILES_ROUTER) API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER) API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER) +API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER) API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER) API_V1_PPT_ROUTER.include_router(ICONS_ROUTER) API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER) diff --git a/servers/fastapi/models/slide_layout_index.py b/servers/fastapi/models/slide_layout_index.py new file mode 100644 index 00000000..e5197422 --- /dev/null +++ b/servers/fastapi/models/slide_layout_index.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class SlideLayoutIndex(BaseModel): + index: int diff --git a/servers/fastapi/utils/llm_calls/edit_slide.py b/servers/fastapi/utils/llm_calls/edit_slide.py new file mode 100644 index 00000000..ddf54d25 --- /dev/null +++ b/servers/fastapi/utils/llm_calls/edit_slide.py @@ -0,0 +1,92 @@ +import asyncio +import json +from typing import Optional + +from models.presentation_layout import SlideLayoutModel +from models.sql.slide import SlideModel +from google.genai.types import GenerateContentConfig +from utils.llm_provider import ( + get_google_llm_client, + get_llm_client, + get_small_model, + is_google_selected, +) +from utils.schema_utils import remove_fields_from_schema + +system_prompt = """ + Edit Slide data based on provided prompt, follow mentioned steps and notes and provide structured output. + + # Notes + - Provide output in language mentioned in **Input**. + - The goal is to change Slide data based on the provided prompt. + - Do not change **Image prompts** and **Icon queries** if not asked for in prompt. + - Generate **Image prompts** and **Icon queries** if asked to generate or change in prompt. + + **Go through all notes and steps and make sure they are followed, including mentioned constraints** +""" + + +def get_user_prompt(prompt: str, slide_data: dict, language: Optional[str] = None): + return f""" + - Prompt: {prompt} + - Output Language: {language} + - Image Prompts and Icon Queries Language: English + - Slide data: {slide_data} + """ + + +def get_prompt_to_edit_slide_content( + prompt: str, + slide_data: dict, + language: Optional[str] = None, +): + return [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": get_user_prompt(prompt, slide_data, language), + }, + ] + + +async def get_edited_slide_content( + prompt: str, + slide_layout: SlideLayoutModel, + slide: SlideModel, + language: Optional[str] = None, +): + model = get_small_model() + response_schema = remove_fields_from_schema( + slide_layout.json_schema, ["__image_url__", "__icon_url__"] + ) + + if is_google_selected(): + client = get_google_llm_client() + response = await asyncio.to_thread( + client.models.generate_content, + model=model, + contents=[get_user_prompt(prompt, slide.content, language)], + config=GenerateContentConfig( + system_instruction=system_prompt, + response_mime_type="application/json", + response_json_schema=response_schema, + ), + ) + slide_content_json = json.loads(response.text) + else: + client = get_llm_client() + response = await client.beta.chat.completions.parse( + model=model, + messages=get_prompt_to_edit_slide_content( + prompt, + slide.content, + language, + ), + response_format=response_schema, + ) + slide_content_json = json.loads(response.choices[0].message.content) + + return slide_content_json diff --git a/servers/fastapi/utils/llm_calls/generate_slide_content.py b/servers/fastapi/utils/llm_calls/generate_slide_content.py index fd58888b..be699ccb 100644 --- a/servers/fastapi/utils/llm_calls/generate_slide_content.py +++ b/servers/fastapi/utils/llm_calls/generate_slide_content.py @@ -36,8 +36,10 @@ def get_user_prompt(title: str, outline: str): """ -def get_prompt_to_generate_slide_content(title: str, outline: str, schema_constraints: str = ""): - +def get_prompt_to_generate_slide_content( + title: str, outline: str, schema_constraints: str = "" +): + return [ { "role": "system", @@ -54,8 +56,11 @@ async def get_slide_content_from_type_and_outline( slide_layout: SlideLayoutModel, outline: SlideOutlineModel ): model = get_small_model() - - schema_constraints = generate_constraint_sentences(slide_layout.json_schema) + + response_schema = remove_fields_from_schema( + slide_layout.json_schema, ["__image_url__", "__icon_url__"] + ) + schema_constraints = generate_constraint_sentences(response_schema) if not is_google_selected(): client = get_llm_client() @@ -70,9 +75,7 @@ async def get_slide_content_from_type_and_outline( "type": "json_schema", "json_schema": { "name": "SlideContent", - "schema": remove_fields_from_schema( - slide_layout.json_schema, ["__image_url__", "__icon_url__"] - ), + "schema": response_schema, }, }, ) @@ -86,7 +89,7 @@ async def get_slide_content_from_type_and_outline( config=GenerateContentConfig( system_instruction=system_prompt + f"\n{schema_constraints}", response_mime_type="application/json", - response_json_schema=slide_layout.json_schema, + response_json_schema=response_schema, ), ) return json.loads(response.text) diff --git a/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py new file mode 100644 index 00000000..671e2778 --- /dev/null +++ b/servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py @@ -0,0 +1,61 @@ +from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel +from models.slide_layout_index import SlideLayoutIndex +from models.sql.slide import SlideModel +from utils.llm_provider import get_llm_client, get_small_model + + +def get_prompt_to_select_slide_layout( + prompt: str, + slide_data: dict, + layout: PresentationLayoutModel, + current_slide_layout: int, +): + return [ + { + "role": "system", + "content": f""" + Select a Slide Layout index based on provided user prompt and current slide data. + {layout.to_string()} + + # Notes + - Do not select different slide layout than current unless absolutely necessary as per user prompt. + - If user prompt is not clear, select the layout that is most relevant to the slide data. + - If user prompt is not clear, select the layout that is most relevant to the slide data. + **Go through all notes and steps and make sure they are followed, including mentioned constraints** + """, + }, + { + "role": "user", + "content": f""" + - User Prompt: {prompt} + - Current Slide Data: {slide_data} + - Current Slide Layout: {current_slide_layout} + """, + }, + ] + + +async def get_slide_layout_from_prompt( + prompt: str, + layout: PresentationLayoutModel, + slide: SlideModel, +) -> SlideLayoutModel: + + client = get_llm_client() + model = get_small_model() + + slide_layout_ids = list(map(lambda x: x.id, layout.slides)) + + response = await client.beta.chat.completions.parse( + model=model, + temperature=0.2, + messages=get_prompt_to_select_slide_layout( + prompt, + slide.content, + layout, + slide_layout_ids.index(slide.layout), + ), + response_format=SlideLayoutIndex, + ) + index = response.choices[0].message.parsed.index + return layout.slides[index]