feat(fastapi): uses llm client wrapper to wrap all clients

This commit is contained in:
sauravniraula 2025-07-31 13:27:21 +05:45
parent 0c02eafeca
commit c63944b7cf
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
15 changed files with 193 additions and 177 deletions

View file

@ -54,7 +54,6 @@ async def stream_outlines(
presentation.outlines = [
each.model_dump() for each in presentation_content.slides
]
presentation.notes = presentation_content.notes
sql_session.add(presentation)
await sql_session.commit()

View file

@ -15,7 +15,6 @@ class PresentationOutlineModel(BaseModel):
title: str = Field(
description="Title of the presentation in about 3 to 8 words",
)
notes: Optional[List[str]] = Field(default=None, description="Notes for the presentation")
slides: List[SlideOutlineModel] = Field(description="List of slides")
def to_string(self):
@ -25,8 +24,8 @@ class PresentationOutlineModel(BaseModel):
message += f" - Title: {slide.title} \n"
message += f" - Body: {slide.body} \n"
if self.notes:
message += f"# Notes: \n"
for note in self.notes:
message += f" - {note} \n"
# if self.notes:
# message += f"# Notes: \n"
# for note in self.notes:
# message += f" - {note} \n"
return message

View file

@ -46,7 +46,7 @@ class PresentationModel(SQLModel, table=True):
return PresentationOutlineModel(
title=self.title,
slides=[SlideOutlineModel(**each) for each in self.outlines],
notes=self.notes,
# notes=self.notes,
)
def get_layout(self):

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import List
from typing import List, Optional
from fastapi import HTTPException
from openai import AsyncOpenAI
from google import genai
@ -24,9 +24,10 @@ from utils.llm_provider import get_llm_provider
class LLMClient:
def __init__(self):
def __init__(self, max_tokens: int = 4000):
self.llm_provider = get_llm_provider()
self._client = self._get_client()
self.max_tokens = max_tokens
# ? Clients
def _get_client(self):
@ -98,12 +99,16 @@ class LLMClient:
def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]:
return [message.content for message in messages if message.role == "user"]
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
return [message for message in messages if message.role == "user"]
# ? Generate Unstructured Content
async def _generate_openai(self, model: str, messages: List[LLMMessage]):
client: AsyncOpenAI = self._client
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=self.max_tokens,
)
return response.choices[0].message.content
@ -116,6 +121,7 @@ class LLMClient:
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
max_output_tokens=self.max_tokens,
),
)
return response.text
@ -124,7 +130,12 @@ class LLMClient:
client: AsyncAnthropic = self._client
response: AnthropicMessage = await client.messages.create(
model=model,
messages=[message.model_dump() for message in messages],
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
)
text = ""
for content in response.content:
@ -179,6 +190,7 @@ class LLMClient:
}
),
},
max_completion_tokens=self.max_tokens,
)
content = response.choices[0].message.content
if content:
@ -189,6 +201,7 @@ class LLMClient:
model=model,
messages=[message.model_dump() for message in messages],
response_format=response_format,
max_completion_tokens=self.max_tokens,
)
content = response.choices[0].message.parsed
if content:
@ -199,6 +212,7 @@ class LLMClient:
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
):
client: genai.Client = self._client
is_response_format_dict = isinstance(response_format, dict)
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
@ -207,11 +221,17 @@ class LLMClient:
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json",
response_schema=response_format,
max_output_tokens=self.max_tokens,
),
)
content = None
if response.text:
return json.loads(response.text)
return None
content = json.loads(response.text)
# If response format is Pydantic model, return the model instance
if content and not is_response_format_dict:
return response_format(**content)
return content
async def _generate_anthropic_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
@ -220,7 +240,12 @@ class LLMClient:
is_response_format_dict = isinstance(response_format, dict)
response: AnthropicMessage = await client.messages.create(
model=model,
messages=[message.model_dump() for message in messages],
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
tools=[
{
"name": "ResponseSchema",
@ -237,6 +262,10 @@ class LLMClient:
for content_block in response.content:
if content_block.type == "tool_use":
content = content_block.input
# If response format is Pydantic model, return the model instance
if content and not is_response_format_dict:
return response_format(**content)
return content
async def _generate_ollama_structured(
@ -287,6 +316,7 @@ class LLMClient:
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
max_completion_tokens=self.max_tokens,
) as stream:
async for event in stream:
if event.type == "content.delta":
@ -300,6 +330,7 @@ class LLMClient:
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="text/plain",
max_output_tokens=self.max_tokens,
),
):
if event.text:
@ -309,7 +340,12 @@ class LLMClient:
client: AsyncAnthropic = self._client
async with client.messages.stream(
model=model,
messages=[message.model_dump() for message in messages],
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
) as stream:
async for event in stream:
event: AnthropicMessageStreamEvent = event
@ -344,17 +380,18 @@ class LLMClient:
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
{
max_completion_tokens=self.max_tokens,
response_format=(
{
"type": "json_schema",
"json_schema": {
"name": "ResponseSchema",
"schema": response_format,
}
if is_response_format_dict
else response_format
),
},
},
}
if is_response_format_dict
else response_format
),
) as stream:
async for event in stream:
if event.type == "content.delta":
@ -371,6 +408,7 @@ class LLMClient:
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json",
response_schema=response_format,
max_output_tokens=self.max_tokens,
),
):
if event.text:
@ -383,7 +421,12 @@ class LLMClient:
is_response_format_dict = isinstance(response_format, dict)
async with client.messages.stream(
model=model,
messages=[message.model_dump() for message in messages],
system=self._get_system_prompt(messages),
messages=[
message.model_dump()
for message in self._get_user_llm_messages(messages)
],
max_tokens=self.max_tokens,
tools=[
{
"name": "ResponseSchema",

View file

@ -0,0 +1,39 @@
import json
import tempfile
from pydantic import BaseModel
from services import TEMP_FILE_SERVICE
from utils.randomizers import get_random_uuid
class SchemaToModelService:
def __init__(self):
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
self._records = {}
def convert(self, schema: dict, identifier: str) -> BaseModel:
return BaseModel.model_validate(schema)
def schema_to_pydantic_model(self, schema: dict, class_name: str):
schema_path = TEMP_FILE_SERVICE.create_temp_file_path(
get_random_uuid() + ".json", self.temp_dir
)
with open(schema_path, "w") as f:
json.dump(schema, f)
generated_model_path = TEMP_FILE_SERVICE.create_temp_file_path(
get_random_uuid() + ".py", self.temp_dir
)
# generate(
# input_=Path(schema_path),
# input_file_type=InputFileType.JsonSchema,
# output=Path(output_file),
# output_model_type=DataModelType.PydanticV2BaseModel,
# class_name=class_name,
# use_annotated=False,
# field_constraints=True,
# )
# Path(schema_file).unlink(missing_ok=True)
return generated_model_path

View file

@ -9,8 +9,7 @@ class TempFileService:
def __init__(self):
self.base_dir = get_temp_directory_env() or "/tmp/presenton"
# TODO: Uncomment this when we want to cleanup the base dir on startup
# self.cleanup_base_dir()
self.cleanup_base_dir()
os.makedirs(self.base_dir, exist_ok=True)
def create_dir_in_dir(self, base_dir: str, dir_name: Optional[str] = None) -> str:

View file

@ -22,12 +22,6 @@ def get_presentation_outline_model_with_n_slides(n_slides: int):
min_length=10,
max_length=50,
)
notes: Optional[List[str]] = Field(
default=None,
description="Important notes for the presentation styling and formatting",
min_length=0,
max_length=10,
)
slides: List[SlideOutlineModelWithValidation] = Field(
description="List of slides", min_items=n_slides, max_items=n_slides
)

View file

@ -1,17 +1,8 @@
import asyncio
import json
from models.llm_message import LLMMessage
from models.presentation_layout import SlideLayoutModel
from models.sql.slide import SlideModel
from google.genai.types import GenerateContentConfig
from utils.llm_provider import (
get_anthropic_llm_client,
get_google_llm_client,
get_large_model,
get_llm_client,
is_anthropic_selected,
is_google_selected,
)
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
from utils.schema_utils import remove_fields_from_schema
system_prompt = """
@ -44,20 +35,20 @@ def get_user_prompt(prompt: str, slide_data: dict, language: str):
"""
def get_prompt_to_edit_slide_content(
def get_messages(
prompt: str,
slide_data: dict,
language: str,
):
return [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": get_user_prompt(prompt, slide_data, language),
},
LLMMessage(
role="system",
content=system_prompt,
),
LLMMessage(
role="user",
content=get_user_prompt(prompt, slide_data, language),
),
]
@ -72,37 +63,10 @@ async def get_edited_slide_content(
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
)
if is_google_selected():
client = get_google_llm_client()
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
contents=[get_user_prompt(prompt, slide.content, language)],
config=GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type="application/json",
response_json_schema=response_schema,
),
)
slide_content_json = json.loads(response.text)
else:
client = get_llm_client()
response = await client.beta.chat.completions.parse(
model=model,
messages=get_prompt_to_edit_slide_content(
prompt,
slide.content,
language,
),
response_format={
"type": "json_schema",
"json_schema": {
"name": "slide_content",
"schema": response_schema,
},
},
)
slide_content_json = json.loads(response.choices[0].message.content)
return slide_content_json
client = LLMClient()
response = await client.generate_structured(
model=model,
messages=get_messages(prompt, slide.content, language),
response_format=response_schema,
)
return response

View file

@ -1,8 +1,9 @@
import asyncio
from typing import List
from openai.types.chat.chat_completion import ChatCompletion
from utils.llm_provider import get_llm_client, get_nano_model
from models.llm_message import LLMMessage
from services.llm_client import LLMClient
from utils.llm_provider import get_nano_model
sysmte_prompt = """
@ -23,23 +24,21 @@ Maintain as much information as possible.
async def generate_document_summary(documents: List[str]):
client = get_llm_client()
client = LLMClient()
model = get_nano_model()
coroutines = []
for document in documents:
truncated_text = document[:200000]
coroutine = client.chat.completions.create(
coroutine = client.generate(
model=model,
messages=[
{"role": "system", "content": sysmte_prompt},
{"role": "user", "content": truncated_text},
LLMMessage(role="system", content=sysmte_prompt),
LLMMessage(role="user", content=truncated_text),
],
)
coroutines.append(coroutine)
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
combined = "\n\n\n\n".join(
[completion.choices[0].message.content for completion in completions]
)
completions: List[str] = await asyncio.gather(*coroutines)
combined = "\n\n\n\n".join(completions)
return combined

View file

@ -1,3 +1,4 @@
import asyncio
from typing import Optional
from models.llm_message import LLMMessage

View file

@ -1,22 +1,19 @@
from models.llm_message import LLMMessage
from models.presentation_layout import PresentationLayoutModel
from models.presentation_outline_model import PresentationOutlineModel
from utils.llm_provider import (
get_large_model,
get_llm_client,
)
from utils.get_dynamic_models import (
get_presentation_structure_model_with_n_slides,
)
from models.presentation_structure_model import (
PresentationStructureModel,
)
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
from utils.get_dynamic_models import get_presentation_structure_model_with_n_slides
from models.presentation_structure_model import PresentationStructureModel
def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data: str):
def get_messages(
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
):
return [
{
"role": "system",
"content": f"""
LLMMessage(
role="system",
content=f"""
You're a professional presentation designer with creative freedom to design engaging presentations.
{presentation_layout.to_string()}
@ -49,13 +46,13 @@ def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
""",
},
{
"role": "user",
"content": f"""
),
LLMMessage(
role="user",
content=f"""
{data}
""",
},
),
]
@ -64,19 +61,19 @@ async def generate_presentation_structure(
presentation_layout: PresentationLayoutModel,
) -> PresentationStructureModel:
client = get_llm_client()
client = LLMClient()
model = get_large_model()
response_model = get_presentation_structure_model_with_n_slides(
len(presentation_outline.slides)
)
response = await client.beta.chat.completions.parse(
response = await client.generate_structured(
model=model,
messages=get_prompt(
messages=get_messages(
presentation_layout,
len(presentation_outline.slides),
presentation_outline.to_string(),
),
response_format=response_model,
)
return response.choices[0].message.parsed
return response

View file

@ -1,8 +1,10 @@
import asyncio
import json
from google.genai.types import GenerateContentConfig
from models.llm_message import LLMMessage
from models.presentation_layout import SlideLayoutModel
from models.presentation_outline_model import SlideOutlineModel
from services.llm_client import LLMClient
from utils.llm_provider import (
get_anthropic_llm_client,
get_google_llm_client,
@ -47,58 +49,37 @@ def get_user_prompt(title: str, outline: str, language: str):
"""
def get_prompt_to_generate_slide_content(title: str, outline: str, language: str):
def get_messages(title: str, outline: str, language: str):
return [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": get_user_prompt(title, outline, language),
},
LLMMessage(
role="system",
content=system_prompt,
),
LLMMessage(
role="user",
content=get_user_prompt(title, outline, language),
),
]
async def get_slide_content_from_type_and_outline(
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
):
client = LLMClient()
model = get_large_model()
response_schema = remove_fields_from_schema(
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
)
if is_google_selected():
client = get_google_llm_client()
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
contents=[get_user_prompt(outline.title, outline.body, language)],
config=GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type="application/json",
response_json_schema=response_schema,
),
)
return json.loads(response.text)
else:
client = get_llm_client()
response = await client.beta.chat.completions.parse(
model=model,
messages=get_prompt_to_generate_slide_content(
outline.title,
outline.body,
language,
),
response_format={
"type": "json_schema",
"json_schema": {
"name": "SlideContent",
"schema": response_schema,
},
},
)
return json.loads(response.choices[0].message.content)
response = await client.generate_structured(
model=model,
messages=get_messages(
outline.title,
outline.body,
language,
),
response_format=response_schema,
)
return response

View file

@ -1,19 +1,21 @@
from models.llm_message import LLMMessage
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from models.slide_layout_index import SlideLayoutIndex
from models.sql.slide import SlideModel
from utils.llm_provider import get_large_model, get_llm_client
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
def get_prompt_to_select_slide_layout(
def get_messages(
prompt: str,
slide_data: dict,
layout: PresentationLayoutModel,
current_slide_layout: int,
):
return [
{
"role": "system",
"content": f"""
LLMMessage(
role="system",
content=f"""
Select a Slide Layout index based on provided user prompt and current slide data.
{layout.to_string()}
@ -23,15 +25,15 @@ def get_prompt_to_select_slide_layout(
- If user prompt is not clear, select the layout that is most relevant to the slide data.
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
""",
},
{
"role": "user",
"content": f"""
),
LLMMessage(
role="user",
content=f"""
- User Prompt: {prompt}
- Current Slide Data: {slide_data}
- Current Slide Layout: {current_slide_layout}
""",
},
),
]
@ -41,15 +43,14 @@ async def get_slide_layout_from_prompt(
slide: SlideModel,
) -> SlideLayoutModel:
client = get_llm_client()
client = LLMClient()
model = get_large_model()
slide_layout_ids = list(map(lambda x: x.id, layout.slides))
response = await client.beta.chat.completions.parse(
response: SlideLayoutIndex = await client.generate_structured(
model=model,
temperature=0.2,
messages=get_prompt_to_select_slide_layout(
messages=get_messages(
prompt,
slide.content,
layout,
@ -57,5 +58,5 @@ async def get_slide_layout_from_prompt(
),
response_format=SlideLayoutIndex,
)
index = response.choices[0].message.parsed.index
index = response.index
return layout.slides[index]

View file

@ -107,7 +107,7 @@ def get_anthropic_llm_client():
def get_large_model():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return "gpt-4.1"
return "gpt-4.1-nano"
elif selected_llm == LLMProvider.GOOGLE:
return "gemini-2.0-flash"
elif selected_llm == LLMProvider.ANTHROPIC: