refactor(fastapi): progress

This commit is contained in:
sauravniraula 2025-07-14 18:49:48 +05:45
parent 9c0e7a37b1
commit 788c8c4042
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
1630 changed files with 1332 additions and 1089 deletions

View file

@ -0,0 +1,75 @@
from contextlib import asynccontextmanager
import os
from fastapi import FastAPI
from sqlmodel import SQLModel
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from enums.llm_provider import LLMProvider
from services import SQL_ENGINE
from utils.custom_llm_provider import list_available_custom_models
from utils.llm_provider import (
get_llm_provider,
is_custom_llm_selected,
is_ollama_selected,
)
from utils.ollama import pull_ollama_model
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
async def check_llm_model_availability():
if not can_change_keys:
if get_llm_provider() == LLMProvider.OPENAI:
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise Exception("OPENAI_API_KEY must be provided")
elif get_llm_provider() == LLMProvider.GOOGLE:
google_api_key = os.getenv("GOOGLE_API_KEY")
if not google_api_key:
raise Exception("GOOGLE_API_KEY must be provided")
elif is_ollama_selected():
ollama_model = os.getenv("OLLAMA_MODEL")
if not ollama_model:
raise Exception("OLLAMA_MODEL must be provided")
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
raise Exception(f"Model {ollama_model} is not supported")
print("-" * 50)
print("Pulling model: ", ollama_model)
async for event in pull_ollama_model(ollama_model):
print(event)
print("Pulled model: ", ollama_model)
print("-" * 50)
elif is_custom_llm_selected():
custom_model = os.getenv("CUSTOM_MODEL")
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
if not custom_model:
raise Exception("CUSTOM_MODEL must be provided")
if not custom_llm_url:
raise Exception("CUSTOM_LLM_URL must be provided")
if not custom_llm_api_key:
raise Exception("CUSTOM_LLM_API_KEY must be provided")
print("-" * 50)
print("Selecting model: ", custom_model)
models = await list_available_custom_models(
custom_llm_url, custom_llm_api_key
)
print("Available models: ", models)
print("-" * 50)
if custom_model not in models:
raise Exception(f"Model {custom_model} is not available")
@asynccontextmanager
async def app_lifespan(_: FastAPI):
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
SQLModel.metadata.create_all(SQL_ENGINE)
await check_llm_model_availability()
yield

View file

@ -1,98 +1,8 @@
import asyncio
import os
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from sqlmodel import SQLModel
from contextlib import asynccontextmanager
from api.models import SelectedLLMProvider
from api.routers.presentation.router import presentation_router
from api.services.database import sql_engine
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from api.utils.utils import update_env_with_user_config
from api.utils.model_utils import (
get_selected_llm_provider,
is_custom_llm_selected,
is_ollama_selected,
list_available_custom_models,
pull_ollama_model,
)
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
from fastapi import FastAPI
from api.v1.ppt.router import API_V1_PPT_ROUTER
from api.lifespan import app_lifespan
async def check_llm_model_availability():
if not can_change_keys:
if get_selected_llm_provider() == SelectedLLMProvider.OPENAI:
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise Exception("OPENAI_API_KEY must be provided")
APP = FastAPI(lifespan=app_lifespan)
elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE:
google_api_key = os.getenv("GOOGLE_API_KEY")
if not google_api_key:
raise Exception("GOOGLE_API_KEY must be provided")
elif is_ollama_selected():
ollama_model = os.getenv("OLLAMA_MODEL")
if not ollama_model:
raise Exception("OLLAMA_MODEL must be provided")
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
raise Exception(f"Model {ollama_model} is not supported")
print("-" * 50)
print("Pulling model: ", ollama_model)
async for event in pull_ollama_model(ollama_model):
print(event)
print("Pulled model: ", ollama_model)
print("-" * 50)
elif is_custom_llm_selected():
custom_model = os.getenv("CUSTOM_MODEL")
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
if not custom_model:
raise Exception("CUSTOM_MODEL must be provided")
if not custom_llm_url:
raise Exception("CUSTOM_LLM_URL must be provided")
if not custom_llm_api_key:
raise Exception("CUSTOM_LLM_API_KEY must be provided")
print("-" * 50)
print("Selecting model: ", custom_model)
models = await list_available_custom_models(
custom_llm_url, custom_llm_api_key
)
print("Available models: ", models)
print("-" * 50)
if custom_model not in models:
raise Exception(f"Model {custom_model} is not available")
@asynccontextmanager
async def lifespan(_: FastAPI):
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
SQLModel.metadata.create_all(sql_engine)
await check_llm_model_availability()
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def update_env_middleware(request: Request, call_next):
if can_change_keys:
update_env_with_user_config()
return await call_next(request)
app.include_router(presentation_router)
APP.include_router(API_V1_PPT_ROUTER)

View file

@ -0,0 +1,23 @@
from fastapi import Request
from api.main import APP
from fastapi.middleware.cors import CORSMiddleware
from utils.get_env import get_can_change_keys_env
from utils.user_config import update_env_with_user_config
origins = ["*"]
APP.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@APP.middleware("http")
async def update_env_middleware(request: Request, call_next):
if get_can_change_keys_env() != "false":
update_env_with_user_config()
return await call_next(request)

View file

@ -0,0 +1,34 @@
from http.client import HTTPException
from typing import List, Optional
import uuid
from fastapi import UploadFile
from api.v1.ppt.router import API_V1_PPT_ROUTER
from constants.documents import UPLOAD_ACCEPTED_DOCUMENTS, UPLOAD_ACCEPTED_IMAGES
from services import TEMP_FILE_SERVICE
from utils.validators import validate_files
@API_V1_PPT_ROUTER.post("/files/upload")
async def upload_files(files: Optional[List[UploadFile]]):
if not files:
raise HTTPException(400, "Files are required")
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
validate_files(files, True, True, 50, UPLOAD_ACCEPTED_DOCUMENTS)
validate_files(files, True, True, 10, UPLOAD_ACCEPTED_IMAGES)
temp_files: List[str] = []
if files:
for each_file in files:
temp_path = TEMP_FILE_SERVICE.create_temp_file_path(
each_file.filename, temp_dir
)
with open(temp_path, "wb") as f:
content = await each_file.read()
f.write(content)
temp_files.append(temp_path)
return temp_files

View file

@ -0,0 +1,5 @@
from fastapi import APIRouter
API_V1_PPT = "/api/v1/ppt"
API_V1_PPT_ROUTER = APIRouter(prefix=API_V1_PPT)

View file

@ -0,0 +1,21 @@
PDF_MIME_TYPES = ["application/pdf"]
TEXT_MIME_TYPES = ["text/plain"]
POWERPOINT_TYPES = [
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
]
WORD_TYPES = [
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
]
SPREADSHEET_TYPES = ["text/csv", "application/csv"]
PNG_MIME_TYPES = ["image/png"]
JPEG_MIME_TYPES = ["image/jpeg"]
WEBP_MIME_TYPES = ["image/webp"]
UPLOAD_ACCEPTED_DOCUMENTS = (
PDF_MIME_TYPES + TEXT_MIME_TYPES + POWERPOINT_TYPES + WORD_TYPES
)
UPLOAD_ACCEPTED_IMAGES = PNG_MIME_TYPES + JPEG_MIME_TYPES + WEBP_MIME_TYPES

View file

@ -0,0 +1,253 @@
from models.ollama_model_metadata import OllamaModelMetadata
SUPPORTED_LLAMA_MODELS = {
"llama3:8b": OllamaModelMetadata(
label="Llama 3:8b",
value="llama3:8b",
description="❌ Graphs not supported.",
size="4.7GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3:70b": OllamaModelMetadata(
label="Llama 3:70b",
value="llama3:70b",
description="✅ Graphs supported.",
size="40GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3.1:8b": OllamaModelMetadata(
label="Llama 3.1:8b",
value="llama3.1:8b",
description="❌ Graphs not supported.",
size="4.9GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3.1:70b": OllamaModelMetadata(
label="Llama 3.1:70b",
value="llama3.1:70b",
description="✅ Graphs supported.",
size="43GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3.1:405b": OllamaModelMetadata(
label="Llama 3.1:405b",
value="llama3.1:405b",
description="✅ Graphs supported.",
size="243GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3.2:1b": OllamaModelMetadata(
label="Llama 3.2:1b",
value="llama3.2:1b",
description="❌ Graphs not supported.",
size="1.3GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3.2:3b": OllamaModelMetadata(
label="Llama 3.2:3b",
value="llama3.2:3b",
description="❌ Graphs not supported.",
size="2GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama3.3:70b": OllamaModelMetadata(
label="Llama 3.3:70b",
value="llama3.3:70b",
description="✅ Graphs supported.",
size="43GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama4:16x17b": OllamaModelMetadata(
label="Llama 4:16x17b",
value="llama4:16x17b",
description="✅ Graphs supported.",
size="67GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
"llama4:128x17b": OllamaModelMetadata(
label="Llama 4:128x17b",
value="llama4:128x17b",
description="✅ Graphs supported.",
size="245GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/meta.png",
),
}
SUPPORTED_GEMMA_MODELS = {
"gemma3:1b": OllamaModelMetadata(
label="Gemma 3:1b",
value="gemma3:1b",
description="❌ Graphs not supported.",
size="815MB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/gemma.png",
),
"gemma3:4b": OllamaModelMetadata(
label="Gemma 3:4b",
value="gemma3:4b",
description="❌ Graphs not supported.",
size="3.3GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/gemma.png",
),
"gemma3:12b": OllamaModelMetadata(
label="Gemma 3:12b",
value="gemma3:12b",
description="❌ Graphs not supported.",
size="8.1GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/gemma.png",
),
"gemma3:27b": OllamaModelMetadata(
label="Gemma 3:27b",
value="gemma3:27b",
description="✅ Graphs supported.",
size="17GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/gemma.png",
),
}
SUPPORTED_DEEPSEEK_MODELS = {
"deepseek-r1:1.5b": OllamaModelMetadata(
label="DeepSeek R1:1.5b",
value="deepseek-r1:1.5b",
description="❌ Graphs not supported.",
size="1.1GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
"deepseek-r1:7b": OllamaModelMetadata(
label="DeepSeek R1:7b",
value="deepseek-r1:7b",
description="❌ Graphs not supported.",
size="4.7GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
"deepseek-r1:8b": OllamaModelMetadata(
label="DeepSeek R1:8b",
value="deepseek-r1:8b",
description="❌ Graphs not supported.",
size="5.2GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
"deepseek-r1:14b": OllamaModelMetadata(
label="DeepSeek R1:14b",
value="deepseek-r1:14b",
description="❌ Graphs not supported.",
size="9GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
"deepseek-r1:32b": OllamaModelMetadata(
label="DeepSeek R1:32b",
value="deepseek-r1:32b",
description="✅ Graphs supported.",
size="20GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
"deepseek-r1:70b": OllamaModelMetadata(
label="DeepSeek R1:70b",
value="deepseek-r1:70b",
description="✅ Graphs supported.",
size="43GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
"deepseek-r1:671b": OllamaModelMetadata(
label="DeepSeek R1:671b",
value="deepseek-r1:671b",
description="✅ Graphs supported.",
size="404GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/deepseek.png",
),
}
SUPPORTED_QWEN_MODELS = {
"qwen3:0.6b": OllamaModelMetadata(
label="Qwen 3:0.6b",
value="qwen3:0.6b",
description="❌ Graphs not supported.",
size="523MB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:1.7b": OllamaModelMetadata(
label="Qwen 3:1.7b",
value="qwen3:1.7b",
description="❌ Graphs not supported.",
size="1.4GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:4b": OllamaModelMetadata(
label="Qwen 3:4b",
value="qwen3:4b",
description="❌ Graphs not supported.",
size="2.6GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:8b": OllamaModelMetadata(
label="Qwen 3:8b",
value="qwen3:8b",
description="❌ Graphs not supported.",
size="5.2GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:14b": OllamaModelMetadata(
label="Qwen 3:14b",
value="qwen3:14b",
description="❌ Graphs not supported.",
size="9.3GB",
supports_graph=False,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:30b": OllamaModelMetadata(
label="Qwen 3:30b",
value="qwen3:30b",
description="✅ Graphs supported.",
size="19GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:32b": OllamaModelMetadata(
label="Qwen 3:32b",
value="qwen3:32b",
description="✅ Graphs supported.",
size="20GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
"qwen3:235b": OllamaModelMetadata(
label="Qwen 3:235b",
value="qwen3:235b",
description="✅ Graphs supported.",
size="142GB",
supports_graph=True,
icon="/static/servers/fastapi/assets/icons/qwen.png",
),
}
SUPPORTED_OLLAMA_MODELS = {
**SUPPORTED_LLAMA_MODELS,
**SUPPORTED_GEMMA_MODELS,
**SUPPORTED_DEEPSEEK_MODELS,
**SUPPORTED_QWEN_MODELS,
}

View file

@ -0,0 +1,8 @@
from enum import Enum
class LLMProvider(Enum):
OLLAMA = "ollama"
OPENAI = "openai"
GOOGLE = "google"
CUSTOM = "custom"

View file

@ -1,116 +0,0 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
class PointModel(BaseModel):
x: float
y: float
def to_list(self) -> List[float]:
return [self.x, self.y]
class PointWithRadius(PointModel):
radius: Optional[float] = None
class BarSeriesModel(BaseModel):
name: str
data: List[float] = Field(
description="Only numbers should be given out in data. Don't include text/string in data."
)
class ScatterSeriesModel(BaseModel):
name: str
points: List[PointModel]
class BubbleSeriesModel(BaseModel):
name: str
points: List[PointWithRadius]
class LineSeriesModel(BaseModel):
name: str
data: List[float] = Field(
description="Only numbers should be given out in data. Don't include text/string in data."
)
class PieChartSeriesModel(BaseModel):
data: List[float]
class BarGraphDataModel(BaseModel):
categories: List[str]
series: List[BarSeriesModel] = Field(
description="There should be no more than 3 series"
)
class ScatterChartDataModel(BaseModel):
series: List[ScatterSeriesModel]
class BubbleChartDataModel(BaseModel):
series: List[BubbleSeriesModel]
class LineChartDataModel(BaseModel):
categories: List[str]
series: List[LineSeriesModel] = Field(
description="There should be no more than 3 series"
)
class PieChartDataModel(BaseModel):
categories: List[str]
series: List[PieChartSeriesModel] = Field(
description="One series model with list of data",
min_length=1,
)
@model_validator(mode="after")
def limit_series(self):
self.series = self.series[:1]
return self
class GraphTypeEnum(Enum):
pie = "pie"
bar = "bar"
line = "line"
class LLMGraphModel(BaseModel):
name: str
type: GraphTypeEnum
unit: Optional[str] = Field(
description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc."
)
data: PieChartDataModel | LineChartDataModel | BarGraphDataModel
class GraphModel(LLMGraphModel):
style: Optional[dict] = {}
@classmethod
def from_llm_graph_model(
cls, llm_graph_model: LLMGraphModel, style: Optional[dict] = {}
):
return cls(
name=llm_graph_model.name,
type=llm_graph_model.type,
unit=llm_graph_model.unit,
data=llm_graph_model.data,
style=style,
)
GRAPH_TYPE_MAPPING = {
GraphTypeEnum.pie: PieChartDataModel,
GraphTypeEnum.bar: BarGraphDataModel,
GraphTypeEnum.line: LineChartDataModel,
}

View file

@ -0,0 +1,10 @@
from pydantic import BaseModel
class OllamaModelMetadata(BaseModel):
label: str
value: str
description: str
icon: str
size: str
supports_graph: bool

View file

@ -0,0 +1,10 @@
from typing import Optional
from pydantic import BaseModel
class OllamaModelStatusResponse(BaseModel):
name: str
size: Optional[int] = None
downloaded: Optional[int] = None
status: str
done: bool

View file

@ -0,0 +1,14 @@
from typing import Optional
from pydantic import BaseModel
class UserConfig(BaseModel):
LLM: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
OLLAMA_URL: Optional[str] = None
OLLAMA_MODEL: Optional[str] = None
CUSTOM_LLM_URL: Optional[str] = None
CUSTOM_LLM_API_KEY: Optional[str] = None
CUSTOM_MODEL: Optional[str] = None
PEXELS_API_KEY: Optional[str] = None

View file

@ -1,232 +0,0 @@
from enum import Enum
from typing import List, Mapping, Union
from pydantic import BaseModel
from graph_processor.models import GraphModel, LLMGraphModel
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
TYPE3,
TYPE4,
TYPE5,
TYPE6,
TYPE7,
TYPE8,
TYPE9,
)
class TableType(Enum):
TABLE = "table"
BAR = "bar"
LINE = "line"
PIE = "pie"
class TableDataModel(BaseModel):
x_labels: List[str]
y_labels: List[str]
data: List[List[float]]
class TableModel(BaseModel):
name: str
type: TableType
data: TableDataModel
class HeadingModel(BaseModel):
heading: str
description: str
def to_llm_content(self, image_prompt: str = None, icon_query: str = None):
from ppt_generator.models.llm_models import (
LLMHeadingModel,
LLMHeadingModelWithImagePrompt,
LLMHeadingModelWithIconQuery,
)
if image_prompt:
return LLMHeadingModelWithImagePrompt(
heading=self.heading,
description=self.description,
image_prompt=image_prompt,
)
elif icon_query:
return LLMHeadingModelWithIconQuery(
heading=self.heading,
description=self.description,
icon_query=icon_query,
)
return LLMHeadingModel(
heading=self.heading,
description=self.description,
)
class SlideContentModel(BaseModel):
title: str
class Type1Content(SlideContentModel):
body: str
image_prompts: List[str]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType1Content
return LLMType1Content(
title=self.title,
body=self.body,
image_prompt=self.image_prompts[0] if self.image_prompts else "",
)
class Type2Content(SlideContentModel):
body: List[HeadingModel]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType2Content
return LLMType2Content(
title=self.title,
body=[item.to_llm_content() for item in self.body],
)
class Type3Content(SlideContentModel):
body: List[HeadingModel]
image_prompts: List[str]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType3Content
return LLMType3Content(
title=self.title,
body=[item.to_llm_content() for item in self.body],
image_prompt=self.image_prompts[0] if self.image_prompts else "",
)
class Type4Content(SlideContentModel):
body: List[HeadingModel]
image_prompts: List[str]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType4Content
llm_body = []
for i, item in enumerate(self.body):
image_prompt = self.image_prompts[i] if i < len(self.image_prompts) else ""
llm_body.append(item.to_llm_content(image_prompt=image_prompt))
return LLMType4Content(
title=self.title,
body=llm_body,
)
class Type5Content(SlideContentModel):
body: str
# table: TableModel
graph: GraphModel
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType5Content
return LLMType5Content(
title=self.title,
body=self.body,
# table=self.table,
graph=self.graph,
)
class Type6Content(SlideContentModel):
description: str
body: List[HeadingModel]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType6Content
return LLMType6Content(
title=self.title,
description=self.description,
body=[item.to_llm_content() for item in self.body],
)
class Type7Content(SlideContentModel):
body: List[HeadingModel]
icon_queries: List[str]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType7Content
llm_body = []
for i, item in enumerate(self.body):
icon_query = self.icon_queries[i] if i < len(self.icon_queries) else ""
llm_body.append(item.to_llm_content(icon_query=icon_query))
return LLMType7Content(
title=self.title,
body=llm_body,
)
class Type8Content(SlideContentModel):
description: str
body: List[HeadingModel]
icon_queries: List[str]
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType8Content
llm_body = []
for i, item in enumerate(self.body):
icon_query = self.icon_queries[i] if i < len(self.icon_queries) else ""
llm_body.append(item.to_llm_content(icon_query=icon_query))
return LLMType8Content(
title=self.title,
description=self.description,
body=llm_body,
)
class Type9Content(SlideContentModel):
body: List[HeadingModel]
# table: TableModel
graph: GraphModel
def to_llm_content(self):
from ppt_generator.models.llm_models import LLMType9Content
return LLMType9Content(
title=self.title,
body=[item.to_llm_content() for item in self.body],
# table=self.table,
graph=self.graph,
)
ContentUnion = Union[
Type1Content,
Type2Content,
Type3Content,
Type4Content,
Type5Content,
Type6Content,
Type7Content,
Type8Content,
Type9Content,
]
CONTENT_TYPE_MAPPING: Mapping[int, ContentUnion] = {
TYPE1: Type1Content,
TYPE2: Type2Content,
TYPE3: Type3Content,
TYPE4: Type4Content,
TYPE5: Type5Content,
TYPE6: Type6Content,
TYPE7: Type7Content,
TYPE8: Type8Content,
TYPE9: Type9Content,
}

View file

@ -1,220 +0,0 @@
from typing import List, Mapping, Union
from pydantic import BaseModel
from graph_processor.models import GraphModel, LLMGraphModel
from ppt_generator.models.content_type_models import (
HeadingModel,
TableDataModel,
TableModel,
TableType,
Type1Content,
Type2Content,
Type3Content,
Type4Content,
Type5Content,
Type6Content,
Type7Content,
Type8Content,
Type9Content,
)
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
TYPE3,
TYPE4,
TYPE5,
TYPE6,
TYPE7,
TYPE8,
TYPE9,
)
class LLMTableDataModel(TableDataModel):
x_labels: List[str]
y_labels: List[str]
data: List[List[float]]
class LLMTableModel(TableModel):
name: str
type: TableType
data: LLMTableDataModel
class LLMHeadingModel(BaseModel):
heading: str
description: str
def to_content(self) -> HeadingModel:
return HeadingModel(
heading=self.heading,
description=self.description,
)
class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
image_prompt: str
def to_content(self) -> HeadingModel:
return HeadingModel(
heading=self.heading,
description=self.description,
)
class LLMHeadingModelWithIconQuery(LLMHeadingModel):
icon_query: str
def to_content(self) -> HeadingModel:
return HeadingModel(
heading=self.heading,
description=self.description,
)
class LLMSlideContentModel(BaseModel):
title: str
class LLMType1Content(LLMSlideContentModel):
body: str
image_prompt: str
def to_content(self) -> Type1Content:
return Type1Content(
title=self.title,
body=self.body,
image_prompts=[self.image_prompt],
)
class LLMType2Content(LLMSlideContentModel):
body: List[LLMHeadingModel]
def to_content(self) -> Type2Content:
return Type2Content(
title=self.title,
body=[each.to_content() for each in self.body],
)
class LLMType3Content(LLMSlideContentModel):
body: List[LLMHeadingModel]
image_prompt: str
def to_content(self) -> Type3Content:
return Type3Content(
title=self.title,
body=[each.to_content() for each in self.body],
image_prompts=[self.image_prompt],
)
class LLMType4Content(LLMSlideContentModel):
body: List[LLMHeadingModelWithImagePrompt]
def to_content(self) -> Type4Content:
return Type4Content(
title=self.title,
body=[each.to_content() for each in self.body],
image_prompts=[each.image_prompt for each in self.body],
)
class LLMType5Content(LLMSlideContentModel):
body: str
# table: LLMTableModel
graph: LLMGraphModel
def to_content(self) -> Type5Content:
return Type5Content(
title=self.title,
body=self.body,
# table=self.table,
graph=GraphModel.from_llm_graph_model(self.graph),
)
class LLMType6Content(LLMSlideContentModel):
description: str
body: List[LLMHeadingModel]
def to_content(self) -> Type6Content:
return Type6Content(
title=self.title,
description=self.description,
body=[each.to_content() for each in self.body],
)
class LLMType7Content(LLMSlideContentModel):
body: List[LLMHeadingModelWithIconQuery]
def to_content(self) -> Type7Content:
return Type7Content(
title=self.title,
body=[each.to_content() for each in self.body],
icon_queries=[each.icon_query for each in self.body],
)
class LLMType8Content(LLMSlideContentModel):
description: str
body: List[LLMHeadingModelWithImagePrompt]
def to_content(self) -> Type8Content:
return Type8Content(
title=self.title,
description=self.description,
body=[each.to_content() for each in self.body],
icon_queries=[each.image_prompt for each in self.body],
)
class LLMType9Content(LLMSlideContentModel):
body: List[LLMHeadingModel]
# table: LLMTableModel
graph: LLMGraphModel
def to_content(self) -> Type9Content:
return Type9Content(
title=self.title,
body=[each.to_content() for each in self.body],
# table=self.table,
graph=GraphModel.from_llm_graph_model(self.graph),
)
LLMContentUnion = Union[
LLMType1Content,
LLMType2Content,
LLMType3Content,
LLMType4Content,
LLMType5Content,
LLMType6Content,
LLMType7Content,
LLMType8Content,
LLMType9Content,
]
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMContentUnion] = {
TYPE1: LLMType1Content,
TYPE2: LLMType2Content,
TYPE3: LLMType3Content,
TYPE4: LLMType4Content,
TYPE5: LLMType5Content,
TYPE6: LLMType6Content,
TYPE7: LLMType7Content,
TYPE8: LLMType8Content,
TYPE9: LLMType9Content,
}
class LLMSlideModel(BaseModel):
type: int
content: LLMContentUnion
class LLMPresentationModel(BaseModel):
slides: List[LLMSlideModel]

View file

@ -1,232 +0,0 @@
from typing import List, Mapping, Union
from pydantic import Field
from graph_processor.models import LLMGraphModel
from ppt_generator.models.content_type_models import TableType
from ppt_generator.models.other_models import (
TYPE1,
TYPE2,
TYPE3,
TYPE4,
TYPE5,
TYPE6,
TYPE7,
TYPE8,
TYPE9,
)
from ppt_generator.models.llm_models import (
LLMTableDataModel,
LLMTableModel,
LLMHeadingModel,
LLMHeadingModelWithImagePrompt,
LLMHeadingModelWithIconQuery,
LLMSlideContentModel,
LLMType1Content,
LLMType2Content,
LLMType3Content,
LLMType4Content,
LLMType5Content,
LLMType6Content,
LLMType7Content,
LLMType8Content,
LLMType9Content,
LLMSlideModel,
LLMPresentationModel,
)
class LLMTableDataModelWithValidation(LLMTableDataModel):
x_labels: List[str] = Field(
description="X labels of the table",
min_length=1,
max_length=5,
)
y_labels: List[str] = Field(
description="Y labels of the table",
min_length=1,
max_length=3,
)
data: List[List[float]] = Field(
description="Data of the table",
min_length=1,
max_length=5,
)
class LLMTableModelWithValidation(LLMTableModel):
name: str = Field(
description="Name of the table in about 8 words",
min_length=10,
max_length=50,
)
type: TableType = Field(description="Type of the table")
data: LLMTableDataModelWithValidation
class LLMHeadingModelWithValidation(LLMHeadingModel):
heading: str = Field(
description="Item heading in about 6 words",
min_length=10,
max_length=40,
)
description: str = Field(
description="Item description in about 12 words.",
min_length=50,
max_length=120,
)
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
image_prompt: str = Field(
description="Item image prompt in about 10 words",
min_length=10,
max_length=100,
)
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
icon_query: str = Field(
description="Item icon query in about 4 words",
min_length=10,
max_length=40,
)
class LLMSlideContentModelWithValidation(LLMSlideContentModel):
title: str = Field(
description="Slide title in about 8 words",
min_length=10,
max_length=80,
)
class LLMType1ContentWithValidation(LLMType1Content):
body: str = Field(
description="Slide content summary in about 30 words.",
min_length=50,
max_length=300,
)
image_prompt: str = Field(
description="Slide image prompt in about 5 words",
min_length=10,
max_length=30,
)
class LLMType2ContentWithValidation(LLMType2Content):
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=4,
)
class LLMType3ContentWithValidation(LLMType3Content):
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=3,
max_length=3,
)
image_prompt: str = Field(
description="Slide image prompt in about 5 words",
min_length=10,
max_length=30,
)
class LLMType4ContentWithValidation(LLMType4Content):
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
class LLMType5ContentWithValidation(LLMType5Content):
body: str = Field(
description="Slide content summary in about 30 words.",
min_length=50,
max_length=300,
)
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
graph: LLMGraphModel = Field(description="Graph to show in slide")
class LLMType6ContentWithValidation(LLMType6Content):
description: str = Field(
description="Slide content summary in about 20 words.",
min_length=50,
max_length=300,
)
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
class LLMType7ContentWithValidation(LLMType7Content):
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=4,
)
class LLMType8ContentWithValidation(LLMType8Content):
description: str = Field(
description="Slide content summary in about 20 words.",
min_length=50,
max_length=300,
)
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
class LLMType9ContentWithValidation(LLMType9Content):
body: List[LLMHeadingModelWithValidation] = Field(
description="Items to show in slide",
min_length=1,
max_length=3,
)
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
graph: LLMGraphModel = Field(description="Graph to show in slide")
LLMContentUnionWithValidation = Union[
LLMType1ContentWithValidation,
LLMType2ContentWithValidation,
LLMType3ContentWithValidation,
LLMType4ContentWithValidation,
LLMType5ContentWithValidation,
LLMType6ContentWithValidation,
LLMType7ContentWithValidation,
LLMType8ContentWithValidation,
LLMType9ContentWithValidation,
]
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION: Mapping[
int, LLMContentUnionWithValidation
] = {
TYPE1: LLMType1ContentWithValidation,
TYPE2: LLMType2ContentWithValidation,
TYPE3: LLMType3ContentWithValidation,
TYPE4: LLMType4ContentWithValidation,
TYPE5: LLMType5ContentWithValidation,
TYPE6: LLMType6ContentWithValidation,
TYPE7: LLMType7ContentWithValidation,
TYPE8: LLMType8ContentWithValidation,
TYPE9: LLMType9ContentWithValidation,
}
class LLMSlideModelWithValidation(LLMSlideModel):
type: int
content: LLMContentUnionWithValidation
class LLMPresentationModelWithValidation(LLMPresentationModel):
slides: List[LLMSlideModelWithValidation]

View file

@ -1,24 +0,0 @@
from pydantic import BaseModel, Field
# 1. contains title, description and an image.
TYPE1 = 1
# 2. contains title and list of items.
TYPE2 = 2
# 3. contains title, list of items and an image.
TYPE3 = 3
# 4. contains title and list of items and multiple images.
TYPE4 = 4
# 5. contains title, description and a graph.
TYPE5 = 5
# 6. contains title, description and list of items.
TYPE6 = 6
# 7. contains title, list of items and icons.
TYPE7 = 7
# 8. contains title, description, list of items and icons.
TYPE8 = 8
# 9. contains title, list of items and a graph.
TYPE9 = 9
class SlideTypeModel(BaseModel):
slide_type: int = Field(gte=1, lte=9, description="Slide type from 1 to 9")

View file

@ -1,66 +0,0 @@
import uuid
from typing import List, Optional
from pydantic import BaseModel
from ppt_generator.models.content_type_models import (
CONTENT_TYPE_MAPPING,
Type1Content,
Type2Content,
Type3Content,
Type4Content,
Type5Content,
Type6Content,
Type7Content,
Type8Content,
Type9Content,
)
class SlideModel(BaseModel):
id: Optional[str] = None
index: int
type: int
design_index: Optional[int] = None
images: Optional[List[str]] = None
icons: Optional[List[str]] = None
presentation: str
content: (
Type1Content
| Type2Content
| Type3Content
| Type4Content
| Type5Content
| Type6Content
| Type7Content
| Type8Content
| Type9Content
)
properties: Optional[dict] = None
@classmethod
def from_dict(cls, data):
slide_model = cls(**data)
content = data["content"]
slide_model.content = CONTENT_TYPE_MAPPING[slide_model.type](**content)
return slide_model
def to_create_dict(self, auto_id: bool = False):
temp = self.model_dump(mode="json")
if not self.id:
if auto_id:
temp["id"] = str(uuid.uuid4())
else:
temp.pop("id")
return temp
@property
def images_count(self):
if not hasattr(self.content, "image_prompts"):
return 0
return len(self.content.image_prompts or [])
@property
def icons_count(self):
if not hasattr(self.content, "icon_queries"):
return 0
return len(self.content.icon_queries or [])

View file

@ -4,12 +4,19 @@ import argparse
if __name__ == "__main__":
os.makedirs("debug", exist_ok=True)
parser = argparse.ArgumentParser(description="Run the FastAPI server")
parser.add_argument(
"--port", type=int, required=True, help="Port number to run the server on"
)
parser.add_argument(
"--reload", type=bool, default=False, help="Reload the server on code changes"
)
args = parser.parse_args()
uvicorn.run("api.main:app", host="0.0.0.0", port=args.port, log_level="info")
uvicorn.run(
"api.main:APP",
host="0.0.0.0",
port=args.port,
log_level="info",
reload=args.reload,
)

View file

@ -0,0 +1,6 @@
from services.temp_file_service import TempFileService
from services.database import sql_engine
TEMP_FILE_SERVICE = TempFileService()
SQL_ENGINE = sql_engine

View file

@ -0,0 +1,25 @@
from contextlib import contextmanager
import os
from sqlalchemy import create_engine
from sqlmodel import Session
from utils.get_env import get_app_data_directory_env, get_database_url_env
database_url = get_database_url_env() or "sqlite:///" + os.path.join(
get_app_data_directory_env(), "fastapi.db"
)
connect_args = {}
if "sqlite" in database_url:
connect_args["check_same_thread"] = False
sql_engine = create_engine(database_url, connect_args=connect_args)
@contextmanager
def get_sql_session():
session = Session(sql_engine)
try:
yield session
finally:
session.close()

View file

@ -0,0 +1,120 @@
from http.client import HTTPException
import mimetypes
import os, pdfplumber, asyncio
from typing import List, Tuple
from docx import Document
from pptx import Presentation
from constants.documents import (
PDF_MIME_TYPES,
POWERPOINT_TYPES,
TEXT_MIME_TYPES,
WORD_TYPES,
)
class DocumentsLoader:
def __init__(self, documents: List[str]):
self._document_paths = documents
self._documents: List[str] = []
self._images: List[List[str]] = []
@property
def documents(self):
return self._documents
@property
def images(self):
return self._images
async def load_documents(
self,
temp_dir: str,
load_text: bool = True,
load_images: bool = False,
):
documents: List[str] = []
images: List[str] = []
for file_path in self._document_paths:
if not os.path.exists(file_path):
raise HTTPException(
status_code=404, detail=f"File {file_path} not found"
)
document = ""
imgs = []
mime_type = mimetypes.guess_type(file_path)[0]
if mime_type in PDF_MIME_TYPES:
document, imgs = await self.load_pdf(
file_path, load_text, load_images, temp_dir
)
elif mime_type in TEXT_MIME_TYPES:
document = await self.load_text(file_path)
elif mime_type in POWERPOINT_TYPES:
document = self.load_powerpoint(file_path)
elif mime_type in WORD_TYPES:
document = self.load_msword(file_path)
documents.append(document)
images.append(imgs)
self._documents = documents
self._images = images
async def load_pdf(
self,
file_path: str,
load_text: bool,
load_images: bool,
temp_dir: str,
) -> Tuple[str, List[str]]:
image_paths = []
document: str = ""
if load_text:
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
document += await asyncio.to_thread(page.extract_text)
if load_images:
image_paths = await self.get_page_images_from_pdf_async(file_path, temp_dir)
return document, image_paths
async def load_text(self, file_path: str) -> str:
with open(file_path, "r") as file:
return await asyncio.to_thread(file.read)
def load_msword(self, file_path: str) -> str:
document = Document(file_path)
text = "\n".join([paragraph.text for paragraph in document.paragraphs])
return text
def load_powerpoint(self, file_path: str) -> str:
presentation = Presentation(file_path)
extracted_text = ""
for index, slide in enumerate(presentation.slides):
extracted_text += f"# Slide {index + 1}\n"
for shape in slide.shapes:
if shape.has_text_frame:
for paragraph in shape.text_frame.paragraphs:
extracted_text += f"{paragraph.text}\n"
extracted_text += "\n"
extracted_text += "\n\n"
return extracted_text
def get_page_images_from_pdf(self, file_path: str, temp_dir: str):
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
img = page.to_image(resolution=300)
img.save(os.path.join(temp_dir, f"page_{page.page_number}.png"))
async def get_page_images_from_pdf_async(self, file_path: str, temp_dir: str):
return await asyncio.to_thread(
self.get_page_images_from_pdf, file_path, temp_dir
)

View file

@ -0,0 +1,65 @@
import os
import uuid
from typing import Optional, Union
class TempFileService:
def __init__(self):
self.base_dir = os.getenv("TEMP_DIRECTORY")
self.cleanup_base_dir()
os.makedirs(self.base_dir, exist_ok=True)
def create_dir_in_dir(self, base_dir: str, dir_name: Optional[str] = None) -> str:
temp_dir = os.path.join(base_dir, dir_name if dir_name else str(uuid.uuid4()))
os.makedirs(temp_dir, exist_ok=True)
return temp_dir
def create_temp_dir(self, dir_name: Optional[str] = None) -> str:
return self.create_dir_in_dir(self.base_dir, dir_name)
def create_temp_file_path(
self, file_path: str, dir_path: Optional[str] = None
) -> str:
if dir_path is None:
dir_path = self.base_dir
full_path = os.path.join(dir_path, file_path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
return full_path
def create_temp_file(
self, file_path: str, content: Union[bytes, str], dir_path: Optional[str] = None
) -> str:
file_path = self.create_temp_file_path(file_path, dir_path)
mode = "wb" if isinstance(content, bytes) else "w"
with open(file_path, mode) as f:
f.write(content)
return file_path
def read_temp_file(self, file_path: str, binary: bool = True) -> Union[bytes, str]:
mode = "rb" if binary else "r"
with open(file_path, mode) as f:
return f.read()
def cleanup_temp_file(self, file_path: str):
if os.path.exists(file_path):
os.remove(file_path)
def delete_dir_files(self, dir_path: str):
if os.path.exists(dir_path):
for root, dirs, files in os.walk(dir_path, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
def cleanup_temp_dir(self, dir_path: str):
if os.path.exists(dir_path):
self.delete_dir_files(dir_path)
os.rmdir(dir_path)
def cleanup_base_dir(self):
self.cleanup_temp_dir(self.base_dir)

View file

@ -0,0 +1,18 @@
from typing import Optional
from openai import AsyncOpenAI
from utils.llm_provider import get_llm_client
async def list_available_custom_models(
url: Optional[str] = None, api_key: Optional[str] = None
) -> list[str]:
if not url:
client = get_llm_client()
else:
client = AsyncOpenAI(api_key=api_key or "null", base_url=url)
models = []
async for model in client.models.list():
print(model)
models.append(model.id)
return models

View file

@ -0,0 +1,57 @@
import os
def get_can_change_keys_env():
return os.getenv("CAN_CHANGE_KEYS")
def get_database_url_env():
return os.getenv("DATABASE_URL")
def get_app_data_directory_env():
return os.getenv("APP_DATA_DIRECTORY")
def get_temp_directory_env():
return os.getenv("TEMP_DIRECTORY")
def get_user_config_path_env():
return os.getenv("USER_CONFIG_PATH")
def get_llm_provider_env():
return os.getenv("LLM")
def get_ollama_url_env():
return os.getenv("OLLAMA_URL")
def get_custom_llm_url_env():
return os.getenv("CUSTOM_URL")
def get_openai_api_key_env():
return os.getenv("OPENAI_API_KEY")
def get_google_api_key_env():
return os.getenv("GOOGLE_API_KEY")
def get_custom_llm_api_key_env():
return os.getenv("CUSTOM_LLM_API_KEY")
def get_ollama_model_env():
return os.getenv("OLLAMA_MODEL")
def get_custom_model_env():
return os.getenv("CUSTOM_MODEL")
def get_pexels_api_key_env():
return os.getenv("PEXELS_API_KEY")

View file

@ -0,0 +1,66 @@
from http.client import HTTPException
import os
from openai import AsyncOpenAI
from enums.llm_provider import LLMProvider
from utils.get_env import (
get_custom_llm_api_key_env,
get_custom_llm_url_env,
get_google_api_key_env,
get_llm_provider_env,
get_ollama_url_env,
get_openai_api_key_env,
)
def get_llm_provider():
return LLMProvider(get_llm_provider_env())
def get_ollama_url():
return get_ollama_url_env() or "http://localhost:11434"
def is_ollama_selected():
return get_llm_provider() == LLMProvider.OLLAMA
def is_custom_llm_selected():
return get_llm_provider() == LLMProvider.CUSTOM
def get_model_base_url():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return "https://api.openai.com/v1"
elif selected_llm == LLMProvider.GOOGLE:
return "https://generativelanguage.googleapis.com/v1beta/openai"
elif selected_llm == LLMProvider.OLLAMA:
return os.path.join(get_ollama_url(), "v1")
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_llm_url_env()
else:
raise HTTPException(f"LLM provider {selected_llm} is not supported")
def get_llm_api_key():
selected_llm = get_llm_provider()
if selected_llm == LLMProvider.OPENAI:
return get_openai_api_key_env()
elif selected_llm == LLMProvider.GOOGLE:
return get_google_api_key_env()
elif selected_llm == LLMProvider.OLLAMA:
return "ollama"
elif selected_llm == LLMProvider.CUSTOM:
return get_custom_llm_api_key_env() or "none"
else:
raise HTTPException(f"LLM provider {selected_llm} is not supported")
def get_llm_client():
client = AsyncOpenAI(
base_url=get_model_base_url(),
api_key=get_llm_api_key(),
)
return client

View file

@ -0,0 +1,60 @@
from http.client import HTTPException
import json
from typing import AsyncGenerator
import aiohttp
from models.ollama_model_status_response import OllamaModelStatusResponse
from utils.get_env import get_ollama_url_env
async def pull_ollama_model(model: str) -> AsyncGenerator[dict, None]:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{get_ollama_url_env()}/api/pull",
json={"model": model},
) as response:
if response.status != 200:
raise HTTPException(
status_code=response.status,
detail=f"Failed to pull model: {await response.text()}",
)
async for line in response.content:
if not line.strip():
continue
try:
event = json.loads(line.decode("utf-8"))
except json.JSONDecodeError:
continue
yield event
async def list_pulled_ollama_models() -> list[OllamaModelStatusResponse]:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{get_ollama_url_env()}/api/tags",
) as response:
if response.status == 200:
pulled_models = await response.json()
return [
OllamaModelStatusResponse(
name=m["model"],
size=m["size"],
status="pulled",
downloaded=m["size"],
done=True,
)
for m in pulled_models["models"]
]
elif response.status == 403:
raise HTTPException(
status_code=403,
detail="Forbidden: Please check your Ollama Configuration",
)
else:
raise HTTPException(
status_code=response.status,
detail=f"Failed to list Ollama models: {response.status}",
)

View file

@ -0,0 +1,45 @@
import os
def set_temp_directory_env(value):
os.environ["TEMP_DIRECTORY"] = value
def set_user_config_path_env(value):
os.environ["USER_CONFIG_PATH"] = value
def set_llm_provider_env(value):
os.environ["LLM"] = value
def set_ollama_url_env(value):
os.environ["OLLAMA_URL"] = value
def set_custom_llm_url_env(value):
os.environ["CUSTOM_URL"] = value
def set_openai_api_key_env(value):
os.environ["OPENAI_API_KEY"] = value
def set_google_api_key_env(value):
os.environ["GOOGLE_API_KEY"] = value
def set_custom_llm_api_key_env(value):
os.environ["CUSTOM_LLM_API_KEY"] = value
def set_ollama_model_env(value):
os.environ["OLLAMA_MODEL"] = value
def set_custom_model_env(value):
os.environ["CUSTOM_MODEL"] = value
def set_pexels_api_key_env(value):
os.environ["PEXELS_API_KEY"] = value

View file

@ -0,0 +1,75 @@
import os
import json
from models.user_config import UserConfig
from utils.get_env import (
get_custom_llm_api_key_env,
get_custom_llm_url_env,
get_custom_model_env,
get_google_api_key_env,
get_llm_provider_env,
get_ollama_model_env,
get_ollama_url_env,
get_openai_api_key_env,
get_pexels_api_key_env,
get_user_config_path_env,
)
from utils.set_env import (
set_custom_llm_api_key_env,
set_custom_llm_url_env,
set_custom_model_env,
set_google_api_key_env,
set_llm_provider_env,
set_ollama_model_env,
set_ollama_url_env,
set_openai_api_key_env,
set_pexels_api_key_env,
)
def get_user_config():
user_config_path = get_user_config_path_env()
existing_config = UserConfig()
try:
if os.path.exists(user_config_path):
with open(user_config_path, "r") as f:
existing_config = UserConfig(**json.load(f))
except Exception as e:
print("Error while loading user config")
pass
return UserConfig(
LLM=existing_config.LLM or get_llm_provider_env(),
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or get_openai_api_key_env(),
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or get_google_api_key_env(),
OLLAMA_URL=existing_config.OLLAMA_URL or get_ollama_url_env(),
OLLAMA_MODEL=existing_config.OLLAMA_MODEL or get_ollama_model_env(),
CUSTOM_LLM_URL=existing_config.CUSTOM_LLM_URL or get_custom_llm_url_env(),
CUSTOM_LLM_API_KEY=existing_config.CUSTOM_LLM_API_KEY
or get_custom_llm_api_key_env(),
CUSTOM_MODEL=existing_config.CUSTOM_MODEL or get_custom_model_env(),
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or get_pexels_api_key_env(),
)
def update_env_with_user_config():
user_config = get_user_config()
if user_config.LLM:
set_llm_provider_env(user_config.LLM)
if user_config.OPENAI_API_KEY:
set_openai_api_key_env(user_config.OPENAI_API_KEY)
if user_config.GOOGLE_API_KEY:
set_google_api_key_env(user_config.GOOGLE_API_KEY)
if user_config.OLLAMA_URL:
set_ollama_url_env(user_config.OLLAMA_URL)
if user_config.OLLAMA_MODEL:
set_ollama_model_env(user_config.OLLAMA_MODEL)
if user_config.CUSTOM_LLM_URL:
set_custom_llm_url_env(user_config.CUSTOM_LLM_URL)
if user_config.CUSTOM_LLM_API_KEY:
set_custom_llm_api_key_env(user_config.CUSTOM_LLM_API_KEY)
if user_config.CUSTOM_MODEL:
set_custom_model_env(user_config.CUSTOM_MODEL)
if user_config.PEXELS_API_KEY:
set_pexels_api_key_env(user_config.PEXELS_API_KEY)

View file

@ -0,0 +1,27 @@
from http.client import HTTPException
from typing import List
from fastapi import UploadFile
def validate_files(
field,
nullable: bool,
multiple: bool,
max_size: int,
accepted_types: List[str],
):
if field:
files: List[UploadFile] = field if multiple else [field]
for each_file in files:
if (max_size * 1024 * 1024) < each_file.size:
raise HTTPException(
400,
f"File '{each_file.filename}' exceeded max upload size of {max_size} MB",
)
elif each_file.content_type not in accepted_types:
raise HTTPException(400, f"File '{each_file.filename}' not accepted.")
elif not (field or nullable):
raise HTTPException(400, "File must be provided.")

View file

@ -0,0 +1,97 @@
import os
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from sqlmodel import SQLModel
from contextlib import asynccontextmanager
from api.models import SelectedLLMProvider
from api.routers.presentation.router import presentation_router
from api.services.database import sql_engine
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from api.utils.utils import update_env_with_user_config
from api.utils.model_utils import (
get_selected_llm_provider,
is_custom_llm_selected,
is_ollama_selected,
list_available_custom_models,
pull_ollama_model,
)
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
async def check_llm_model_availability():
if not can_change_keys:
if get_selected_llm_provider() == SelectedLLMProvider.OPENAI:
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise Exception("OPENAI_API_KEY must be provided")
elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE:
google_api_key = os.getenv("GOOGLE_API_KEY")
if not google_api_key:
raise Exception("GOOGLE_API_KEY must be provided")
elif is_ollama_selected():
ollama_model = os.getenv("OLLAMA_MODEL")
if not ollama_model:
raise Exception("OLLAMA_MODEL must be provided")
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
raise Exception(f"Model {ollama_model} is not supported")
print("-" * 50)
print("Pulling model: ", ollama_model)
async for event in pull_ollama_model(ollama_model):
print(event)
print("Pulled model: ", ollama_model)
print("-" * 50)
elif is_custom_llm_selected():
custom_model = os.getenv("CUSTOM_MODEL")
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
if not custom_model:
raise Exception("CUSTOM_MODEL must be provided")
if not custom_llm_url:
raise Exception("CUSTOM_LLM_URL must be provided")
if not custom_llm_api_key:
raise Exception("CUSTOM_LLM_API_KEY must be provided")
print("-" * 50)
print("Selecting model: ", custom_model)
models = await list_available_custom_models(
custom_llm_url, custom_llm_api_key
)
print("Available models: ", models)
print("-" * 50)
if custom_model not in models:
raise Exception(f"Model {custom_model} is not available")
@asynccontextmanager
async def lifespan(_: FastAPI):
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
SQLModel.metadata.create_all(sql_engine)
await check_llm_model_availability()
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def update_env_middleware(request: Request, call_next):
if can_change_keys:
update_env_with_user_config()
return await call_next(request)
app.include_router(presentation_router)

View file

@ -23,7 +23,6 @@ from ppt_generator.models.query_and_prompt_models import (
IconQueryCollectionWithData,
ImagePromptWithThemeAndAspectRatio,
)
from ppt_generator.models.slide_model import SlideModel
from ppt_generator.slide_generator import (
get_edited_slide_content_model,
get_slide_type_from_prompt,
@ -64,7 +63,7 @@ class PresentationEditHandler:
)
).first()
slide_to_edit = SlideModel.from_dict(slide_to_edit_sql.model_dump(mode="json"))
slide_to_edit = SlideSqlModel(**slide_to_edit_sql.model_dump(mode="json"))
new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit)
new_slide_type = new_slide_type.slide_type
@ -87,7 +86,7 @@ class PresentationEditHandler:
presentation.language,
)
new_slide_model = SlideModel(
new_slide_model = SlideSqlModel(
id=slide_to_edit.id,
index=slide_to_edit.index,
type=new_slide_type,

View file

@ -0,0 +1,16 @@
from sqlmodel import select
from api.models import LogMetadata
from api.services.database import get_sql_session
from api.services.logging import LoggingService
from api.sql_models import PresentationLayoutSqlModel
class ListPresentationLayoutsHandler:
def __init__(self):
pass
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
with get_sql_session() as sql_session:
layouts = sql_session.exec(select(PresentationLayoutSqlModel)).all()
return layouts

View file

@ -0,0 +1,13 @@
from sqlmodel import select
from api.models import LogMetadata
from api.services.database import get_sql_session
from api.services.logging import LoggingService
from api.sql_models import SlideLayoutSqlModel
class ListSlideLayoutsHandler:
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
with get_sql_session() as sql_session:
layouts = sql_session.exec(select(SlideLayoutSqlModel)).all()
return layouts

View file

@ -0,0 +1,19 @@
from api.models import LogMetadata
from api.routers.presentation.models import SavePresentationLayoutsRequest
from api.services.database import get_sql_session
from api.services.logging import LoggingService
class SavePresentationLayoutsHandler:
def __init__(self, data: SavePresentationLayoutsRequest):
self.data = data
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
with get_sql_session() as sql_session:
for layout in self.data.layouts:
sql_session.merge(layout)
sql_session.commit()
return self.data

View file

@ -0,0 +1,19 @@
from api.models import LogMetadata
from api.routers.presentation.models import SaveSlideLayoutsRequest
from api.services.database import get_sql_session
from api.services.logging import LoggingService
class SaveSlideLayoutsHandler:
def __init__(self, data: SaveSlideLayoutsRequest):
self.data = data
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
with get_sql_session() as sql_session:
for layout in self.data.layouts:
sql_session.merge(layout)
sql_session.commit()
return self.data

View file

@ -11,8 +11,12 @@ from ppt_generator.models.query_and_prompt_models import (
IconCategoryEnum,
ImagePromptWithThemeAndAspectRatio,
)
from ppt_generator.models.slide_model import SlideModel
from api.sql_models import PresentationSqlModel, SlideSqlModel
from api.sql_models import (
PresentationLayoutSqlModel,
PresentationSqlModel,
SlideLayoutSqlModel,
SlideSqlModel,
)
from ollama._types import ModelDetails
@ -126,7 +130,7 @@ class PresentationAndSlides(BaseModel):
class PresentationUpdateRequest(BaseModel):
presentation_id: str
slides: List[SlideModel]
slides: List[SlideSqlModel]
class PresentationAndUrl(BaseModel):
@ -197,3 +201,11 @@ class PresentationWithOneSlide(BaseModel):
title=presentation.title,
slide=slide,
)
class SaveSlideLayoutsRequest(BaseModel):
layouts: List[SlideLayoutSqlModel]
class SavePresentationLayoutsRequest(BaseModel):
layouts: List[PresentationLayoutSqlModel]

View file

@ -40,10 +40,20 @@ from api.routers.presentation.handlers.list_available_custom_models import (
from api.routers.presentation.handlers.list_ollama_pulled_models import (
ListPulledOllamaModelsHandler,
)
from api.routers.presentation.handlers.list_presentation_layouts import (
ListPresentationLayoutsHandler,
)
from api.routers.presentation.handlers.list_slide_layouts import ListSlideLayoutsHandler
from api.routers.presentation.handlers.list_supported_ollama_models import (
ListSupportedOllamaModelsHandler,
)
from api.routers.presentation.handlers.pull_ollama_model import PullOllamaModelHandler
from api.routers.presentation.handlers.save_presentation_layouts_handler import (
SavePresentationLayoutsHandler,
)
from api.routers.presentation.handlers.save_slide_layouts_handler import (
SaveSlideLayoutsHandler,
)
from api.routers.presentation.handlers.search_icon import SearchIconHandler
from api.routers.presentation.handlers.search_image import SearchImageHandler
from api.routers.presentation.handlers.update_parsed_document import (
@ -78,15 +88,21 @@ from api.routers.presentation.models import (
PresentationAndUrls,
PresentationGenerateRequest,
PresentationPathAndEditPath,
SavePresentationLayoutsRequest,
SaveSlideLayoutsRequest,
SearchIconRequest,
SearchImageRequest,
UpdatePresentationThemeRequest,
PresentationUpdateRequest,
PresentationWithOneSlide,
)
from api.sql_models import PresentationSqlModel
from api.sql_models import (
PresentationLayoutSqlModel,
PresentationSqlModel,
SlideLayoutSqlModel,
SlideSqlModel,
)
from api.utils.utils import handle_errors
from ppt_generator.models.slide_model import SlideModel
route_prefix = "/api/v1/ppt"
presentation_router = APIRouter(prefix=route_prefix)
@ -248,7 +264,7 @@ async def update_presentation(
)
@presentation_router.post("/edit", response_model=SlideModel)
@presentation_router.post("/edit", response_model=SlideSqlModel)
async def update_presentation(
data: EditPresentationSlideRequest,
):
@ -404,3 +420,51 @@ async def list_custom_models(
logging_service,
log_metadata,
)
@presentation_router.get(
"/layout/slides/list", response_model=List[SlideLayoutSqlModel]
)
async def list_slide_layouts():
request_utils = RequestUtils(f"{route_prefix}/layout/slides/list")
logging_service, log_metadata = await request_utils.initialize_logger()
return await handle_errors(
ListSlideLayoutsHandler().get,
logging_service,
log_metadata,
)
@presentation_router.post("/layout/slides/save")
async def save_slide_layouts(data: SaveSlideLayoutsRequest):
request_utils = RequestUtils(f"{route_prefix}/layout/slides/save")
logging_service, log_metadata = await request_utils.initialize_logger()
return await handle_errors(
SaveSlideLayoutsHandler(data).post,
logging_service,
log_metadata,
)
@presentation_router.get(
"/layout/presentations/list", response_model=List[PresentationLayoutSqlModel]
)
async def list_presentation_layouts():
request_utils = RequestUtils(f"{route_prefix}/layout/presentations/list")
logging_service, log_metadata = await request_utils.initialize_logger()
return await handle_errors(
ListPresentationLayoutsHandler().get,
logging_service,
log_metadata,
)
@presentation_router.post("/layout/presentations/save")
async def save_presentation_layouts(data: SavePresentationLayoutsRequest):
request_utils = RequestUtils(f"{route_prefix}/layout/presentations/save")
logging_service, log_metadata = await request_utils.initialize_logger()
return await handle_errors(
SavePresentationLayoutsHandler(data).post,
logging_service,
log_metadata,
)

View file

@ -8,13 +8,25 @@ def get_random_uuid() -> str:
return str(uuid.uuid4())
class SlideLayoutSqlModel(SQLModel, table=True):
id: str = Field(primary_key=True)
description: Optional[str] = None
json_schema: dict = Field(sa_column=Column(JSON, nullable=False))
class PresentationLayoutSqlModel(SQLModel, table=True):
id: str = Field(default_factory=get_random_uuid, primary_key=True)
name: str
description: Optional[str] = None
slide_layouts: List[str] = Field(sa_column=Column(JSON, nullable=False))
class PresentationSqlModel(SQLModel, table=True):
id: str = Field(default_factory=get_random_uuid, primary_key=True)
created_at: datetime = Field(default=datetime.now())
prompt: Optional[str] = None
n_slides: int
theme: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
file: Optional[str] = None
title: Optional[str] = None
structure: Optional[dict] = Field(
sa_column=Column(JSON, nullable=True), default=None
@ -27,21 +39,13 @@ class PresentationSqlModel(SQLModel, table=True):
)
language: Optional[str] = None
summary: Optional[str] = None
thumbnail: Optional[str] = None
data: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
class SlideSqlModel(SQLModel, table=True):
id: str = Field(default_factory=get_random_uuid, primary_key=True)
index: int = Field(index=True)
type: int
design_index: Optional[int] = None
images: Optional[List[str]] = Field(
sa_column=Column(JSON, nullable=True), default=None
)
icons: Optional[List[str]] = Field(
sa_column=Column(JSON, nullable=True), default=None
)
layout: str
presentation: str
content: dict = Field(sa_column=Column(JSON, nullable=False), default=None)
properties: Optional[dict] = Field(

View file

@ -160,8 +160,8 @@ async def handle_errors(
)
raise e
except Exception as e:
print(traceback.print_stack())
print(traceback.print_exc())
traceback.print_stack()
traceback.print_exc()
log_metadata.status_code = 400
logging_service.logger.critical(

View file

Before

Width:  |  Height:  |  Size: 4.6 KiB

After

Width:  |  Height:  |  Size: 4.6 KiB

View file

Before

Width:  |  Height:  |  Size: 3.9 KiB

After

Width:  |  Height:  |  Size: 3.9 KiB

View file

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

View file

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

View file

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 3.4 KiB

View file

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

View file

Before

Width:  |  Height:  |  Size: 3.6 KiB

After

Width:  |  Height:  |  Size: 3.6 KiB

View file

Before

Width:  |  Height:  |  Size: 4.5 KiB

After

Width:  |  Height:  |  Size: 4.5 KiB

View file

Before

Width:  |  Height:  |  Size: 3.5 KiB

After

Width:  |  Height:  |  Size: 3.5 KiB

View file

Before

Width:  |  Height:  |  Size: 5.1 KiB

After

Width:  |  Height:  |  Size: 5.1 KiB

View file

Before

Width:  |  Height:  |  Size: 2.9 KiB

After

Width:  |  Height:  |  Size: 2.9 KiB

View file

Before

Width:  |  Height:  |  Size: 5.3 KiB

After

Width:  |  Height:  |  Size: 5.3 KiB

View file

Before

Width:  |  Height:  |  Size: 5.9 KiB

After

Width:  |  Height:  |  Size: 5.9 KiB

View file

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

View file

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View file

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View file

Before

Width:  |  Height:  |  Size: 2.1 KiB

After

Width:  |  Height:  |  Size: 2.1 KiB

View file

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

View file

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

Some files were not shown because too many files have changed in this diff Show more