feat(fastapi): adds strict support for every schemas, proper models check, refactor
This commit is contained in:
parent
9ad017f164
commit
e542fdf869
38 changed files with 542 additions and 535 deletions
|
|
@ -1,10 +1,8 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from api.lifespan import app_lifespan
|
||||
from api.middlewares import UserConfigEnvUpdateMiddleware
|
||||
from api.v1.ppt.router import API_V1_PPT_ROUTER
|
||||
from utils.asset_directory_utils import get_exports_directory, get_images_directory, get_uploads_directory
|
||||
|
||||
|
||||
app = FastAPI(lifespan=app_lifespan)
|
||||
|
|
@ -13,25 +11,6 @@ app = FastAPI(lifespan=app_lifespan)
|
|||
# Routers
|
||||
app.include_router(API_V1_PPT_ROUTER)
|
||||
|
||||
# Static files
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
app.mount(
|
||||
"/app_data/images",
|
||||
StaticFiles(directory=get_images_directory()),
|
||||
name="app_data/images",
|
||||
)
|
||||
app.mount(
|
||||
"/app_data/exports",
|
||||
StaticFiles(directory=get_exports_directory()),
|
||||
name="app_data/exports",
|
||||
)
|
||||
app.mount(
|
||||
"/app_data/uploads",
|
||||
StaticFiles(directory=get_uploads_directory()),
|
||||
name="app_data/uploads",
|
||||
)
|
||||
|
||||
|
||||
# Middlewares
|
||||
origins = ["*"]
|
||||
app.add_middleware(
|
||||
|
|
|
|||
|
|
@ -1,20 +1,16 @@
|
|||
from typing import Annotated, List, Optional
|
||||
import anthropic
|
||||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.get_env import get_anthropic_api_key_env
|
||||
from utils.available_models import list_available_anthropic_models
|
||||
|
||||
ANTHROPIC_ROUTER = APIRouter(prefix="/anthropic", tags=["Anthropic"])
|
||||
|
||||
|
||||
@ANTHROPIC_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
api_key: Annotated[Optional[str], Body(embed=True)] = None,
|
||||
api_key: Annotated[str, Body(embed=True)],
|
||||
):
|
||||
anthropic_api_key = api_key or get_anthropic_api_key_env()
|
||||
if not anthropic_api_key:
|
||||
raise HTTPException(status_code=400, detail="Anthropic API key is required")
|
||||
|
||||
client = anthropic.Anthropic(api_key=anthropic_api_key)
|
||||
models = client.models.list(limit=20)
|
||||
return [model.id for model in models]
|
||||
try:
|
||||
return await list_available_anthropic_models(api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
from typing import Annotated, List, Optional
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.custom_llm_provider import list_available_custom_models
|
||||
|
||||
CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"])
|
||||
|
||||
|
||||
@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
url: Annotated[Optional[str], Body()] = None,
|
||||
api_key: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
try:
|
||||
return await list_available_custom_models(url, api_key)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
14
servers/fastapi/api/v1/ppt/endpoints/google.py
Normal file
14
servers/fastapi/api/v1/ppt/endpoints/google.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.available_models import list_available_google_models
|
||||
|
||||
GOOGLE_ROUTER = APIRouter(prefix="/google", tags=["Google"])
|
||||
|
||||
|
||||
@GOOGLE_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(api_key: Annotated[str, Body(embed=True)]):
|
||||
try:
|
||||
return await list_available_google_models(api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
17
servers/fastapi/api/v1/ppt/endpoints/openai.py
Normal file
17
servers/fastapi/api/v1/ppt/endpoints/openai.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.available_models import list_available_openai_compatible_models
|
||||
|
||||
OPENAI_ROUTER = APIRouter(prefix="/openai", tags=["OpenAI"])
|
||||
|
||||
|
||||
@OPENAI_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
url: Annotated[str, Body()],
|
||||
api_key: Annotated[str, Body()],
|
||||
):
|
||||
try:
|
||||
return await list_available_openai_compatible_models(url, api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -28,7 +28,7 @@ from utils.export_utils import export_presentation
|
|||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sse_response import SSECompleteResponse, SSEResponse
|
||||
from services import SCHEMA_TO_MODEL_SERVICE, TEMP_FILE_SERVICE
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from services.database import get_async_session
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from models.sql.presentation import PresentationModel
|
||||
|
|
@ -43,7 +43,6 @@ from utils.llm_calls.generate_slide_content import (
|
|||
)
|
||||
from utils.process_slides import process_slide_and_fetch_assets
|
||||
from utils.randomizers import get_random_uuid
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
from utils.validators import validate_files
|
||||
|
||||
PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"])
|
||||
|
|
@ -220,20 +219,8 @@ async def stream_presentation(
|
|||
for i, slide_layout_index in enumerate(structure.slides):
|
||||
slide_layout = layout.slides[slide_layout_index]
|
||||
|
||||
# Generate Pydantic model from slide layout schema
|
||||
schema_model_id = f"{layout.name}/{slide_layout.id}"
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["image_url_", "icon_url_"]
|
||||
)
|
||||
schema_model_path = (
|
||||
await SCHEMA_TO_MODEL_SERVICE.get_pydantic_model_path_from_schema(
|
||||
schema_model_id, response_schema
|
||||
)
|
||||
)
|
||||
module = importlib.import_module(schema_model_path)
|
||||
response_model = module.GeneratedModel
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
response_model, outline.slides[i], presentation.language
|
||||
slide_layout, outline.slides[i], presentation.language
|
||||
)
|
||||
|
||||
slide = SlideModel(
|
||||
|
|
@ -252,9 +239,6 @@ async def stream_presentation(
|
|||
)
|
||||
)
|
||||
|
||||
# Give control to the event loop
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
||||
|
|
@ -491,7 +475,6 @@ async def from_template(
|
|||
new_slide_data = list(filter(lambda x: x.index == each_slide.index, data.data))
|
||||
if new_slide_data:
|
||||
updated_content = deep_update(each_slide.content, new_slide_data[0].content)
|
||||
print(f"Updated content for slide {each_slide.index}: {updated_content}")
|
||||
new_slides.append(
|
||||
each_slide.get_new_slide(new_presentation.id, updated_content)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services import SCHEMA_TO_MODEL_SERVICE
|
||||
from services.database import get_async_session
|
||||
from services.icon_finder_service import IconFinderService
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
|
|
@ -35,25 +34,12 @@ async def edit_slide(
|
|||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
presentation_layout = presentation.get_layout()
|
||||
|
||||
slide_layout = await get_slide_layout_from_prompt(
|
||||
prompt, presentation_layout, slide
|
||||
)
|
||||
|
||||
# Generate Pydantic model from slide layout schema
|
||||
schema_model_id = f"{presentation_layout.name}/{slide_layout.id}"
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["image_url_", "icon_url_"]
|
||||
)
|
||||
schema_model_path = (
|
||||
await SCHEMA_TO_MODEL_SERVICE.get_pydantic_model_path_from_schema(
|
||||
schema_model_id, response_schema
|
||||
)
|
||||
)
|
||||
module = importlib.import_module(schema_model_path)
|
||||
response_model = module.GeneratedModel
|
||||
edited_slide_content = await get_edited_slide_content(
|
||||
prompt, slide, presentation.language, response_model
|
||||
prompt, slide, presentation.language, slide_layout
|
||||
)
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from api.v1.ppt.endpoints.anthropic import ANTHROPIC_ROUTER
|
||||
from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER
|
||||
from api.v1.ppt.endpoints.google import GOOGLE_ROUTER
|
||||
from api.v1.ppt.endpoints.openai import OPENAI_ROUTER
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
from api.v1.ppt.endpoints.icons import ICONS_ROUTER
|
||||
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
||||
|
|
@ -20,5 +21,6 @@ API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER)
|
||||
|
|
|
|||
Binary file not shown.
6
servers/fastapi/constants/llm.py
Normal file
6
servers/fastapi/constants/llm.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
OPENAI_URL = "https://api.openai.com/v1"
|
||||
|
||||
# Default models
|
||||
DEFAULT_OPENAI_MODEL = "gpt-4.1"
|
||||
DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash"
|
||||
DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"
|
||||
|
|
@ -7,7 +7,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3:70b": OllamaModelMetadata(
|
||||
|
|
@ -15,7 +14,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="40GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:8b": OllamaModelMetadata(
|
||||
|
|
@ -23,7 +21,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.1:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:70b": OllamaModelMetadata(
|
||||
|
|
@ -31,7 +28,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.1:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:405b": OllamaModelMetadata(
|
||||
|
|
@ -39,7 +35,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.1:405b",
|
||||
description="✅ Graphs supported.",
|
||||
size="243GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:1b": OllamaModelMetadata(
|
||||
|
|
@ -47,7 +42,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.2:1b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:3b": OllamaModelMetadata(
|
||||
|
|
@ -55,7 +49,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.2:3b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.3:70b": OllamaModelMetadata(
|
||||
|
|
@ -63,7 +56,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.3:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:16x17b": OllamaModelMetadata(
|
||||
|
|
@ -71,7 +63,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama4:16x17b",
|
||||
description="✅ Graphs supported.",
|
||||
size="67GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:128x17b": OllamaModelMetadata(
|
||||
|
|
@ -79,7 +70,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama4:128x17b",
|
||||
description="✅ Graphs supported.",
|
||||
size="245GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
}
|
||||
|
|
@ -90,7 +80,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:1b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="815MB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:4b": OllamaModelMetadata(
|
||||
|
|
@ -98,7 +87,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:4b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="3.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:12b": OllamaModelMetadata(
|
||||
|
|
@ -106,7 +94,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:12b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="8.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:27b": OllamaModelMetadata(
|
||||
|
|
@ -114,7 +101,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:27b",
|
||||
description="✅ Graphs supported.",
|
||||
size="17GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
}
|
||||
|
|
@ -125,7 +111,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:1.5b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:7b": OllamaModelMetadata(
|
||||
|
|
@ -133,7 +118,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:7b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:8b": OllamaModelMetadata(
|
||||
|
|
@ -141,7 +125,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:14b": OllamaModelMetadata(
|
||||
|
|
@ -149,7 +132,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:14b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:32b": OllamaModelMetadata(
|
||||
|
|
@ -157,7 +139,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:32b",
|
||||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:70b": OllamaModelMetadata(
|
||||
|
|
@ -165,7 +146,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:671b": OllamaModelMetadata(
|
||||
|
|
@ -173,7 +153,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:671b",
|
||||
description="✅ Graphs supported.",
|
||||
size="404GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
}
|
||||
|
|
@ -184,7 +163,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:0.6b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="523MB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:1.7b": OllamaModelMetadata(
|
||||
|
|
@ -192,7 +170,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:1.7b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.4GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:4b": OllamaModelMetadata(
|
||||
|
|
@ -200,7 +177,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:4b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="2.6GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:8b": OllamaModelMetadata(
|
||||
|
|
@ -208,7 +184,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:14b": OllamaModelMetadata(
|
||||
|
|
@ -216,7 +191,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:14b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="9.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:30b": OllamaModelMetadata(
|
||||
|
|
@ -224,7 +198,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:30b",
|
||||
description="✅ Graphs supported.",
|
||||
size="19GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:32b": OllamaModelMetadata(
|
||||
|
|
@ -232,7 +205,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:32b",
|
||||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:235b": OllamaModelMetadata(
|
||||
|
|
@ -240,7 +212,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:235b",
|
||||
description="✅ Graphs supported.",
|
||||
size="142GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ class ContactInfoModel(BaseModel):
|
|||
|
||||
|
||||
class ImageModel(BaseModel):
|
||||
image_url_: str = Field(description="Image URL")
|
||||
image_prompt_: str = Field(description="Image prompt")
|
||||
__image_url__: str = Field(description="Image URL")
|
||||
__image_prompt__: str = Field(description="Image prompt")
|
||||
|
||||
|
||||
# First Slide Layout
|
||||
|
|
|
|||
|
|
@ -7,4 +7,3 @@ class OllamaModelMetadata(BaseModel):
|
|||
description: str
|
||||
icon: str
|
||||
size: str
|
||||
supports_graph: bool
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
|
@ -16,6 +17,14 @@ class PresentationLayoutModel(BaseModel):
|
|||
ordered: bool = Field(default=False)
|
||||
slides: List[SlideLayoutModel]
|
||||
|
||||
def get_slide_layout_index(self, slide_layout_id: str) -> int:
|
||||
for index, slide in enumerate(self.slides):
|
||||
if slide.id == slide_layout_id:
|
||||
return index
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Slide layout {slide_layout_id} not found"
|
||||
)
|
||||
|
||||
def to_presentation_structure(self):
|
||||
return PresentationStructureModel(
|
||||
slides=[index for index in range(len(self.slides))]
|
||||
|
|
|
|||
|
|
@ -4,16 +4,32 @@ from pydantic import BaseModel
|
|||
|
||||
class UserConfig(BaseModel):
|
||||
LLM: Optional[str] = None
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
OPENAI_MODEL: Optional[str] = None
|
||||
|
||||
# Google
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GOOGLE_MODEL: Optional[str] = None
|
||||
|
||||
# Anthropic
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_MODEL: Optional[str] = None
|
||||
|
||||
# Ollama
|
||||
OLLAMA_URL: Optional[str] = None
|
||||
OLLAMA_MODEL: Optional[str] = None
|
||||
|
||||
# Custom LLM
|
||||
CUSTOM_LLM_URL: Optional[str] = None
|
||||
CUSTOM_LLM_API_KEY: Optional[str] = None
|
||||
CUSTOM_MODEL: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
|
||||
# Image Provider
|
||||
IMAGE_PROVIDER: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
PIXABAY_API_KEY: Optional[str] = None
|
||||
|
||||
# Reasoning
|
||||
EXTENDED_REASONING: Optional[bool] = None
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ chromadb==1.0.15
|
|||
click==8.2.1
|
||||
coloredlogs==15.0.1
|
||||
cryptography==45.0.5
|
||||
datamodel-code-generator==0.32.0
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
durationpy==0.10
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from services.redis_service import RedisService
|
||||
from services.schema_to_model_service import SchemaToModelService
|
||||
from services.temp_file_service import TempFileService
|
||||
|
||||
|
||||
TEMP_FILE_SERVICE = TempFileService()
|
||||
REDIS_SERVICE = RedisService()
|
||||
SCHEMA_TO_MODEL_SERVICE = SchemaToModelService(TEMP_FILE_SERVICE)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import os
|
|||
import aiohttp
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from openai import AsyncOpenAI
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.sql.image_asset import ImageAsset
|
||||
from utils.download_helpers import download_file
|
||||
from utils.get_env import get_pexels_api_key_env
|
||||
from utils.get_env import get_pixabay_api_key_env
|
||||
from utils.llm_provider import get_llm_client
|
||||
from utils.image_provider import (
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
|
|
@ -80,7 +80,7 @@ class ImageGenerationService:
|
|||
return "/static/images/placeholder.jpg"
|
||||
|
||||
async def generate_image_openai(self, prompt: str, output_directory: str) -> str:
|
||||
client = get_llm_client()
|
||||
client = AsyncOpenAI()
|
||||
result = await client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=prompt,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
|
@ -8,7 +8,6 @@ from google.genai.types import GenerateContentConfig
|
|||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
|
||||
from pydantic import BaseModel
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import LLMMessage
|
||||
from utils.async_iterator import iterator_to_async
|
||||
|
|
@ -21,6 +20,7 @@ from utils.get_env import (
|
|||
get_openai_api_key_env,
|
||||
)
|
||||
from utils.llm_provider import get_llm_provider
|
||||
from utils.schema_utils import ensure_strict_json_schema
|
||||
|
||||
|
||||
class LLMClient:
|
||||
|
|
@ -173,43 +173,45 @@ class LLMClient:
|
|||
|
||||
# ? Generate Structured Content
|
||||
async def _generate_openai_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
is_response_format_dict = isinstance(response_format, dict)
|
||||
if is_response_format_dict:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"schema": response_format,
|
||||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=self.max_tokens,
|
||||
response_schema = response_format
|
||||
if strict:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
response_schema,
|
||||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return json.loads(content)
|
||||
return None
|
||||
else:
|
||||
response = await client.chat.completions.parse(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format=response_format,
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.parsed
|
||||
if content:
|
||||
return content.model_dump(mode="json")
|
||||
return None
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return json.loads(content)
|
||||
return None
|
||||
|
||||
async def _generate_google_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
response = await asyncio.to_thread(
|
||||
|
|
@ -219,7 +221,7 @@ class LLMClient:
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_format,
|
||||
response_json_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
)
|
||||
|
|
@ -230,10 +232,12 @@ class LLMClient:
|
|||
return content
|
||||
|
||||
async def _generate_anthropic_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
is_response_format_dict = isinstance(response_format, dict)
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
|
|
@ -246,11 +250,7 @@ class LLMClient:
|
|||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"input_schema": (
|
||||
response_format
|
||||
if is_response_format_dict
|
||||
else response_format.model_json_schema()
|
||||
),
|
||||
"input_schema": response_format,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
|
@ -262,23 +262,39 @@ class LLMClient:
|
|||
return content
|
||||
|
||||
async def _generate_ollama_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return await self._generate_openai_structured(model, messages, response_format)
|
||||
return await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
|
||||
async def _generate_custom_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return await self._generate_openai_structured(model, messages, response_format)
|
||||
return await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
|
||||
async def generate_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
) -> dict:
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
content = await self._generate_openai_structured(
|
||||
model, messages, response_format
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google_structured(
|
||||
|
|
@ -290,11 +306,11 @@ class LLMClient:
|
|||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama_structured(
|
||||
model, messages, response_format
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
content = await self._generate_custom_structured(
|
||||
model, messages, response_format
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -366,10 +382,20 @@ class LLMClient:
|
|||
|
||||
# ? Stream Structured Content
|
||||
async def _stream_openai_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
is_response_format_dict = isinstance(response_format, dict)
|
||||
response_schema = response_format
|
||||
if strict:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
response_schema,
|
||||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
|
|
@ -379,11 +405,10 @@ class LLMClient:
|
|||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ResponseSchema",
|
||||
"schema": response_format,
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
},
|
||||
}
|
||||
if is_response_format_dict
|
||||
else response_format
|
||||
),
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
|
|
@ -391,7 +416,10 @@ class LLMClient:
|
|||
yield event.delta
|
||||
|
||||
async def _stream_google_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
async for event in iterator_to_async(client.models.generate_content_stream)(
|
||||
|
|
@ -400,7 +428,7 @@ class LLMClient:
|
|||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_format,
|
||||
response_json_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
):
|
||||
|
|
@ -408,10 +436,12 @@ class LLMClient:
|
|||
yield event.text
|
||||
|
||||
async def _stream_anthropic_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
is_response_format_dict = isinstance(response_format, dict)
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
|
|
@ -424,11 +454,7 @@ class LLMClient:
|
|||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"input_schema": (
|
||||
response_format
|
||||
if is_response_format_dict
|
||||
else response_format.model_json_schema()
|
||||
),
|
||||
"input_schema": response_format,
|
||||
}
|
||||
],
|
||||
) as stream:
|
||||
|
|
@ -438,21 +464,35 @@ class LLMClient:
|
|||
yield event.partial_json
|
||||
|
||||
def _stream_ollama_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return self._stream_openai_structured(model, messages, response_format)
|
||||
return self._stream_openai_structured(model, messages, response_format, strict)
|
||||
|
||||
def _stream_custom_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return self._stream_openai_structured(model, messages, response_format)
|
||||
return self._stream_openai_structured(model, messages, response_format, strict)
|
||||
|
||||
def stream_structured(
|
||||
self, model: str, messages: List[LLMMessage], response_format: BaseModel | dict
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self._stream_openai_structured(model, messages, response_format)
|
||||
return self._stream_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google_structured(model, messages, response_format)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
|
|
@ -460,6 +500,10 @@ class LLMClient:
|
|||
model, messages, response_format
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._stream_ollama_structured(model, messages, response_format)
|
||||
return self._stream_ollama_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._stream_custom_structured(model, messages, response_format)
|
||||
return self._stream_custom_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,78 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from fastapi import HTTPException
|
||||
from datamodel_code_generator import generate, InputFileType, DataModelType
|
||||
|
||||
from services.temp_file_service import TempFileService
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class SchemaToModelService:
|
||||
def __init__(self, temp_file_service: TempFileService):
|
||||
self.temp_file_service = temp_file_service
|
||||
self.temp_dir = self.temp_file_service.create_temp_dir()
|
||||
|
||||
self.generated_models_dir = "generated_models"
|
||||
if os.path.exists(self.generated_models_dir):
|
||||
for file in os.listdir(self.generated_models_dir):
|
||||
if file.endswith(".py"):
|
||||
os.remove(os.path.join(self.generated_models_dir, file))
|
||||
os.makedirs(self.generated_models_dir, exist_ok=True)
|
||||
|
||||
self._records: Dict[str, str] = {}
|
||||
self._fetch_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def convert_path_to_module_path(self, path: str):
|
||||
return path.replace("/", ".").replace("\\", ".").replace(".py", "")
|
||||
|
||||
async def get_pydantic_model_path_from_schema(
|
||||
self, identifier: str, schema: dict
|
||||
) -> str:
|
||||
if identifier in self._fetch_locks:
|
||||
async with self._fetch_locks[identifier]:
|
||||
return self._records[identifier]
|
||||
else:
|
||||
async_lock = asyncio.Lock()
|
||||
await async_lock.acquire()
|
||||
self._fetch_locks[identifier] = async_lock
|
||||
model_path = await self.generate_pydantic_model_from_schema_async(schema)
|
||||
model_path = self.convert_path_to_module_path(model_path)
|
||||
self._records[identifier] = model_path
|
||||
async_lock.release()
|
||||
return model_path
|
||||
|
||||
async def generate_pydantic_model_from_schema_async(self, schema: dict):
|
||||
return await asyncio.to_thread(self.generate_pydantic_model_from_schema, schema)
|
||||
|
||||
def generate_pydantic_model_from_schema(self, schema: dict):
|
||||
generated_model_path = os.path.join(
|
||||
self.generated_models_dir, get_random_uuid() + ".py"
|
||||
)
|
||||
try:
|
||||
schema_path = self.temp_file_service.create_temp_file_path(
|
||||
get_random_uuid() + ".json", self.temp_dir
|
||||
)
|
||||
with open(schema_path, "w") as f:
|
||||
json.dump(schema, f)
|
||||
|
||||
generate(
|
||||
input_=Path(schema_path),
|
||||
input_file_type=InputFileType.JsonSchema,
|
||||
output=Path(generated_model_path),
|
||||
output_model_type=DataModelType.PydanticV2BaseModel,
|
||||
class_name="GeneratedModel",
|
||||
use_annotated=False,
|
||||
field_constraints=True,
|
||||
extra_fields="ignore",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate Pydantic model from schema"
|
||||
)
|
||||
finally:
|
||||
self.temp_file_service.cleanup_temp_file(schema_path)
|
||||
|
||||
return generated_model_path
|
||||
21
servers/fastapi/utils/available_models.py
Normal file
21
servers/fastapi/utils/available_models.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
||||
|
||||
async def list_available_openai_compatible_models(url: str, api_key: str) -> list[str]:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=url)
|
||||
models = (await client.models.list()).data
|
||||
if models:
|
||||
return list(map(lambda x: x.id, models))
|
||||
return []
|
||||
|
||||
|
||||
async def list_available_anthropic_models(api_key: str) -> list[str]:
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
return list(map(lambda x: x.id, (await client.models.list(limit=50)).data))
|
||||
|
||||
|
||||
async def list_available_google_models(api_key: str) -> list[str]:
|
||||
client = genai.Client(api_key=api_key)
|
||||
return list(map(lambda x: x.name, client.models.list(config={"page_size": 50})))
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
from typing import Optional
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from utils.llm_provider import get_llm_client
|
||||
|
||||
|
||||
async def list_available_custom_models(
|
||||
url: Optional[str] = None, api_key: Optional[str] = None
|
||||
) -> list[str]:
|
||||
if not url:
|
||||
client = get_llm_client()
|
||||
else:
|
||||
client = AsyncOpenAI(api_key=api_key or "null", base_url=url)
|
||||
models = []
|
||||
async for model in client.models.list():
|
||||
models.append(model.id)
|
||||
return models
|
||||
|
|
@ -78,3 +78,12 @@ def deep_update(original: dict, updates: dict) -> dict:
|
|||
if not isinstance(value, (dict, list)):
|
||||
original[key] = value
|
||||
return original
|
||||
|
||||
|
||||
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
|
||||
i = 0
|
||||
for _ in obj.keys():
|
||||
i += 1
|
||||
if i > n:
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -45,10 +45,18 @@ def get_openai_api_key_env():
|
|||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def get_openai_model_env():
|
||||
return os.getenv("OPENAI_MODEL")
|
||||
|
||||
|
||||
def get_google_api_key_env():
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
|
||||
def get_google_model_env():
|
||||
return os.getenv("GOOGLE_MODEL")
|
||||
|
||||
|
||||
def get_custom_llm_api_key_env():
|
||||
return os.getenv("CUSTOM_LLM_API_KEY")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from pydantic import BaseModel
|
||||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.llm_provider import get_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
system_prompt = """
|
||||
|
|
@ -57,14 +56,19 @@ async def get_edited_slide_content(
|
|||
prompt: str,
|
||||
slide: SlideModel,
|
||||
language: str,
|
||||
response_model: BaseModel,
|
||||
slide_layout: SlideLayoutModel,
|
||||
):
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
client = LLMClient()
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(prompt, slide.content, language),
|
||||
response_format=response_model,
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,14 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from utils.llm_provider import (
|
||||
get_anthropic_llm_client,
|
||||
get_google_llm_client,
|
||||
get_large_model,
|
||||
is_anthropic_selected,
|
||||
is_google_selected,
|
||||
get_llm_client,
|
||||
)
|
||||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
system_prompt = """
|
||||
You are an expert HTML slide editor. Your task is to modify slide HTML content based on user prompts while maintaining proper structure, styling, and functionality.
|
||||
|
|
@ -54,48 +47,17 @@ def get_user_prompt(prompt: str, html: str):
|
|||
|
||||
|
||||
async def get_edited_slide_html(prompt: str, html: str):
|
||||
model = get_large_model()
|
||||
llm_response = None
|
||||
model = get_model()
|
||||
|
||||
if is_anthropic_selected():
|
||||
client = get_anthropic_llm_client()
|
||||
response = await client.messages.create(
|
||||
model=model,
|
||||
messages=[get_user_prompt(prompt, html)],
|
||||
)
|
||||
for each in response.content:
|
||||
if each.type == "text":
|
||||
llm_response = each.text
|
||||
break
|
||||
|
||||
elif is_google_selected():
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(prompt, html)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="text/plain",
|
||||
),
|
||||
)
|
||||
llm_response = response.text
|
||||
|
||||
else:
|
||||
client = get_llm_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": get_user_prompt(prompt, html)},
|
||||
],
|
||||
)
|
||||
llm_response = response.choices[0].message.content
|
||||
|
||||
if not llm_response:
|
||||
return html
|
||||
|
||||
return extract_html_from_response(llm_response) or html
|
||||
client = LLMClient()
|
||||
response = await client.generate(
|
||||
model=model,
|
||||
messages=[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
|
||||
],
|
||||
)
|
||||
return extract_html_from_response(response) or html
|
||||
|
||||
|
||||
def extract_html_from_response(response_text: str) -> Optional[str]:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import List
|
|||
|
||||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_nano_model
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
||||
sysmte_prompt = """
|
||||
|
|
@ -25,7 +25,7 @@ Maintain as much information as possible.
|
|||
|
||||
async def generate_document_summary(documents: List[str]):
|
||||
client = LLMClient()
|
||||
model = get_nano_model()
|
||||
model = get_model()
|
||||
|
||||
coroutines = []
|
||||
for document in documents:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
system_prompt = """
|
||||
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
|
||||
|
|
@ -75,7 +75,7 @@ async def generate_ppt_outline(
|
|||
language: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
):
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
response_model = get_presentation_outline_model_with_n_slides(n_slides)
|
||||
|
||||
client = LLMClient()
|
||||
|
|
@ -83,6 +83,7 @@ async def generate_ppt_outline(
|
|||
async for chunk in client.stream_structured(
|
||||
model,
|
||||
get_messages(prompt, n_slides, language, content),
|
||||
response_model,
|
||||
response_model.model_json_schema(),
|
||||
strict=True,
|
||||
):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from models.llm_message import LLMMessage
|
|||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.llm_provider import get_model
|
||||
from utils.get_dynamic_models import get_presentation_structure_model_with_n_slides
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ async def generate_presentation_structure(
|
|||
) -> PresentationStructureModel:
|
||||
|
||||
client = LLMClient()
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
|
|
@ -74,6 +74,7 @@ async def generate_presentation_structure(
|
|||
len(presentation_outline.slides),
|
||||
presentation_outline.to_string(),
|
||||
),
|
||||
response_format=response_model,
|
||||
response_format=response_model.model_json_schema(),
|
||||
strict=True,
|
||||
)
|
||||
return PresentationStructureModel(**response)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from pydantic import BaseModel
|
||||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.llm_provider import get_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
system_prompt = """
|
||||
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
|
||||
|
|
@ -14,8 +15,8 @@ system_prompt = """
|
|||
# Notes
|
||||
- Slide body should not use words like "This slide", "This presentation".
|
||||
- Rephrase the slide body to make it flow naturally.
|
||||
- Provide prompt to generate image on "image_prompt_" property.
|
||||
- Provide query to search icon on "icon_query_" property.
|
||||
- Provide prompt to generate image on "__image_prompt__" property.
|
||||
- Provide query to search icon on "__icon_query__" property.
|
||||
- Do not use markdown formatting in slide body.
|
||||
- Make sure to follow language guidelines.
|
||||
**Strictly follow the max and min character limit for every property in the slide.**
|
||||
|
|
@ -53,10 +54,14 @@ def get_messages(title: str, outline: str, language: str):
|
|||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
response_model: BaseModel, outline: SlideOutlineModel, language: str
|
||||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
|
||||
):
|
||||
client = LLMClient()
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
|
|
@ -65,6 +70,7 @@ async def get_slide_content_from_type_and_outline(
|
|||
outline.body,
|
||||
language,
|
||||
),
|
||||
response_format=response_model,
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
|||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_large_model
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
||||
def get_messages(
|
||||
|
|
@ -44,9 +44,9 @@ async def get_slide_layout_from_prompt(
|
|||
) -> SlideLayoutModel:
|
||||
|
||||
client = LLMClient()
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
|
||||
slide_layout_ids = list(map(lambda x: x.id, layout.slides))
|
||||
slide_layout_index = layout.get_slide_layout_index(slide.layout)
|
||||
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
|
|
@ -54,9 +54,10 @@ async def get_slide_layout_from_prompt(
|
|||
prompt,
|
||||
slide.content,
|
||||
layout,
|
||||
slide_layout_ids.index(slide.layout),
|
||||
slide_layout_index,
|
||||
),
|
||||
response_format=SlideLayoutIndex,
|
||||
response_format=SlideLayoutIndex.model_json_schema(),
|
||||
strict=True,
|
||||
)
|
||||
index = SlideLayoutIndex(**response).index
|
||||
return layout.slides[index]
|
||||
|
|
|
|||
|
|
@ -1,21 +1,18 @@
|
|||
import os
|
||||
import anthropic
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
||||
from constants.llm import (
|
||||
DEFAULT_ANTHROPIC_MODEL,
|
||||
DEFAULT_GOOGLE_MODEL,
|
||||
DEFAULT_OPENAI_MODEL,
|
||||
)
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_anthropic_model_env,
|
||||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_custom_model_env,
|
||||
get_google_api_key_env,
|
||||
get_google_model_env,
|
||||
get_llm_provider_env,
|
||||
get_ollama_model_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -25,14 +22,10 @@ def get_llm_provider():
|
|||
except:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, ollama, custom",
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
)
|
||||
|
||||
|
||||
def get_ollama_url():
|
||||
return get_ollama_url_env() or "http://localhost:11434"
|
||||
|
||||
|
||||
def is_openai_selected():
|
||||
return get_llm_provider() == LLMProvider.OPENAI
|
||||
|
||||
|
|
@ -53,100 +46,20 @@ def is_custom_llm_selected():
|
|||
return get_llm_provider() == LLMProvider.CUSTOM
|
||||
|
||||
|
||||
def get_model_base_url():
|
||||
selected_llm = get_llm_provider()
|
||||
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
return "https://api.anthropic.com/v1"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return os.path.join(get_ollama_url(), "v1")
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_llm_url_env()
|
||||
else:
|
||||
raise HTTPException(f"LLM provider {selected_llm} is not supported")
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
def get_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return get_openai_api_key_env()
|
||||
return get_openai_model_env() or DEFAULT_OPENAI_MODEL
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return get_google_api_key_env()
|
||||
return get_google_model_env() or DEFAULT_GOOGLE_MODEL
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
return get_anthropic_api_key_env()
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return "ollama"
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_llm_api_key_env() or "none"
|
||||
else:
|
||||
raise HTTPException(f"LLM provider {selected_llm} is not supported")
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
client = AsyncOpenAI(
|
||||
base_url=get_model_base_url(),
|
||||
api_key=get_llm_api_key(),
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_google_llm_client():
|
||||
client = genai.Client(api_key=get_google_api_key_env())
|
||||
return client
|
||||
|
||||
|
||||
def get_anthropic_llm_client():
|
||||
client = anthropic.AsyncAnthropic(api_key=get_anthropic_api_key_env())
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
return get_anthropic_model_env()
|
||||
return get_anthropic_model_env() or DEFAULT_ANTHROPIC_MODEL
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
return get_anthropic_model_env()
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
return get_anthropic_model_env()
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
import os
|
||||
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from constants.llm import OPENAI_URL
|
||||
from enums.image_provider import ImageProvider
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.custom_llm_provider import list_available_custom_models
|
||||
from utils.available_models import (
|
||||
list_available_anthropic_models,
|
||||
list_available_google_models,
|
||||
list_available_openai_compatible_models,
|
||||
)
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_anthropic_model_env,
|
||||
get_can_change_keys_env,
|
||||
get_google_model_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
get_pixabay_api_key_env,
|
||||
get_pexels_api_key_env,
|
||||
)
|
||||
|
|
@ -20,13 +28,7 @@ from utils.llm_provider import (
|
|||
is_ollama_selected,
|
||||
)
|
||||
from utils.ollama import pull_ollama_model
|
||||
from utils.image_provider import (
|
||||
get_selected_image_provider,
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
is_gemini_flash_selected,
|
||||
is_dalle3_selected,
|
||||
)
|
||||
from utils.image_provider import get_selected_image_provider
|
||||
|
||||
|
||||
async def check_llm_and_image_provider_api_or_model_availability():
|
||||
|
|
@ -36,11 +38,41 @@ async def check_llm_and_image_provider_api_or_model_availability():
|
|||
openai_api_key = get_openai_api_key_env()
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
openai_model = get_openai_model_env()
|
||||
if openai_model:
|
||||
available_models = await list_available_openai_compatible_models(
|
||||
OPENAI_URL, openai_api_key
|
||||
)
|
||||
if openai_model not in available_models:
|
||||
print("-" * 50)
|
||||
print("Available models: ", available_models)
|
||||
raise Exception(f"Model {openai_model} is not available")
|
||||
|
||||
elif get_llm_provider() == LLMProvider.GOOGLE:
|
||||
google_api_key = get_google_api_key_env()
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
google_model = get_google_model_env()
|
||||
if google_model:
|
||||
available_models = await list_available_google_models(google_api_key)
|
||||
if google_model not in available_models:
|
||||
print("-" * 50)
|
||||
print("Available models: ", available_models)
|
||||
raise Exception(f"Model {google_model} is not available")
|
||||
|
||||
elif get_llm_provider() == LLMProvider.ANTHROPIC:
|
||||
anthropic_api_key = get_anthropic_api_key_env()
|
||||
if not anthropic_api_key:
|
||||
raise Exception("ANTHROPIC_API_KEY must be provided")
|
||||
anthropic_model = get_anthropic_model_env()
|
||||
if anthropic_model:
|
||||
available_models = await list_available_anthropic_models(
|
||||
anthropic_api_key
|
||||
)
|
||||
if anthropic_model not in available_models:
|
||||
print("-" * 50)
|
||||
print("Available models: ", available_models)
|
||||
raise Exception(f"Model {anthropic_model} is not available")
|
||||
|
||||
elif is_ollama_selected():
|
||||
ollama_model = get_ollama_model_env()
|
||||
|
|
@ -67,14 +99,12 @@ async def check_llm_and_image_provider_api_or_model_availability():
|
|||
raise Exception("CUSTOM_LLM_URL must be provided")
|
||||
if not custom_llm_api_key:
|
||||
raise Exception("CUSTOM_LLM_API_KEY must be provided")
|
||||
print("-" * 50)
|
||||
print("Selecting model: ", custom_model)
|
||||
models = await list_available_custom_models(
|
||||
available_models = await list_available_openai_compatible_models(
|
||||
custom_llm_url, custom_llm_api_key
|
||||
)
|
||||
print("Available models: ", models)
|
||||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
print("Available models: ", available_models)
|
||||
if custom_model not in available_models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
|
||||
# Check for Image Provider and API keys
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from http.client import HTTPException
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from utils.get_env import get_ollama_url_env
|
||||
|
|
|
|||
|
|
@ -17,23 +17,23 @@ async def process_slide_and_fetch_assets(
|
|||
|
||||
async_tasks = []
|
||||
|
||||
image_paths = get_dict_paths_with_key(slide.content, "image_prompt_")
|
||||
icon_paths = get_dict_paths_with_key(slide.content, "icon_query_")
|
||||
image_paths = get_dict_paths_with_key(slide.content, "__image_prompt__")
|
||||
icon_paths = get_dict_paths_with_key(slide.content, "__icon_query__")
|
||||
|
||||
for image_path in image_paths:
|
||||
image_prompt_parent = get_dict_at_path(slide.content, image_path)
|
||||
__image_prompt__parent = get_dict_at_path(slide.content, image_path)
|
||||
async_tasks.append(
|
||||
image_generation_service.generate_image(
|
||||
ImagePrompt(
|
||||
prompt=image_prompt_parent["image_prompt_"],
|
||||
prompt=__image_prompt__parent["__image_prompt__"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for icon_path in icon_paths:
|
||||
icon_query_parent = get_dict_at_path(slide.content, icon_path)
|
||||
__icon_query__parent = get_dict_at_path(slide.content, icon_path)
|
||||
async_tasks.append(
|
||||
icon_finder_service.search_icons(icon_query_parent["icon_query_"])
|
||||
icon_finder_service.search_icons(__icon_query__parent["__icon_query__"])
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*async_tasks)
|
||||
|
|
@ -45,14 +45,14 @@ async def process_slide_and_fetch_assets(
|
|||
result = results.pop()
|
||||
if isinstance(result, ImageAsset):
|
||||
return_assets.append(result)
|
||||
image_dict["image_url_"] = result.path
|
||||
image_dict["__image_url__"] = result.path
|
||||
else:
|
||||
image_dict["image_url_"] = result
|
||||
image_dict["__image_url__"] = result
|
||||
set_dict_at_path(slide.content, image_path, image_dict)
|
||||
|
||||
for icon_path in icon_paths:
|
||||
icon_dict = get_dict_at_path(slide.content, icon_path)
|
||||
icon_dict["icon_url_"] = results.pop()[0]
|
||||
icon_dict["__icon_url__"] = results.pop()[0]
|
||||
set_dict_at_path(slide.content, icon_path, icon_dict)
|
||||
|
||||
return return_assets
|
||||
|
|
@ -66,34 +66,34 @@ async def process_old_and_new_slides_and_fetch_assets(
|
|||
) -> List[ImageAsset]:
|
||||
# Finds all old images
|
||||
old_image_dict_paths = get_dict_paths_with_key(
|
||||
old_slide_content, "image_prompt_"
|
||||
old_slide_content, "__image_prompt__"
|
||||
)
|
||||
old_image_dicts = [
|
||||
get_dict_at_path(old_slide_content, path) for path in old_image_dict_paths
|
||||
]
|
||||
old_image_prompts = [
|
||||
old_image_dict["image_prompt_"] for old_image_dict in old_image_dicts
|
||||
old_image_dict["__image_prompt__"] for old_image_dict in old_image_dicts
|
||||
]
|
||||
|
||||
# Finds all old icons
|
||||
old_icon_dict_paths = get_dict_paths_with_key(old_slide_content, "icon_query_")
|
||||
old_icon_dict_paths = get_dict_paths_with_key(old_slide_content, "__icon_query__")
|
||||
old_icon_dicts = [
|
||||
get_dict_at_path(old_slide_content, path) for path in old_icon_dict_paths
|
||||
]
|
||||
old_icon_queries = [
|
||||
old_icon_dict["icon_query_"] for old_icon_dict in old_icon_dicts
|
||||
old_icon_dict["__icon_query__"] for old_icon_dict in old_icon_dicts
|
||||
]
|
||||
|
||||
# Finds all new images
|
||||
new_image_dict_paths = get_dict_paths_with_key(
|
||||
new_slide_content, "image_prompt_"
|
||||
new_slide_content, "__image_prompt__"
|
||||
)
|
||||
new_image_dicts = [
|
||||
get_dict_at_path(new_slide_content, path) for path in new_image_dict_paths
|
||||
]
|
||||
|
||||
# Finds all new icons
|
||||
new_icon_dict_paths = get_dict_paths_with_key(new_slide_content, "icon_query_")
|
||||
new_icon_dict_paths = get_dict_paths_with_key(new_slide_content, "__icon_query__")
|
||||
new_icon_dicts = [
|
||||
get_dict_at_path(new_slide_content, path) for path in new_icon_dict_paths
|
||||
]
|
||||
|
|
@ -109,18 +109,18 @@ async def process_old_and_new_slides_and_fetch_assets(
|
|||
# Creates async tasks for fetching new images
|
||||
# Use old image url if prompt is same
|
||||
for new_image in new_image_dicts:
|
||||
if new_image["image_prompt_"] in old_image_prompts:
|
||||
if new_image["__image_prompt__"] in old_image_prompts:
|
||||
old_image_url = old_image_dicts[
|
||||
old_image_prompts.index(new_image["image_prompt_"])
|
||||
]["image_url_"]
|
||||
new_image["image_url_"] = old_image_url
|
||||
old_image_prompts.index(new_image["__image_prompt__"])
|
||||
]["__image_url__"]
|
||||
new_image["__image_url__"] = old_image_url
|
||||
new_images_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_image_fetch_tasks.append(
|
||||
image_generation_service.generate_image(
|
||||
ImagePrompt(
|
||||
prompt=new_image["image_prompt_"],
|
||||
prompt=new_image["__image_prompt__"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -129,16 +129,16 @@ async def process_old_and_new_slides_and_fetch_assets(
|
|||
# Creates async tasks for fetching new icons
|
||||
# Use old icon url if query is same
|
||||
for new_icon in new_icon_dicts:
|
||||
if new_icon["icon_query_"] in old_icon_queries:
|
||||
if new_icon["__icon_query__"] in old_icon_queries:
|
||||
old_icon_url = old_icon_dicts[
|
||||
old_icon_queries.index(new_icon["icon_query_"])
|
||||
]["icon_url_"]
|
||||
new_icon["icon_url_"] = old_icon_url
|
||||
old_icon_queries.index(new_icon["__icon_query__"])
|
||||
]["__icon_url__"]
|
||||
new_icon["__icon_url__"] = old_icon_url
|
||||
new_icons_fetch_status.append(False)
|
||||
continue
|
||||
|
||||
async_icon_fetch_tasks.append(
|
||||
icon_finder_service.search_icons(new_icon["icon_query_"])
|
||||
icon_finder_service.search_icons(new_icon["__icon_query__"])
|
||||
)
|
||||
new_icons_fetch_status.append(True)
|
||||
|
||||
|
|
@ -157,11 +157,11 @@ async def process_old_and_new_slides_and_fetch_assets(
|
|||
image_url = fetched_image.path
|
||||
else:
|
||||
image_url = fetched_image
|
||||
new_image_dicts[i]["image_url_"] = image_url
|
||||
new_image_dicts[i]["__image_url__"] = image_url
|
||||
|
||||
for i, new_icon in enumerate(new_icons):
|
||||
if new_icons_fetch_status[i]:
|
||||
new_icon_dicts[i]["icon_url_"] = new_icons[i][0]
|
||||
new_icon_dicts[i]["__icon_url__"] = new_icons[i][0]
|
||||
|
||||
for i, new_image_dict in enumerate(new_image_dicts):
|
||||
set_dict_at_path(new_slide_content, new_image_dict_paths[i], new_image_dict)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,25 @@
|
|||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from utils.dict_utils import get_dict_paths_with_key, get_dict_at_path
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from utils.dict_utils import (
|
||||
get_dict_paths_with_key,
|
||||
get_dict_at_path,
|
||||
has_more_than_n_keys,
|
||||
)
|
||||
|
||||
def resolve_refs(schema, defs):
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_key = ref_path.replace("#/$defs/", "")
|
||||
return resolve_refs(defs[def_key], defs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported $ref path: {ref_path}")
|
||||
else:
|
||||
return {k: resolve_refs(v, defs) for k, v in schema.items()}
|
||||
elif isinstance(schema, list):
|
||||
return [resolve_refs(item, defs) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
|
||||
def flatten_schema(schema):
|
||||
schema = deepcopy(schema)
|
||||
defs = schema.pop("$defs", {})
|
||||
return resolve_refs(schema, defs)
|
||||
supported_string_formats = [
|
||||
"date-time",
|
||||
"time",
|
||||
"date",
|
||||
"duration",
|
||||
"email",
|
||||
"hostname",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
||||
def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
|
||||
|
|
@ -50,6 +45,138 @@ def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
|
|||
return schema
|
||||
|
||||
|
||||
# From OpenAI
|
||||
def ensure_strict_json_schema(
|
||||
json_schema: object,
|
||||
*,
|
||||
path: tuple[str, ...],
|
||||
root: dict[str, object],
|
||||
) -> dict[str, Any]:
|
||||
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
|
||||
that the API expects.
|
||||
"""
|
||||
if not isinstance(json_schema, dict):
|
||||
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
|
||||
|
||||
defs = json_schema.get("$defs")
|
||||
if isinstance(defs, dict):
|
||||
for def_name, def_schema in defs.items():
|
||||
ensure_strict_json_schema(
|
||||
def_schema, path=(*path, "$defs", def_name), root=root
|
||||
)
|
||||
|
||||
definitions = json_schema.get("definitions")
|
||||
if isinstance(definitions, dict):
|
||||
for definition_name, definition_schema in definitions.items():
|
||||
ensure_strict_json_schema(
|
||||
definition_schema,
|
||||
path=(*path, "definitions", definition_name),
|
||||
root=root,
|
||||
)
|
||||
|
||||
typ = json_schema.get("type")
|
||||
if typ == "object" and "additionalProperties" not in json_schema:
|
||||
json_schema["additionalProperties"] = False
|
||||
|
||||
# object types
|
||||
# { 'type': 'object', 'properties': { 'a': {...} } }
|
||||
properties = json_schema.get("properties")
|
||||
if isinstance(properties, dict):
|
||||
json_schema["required"] = [prop for prop in properties.keys()]
|
||||
json_schema["properties"] = {
|
||||
key: ensure_strict_json_schema(
|
||||
prop_schema, path=(*path, "properties", key), root=root
|
||||
)
|
||||
for key, prop_schema in properties.items()
|
||||
}
|
||||
|
||||
# arrays
|
||||
# { 'type': 'array', 'items': {...} }
|
||||
items = json_schema.get("items")
|
||||
if isinstance(items, dict):
|
||||
json_schema["items"] = ensure_strict_json_schema(
|
||||
items, path=(*path, "items"), root=root
|
||||
)
|
||||
|
||||
# unions
|
||||
any_of = json_schema.get("anyOf")
|
||||
if isinstance(any_of, list):
|
||||
json_schema["anyOf"] = [
|
||||
ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
|
||||
for i, variant in enumerate(any_of)
|
||||
]
|
||||
|
||||
# intersections
|
||||
all_of = json_schema.get("allOf")
|
||||
if isinstance(all_of, list):
|
||||
if len(all_of) == 1:
|
||||
json_schema.update(
|
||||
ensure_strict_json_schema(
|
||||
all_of[0], path=(*path, "allOf", "0"), root=root
|
||||
)
|
||||
)
|
||||
json_schema.pop("allOf")
|
||||
else:
|
||||
json_schema["allOf"] = [
|
||||
ensure_strict_json_schema(
|
||||
entry, path=(*path, "allOf", str(i)), root=root
|
||||
)
|
||||
for i, entry in enumerate(all_of)
|
||||
]
|
||||
|
||||
# string
|
||||
if typ == "string":
|
||||
if "format" in json_schema:
|
||||
if json_schema["format"] not in supported_string_formats:
|
||||
del json_schema["format"]
|
||||
|
||||
# strip `None` defaults as there's no meaningful distinction here
|
||||
# the schema will still be `nullable` and the model will default
|
||||
# to using `None` anyway
|
||||
if json_schema.get("default", NOT_GIVEN) is None:
|
||||
json_schema.pop("default")
|
||||
|
||||
# we can't use `$ref`s if there are also other properties defined, e.g.
|
||||
# `{"$ref": "...", "description": "my description"}`
|
||||
#
|
||||
# so we unravel the ref
|
||||
# `{"type": "string", "description": "my description"}`
|
||||
ref = json_schema.get("$ref")
|
||||
if ref and has_more_than_n_keys(json_schema, 1):
|
||||
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
|
||||
|
||||
resolved = resolve_ref(root=root, ref=ref)
|
||||
if not isinstance(resolved, dict):
|
||||
raise ValueError(
|
||||
f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}"
|
||||
)
|
||||
|
||||
# properties from the json schema take priority over the ones on the `$ref`
|
||||
json_schema.update({**resolved, **json_schema})
|
||||
json_schema.pop("$ref")
|
||||
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
|
||||
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
|
||||
return ensure_strict_json_schema(json_schema, path=path, root=root)
|
||||
|
||||
return json_schema
|
||||
|
||||
|
||||
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
|
||||
if not ref.startswith("#/"):
|
||||
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
|
||||
|
||||
path = ref[2:].split("/")
|
||||
resolved = root
|
||||
for key in path:
|
||||
value = resolved[key]
|
||||
assert isinstance(
|
||||
value, dict
|
||||
), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
|
||||
resolved = value
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
# ? Not used
|
||||
def generate_constraint_sentences(schema: dict) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,10 +25,18 @@ def set_openai_api_key_env(value):
|
|||
os.environ["OPENAI_API_KEY"] = value
|
||||
|
||||
|
||||
def set_openai_model_env(value):
|
||||
os.environ["OPENAI_MODEL"] = value
|
||||
|
||||
|
||||
def set_google_api_key_env(value):
|
||||
os.environ["GOOGLE_API_KEY"] = value
|
||||
|
||||
|
||||
def set_google_model_env(value):
|
||||
os.environ["GOOGLE_MODEL"] = value
|
||||
|
||||
|
||||
def set_anthropic_api_key_env(value):
|
||||
os.environ["ANTHROPIC_API_KEY"] = value
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,12 @@ from utils.get_env import (
|
|||
get_custom_llm_url_env,
|
||||
get_custom_model_env,
|
||||
get_google_api_key_env,
|
||||
get_google_model_env,
|
||||
get_llm_provider_env,
|
||||
get_ollama_model_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
get_pexels_api_key_env,
|
||||
get_user_config_path_env,
|
||||
get_image_provider_env,
|
||||
|
|
@ -27,10 +29,12 @@ from utils.set_env import (
|
|||
set_custom_model_env,
|
||||
set_extended_reasoning_env,
|
||||
set_google_api_key_env,
|
||||
set_google_model_env,
|
||||
set_llm_provider_env,
|
||||
set_ollama_model_env,
|
||||
set_ollama_url_env,
|
||||
set_openai_api_key_env,
|
||||
set_openai_model_env,
|
||||
set_pexels_api_key_env,
|
||||
set_image_provider_env,
|
||||
set_pixabay_api_key_env,
|
||||
|
|
@ -58,7 +62,9 @@ def get_user_config():
|
|||
return UserConfig(
|
||||
LLM=existing_config.LLM or get_llm_provider_env(),
|
||||
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or get_openai_api_key_env(),
|
||||
OPENAI_MODEL=existing_config.OPENAI_MODEL or get_openai_model_env(),
|
||||
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or get_google_api_key_env(),
|
||||
GOOGLE_MODEL=existing_config.GOOGLE_MODEL or get_google_model_env(),
|
||||
ANTHROPIC_API_KEY=existing_config.ANTHROPIC_API_KEY
|
||||
or get_anthropic_api_key_env(),
|
||||
ANTHROPIC_MODEL=existing_config.ANTHROPIC_MODEL or get_anthropic_model_env(),
|
||||
|
|
@ -81,8 +87,12 @@ def update_env_with_user_config():
|
|||
set_llm_provider_env(user_config.LLM)
|
||||
if user_config.OPENAI_API_KEY:
|
||||
set_openai_api_key_env(user_config.OPENAI_API_KEY)
|
||||
if user_config.OPENAI_MODEL:
|
||||
set_openai_model_env(user_config.OPENAI_MODEL)
|
||||
if user_config.GOOGLE_API_KEY:
|
||||
set_google_api_key_env(user_config.GOOGLE_API_KEY)
|
||||
if user_config.GOOGLE_MODEL:
|
||||
set_google_model_env(user_config.GOOGLE_MODEL)
|
||||
if user_config.ANTHROPIC_API_KEY:
|
||||
set_anthropic_api_key_env(user_config.ANTHROPIC_API_KEY)
|
||||
if user_config.ANTHROPIC_MODEL:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue