feat(fastapi): uses llm client wrapper to wrap all clients
This commit is contained in:
parent
0c02eafeca
commit
c63944b7cf
15 changed files with 193 additions and 177 deletions
|
|
@ -54,7 +54,6 @@ async def stream_outlines(
|
|||
presentation.outlines = [
|
||||
each.model_dump() for each in presentation_content.slides
|
||||
]
|
||||
presentation.notes = presentation_content.notes
|
||||
|
||||
sql_session.add(presentation)
|
||||
await sql_session.commit()
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -15,7 +15,6 @@ class PresentationOutlineModel(BaseModel):
|
|||
title: str = Field(
|
||||
description="Title of the presentation in about 3 to 8 words",
|
||||
)
|
||||
notes: Optional[List[str]] = Field(default=None, description="Notes for the presentation")
|
||||
slides: List[SlideOutlineModel] = Field(description="List of slides")
|
||||
|
||||
def to_string(self):
|
||||
|
|
@ -25,8 +24,8 @@ class PresentationOutlineModel(BaseModel):
|
|||
message += f" - Title: {slide.title} \n"
|
||||
message += f" - Body: {slide.body} \n"
|
||||
|
||||
if self.notes:
|
||||
message += f"# Notes: \n"
|
||||
for note in self.notes:
|
||||
message += f" - {note} \n"
|
||||
# if self.notes:
|
||||
# message += f"# Notes: \n"
|
||||
# for note in self.notes:
|
||||
# message += f" - {note} \n"
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class PresentationModel(SQLModel, table=True):
|
|||
return PresentationOutlineModel(
|
||||
title=self.title,
|
||||
slides=[SlideOutlineModel(**each) for each in self.outlines],
|
||||
notes=self.notes,
|
||||
# notes=self.notes,
|
||||
)
|
||||
|
||||
def get_layout(self):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
|
@ -24,9 +24,10 @@ from utils.llm_provider import get_llm_provider
|
|||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self):
|
||||
def __init__(self, max_tokens: int = 4000):
|
||||
self.llm_provider = get_llm_provider()
|
||||
self._client = self._get_client()
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# ? Clients
|
||||
def _get_client(self):
|
||||
|
|
@ -98,12 +99,16 @@ class LLMClient:
|
|||
def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]:
|
||||
return [message.content for message in messages if message.role == "user"]
|
||||
|
||||
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
return [message for message in messages if message.role == "user"]
|
||||
|
||||
# ? Generate Unstructured Content
|
||||
async def _generate_openai(self, model: str, messages: List[LLMMessage]):
|
||||
client: AsyncOpenAI = self._client
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
|
@ -116,6 +121,7 @@ class LLMClient:
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="text/plain",
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
|
@ -124,7 +130,12 @@ class LLMClient:
|
|||
client: AsyncAnthropic = self._client
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
text = ""
|
||||
for content in response.content:
|
||||
|
|
@ -179,6 +190,7 @@ class LLMClient:
|
|||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
|
|
@ -189,6 +201,7 @@ class LLMClient:
|
|||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format=response_format,
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.parsed
|
||||
if content:
|
||||
|
|
@ -199,6 +212,7 @@ class LLMClient:
|
|||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
is_response_format_dict = isinstance(response_format, dict)
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
|
|
@ -207,11 +221,17 @@ class LLMClient:
|
|||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
)
|
||||
content = None
|
||||
if response.text:
|
||||
return json.loads(response.text)
|
||||
return None
|
||||
content = json.loads(response.text)
|
||||
|
||||
# If response format is Pydantic model, return the model instance
|
||||
if content and not is_response_format_dict:
|
||||
return response_format(**content)
|
||||
return content
|
||||
|
||||
async def _generate_anthropic_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
|
|
@ -220,7 +240,12 @@ class LLMClient:
|
|||
is_response_format_dict = isinstance(response_format, dict)
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
|
|
@ -237,6 +262,10 @@ class LLMClient:
|
|||
for content_block in response.content:
|
||||
if content_block.type == "tool_use":
|
||||
content = content_block.input
|
||||
|
||||
# If response format is Pydantic model, return the model instance
|
||||
if content and not is_response_format_dict:
|
||||
return response_format(**content)
|
||||
return content
|
||||
|
||||
async def _generate_ollama_structured(
|
||||
|
|
@ -287,6 +316,7 @@ class LLMClient:
|
|||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
|
|
@ -300,6 +330,7 @@ class LLMClient:
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="text/plain",
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
|
|
@ -309,7 +340,12 @@ class LLMClient:
|
|||
client: AsyncAnthropic = self._client
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
|
|
@ -344,17 +380,18 @@ class LLMClient:
|
|||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
max_completion_tokens=self.max_tokens,
|
||||
response_format=(
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ResponseSchema",
|
||||
"schema": response_format,
|
||||
}
|
||||
if is_response_format_dict
|
||||
else response_format
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
if is_response_format_dict
|
||||
else response_format
|
||||
),
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
|
|
@ -371,6 +408,7 @@ class LLMClient:
|
|||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
|
|
@ -383,7 +421,12 @@ class LLMClient:
|
|||
is_response_format_dict = isinstance(response_format, dict)
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
|
|
|
|||
39
servers/fastapi/services/schema_to_model_service.py
Normal file
39
servers/fastapi/services/schema_to_model_service.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import json
|
||||
import tempfile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class SchemaToModelService:
|
||||
def __init__(self):
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
self._records = {}
|
||||
|
||||
def convert(self, schema: dict, identifier: str) -> BaseModel:
|
||||
return BaseModel.model_validate(schema)
|
||||
|
||||
def schema_to_pydantic_model(self, schema: dict, class_name: str):
|
||||
schema_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
get_random_uuid() + ".json", self.temp_dir
|
||||
)
|
||||
with open(schema_path, "w") as f:
|
||||
json.dump(schema, f)
|
||||
|
||||
generated_model_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
get_random_uuid() + ".py", self.temp_dir
|
||||
)
|
||||
# generate(
|
||||
# input_=Path(schema_path),
|
||||
# input_file_type=InputFileType.JsonSchema,
|
||||
# output=Path(output_file),
|
||||
# output_model_type=DataModelType.PydanticV2BaseModel,
|
||||
# class_name=class_name,
|
||||
# use_annotated=False,
|
||||
# field_constraints=True,
|
||||
# )
|
||||
|
||||
# Path(schema_file).unlink(missing_ok=True)
|
||||
|
||||
return generated_model_path
|
||||
|
|
@ -9,8 +9,7 @@ class TempFileService:
|
|||
|
||||
def __init__(self):
|
||||
self.base_dir = get_temp_directory_env() or "/tmp/presenton"
|
||||
# TODO: Uncomment this when we want to cleanup the base dir on startup
|
||||
# self.cleanup_base_dir()
|
||||
self.cleanup_base_dir()
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
def create_dir_in_dir(self, base_dir: str, dir_name: Optional[str] = None) -> str:
|
||||
|
|
|
|||
|
|
@ -22,12 +22,6 @@ def get_presentation_outline_model_with_n_slides(n_slides: int):
|
|||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
notes: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Important notes for the presentation styling and formatting",
|
||||
min_length=0,
|
||||
max_length=10,
|
||||
)
|
||||
slides: List[SlideOutlineModelWithValidation] = Field(
|
||||
description="List of slides", min_items=n_slides, max_items=n_slides
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,17 +1,8 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from utils.llm_provider import (
|
||||
get_anthropic_llm_client,
|
||||
get_google_llm_client,
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
is_anthropic_selected,
|
||||
is_google_selected,
|
||||
)
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
system_prompt = """
|
||||
|
|
@ -44,20 +35,20 @@ def get_user_prompt(prompt: str, slide_data: dict, language: str):
|
|||
"""
|
||||
|
||||
|
||||
def get_prompt_to_edit_slide_content(
|
||||
def get_messages(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
language: str,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(prompt, slide_data, language),
|
||||
},
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(prompt, slide_data, language),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -72,37 +63,10 @@ async def get_edited_slide_content(
|
|||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
if is_google_selected():
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(prompt, slide.content, language)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_schema,
|
||||
),
|
||||
)
|
||||
slide_content_json = json.loads(response.text)
|
||||
|
||||
else:
|
||||
client = get_llm_client()
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_edit_slide_content(
|
||||
prompt,
|
||||
slide.content,
|
||||
language,
|
||||
),
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "slide_content",
|
||||
"schema": response_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
slide_content_json = json.loads(response.choices[0].message.content)
|
||||
|
||||
return slide_content_json
|
||||
client = LLMClient()
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(prompt, slide.content, language),
|
||||
response_format=response_schema,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from utils.llm_provider import get_llm_client, get_nano_model
|
||||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_nano_model
|
||||
|
||||
|
||||
sysmte_prompt = """
|
||||
|
|
@ -23,23 +24,21 @@ Maintain as much information as possible.
|
|||
|
||||
|
||||
async def generate_document_summary(documents: List[str]):
|
||||
client = get_llm_client()
|
||||
client = LLMClient()
|
||||
model = get_nano_model()
|
||||
|
||||
coroutines = []
|
||||
for document in documents:
|
||||
truncated_text = document[:200000]
|
||||
coroutine = client.chat.completions.create(
|
||||
coroutine = client.generate(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": sysmte_prompt},
|
||||
{"role": "user", "content": truncated_text},
|
||||
LLMMessage(role="system", content=sysmte_prompt),
|
||||
LLMMessage(role="user", content=truncated_text),
|
||||
],
|
||||
)
|
||||
coroutines.append(coroutine)
|
||||
|
||||
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join(
|
||||
[completion.choices[0].message.content for completion in completions]
|
||||
)
|
||||
completions: List[str] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join(completions)
|
||||
return combined
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from models.llm_message import LLMMessage
|
||||
|
|
|
|||
|
|
@ -1,22 +1,19 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from utils.llm_provider import (
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
)
|
||||
from utils.get_dynamic_models import (
|
||||
get_presentation_structure_model_with_n_slides,
|
||||
)
|
||||
from models.presentation_structure_model import (
|
||||
PresentationStructureModel,
|
||||
)
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.get_dynamic_models import get_presentation_structure_model_with_n_slides
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data: str):
|
||||
def get_messages(
|
||||
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=f"""
|
||||
You're a professional presentation designer with creative freedom to design engaging presentations.
|
||||
|
||||
{presentation_layout.to_string()}
|
||||
|
|
@ -49,13 +46,13 @@ def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data
|
|||
|
||||
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"""
|
||||
{data}
|
||||
""",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -64,19 +61,19 @@ async def generate_presentation_structure(
|
|||
presentation_layout: PresentationLayoutModel,
|
||||
) -> PresentationStructureModel:
|
||||
|
||||
client = get_llm_client()
|
||||
client = LLMClient()
|
||||
model = get_large_model()
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_prompt(
|
||||
messages=get_messages(
|
||||
presentation_layout,
|
||||
len(presentation_outline.slides),
|
||||
presentation_outline.to_string(),
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import asyncio
|
||||
import json
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import (
|
||||
get_anthropic_llm_client,
|
||||
get_google_llm_client,
|
||||
|
|
@ -47,58 +49,37 @@ def get_user_prompt(title: str, outline: str, language: str):
|
|||
"""
|
||||
|
||||
|
||||
def get_prompt_to_generate_slide_content(title: str, outline: str, language: str):
|
||||
def get_messages(title: str, outline: str, language: str):
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(title, outline, language),
|
||||
},
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(title, outline, language),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
|
||||
):
|
||||
client = LLMClient()
|
||||
model = get_large_model()
|
||||
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
if is_google_selected():
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(outline.title, outline.body, language)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_schema,
|
||||
),
|
||||
)
|
||||
return json.loads(response.text)
|
||||
|
||||
else:
|
||||
client = get_llm_client()
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_generate_slide_content(
|
||||
outline.title,
|
||||
outline.body,
|
||||
language,
|
||||
),
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "SlideContent",
|
||||
"schema": response_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(
|
||||
outline.title,
|
||||
outline.body,
|
||||
language,
|
||||
),
|
||||
response_format=response_schema,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
from utils.llm_provider import get_large_model, get_llm_client
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
|
||||
|
||||
def get_prompt_to_select_slide_layout(
|
||||
def get_messages(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
layout: PresentationLayoutModel,
|
||||
current_slide_layout: int,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=f"""
|
||||
Select a Slide Layout index based on provided user prompt and current slide data.
|
||||
{layout.to_string()}
|
||||
|
||||
|
|
@ -23,15 +25,15 @@ def get_prompt_to_select_slide_layout(
|
|||
- If user prompt is not clear, select the layout that is most relevant to the slide data.
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
- Current Slide Layout: {current_slide_layout}
|
||||
""",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -41,15 +43,14 @@ async def get_slide_layout_from_prompt(
|
|||
slide: SlideModel,
|
||||
) -> SlideLayoutModel:
|
||||
|
||||
client = get_llm_client()
|
||||
client = LLMClient()
|
||||
model = get_large_model()
|
||||
|
||||
slide_layout_ids = list(map(lambda x: x.id, layout.slides))
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
response: SlideLayoutIndex = await client.generate_structured(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_to_select_slide_layout(
|
||||
messages=get_messages(
|
||||
prompt,
|
||||
slide.content,
|
||||
layout,
|
||||
|
|
@ -57,5 +58,5 @@ async def get_slide_layout_from_prompt(
|
|||
),
|
||||
response_format=SlideLayoutIndex,
|
||||
)
|
||||
index = response.choices[0].message.parsed.index
|
||||
index = response.index
|
||||
return layout.slides[index]
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def get_anthropic_llm_client():
|
|||
def get_large_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue