feat(fastapi): adds slide edit feature
This commit is contained in:
parent
4ade6a9df4
commit
4d63c905f5
6 changed files with 307 additions and 8 deletions
136
servers/fastapi/api/v1/ppt/endpoints/slide.py
Normal file
136
servers/fastapi/api/v1/ppt/endpoints/slide.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import asyncio
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.database import get_sql_session
|
||||
from services.icon_finder_service import IconFinderService
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
from utils.dict_utils import get_dict_at_path, get_dict_paths_with_key, set_dict_at_path
|
||||
from utils.llm_calls.edit_slide import get_edited_slide_content
|
||||
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
|
||||
|
||||
|
||||
SLIDE_ROUTER = APIRouter(prefix="/slide", tags=["Slide"])
|
||||
|
||||
|
||||
@SLIDE_ROUTER.post("/edit")
|
||||
async def edit_slide(id: Annotated[str, Body()], prompt: Annotated[str, Body()]):
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
slide = sql_session.get(SlideModel, id)
|
||||
if not slide:
|
||||
raise HTTPException(status_code=404, detail="Slide not found")
|
||||
presentation = sql_session.get(PresentationModel, slide.presentation)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
presentation_layout = presentation.get_layout()
|
||||
|
||||
slide_layout = await get_slide_layout_from_prompt(
|
||||
prompt, presentation_layout, slide
|
||||
)
|
||||
edited_slide_content = await get_edited_slide_content(
|
||||
prompt, slide_layout, slide, presentation.language
|
||||
)
|
||||
|
||||
# Get old image and icon dicts
|
||||
old_image_dict_paths = get_dict_paths_with_key(slide.content, "__image_prompt__")
|
||||
old_image_dicts = [
|
||||
get_dict_at_path(slide.content, path) for path in old_image_dict_paths
|
||||
]
|
||||
old_image_prompts = [
|
||||
old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts
|
||||
]
|
||||
old_icon_dict_paths = get_dict_paths_with_key(slide.content, "__icon_query__")
|
||||
old_icon_dicts = [
|
||||
get_dict_at_path(slide.content, path) for path in old_icon_dict_paths
|
||||
]
|
||||
old_icon_queries = [
|
||||
old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts
|
||||
]
|
||||
|
||||
# Get new image and icon dicts
|
||||
new_image_dict_paths = get_dict_paths_with_key(
|
||||
edited_slide_content, "__image_prompt__"
|
||||
)
|
||||
new_image_dicts = [
|
||||
get_dict_at_path(edited_slide_content, path) for path in new_image_dict_paths
|
||||
]
|
||||
new_icon_dict_paths = get_dict_paths_with_key(
|
||||
edited_slide_content, "__icon_query__"
|
||||
)
|
||||
new_icon_dicts = [
|
||||
get_dict_at_path(edited_slide_content, path) for path in new_icon_dict_paths
|
||||
]
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
icon_finder_service = IconFinderService()
|
||||
|
||||
async_image_fetch_tasks = []
|
||||
new_images_fetch_status = []
|
||||
|
||||
async_icon_fetch_tasks = []
|
||||
new_icons_fetch_status = []
|
||||
|
||||
for new_image in new_image_dicts:
|
||||
# Use old image url if prompt is same
|
||||
if new_image["__image_prompt__"] in old_image_prompts:
|
||||
old_image_url = old_image_dicts[
|
||||
old_image_prompts.index(new_image["__image_prompt__"])
|
||||
]["__image_url__"]
|
||||
new_image["__image_url__"] = old_image_url
|
||||
new_images_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_image_fetch_tasks.append(
|
||||
image_generation_service.generate_image(
|
||||
ImagePrompt(
|
||||
prompt=new_image["__image_prompt__"],
|
||||
)
|
||||
)
|
||||
)
|
||||
new_images_fetch_status.append(True)
|
||||
|
||||
for new_icon in new_icon_dicts:
|
||||
if new_icon["__icon_query__"] in old_icon_queries:
|
||||
old_icon_url = old_icon_dicts[
|
||||
old_icon_queries.index(new_icon["__icon_query__"])
|
||||
]["__icon_url__"]
|
||||
new_icon["__icon_url__"] = old_icon_url
|
||||
new_icons_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_icon_fetch_tasks.append(
|
||||
icon_finder_service.search_icons(new_icon["__icon_query__"])
|
||||
)
|
||||
new_icons_fetch_status.append(True)
|
||||
|
||||
new_images = await asyncio.gather(*async_image_fetch_tasks)
|
||||
new_icons = await asyncio.gather(*async_icon_fetch_tasks)
|
||||
|
||||
for i, new_image in enumerate(new_images):
|
||||
if new_images_fetch_status[i]:
|
||||
new_image_dicts[i]["__image_url__"] = new_images[i]
|
||||
|
||||
for i, new_icon in enumerate(new_icons):
|
||||
if new_icons_fetch_status[i]:
|
||||
new_icon_dicts[i]["__icon_url__"] = new_icons[i]
|
||||
|
||||
for i, new_image_dict in enumerate(new_image_dicts):
|
||||
set_dict_at_path(edited_slide_content, new_image_dict_paths[i], new_image_dict)
|
||||
|
||||
for i, new_icon_dict in enumerate(new_icon_dicts):
|
||||
set_dict_at_path(edited_slide_content, new_icon_dict_paths[i], new_icon_dict)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(slide)
|
||||
slide.content = edited_slide_content
|
||||
sql_session.commit()
|
||||
sql_session.refresh(slide)
|
||||
|
||||
return slide
|
||||
|
|
@ -7,6 +7,7 @@ from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
|||
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
|
||||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
|
||||
|
||||
|
||||
API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt")
|
||||
|
|
@ -14,6 +15,7 @@ API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt")
|
|||
API_V1_PPT_ROUTER.include_router(FILES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
|
||||
|
|
|
|||
5
servers/fastapi/models/slide_layout_index.py
Normal file
5
servers/fastapi/models/slide_layout_index.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlideLayoutIndex(BaseModel):
|
||||
index: int
|
||||
92
servers/fastapi/utils/llm_calls/edit_slide.py
Normal file
92
servers/fastapi/utils/llm_calls/edit_slide.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from utils.llm_provider import (
|
||||
get_google_llm_client,
|
||||
get_llm_client,
|
||||
get_small_model,
|
||||
is_google_selected,
|
||||
)
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
system_prompt = """
|
||||
Edit Slide data based on provided prompt, follow mentioned steps and notes and provide structured output.
|
||||
|
||||
# Notes
|
||||
- Provide output in language mentioned in **Input**.
|
||||
- The goal is to change Slide data based on the provided prompt.
|
||||
- Do not change **Image prompts** and **Icon queries** if not asked for in prompt.
|
||||
- Generate **Image prompts** and **Icon queries** if asked to generate or change in prompt.
|
||||
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
"""
|
||||
|
||||
|
||||
def get_user_prompt(prompt: str, slide_data: dict, language: Optional[str] = None):
|
||||
return f"""
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Image Prompts and Icon Queries Language: English
|
||||
- Slide data: {slide_data}
|
||||
"""
|
||||
|
||||
|
||||
def get_prompt_to_edit_slide_content(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(prompt, slide_data, language),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def get_edited_slide_content(
|
||||
prompt: str,
|
||||
slide_layout: SlideLayoutModel,
|
||||
slide: SlideModel,
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
model = get_small_model()
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
if is_google_selected():
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(prompt, slide.content, language)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_schema,
|
||||
),
|
||||
)
|
||||
slide_content_json = json.loads(response.text)
|
||||
else:
|
||||
client = get_llm_client()
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_edit_slide_content(
|
||||
prompt,
|
||||
slide.content,
|
||||
language,
|
||||
),
|
||||
response_format=response_schema,
|
||||
)
|
||||
slide_content_json = json.loads(response.choices[0].message.content)
|
||||
|
||||
return slide_content_json
|
||||
|
|
@ -36,8 +36,10 @@ def get_user_prompt(title: str, outline: str):
|
|||
"""
|
||||
|
||||
|
||||
def get_prompt_to_generate_slide_content(title: str, outline: str, schema_constraints: str = ""):
|
||||
|
||||
def get_prompt_to_generate_slide_content(
|
||||
title: str, outline: str, schema_constraints: str = ""
|
||||
):
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -54,8 +56,11 @@ async def get_slide_content_from_type_and_outline(
|
|||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel
|
||||
):
|
||||
model = get_small_model()
|
||||
|
||||
schema_constraints = generate_constraint_sentences(slide_layout.json_schema)
|
||||
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
schema_constraints = generate_constraint_sentences(response_schema)
|
||||
|
||||
if not is_google_selected():
|
||||
client = get_llm_client()
|
||||
|
|
@ -70,9 +75,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "SlideContent",
|
||||
"schema": remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
),
|
||||
"schema": response_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
@ -86,7 +89,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt + f"\n{schema_constraints}",
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=slide_layout.json_schema,
|
||||
response_json_schema=response_schema,
|
||||
),
|
||||
)
|
||||
return json.loads(response.text)
|
||||
|
|
|
|||
61
servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py
Normal file
61
servers/fastapi/utils/llm_calls/select_slide_type_on_edit.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
from utils.llm_provider import get_llm_client, get_small_model
|
||||
|
||||
|
||||
def get_prompt_to_select_slide_layout(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
layout: PresentationLayoutModel,
|
||||
current_slide_layout: int,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
Select a Slide Layout index based on provided user prompt and current slide data.
|
||||
{layout.to_string()}
|
||||
|
||||
# Notes
|
||||
- Do not select different slide layout than current unless absolutely necessary as per user prompt.
|
||||
- If user prompt is not clear, select the layout that is most relevant to the slide data.
|
||||
- If user prompt is not clear, select the layout that is most relevant to the slide data.
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
- Current Slide Layout: {current_slide_layout}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def get_slide_layout_from_prompt(
|
||||
prompt: str,
|
||||
layout: PresentationLayoutModel,
|
||||
slide: SlideModel,
|
||||
) -> SlideLayoutModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
|
||||
slide_layout_ids = list(map(lambda x: x.id, layout.slides))
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_to_select_slide_layout(
|
||||
prompt,
|
||||
slide.content,
|
||||
layout,
|
||||
slide_layout_ids.index(slide.layout),
|
||||
),
|
||||
response_format=SlideLayoutIndex,
|
||||
)
|
||||
index = response.choices[0].message.parsed.index
|
||||
return layout.slides[index]
|
||||
Loading…
Add table
Reference in a new issue