Merge pull request #173 from presenton/feat/claude-support
feat/claude support
This commit is contained in:
commit
e094d5771c
64 changed files with 2092 additions and 928 deletions
|
|
@ -8,4 +8,5 @@ build
|
|||
.gitignore
|
||||
tmp
|
||||
debug
|
||||
.fastembed_cache
|
||||
.fastembed_cache
|
||||
generated_models
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -11,4 +11,5 @@ app_data
|
|||
tmp
|
||||
debug
|
||||
.fastembed_cache
|
||||
my-doc.txt
|
||||
my-doc.txt
|
||||
generated_models
|
||||
|
|
@ -2,12 +2,11 @@ FROM python:3.11-slim-bookworm
|
|||
|
||||
# Install Node.js and npm
|
||||
RUN apt-get update && apt-get install -y \
|
||||
|
||||
nginx \
|
||||
curl \
|
||||
redis-server
|
||||
|
||||
# Install Node.js 20 using NodeSource repository
|
||||
# Install Node.js 20 using NodeSource repository
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs
|
||||
|
||||
|
|
@ -43,7 +42,7 @@ RUN npm run build
|
|||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy FastAPI and start script
|
||||
# Copy FastAPI
|
||||
COPY servers/fastapi/ ./servers/fastapi/
|
||||
COPY start.js LICENSE NOTICE ./
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ RUN apt-get update && apt-get install -y \
|
|||
redis-server
|
||||
|
||||
|
||||
# Install Node.js 20 using NodeSource repository
|
||||
# Install Node.js 20 using NodeSource repository
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs
|
||||
apt-get install -y nodejs
|
||||
|
||||
|
||||
# Change working directory
|
||||
|
|
|
|||
|
|
@ -14,12 +14,15 @@ services:
|
|||
- LLM=${LLM}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- ANTHROPIC_MODEL=${ANTHROPIC_MODEL}
|
||||
- OLLAMA_URL=${OLLAMA_URL}
|
||||
- OLLAMA_MODEL=${OLLAMA_MODEL}
|
||||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
||||
production-gpu:
|
||||
|
|
@ -44,12 +47,15 @@ services:
|
|||
- LLM=${LLM}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- ANTHROPIC_MODEL=${ANTHROPIC_MODEL}
|
||||
- OLLAMA_URL=${OLLAMA_URL}
|
||||
- OLLAMA_MODEL=${OLLAMA_MODEL}
|
||||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
||||
development:
|
||||
|
|
@ -67,12 +73,15 @@ services:
|
|||
- LLM=${LLM}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- ANTHROPIC_MODEL=${ANTHROPIC_MODEL}
|
||||
- OLLAMA_URL=${OLLAMA_URL}
|
||||
- OLLAMA_MODEL=${OLLAMA_MODEL}
|
||||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
||||
development-gpu:
|
||||
|
|
@ -97,10 +106,13 @@ services:
|
|||
- LLM=${LLM}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- ANTHROPIC_MODEL=${ANTHROPIC_MODEL}
|
||||
- OLLAMA_URL=${OLLAMA_URL}
|
||||
- OLLAMA_MODEL=${OLLAMA_MODEL}
|
||||
- CUSTOM_LLM_URL=${CUSTOM_LLM_URL}
|
||||
- CUSTOM_LLM_API_KEY=${CUSTOM_LLM_API_KEY}
|
||||
- CUSTOM_MODEL=${CUSTOM_MODEL}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY}
|
||||
- EXTENDED_REASONING=${EXTENDED_REASONING}
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
|
|
@ -1,10 +1,8 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from api.lifespan import app_lifespan
|
||||
from api.middlewares import UserConfigEnvUpdateMiddleware
|
||||
from api.v1.ppt.router import API_V1_PPT_ROUTER
|
||||
from utils.asset_directory_utils import get_exports_directory, get_images_directory, get_uploads_directory
|
||||
|
||||
|
||||
app = FastAPI(lifespan=app_lifespan)
|
||||
|
|
@ -13,25 +11,6 @@ app = FastAPI(lifespan=app_lifespan)
|
|||
# Routers
|
||||
app.include_router(API_V1_PPT_ROUTER)
|
||||
|
||||
# Static files
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
app.mount(
|
||||
"/app_data/images",
|
||||
StaticFiles(directory=get_images_directory()),
|
||||
name="app_data/images",
|
||||
)
|
||||
app.mount(
|
||||
"/app_data/exports",
|
||||
StaticFiles(directory=get_exports_directory()),
|
||||
name="app_data/exports",
|
||||
)
|
||||
app.mount(
|
||||
"/app_data/uploads",
|
||||
StaticFiles(directory=get_uploads_directory()),
|
||||
name="app_data/uploads",
|
||||
)
|
||||
|
||||
|
||||
# Middlewares
|
||||
origins = ["*"]
|
||||
app.add_middleware(
|
||||
|
|
|
|||
16
servers/fastapi/api/v1/ppt/endpoints/anthropic.py
Normal file
16
servers/fastapi/api/v1/ppt/endpoints/anthropic.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.available_models import list_available_anthropic_models
|
||||
|
||||
ANTHROPIC_ROUTER = APIRouter(prefix="/anthropic", tags=["Anthropic"])
|
||||
|
||||
|
||||
@ANTHROPIC_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
api_key: Annotated[str, Body(embed=True)],
|
||||
):
|
||||
try:
|
||||
return await list_available_anthropic_models(api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
from typing import Annotated, List, Optional
|
||||
from fastapi import APIRouter, Body
|
||||
|
||||
from utils.custom_llm_provider import list_available_custom_models
|
||||
|
||||
CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"])
|
||||
|
||||
|
||||
@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
url: Annotated[Optional[str], Body()] = None,
|
||||
api_key: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
return await list_available_custom_models(url, api_key)
|
||||
14
servers/fastapi/api/v1/ppt/endpoints/google.py
Normal file
14
servers/fastapi/api/v1/ppt/endpoints/google.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.available_models import list_available_google_models
|
||||
|
||||
GOOGLE_ROUTER = APIRouter(prefix="/google", tags=["Google"])
|
||||
|
||||
|
||||
@GOOGLE_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(api_key: Annotated[str, Body(embed=True)]):
|
||||
try:
|
||||
return await list_available_google_models(api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
17
servers/fastapi/api/v1/ppt/endpoints/openai.py
Normal file
17
servers/fastapi/api/v1/ppt/endpoints/openai.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from typing import Annotated, List
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from utils.available_models import list_available_openai_compatible_models
|
||||
|
||||
OPENAI_ROUTER = APIRouter(prefix="/openai", tags=["OpenAI"])
|
||||
|
||||
|
||||
@OPENAI_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
url: Annotated[str, Body()],
|
||||
api_key: Annotated[str, Body()],
|
||||
):
|
||||
try:
|
||||
return await list_available_openai_compatible_models(url, api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -54,7 +54,6 @@ async def stream_outlines(
|
|||
presentation.outlines = [
|
||||
each.model_dump() for each in presentation_content.slides
|
||||
]
|
||||
presentation.notes = presentation_content.notes
|
||||
|
||||
sql_session.add(presentation)
|
||||
await sql_session.commit()
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
import importlib
|
||||
from typing import Annotated, List, Literal, Optional
|
||||
from fastapi import APIRouter, Body, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import String, cast, delete
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
|
||||
|
|
@ -19,7 +20,7 @@ from models.pptx_models import PptxPresentationModel
|
|||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.presentation_with_slides import PresentationWithSlides
|
||||
from services.get_layout_by_name import get_layout_by_name
|
||||
from utils.get_layout_by_name import get_layout_by_name
|
||||
from services.icon_finder_service import IconFinderService
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.dict_utils import deep_update
|
||||
|
|
@ -217,9 +218,11 @@ async def stream_presentation(
|
|||
).to_string()
|
||||
for i, slide_layout_index in enumerate(structure.slides):
|
||||
slide_layout = layout.slides[slide_layout_index]
|
||||
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
slide_layout, outline.slides[i], presentation.language
|
||||
)
|
||||
|
||||
slide = SlideModel(
|
||||
presentation=presentation_id,
|
||||
layout_group=layout.name,
|
||||
|
|
@ -236,9 +239,6 @@ async def stream_presentation(
|
|||
)
|
||||
)
|
||||
|
||||
# Give control to the event loop
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
||||
|
|
@ -475,7 +475,6 @@ async def from_template(
|
|||
new_slide_data = list(filter(lambda x: x.index == each_slide.index, data.data))
|
||||
if new_slide_data:
|
||||
updated_content = deep_update(each_slide.content, new_slide_data[0].content)
|
||||
print(f"Updated content for slide {each_slide.index}: {updated_content}")
|
||||
new_slides.append(
|
||||
each_slide.get_new_slide(new_presentation.id, updated_content)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -13,6 +14,7 @@ from utils.llm_calls.edit_slide_html import get_edited_slide_html
|
|||
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
|
||||
from utils.process_slides import process_old_and_new_slides_and_fetch_assets
|
||||
from utils.randomizers import get_random_uuid
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
|
||||
SLIDE_ROUTER = APIRouter(prefix="/slide", tags=["Slide"])
|
||||
|
|
@ -32,12 +34,12 @@ async def edit_slide(
|
|||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
presentation_layout = presentation.get_layout()
|
||||
|
||||
slide_layout = await get_slide_layout_from_prompt(
|
||||
prompt, presentation_layout, slide
|
||||
)
|
||||
|
||||
edited_slide_content = await get_edited_slide_content(
|
||||
prompt, slide_layout, slide, presentation.language
|
||||
prompt, slide, presentation.language, slide_layout
|
||||
)
|
||||
|
||||
image_generation_service = ImageGenerationService(get_images_directory())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER
|
||||
from api.v1.ppt.endpoints.anthropic import ANTHROPIC_ROUTER
|
||||
from api.v1.ppt.endpoints.google import GOOGLE_ROUTER
|
||||
from api.v1.ppt.endpoints.openai import OPENAI_ROUTER
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
from api.v1.ppt.endpoints.icons import ICONS_ROUTER
|
||||
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
||||
|
|
@ -19,4 +21,6 @@ API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OPENAI_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ANTHROPIC_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(GOOGLE_ROUTER)
|
||||
|
|
|
|||
Binary file not shown.
6
servers/fastapi/constants/llm.py
Normal file
6
servers/fastapi/constants/llm.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
OPENAI_URL = "https://api.openai.com/v1"
|
||||
|
||||
# Default models
|
||||
DEFAULT_OPENAI_MODEL = "gpt-4.1"
|
||||
DEFAULT_GOOGLE_MODEL = "models/gemini-2.0-flash"
|
||||
DEFAULT_ANTHROPIC_MODEL = "claude-3-5-sonnet-20240620"
|
||||
|
|
@ -7,7 +7,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3:70b": OllamaModelMetadata(
|
||||
|
|
@ -15,7 +14,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="40GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:8b": OllamaModelMetadata(
|
||||
|
|
@ -23,7 +21,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.1:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:70b": OllamaModelMetadata(
|
||||
|
|
@ -31,7 +28,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.1:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:405b": OllamaModelMetadata(
|
||||
|
|
@ -39,7 +35,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.1:405b",
|
||||
description="✅ Graphs supported.",
|
||||
size="243GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:1b": OllamaModelMetadata(
|
||||
|
|
@ -47,7 +42,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.2:1b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:3b": OllamaModelMetadata(
|
||||
|
|
@ -55,7 +49,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.2:3b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.3:70b": OllamaModelMetadata(
|
||||
|
|
@ -63,7 +56,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama3.3:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:16x17b": OllamaModelMetadata(
|
||||
|
|
@ -71,7 +63,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama4:16x17b",
|
||||
description="✅ Graphs supported.",
|
||||
size="67GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:128x17b": OllamaModelMetadata(
|
||||
|
|
@ -79,7 +70,6 @@ SUPPORTED_OLLAMA_MODELS = {
|
|||
value="llama4:128x17b",
|
||||
description="✅ Graphs supported.",
|
||||
size="245GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
}
|
||||
|
|
@ -90,7 +80,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:1b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="815MB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:4b": OllamaModelMetadata(
|
||||
|
|
@ -98,7 +87,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:4b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="3.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:12b": OllamaModelMetadata(
|
||||
|
|
@ -106,7 +94,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:12b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="8.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:27b": OllamaModelMetadata(
|
||||
|
|
@ -114,7 +101,6 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
value="gemma3:27b",
|
||||
description="✅ Graphs supported.",
|
||||
size="17GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
}
|
||||
|
|
@ -125,7 +111,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:1.5b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:7b": OllamaModelMetadata(
|
||||
|
|
@ -133,7 +118,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:7b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:8b": OllamaModelMetadata(
|
||||
|
|
@ -141,7 +125,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:14b": OllamaModelMetadata(
|
||||
|
|
@ -149,7 +132,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:14b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:32b": OllamaModelMetadata(
|
||||
|
|
@ -157,7 +139,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:32b",
|
||||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:70b": OllamaModelMetadata(
|
||||
|
|
@ -165,7 +146,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:671b": OllamaModelMetadata(
|
||||
|
|
@ -173,7 +153,6 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
value="deepseek-r1:671b",
|
||||
description="✅ Graphs supported.",
|
||||
size="404GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
}
|
||||
|
|
@ -184,7 +163,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:0.6b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="523MB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:1.7b": OllamaModelMetadata(
|
||||
|
|
@ -192,7 +170,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:1.7b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.4GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:4b": OllamaModelMetadata(
|
||||
|
|
@ -200,7 +177,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:4b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="2.6GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:8b": OllamaModelMetadata(
|
||||
|
|
@ -208,7 +184,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:14b": OllamaModelMetadata(
|
||||
|
|
@ -216,7 +191,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:14b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="9.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:30b": OllamaModelMetadata(
|
||||
|
|
@ -224,7 +198,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:30b",
|
||||
description="✅ Graphs supported.",
|
||||
size="19GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:32b": OllamaModelMetadata(
|
||||
|
|
@ -232,7 +205,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:32b",
|
||||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:235b": OllamaModelMetadata(
|
||||
|
|
@ -240,7 +212,6 @@ SUPPORTED_QWEN_MODELS = {
|
|||
value="qwen3:235b",
|
||||
description="✅ Graphs supported.",
|
||||
size="142GB",
|
||||
supports_graph=True,
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,4 +5,5 @@ class LLMProvider(Enum):
|
|||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
|
|
|
|||
7
servers/fastapi/models/llm_message.py
Normal file
7
servers/fastapi/models/llm_message.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from typing import Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: Literal["user", "system"]
|
||||
content: str
|
||||
|
|
@ -7,4 +7,3 @@ class OllamaModelMetadata(BaseModel):
|
|||
description: str
|
||||
icon: str
|
||||
size: str
|
||||
supports_graph: bool
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List, Optional
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
|
@ -12,10 +13,18 @@ class SlideLayoutModel(BaseModel):
|
|||
|
||||
|
||||
class PresentationLayoutModel(BaseModel):
|
||||
name: Optional[str] = None
|
||||
name: str
|
||||
ordered: bool = Field(default=False)
|
||||
slides: List[SlideLayoutModel]
|
||||
|
||||
def get_slide_layout_index(self, slide_layout_id: str) -> int:
|
||||
for index, slide in enumerate(self.slides):
|
||||
if slide.id == slide_layout_id:
|
||||
return index
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Slide layout {slide_layout_id} not found"
|
||||
)
|
||||
|
||||
def to_presentation_structure(self):
|
||||
return PresentationStructureModel(
|
||||
slides=[index for index in range(len(self.slides))]
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ class PresentationOutlineModel(BaseModel):
|
|||
title: str = Field(
|
||||
description="Title of the presentation in about 3 to 8 words",
|
||||
)
|
||||
notes: Optional[List[str]] = Field(default=None, description="Notes for the presentation")
|
||||
slides: List[SlideOutlineModel] = Field(description="List of slides")
|
||||
|
||||
def to_string(self):
|
||||
|
|
@ -25,8 +24,8 @@ class PresentationOutlineModel(BaseModel):
|
|||
message += f" - Title: {slide.title} \n"
|
||||
message += f" - Body: {slide.body} \n"
|
||||
|
||||
if self.notes:
|
||||
message += f"# Notes: \n"
|
||||
for note in self.notes:
|
||||
message += f" - {note} \n"
|
||||
# if self.notes:
|
||||
# message += f"# Notes: \n"
|
||||
# for note in self.notes:
|
||||
# message += f" - {note} \n"
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class PresentationModel(SQLModel, table=True):
|
|||
return PresentationOutlineModel(
|
||||
title=self.title,
|
||||
slides=[SlideOutlineModel(**each) for each in self.outlines],
|
||||
notes=self.notes,
|
||||
# notes=self.notes,
|
||||
)
|
||||
|
||||
def get_layout(self):
|
||||
|
|
|
|||
|
|
@ -4,13 +4,32 @@ from pydantic import BaseModel
|
|||
|
||||
class UserConfig(BaseModel):
|
||||
LLM: Optional[str] = None
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
OPENAI_MODEL: Optional[str] = None
|
||||
|
||||
# Google
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
GOOGLE_MODEL: Optional[str] = None
|
||||
|
||||
# Anthropic
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_MODEL: Optional[str] = None
|
||||
|
||||
# Ollama
|
||||
OLLAMA_URL: Optional[str] = None
|
||||
OLLAMA_MODEL: Optional[str] = None
|
||||
|
||||
# Custom LLM
|
||||
CUSTOM_LLM_URL: Optional[str] = None
|
||||
CUSTOM_LLM_API_KEY: Optional[str] = None
|
||||
CUSTOM_MODEL: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
|
||||
# Image Provider
|
||||
IMAGE_PROVIDER: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
PIXABAY_API_KEY: Optional[str] = None
|
||||
|
||||
# Reasoning
|
||||
EXTENDED_REASONING: Optional[bool] = None
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ aiomysql==0.2.0
|
|||
aiosignal==1.4.0
|
||||
aiosqlite==0.21.0
|
||||
annotated-types==0.7.0
|
||||
anthropic==0.60.0
|
||||
anyio==4.9.0
|
||||
argcomplete==3.6.2
|
||||
async-timeout==5.0.1
|
||||
asyncpg==0.30.0
|
||||
attrs==25.3.0
|
||||
backoff==2.2.1
|
||||
bcrypt==4.3.0
|
||||
black==25.1.0
|
||||
build==1.2.2.post1
|
||||
cachetools==5.5.2
|
||||
certifi==2025.7.14
|
||||
|
|
@ -31,6 +34,7 @@ filelock==3.18.0
|
|||
flatbuffers==25.2.10
|
||||
frozenlist==1.7.0
|
||||
fsspec==2025.7.0
|
||||
genson==1.3.0
|
||||
google-auth==2.40.3
|
||||
google-genai==1.25.0
|
||||
googleapis-common-protos==1.70.0
|
||||
|
|
@ -49,7 +53,9 @@ hyperframe==6.1.0
|
|||
idna==3.10
|
||||
importlib_metadata==8.7.0
|
||||
importlib_resources==6.5.2
|
||||
inflect==7.5.0
|
||||
iniconfig==2.1.0
|
||||
isort==6.0.1
|
||||
Jinja2==3.1.6
|
||||
jiter==0.10.0
|
||||
jsonschema==4.25.0
|
||||
|
|
@ -61,8 +67,10 @@ markdown-it-py==3.0.0
|
|||
MarkupSafe==3.0.2
|
||||
mdurl==0.1.2
|
||||
mmh3==5.1.0
|
||||
more-itertools==10.7.0
|
||||
mpmath==1.3.0
|
||||
multidict==6.6.3
|
||||
mypy_extensions==1.1.0
|
||||
numpy==2.3.2
|
||||
oauthlib==3.3.1
|
||||
onnxruntime==1.22.1
|
||||
|
|
@ -76,10 +84,12 @@ opentelemetry-semantic-conventions==0.56b0
|
|||
orjson==3.11.1
|
||||
overrides==7.7.0
|
||||
packaging==25.0
|
||||
pathspec==0.12.1
|
||||
pathvalidate==3.3.1
|
||||
pdfminer.six==20250506
|
||||
pdfplumber==0.11.7
|
||||
pillow==11.3.0
|
||||
platformdirs==4.3.8
|
||||
pluggy==1.6.0
|
||||
portalocker==3.2.0
|
||||
posthog==5.4.0
|
||||
|
|
@ -122,7 +132,9 @@ starlette==0.47.1
|
|||
sympy==1.14.0
|
||||
tenacity==8.5.0
|
||||
tokenizers==0.21.2
|
||||
tomli==2.2.1
|
||||
tqdm==4.67.1
|
||||
typeguard==4.4.4
|
||||
typer==0.16.0
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.14.1
|
||||
|
|
|
|||
|
|
@ -8,14 +8,15 @@ if __name__ == "__main__":
|
|||
"--port", type=int, required=True, help="Port number to run the server on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload", type=bool, default=False, help="Reload the server on code changes"
|
||||
"--reload", type=str, default="false", help="Reload the server on code changes"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
reload = args.reload == "true"
|
||||
|
||||
uvicorn.run(
|
||||
"api.main:app",
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
reload=args.reload,
|
||||
reload=reload,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import os
|
|||
import aiohttp
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from openai import AsyncOpenAI
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.sql.image_asset import ImageAsset
|
||||
from utils.download_helpers import download_file
|
||||
from utils.get_env import get_pexels_api_key_env
|
||||
from utils.get_env import get_pixabay_api_key_env
|
||||
from utils.llm_provider import get_llm_client
|
||||
from utils.image_provider import (
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
|
|
@ -80,7 +80,7 @@ class ImageGenerationService:
|
|||
return "/static/images/placeholder.jpg"
|
||||
|
||||
async def generate_image_openai(self, prompt: str, output_directory: str) -> str:
|
||||
client = get_llm_client()
|
||||
client = AsyncOpenAI()
|
||||
result = await client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=prompt,
|
||||
|
|
|
|||
509
servers/fastapi/services/llm_client.py
Normal file
509
servers/fastapi/services/llm_client.py
Normal file
|
|
@ -0,0 +1,509 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import List
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
from anthropic import MessageStreamEvent as AnthropicMessageStreamEvent
|
||||
from enums.llm_provider import LLMProvider
|
||||
from models.llm_message import LLMMessage
|
||||
from utils.async_iterator import iterator_to_async
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_google_api_key_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
)
|
||||
from utils.llm_provider import get_llm_provider
|
||||
from utils.schema_utils import ensure_strict_json_schema
|
||||
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, max_tokens: int = 4000):
|
||||
self.llm_provider = get_llm_provider()
|
||||
self._client = self._get_client()
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# ? Clients
|
||||
def _get_client(self):
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self._get_openai_client()
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._get_google_client()
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self._get_anthropic_client()
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._get_ollama_client()
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._get_custom_client()
|
||||
case _:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="LLM Provider must be either openai, google, anthropic, ollama, or custom",
|
||||
)
|
||||
|
||||
def _get_openai_client(self):
|
||||
if not get_openai_api_key_env():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="OpenAI API Key is not set",
|
||||
)
|
||||
return AsyncOpenAI()
|
||||
|
||||
def _get_google_client(self):
|
||||
if not get_google_api_key_env():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google API Key is not set",
|
||||
)
|
||||
return genai.Client()
|
||||
|
||||
def _get_anthropic_client(self):
|
||||
if not get_anthropic_api_key_env():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Anthropic API Key is not set",
|
||||
)
|
||||
return AsyncAnthropic()
|
||||
|
||||
def _get_ollama_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=(get_ollama_url_env() or "http://localhost:11434") + "/v1",
|
||||
api_key="ollama",
|
||||
)
|
||||
|
||||
def _get_custom_client(self):
|
||||
if not (get_custom_llm_api_key_env() and get_custom_llm_url_env()):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Custom LLM API Key is not set",
|
||||
)
|
||||
return AsyncOpenAI(
|
||||
base_url=get_custom_llm_url_env(),
|
||||
api_key=get_custom_llm_api_key_env(),
|
||||
)
|
||||
|
||||
# ? Prompts
|
||||
def _get_system_prompt(self, messages: List[LLMMessage]) -> str:
|
||||
for message in messages:
|
||||
if message.role == "system":
|
||||
return message.content
|
||||
return ""
|
||||
|
||||
def _get_user_prompts(self, messages: List[LLMMessage]) -> List[str]:
|
||||
return [message.content for message in messages if message.role == "user"]
|
||||
|
||||
def _get_user_llm_messages(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
||||
return [message for message in messages if message.role == "user"]
|
||||
|
||||
# ? Generate Unstructured Content
|
||||
async def _generate_openai(self, model: str, messages: List[LLMMessage]):
|
||||
client: AsyncOpenAI = self._client
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _generate_google(self, model: str, messages: List[LLMMessage]):
|
||||
client: genai.Client = self._client
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=self._get_user_prompts(messages),
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="text/plain",
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
async def _generate_anthropic(self, model: str, messages: List[LLMMessage]):
|
||||
client: AsyncAnthropic = self._client
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
text = ""
|
||||
for content in response.content:
|
||||
if content.type == "text" and isinstance(content.text, str):
|
||||
text += content.text
|
||||
if text == "":
|
||||
return None
|
||||
return text
|
||||
|
||||
async def _generate_ollama(self, model: str, messages: List[LLMMessage]):
|
||||
return await self._generate_openai(model, messages)
|
||||
|
||||
async def _generate_custom(self, model: str, messages: List[LLMMessage]):
|
||||
return await self._generate_openai(model, messages)
|
||||
|
||||
async def generate(self, model: str, messages: List[LLMMessage]):
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
content = await self._generate_openai(model, messages)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google(model, messages)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
content = await self._generate_anthropic(model, messages)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama(model, messages)
|
||||
case LLMProvider.CUSTOM:
|
||||
content = await self._generate_custom(model, messages)
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="LLM did not return any content",
|
||||
)
|
||||
return content
|
||||
|
||||
# ? Generate Structured Content
|
||||
async def _generate_openai_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
response_schema = response_format
|
||||
if strict:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
response_schema,
|
||||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": (
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
}
|
||||
),
|
||||
},
|
||||
max_completion_tokens=self.max_tokens,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return json.loads(content)
|
||||
return None
|
||||
|
||||
async def _generate_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=self._get_user_prompts(messages),
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
)
|
||||
content = None
|
||||
if response.text:
|
||||
content = json.loads(response.text)
|
||||
|
||||
return content
|
||||
|
||||
async def _generate_anthropic_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
response: AnthropicMessage = await client.messages.create(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"input_schema": response_format,
|
||||
}
|
||||
],
|
||||
)
|
||||
content: dict | None = None
|
||||
for content_block in response.content:
|
||||
if content_block.type == "tool_use":
|
||||
content = content_block.input
|
||||
|
||||
return content
|
||||
|
||||
async def _generate_ollama_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
|
||||
async def _generate_custom_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
|
||||
async def generate_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
) -> dict:
|
||||
content = None
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
content = await self._generate_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
content = await self._generate_google_structured(
|
||||
model, messages, response_format
|
||||
)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
content = await self._generate_anthropic_structured(
|
||||
model, messages, response_format
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
content = await self._generate_ollama_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
content = await self._generate_custom_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
if content is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="LLM did not return any content",
|
||||
)
|
||||
return content
|
||||
|
||||
# ? Stream Unstructured Content
|
||||
async def _stream_openai(self, model: str, messages: List[LLMMessage]):
|
||||
client: AsyncOpenAI = self._client
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
yield event.delta
|
||||
|
||||
async def _stream_google(self, model: str, messages: List[LLMMessage]):
|
||||
client: genai.Client = self._client
|
||||
async for event in iterator_to_async(client.models.generate_content_stream)(
|
||||
model=model,
|
||||
contents=self._get_user_prompts(messages),
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="text/plain",
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
yield event.text
|
||||
|
||||
async def _stream_anthropic(self, model: str, messages: List[LLMMessage]):
|
||||
client: AsyncAnthropic = self._client
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
if event.type == "text" and isinstance(event.text, str):
|
||||
yield event.text
|
||||
|
||||
def _stream_ollama(self, model: str, messages: List[LLMMessage]):
|
||||
return self._stream_openai(model, messages)
|
||||
|
||||
def _stream_custom(self, model: str, messages: List[LLMMessage]):
|
||||
return self._stream_openai(model, messages)
|
||||
|
||||
def stream(self, model: str, messages: List[LLMMessage]):
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self._stream_openai(model, messages)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google(model, messages)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self._stream_anthropic(model, messages)
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._stream_ollama(model, messages)
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._stream_custom(model, messages)
|
||||
|
||||
# ? Stream Structured Content
|
||||
async def _stream_openai_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
client: AsyncOpenAI = self._client
|
||||
response_schema = response_format
|
||||
if strict:
|
||||
response_schema = ensure_strict_json_schema(
|
||||
response_schema,
|
||||
path=(),
|
||||
root=response_schema,
|
||||
)
|
||||
async with client.chat.completions.stream(
|
||||
model=model,
|
||||
messages=[message.model_dump() for message in messages],
|
||||
max_completion_tokens=self.max_tokens,
|
||||
response_format=(
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "ResponseSchema",
|
||||
"strict": strict,
|
||||
"schema": response_schema,
|
||||
},
|
||||
}
|
||||
),
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content.delta":
|
||||
yield event.delta
|
||||
|
||||
async def _stream_google_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: genai.Client = self._client
|
||||
async for event in iterator_to_async(client.models.generate_content_stream)(
|
||||
model=model,
|
||||
contents=self._get_user_prompts(messages),
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=self._get_system_prompt(messages),
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_format,
|
||||
max_output_tokens=self.max_tokens,
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
yield event.text
|
||||
|
||||
async def _stream_anthropic_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
):
|
||||
client: AsyncAnthropic = self._client
|
||||
async with client.messages.stream(
|
||||
model=model,
|
||||
system=self._get_system_prompt(messages),
|
||||
messages=[
|
||||
message.model_dump()
|
||||
for message in self._get_user_llm_messages(messages)
|
||||
],
|
||||
max_tokens=self.max_tokens,
|
||||
tools=[
|
||||
{
|
||||
"name": "ResponseSchema",
|
||||
"description": "A response to the user's message",
|
||||
"input_schema": response_format,
|
||||
}
|
||||
],
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
event: AnthropicMessageStreamEvent = event
|
||||
if event.type == "input_json" and isinstance(event.partial_json, str):
|
||||
yield event.partial_json
|
||||
|
||||
def _stream_ollama_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return self._stream_openai_structured(model, messages, response_format, strict)
|
||||
|
||||
def _stream_custom_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
return self._stream_openai_structured(model, messages, response_format, strict)
|
||||
|
||||
def stream_structured(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[LLMMessage],
|
||||
response_format: dict,
|
||||
strict: bool = False,
|
||||
):
|
||||
match self.llm_provider:
|
||||
case LLMProvider.OPENAI:
|
||||
return self._stream_openai_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.GOOGLE:
|
||||
return self._stream_google_structured(model, messages, response_format)
|
||||
case LLMProvider.ANTHROPIC:
|
||||
return self._stream_anthropic_structured(
|
||||
model, messages, response_format
|
||||
)
|
||||
case LLMProvider.OLLAMA:
|
||||
return self._stream_ollama_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
case LLMProvider.CUSTOM:
|
||||
return self._stream_custom_structured(
|
||||
model, messages, response_format, strict
|
||||
)
|
||||
|
|
@ -9,8 +9,7 @@ class TempFileService:
|
|||
|
||||
def __init__(self):
|
||||
self.base_dir = get_temp_directory_env() or "/tmp/presenton"
|
||||
# TODO: Uncomment this when we want to cleanup the base dir on startup
|
||||
# self.cleanup_base_dir()
|
||||
self.cleanup_base_dir()
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
def create_dir_in_dir(self, base_dir: str, dir_name: Optional[str] = None) -> str:
|
||||
|
|
|
|||
21
servers/fastapi/utils/available_models.py
Normal file
21
servers/fastapi/utils/available_models.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
||||
|
||||
async def list_available_openai_compatible_models(url: str, api_key: str) -> list[str]:
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=url)
|
||||
models = (await client.models.list()).data
|
||||
if models:
|
||||
return list(map(lambda x: x.id, models))
|
||||
return []
|
||||
|
||||
|
||||
async def list_available_anthropic_models(api_key: str) -> list[str]:
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
return list(map(lambda x: x.id, (await client.models.list(limit=50)).data))
|
||||
|
||||
|
||||
async def list_available_google_models(api_key: str) -> list[str]:
|
||||
client = genai.Client(api_key=api_key)
|
||||
return list(map(lambda x: x.name, client.models.list(config={"page_size": 50})))
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from typing import Optional
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from utils.llm_provider import get_llm_client
|
||||
|
||||
|
||||
async def list_available_custom_models(
|
||||
url: Optional[str] = None, api_key: Optional[str] = None
|
||||
) -> list[str]:
|
||||
if not url:
|
||||
client = get_llm_client()
|
||||
else:
|
||||
client = AsyncOpenAI(api_key=api_key or "null", base_url=url)
|
||||
models = []
|
||||
async for model in client.models.list():
|
||||
print(model)
|
||||
models.append(model.id)
|
||||
return models
|
||||
|
|
@ -78,3 +78,12 @@ def deep_update(original: dict, updates: dict) -> dict:
|
|||
if not isinstance(value, (dict, list)):
|
||||
original[key] = value
|
||||
return original
|
||||
|
||||
|
||||
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
|
||||
i = 0
|
||||
for _ in obj.keys():
|
||||
i += 1
|
||||
if i > n:
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -22,12 +22,6 @@ def get_presentation_outline_model_with_n_slides(n_slides: int):
|
|||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
notes: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Important notes for the presentation styling and formatting",
|
||||
min_length=0,
|
||||
max_length=10,
|
||||
)
|
||||
slides: List[SlideOutlineModelWithValidation] = Field(
|
||||
description="List of slides", min_items=n_slides, max_items=n_slides
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,14 @@ def get_llm_provider_env():
|
|||
return os.getenv("LLM")
|
||||
|
||||
|
||||
def get_anthropic_api_key_env():
|
||||
return os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
|
||||
def get_anthropic_model_env():
|
||||
return os.getenv("ANTHROPIC_MODEL")
|
||||
|
||||
|
||||
def get_ollama_url_env():
|
||||
return os.getenv("OLLAMA_URL")
|
||||
|
||||
|
|
@ -37,10 +45,18 @@ def get_openai_api_key_env():
|
|||
return os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def get_openai_model_env():
|
||||
return os.getenv("OPENAI_MODEL")
|
||||
|
||||
|
||||
def get_google_api_key_env():
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
|
||||
def get_google_model_env():
|
||||
return os.getenv("GOOGLE_MODEL")
|
||||
|
||||
|
||||
def get_custom_llm_api_key_env():
|
||||
return os.getenv("CUSTOM_LLM_API_KEY")
|
||||
|
||||
|
|
@ -79,3 +95,7 @@ def get_redis_db_env():
|
|||
|
||||
def get_redis_password_env():
|
||||
return os.getenv("REDIS_PASSWORD")
|
||||
|
||||
|
||||
def get_extended_reasoning_env():
|
||||
return os.getenv("EXTENDED_REASONING")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,8 @@
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.sql.slide import SlideModel
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from utils.llm_provider import (
|
||||
get_google_llm_client,
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
is_google_selected,
|
||||
)
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
system_prompt = """
|
||||
|
|
@ -42,64 +35,40 @@ def get_user_prompt(prompt: str, slide_data: dict, language: str):
|
|||
"""
|
||||
|
||||
|
||||
def get_prompt_to_edit_slide_content(
|
||||
def get_messages(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
language: str,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(prompt, slide_data, language),
|
||||
},
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(prompt, slide_data, language),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def get_edited_slide_content(
|
||||
prompt: str,
|
||||
slide_layout: SlideLayoutModel,
|
||||
slide: SlideModel,
|
||||
language: str,
|
||||
slide_layout: SlideLayoutModel,
|
||||
):
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
if is_google_selected():
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(prompt, slide.content, language)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_schema,
|
||||
),
|
||||
)
|
||||
slide_content_json = json.loads(response.text)
|
||||
else:
|
||||
client = get_llm_client()
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_edit_slide_content(
|
||||
prompt,
|
||||
slide.content,
|
||||
language,
|
||||
),
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "slide_content",
|
||||
"schema": response_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
slide_content_json = json.loads(response.choices[0].message.content)
|
||||
|
||||
return slide_content_json
|
||||
client = LLMClient()
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(prompt, slide.content, language),
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from utils.llm_provider import (
|
||||
get_google_llm_client,
|
||||
get_large_model,
|
||||
is_google_selected,
|
||||
get_llm_client,
|
||||
)
|
||||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
system_prompt = """
|
||||
You are an expert HTML slide editor. Your task is to modify slide HTML content based on user prompts while maintaining proper structure, styling, and functionality.
|
||||
|
|
@ -52,35 +47,17 @@ def get_user_prompt(prompt: str, html: str):
|
|||
|
||||
|
||||
async def get_edited_slide_html(prompt: str, html: str):
|
||||
model = get_large_model()
|
||||
llm_response = None
|
||||
if is_google_selected():
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(prompt, html)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="text/plain",
|
||||
),
|
||||
)
|
||||
llm_response = response.text
|
||||
else:
|
||||
client = get_llm_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": get_user_prompt(prompt, html)},
|
||||
],
|
||||
)
|
||||
llm_response = response.choices[0].message.content
|
||||
model = get_model()
|
||||
|
||||
if not llm_response:
|
||||
return html
|
||||
|
||||
return extract_html_from_response(llm_response) or html
|
||||
client = LLMClient()
|
||||
response = await client.generate(
|
||||
model=model,
|
||||
messages=[
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=get_user_prompt(prompt, html)),
|
||||
],
|
||||
)
|
||||
return extract_html_from_response(response) or html
|
||||
|
||||
|
||||
def extract_html_from_response(response_text: str) -> Optional[str]:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from utils.llm_provider import get_llm_client, get_nano_model
|
||||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
||||
sysmte_prompt = """
|
||||
|
|
@ -23,23 +24,21 @@ Maintain as much information as possible.
|
|||
|
||||
|
||||
async def generate_document_summary(documents: List[str]):
|
||||
client = get_llm_client()
|
||||
model = get_nano_model()
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
coroutines = []
|
||||
for document in documents:
|
||||
truncated_text = document[:200000]
|
||||
coroutine = client.chat.completions.create(
|
||||
coroutine = client.generate(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": sysmte_prompt},
|
||||
{"role": "user", "content": truncated_text},
|
||||
LLMMessage(role="system", content=sysmte_prompt),
|
||||
LLMMessage(role="user", content=truncated_text),
|
||||
],
|
||||
)
|
||||
coroutines.append(coroutine)
|
||||
|
||||
completions: List[ChatCompletion] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join(
|
||||
[completion.choices[0].message.content for completion in completions]
|
||||
)
|
||||
completions: List[str] = await asyncio.gather(*coroutines)
|
||||
combined = "\n\n\n\n".join(completions)
|
||||
return combined
|
||||
|
|
|
|||
|
|
@ -1,16 +1,10 @@
|
|||
import asyncio
|
||||
from typing import Optional
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
|
||||
from utils.async_iterator import iterator_to_async
|
||||
from models.llm_message import LLMMessage
|
||||
from services.llm_client import LLMClient
|
||||
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
|
||||
from utils.llm_provider import (
|
||||
get_google_llm_client,
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
is_google_selected,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
system_prompt = """
|
||||
You are an expert presentation creator. Generate structured presentations based on user requirements and format them according to the specified JSON schema with markdown content.
|
||||
|
|
@ -62,61 +56,34 @@ def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
|
|||
"""
|
||||
|
||||
|
||||
def get_prompt_template(prompt: str, n_slides: int, language: str, content: str):
|
||||
def get_messages(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(prompt, n_slides, language, content),
|
||||
},
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(prompt, n_slides, language, content),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_response_format(response_model: BaseModel):
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "PresentationOutlineModel",
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def generate_ppt_outline(
|
||||
prompt: Optional[str],
|
||||
n_slides: int,
|
||||
language: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
):
|
||||
model = get_large_model()
|
||||
model = get_model()
|
||||
response_model = get_presentation_outline_model_with_n_slides(n_slides)
|
||||
|
||||
if not is_google_selected():
|
||||
client = get_llm_client()
|
||||
async for response in await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=get_prompt_template(prompt, n_slides, language, content),
|
||||
stream=True,
|
||||
response_format=get_response_format(response_model),
|
||||
):
|
||||
delta: ChoiceDelta = response.choices[0].delta
|
||||
if delta.content:
|
||||
yield delta.content
|
||||
client = LLMClient()
|
||||
|
||||
else:
|
||||
client = get_google_llm_client()
|
||||
generate_stream = iterator_to_async(client.models.generate_content_stream)
|
||||
async for event in generate_stream(
|
||||
model=model,
|
||||
contents=[get_user_prompt(prompt, n_slides, language, content)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_model.model_json_schema(),
|
||||
),
|
||||
):
|
||||
if event.text:
|
||||
yield event.text
|
||||
async for chunk in client.stream_structured(
|
||||
model,
|
||||
get_messages(prompt, n_slides, language, content),
|
||||
response_model.model_json_schema(),
|
||||
strict=True,
|
||||
):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,24 +1,19 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from utils.llm_provider import (
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
get_nano_model,
|
||||
get_small_model,
|
||||
)
|
||||
from utils.get_dynamic_models import (
|
||||
get_presentation_structure_model_with_n_slides,
|
||||
)
|
||||
from models.presentation_structure_model import (
|
||||
PresentationStructureModel,
|
||||
)
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
from utils.get_dynamic_models import get_presentation_structure_model_with_n_slides
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data: str):
|
||||
def get_messages(
|
||||
presentation_layout: PresentationLayoutModel, n_slides: int, data: str
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=f"""
|
||||
You're a professional presentation designer with creative freedom to design engaging presentations.
|
||||
|
||||
{presentation_layout.to_string()}
|
||||
|
|
@ -51,13 +46,13 @@ def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data
|
|||
|
||||
Select layout index for each of the {n_slides} slides based on what will best serve the presentation's goals.
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"""
|
||||
{data}
|
||||
""",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -66,20 +61,20 @@ async def generate_presentation_structure(
|
|||
presentation_layout: PresentationLayoutModel,
|
||||
) -> PresentationStructureModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_prompt(
|
||||
messages=get_messages(
|
||||
presentation_layout,
|
||||
len(presentation_outline.slides),
|
||||
presentation_outline.to_string(),
|
||||
),
|
||||
response_format=response_model,
|
||||
response_format=response_model.model_json_schema(),
|
||||
strict=True,
|
||||
)
|
||||
print(response.choices[0].message.parsed)
|
||||
return response.choices[0].message.parsed
|
||||
return PresentationStructureModel(**response)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,8 @@
|
|||
import asyncio
|
||||
import json
|
||||
from google.genai.types import GenerateContentConfig
|
||||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from utils.llm_provider import (
|
||||
get_google_llm_client,
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
is_google_selected,
|
||||
)
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
system_prompt = """
|
||||
|
|
@ -45,57 +39,38 @@ def get_user_prompt(title: str, outline: str, language: str):
|
|||
"""
|
||||
|
||||
|
||||
def get_prompt_to_generate_slide_content(title: str, outline: str, language: str):
|
||||
def get_messages(title: str, outline: str, language: str):
|
||||
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(title, outline, language),
|
||||
},
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=get_user_prompt(title, outline, language),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel, language: str
|
||||
):
|
||||
model = get_large_model()
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
response_schema = remove_fields_from_schema(
|
||||
slide_layout.json_schema, ["__image_url__", "__icon_url__"]
|
||||
)
|
||||
|
||||
if not is_google_selected():
|
||||
client = get_llm_client()
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_generate_slide_content(
|
||||
outline.title,
|
||||
outline.body,
|
||||
language,
|
||||
),
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "SlideContent",
|
||||
"schema": response_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
else:
|
||||
client = get_google_llm_client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model=model,
|
||||
contents=[get_user_prompt(outline.title, outline.body, language)],
|
||||
config=GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
response_mime_type="application/json",
|
||||
response_json_schema=response_schema,
|
||||
),
|
||||
)
|
||||
return json.loads(response.text)
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
messages=get_messages(
|
||||
outline.title,
|
||||
outline.body,
|
||||
language,
|
||||
),
|
||||
response_format=response_schema,
|
||||
strict=False,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from models.llm_message import LLMMessage
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.slide_layout_index import SlideLayoutIndex
|
||||
from models.sql.slide import SlideModel
|
||||
from utils.llm_provider import get_large_model, get_llm_client
|
||||
from services.llm_client import LLMClient
|
||||
from utils.llm_provider import get_model
|
||||
|
||||
|
||||
def get_prompt_to_select_slide_layout(
|
||||
def get_messages(
|
||||
prompt: str,
|
||||
slide_data: dict,
|
||||
layout: PresentationLayoutModel,
|
||||
current_slide_layout: int,
|
||||
):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
LLMMessage(
|
||||
role="system",
|
||||
content=f"""
|
||||
Select a Slide Layout index based on provided user prompt and current slide data.
|
||||
{layout.to_string()}
|
||||
|
||||
|
|
@ -23,15 +25,15 @@ def get_prompt_to_select_slide_layout(
|
|||
- If user prompt is not clear, select the layout that is most relevant to the slide data.
|
||||
**Go through all notes and steps and make sure they are followed, including mentioned constraints**
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
),
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"""
|
||||
- User Prompt: {prompt}
|
||||
- Current Slide Data: {slide_data}
|
||||
- Current Slide Layout: {current_slide_layout}
|
||||
""",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -41,21 +43,21 @@ async def get_slide_layout_from_prompt(
|
|||
slide: SlideModel,
|
||||
) -> SlideLayoutModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
client = LLMClient()
|
||||
model = get_model()
|
||||
|
||||
slide_layout_ids = list(map(lambda x: x.id, layout.slides))
|
||||
slide_layout_index = layout.get_slide_layout_index(slide.layout)
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
response = await client.generate_structured(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_to_select_slide_layout(
|
||||
messages=get_messages(
|
||||
prompt,
|
||||
slide.content,
|
||||
layout,
|
||||
slide_layout_ids.index(slide.layout),
|
||||
slide_layout_index,
|
||||
),
|
||||
response_format=SlideLayoutIndex,
|
||||
response_format=SlideLayoutIndex.model_json_schema(),
|
||||
strict=True,
|
||||
)
|
||||
index = response.choices[0].message.parsed.index
|
||||
index = SlideLayoutIndex(**response).index
|
||||
return layout.slides[index]
|
||||
|
|
|
|||
|
|
@ -1,36 +1,31 @@
|
|||
import os
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
||||
from constants.llm import (
|
||||
DEFAULT_ANTHROPIC_MODEL,
|
||||
DEFAULT_GOOGLE_MODEL,
|
||||
DEFAULT_OPENAI_MODEL,
|
||||
)
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.get_env import (
|
||||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_anthropic_model_env,
|
||||
get_custom_model_env,
|
||||
get_google_api_key_env,
|
||||
get_google_model_env,
|
||||
get_llm_provider_env,
|
||||
get_ollama_model_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_llm_provider():
|
||||
try:
|
||||
return LLMProvider(get_llm_provider_env())
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, ollama, custom",
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
)
|
||||
|
||||
|
||||
def get_ollama_url():
|
||||
return get_ollama_url_env() or "http://localhost:11434"
|
||||
|
||||
|
||||
def is_openai_selected():
|
||||
return get_llm_provider() == LLMProvider.OPENAI
|
||||
|
||||
|
|
@ -39,6 +34,10 @@ def is_google_selected():
|
|||
return get_llm_provider() == LLMProvider.GOOGLE
|
||||
|
||||
|
||||
def is_anthropic_selected():
|
||||
return get_llm_provider() == LLMProvider.ANTHROPIC
|
||||
|
||||
|
||||
def is_ollama_selected():
|
||||
return get_llm_provider() == LLMProvider.OLLAMA
|
||||
|
||||
|
|
@ -47,85 +46,20 @@ def is_custom_llm_selected():
|
|||
return get_llm_provider() == LLMProvider.CUSTOM
|
||||
|
||||
|
||||
def get_model_base_url():
|
||||
selected_llm = get_llm_provider()
|
||||
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return os.path.join(get_ollama_url(), "v1")
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_llm_url_env()
|
||||
else:
|
||||
raise HTTPException(f"LLM provider {selected_llm} is not supported")
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
def get_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return get_openai_api_key_env()
|
||||
return get_openai_model_env() or DEFAULT_OPENAI_MODEL
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return get_google_api_key_env()
|
||||
return get_google_model_env() or DEFAULT_GOOGLE_MODEL
|
||||
elif selected_llm == LLMProvider.ANTHROPIC:
|
||||
return get_anthropic_model_env() or DEFAULT_ANTHROPIC_MODEL
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return "ollama"
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_llm_api_key_env() or "none"
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise HTTPException(f"LLM provider {selected_llm} is not supported")
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
client = AsyncOpenAI(
|
||||
base_url=get_model_base_url(),
|
||||
api_key=get_llm_api_key(),
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Invalid LLM provider. Please select one of: openai, google, anthropic, ollama, custom",
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_google_llm_client():
|
||||
client = genai.Client(api_key=get_llm_api_key())
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return get_ollama_model_env()
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return get_custom_model_env()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
import os
|
||||
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from constants.llm import OPENAI_URL
|
||||
from enums.image_provider import ImageProvider
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.custom_llm_provider import list_available_custom_models
|
||||
from utils.available_models import (
|
||||
list_available_anthropic_models,
|
||||
list_available_google_models,
|
||||
list_available_openai_compatible_models,
|
||||
)
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_anthropic_model_env,
|
||||
get_can_change_keys_env,
|
||||
get_google_model_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
get_pixabay_api_key_env,
|
||||
get_pexels_api_key_env,
|
||||
)
|
||||
|
|
@ -20,13 +28,7 @@ from utils.llm_provider import (
|
|||
is_ollama_selected,
|
||||
)
|
||||
from utils.ollama import pull_ollama_model
|
||||
from utils.image_provider import (
|
||||
get_selected_image_provider,
|
||||
is_pixels_selected,
|
||||
is_pixabay_selected,
|
||||
is_gemini_flash_selected,
|
||||
is_dalle3_selected,
|
||||
)
|
||||
from utils.image_provider import get_selected_image_provider
|
||||
|
||||
|
||||
async def check_llm_and_image_provider_api_or_model_availability():
|
||||
|
|
@ -36,11 +38,41 @@ async def check_llm_and_image_provider_api_or_model_availability():
|
|||
openai_api_key = get_openai_api_key_env()
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
openai_model = get_openai_model_env()
|
||||
if openai_model:
|
||||
available_models = await list_available_openai_compatible_models(
|
||||
OPENAI_URL, openai_api_key
|
||||
)
|
||||
if openai_model not in available_models:
|
||||
print("-" * 50)
|
||||
print("Available models: ", available_models)
|
||||
raise Exception(f"Model {openai_model} is not available")
|
||||
|
||||
elif get_llm_provider() == LLMProvider.GOOGLE:
|
||||
google_api_key = get_google_api_key_env()
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
google_model = get_google_model_env()
|
||||
if google_model:
|
||||
available_models = await list_available_google_models(google_api_key)
|
||||
if google_model not in available_models:
|
||||
print("-" * 50)
|
||||
print("Available models: ", available_models)
|
||||
raise Exception(f"Model {google_model} is not available")
|
||||
|
||||
elif get_llm_provider() == LLMProvider.ANTHROPIC:
|
||||
anthropic_api_key = get_anthropic_api_key_env()
|
||||
if not anthropic_api_key:
|
||||
raise Exception("ANTHROPIC_API_KEY must be provided")
|
||||
anthropic_model = get_anthropic_model_env()
|
||||
if anthropic_model:
|
||||
available_models = await list_available_anthropic_models(
|
||||
anthropic_api_key
|
||||
)
|
||||
if anthropic_model not in available_models:
|
||||
print("-" * 50)
|
||||
print("Available models: ", available_models)
|
||||
raise Exception(f"Model {anthropic_model} is not available")
|
||||
|
||||
elif is_ollama_selected():
|
||||
ollama_model = get_ollama_model_env()
|
||||
|
|
@ -67,14 +99,12 @@ async def check_llm_and_image_provider_api_or_model_availability():
|
|||
raise Exception("CUSTOM_LLM_URL must be provided")
|
||||
if not custom_llm_api_key:
|
||||
raise Exception("CUSTOM_LLM_API_KEY must be provided")
|
||||
print("-" * 50)
|
||||
print("Selecting model: ", custom_model)
|
||||
models = await list_available_custom_models(
|
||||
available_models = await list_available_openai_compatible_models(
|
||||
custom_llm_url, custom_llm_api_key
|
||||
)
|
||||
print("Available models: ", models)
|
||||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
print("Available models: ", available_models)
|
||||
if custom_model not in available_models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
|
||||
# Check for Image Provider and API keys
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from http.client import HTTPException
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from utils.get_env import get_ollama_url_env
|
||||
|
|
|
|||
|
|
@ -21,19 +21,19 @@ async def process_slide_and_fetch_assets(
|
|||
icon_paths = get_dict_paths_with_key(slide.content, "__icon_query__")
|
||||
|
||||
for image_path in image_paths:
|
||||
image_prompt_parent = get_dict_at_path(slide.content, image_path)
|
||||
__image_prompt__parent = get_dict_at_path(slide.content, image_path)
|
||||
async_tasks.append(
|
||||
image_generation_service.generate_image(
|
||||
ImagePrompt(
|
||||
prompt=image_prompt_parent["__image_prompt__"],
|
||||
prompt=__image_prompt__parent["__image_prompt__"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for icon_path in icon_paths:
|
||||
icon_query_parent = get_dict_at_path(slide.content, icon_path)
|
||||
__icon_query__parent = get_dict_at_path(slide.content, icon_path)
|
||||
async_tasks.append(
|
||||
icon_finder_service.search_icons(icon_query_parent["__icon_query__"])
|
||||
icon_finder_service.search_icons(__icon_query__parent["__icon_query__"])
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*async_tasks)
|
||||
|
|
|
|||
|
|
@ -1,30 +1,25 @@
|
|||
from copy import deepcopy
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from utils.dict_utils import get_dict_paths_with_key, get_dict_at_path
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from utils.dict_utils import (
|
||||
get_dict_paths_with_key,
|
||||
get_dict_at_path,
|
||||
has_more_than_n_keys,
|
||||
)
|
||||
|
||||
def resolve_refs(schema, defs):
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"]
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_key = ref_path.replace("#/$defs/", "")
|
||||
return resolve_refs(defs[def_key], defs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported $ref path: {ref_path}")
|
||||
else:
|
||||
return {k: resolve_refs(v, defs) for k, v in schema.items()}
|
||||
elif isinstance(schema, list):
|
||||
return [resolve_refs(item, defs) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
|
||||
def flatten_schema(schema):
|
||||
schema = deepcopy(schema)
|
||||
defs = schema.pop("$defs", {})
|
||||
return resolve_refs(schema, defs)
|
||||
supported_string_formats = [
|
||||
"date-time",
|
||||
"time",
|
||||
"date",
|
||||
"duration",
|
||||
"email",
|
||||
"hostname",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
||||
def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
|
||||
|
|
@ -50,6 +45,138 @@ def remove_fields_from_schema(schema: dict, fields_to_remove: List[str]):
|
|||
return schema
|
||||
|
||||
|
||||
# From OpenAI
|
||||
def ensure_strict_json_schema(
|
||||
json_schema: object,
|
||||
*,
|
||||
path: tuple[str, ...],
|
||||
root: dict[str, object],
|
||||
) -> dict[str, Any]:
|
||||
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
|
||||
that the API expects.
|
||||
"""
|
||||
if not isinstance(json_schema, dict):
|
||||
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
|
||||
|
||||
defs = json_schema.get("$defs")
|
||||
if isinstance(defs, dict):
|
||||
for def_name, def_schema in defs.items():
|
||||
ensure_strict_json_schema(
|
||||
def_schema, path=(*path, "$defs", def_name), root=root
|
||||
)
|
||||
|
||||
definitions = json_schema.get("definitions")
|
||||
if isinstance(definitions, dict):
|
||||
for definition_name, definition_schema in definitions.items():
|
||||
ensure_strict_json_schema(
|
||||
definition_schema,
|
||||
path=(*path, "definitions", definition_name),
|
||||
root=root,
|
||||
)
|
||||
|
||||
typ = json_schema.get("type")
|
||||
if typ == "object" and "additionalProperties" not in json_schema:
|
||||
json_schema["additionalProperties"] = False
|
||||
|
||||
# object types
|
||||
# { 'type': 'object', 'properties': { 'a': {...} } }
|
||||
properties = json_schema.get("properties")
|
||||
if isinstance(properties, dict):
|
||||
json_schema["required"] = [prop for prop in properties.keys()]
|
||||
json_schema["properties"] = {
|
||||
key: ensure_strict_json_schema(
|
||||
prop_schema, path=(*path, "properties", key), root=root
|
||||
)
|
||||
for key, prop_schema in properties.items()
|
||||
}
|
||||
|
||||
# arrays
|
||||
# { 'type': 'array', 'items': {...} }
|
||||
items = json_schema.get("items")
|
||||
if isinstance(items, dict):
|
||||
json_schema["items"] = ensure_strict_json_schema(
|
||||
items, path=(*path, "items"), root=root
|
||||
)
|
||||
|
||||
# unions
|
||||
any_of = json_schema.get("anyOf")
|
||||
if isinstance(any_of, list):
|
||||
json_schema["anyOf"] = [
|
||||
ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
|
||||
for i, variant in enumerate(any_of)
|
||||
]
|
||||
|
||||
# intersections
|
||||
all_of = json_schema.get("allOf")
|
||||
if isinstance(all_of, list):
|
||||
if len(all_of) == 1:
|
||||
json_schema.update(
|
||||
ensure_strict_json_schema(
|
||||
all_of[0], path=(*path, "allOf", "0"), root=root
|
||||
)
|
||||
)
|
||||
json_schema.pop("allOf")
|
||||
else:
|
||||
json_schema["allOf"] = [
|
||||
ensure_strict_json_schema(
|
||||
entry, path=(*path, "allOf", str(i)), root=root
|
||||
)
|
||||
for i, entry in enumerate(all_of)
|
||||
]
|
||||
|
||||
# string
|
||||
if typ == "string":
|
||||
if "format" in json_schema:
|
||||
if json_schema["format"] not in supported_string_formats:
|
||||
del json_schema["format"]
|
||||
|
||||
# strip `None` defaults as there's no meaningful distinction here
|
||||
# the schema will still be `nullable` and the model will default
|
||||
# to using `None` anyway
|
||||
if json_schema.get("default", NOT_GIVEN) is None:
|
||||
json_schema.pop("default")
|
||||
|
||||
# we can't use `$ref`s if there are also other properties defined, e.g.
|
||||
# `{"$ref": "...", "description": "my description"}`
|
||||
#
|
||||
# so we unravel the ref
|
||||
# `{"type": "string", "description": "my description"}`
|
||||
ref = json_schema.get("$ref")
|
||||
if ref and has_more_than_n_keys(json_schema, 1):
|
||||
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
|
||||
|
||||
resolved = resolve_ref(root=root, ref=ref)
|
||||
if not isinstance(resolved, dict):
|
||||
raise ValueError(
|
||||
f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}"
|
||||
)
|
||||
|
||||
# properties from the json schema take priority over the ones on the `$ref`
|
||||
json_schema.update({**resolved, **json_schema})
|
||||
json_schema.pop("$ref")
|
||||
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
|
||||
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
|
||||
return ensure_strict_json_schema(json_schema, path=path, root=root)
|
||||
|
||||
return json_schema
|
||||
|
||||
|
||||
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
|
||||
if not ref.startswith("#/"):
|
||||
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
|
||||
|
||||
path = ref[2:].split("/")
|
||||
resolved = root
|
||||
for key in path:
|
||||
value = resolved[key]
|
||||
assert isinstance(
|
||||
value, dict
|
||||
), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
|
||||
resolved = value
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
# ? Not used
|
||||
def generate_constraint_sentences(schema: dict) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,10 +25,26 @@ def set_openai_api_key_env(value):
|
|||
os.environ["OPENAI_API_KEY"] = value
|
||||
|
||||
|
||||
def set_openai_model_env(value):
|
||||
os.environ["OPENAI_MODEL"] = value
|
||||
|
||||
|
||||
def set_google_api_key_env(value):
|
||||
os.environ["GOOGLE_API_KEY"] = value
|
||||
|
||||
|
||||
def set_google_model_env(value):
|
||||
os.environ["GOOGLE_MODEL"] = value
|
||||
|
||||
|
||||
def set_anthropic_api_key_env(value):
|
||||
os.environ["ANTHROPIC_API_KEY"] = value
|
||||
|
||||
|
||||
def set_anthropic_model_env(value):
|
||||
os.environ["ANTHROPIC_MODEL"] = value
|
||||
|
||||
|
||||
def set_custom_llm_api_key_env(value):
|
||||
os.environ["CUSTOM_LLM_API_KEY"] = value
|
||||
|
||||
|
|
@ -51,3 +67,7 @@ def set_image_provider_env(value):
|
|||
|
||||
def set_pixabay_api_key_env(value):
|
||||
os.environ["PIXABAY_API_KEY"] = value
|
||||
|
||||
|
||||
def set_extended_reasoning_env(value):
|
||||
os.environ["EXTENDED_REASONING"] = value
|
||||
|
|
|
|||
|
|
@ -3,31 +3,41 @@ import json
|
|||
|
||||
from models.user_config import UserConfig
|
||||
from utils.get_env import (
|
||||
get_anthropic_api_key_env,
|
||||
get_anthropic_model_env,
|
||||
get_custom_llm_api_key_env,
|
||||
get_custom_llm_url_env,
|
||||
get_custom_model_env,
|
||||
get_google_api_key_env,
|
||||
get_google_model_env,
|
||||
get_llm_provider_env,
|
||||
get_ollama_model_env,
|
||||
get_ollama_url_env,
|
||||
get_openai_api_key_env,
|
||||
get_openai_model_env,
|
||||
get_pexels_api_key_env,
|
||||
get_user_config_path_env,
|
||||
get_image_provider_env,
|
||||
get_pixabay_api_key_env
|
||||
get_pixabay_api_key_env,
|
||||
get_extended_reasoning_env,
|
||||
)
|
||||
from utils.set_env import (
|
||||
set_anthropic_api_key_env,
|
||||
set_anthropic_model_env,
|
||||
set_custom_llm_api_key_env,
|
||||
set_custom_llm_url_env,
|
||||
set_custom_model_env,
|
||||
set_extended_reasoning_env,
|
||||
set_google_api_key_env,
|
||||
set_google_model_env,
|
||||
set_llm_provider_env,
|
||||
set_ollama_model_env,
|
||||
set_ollama_url_env,
|
||||
set_openai_api_key_env,
|
||||
set_openai_model_env,
|
||||
set_pexels_api_key_env,
|
||||
set_image_provider_env,
|
||||
set_pixabay_api_key_env
|
||||
set_pixabay_api_key_env,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -43,10 +53,21 @@ def get_user_config():
|
|||
print("Error while loading user config")
|
||||
pass
|
||||
|
||||
new_extended_reasoning = (
|
||||
existing_config.EXTENDED_REASONING or get_extended_reasoning_env()
|
||||
)
|
||||
if new_extended_reasoning is not None:
|
||||
new_extended_reasoning = bool(new_extended_reasoning)
|
||||
|
||||
return UserConfig(
|
||||
LLM=existing_config.LLM or get_llm_provider_env(),
|
||||
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or get_openai_api_key_env(),
|
||||
OPENAI_MODEL=existing_config.OPENAI_MODEL or get_openai_model_env(),
|
||||
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or get_google_api_key_env(),
|
||||
GOOGLE_MODEL=existing_config.GOOGLE_MODEL or get_google_model_env(),
|
||||
ANTHROPIC_API_KEY=existing_config.ANTHROPIC_API_KEY
|
||||
or get_anthropic_api_key_env(),
|
||||
ANTHROPIC_MODEL=existing_config.ANTHROPIC_MODEL or get_anthropic_model_env(),
|
||||
OLLAMA_URL=existing_config.OLLAMA_URL or get_ollama_url_env(),
|
||||
OLLAMA_MODEL=existing_config.OLLAMA_MODEL or get_ollama_model_env(),
|
||||
CUSTOM_LLM_URL=existing_config.CUSTOM_LLM_URL or get_custom_llm_url_env(),
|
||||
|
|
@ -56,6 +77,7 @@ def get_user_config():
|
|||
IMAGE_PROVIDER=existing_config.IMAGE_PROVIDER or get_image_provider_env(),
|
||||
PIXABAY_API_KEY=existing_config.PIXABAY_API_KEY or get_pixabay_api_key_env(),
|
||||
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(),
|
||||
EXTENDED_REASONING=new_extended_reasoning,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -65,8 +87,16 @@ def update_env_with_user_config():
|
|||
set_llm_provider_env(user_config.LLM)
|
||||
if user_config.OPENAI_API_KEY:
|
||||
set_openai_api_key_env(user_config.OPENAI_API_KEY)
|
||||
if user_config.OPENAI_MODEL:
|
||||
set_openai_model_env(user_config.OPENAI_MODEL)
|
||||
if user_config.GOOGLE_API_KEY:
|
||||
set_google_api_key_env(user_config.GOOGLE_API_KEY)
|
||||
if user_config.GOOGLE_MODEL:
|
||||
set_google_model_env(user_config.GOOGLE_MODEL)
|
||||
if user_config.ANTHROPIC_API_KEY:
|
||||
set_anthropic_api_key_env(user_config.ANTHROPIC_API_KEY)
|
||||
if user_config.ANTHROPIC_MODEL:
|
||||
set_anthropic_model_env(user_config.ANTHROPIC_MODEL)
|
||||
if user_config.OLLAMA_URL:
|
||||
set_ollama_url_env(user_config.OLLAMA_URL)
|
||||
if user_config.OLLAMA_MODEL:
|
||||
|
|
@ -83,3 +113,6 @@ def update_env_with_user_config():
|
|||
set_pixabay_api_key_env(user_config.PIXABAY_API_KEY)
|
||||
if user_config.PEXELS_API_KEY:
|
||||
set_pexels_api_key_env(user_config.PEXELS_API_KEY)
|
||||
if user_config.EXTENDED_REASONING:
|
||||
if user_config.EXTENDED_REASONING:
|
||||
set_extended_reasoning_env(str(user_config.EXTENDED_REASONING))
|
||||
|
|
|
|||
|
|
@ -35,7 +35,11 @@ export async function POST(request: Request) {
|
|||
const mergedConfig: LLMConfig = {
|
||||
LLM: userConfig.LLM || existingConfig.LLM,
|
||||
OPENAI_API_KEY: userConfig.OPENAI_API_KEY || existingConfig.OPENAI_API_KEY,
|
||||
OPENAI_MODEL: userConfig.OPENAI_MODEL || existingConfig.OPENAI_MODEL,
|
||||
GOOGLE_API_KEY: userConfig.GOOGLE_API_KEY || existingConfig.GOOGLE_API_KEY,
|
||||
GOOGLE_MODEL: userConfig.GOOGLE_MODEL || existingConfig.GOOGLE_MODEL,
|
||||
ANTHROPIC_API_KEY: userConfig.ANTHROPIC_API_KEY || existingConfig.ANTHROPIC_API_KEY,
|
||||
ANTHROPIC_MODEL: userConfig.ANTHROPIC_MODEL || existingConfig.ANTHROPIC_MODEL,
|
||||
OLLAMA_URL: userConfig.OLLAMA_URL || existingConfig.OLLAMA_URL,
|
||||
OLLAMA_MODEL: userConfig.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL,
|
||||
CUSTOM_LLM_URL: userConfig.CUSTOM_LLM_URL || existingConfig.CUSTOM_LLM_URL,
|
||||
|
|
@ -50,6 +54,10 @@ export async function POST(request: Request) {
|
|||
userConfig.USE_CUSTOM_URL === undefined
|
||||
? existingConfig.USE_CUSTOM_URL
|
||||
: userConfig.USE_CUSTOM_URL,
|
||||
EXTENDED_REASONING:
|
||||
userConfig.EXTENDED_REASONING === undefined
|
||||
? existingConfig.EXTENDED_REASONING
|
||||
: userConfig.EXTENDED_REASONING,
|
||||
};
|
||||
fs.writeFileSync(userConfigPath, JSON.stringify(mergedConfig));
|
||||
return NextResponse.json(mergedConfig);
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ const SettingsPage = () => {
|
|||
const userConfigState = useSelector((state: RootState) => state.userConfig);
|
||||
const [llmConfig, setLlmConfig] = useState<LLMConfig>(userConfigState.llm_config);
|
||||
const canChangeKeys = userConfigState.can_change_keys;
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
const [buttonState, setButtonState] = useState<ButtonState>({
|
||||
isLoading: false,
|
||||
isDisabled: false,
|
||||
|
|
@ -55,16 +54,13 @@ const SettingsPage = () => {
|
|||
|
||||
const handleSaveConfig = async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
isLoading: true,
|
||||
isDisabled: true,
|
||||
text: "Saving Configuration..."
|
||||
}));
|
||||
|
||||
await handleSaveLLMConfig(llmConfig);
|
||||
|
||||
if (llmConfig.LLM === "ollama" && llmConfig.OLLAMA_MODEL) {
|
||||
const isPulled = await checkIfSelectedOllamaModelIsPulled(llmConfig.OLLAMA_MODEL);
|
||||
if (!isPulled) {
|
||||
|
|
@ -72,24 +68,16 @@ const SettingsPage = () => {
|
|||
await handleModelDownload();
|
||||
}
|
||||
}
|
||||
|
||||
toast.info("Configuration saved successfully");
|
||||
setIsLoading(false);
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
isDisabled: false,
|
||||
text: "Save Configuration"
|
||||
}));
|
||||
router.back();
|
||||
router.push("/upload");
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
toast.info(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: "Failed to save configuration"
|
||||
);
|
||||
setIsLoading(false);
|
||||
toast.info(error instanceof Error ? error.message : "Failed to save configuration");
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
|
|
@ -102,8 +90,8 @@ const SettingsPage = () => {
|
|||
const handleModelDownload = async () => {
|
||||
try {
|
||||
await pullOllamaModel(llmConfig.OLLAMA_MODEL!, setDownloadingModel);
|
||||
} catch (error) {
|
||||
console.error("Error downloading model:", error);
|
||||
}
|
||||
finally {
|
||||
setDownloadingModel(null);
|
||||
setShowDownloadModal(false);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) {
|
|||
}
|
||||
}
|
||||
if (llmConfig.LLM === 'custom') {
|
||||
const isAvailable = await checkIfSelectedCustomModelIsAvailable(llmConfig.CUSTOM_MODEL);
|
||||
const isAvailable = await checkIfSelectedCustomModelIsAvailable(llmConfig);
|
||||
if (!isAvailable) {
|
||||
router.push('/');
|
||||
setLoadingToFalseAfterNavigatingTo('/');
|
||||
|
|
@ -83,16 +83,20 @@ export function StoreInitializer({ children }: { children: React.ReactNode }) {
|
|||
}
|
||||
|
||||
|
||||
const checkIfSelectedCustomModelIsAvailable = async (customModel: string) => {
|
||||
const checkIfSelectedCustomModelIsAvailable = async (llmConfig: LLMConfig) => {
|
||||
try {
|
||||
const response = await fetch('/api/v1/ppt/custom_llm/models/available', {
|
||||
const response = await fetch('/api/v1/ppt/openai/models/available', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
url: llmConfig.CUSTOM_LLM_URL,
|
||||
api_key: llmConfig.CUSTOM_LLM_API_KEY,
|
||||
}),
|
||||
});
|
||||
const data = await response.json();
|
||||
return data.includes(customModel);
|
||||
return data.includes(llmConfig.CUSTOM_MODEL);
|
||||
} catch (error) {
|
||||
console.error('Error fetching custom models:', error);
|
||||
return false;
|
||||
|
|
|
|||
230
servers/nextjs/components/AnthropicConfig.tsx
Normal file
230
servers/nextjs/components/AnthropicConfig.tsx
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Check, ChevronsUpDown, Loader2 } from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "./ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
import { Switch } from "./ui/switch";
|
||||
|
||||
interface AnthropicConfigProps {
|
||||
anthropicApiKey: string;
|
||||
anthropicModel: string;
|
||||
extendedReasoning: boolean;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
|
||||
export default function AnthropicConfig({
|
||||
anthropicApiKey,
|
||||
anthropicModel,
|
||||
extendedReasoning,
|
||||
onInputChange,
|
||||
}: AnthropicConfigProps) {
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [availableModels, setAvailableModels] = useState<string[]>([]);
|
||||
const [modelsLoading, setModelsLoading] = useState(false);
|
||||
const [modelsChecked, setModelsChecked] = useState(false);
|
||||
const [apiKey, setApiKey] = useState(anthropicApiKey);
|
||||
|
||||
useEffect(() => {
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(false);
|
||||
onInputChange("", "anthropic_model");
|
||||
}, [apiKey]);
|
||||
|
||||
const onApiKeyChange = (value: string) => {
|
||||
setApiKey(value);
|
||||
onInputChange(value, "anthropic_api_key");
|
||||
};
|
||||
|
||||
const fetchAvailableModels = async () => {
|
||||
if (!anthropicApiKey) return;
|
||||
|
||||
setModelsLoading(true);
|
||||
try {
|
||||
const response = await fetch('/api/v1/ppt/anthropic/models/available', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_key: anthropicApiKey
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
setAvailableModels(data);
|
||||
setModelsChecked(true);
|
||||
} else {
|
||||
console.error('Failed to fetch models');
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(true);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
toast.error('Error fetching models');
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(true);
|
||||
} finally {
|
||||
setModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* API Key Input */}
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Anthropic API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
value={anthropicApiKey}
|
||||
onChange={(e) => onApiKeyChange(e.target.value)}
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
placeholder="Enter your Anthropic API key"
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Your API key will be stored locally and never shared
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Extended Reasoning Toggle */}
|
||||
{/* <div>
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Extended Reasoning
|
||||
</label>
|
||||
<Switch
|
||||
checked={extendedReasoning}
|
||||
onCheckedChange={(checked) => onInputChange(checked, "extended_reasoning")}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Enable extended reasoning for more detailed and thorough responses
|
||||
</p>
|
||||
</div> */}
|
||||
|
||||
{/* Check for available models button - show when no models checked or no models found */}
|
||||
{(!modelsChecked || (modelsChecked && availableModels.length === 0)) && (
|
||||
<div className="mb-4">
|
||||
<button
|
||||
onClick={fetchAvailableModels}
|
||||
disabled={modelsLoading || !anthropicApiKey}
|
||||
className={`w-full py-2.5 px-4 rounded-lg transition-all duration-200 border-2 ${modelsLoading || !anthropicApiKey
|
||||
? "bg-gray-100 border-gray-300 cursor-not-allowed text-gray-500"
|
||||
: "bg-white border-blue-600 text-blue-600 hover:bg-blue-50 focus:ring-2 focus:ring-blue-500/20"
|
||||
}`}
|
||||
>
|
||||
{modelsLoading ? (
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Checking for models...
|
||||
</div>
|
||||
) : (
|
||||
"Check for available models"
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show message if no models found */}
|
||||
{modelsChecked && availableModels.length === 0 && (
|
||||
<div className="mb-4 p-3 bg-yellow-50 border border-yellow-200 rounded-lg">
|
||||
<p className="text-sm text-yellow-800">
|
||||
No models found. Please make sure your API key is valid and has access to Anthropic models.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Model Selection - only show if models are available */}
|
||||
{modelsChecked && availableModels.length > 0 ? (
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Select Anthropic Model
|
||||
</label>
|
||||
<div className="w-full">
|
||||
<Popover
|
||||
open={openModelSelect}
|
||||
onOpenChange={setOpenModelSelect}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{anthropicModel
|
||||
? availableModels.find(model => model === anthropicModel) || anthropicModel
|
||||
: "Select a model"}
|
||||
</span>
|
||||
</div>
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="p-0"
|
||||
align="start"
|
||||
style={{ width: "var(--radix-popover-trigger-width)" }}
|
||||
>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search models..." />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{availableModels.map((model, index) => (
|
||||
<CommandItem
|
||||
key={index}
|
||||
value={model}
|
||||
onSelect={(value) => {
|
||||
onInputChange(value, "anthropic_model");
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
anthropicModel === model
|
||||
? "opacity-100"
|
||||
: "opacity-0"
|
||||
)}
|
||||
/>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{model}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
"use client";
|
||||
import { useState, useEffect } from "react";
|
||||
import { Check, ChevronsUpDown, Loader2 } from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
|
|
@ -11,34 +12,83 @@ import {
|
|||
} from "./ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface CustomConfigProps {
|
||||
customLlmUrl: string;
|
||||
customLlmApiKey: string;
|
||||
customModel: string;
|
||||
customModels: string[];
|
||||
customModelsLoading: boolean;
|
||||
customModelsChecked: boolean;
|
||||
openModelSelect: boolean;
|
||||
onInputChange: (value: string, field: string) => void;
|
||||
onOpenModelSelectChange: (open: boolean) => void;
|
||||
onFetchCustomModels: () => void;
|
||||
}
|
||||
|
||||
export default function CustomConfig({
|
||||
customLlmUrl,
|
||||
customLlmApiKey,
|
||||
customModel,
|
||||
customModels,
|
||||
customModelsLoading,
|
||||
customModelsChecked,
|
||||
openModelSelect,
|
||||
onInputChange,
|
||||
onOpenModelSelectChange,
|
||||
onFetchCustomModels,
|
||||
}: CustomConfigProps) {
|
||||
const [customModels, setCustomModels] = useState<string[]>([]);
|
||||
const [customModelsLoading, setCustomModelsLoading] = useState(false);
|
||||
const [customModelsChecked, setCustomModelsChecked] = useState(false);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [url, setUrl] = useState(customLlmUrl);
|
||||
const [apiKey, setApiKey] = useState(customLlmApiKey);
|
||||
|
||||
useEffect(() => {
|
||||
setCustomModels([]);
|
||||
setCustomModelsChecked(false);
|
||||
onInputChange("", "custom_model");
|
||||
}, [url, apiKey]);
|
||||
|
||||
const onUrlChange = (value: string) => {
|
||||
setUrl(value);
|
||||
onInputChange(value, "custom_llm_url");
|
||||
};
|
||||
|
||||
const onApiKeyChange = (value: string) => {
|
||||
setApiKey(value);
|
||||
onInputChange(value, "custom_llm_api_key");
|
||||
};
|
||||
|
||||
const fetchCustomModels = async () => {
|
||||
if (!customLlmUrl) return;
|
||||
|
||||
try {
|
||||
setCustomModelsLoading(true);
|
||||
const response = await fetch("/api/v1/ppt/openai/models/available", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
url: customLlmUrl,
|
||||
api_key: customLlmApiKey,
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
setCustomModels(data);
|
||||
setCustomModelsChecked(true);
|
||||
} else {
|
||||
console.error('Failed to fetch custom models');
|
||||
setCustomModels([]);
|
||||
setCustomModelsChecked(true);
|
||||
toast.error('Failed to fetch custom models');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching custom models:', error);
|
||||
toast.error('Error fetching custom models');
|
||||
setCustomModels([]);
|
||||
setCustomModelsChecked(true);
|
||||
} finally {
|
||||
setCustomModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="space-y-6">
|
||||
{/* URL Input */}
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
OpenAI Compatible URL
|
||||
|
|
@ -50,12 +100,12 @@ export default function CustomConfig({
|
|||
placeholder="Enter your URL"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={customLlmUrl}
|
||||
onChange={(e) =>
|
||||
onInputChange(e.target.value, "custom_llm_url")
|
||||
}
|
||||
onChange={(e) => onUrlChange(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* API Key Input */}
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
OpenAI Compatible API Key
|
||||
|
|
@ -67,13 +117,43 @@ export default function CustomConfig({
|
|||
placeholder="Enter your API Key"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={customLlmApiKey}
|
||||
onChange={(e) =>
|
||||
onInputChange(e.target.value, "custom_llm_api_key")
|
||||
}
|
||||
onChange={(e) => onApiKeyChange(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Check for available models button - show when no models checked or no models found */}
|
||||
{(!customModelsChecked || (customModelsChecked && customModels.length === 0)) && (
|
||||
<div className="mb-4">
|
||||
<button
|
||||
onClick={fetchCustomModels}
|
||||
disabled={customModelsLoading || !customLlmUrl}
|
||||
className={`w-full py-2.5 px-4 rounded-lg transition-all duration-200 border-2 ${customModelsLoading || !customLlmUrl
|
||||
? "bg-gray-100 border-gray-300 cursor-not-allowed text-gray-500"
|
||||
: "bg-white border-blue-600 text-blue-600 hover:bg-blue-50 focus:ring-2 focus:ring-blue-500/20"
|
||||
}`}
|
||||
>
|
||||
{customModelsLoading ? (
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Checking for models...
|
||||
</div>
|
||||
) : (
|
||||
"Check for available models"
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show message if no models found */}
|
||||
{customModelsChecked && customModels.length === 0 && (
|
||||
<div className="mb-4 p-3 bg-yellow-50 border border-yellow-200 rounded-lg">
|
||||
<p className="text-sm text-yellow-800">
|
||||
No models found. Please make sure your API key is valid and has access to models.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Model selection dropdown - only show if models are available */}
|
||||
{customModelsChecked && customModels.length > 0 && (
|
||||
<div className="mb-4">
|
||||
|
|
@ -90,7 +170,7 @@ export default function CustomConfig({
|
|||
<div className="w-full">
|
||||
<Popover
|
||||
open={openModelSelect}
|
||||
onOpenChange={onOpenModelSelectChange}
|
||||
onOpenChange={setOpenModelSelect}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
|
|
@ -121,7 +201,7 @@ export default function CustomConfig({
|
|||
value={model}
|
||||
onSelect={(value) => {
|
||||
onInputChange(value, "custom_model");
|
||||
onOpenModelSelectChange(false);
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
|
|
@ -145,39 +225,6 @@ export default function CustomConfig({
|
|||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Check for available models button - show when no models checked or no models found */}
|
||||
{(!customModelsChecked ||
|
||||
(customModelsChecked && customModels.length === 0)) && (
|
||||
<div className="mb-4">
|
||||
<button
|
||||
onClick={onFetchCustomModels}
|
||||
disabled={customModelsLoading || !customLlmUrl}
|
||||
className={`w-full py-2.5 px-4 rounded-lg transition-all duration-200 border-2 ${customModelsLoading || !customLlmUrl
|
||||
? "bg-gray-100 border-gray-300 cursor-not-allowed text-gray-500"
|
||||
: "bg-white border-blue-600 text-blue-600 hover:bg-blue-50 focus:ring-2 focus:ring-blue-500/20"
|
||||
}`}
|
||||
>
|
||||
{customModelsLoading ? (
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Checking for models...
|
||||
</div>
|
||||
) : (
|
||||
"Check for available models"
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show message if no models found */}
|
||||
{customModelsChecked && customModels.length === 0 && (
|
||||
<div className="mb-4 p-3 bg-yellow-50 border border-yellow-200 rounded-lg">
|
||||
<p className="text-sm text-yellow-800">
|
||||
No models found. Please make sure models are available.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,27 +1,209 @@
|
|||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Check, ChevronsUpDown, Loader2 } from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "./ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface GoogleConfigProps {
|
||||
googleApiKey: string;
|
||||
googleModel: string;
|
||||
onInputChange: (value: string, field: string) => void;
|
||||
}
|
||||
|
||||
export default function GoogleConfig({ googleApiKey, onInputChange }: GoogleConfigProps) {
|
||||
export default function GoogleConfig({
|
||||
googleApiKey,
|
||||
googleModel,
|
||||
onInputChange
|
||||
}: GoogleConfigProps) {
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [availableModels, setAvailableModels] = useState<string[]>([]);
|
||||
const [modelsLoading, setModelsLoading] = useState(false);
|
||||
const [modelsChecked, setModelsChecked] = useState(false);
|
||||
const [apiKey, setApiKey] = useState(googleApiKey);
|
||||
|
||||
useEffect(() => {
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(false);
|
||||
onInputChange("", "google_model");
|
||||
}, [apiKey]);
|
||||
|
||||
const onApiKeyChange = (value: string) => {
|
||||
setApiKey(value);
|
||||
onInputChange(value, "google_api_key");
|
||||
};
|
||||
|
||||
const fetchAvailableModels = async () => {
|
||||
if (!googleApiKey) return;
|
||||
|
||||
setModelsLoading(true);
|
||||
try {
|
||||
const response = await fetch('/api/v1/ppt/google/models/available', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_key: googleApiKey
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
setAvailableModels(data);
|
||||
setModelsChecked(true);
|
||||
} else {
|
||||
console.error('Failed to fetch models');
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(true);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
toast.error('Error fetching models');
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(true);
|
||||
} finally {
|
||||
setModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mb-8">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Google API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
value={googleApiKey}
|
||||
onChange={(e) => onInputChange(e.target.value, "google_api_key")}
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
placeholder="Enter your API key"
|
||||
/>
|
||||
<div className="space-y-6">
|
||||
{/* API Key Input */}
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Google API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
value={googleApiKey}
|
||||
onChange={(e) => onApiKeyChange(e.target.value)}
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
placeholder="Enter your API key"
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Your API key will be stored locally and never shared
|
||||
</p>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Your API key will be stored locally and never shared
|
||||
</p>
|
||||
|
||||
{/* Check for available models button - show when no models checked or no models found */}
|
||||
{(!modelsChecked || (modelsChecked && availableModels.length === 0)) && (
|
||||
<div className="mb-4">
|
||||
<button
|
||||
onClick={fetchAvailableModels}
|
||||
disabled={modelsLoading || !googleApiKey}
|
||||
className={`w-full py-2.5 px-4 rounded-lg transition-all duration-200 border-2 ${modelsLoading || !googleApiKey
|
||||
? "bg-gray-100 border-gray-300 cursor-not-allowed text-gray-500"
|
||||
: "bg-white border-blue-600 text-blue-600 hover:bg-blue-50 focus:ring-2 focus:ring-blue-500/20"
|
||||
}`}
|
||||
>
|
||||
{modelsLoading ? (
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Checking for models...
|
||||
</div>
|
||||
) : (
|
||||
"Check for available models"
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show message if no models found */}
|
||||
{modelsChecked && availableModels.length === 0 && (
|
||||
<div className="mb-4 p-3 bg-yellow-50 border border-yellow-200 rounded-lg">
|
||||
<p className="text-sm text-yellow-800">
|
||||
No models found. Please make sure your API key is valid and has access to Google models.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Model Selection - only show if models are available */}
|
||||
{modelsChecked && availableModels.length > 0 ? (
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Select Google Model
|
||||
</label>
|
||||
<div className="w-full">
|
||||
<Popover
|
||||
open={openModelSelect}
|
||||
onOpenChange={setOpenModelSelect}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{googleModel
|
||||
? availableModels.find(model => model === googleModel) || googleModel
|
||||
: "Select a model"}
|
||||
</span>
|
||||
</div>
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="p-0"
|
||||
align="start"
|
||||
style={{ width: "var(--radix-popover-trigger-width)" }}
|
||||
>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search models..." />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{availableModels.map((model, index) => (
|
||||
<CommandItem
|
||||
key={index}
|
||||
value={model}
|
||||
onSelect={(value) => {
|
||||
onInputChange(value, "google_model");
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
googleModel === model
|
||||
? "opacity-100"
|
||||
: "opacity-0"
|
||||
)}
|
||||
/>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{model}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -27,7 +27,6 @@ export default function Home() {
|
|||
const router = useRouter();
|
||||
const config = useSelector((state: RootState) => state.userConfig);
|
||||
const [llmConfig, setLlmConfig] = useState<LLMConfig>(config.llm_config);
|
||||
const [isLoading, setIsLoading] = useState<boolean>(false);
|
||||
|
||||
const [downloadingModel, setDownloadingModel] = useState<{
|
||||
name: string;
|
||||
|
|
@ -54,7 +53,6 @@ export default function Home() {
|
|||
|
||||
const handleSaveConfig = async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
isLoading: true,
|
||||
|
|
@ -70,7 +68,6 @@ export default function Home() {
|
|||
}
|
||||
}
|
||||
toast.info("Configuration saved successfully");
|
||||
setIsLoading(false);
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
|
|
@ -79,8 +76,7 @@ export default function Home() {
|
|||
}));
|
||||
router.push("/upload");
|
||||
} catch (error) {
|
||||
toast.info("Failed to save configuration");
|
||||
setIsLoading(false);
|
||||
toast.info(error instanceof Error ? error.message : "Failed to save configuration");
|
||||
setButtonState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
|
|
@ -93,8 +89,8 @@ export default function Home() {
|
|||
const handleModelDownload = async () => {
|
||||
try {
|
||||
await pullOllamaModel(llmConfig.OLLAMA_MODEL!, setDownloadingModel);
|
||||
} catch (error) {
|
||||
console.info("Error downloading model:", error);
|
||||
}
|
||||
finally {
|
||||
setDownloadingModel(null);
|
||||
setShowDownloadModal(false);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,16 +15,13 @@ import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
|||
import { cn } from "@/lib/utils";
|
||||
import OpenAIConfig from "./OpenAIConfig";
|
||||
import GoogleConfig from "./GoogleConfig";
|
||||
import AnthropicConfig from "./AnthropicConfig";
|
||||
import OllamaConfig from "./OllamaConfig";
|
||||
import CustomConfig from "./CustomConfig";
|
||||
import {
|
||||
OllamaModel,
|
||||
LLMConfig,
|
||||
updateLLMConfig,
|
||||
changeProvider as changeProviderUtil,
|
||||
fetchOllamaModelsWithConfig,
|
||||
setOllamaConfig,
|
||||
fetchCustomModels,
|
||||
} from "@/utils/providerUtils";
|
||||
import { IMAGE_PROVIDERS, LLM_PROVIDERS } from "@/utils/providerConstants";
|
||||
|
||||
|
|
@ -51,29 +48,34 @@ export default function LLMProviderSelection({
|
|||
setButtonState,
|
||||
}: LLMProviderSelectionProps) {
|
||||
const [llmConfig, setLlmConfig] = useState<LLMConfig>(initialLLMConfig);
|
||||
const [ollamaModels, setOllamaModels] = useState<OllamaModel[]>([]);
|
||||
const [customModels, setCustomModels] = useState<string[]>([]);
|
||||
const [customModelsLoading, setCustomModelsLoading] = useState<boolean>(false);
|
||||
const [customModelsChecked, setCustomModelsChecked] = useState<boolean>(false);
|
||||
const [ollamaModelsLoading, setOllamaModelsLoading] = useState<boolean>(false);
|
||||
const [useCustomOllamaUrl, setUseCustomOllamaUrl] = useState<boolean>(
|
||||
initialLLMConfig.USE_CUSTOM_URL || false
|
||||
);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [openImageProviderSelect, setOpenImageProviderSelect] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!llmConfig.USE_CUSTOM_URL) {
|
||||
setLlmConfig({ ...llmConfig, OLLAMA_URL: "http://localhost:11434" });
|
||||
} else {
|
||||
if (!llmConfig.OLLAMA_URL) {
|
||||
setLlmConfig({ ...llmConfig, OLLAMA_URL: "http://localhost:11434" });
|
||||
}
|
||||
}
|
||||
}, [llmConfig.USE_CUSTOM_URL]);
|
||||
|
||||
useEffect(() => {
|
||||
onConfigChange(llmConfig);
|
||||
}, [llmConfig]);
|
||||
|
||||
useEffect(() => {
|
||||
const needsModelSelection =
|
||||
(llmConfig.LLM === "openai" && !llmConfig.OPENAI_MODEL) ||
|
||||
(llmConfig.LLM === "google" && !llmConfig.GOOGLE_MODEL) ||
|
||||
(llmConfig.LLM === "ollama" && !llmConfig.OLLAMA_MODEL) ||
|
||||
(llmConfig.LLM === "custom" && !llmConfig.CUSTOM_MODEL);
|
||||
(llmConfig.LLM === "custom" && !llmConfig.CUSTOM_MODEL) ||
|
||||
(llmConfig.LLM === "anthropic" && !llmConfig.ANTHROPIC_MODEL);
|
||||
|
||||
const needsApiKey =
|
||||
((llmConfig.IMAGE_PROVIDER === "dall-e-3" || llmConfig.LLM === "openai") && !llmConfig.OPENAI_API_KEY) ||
|
||||
((llmConfig.IMAGE_PROVIDER === "gemini_flash" || llmConfig.LLM === "google") && !llmConfig.GOOGLE_API_KEY) ||
|
||||
(llmConfig.LLM === "anthropic" && !llmConfig.ANTHROPIC_API_KEY) ||
|
||||
(llmConfig.IMAGE_PROVIDER === "pexels" && !llmConfig.PEXELS_API_KEY) ||
|
||||
(llmConfig.IMAGE_PROVIDER === "pixabay" && !llmConfig.PIXABAY_API_KEY);
|
||||
|
||||
|
|
@ -86,7 +88,7 @@ export default function LLMProviderSelection({
|
|||
|
||||
}, [llmConfig]);
|
||||
|
||||
const input_field_changed = (new_value: string, field: string) => {
|
||||
const input_field_changed = (new_value: string | boolean, field: string) => {
|
||||
const updatedConfig = updateLLMConfig(llmConfig, field, new_value);
|
||||
setLlmConfig(updatedConfig);
|
||||
};
|
||||
|
|
@ -94,68 +96,8 @@ export default function LLMProviderSelection({
|
|||
const handleProviderChange = (provider: string) => {
|
||||
const newConfig = changeProviderUtil(llmConfig, provider);
|
||||
setLlmConfig(newConfig);
|
||||
if (provider === "ollama") {
|
||||
fetchOllamaModels();
|
||||
}
|
||||
};
|
||||
|
||||
const fetchOllamaModels = async () => {
|
||||
try {
|
||||
setOllamaModelsLoading(true);
|
||||
const result = await fetchOllamaModelsWithConfig(llmConfig);
|
||||
setOllamaModels(result.models);
|
||||
if (result.updatedConfig) {
|
||||
setLlmConfig(result.updatedConfig);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching Ollama models:", error);
|
||||
setOllamaModels([]);
|
||||
} finally {
|
||||
setOllamaModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const fetchCustomModelsHandler = async () => {
|
||||
try {
|
||||
setCustomModelsLoading(true);
|
||||
const models = await fetchCustomModels(
|
||||
llmConfig.CUSTOM_LLM_URL || "",
|
||||
llmConfig.CUSTOM_LLM_API_KEY || ""
|
||||
);
|
||||
setCustomModels(models);
|
||||
setCustomModelsChecked(true);
|
||||
} catch (error) {
|
||||
console.error("Error fetching custom models:", error);
|
||||
setCustomModels([]);
|
||||
} finally {
|
||||
setCustomModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const setOllamaConfigHandler = () => {
|
||||
const updatedConfig = setOllamaConfig(llmConfig, useCustomOllamaUrl);
|
||||
setLlmConfig(updatedConfig);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (llmConfig.LLM === "ollama") {
|
||||
setOllamaConfigHandler();
|
||||
fetchOllamaModels();
|
||||
}
|
||||
}, [llmConfig.LLM]);
|
||||
|
||||
useEffect(() => {
|
||||
setOllamaConfigHandler();
|
||||
}, [useCustomOllamaUrl]);
|
||||
|
||||
useEffect(() => {
|
||||
if (llmConfig.LLM === "custom") {
|
||||
setCustomModels([]);
|
||||
setCustomModelsChecked(false);
|
||||
setLlmConfig({ ...llmConfig, CUSTOM_MODEL: "" });
|
||||
}
|
||||
}, [llmConfig.CUSTOM_LLM_URL, llmConfig.CUSTOM_LLM_API_KEY]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!llmConfig.IMAGE_PROVIDER) {
|
||||
if (llmConfig.LLM === "openai") {
|
||||
|
|
@ -177,9 +119,10 @@ export default function LLMProviderSelection({
|
|||
onValueChange={handleProviderChange}
|
||||
className="w-full"
|
||||
>
|
||||
<TabsList className="grid w-full grid-cols-4 bg-transparent h-10">
|
||||
<TabsList className="grid w-full grid-cols-5 bg-transparent h-10">
|
||||
<TabsTrigger value="openai">OpenAI</TabsTrigger>
|
||||
<TabsTrigger value="google">Google</TabsTrigger>
|
||||
<TabsTrigger value="anthropic">Anthropic</TabsTrigger>
|
||||
<TabsTrigger value="ollama">Ollama</TabsTrigger>
|
||||
<TabsTrigger value="custom">Custom</TabsTrigger>
|
||||
</TabsList>
|
||||
|
|
@ -198,6 +141,7 @@ export default function LLMProviderSelection({
|
|||
<TabsContent value="openai" className="mt-6">
|
||||
<OpenAIConfig
|
||||
openaiApiKey={llmConfig.OPENAI_API_KEY || ""}
|
||||
openaiModel={llmConfig.OPENAI_MODEL || ""}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
|
@ -206,6 +150,17 @@ export default function LLMProviderSelection({
|
|||
<TabsContent value="google" className="mt-6">
|
||||
<GoogleConfig
|
||||
googleApiKey={llmConfig.GOOGLE_API_KEY || ""}
|
||||
googleModel={llmConfig.GOOGLE_MODEL || ""}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
{/* Anthropic Content */}
|
||||
<TabsContent value="anthropic" className="mt-6">
|
||||
<AnthropicConfig
|
||||
anthropicApiKey={llmConfig.ANTHROPIC_API_KEY || ""}
|
||||
anthropicModel={llmConfig.ANTHROPIC_MODEL || ""}
|
||||
extendedReasoning={llmConfig.EXTENDED_REASONING || false}
|
||||
onInputChange={input_field_changed}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
|
@ -215,16 +170,8 @@ export default function LLMProviderSelection({
|
|||
<OllamaConfig
|
||||
ollamaModel={llmConfig.OLLAMA_MODEL || ""}
|
||||
ollamaUrl={llmConfig.OLLAMA_URL || ""}
|
||||
useCustomUrl={useCustomOllamaUrl}
|
||||
ollamaModels={ollamaModels}
|
||||
ollamaModelsLoading={ollamaModelsLoading}
|
||||
useCustomUrl={llmConfig.USE_CUSTOM_URL || false}
|
||||
onInputChange={input_field_changed}
|
||||
onUseCustomUrlChange={setUseCustomOllamaUrl}
|
||||
openModelSelect={openModelSelect}
|
||||
onOpenModelSelectChange={setOpenModelSelect}
|
||||
onModelSelect={(modelName: string) => {
|
||||
input_field_changed(modelName, "ollama_model");
|
||||
}}
|
||||
/>
|
||||
</TabsContent>
|
||||
|
||||
|
|
@ -234,19 +181,13 @@ export default function LLMProviderSelection({
|
|||
customLlmUrl={llmConfig.CUSTOM_LLM_URL || ""}
|
||||
customLlmApiKey={llmConfig.CUSTOM_LLM_API_KEY || ""}
|
||||
customModel={llmConfig.CUSTOM_MODEL || ""}
|
||||
customModels={customModels}
|
||||
customModelsLoading={customModelsLoading}
|
||||
customModelsChecked={customModelsChecked}
|
||||
openModelSelect={openModelSelect}
|
||||
onInputChange={input_field_changed}
|
||||
onOpenModelSelectChange={setOpenModelSelect}
|
||||
onFetchCustomModels={fetchCustomModelsHandler}
|
||||
/>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
|
||||
{/* Image Provider Selection */}
|
||||
<div className="mb-8">
|
||||
<div className="my-8">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Select Image Provider
|
||||
</label>
|
||||
|
|
@ -388,7 +329,13 @@ export default function LLMProviderSelection({
|
|||
? llmConfig.OLLAMA_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "custom"
|
||||
? llmConfig.CUSTOM_MODEL ?? "xxxxx"
|
||||
: LLM_PROVIDERS[llmConfig.LLM!]?.model_label || "xxxxx"}{" "}
|
||||
: llmConfig.LLM === "anthropic"
|
||||
? llmConfig.ANTHROPIC_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "google"
|
||||
? llmConfig.GOOGLE_MODEL ?? "xxxxx"
|
||||
: llmConfig.LLM === "openai"
|
||||
? llmConfig.OPENAI_MODEL ?? "xxxxx"
|
||||
: "xxxxx"}{" "}
|
||||
for text generation and{" "}
|
||||
{llmConfig.IMAGE_PROVIDER &&
|
||||
IMAGE_PROVIDERS[llmConfig.IMAGE_PROVIDER]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"use client";
|
||||
import { useState } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { Check, ChevronsUpDown, Loader2 } from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
|
|
@ -13,6 +13,7 @@ import {
|
|||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Switch } from "./ui/switch";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface OllamaModel {
|
||||
label: string;
|
||||
|
|
@ -26,30 +27,84 @@ interface OllamaConfigProps {
|
|||
ollamaModel: string;
|
||||
ollamaUrl: string;
|
||||
useCustomUrl: boolean;
|
||||
ollamaModels: OllamaModel[];
|
||||
ollamaModelsLoading?: boolean;
|
||||
onInputChange: (value: string, field: string) => void;
|
||||
onUseCustomUrlChange: (checked: boolean) => void;
|
||||
openModelSelect: boolean;
|
||||
onOpenModelSelectChange: (open: boolean) => void;
|
||||
onModelSelect?: (modelName: string) => void;
|
||||
onInputChange: (value: string | boolean, field: string) => void;
|
||||
}
|
||||
|
||||
export default function OllamaConfig({
|
||||
ollamaModel,
|
||||
ollamaUrl,
|
||||
useCustomUrl,
|
||||
ollamaModels,
|
||||
ollamaModelsLoading = false,
|
||||
onInputChange,
|
||||
onUseCustomUrlChange,
|
||||
openModelSelect,
|
||||
onOpenModelSelectChange,
|
||||
onModelSelect,
|
||||
}: OllamaConfigProps) {
|
||||
const [ollamaModels, setOllamaModels] = useState<OllamaModel[]>([]);
|
||||
const [ollamaModelsLoading, setOllamaModelsLoading] = useState(false);
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
|
||||
const fetchOllamaModels = async () => {
|
||||
try {
|
||||
setOllamaModelsLoading(true);
|
||||
const response = await fetch('/api/v1/ppt/ollama/models/supported');
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
setOllamaModels(data);
|
||||
} else {
|
||||
console.error('Failed to fetch Ollama models');
|
||||
setOllamaModels([]);
|
||||
toast.error('Failed to fetch Ollama models');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching Ollama models:', error);
|
||||
toast.error('Error fetching Ollama models');
|
||||
setOllamaModels([]);
|
||||
} finally {
|
||||
setOllamaModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchOllamaModels();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="mb-8">
|
||||
<div className="space-y-6">
|
||||
{/* URL Configuration */}
|
||||
<div>
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Use custom Ollama URL
|
||||
</label>
|
||||
<Switch
|
||||
checked={useCustomUrl}
|
||||
onCheckedChange={(checked) => onInputChange(checked, "use_custom_url")}
|
||||
/>
|
||||
</div>
|
||||
{useCustomUrl && (
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Ollama URL
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
required
|
||||
placeholder="Enter your Ollama URL"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={ollamaUrl}
|
||||
onChange={(e) => onInputChange(e.target.value, "ollama_url")}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Change this if you are using a custom Ollama instance
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Model Selection */}
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Choose a supported model
|
||||
</label>
|
||||
|
|
@ -64,7 +119,7 @@ export default function OllamaConfig({
|
|||
) : ollamaModels && ollamaModels.length > 0 ? (
|
||||
<Popover
|
||||
open={openModelSelect}
|
||||
onOpenChange={onOpenModelSelectChange}
|
||||
onOpenChange={setOpenModelSelect}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
|
|
@ -122,12 +177,8 @@ export default function OllamaConfig({
|
|||
key={index}
|
||||
value={model.value}
|
||||
onSelect={(value) => {
|
||||
if (onModelSelect) {
|
||||
onModelSelect(value);
|
||||
} else {
|
||||
onInputChange(value, "ollama_model");
|
||||
}
|
||||
onOpenModelSelectChange(false);
|
||||
onInputChange(value, "ollama_model");
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
|
|
@ -185,42 +236,6 @@ export default function OllamaConfig({
|
|||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<div className="flex items-center justify-between mb-4 bg-green-50 p-2 rounded-sm">
|
||||
<label className="text-sm font-medium text-gray-700">
|
||||
Use custom Ollama URL
|
||||
</label>
|
||||
<Switch
|
||||
checked={useCustomUrl}
|
||||
onCheckedChange={onUseCustomUrlChange}
|
||||
/>
|
||||
</div>
|
||||
{useCustomUrl && (
|
||||
<>
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
Ollama URL
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
required
|
||||
placeholder="Enter your Ollama URL"
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
value={ollamaUrl}
|
||||
onChange={(e) =>
|
||||
onInputChange(e.target.value, "ollama_url")
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Change this if you are using a custom Ollama instance
|
||||
</p>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,27 +1,214 @@
|
|||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Check, ChevronsUpDown, Loader2 } from "lucide-react";
|
||||
import { Button } from "./ui/button";
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandInput,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "./ui/command";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "./ui/popover";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface OpenAIConfigProps {
|
||||
openaiApiKey: string;
|
||||
openaiModel: string;
|
||||
onInputChange: (value: string, field: string) => void;
|
||||
}
|
||||
|
||||
export default function OpenAIConfig({ openaiApiKey, onInputChange }: OpenAIConfigProps) {
|
||||
export default function OpenAIConfig({
|
||||
openaiApiKey,
|
||||
openaiModel,
|
||||
onInputChange
|
||||
}: OpenAIConfigProps) {
|
||||
const [openModelSelect, setOpenModelSelect] = useState(false);
|
||||
const [availableModels, setAvailableModels] = useState<string[]>([]);
|
||||
const [modelsLoading, setModelsLoading] = useState(false);
|
||||
const [modelsChecked, setModelsChecked] = useState(false);
|
||||
const [apiKey, setApiKey] = useState(openaiApiKey);
|
||||
|
||||
const openaiUrl = "https://api.openai.com/v1";
|
||||
|
||||
useEffect(() => {
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(false);
|
||||
onInputChange("", "openai_model");
|
||||
}, [apiKey]);
|
||||
|
||||
const onApiKeyChange = (value: string) => {
|
||||
setApiKey(value);
|
||||
onInputChange(value, "openai_api_key");
|
||||
};
|
||||
|
||||
const fetchAvailableModels = async () => {
|
||||
if (!openaiApiKey) return;
|
||||
|
||||
setModelsLoading(true);
|
||||
try {
|
||||
const response = await fetch('/api/v1/ppt/openai/models/available', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
url: openaiUrl,
|
||||
api_key: openaiApiKey
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
setAvailableModels(data);
|
||||
setModelsChecked(true);
|
||||
} else {
|
||||
console.error('Failed to fetch models');
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(true);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
toast.error('Error fetching models');
|
||||
setAvailableModels([]);
|
||||
setModelsChecked(true);
|
||||
} finally {
|
||||
setModelsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mb-8">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
OpenAI API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
value={openaiApiKey}
|
||||
onChange={(e) => onInputChange(e.target.value, "openai_api_key")}
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
placeholder="Enter your API key"
|
||||
/>
|
||||
<div className="space-y-6">
|
||||
{/* API Key Input */}
|
||||
<div className="mb-4">
|
||||
<label className="block text-sm font-medium text-gray-700 mb-2">
|
||||
OpenAI API Key
|
||||
</label>
|
||||
<div className="relative">
|
||||
<input
|
||||
type="text"
|
||||
value={openaiApiKey}
|
||||
onChange={(e) => onApiKeyChange(e.target.value)}
|
||||
className="w-full px-4 py-2.5 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors"
|
||||
placeholder="Enter your API key"
|
||||
/>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Your API key will be stored locally and never shared
|
||||
</p>
|
||||
</div>
|
||||
<p className="mt-2 text-sm text-gray-500 flex items-center gap-2">
|
||||
<span className="block w-1 h-1 rounded-full bg-gray-400"></span>
|
||||
Your API key will be stored locally and never shared
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
{/* Check for available models button - show when no models checked or no models found */}
|
||||
{(!modelsChecked || (modelsChecked && availableModels.length === 0)) && (
|
||||
<div className="mb-4">
|
||||
<button
|
||||
onClick={fetchAvailableModels}
|
||||
disabled={modelsLoading || !openaiApiKey}
|
||||
className={`w-full py-2.5 px-4 rounded-lg transition-all duration-200 border-2 ${modelsLoading || !openaiApiKey
|
||||
? "bg-gray-100 border-gray-300 cursor-not-allowed text-gray-500"
|
||||
: "bg-white border-blue-600 text-blue-600 hover:bg-blue-50 focus:ring-2 focus:ring-blue-500/20"
|
||||
}`}
|
||||
>
|
||||
{modelsLoading ? (
|
||||
<div className="flex items-center justify-center gap-2">
|
||||
<Loader2 className="w-4 h-4 animate-spin" />
|
||||
Checking for models...
|
||||
</div>
|
||||
) : (
|
||||
"Check for available models"
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show message if no models found */}
|
||||
{modelsChecked && availableModels.length === 0 && (
|
||||
<div className="mb-4 p-3 bg-yellow-50 border border-yellow-200 rounded-lg">
|
||||
<p className="text-sm text-yellow-800">
|
||||
No models found. Please make sure your API key is valid and has access to OpenAI models.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Model Selection - only show if models are available */}
|
||||
{modelsChecked && availableModels.length > 0 ? (
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-700 mb-3">
|
||||
Select OpenAI Model
|
||||
</label>
|
||||
<div className="w-full">
|
||||
<Popover
|
||||
open={openModelSelect}
|
||||
onOpenChange={setOpenModelSelect}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
role="combobox"
|
||||
aria-expanded={openModelSelect}
|
||||
className="w-full h-12 px-4 py-4 outline-none border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-colors hover:border-gray-400 justify-between"
|
||||
>
|
||||
<div className="flex gap-3 items-center">
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{openaiModel
|
||||
? availableModels.find(model => model === openaiModel) || openaiModel
|
||||
: "Select a model"}
|
||||
</span>
|
||||
</div>
|
||||
<ChevronsUpDown className="w-4 h-4 text-gray-500" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="p-0"
|
||||
align="start"
|
||||
style={{ width: "var(--radix-popover-trigger-width)" }}
|
||||
>
|
||||
<Command>
|
||||
<CommandInput placeholder="Search models..." />
|
||||
<CommandList>
|
||||
<CommandEmpty>No model found.</CommandEmpty>
|
||||
<CommandGroup>
|
||||
{availableModels.map((model, index) => (
|
||||
<CommandItem
|
||||
key={index}
|
||||
value={model}
|
||||
onSelect={(value) => {
|
||||
onInputChange(value, "openai_model");
|
||||
setOpenModelSelect(false);
|
||||
}}
|
||||
>
|
||||
<Check
|
||||
className={cn(
|
||||
"mr-2 h-4 w-4",
|
||||
openaiModel === model
|
||||
? "opacity-100"
|
||||
: "opacity-0"
|
||||
)}
|
||||
/>
|
||||
<div className="flex gap-3 items-center">
|
||||
<div className="flex flex-col space-y-1 flex-1">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<span className="text-sm font-medium text-gray-900">
|
||||
{model}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
19
servers/nextjs/types/global.d.ts
vendored
19
servers/nextjs/types/global.d.ts
vendored
|
|
@ -15,17 +15,36 @@ interface TextFrameProps {
|
|||
|
||||
interface LLMConfig {
|
||||
LLM?: string;
|
||||
|
||||
// OpenAI
|
||||
OPENAI_API_KEY?: string;
|
||||
OPENAI_MODEL?: string;
|
||||
|
||||
// Google
|
||||
GOOGLE_API_KEY?: string;
|
||||
GOOGLE_MODEL?: string;
|
||||
|
||||
// Anthropic
|
||||
ANTHROPIC_API_KEY?: string;
|
||||
ANTHROPIC_MODEL?: string;
|
||||
|
||||
// Ollama
|
||||
OLLAMA_URL?: string;
|
||||
OLLAMA_MODEL?: string;
|
||||
|
||||
// Custom LLM
|
||||
CUSTOM_LLM_URL?: string;
|
||||
CUSTOM_LLM_API_KEY?: string;
|
||||
CUSTOM_MODEL?: string;
|
||||
|
||||
// Image providers
|
||||
IMAGE_PROVIDER?: string;
|
||||
PIXABAY_API_KEY?: string;
|
||||
PEXELS_API_KEY?: string;
|
||||
|
||||
// Extended reasoning
|
||||
EXTENDED_REASONING?: boolean;
|
||||
|
||||
// Only used in UI settings
|
||||
USE_CUSTOM_URL?: boolean;
|
||||
}
|
||||
|
|
@ -67,16 +67,17 @@ export const LLM_PROVIDERS: Record<string, LLMProviderOption> = {
|
|||
openai: {
|
||||
value: "openai",
|
||||
label: "OpenAI",
|
||||
description: "OpenAI's latest image generation model",
|
||||
model_value: "gpt-4.1",
|
||||
model_label: "GPT-4.1"
|
||||
description: "OpenAI's latest text generation model",
|
||||
},
|
||||
google: {
|
||||
value: "google",
|
||||
label: "Google",
|
||||
description: "Google's primary image generation model",
|
||||
model_value: "gemini-2.0-flash",
|
||||
model_label: "Gemini 2.0 Flash"
|
||||
description: "Google's primary text generation model",
|
||||
},
|
||||
anthropic: {
|
||||
value: "anthropic",
|
||||
label: "Anthropic",
|
||||
description: "Anthropic's Claude models",
|
||||
},
|
||||
ollama: {
|
||||
value: "ollama",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import { toast } from "sonner";
|
||||
|
||||
export interface OllamaModel {
|
||||
label: string;
|
||||
value: string;
|
||||
|
|
@ -18,16 +16,37 @@ export interface DownloadingModel {
|
|||
|
||||
export interface LLMConfig {
|
||||
LLM?: string;
|
||||
|
||||
// OpenAI
|
||||
OPENAI_API_KEY?: string;
|
||||
OPENAI_MODEL?: string;
|
||||
|
||||
// Google
|
||||
GOOGLE_API_KEY?: string;
|
||||
GOOGLE_MODEL?: string;
|
||||
|
||||
// Anthropic
|
||||
ANTHROPIC_API_KEY?: string;
|
||||
ANTHROPIC_MODEL?: string;
|
||||
|
||||
// Ollama
|
||||
OLLAMA_URL?: string;
|
||||
OLLAMA_MODEL?: string;
|
||||
|
||||
// Custom LLM
|
||||
CUSTOM_LLM_URL?: string;
|
||||
CUSTOM_LLM_API_KEY?: string;
|
||||
CUSTOM_MODEL?: string;
|
||||
|
||||
// Image providers
|
||||
IMAGE_PROVIDER?: string;
|
||||
PEXELS_API_KEY?: string;
|
||||
PIXABAY_API_KEY?: string;
|
||||
IMAGE_PROVIDER?: string;
|
||||
|
||||
// Extended reasoning
|
||||
EXTENDED_REASONING?: boolean;
|
||||
|
||||
// Only used in UI settings
|
||||
USE_CUSTOM_URL?: boolean;
|
||||
}
|
||||
|
||||
|
|
@ -42,11 +61,15 @@ export interface OllamaModelsResult {
|
|||
export const updateLLMConfig = (
|
||||
currentConfig: LLMConfig,
|
||||
field: string,
|
||||
value: string
|
||||
value: string | boolean
|
||||
): LLMConfig => {
|
||||
const fieldMappings: Record<string, keyof LLMConfig> = {
|
||||
openai_api_key: "OPENAI_API_KEY",
|
||||
openai_model: "OPENAI_MODEL",
|
||||
google_api_key: "GOOGLE_API_KEY",
|
||||
google_model: "GOOGLE_MODEL",
|
||||
anthropic_api_key: "ANTHROPIC_API_KEY",
|
||||
anthropic_model: "ANTHROPIC_MODEL",
|
||||
ollama_url: "OLLAMA_URL",
|
||||
ollama_model: "OLLAMA_MODEL",
|
||||
custom_llm_url: "CUSTOM_LLM_URL",
|
||||
|
|
@ -55,6 +78,8 @@ export const updateLLMConfig = (
|
|||
pexels_api_key: "PEXELS_API_KEY",
|
||||
pixabay_api_key: "PIXABAY_API_KEY",
|
||||
image_provider: "IMAGE_PROVIDER",
|
||||
extended_reasoning: "EXTENDED_REASONING",
|
||||
use_custom_url: "USE_CUSTOM_URL",
|
||||
};
|
||||
|
||||
const configKey = fieldMappings[field];
|
||||
|
|
@ -86,53 +111,6 @@ export const changeProvider = (
|
|||
return newConfig;
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches supported Ollama models
|
||||
*/
|
||||
export const fetchOllamaModels = async (): Promise<OllamaModel[]> => {
|
||||
try {
|
||||
const response = await fetch("/api/v1/ppt/ollama/models/supported");
|
||||
const models = await response.json();
|
||||
return models || [];
|
||||
} catch (error) {
|
||||
console.error("Error fetching ollama models:", error);
|
||||
return []; // Ensure we always return an empty array on error
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches Ollama models and validates current selection
|
||||
* Returns models and updated config if needed
|
||||
*/
|
||||
export const fetchOllamaModelsWithConfig = async (
|
||||
config: LLMConfig
|
||||
): Promise<OllamaModelsResult> => {
|
||||
try {
|
||||
const models = await fetchOllamaModels();
|
||||
|
||||
// Check if currently selected model is still available
|
||||
let updatedConfig: LLMConfig | undefined;
|
||||
if (config.OLLAMA_MODEL && models && models.length > 0) {
|
||||
const isModelAvailable = models.some(
|
||||
(model: OllamaModel) => model.value === config.OLLAMA_MODEL
|
||||
);
|
||||
if (!isModelAvailable) {
|
||||
updatedConfig = { ...config, OLLAMA_MODEL: "" };
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
models,
|
||||
updatedConfig
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error fetching ollama models:", error);
|
||||
return {
|
||||
models: [],
|
||||
updatedConfig: { ...config, OLLAMA_MODEL: "" }
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
export const checkIfSelectedOllamaModelIsPulled = async (ollamaModel: string) => {
|
||||
try {
|
||||
|
|
@ -146,37 +124,6 @@ export const checkIfSelectedOllamaModelIsPulled = async (ollamaModel: string) =>
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches available custom models
|
||||
*/
|
||||
export const fetchCustomModels = async (
|
||||
url: string,
|
||||
apiKey: string
|
||||
): Promise<string[]> => {
|
||||
try {
|
||||
const response = await fetch("/api/v1/ppt/custom_llm/models/available", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
url: url || "",
|
||||
api_key: apiKey || "",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data;
|
||||
} catch (error) {
|
||||
toast.info("Could not fetch custom models");
|
||||
console.error("Error fetching custom models:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Resets downloading model state
|
||||
|
|
@ -234,51 +181,4 @@ export const pullOllamaModel = async (
|
|||
}
|
||||
}, 1000);
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Sets Ollama configuration based on custom URL preference
|
||||
*/
|
||||
export const setOllamaConfig = (
|
||||
currentConfig: LLMConfig,
|
||||
useCustomUrl: boolean
|
||||
): LLMConfig => {
|
||||
let customUrl = "http://localhost:11434";
|
||||
if (!useCustomUrl) {
|
||||
return {
|
||||
...currentConfig,
|
||||
OLLAMA_URL: customUrl,
|
||||
USE_CUSTOM_URL: false,
|
||||
};
|
||||
} else {
|
||||
return { ...currentConfig, USE_CUSTOM_URL: true, OLLAMA_URL: customUrl };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles saving configuration with error handling
|
||||
*/
|
||||
export const handleSaveConfiguration = async (
|
||||
llmConfig: LLMConfig,
|
||||
handleSaveLLMConfig: (config: LLMConfig) => Promise<void>,
|
||||
pullOllamaModels?: () => Promise<void>
|
||||
): Promise<void> => {
|
||||
try {
|
||||
await handleSaveLLMConfig(llmConfig);
|
||||
if (llmConfig.LLM === "ollama" && pullOllamaModels) {
|
||||
await pullOllamaModels();
|
||||
}
|
||||
toast.success("Configuration saved successfully");
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
toast.error(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: "Failed to save configuration",
|
||||
{
|
||||
description: "Failed to save configuration",
|
||||
}
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
|
@ -16,8 +16,30 @@ export const handleSaveLLMConfig = async (llmConfig: LLMConfig) => {
|
|||
export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
||||
if (!llmConfig.LLM) return false;
|
||||
if (!llmConfig.IMAGE_PROVIDER) return false;
|
||||
const OPENAI_API_KEY = llmConfig.OPENAI_API_KEY;
|
||||
const GOOGLE_API_KEY = llmConfig.GOOGLE_API_KEY;
|
||||
|
||||
const isOpenAIConfigValid =
|
||||
llmConfig.OPENAI_MODEL !== "" &&
|
||||
llmConfig.OPENAI_MODEL !== null &&
|
||||
llmConfig.OPENAI_MODEL !== undefined &&
|
||||
llmConfig.OPENAI_API_KEY !== "" &&
|
||||
llmConfig.OPENAI_API_KEY !== null &&
|
||||
llmConfig.OPENAI_API_KEY !== undefined;
|
||||
|
||||
const isGoogleConfigValid =
|
||||
llmConfig.GOOGLE_MODEL !== "" &&
|
||||
llmConfig.GOOGLE_MODEL !== null &&
|
||||
llmConfig.GOOGLE_MODEL !== undefined &&
|
||||
llmConfig.GOOGLE_API_KEY !== "" &&
|
||||
llmConfig.GOOGLE_API_KEY !== null &&
|
||||
llmConfig.GOOGLE_API_KEY !== undefined;
|
||||
|
||||
const isAnthropicConfigValid =
|
||||
llmConfig.ANTHROPIC_MODEL !== "" &&
|
||||
llmConfig.ANTHROPIC_MODEL !== null &&
|
||||
llmConfig.ANTHROPIC_MODEL !== undefined &&
|
||||
llmConfig.ANTHROPIC_API_KEY !== "" &&
|
||||
llmConfig.ANTHROPIC_API_KEY !== null &&
|
||||
llmConfig.ANTHROPIC_API_KEY !== undefined;
|
||||
|
||||
const isOllamaConfigValid =
|
||||
llmConfig.OLLAMA_MODEL !== "" &&
|
||||
|
|
@ -42,9 +64,9 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
case "pixabay":
|
||||
return llmConfig.PIXABAY_API_KEY && llmConfig.PIXABAY_API_KEY !== "";
|
||||
case "dall-e-3":
|
||||
return OPENAI_API_KEY && OPENAI_API_KEY !== "";
|
||||
return llmConfig.OPENAI_API_KEY && llmConfig.OPENAI_API_KEY !== "";
|
||||
case "gemini_flash":
|
||||
return GOOGLE_API_KEY && GOOGLE_API_KEY !== "";
|
||||
return llmConfig.GOOGLE_API_KEY && llmConfig.GOOGLE_API_KEY !== "";
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -52,18 +74,16 @@ export const hasValidLLMConfig = (llmConfig: LLMConfig) => {
|
|||
|
||||
const isLLMConfigValid =
|
||||
llmConfig.LLM === "openai"
|
||||
? OPENAI_API_KEY !== "" &&
|
||||
OPENAI_API_KEY !== null &&
|
||||
OPENAI_API_KEY !== undefined
|
||||
? isOpenAIConfigValid
|
||||
: llmConfig.LLM === "google"
|
||||
? GOOGLE_API_KEY !== "" &&
|
||||
GOOGLE_API_KEY !== null &&
|
||||
GOOGLE_API_KEY !== undefined
|
||||
: llmConfig.LLM === "ollama"
|
||||
? isOllamaConfigValid
|
||||
: llmConfig.LLM === "custom"
|
||||
? isCustomConfigValid
|
||||
: false;
|
||||
? isGoogleConfigValid
|
||||
: llmConfig.LLM === "anthropic"
|
||||
? isAnthropicConfigValid
|
||||
: llmConfig.LLM === "ollama"
|
||||
? isOllamaConfigValid
|
||||
: llmConfig.LLM === "custom"
|
||||
? isCustomConfigValid
|
||||
: false;
|
||||
|
||||
return isLLMConfigValid && isImageConfigValid();
|
||||
};
|
||||
|
|
|
|||
5
start.js
5
start.js
|
|
@ -38,15 +38,20 @@ const setupUserConfigFromEnv = () => {
|
|||
const userConfig = {
|
||||
LLM: process.env.LLM || existingConfig.LLM,
|
||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY || existingConfig.OPENAI_API_KEY,
|
||||
OPENAI_MODEL: process.env.OPENAI_MODEL || existingConfig.OPENAI_MODEL,
|
||||
GOOGLE_API_KEY: process.env.GOOGLE_API_KEY || existingConfig.GOOGLE_API_KEY,
|
||||
GOOGLE_MODEL: process.env.GOOGLE_MODEL || existingConfig.GOOGLE_MODEL,
|
||||
OLLAMA_URL: process.env.OLLAMA_URL || existingConfig.OLLAMA_URL,
|
||||
OLLAMA_MODEL: process.env.OLLAMA_MODEL || existingConfig.OLLAMA_MODEL,
|
||||
ANTHROPIC_API_KEY: process.env.ANTHROPIC_API_KEY || existingConfig.ANTHROPIC_API_KEY,
|
||||
ANTHROPIC_MODEL: process.env.ANTHROPIC_MODEL || existingConfig.ANTHROPIC_MODEL,
|
||||
CUSTOM_LLM_URL: process.env.CUSTOM_LLM_URL || existingConfig.CUSTOM_LLM_URL,
|
||||
CUSTOM_LLM_API_KEY: process.env.CUSTOM_LLM_API_KEY || existingConfig.CUSTOM_LLM_API_KEY,
|
||||
CUSTOM_MODEL: process.env.CUSTOM_MODEL || existingConfig.CUSTOM_MODEL,
|
||||
PEXELS_API_KEY: process.env.PEXELS_API_KEY || existingConfig.PEXELS_API_KEY,
|
||||
PIXABAY_API_KEY: process.env.PIXABAY_API_KEY || existingConfig.PIXABAY_API_KEY,
|
||||
IMAGE_PROVIDER: process.env.IMAGE_PROVIDER || existingConfig.IMAGE_PROVIDER,
|
||||
EXTENDED_REASONING: process.env.EXTENDED_REASONING || existingConfig.EXTENDED_REASONING,
|
||||
USE_CUSTOM_URL: process.env.USE_CUSTOM_URL || existingConfig.USE_CUSTOM_URL,
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue