feat(fastapi): adds strict support for every schemas, proper models check, refactor

This commit is contained in:
sauravniraula 2025-08-01 00:15:03 +05:45
parent 9ad017f164
commit e542fdf869
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
38 changed files with 542 additions and 535 deletions

View file

@ -1,10 +1,8 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from api.lifespan import app_lifespan
from api.middlewares import UserConfigEnvUpdateMiddleware
from api.v1.ppt.router import API_V1_PPT_ROUTER
from utils.asset_directory_utils import get_exports_directory, get_images_directory, get_uploads_directory
app = FastAPI(lifespan=app_lifespan)
@ -13,25 +11,6 @@ app = FastAPI(lifespan=app_lifespan)
# Routers
app.include_router(API_V1_PPT_ROUTER)
# Static files
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount(
"/app_data/images",
StaticFiles(directory=get_images_directory()),
name="app_data/images",
)
app.mount(
"/app_data/exports",
StaticFiles(directory=get_exports_directory()),
name="app_data/exports",
)
app.mount(
"/app_data/uploads",
StaticFiles(directory=get_uploads_directory()),
name="app_data/uploads",
)
# Middlewares
origins = ["*"]
app.add_middleware(

View file

@ -1,20 +1,16 @@
from typing import Annotated, List, Optional
import anthropic
from typing import Annotated, List
from fastapi import APIRouter, Body, HTTPException
from utils.get_env import get_anthropic_api_key_env
from utils.available_models import list_available_anthropic_models
ANTHROPIC_ROUTER = APIRouter(prefix="/anthropic", tags=["Anthropic"])
@ANTHROPIC_ROUTER.post("/models/available", response_model=List[str])
async def get_available_models(
api_key: Annotated[Optional[str], Body(embed=True)] = None,
api_key: Annotated[str, Body(embed=True)],
):
anthropic_api_key = api_key or get_anthropic_api_key_env()
if not anthropic_api_key:
raise HTTPException(status_code=400, detail="Anthropic API key is required")
client = anthropic.Anthropic(api_key=anthropic_api_key)
models = client.models.list(limit=20)
return [model.id for model in models]
try:
return await list_available_anthropic_models(api_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View file

@ -1,18 +0,0 @@
from typing import Annotated, List, Optional
from fastapi import APIRouter, Body, HTTPException
from utils.custom_llm_provider import list_available_custom_models
CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"])
@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str])
async def get_available_models(
url: Annotated[Optional[str], Body()] = None,
api_key: Annotated[Optional[str], Body()] = None,
):
try:
return await list_available_custom_models(url, api_key)
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))

View file

@ -0,0 +1,14 @@
from typing import Annotated, List
from fastapi import APIRouter, Body, HTTPException
from utils.available_models import list_available_google_models
GOOGLE_ROUTER = APIRouter(prefix="/google", tags=["Google"])
@GOOGLE_ROUTER.post("/models/available", response_model=List[str])
async def get_available_models(api_key: Annotated[str, Body(embed=True)]):
try:
return await list_available_google_models(api_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View file

@ -0,0 +1,17 @@
from typing import Annotated, List
from fastapi import APIRouter, Body, HTTPException
from utils.available_models import list_available_openai_compatible_models
OPENAI_ROUTER = APIRouter(prefix="/openai", tags=["OpenAI"])
@OPENAI_ROUTER.post("/models/available", response_model=List[str])
async def get_available_models(
url: Annotated[str, Body()],
api_key: Annotated[str, Body()],
):
try:
return await list_available_openai_compatible_models(url, api_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View file

@ -28,7 +28,7 @@ from utils.export_utils import export_presentation
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
from models.sql.slide import SlideModel
from models.sse_response import SSECompleteResponse, SSEResponse
from services import SCHEMA_TO_MODEL_SERVICE, TEMP_FILE_SERVICE
from services import TEMP_FILE_SERVICE
from services.database import get_async_session
from services.documents_loader import DocumentsLoader
from models.sql.presentation import PresentationModel
@ -43,7 +43,6 @@ from utils.llm_calls.generate_slide_content import (
)
from utils.process_slides import process_slide_and_fetch_assets
from utils.randomizers import get_random_uuid
from utils.schema_utils import remove_fields_from_schema
from utils.validators import validate_files
PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"])
@ -220,20 +219,8 @@ async def stream_presentation(
for i, slide_layout_index in enumerate(structure.slides):
slide_layout = layout.slides[slide_layout_index]
# Generate Pydantic model from slide layout schema
schema_model_id = f"{layout.name}/{slide_layout.id}"
response_schema = remove_fields_from_schema(
slide_layout.json_schema, ["image_url_", "icon_url_"]
)
schema_model_path = (
await SCHEMA_TO_MODEL_SERVICE.get_pydantic_model_path_from_schema(
schema_model_id, response_schema
)
)
module = importlib.import_module(schema_model_path)
response_model = module.GeneratedModel
slide_content = await get_slide_content_from_type_and_outline(
response_model, outline.slides[i], presentation.language
slide_layout, outline.slides[i], presentation.language
)
slide = SlideModel(
@ -252,9 +239,6 @@ async def stream_presentation(
)
)
# Give control to the event loop
await asyncio.sleep(0)
yield SSEResponse(
event="response",
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
@ -491,7 +475,6 @@ async def from_template(
new_slide_data = list(filter(lambda x: x.index == each_slide.index, data.data))
if new_slide_data:
updated_content = deep_update(each_slide.content, new_slide_data[0].content)
print(f"Updated content for slide {each_slide.index}: {updated_content}")
new_slides.append(
each_slide.get_new_slide(new_presentation.id, updated_content)
)

View file

@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from models.sql.presentation import PresentationModel
from models.sql.slide import SlideModel
from services import SCHEMA_TO_MODEL_SERVICE
from services.database import get_async_session
from services.icon_finder_service import IconFinderService
from services.image_generation_service import ImageGenerationService
@ -35,25 +34,12 @@ async def edit_slide(
raise HTTPException(status_code=404, detail="Presentation not found")
presentation_layout = presentation.get_layout()
slide_layout = await get_slide_layout_from_prompt(
prompt, presentation_layout, slide
)
# Generate Pydantic model from slide layout schema
schema_model_id = f"{presentation_layout.name}/{slide_layout.id}"
response_schema = remove_fields_from_schema(
slide_layout.json_schema, ["image_url_", "icon_url_"]
)
schema_model_path = (
await SCHEMA_TO_MODEL_SERVICE.get_pydantic_model_path_from_schema(
schema_model_id, response_schema
)
)
module = importlib.import_module(schema_model_path)
response_model = module.GeneratedModel
edited_slide_content = await get_edited_slide_content(
prompt, slide, presentation.language, response_model
prompt, slide, presentation.language, slide_layout
)
image_generation_service = ImageGenerationService(get_images_directory())

View file

@ -1,7 +1,8 @@
from fastapi import APIRouter
from api.v1.ppt.endpoints.anthropic import ANTHROPIC_ROUTER
from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER
from api.v1.ppt.endpoints.google import GOOGLE_ROUTER
from api.v1.ppt.endpoints.openai import OPENAI_ROUTER
from api.v1.ppt.endpoints.files import FILES_ROUTER
from api.v1.ppt.endpoints.icons import ICONS_ROUTER
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
@ -20,5 +21,6 @@ API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER)
API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER)
API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER)
API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER)

View file

@ -0,0 +1,6 @@
OPENAI_URL = "https://api.openai.com/v1"
# Default models
DEFAULT_OPENAI_MODEL = "gpt-4.1"
DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash"
DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"

View file

@ -7,7 +7,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3:8b",
description="❌ Graphs not supported.",
size="4.7GB",
supports_graph=False,
icon="/static/icons/meta.png",
),
"llama3:70b": OllamaModelMetadata(
@ -15,7 +14,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3:70b",
description="✅ Graphs supported.",
size="40GB",
supports_graph=True,
icon="/static/icons/meta.png",
),
"llama3.1:8b": OllamaModelMetadata(
@ -23,7 +21,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3.1:8b",
description="❌ Graphs not supported.",
size="4.9GB",
supports_graph=False,
icon="/static/icons/meta.png",
),
"llama3.1:70b": OllamaModelMetadata(
@ -31,7 +28,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3.1:70b",
description="✅ Graphs supported.",
size="43GB",
supports_graph=True,
icon="/static/icons/meta.png",
),
"llama3.1:405b": OllamaModelMetadata(
@ -39,7 +35,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3.1:405b",
description="✅ Graphs supported.",
size="243GB",
supports_graph=True,
icon="/static/icons/meta.png",
),
"llama3.2:1b": OllamaModelMetadata(
@ -47,7 +42,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3.2:1b",
description="❌ Graphs not supported.",
size="1.3GB",
supports_graph=False,
icon="/static/icons/meta.png",
),
"llama3.2:3b": OllamaModelMetadata(
@ -55,7 +49,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3.2:3b",
description="❌ Graphs not supported.",
size="2GB",
supports_graph=False,
icon="/static/icons/meta.png",
),
"llama3.3:70b": OllamaModelMetadata(
@ -63,7 +56,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama3.3:70b",
description="✅ Graphs supported.",
size="43GB",
supports_graph=True,
icon="/static/icons/meta.png",
),
"llama4:16x17b": OllamaModelMetadata(
@ -71,7 +63,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama4:16x17b",
description="✅ Graphs supported.",
size="67GB",
supports_graph=True,
icon="/static/icons/meta.png",
),
"llama4:128x17b": OllamaModelMetadata(
@ -79,7 +70,6 @@ SUPPORTED_OLLAMA_MODELS = {
value="llama4:128x17b",
description="✅ Graphs supported.",
size="245GB",
supports_graph=True,
icon="/static/icons/meta.png",
),
}
@ -90,7 +80,6 @@ SUPPORTED_GEMMA_MODELS = {
value="gemma3:1b",
description="❌ Graphs not supported.",
size="815MB",
supports_graph=False,
icon="/static/icons/gemma.png",
),
"gemma3:4b": OllamaModelMetadata(
@ -98,7 +87,6 @@ SUPPORTED_GEMMA_MODELS = {
value="gemma3:4b",
description="❌ Graphs not supported.",
size="3.3GB",
supports_graph=False,
icon="/static/icons/gemma.png",
),
"gemma3:12b": OllamaModelMetadata(
@ -106,7 +94,6 @@ SUPPORTED_GEMMA_MODELS = {
value="gemma3:12b",
description="❌ Graphs not supported.",
size="8.1GB",
supports_graph=False,
icon="/static/icons/gemma.png",
),
"gemma3:27b": OllamaModelMetadata(
@ -114,7 +101,6 @@ SUPPORTED_GEMMA_MODELS = {
value="gemma3:27b",
description="✅ Graphs supported.",
size="17GB",
supports_graph=True,
icon="/static/icons/gemma.png",
),
}
@ -125,7 +111,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:1.5b",
description="❌ Graphs not supported.",
size="1.1GB",
supports_graph=False,
icon="/static/icons/deepseek.png",
),
"deepseek-r1:7b": OllamaModelMetadata(
@ -133,7 +118,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:7b",
description="❌ Graphs not supported.",
size="4.7GB",
supports_graph=False,
icon="/static/icons/deepseek.png",
),
"deepseek-r1:8b": OllamaModelMetadata(
@ -141,7 +125,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:8b",
description="❌ Graphs not supported.",
size="5.2GB",
supports_graph=False,
icon="/static/icons/deepseek.png",
),
"deepseek-r1:14b": OllamaModelMetadata(
@ -149,7 +132,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:14b",
description="❌ Graphs not supported.",
size="9GB",
supports_graph=False,
icon="/static/icons/deepseek.png",
),
"deepseek-r1:32b": OllamaModelMetadata(
@ -157,7 +139,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:32b",
description="✅ Graphs supported.",
size="20GB",
supports_graph=True,
icon="/static/icons/deepseek.png",
),
"deepseek-r1:70b": OllamaModelMetadata(
@ -165,7 +146,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:70b",
description="✅ Graphs supported.",
size="43GB",
supports_graph=True,
icon="/static/icons/deepseek.png",
),
"deepseek-r1:671b": OllamaModelMetadata(
@ -173,7 +153,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
value="deepseek-r1:671b",
description="✅ Graphs supported.",
size="404GB",
supports_graph=True,
icon="/static/icons/deepseek.png",
),
}
@ -184,7 +163,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:0.6b",
description="❌ Graphs not supported.",
size="523MB",
supports_graph=False,
icon="/static/icons/qwen.png",
),
"qwen3:1.7b": OllamaModelMetadata(
@ -192,7 +170,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:1.7b",
description="❌ Graphs not supported.",
size="1.4GB",
supports_graph=False,
icon="/static/icons/qwen.png",
),
"qwen3:4b": OllamaModelMetadata(
@ -200,7 +177,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:4b",
description="❌ Graphs not supported.",
size="2.6GB",
supports_graph=False,
icon="/static/icons/qwen.png",
),
"qwen3:8b": OllamaModelMetadata(
@ -208,7 +184,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:8b",
description="❌ Graphs not supported.",
size="5.2GB",
supports_graph=False,
icon="/static/icons/qwen.png",
),
"qwen3:14b": OllamaModelMetadata(
@ -216,7 +191,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:14b",
description="❌ Graphs not supported.",
size="9.3GB",
supports_graph=False,
icon="/static/icons/qwen.png",
),
"qwen3:30b": OllamaModelMetadata(
@ -224,7 +198,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:30b",
description="✅ Graphs supported.",
size="19GB",
supports_graph=True,
icon="/static/icons/qwen.png",
),
"qwen3:32b": OllamaModelMetadata(
@ -232,7 +205,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:32b",
description="✅ Graphs supported.",
size="20GB",
supports_graph=True,
icon="/static/icons/qwen.png",
),
"qwen3:235b": OllamaModelMetadata(
@ -240,7 +212,6 @@ SUPPORTED_QWEN_MODELS = {
value="qwen3:235b",
description="✅ Graphs supported.",
size="142GB",
supports_graph=True,
icon="/static/icons/qwen.png",
),
}

View file

@ -17,8 +17,8 @@ class ContactInfoModel(BaseModel):
class ImageModel(BaseModel):
image_url_: str = Field(description="Image URL")
image_prompt_: str = Field(description="Image prompt")
__image_url__: str = Field(description="Image URL")
__image_prompt__: str = Field(description="Image prompt")
# First Slide Layout

View file

@ -7,4 +7,3 @@ class OllamaModelMetadata(BaseModel):
description: str
icon: str
size: str
supports_graph: bool

View file

@ -1,4 +1,5 @@
from typing import List, Optional
from fastapi import HTTPException
from pydantic import BaseModel, Field
from models.presentation_structure_model import PresentationStructureModel
@ -16,6 +17,14 @@ class PresentationLayoutModel(BaseModel):
ordered: bool = Field(default=False)
slides: List[SlideLayoutModel]
def get_slide_layout_index(self, slide_layout_id: str) -> int:
for index, slide in enumerate(self.slides):
if slide.id == slide_layout_id:
return index
raise HTTPException(
status_code=404, detail=f"Slide layout {slide_layout_id} not found"
)
def to_presentation_structure(self):
return PresentationStructureModel(
slides=[index for index in range(len(self.slides))]

View file

@ -4,16 +4,32 @@ from pydantic import BaseModel
class UserConfig(BaseModel):
LLM: Optional[str] = None
# OpenAI
OPENAI_API_KEY: Optional[str] = None
OPENAI_MODEL: Optional[str] = None
# Google
GOOGLE_API_KEY: Optional[str] = None
GOOGLE_MODEL: Optional[str] = None
# Anthropic
ANTHROPIC_API_KEY: Optional[str] = None
ANTHROPIC_MODEL: Optional[str] = None
# Ollama
OLLAMA_URL: Optional[str] = None
OLLAMA_MODEL: Optional[str] = None
# Custom LLM
CUSTOM_LLM_URL: Optional[str] = None
CUSTOM_LLM_API_KEY: Optional[str] = None
CUSTOM_MODEL: Optional[str] = None
PEXELS_API_KEY: Optional[str] = None
# Image Provider
IMAGE_PROVIDER: Optional[str] = None
PEXELS_API_KEY: Optional[str] = None
PIXABAY_API_KEY: Optional[str] = None
# Reasoning
EXTENDED_REASONING: Optional[bool] = None

View file

@ -22,7 +22,6 @@ chromadb==1.0.15
click==8.2.1
coloredlogs==15.0.1
cryptography==45.0.5
datamodel-code-generator==0.32.0
distro==1.9.0
dnspython==2.7.0
durationpy==0.10

View file

@ -1,8 +1,6 @@
from services.redis_service import RedisService
from services.schema_to_model_service import SchemaToModelService
from services.temp_file_service import TempFileService
TEMP_FILE_SERVICE = TempFileService()
REDIS_SERVICE = RedisService()
SCHEMA_TO_MODEL_SERVICE = SchemaToModelService(TEMP_FILE_SERVICE)

View file

@ -3,12 +3,12 @@ import os
import aiohttp
from google import genai
from google.genai.types import GenerateContentConfig
from openai import AsyncOpenAI
from models.image_prompt import ImagePrompt
from models.sql.image_asset import ImageAsset
from utils.download_helpers import download_file
from utils.get_env import get_pexels_api_key_env
from utils.get_env import get_pixabay_api_key_env
from utils.llm_provider import get_llm_client
from utils.image_provider import (
is_pixels_selected,
is_pixabay_selected,
@ -80,7 +80,7 @@ class ImageGenerationService:
return "/static/images/placeholder.jpg"
async def generate_image_openai(self, prompt: str, output_directory: str) -> str:
client = get_llm_client()
client = AsyncOpenAI()
result = await client.images.generate(
model="dall-e-3",
prompt=prompt,

View file

@ -1,6 +1,6 @@
import asyncio
import json
from typing import List, Optional
from typing import List
from fastapi import HTTPException
from openai import AsyncOpenAI
from google import genai
@ -8,7 +8,6 @@ from google.genai.types import GenerateContentConfig
from anthropic import AsyncAnthropic
from anthropic.types import Message as AnthropicMessage
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
from pydantic import BaseModel
from enums.llm_provider import LLMProvider
from models.llm_message import LLMMessage
from utils.async_iterator import iterator_to_async
@ -21,6 +20,7 @@ from utils.get_env import (
get_openai_api_key_env,
)
from utils.llm_provider import get_llm_provider
from utils.schema_utils import ensure_strict_json_schema
class LLMClient:
@ -173,43 +173,45 @@ class LLMClient:
# ? Generate Structured Content
async def _generate_openai_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
client: AsyncOpenAI = self._client
is_response_format_dict = isinstance(response_format, dict)
if is_response_format_dict:
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
{
"name": "ResponseSchema",
"schema": response_format,
}
),
},
max_completion_tokens=self.max_tokens,
response_schema = response_format
if strict:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
content = response.choices[0].message.content
if content:
return json.loads(content)
return None
else:
response = await client.chat.completions.parse(
model=model,
messages=[message.model_dump() for message in messages],
response_format=response_format,
max_completion_tokens=self.max_tokens,
)
content = response.choices[0].message.parsed
if content:
return content.model_dump(mode="json")
return None
response = await client.chat.completions.create(
model=model,
messages=[message.model_dump() for message in messages],
response_format={
"type": "json_schema",
"json_schema": (
{
"name": "ResponseSchema",
"strict": strict,
"schema": response_schema,
}
),
},
max_completion_tokens=self.max_tokens,
)
content = response.choices[0].message.content
if content:
return json.loads(content)
return None
async def _generate_google_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
):
client: genai.Client = self._client
response = await asyncio.to_thread(
@ -219,7 +221,7 @@ class LLMClient:
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json",
response_schema=response_format,
response_json_schema=response_format,
max_output_tokens=self.max_tokens,
),
)
@ -230,10 +232,12 @@ class LLMClient:
return content
async def _generate_anthropic_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
):
client: AsyncAnthropic = self._client
is_response_format_dict = isinstance(response_format, dict)
response: AnthropicMessage = await client.messages.create(
model=model,
system=self._get_system_prompt(messages),
@ -246,11 +250,7 @@ class LLMClient:
{
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": (
response_format
if is_response_format_dict
else response_format.model_json_schema()
),
"input_schema": response_format,
}
],
)
@ -262,23 +262,39 @@ class LLMClient:
return content
async def _generate_ollama_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
return await self._generate_openai_structured(model, messages, response_format)
return await self._generate_openai_structured(
model, messages, response_format, strict
)
async def _generate_custom_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
return await self._generate_openai_structured(model, messages, response_format)
return await self._generate_openai_structured(
model, messages, response_format, strict
)
async def generate_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
) -> dict:
content = None
match self.llm_provider:
case LLMProvider.OPENAI:
content = await self._generate_openai_structured(
model, messages, response_format
model, messages, response_format, strict
)
case LLMProvider.GOOGLE:
content = await self._generate_google_structured(
@ -290,11 +306,11 @@ class LLMClient:
)
case LLMProvider.OLLAMA:
content = await self._generate_ollama_structured(
model, messages, response_format
model, messages, response_format, strict
)
case LLMProvider.CUSTOM:
content = await self._generate_custom_structured(
model, messages, response_format
model, messages, response_format, strict
)
if content is None:
raise HTTPException(
@ -366,10 +382,20 @@ class LLMClient:
# ? Stream Structured Content
async def _stream_openai_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
client: AsyncOpenAI = self._client
is_response_format_dict = isinstance(response_format, dict)
response_schema = response_format
if strict:
response_schema = ensure_strict_json_schema(
response_schema,
path=(),
root=response_schema,
)
async with client.chat.completions.stream(
model=model,
messages=[message.model_dump() for message in messages],
@ -379,11 +405,10 @@ class LLMClient:
"type": "json_schema",
"json_schema": {
"name": "ResponseSchema",
"schema": response_format,
"strict": strict,
"schema": response_schema,
},
}
if is_response_format_dict
else response_format
),
) as stream:
async for event in stream:
@ -391,7 +416,10 @@ class LLMClient:
yield event.delta
async def _stream_google_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
):
client: genai.Client = self._client
async for event in iterator_to_async(client.models.generate_content_stream)(
@ -400,7 +428,7 @@ class LLMClient:
config=GenerateContentConfig(
system_instruction=self._get_system_prompt(messages),
response_mime_type="application/json",
response_schema=response_format,
response_json_schema=response_format,
max_output_tokens=self.max_tokens,
),
):
@ -408,10 +436,12 @@ class LLMClient:
yield event.text
async def _stream_anthropic_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
):
client: AsyncAnthropic = self._client
is_response_format_dict = isinstance(response_format, dict)
async with client.messages.stream(
model=model,
system=self._get_system_prompt(messages),
@ -424,11 +454,7 @@ class LLMClient:
{
"name": "ResponseSchema",
"description": "A response to the user's message",
"input_schema": (
response_format
if is_response_format_dict
else response_format.model_json_schema()
),
"input_schema": response_format,
}
],
) as stream:
@ -438,21 +464,35 @@ class LLMClient:
yield event.partial_json
def _stream_ollama_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
return self._stream_openai_structured(model, messages, response_format)
return self._stream_openai_structured(model, messages, response_format, strict)
def _stream_custom_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
return self._stream_openai_structured(model, messages, response_format)
return self._stream_openai_structured(model, messages, response_format, strict)
def stream_structured(
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
self,
model: str,
messages: List[LLMMessage],
response_format: dict,
strict: bool = False,
):
match self.llm_provider:
case LLMProvider.OPENAI:
return self._stream_openai_structured(model, messages, response_format)
return self._stream_openai_structured(
model, messages, response_format, strict
)
case LLMProvider.GOOGLE:
return self._stream_google_structured(model, messages, response_format)
case LLMProvider.ANTHROPIC:
@ -460,6 +500,10 @@ class LLMClient:
model, messages, response_format
)
case LLMProvider.OLLAMA:
return self._stream_ollama_structured(model, messages, response_format)
return self._stream_ollama_structured(
model, messages, response_format, strict
)
case LLMProvider.CUSTOM:
return self._stream_custom_structured(model, messages, response_format)
return self._stream_custom_structured(
model, messages, response_format, strict
)

View file

@ -1,78 +0,0 @@
import asyncio
import json
import os
from pathlib import Path
from typing import Dict
from fastapi import HTTPException
from datamodel_code_generator import generate, InputFileType, DataModelType
from services.temp_file_service import TempFileService
from utils.randomizers import get_random_uuid
class SchemaToModelService:
def __init__(self, temp_file_service: TempFileService):
self.temp_file_service = temp_file_service
self.temp_dir = self.temp_file_service.create_temp_dir()
self.generated_models_dir = "generated_models"
if os.path.exists(self.generated_models_dir):
for file in os.listdir(self.generated_models_dir):
if file.endswith(".py"):
os.remove(os.path.join(self.generated_models_dir, file))
os.makedirs(self.generated_models_dir, exist_ok=True)
self._records: Dict[str, str] = {}
self._fetch_locks: Dict[str, asyncio.Lock] = {}
def convert_path_to_module_path(self, path: str):
return path.replace("/", ".").replace("\\", ".").replace(".py", "")
async def get_pydantic_model_path_from_schema(
self, identifier: str, schema: dict
) -> str:
if identifier in self._fetch_locks:
async with self._fetch_locks[identifier]:
return self._records[identifier]
else:
async_lock = asyncio.Lock()
await async_lock.acquire()
self._fetch_locks[identifier] = async_lock
model_path = await self.generate_pydantic_model_from_schema_async(schema)
model_path = self.convert_path_to_module_path(model_path)
self._records[identifier] = model_path
async_lock.release()
return model_path
async def generate_pydantic_model_from_schema_async(self, schema: dict):
return await asyncio.to_thread(self.generate_pydantic_model_from_schema, schema)
def generate_pydantic_model_from_schema(self, schema: dict):
generated_model_path = os.path.join(
self.generated_models_dir, get_random_uuid() + ".py"
)
try:
schema_path = self.temp_file_service.create_temp_file_path(
get_random_uuid() + ".json", self.temp_dir
)
with open(schema_path, "w") as f:
json.dump(schema, f)
generate(
input_=Path(schema_path),
input_file_type=InputFileType.JsonSchema,
output=Path(generated_model_path),
output_model_type=DataModelType.PydanticV2BaseModel,
class_name="GeneratedModel",
use_annotated=False,
field_constraints=True,
extra_fields="ignore",
)
except Exception as e:
raise HTTPException(
status_code=500, detail="Failed to generate Pydantic model from schema"
)
finally:
self.temp_file_service.cleanup_temp_file(schema_path)
return generated_model_path

View file

@ -0,0 +1,21 @@
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from google import genai
async def list_available_openai_compatible_models(url: str, api_key: str) -> list[str]:
client = AsyncOpenAI(api_key=api_key, base_url=url)
models = (await client.models.list()).data
if models:
return list(map(lambda x: x.id, models))
return []
async def list_available_anthropic_models(api_key: str) -> list[str]:
client = AsyncAnthropic(api_key=api_key)
return list(map(lambda x: x.id, (await client.models.list(limit=50)).data))
async def list_available_google_models(api_key: str) -> list[str]:
client = genai.Client(api_key=api_key)
return list(map(lambda x: x.name, client.models.list(config={"page_size": 50})))

View file

@ -1,17 +0,0 @@
from typing import Optional
from openai import AsyncOpenAI
from utils.llm_provider import get_llm_client
async def list_available_custom_models(
url: Optional[str] = None, api_key: Optional[str] = None
) -> list[str]:
if not url:
client = get_llm_client()
else:
client = AsyncOpenAI(api_key=api_key or "null", base_url=url)
models = []
async for model in client.models.list():
models.append(model.id)
return models

View file

@ -78,3 +78,12 @@ def deep_update(original: dict, updates: dict) -> dict:
if not isinstance(value, (dict, list)):
original[key] = value
return original
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
i = 0
for _ in obj.keys():
i += 1
if i > n:
return True
return False

View file

@ -45,10 +45,18 @@ def get_openai_api_key_env():
return os.getenv("OPENAI_API_KEY")
def get_openai_model_env():
return os.getenv("OPENAI_MODEL")
def get_google_api_key_env():
return os.getenv("GOOGLE_API_KEY")
def get_google_model_env():
return os.getenv("GOOGLE_MODEL")
def get_custom_llm_api_key_env():
return os.getenv("CUSTOM_LLM_API_KEY")

View file

@ -1,9 +1,8 @@
from pydantic import BaseModel
from models.llm_message import LLMMessage
from models.presentation_layout import SlideLayoutModel
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
from utils.llm_provider import get_model
from utils.schema_utils import remove_fields_from_schema
system_prompt = """
@ -57,14 +56,19 @@ async def get_edited_slide_content(
prompt: str,
slide: SlideModel,
language: str,
response_model: BaseModel,
slide_layout: SlideLayoutModel,
):
model = get_large_model()
model = get_model()
response_schema = remove_fields_from_schema(
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
)
client = LLMClient()
response = await client.generate_structured(
model=model,
messages=get_messages(prompt, slide.content, language),
response_format=response_model,
response_format=response_schema,
strict=False,
)
return response

View file

@ -1,14 +1,7 @@
import asyncio
from typing import Optional
from google.genai.types import GenerateContentConfig
from utils.llm_provider import (
get_anthropic_llm_client,
get_google_llm_client,
get_large_model,
is_anthropic_selected,
is_google_selected,
get_llm_client,
)
from models.llm_message import LLMMessage
from services.llm_client import LLMClient
from utils.llm_provider import get_model
system_prompt = """
You are an expert HTML slide editor. Your task is to modify slide HTML content based on user prompts while maintaining proper structure, styling, and functionality.
@ -54,48 +47,17 @@ def get_user_prompt(prompt: str, html: str):
async def get_edited_slide_html(prompt: str, html: str):
model = get_large_model()
llm_response = None
model = get_model()
if is_anthropic_selected():
client = get_anthropic_llm_client()
response = await client.messages.create(
model=model,
messages=[get_user_prompt(prompt, html)],
)
for each in response.content:
if each.type == "text":
llm_response = each.text
break
elif is_google_selected():
client = get_google_llm_client()
response = await asyncio.to_thread(
client.models.generate_content,
model=model,
contents=[get_user_prompt(prompt, html)],
config=GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type="text/plain",
),
)
llm_response = response.text
else:
client = get_llm_client()
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": get_user_prompt(prompt, html)},
],
)
llm_response = response.choices[0].message.content
if not llm_response:
return html
return extract_html_from_response(llm_response) or html
client = LLMClient()
response = await client.generate(
model=model,
messages=[
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
],
)
return extract_html_from_response(response) or html
def extract_html_from_response(response_text: str) -> Optional[str]:

View file

@ -3,7 +3,7 @@ from typing import List
from models.llm_message import LLMMessage
from services.llm_client import LLMClient
from utils.llm_provider import get_nano_model
from utils.llm_provider import get_model
sysmte_prompt = """
@ -25,7 +25,7 @@ Maintain as much information as possible.
async def generate_document_summary(documents: List[str]):
client = LLMClient()
model = get_nano_model()
model = get_model()
coroutines = []
for document in documents:

View file

@ -4,7 +4,7 @@ from typing import Optional
from models.llm_message import LLMMessage
from services.llm_client import LLMClient
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
from utils.llm_provider import get_large_model
from utils.llm_provider import get_model
system_prompt = """
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
@ -75,7 +75,7 @@ async def generate_ppt_outline(
language: Optional[str] = None,
content: Optional[str] = None,
):
model = get_large_model()
model = get_model()
response_model = get_presentation_outline_model_with_n_slides(n_slides)
client = LLMClient()
@ -83,6 +83,7 @@ async def generate_ppt_outline(
async for chunk in client.stream_structured(
model,
get_messages(prompt, n_slides, language, content),
response_model,
response_model.model_json_schema(),
strict=True,
):
yield chunk

View file

@ -2,7 +2,7 @@ from models.llm_message import LLMMessage
from models.presentation_layout import PresentationLayoutModel
from models.presentation_outline_model import PresentationOutlineModel
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
from utils.llm_provider import get_model
from utils.get_dynamic_models import get_presentation_structure_model_with_n_slides
from models.presentation_structure_model import PresentationStructureModel
@ -62,7 +62,7 @@ async def generate_presentation_structure(
) -> PresentationStructureModel:
client = LLMClient()
model = get_large_model()
model = get_model()
response_model = get_presentation_structure_model_with_n_slides(
len(presentation_outline.slides)
)
@ -74,6 +74,7 @@ async def generate_presentation_structure(
len(presentation_outline.slides),
presentation_outline.to_string(),
),
response_format=response_model,
response_format=response_model.model_json_schema(),
strict=True,
)
return PresentationStructureModel(**response)

View file

@ -1,8 +1,9 @@
from pydantic import BaseModel
from models.llm_message import LLMMessage
from models.presentation_layout import SlideLayoutModel
from models.presentation_outline_model import SlideOutlineModel
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
from utils.llm_provider import get_model
from utils.schema_utils import remove_fields_from_schema
system_prompt = """
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
@ -14,8 +15,8 @@ system_prompt = """
# Notes
- Slide body should not use words like "This slide", "This presentation".
- Rephrase the slide body to make it flow naturally.
- Provide prompt to generate image on "image_prompt_" property.
- Provide query to search icon on "icon_query_" property.
- Provide prompt to generate image on "__image_prompt__" property.
- Provide query to search icon on "__icon_query__" property.
- Do not use markdown formatting in slide body.
- Make sure to follow language guidelines.
**Strictly follow the max and min character limit for every property in the slide.**
@ -53,10 +54,14 @@ def get_messages(title: str, outline: str, language: str):
async def get_slide_content_from_type_and_outline(
response_model: BaseModel, outline: SlideOutlineModel, language: str
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
):
client = LLMClient()
model = get_large_model()
model = get_model()
response_schema = remove_fields_from_schema(
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
)
response = await client.generate_structured(
model=model,
@ -65,6 +70,7 @@ async def get_slide_content_from_type_and_outline(
outline.body,
language,
),
response_format=response_model,
response_format=response_schema,
strict=False,
)
return response

View file

@ -3,7 +3,7 @@ from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
from models.slide_layout_index import SlideLayoutIndex
from models.sql.slide import SlideModel
from services.llm_client import LLMClient
from utils.llm_provider import get_large_model
from utils.llm_provider import get_model
def get_messages(
@ -44,9 +44,9 @@ async def get_slide_layout_from_prompt(
) -> SlideLayoutModel:
client = LLMClient()
model = get_large_model()
model = get_model()
slide_layout_ids = list(map(lambda x: x.id, layout.slides))
slide_layout_index = layout.get_slide_layout_index(slide.layout)
response = await client.generate_structured(
model=model,
@ -54,9 +54,10 @@ async def get_slide_layout_from_prompt(
prompt,
slide.content,
layout,
slide_layout_ids.index(slide.layout),
slide_layout_index,
),
response_format=SlideLayoutIndex,
response_format=SlideLayoutIndex.model_json_schema(),
strict=True,
)
index = SlideLayoutIndex(**response).index
return layout.slides[index]

View file

@ -1,21 +1,18 @@
import os
import anthropic
from fastapi import HTTPException
from openai import AsyncOpenAI
from google import genai
from constants.llm import (
DEFAULT_ANTHROPIC_MODEL,
DEFAULT_GOOGLE_MODEL,
DEFAULT_OPENAI_MODEL,
)
from enums.llm_provider import LLMProvider
from utils.get_env import (
get_anthropic_api_key_env,
get_anthropic_model_env,
get_custom_llm_api_key_env,
get_custom_llm_url_env,
get_custom_model_env,
get_google_api_key_env,
get_google_model_env,
get_llm_provider_env,
get_ollama_model_env,
get_ollama_url_env,
get_openai_api_key_env,
get_openai_model_env,
)
@ -25,14 +22,10 @@ def get_llm_provider():
except:
raise HTTPException(
status_code=500,
detail=f"Invalid LLM provider. Please select one of: openai, google, ollama, custom",
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
)
def get_ollama_url():
return get_ollama_url_env() or "http://localhost:11434"
def is_openai_selected():
return get_llm_provider() == LLMProvider.OPENAI
@ -53,100 +46,20 @@ def is_custom_llm_selected():
return get_llm_provider() == LLMProvider.CUSTOM
def get_model_base_url():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return "https://api.openai.com/v1"
elif selected_llm == LLMProvider.GOOGLE:
return "https://generativelanguage.googleapis.com/v1beta/openai"
elif selected_llm == LLMProvider.ANTHROPIC:
return "https://api.anthropic.com/v1"
elif selected_llm == LLMProvider.OLLAMA:
return os.path.join(get_ollama_url(), "v1")
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_llm_url_env()
else:
raise HTTPException(f"LLM provider {selected_llm} is not supported")
def get_llm_api_key():
def get_model():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return get_openai_api_key_env()
return get_openai_model_env() or DEFAULT_OPENAI_MODEL
elif selected_llm == LLMProvider.GOOGLE:
return get_google_api_key_env()
return get_google_model_env() or DEFAULT_GOOGLE_MODEL
elif selected_llm == LLMProvider.ANTHROPIC:
return get_anthropic_api_key_env()
elif selected_llm == LLMProvider.OLLAMA:
return "ollama"
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_llm_api_key_env() or "none"
else:
raise HTTPException(f"LLM provider {selected_llm} is not supported")
def get_llm_client():
client = AsyncOpenAI(
base_url=get_model_base_url(),
api_key=get_llm_api_key(),
)
return client
def get_google_llm_client():
client = genai.Client(api_key=get_google_api_key_env())
return client
def get_anthropic_llm_client():
client = anthropic.AsyncAnthropic(api_key=get_anthropic_api_key_env())
return client
def get_large_model():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return "gpt-4.1"
elif selected_llm == LLMProvider.GOOGLE:
return "gemini-2.0-flash"
elif selected_llm == LLMProvider.ANTHROPIC:
return get_anthropic_model_env()
return get_anthropic_model_env() or DEFAULT_ANTHROPIC_MODEL
elif selected_llm == LLMProvider.OLLAMA:
return get_ollama_model_env()
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_model_env()
else:
raise ValueError(f"Invalid LLM model")
def get_small_model():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return "gpt-4.1-mini"
elif selected_llm == LLMProvider.GOOGLE:
return "gemini-2.0-flash"
elif selected_llm == LLMProvider.ANTHROPIC:
return get_anthropic_model_env()
elif selected_llm == LLMProvider.OLLAMA:
return get_ollama_model_env()
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_model_env()
else:
raise ValueError(f"Invalid LLM model")
def get_nano_model():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return "gpt-4.1-nano"
elif selected_llm == LLMProvider.GOOGLE:
return "gemini-2.0-flash"
elif selected_llm == LLMProvider.ANTHROPIC:
return get_anthropic_model_env()
elif selected_llm == LLMProvider.OLLAMA:
return get_ollama_model_env()
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_model_env()
else:
raise ValueError(f"Invalid LLM model")
raise HTTPException(
status_code=500,
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
)

View file

@ -1,11 +1,19 @@
import os
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from constants.llm import OPENAI_URL
from enums.image_provider import ImageProvider
from enums.llm_provider import LLMProvider
from utils.custom_llm_provider import list_available_custom_models
from utils.available_models import (
list_available_anthropic_models,
list_available_google_models,
list_available_openai_compatible_models,
)
from utils.get_env import (
get_anthropic_api_key_env,
get_anthropic_model_env,
get_can_change_keys_env,
get_google_model_env,
get_openai_api_key_env,
get_openai_model_env,
get_pixabay_api_key_env,
get_pexels_api_key_env,
)
@ -20,13 +28,7 @@ from utils.llm_provider import (
is_ollama_selected,
)
from utils.ollama import pull_ollama_model
from utils.image_provider import (
get_selected_image_provider,
is_pixels_selected,
is_pixabay_selected,
is_gemini_flash_selected,
is_dalle3_selected,
)
from utils.image_provider import get_selected_image_provider
async def check_llm_and_image_provider_api_or_model_availability():
@ -36,11 +38,41 @@ async def check_llm_and_image_provider_api_or_model_availability():
openai_api_key = get_openai_api_key_env()
if not openai_api_key:
raise Exception("OPENAI_API_KEY must be provided")
openai_model = get_openai_model_env()
if openai_model:
available_models = await list_available_openai_compatible_models(
OPENAI_URL, openai_api_key
)
if openai_model not in available_models:
print("-" * 50)
print("Available models: ", available_models)
raise Exception(f"Model {openai_model} is not available")
elif get_llm_provider() == LLMProvider.GOOGLE:
google_api_key = get_google_api_key_env()
if not google_api_key:
raise Exception("GOOGLE_API_KEY must be provided")
google_model = get_google_model_env()
if google_model:
available_models = await list_available_google_models(google_api_key)
if google_model not in available_models:
print("-" * 50)
print("Available models: ", available_models)
raise Exception(f"Model {google_model} is not available")
elif get_llm_provider() == LLMProvider.ANTHROPIC:
anthropic_api_key = get_anthropic_api_key_env()
if not anthropic_api_key:
raise Exception("ANTHROPIC_API_KEY must be provided")
anthropic_model = get_anthropic_model_env()
if anthropic_model:
available_models = await list_available_anthropic_models(
anthropic_api_key
)
if anthropic_model not in available_models:
print("-" * 50)
print("Available models: ", available_models)
raise Exception(f"Model {anthropic_model} is not available")
elif is_ollama_selected():
ollama_model = get_ollama_model_env()
@ -67,14 +99,12 @@ async def check_llm_and_image_provider_api_or_model_availability():
raise Exception("CUSTOM_LLM_URL must be provided")
if not custom_llm_api_key:
raise Exception("CUSTOM_LLM_API_KEY must be provided")
print("-" * 50)
print("Selecting model: ", custom_model)
models = await list_available_custom_models(
available_models = await list_available_openai_compatible_models(
custom_llm_url, custom_llm_api_key
)
print("Available models: ", models)
print("-" * 50)
if custom_model not in models:
print("Available models: ", available_models)
if custom_model not in available_models:
raise Exception(f"Model {custom_model} is not available")
# Check for Image Provider and API keys

View file

@ -1,7 +1,7 @@
from http.client import HTTPException
import json
from typing import AsyncGenerator
import aiohttp
from fastapi import HTTPException
from models.ollama_model_status import OllamaModelStatus
from utils.get_env import get_ollama_url_env

View file

@ -17,23 +17,23 @@ async def process_slide_and_fetch_assets(
async_tasks = []
image_paths = get_dict_paths_with_key(slide.content, "image_prompt_")
icon_paths = get_dict_paths_with_key(slide.content, "icon_query_")
image_paths = get_dict_paths_with_key(slide.content, "__image_prompt__")
icon_paths = get_dict_paths_with_key(slide.content, "__icon_query__")
for image_path in image_paths:
image_prompt_parent = get_dict_at_path(slide.content, image_path)
__image_prompt__parent = get_dict_at_path(slide.content, image_path)
async_tasks.append(
image_generation_service.generate_image(
ImagePrompt(
prompt=image_prompt_parent["image_prompt_"],
prompt=__image_prompt__parent["__image_prompt__"],
)
)
)
for icon_path in icon_paths:
icon_query_parent = get_dict_at_path(slide.content, icon_path)
__icon_query__parent = get_dict_at_path(slide.content, icon_path)
async_tasks.append(
icon_finder_service.search_icons(icon_query_parent["icon_query_"])
icon_finder_service.search_icons(__icon_query__parent["__icon_query__"])
)
results = await asyncio.gather(*async_tasks)
@ -45,14 +45,14 @@ async def process_slide_and_fetch_assets(
result = results.pop()
if isinstance(result, ImageAsset):
return_assets.append(result)
image_dict["image_url_"] = result.path
image_dict["__image_url__"] = result.path
else:
image_dict["image_url_"] = result
image_dict["__image_url__"] = result
set_dict_at_path(slide.content, image_path, image_dict)
for icon_path in icon_paths:
icon_dict = get_dict_at_path(slide.content, icon_path)
icon_dict["icon_url_"] = results.pop()[0]
icon_dict["__icon_url__"] = results.pop()[0]
set_dict_at_path(slide.content, icon_path, icon_dict)
return return_assets
@ -66,34 +66,34 @@ async def process_old_and_new_slides_and_fetch_assets(
) -> List[ImageAsset]:
# Finds all old images
old_image_dict_paths = get_dict_paths_with_key(
old_slide_content, "image_prompt_"
old_slide_content, "__image_prompt__"
)
old_image_dicts = [
get_dict_at_path(old_slide_content, path) for path in old_image_dict_paths
]
old_image_prompts = [
old_image_dict["image_prompt_"] for old_image_dict in old_image_dicts
old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts
]
# Finds all old icons
old_icon_dict_paths = get_dict_paths_with_key(old_slide_content, "icon_query_")
old_icon_dict_paths = get_dict_paths_with_key(old_slide_content, "__icon_query__")
old_icon_dicts = [
get_dict_at_path(old_slide_content, path) for path in old_icon_dict_paths
]
old_icon_queries = [
old_icon_dict["icon_query_"] for old_icon_dict in old_icon_dicts
old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts
]
# Finds all new images
new_image_dict_paths = get_dict_paths_with_key(
new_slide_content, "image_prompt_"
new_slide_content, "__image_prompt__"
)
new_image_dicts = [
get_dict_at_path(new_slide_content, path) for path in new_image_dict_paths
]
# Finds all new icons
new_icon_dict_paths = get_dict_paths_with_key(new_slide_content, "icon_query_")
new_icon_dict_paths = get_dict_paths_with_key(new_slide_content, "__icon_query__")
new_icon_dicts = [
get_dict_at_path(new_slide_content, path) for path in new_icon_dict_paths
]
@ -109,18 +109,18 @@ async def process_old_and_new_slides_and_fetch_assets(
# Creates async tasks for fetching new images
# Use old image url if prompt is same
for new_image in new_image_dicts:
if new_image["image_prompt_"] in old_image_prompts:
if new_image["__image_prompt__"] in old_image_prompts:
old_image_url = old_image_dicts[
old_image_prompts.index(new_image["image_prompt_"])
]["image_url_"]
new_image["image_url_"] = old_image_url
old_image_prompts.index(new_image["__image_prompt__"])
]["__image_url__"]
new_image["__image_url__"] = old_image_url
new_images_fetch_status.append(False)
continue
async_image_fetch_tasks.append(
image_generation_service.generate_image(
ImagePrompt(
prompt=new_image["image_prompt_"],
prompt=new_image["__image_prompt__"],
)
)
)
@ -129,16 +129,16 @@ async def process_old_and_new_slides_and_fetch_assets(
# Creates async tasks for fetching new icons
# Use old icon url if query is same
for new_icon in new_icon_dicts:
if new_icon["icon_query_"] in old_icon_queries:
if new_icon["__icon_query__"] in old_icon_queries:
old_icon_url = old_icon_dicts[
old_icon_queries.index(new_icon["icon_query_"])
]["icon_url_"]
new_icon["icon_url_"] = old_icon_url
old_icon_queries.index(new_icon["__icon_query__"])
]["__icon_url__"]
new_icon["__icon_url__"] = old_icon_url
new_icons_fetch_status.append(False)
continue
async_icon_fetch_tasks.append(
icon_finder_service.search_icons(new_icon["icon_query_"])
icon_finder_service.search_icons(new_icon["__icon_query__"])
)
new_icons_fetch_status.append(True)
@ -157,11 +157,11 @@ async def process_old_and_new_slides_and_fetch_assets(
image_url = fetched_image.path
else:
image_url = fetched_image
new_image_dicts[i]["image_url_"] = image_url
new_image_dicts[i]["__image_url__"] = image_url
for i, new_icon in enumerate(new_icons):
if new_icons_fetch_status[i]:
new_icon_dicts[i]["icon_url_"] = new_icons[i][0]
new_icon_dicts[i]["__icon_url__"] = new_icons[i][0]
for i, new_image_dict in enumerate(new_image_dicts):
set_dict_at_path(new_slide_content, new_image_dict_paths[i], new_image_dict)

View file

@ -1,30 +1,25 @@
from copy import deepcopy
from typing import List
from typing import Any, List
from utils.dict_utils import get_dict_paths_with_key, get_dict_at_path
from openai import NOT_GIVEN
from utils.dict_utils import (
get_dict_paths_with_key,
get_dict_at_path,
has_more_than_n_keys,
)
def resolve_refs(schema, defs):
if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_key = ref_path.replace("#/$defs/", "")
return resolve_refs(defs[def_key], defs)
else:
raise ValueError(f"Unsupported $ref path: {ref_path}")
else:
return {k: resolve_refs(v, defs) for k, v in schema.items()}
elif isinstance(schema, list):
return [resolve_refs(item, defs) for item in schema]
else:
return schema
def flatten_schema(schema):
schema = deepcopy(schema)
defs = schema.pop("$defs", {})
return resolve_refs(schema, defs)
supported_string_formats = [
"date-time",
"time",
"date",
"duration",
"email",
"hostname",
"ipv4",
"ipv6",
"uuid",
]
def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
@ -50,6 +45,138 @@ def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
return schema
# From OpenAI
def ensure_strict_json_schema(
json_schema: object,
*,
path: tuple[str, ...],
root: dict[str, object],
) -> dict[str, Any]:
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
that the API expects.
"""
if not isinstance(json_schema, dict):
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
defs = json_schema.get("$defs")
if isinstance(defs, dict):
for def_name, def_schema in defs.items():
ensure_strict_json_schema(
def_schema, path=(*path, "$defs", def_name), root=root
)
definitions = json_schema.get("definitions")
if isinstance(definitions, dict):
for definition_name, definition_schema in definitions.items():
ensure_strict_json_schema(
definition_schema,
path=(*path, "definitions", definition_name),
root=root,
)
typ = json_schema.get("type")
if typ == "object" and "additionalProperties" not in json_schema:
json_schema["additionalProperties"] = False
# object types
# { 'type': 'object', 'properties': { 'a': {...} } }
properties = json_schema.get("properties")
if isinstance(properties, dict):
json_schema["required"] = [prop for prop in properties.keys()]
json_schema["properties"] = {
key: ensure_strict_json_schema(
prop_schema, path=(*path, "properties", key), root=root
)
for key, prop_schema in properties.items()
}
# arrays
# { 'type': 'array', 'items': {...} }
items = json_schema.get("items")
if isinstance(items, dict):
json_schema["items"] = ensure_strict_json_schema(
items, path=(*path, "items"), root=root
)
# unions
any_of = json_schema.get("anyOf")
if isinstance(any_of, list):
json_schema["anyOf"] = [
ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(any_of)
]
# intersections
all_of = json_schema.get("allOf")
if isinstance(all_of, list):
if len(all_of) == 1:
json_schema.update(
ensure_strict_json_schema(
all_of[0], path=(*path, "allOf", "0"), root=root
)
)
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
ensure_strict_json_schema(
entry, path=(*path, "allOf", str(i)), root=root
)
for i, entry in enumerate(all_of)
]
# string
if typ == "string":
if "format" in json_schema:
if json_schema["format"] not in supported_string_formats:
del json_schema["format"]
# strip `None` defaults as there's no meaningful distinction here
# the schema will still be `nullable` and the model will default
# to using `None` anyway
if json_schema.get("default", NOT_GIVEN) is None:
json_schema.pop("default")
# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
# so we unravel the ref
# `{"type": "string", "description": "my description"}`
ref = json_schema.get("$ref")
if ref and has_more_than_n_keys(json_schema, 1):
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
resolved = resolve_ref(root=root, ref=ref)
if not isinstance(resolved, dict):
raise ValueError(
f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}"
)
# properties from the json schema take priority over the ones on the `$ref`
json_schema.update({**resolved, **json_schema})
json_schema.pop("$ref")
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
return ensure_strict_json_schema(json_schema, path=path, root=root)
return json_schema
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
if not ref.startswith("#/"):
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
path = ref[2:].split("/")
resolved = root
for key in path:
value = resolved[key]
assert isinstance(
value, dict
), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
resolved = value
return resolved
# ? Not used
def generate_constraint_sentences(schema: dict) -> str:
"""

View file

@ -25,10 +25,18 @@ def set_openai_api_key_env(value):
os.environ["OPENAI_API_KEY"] = value
def set_openai_model_env(value):
os.environ["OPENAI_MODEL"] = value
def set_google_api_key_env(value):
os.environ["GOOGLE_API_KEY"] = value
def set_google_model_env(value):
os.environ["GOOGLE_MODEL"] = value
def set_anthropic_api_key_env(value):
os.environ["ANTHROPIC_API_KEY"] = value

View file

@ -9,10 +9,12 @@ from utils.get_env import (
get_custom_llm_url_env,
get_custom_model_env,
get_google_api_key_env,
get_google_model_env,
get_llm_provider_env,
get_ollama_model_env,
get_ollama_url_env,
get_openai_api_key_env,
get_openai_model_env,
get_pexels_api_key_env,
get_user_config_path_env,
get_image_provider_env,
@ -27,10 +29,12 @@ from utils.set_env import (
set_custom_model_env,
set_extended_reasoning_env,
set_google_api_key_env,
set_google_model_env,
set_llm_provider_env,
set_ollama_model_env,
set_ollama_url_env,
set_openai_api_key_env,
set_openai_model_env,
set_pexels_api_key_env,
set_image_provider_env,
set_pixabay_api_key_env,
@ -58,7 +62,9 @@ def get_user_config():
return UserConfig(
LLM=existing_config.LLM or get_llm_provider_env(),
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or get_openai_api_key_env(),
OPENAI_MODEL=existing_config.OPENAI_MODEL or get_openai_model_env(),
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or get_google_api_key_env(),
GOOGLE_MODEL=existing_config.GOOGLE_MODEL or get_google_model_env(),
ANTHROPIC_API_KEY=existing_config.ANTHROPIC_API_KEY
or get_anthropic_api_key_env(),
ANTHROPIC_MODEL=existing_config.ANTHROPIC_MODEL or get_anthropic_model_env(),
@ -81,8 +87,12 @@ def update_env_with_user_config():
set_llm_provider_env(user_config.LLM)
if user_config.OPENAI_API_KEY:
set_openai_api_key_env(user_config.OPENAI_API_KEY)
if user_config.OPENAI_MODEL:
set_openai_model_env(user_config.OPENAI_MODEL)
if user_config.GOOGLE_API_KEY:
set_google_api_key_env(user_config.GOOGLE_API_KEY)
if user_config.GOOGLE_MODEL:
set_google_model_env(user_config.GOOGLE_MODEL)
if user_config.ANTHROPIC_API_KEY:
set_anthropic_api_key_env(user_config.ANTHROPIC_API_KEY)
if user_config.ANTHROPIC_MODEL: