Merge branch 'feat/custom_schema_and_layout'
This commit is contained in:
commit
5e146e5148
1891 changed files with 600833 additions and 594156 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -7,6 +7,8 @@ __pycache__
|
|||
node_modules
|
||||
out
|
||||
user_data
|
||||
app_data
|
||||
tmp
|
||||
debug
|
||||
.fastembed_cache
|
||||
.fastembed_cache
|
||||
my-doc.txt
|
||||
|
|
@ -15,7 +15,7 @@ RUN apt-get update && apt-get install -y \
|
|||
WORKDIR /app
|
||||
|
||||
# Set environment variables
|
||||
ENV APP_DATA_DIRECTORY=/app/user_data
|
||||
ENV APP_DATA_DIRECTORY=/app_data
|
||||
ENV TEMP_DIRECTORY=/tmp/presenton
|
||||
|
||||
# Install ollama
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ WORKDIR /app
|
|||
RUN ls -a
|
||||
|
||||
# Set environment variables
|
||||
ENV APP_DATA_DIRECTORY=/app/user_data
|
||||
ENV APP_DATA_DIRECTORY=/app_data
|
||||
ENV TEMP_DIRECTORY=/tmp/presenton
|
||||
|
||||
# Install ollama
|
||||
|
|
@ -30,10 +30,10 @@ RUN pip install -r requirements.txt
|
|||
# Install dependencies for Next.js
|
||||
WORKDIR /node_dependencies
|
||||
COPY servers/nextjs/package.json servers/nextjs/package-lock.json ./
|
||||
RUN npm install
|
||||
RUN npm install
|
||||
|
||||
# Install chrome for puppeteer
|
||||
RUN npx puppeteer browsers install chrome@136.0.7103.92 --install-deps
|
||||
RUN npx puppeteer browsers install chrome@138.0.7204.94 --install-deps
|
||||
|
||||
RUN chmod -R 777 /node_dependencies
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ RUN chmod -R 777 /node_dependencies
|
|||
COPY nginx.conf /etc/nginx/nginx.conf
|
||||
|
||||
# Expose the port
|
||||
EXPOSE 80 3000 8000
|
||||
EXPOSE 80
|
||||
|
||||
# Start the servers
|
||||
CMD ["/bin/bash", "/app/docker-dev-start.sh"]
|
||||
16
README.md
16
README.md
|
|
@ -12,7 +12,7 @@
|
|||
# Open-Source AI Presentation Generator and API (Gamma Alternative)
|
||||
|
||||
|
||||
**Presenton** is an open-source application for generating presentations with AI — all running locally on your device. Stay in control of your data and privacy while using models like OpenAI and Gemini, or use your own hosted models through Ollama.
|
||||
**Presenton** is an open-source application for generating presentations with AI — all running locally on your device. Stay in control of your data and privacy while using models like OpenAI and Gemini, or use your own hosted models through Ollama.
|
||||
|
||||

|
||||
|
||||
|
|
@ -29,6 +29,7 @@
|
|||
* ✅ **API Presentation Generation** — Host as API to generate presentations over requests
|
||||
* ✅ **Ollama Support** — Run open-source models locally with Ollama integration
|
||||
* ✅ **OpenAI API Compatibility** — Use any OpenAI-compatible API endpoint with your own models
|
||||
* ✅ **Versatile Image Generation** — Choose from DALL-E 3, Gemini Flash, Pexels, or Pixabay for your visuals
|
||||
* ✅ **Runs Locally** — All code runs on your device
|
||||
* ✅ **Privacy-First** — No tracking, no data stored by us
|
||||
* ✅ **Flexible** — Generate presentations from prompts or outlines
|
||||
|
|
@ -70,7 +71,17 @@ You may want to directly provide your API KEYS as environment variables and keep
|
|||
- **CUSTOM_LLM_URL=[Custom OpenAI Compatible URL]**: Provide this if **LLM** is set to **custom**
|
||||
- **CUSTOM_LLM_API_KEY=[Custom OpenAI Compatible API KEY]**: Provide this if **LLM** is set to **custom**
|
||||
- **CUSTOM_MODEL=[Custom Model ID]**: Provide this if **LLM** is set to **custom**
|
||||
- **PEXELS_API_KEY=[Your Pexels API Key]**: Provide this to generate images if **LLM** is set to **ollama** or **custom**
|
||||
|
||||
You can also set the following environment variables to customize the image generation provider and API keys:
|
||||
|
||||
- **IMAGE_PROVIDER=[pexels/pixabay/gemini_flash/dall-e-3]**: Select the image provider of your choice.
|
||||
- Defaults to **dall-e-3** for OpenAI models, **gemini_flash** for Google models if not set.
|
||||
- **PEXELS_API_KEY=[Your Pexels API Key]**: Required if using **pexels** as the image provider.
|
||||
- **PIXABAY_API_KEY=[Your Pixabay API Key]**: Required if using **pixabay** as the image provider.
|
||||
- **GOOGLE_API_KEY=[Your Google API Key]**: Required if using **gemini_flash** as the image provider.
|
||||
- **OPENAI_API_KEY=[Your OpenAI API Key]**: Required if using **dall-e-3** as the image provider.
|
||||
|
||||
> **Note:** You can freely choose both the LLM (text generation) and the image provider. Supported image providers: **pexels**, **pixabay**, **gemini_flash** (Google), and **dall-e-3** (OpenAI).
|
||||
|
||||
### Using OpenAI
|
||||
```bash
|
||||
|
|
@ -199,4 +210,3 @@ For detailed info checkout [API documentation](https://docs.presenton.ai/using-p
|
|||
## License
|
||||
|
||||
Apache 2.0
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ services:
|
|||
# You can replace 5000 with any other port number of your choice to run Presenton on a different port number.
|
||||
- "5000:80"
|
||||
volumes:
|
||||
- ./user_data:/app/user_data
|
||||
- ./app_data:/app_data
|
||||
environment:
|
||||
- CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS}
|
||||
- LLM=${LLM}
|
||||
|
|
@ -38,7 +38,7 @@ services:
|
|||
# You can replace 5000 with any other port number of your choice to run Presenton on a different port number.
|
||||
- "5000:80"
|
||||
volumes:
|
||||
- ./user_data:/app/user_data
|
||||
- ./app_data:/app_data
|
||||
environment:
|
||||
- CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS}
|
||||
- LLM=${LLM}
|
||||
|
|
@ -58,10 +58,9 @@ services:
|
|||
dockerfile: Dockerfile.dev
|
||||
ports:
|
||||
- "5000:80"
|
||||
- "3000:3000"
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- .:/app
|
||||
- ./app_data:/app_data
|
||||
environment:
|
||||
- NODE_ENV=development
|
||||
- CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS}
|
||||
|
|
@ -89,10 +88,9 @@ services:
|
|||
capabilities: [gpu]
|
||||
ports:
|
||||
- "5000:80"
|
||||
- "3000:3000"
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- .:/app
|
||||
- ./app_data:/app_data
|
||||
environment:
|
||||
- NODE_ENV=development
|
||||
- CAN_CHANGE_KEYS=${CAN_CHANGE_KEYS}
|
||||
|
|
|
|||
28
nginx.conf
28
nginx.conf
|
|
@ -16,6 +16,10 @@ http {
|
|||
|
||||
location / {
|
||||
proxy_pass http://localhost:3000;
|
||||
proxy_http_version 1.1; # Required for WebSocket
|
||||
proxy_set_header Upgrade $http_upgrade; # WebSocket header
|
||||
proxy_set_header Connection "upgrade"; # WebSocket header
|
||||
proxy_set_header Host $host;
|
||||
proxy_read_timeout 30m;
|
||||
proxy_connect_timeout 30m;
|
||||
}
|
||||
|
|
@ -25,5 +29,29 @@ http {
|
|||
proxy_read_timeout 30m;
|
||||
proxy_connect_timeout 30m;
|
||||
}
|
||||
|
||||
location /static {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_read_timeout 30m;
|
||||
proxy_connect_timeout 30m;
|
||||
}
|
||||
|
||||
location /app_data {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_read_timeout 30m;
|
||||
proxy_connect_timeout 30m;
|
||||
}
|
||||
|
||||
location /docs {
|
||||
proxy_pass http://localhost:8000/docs;
|
||||
proxy_read_timeout 30m;
|
||||
proxy_connect_timeout 30m;
|
||||
}
|
||||
|
||||
location /openapi.json {
|
||||
proxy_pass http://localhost:8000/openapi.json;
|
||||
proxy_read_timeout 30m;
|
||||
proxy_connect_timeout 30m;
|
||||
}
|
||||
}
|
||||
}
|
||||
24
package-lock.json
generated
Normal file
24
package-lock.json
generated
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"name": "presenton",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"dependencies": {
|
||||
"uuid": "^11.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/uuid": {
|
||||
"version": "11.1.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.0.tgz",
|
||||
"integrity": "sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==",
|
||||
"funding": [
|
||||
"https://github.com/sponsors/broofa",
|
||||
"https://github.com/sponsors/ctavan"
|
||||
],
|
||||
"bin": {
|
||||
"uuid": "dist/esm/bin/uuid"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
5
package.json
Normal file
5
package.json
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"dependencies": {
|
||||
"uuid": "^11.1.0"
|
||||
}
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
{}
|
||||
22
servers/fastapi/api/lifespan.py
Normal file
22
servers/fastapi/api/lifespan.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from services import SQL_ENGINE
|
||||
from utils.get_env import get_app_data_directory_env
|
||||
from utils.model_availability import check_llm_and_image_provider_api_or_model_availability
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(_: FastAPI):
|
||||
"""
|
||||
Lifespan context manager for FastAPI application.
|
||||
Initializes the application data directory and checks LLM model availability.
|
||||
|
||||
"""
|
||||
os.makedirs(get_app_data_directory_env(), exist_ok=True)
|
||||
SQLModel.metadata.create_all(SQL_ENGINE)
|
||||
await check_llm_and_image_provider_api_or_model_availability()
|
||||
yield
|
||||
|
|
@ -1,83 +1,38 @@
|
|||
import asyncio
|
||||
import os
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlmodel import SQLModel
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from api.lifespan import app_lifespan
|
||||
from api.middlewares import UserConfigEnvUpdateMiddleware
|
||||
from api.v1.ppt.router import API_V1_PPT_ROUTER
|
||||
from utils.asset_directory_utils import get_exports_directory, get_images_directory, get_uploads_directory
|
||||
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.routers.presentation.router import presentation_router
|
||||
from api.services.database import sql_engine
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from api.utils.utils import update_env_with_user_config
|
||||
from api.utils.model_utils import (
|
||||
get_selected_llm_provider,
|
||||
is_custom_llm_selected,
|
||||
is_ollama_selected,
|
||||
list_available_custom_models,
|
||||
pull_ollama_model,
|
||||
|
||||
app = FastAPI(lifespan=app_lifespan)
|
||||
|
||||
|
||||
# Routers
|
||||
app.include_router(API_V1_PPT_ROUTER)
|
||||
|
||||
# Static files
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
app.mount(
|
||||
"/app_data/images",
|
||||
StaticFiles(directory=get_images_directory()),
|
||||
name="app_data/images",
|
||||
)
|
||||
app.mount(
|
||||
"/app_data/exports",
|
||||
StaticFiles(directory=get_exports_directory()),
|
||||
name="app_data/exports",
|
||||
)
|
||||
app.mount(
|
||||
"/app_data/uploads",
|
||||
StaticFiles(directory=get_uploads_directory()),
|
||||
name="app_data/uploads",
|
||||
)
|
||||
|
||||
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
|
||||
|
||||
|
||||
async def check_llm_model_availability():
|
||||
if not can_change_keys:
|
||||
if get_selected_llm_provider() == SelectedLLMProvider.OPENAI:
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
|
||||
elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE:
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
|
||||
elif is_ollama_selected():
|
||||
ollama_model = os.getenv("OLLAMA_MODEL")
|
||||
if not ollama_model:
|
||||
raise Exception("OLLAMA_MODEL must be provided")
|
||||
|
||||
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise Exception(f"Model {ollama_model} is not supported")
|
||||
|
||||
print("-" * 50)
|
||||
print("Pulling model: ", ollama_model)
|
||||
async for event in pull_ollama_model(ollama_model):
|
||||
print(event)
|
||||
print("Pulled model: ", ollama_model)
|
||||
print("-" * 50)
|
||||
|
||||
elif is_custom_llm_selected():
|
||||
custom_model = os.getenv("CUSTOM_MODEL")
|
||||
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
|
||||
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
if not custom_model:
|
||||
raise Exception("CUSTOM_MODEL must be provided")
|
||||
if not custom_llm_url:
|
||||
raise Exception("CUSTOM_LLM_URL must be provided")
|
||||
if not custom_llm_api_key:
|
||||
raise Exception("CUSTOM_LLM_API_KEY must be provided")
|
||||
print("-" * 50)
|
||||
print("Selecting model: ", custom_model)
|
||||
models = await list_available_custom_models(
|
||||
custom_llm_url, custom_llm_api_key
|
||||
)
|
||||
print("Available models: ", models)
|
||||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
|
||||
SQLModel.metadata.create_all(sql_engine)
|
||||
await check_llm_model_availability()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
# Middlewares
|
||||
origins = ["*"]
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
@ -87,12 +42,4 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def update_env_middleware(request: Request, call_next):
|
||||
if can_change_keys:
|
||||
update_env_with_user_config()
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
app.include_router(presentation_router)
|
||||
app.add_middleware(UserConfigEnvUpdateMiddleware)
|
||||
|
|
|
|||
13
servers/fastapi/api/middlewares.py
Normal file
13
servers/fastapi/api/middlewares.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from utils.get_env import get_can_change_keys_env
|
||||
from utils.user_config import update_env_with_user_config
|
||||
|
||||
|
||||
class UserConfigEnvUpdateMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if get_can_change_keys_env() != "false":
|
||||
update_env_with_user_config()
|
||||
return await call_next(request)
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
from enum import Enum
|
||||
import json
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.sql_models import PresentationSqlModel
|
||||
|
||||
|
||||
class LogMetadata(BaseModel):
|
||||
presentation: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
status_code: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_presentation(
|
||||
cls, presentation: PresentationSqlModel, endpoint: Optional[str] = None
|
||||
):
|
||||
return cls(
|
||||
presentation=presentation.id,
|
||||
title=presentation.title,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
@property
|
||||
def stream_name(self):
|
||||
return f"Endpoint - {self.endpoint}, Presentation - {self.presentation}"
|
||||
|
||||
|
||||
class SessionModel(BaseModel):
|
||||
session: str
|
||||
|
||||
|
||||
class SSEResponse(BaseModel):
|
||||
event: str
|
||||
data: str
|
||||
|
||||
def to_string(self):
|
||||
return f"event: {self.event}\ndata: {self.data}\n\n"
|
||||
|
||||
|
||||
class SSEStatusResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "status", "status": self.status})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSECompleteResponse(BaseModel):
|
||||
key: str
|
||||
value: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "complete", self.key: self.value}),
|
||||
).to_string()
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
LLM: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
OLLAMA_URL: Optional[str] = None
|
||||
OLLAMA_MODEL: Optional[str] = None
|
||||
CUSTOM_LLM_URL: Optional[str] = None
|
||||
CUSTOM_LLM_API_KEY: Optional[str] = None
|
||||
CUSTOM_MODEL: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
|
||||
|
||||
class OllamaModelMetadata(BaseModel):
|
||||
label: str
|
||||
value: str
|
||||
description: str
|
||||
icon: str
|
||||
size: str
|
||||
supports_graph: bool
|
||||
|
||||
|
||||
class SelectedLLMProvider(Enum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
CUSTOM = "custom"
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class RequestUtils:
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
async def initialize_logger(
|
||||
self,
|
||||
presentation_id: Optional[str] = None,
|
||||
):
|
||||
metadata = LogMetadata(presentation=presentation_id, endpoint=self.endpoint)
|
||||
logging_service = LoggingService(metadata.stream_name)
|
||||
|
||||
return logging_service, metadata
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
DecomposeDocumentsRequest,
|
||||
DecomposeDocumentsResponse,
|
||||
)
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from document_processor.loader import DocumentsLoader
|
||||
|
||||
|
||||
class DecomposeDocumentsHandler:
|
||||
|
||||
def __init__(self, data: DecomposeDocumentsRequest):
|
||||
self.data = data
|
||||
self.documents = list(
|
||||
filter(lambda doc: not doc.endswith(".csv"), self.data.documents or [])
|
||||
)
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
documents_loader = DocumentsLoader(self.documents)
|
||||
await documents_loader.load_documents(self.temp_dir)
|
||||
parsed_documents = documents_loader.documents
|
||||
|
||||
document_paths = []
|
||||
for parsed_doc in parsed_documents:
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{str(uuid.uuid4())}.txt", self.temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
document_paths.append(file_path)
|
||||
|
||||
documents = {}
|
||||
for index, each in enumerate(self.documents):
|
||||
documents[each] = document_paths[index]
|
||||
|
||||
response = DecomposeDocumentsResponse(
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.utils import get_presentation_dir
|
||||
|
||||
|
||||
class DeletePresentationHandler:
|
||||
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.id)
|
||||
|
||||
async def delete(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"presentation": self.id}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.id)
|
||||
sql_session.delete(presentation)
|
||||
sql_session.commit()
|
||||
|
||||
if os.path.exists(self.presentation_dir):
|
||||
shutil.rmtree(self.presentation_dir)
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.database import get_sql_session
|
||||
from api.sql_models import SlideSqlModel
|
||||
|
||||
|
||||
class DeleteSlideHandler:
|
||||
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
async def delete(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"slide": self.id}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
slide = sql_session.get(SlideSqlModel, self.id)
|
||||
sql_session.delete(slide)
|
||||
sql_session.commit()
|
||||
|
|
@ -1,205 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
EditPresentationSlideRequest,
|
||||
)
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from api.utils.utils import (
|
||||
get_presentation_dir,
|
||||
get_presentation_images_dir,
|
||||
)
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
IconQueryCollectionWithData,
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from ppt_generator.slide_generator import (
|
||||
get_edited_slide_content_model,
|
||||
get_slide_type_from_prompt,
|
||||
)
|
||||
from ppt_generator.slide_model_utils import SlideModelUtils
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class PresentationEditHandler:
|
||||
def __init__(self, data: EditPresentationSlideRequest):
|
||||
self.data = data
|
||||
self.presentation_id = data.presentation_id
|
||||
|
||||
self.slide_index = data.index
|
||||
self.prompt = data.prompt
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
slide_to_edit_sql = sql_session.exec(
|
||||
select(SlideSqlModel).where(
|
||||
SlideSqlModel.index == self.slide_index,
|
||||
SlideSqlModel.presentation == self.presentation_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
slide_to_edit = SlideModel.from_dict(slide_to_edit_sql.model_dump(mode="json"))
|
||||
new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit)
|
||||
new_slide_type = new_slide_type.slide_type
|
||||
|
||||
supports_graph = not is_custom_llm_selected()
|
||||
if is_ollama_selected():
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
supports_graph = model.supports_graph
|
||||
|
||||
if not supports_graph:
|
||||
if new_slide_type == 5:
|
||||
new_slide_type = 1
|
||||
elif new_slide_type == 9:
|
||||
new_slide_type = 6
|
||||
|
||||
edited_content = await get_edited_slide_content_model(
|
||||
self.prompt,
|
||||
new_slide_type,
|
||||
slide_to_edit,
|
||||
presentation.theme,
|
||||
presentation.language,
|
||||
)
|
||||
|
||||
new_slide_model = SlideModel(
|
||||
id=slide_to_edit.id,
|
||||
index=slide_to_edit.index,
|
||||
type=new_slide_type,
|
||||
design_index=slide_to_edit.design_index,
|
||||
images=None,
|
||||
icons=None,
|
||||
presentation=slide_to_edit.presentation,
|
||||
properties=slide_to_edit.properties,
|
||||
content=edited_content.to_content(),
|
||||
)
|
||||
|
||||
new_slide_images_count = new_slide_model.images_count
|
||||
new_slide_icons_count = new_slide_model.icons_count
|
||||
|
||||
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
|
||||
|
||||
new_slide_images: dict[int, str | ImagePromptWithThemeAndAspectRatio] = {}
|
||||
new_slide_icons: dict[int, str | IconQueryCollectionWithData] = {}
|
||||
|
||||
# ? Checks if image prompts have changed
|
||||
# ? If they have, it will search if it is same as the old one but used at a different index
|
||||
# ? If it is, it will use the old image
|
||||
# ? If it is not, it will generate a new image
|
||||
if new_slide_images_count:
|
||||
new_image_prompts = slide_model_utils.get_image_prompts()
|
||||
old_image_prompts = (
|
||||
slide_to_edit.content.image_prompts
|
||||
if slide_to_edit.images_count
|
||||
else []
|
||||
)
|
||||
for index in range(new_slide_images_count):
|
||||
new_prompt = new_slide_model.content.image_prompts[index]
|
||||
for old_prompt in old_image_prompts:
|
||||
if old_prompt != new_prompt:
|
||||
continue
|
||||
if index < len(slide_to_edit.images or []):
|
||||
new_slide_images[index] = slide_to_edit.images[index]
|
||||
break
|
||||
if not new_slide_images.get(index):
|
||||
new_slide_images[index] = new_image_prompts[index]
|
||||
|
||||
# ? Checks if icon queries have changed
|
||||
# ? If they have, it will search if it is same as the old one but used at a different index
|
||||
# ? If it is, it will use the old icon
|
||||
# ? If it is not, it will generate a new icon
|
||||
if new_slide_icons_count:
|
||||
new_icon_queries = slide_model_utils.get_icon_queries()
|
||||
old_icon_queries = (
|
||||
slide_to_edit.content.icon_queries if slide_to_edit.icons_count else []
|
||||
)
|
||||
for index in range(new_slide_icons_count):
|
||||
new_query = new_slide_model.content.icon_queries[index]
|
||||
for old_query in old_icon_queries:
|
||||
if old_query != new_query:
|
||||
continue
|
||||
if index < len(slide_to_edit.icons or []):
|
||||
new_slide_icons[index] = slide_to_edit.icons[index]
|
||||
break
|
||||
if not new_slide_icons.get(index):
|
||||
new_slide_icons[index] = new_icon_queries[index]
|
||||
|
||||
images_to_generate = []
|
||||
for each in new_slide_images.values():
|
||||
if isinstance(each, ImagePromptWithThemeAndAspectRatio):
|
||||
images_to_generate.append(each)
|
||||
|
||||
icons_to_generate = []
|
||||
for each in new_slide_icons.values():
|
||||
if isinstance(each, IconQueryCollectionWithData):
|
||||
icons_to_generate.append(each)
|
||||
|
||||
images_directory = get_presentation_images_dir(self.presentation_id)
|
||||
if icons_to_generate:
|
||||
icons_vectorstore = get_icons_vectorstore()
|
||||
|
||||
coroutines = [
|
||||
generate_image(each_prompt, images_directory)
|
||||
for each_prompt in images_to_generate
|
||||
] + [
|
||||
get_icon(icons_vectorstore, each_query) for each_query in icons_to_generate
|
||||
]
|
||||
generated_assets = await asyncio.gather(*coroutines)
|
||||
generated_image_count = len(images_to_generate)
|
||||
generate_images = generated_assets[:generated_image_count]
|
||||
generate_icons = generated_assets[generated_image_count:]
|
||||
|
||||
for each in new_slide_images:
|
||||
if isinstance(new_slide_images[each], ImagePromptWithThemeAndAspectRatio):
|
||||
new_slide_images[each] = generate_images.pop(0)
|
||||
|
||||
for each in new_slide_icons:
|
||||
if isinstance(new_slide_icons[each], IconQueryCollectionWithData):
|
||||
new_slide_icons[each] = generate_icons.pop(0)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.exec(
|
||||
update(SlideSqlModel)
|
||||
.where(SlideSqlModel.id == slide_to_edit.id)
|
||||
.values(
|
||||
type=new_slide_type,
|
||||
images=list(new_slide_images.values()),
|
||||
icons=list(new_slide_icons.values()),
|
||||
content=new_slide_model.content.model_dump(mode="json"),
|
||||
)
|
||||
)
|
||||
sql_session.commit()
|
||||
slide_to_edit_sql = sql_session.exec(
|
||||
select(SlideSqlModel).where(SlideSqlModel.id == slide_to_edit.id)
|
||||
).first()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(slide_to_edit_sql.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return slide_to_edit_sql
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.mixins.fetch_presentation_assets import (
|
||||
FetchPresentationAssetsMixin,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
ExportAsRequest,
|
||||
PresentationAndPath,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.utils.utils import get_presentation_dir, sanitize_filename
|
||||
from ppt_generator.pptx_presentation_creator import PptxPresentationCreator
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class ExportAsPptxHandler(FetchPresentationAssetsMixin):
|
||||
|
||||
def __init__(self, data: ExportAsRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.data.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
await self.fetch_presentation_assets()
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
|
||||
ppt_path = os.path.join(
|
||||
self.presentation_dir,
|
||||
sanitize_filename(f"{presentation.title}.pptx")
|
||||
)
|
||||
ppt_creator = PptxPresentationCreator(self.data.pptx_model, self.temp_dir)
|
||||
ppt_creator.create_ppt()
|
||||
ppt_creator.save(ppt_path)
|
||||
|
||||
response = PresentationAndPath(
|
||||
presentation_id=self.data.presentation_id, path=ppt_path
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
presentation.file = ppt_path
|
||||
sql_session.commit()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException
|
||||
from api.models import LogMetadata, SessionModel
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
SUPPORTED_OLLAMA_MODELS,
|
||||
)
|
||||
from api.routers.presentation.models import PresentationGenerateRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel
|
||||
from ppt_config_generator.structure_generator import generate_presentation_structure
|
||||
|
||||
SLIDES_WITHOUT_GRAPH = [2, 4, 6, 7, 8]
|
||||
|
||||
|
||||
class PresentationGenerateDataHandler:
|
||||
|
||||
def __init__(self, data: PresentationGenerateRequest):
|
||||
self.data = data
|
||||
self.session = str(uuid.uuid4())
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump()),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
if not self.data.outlines:
|
||||
raise HTTPException(400, "Outlines can not be empty")
|
||||
|
||||
key_value_model = KeyValueSqlModel(
|
||||
id=self.session,
|
||||
key=self.session,
|
||||
value=self.data.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
presentation_structure = await generate_presentation_structure(
|
||||
PresentationMarkdownModel(
|
||||
**{
|
||||
"title": presentation.title,
|
||||
"slides": presentation.outlines,
|
||||
"notes": presentation.notes,
|
||||
}
|
||||
)
|
||||
)
|
||||
supports_graph = not is_custom_llm_selected()
|
||||
if is_ollama_selected():
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
supports_graph = model.supports_graph
|
||||
|
||||
for each in presentation_structure.slides:
|
||||
if each.type > 9:
|
||||
each.type = random.choice(SLIDES_WITHOUT_GRAPH)
|
||||
if each.type == 3:
|
||||
each.type = 6
|
||||
if not supports_graph:
|
||||
if each.type == 5:
|
||||
each.type = 1
|
||||
elif each.type == 9:
|
||||
each.type = 6
|
||||
|
||||
presentation_outlines_len = len(presentation.outlines)
|
||||
missing_slides_len = presentation_outlines_len - len(
|
||||
presentation_structure.slides
|
||||
)
|
||||
if missing_slides_len > 0:
|
||||
for index in range(missing_slides_len):
|
||||
selected_type = (
|
||||
random.choice(SLIDES_WITHOUT_GRAPH)
|
||||
if index != missing_slides_len - 1
|
||||
else 1
|
||||
)
|
||||
presentation_structure.slides.append(
|
||||
SlideStructureModel(type=selected_type)
|
||||
)
|
||||
elif missing_slides_len < 0:
|
||||
presentation_structure.slides = presentation_structure.slides[
|
||||
:presentation_outlines_len
|
||||
]
|
||||
|
||||
presentation.structure = presentation_structure.model_dump(mode="json")
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(key_value_model)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(key_value_model)
|
||||
|
||||
response = SessionModel(session=self.session)
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
GenerateImageRequest,
|
||||
PresentationAndPaths,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.utils.utils import get_presentation_dir, get_presentation_images_dir
|
||||
from image_processor.images_finder import generate_image
|
||||
|
||||
|
||||
class GenerateImageHandler:
|
||||
|
||||
def __init__(self, data: GenerateImageRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.data.presentation_id)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
images_directory = get_presentation_images_dir(self.data.presentation_id)
|
||||
image_path = await generate_image(self.data.prompt, images_directory)
|
||||
|
||||
response = PresentationAndPaths(
|
||||
presentation_id=self.data.presentation_id, paths=[image_path]
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
import uuid
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GenerateOutlinesRequest
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from ppt_config_generator.ppt_outlines_generator import generate_ppt_content
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class PresentationOutlinesGenerateHandler:
|
||||
def __init__(self, data: GenerateOutlinesRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
|
||||
presentation_content = await generate_ppt_content(
|
||||
presentation.prompt,
|
||||
presentation.n_slides,
|
||||
presentation.language,
|
||||
presentation.summary,
|
||||
)
|
||||
presentation_content.slides = presentation_content.slides[
|
||||
: presentation.n_slides
|
||||
]
|
||||
|
||||
presentation.title = presentation_content.title
|
||||
presentation.outlines = [
|
||||
each.model_dump() for each in presentation_content.slides
|
||||
]
|
||||
presentation.notes = presentation_content.notes
|
||||
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(presentation.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return presentation
|
||||
|
|
@ -1,185 +0,0 @@
|
|||
import json
|
||||
from typing import List
|
||||
import uuid, aiohttp
|
||||
from fastapi import HTTPException
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.handlers.export_as_pptx import ExportAsPptxHandler
|
||||
from api.routers.presentation.handlers.upload_files import UploadFilesHandler
|
||||
from api.routers.presentation.mixins.fetch_assets_on_generation import (
|
||||
FetchAssetsOnPresentationGenerationMixin,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
ExportAsRequest,
|
||||
GeneratePresentationRequest,
|
||||
PresentationAndPath,
|
||||
PresentationPathAndEditPath,
|
||||
)
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_config_generator.ppt_outlines_generator import generate_ppt_content
|
||||
from ppt_generator.generator import generate_presentation
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
|
||||
class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
def __init__(self, presentation_id: str, data: GeneratePresentationRequest):
|
||||
self.session = str(uuid.uuid4())
|
||||
self.presentation_id = presentation_id
|
||||
self.data = data
|
||||
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Ollama is not currently supported for this endpoint",
|
||||
)
|
||||
|
||||
documents_and_images_path = await UploadFilesHandler(
|
||||
documents=self.data.documents,
|
||||
images=None,
|
||||
).post(logging_service, log_metadata)
|
||||
|
||||
summary = None
|
||||
if documents_and_images_path.documents:
|
||||
documents_loader = DocumentsLoader(documents_and_images_path.documents)
|
||||
await documents_loader.load_documents(self.temp_dir)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generating Document Summary")
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generating PPT Outline")
|
||||
presentation_content = await generate_ppt_content(
|
||||
self.data.prompt,
|
||||
self.data.n_slides,
|
||||
self.data.language,
|
||||
summary,
|
||||
)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generating Presentation")
|
||||
presentation_text = await generate_presentation(
|
||||
PresentationMarkdownModel(
|
||||
title=presentation_content.title,
|
||||
slides=presentation_content.slides,
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
)
|
||||
|
||||
print("-" * 40)
|
||||
print("Parsing Presentation")
|
||||
presentation_json = json.loads(presentation_text)
|
||||
|
||||
slide_models: List[SlideModel] = []
|
||||
for i, slide in enumerate(presentation_json["slides"]):
|
||||
slide["index"] = i
|
||||
slide["presentation"] = self.presentation_id
|
||||
slide["content"] = (
|
||||
LLM_CONTENT_TYPE_MAPPING[slide["type"]](**slide["content"])
|
||||
.to_content()
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
slide_model = SlideModel(**slide)
|
||||
slide_models.append(slide_model)
|
||||
|
||||
print("-" * 40)
|
||||
print("Fetching Theme Colors")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost/api/get-theme-from-name?theme={self.data.theme.value}",
|
||||
) as response:
|
||||
self.theme = await response.json()
|
||||
|
||||
print("-" * 40)
|
||||
print("Fetching Slide Assets")
|
||||
async for result in self.fetch_slide_assets(slide_models):
|
||||
print(result)
|
||||
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models
|
||||
]
|
||||
|
||||
presentation = PresentationSqlModel(
|
||||
id=self.presentation_id,
|
||||
prompt=self.data.prompt,
|
||||
n_slides=self.data.n_slides,
|
||||
language=self.data.language,
|
||||
summary=summary,
|
||||
theme=self.theme,
|
||||
title=presentation_content.title,
|
||||
outlines=[each.model_dump() for each in presentation_content.slides],
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slide_sql_models)
|
||||
sql_session.commit()
|
||||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
|
||||
if self.data.export_as == "pptx":
|
||||
print("-" * 40)
|
||||
print("Fetching Slide Metadata for Export")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost/api/slide-metadata",
|
||||
json={
|
||||
"id": self.presentation_id,
|
||||
|
||||
},
|
||||
) as response:
|
||||
export_request_body = await response.json()
|
||||
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation")
|
||||
export_request_body["presentation_id"] = self.presentation_id
|
||||
print(export_request_body)
|
||||
export_request = ExportAsRequest(**export_request_body)
|
||||
|
||||
presentation_and_path = await ExportAsPptxHandler(export_request).post(
|
||||
logging_service, log_metadata
|
||||
)
|
||||
|
||||
else:
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation as PDF")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost/api/export-as-pdf",
|
||||
json={
|
||||
"id": self.presentation_id,
|
||||
"title": presentation_content.title,
|
||||
},
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
presentation_and_path = PresentationAndPath(
|
||||
presentation_id=self.presentation_id,
|
||||
path=response_json["path"].replace("app", "static"),
|
||||
)
|
||||
|
||||
presentation_and_path.path = presentation_and_path.path.replace("app", "static")
|
||||
return PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={self.presentation_id}",
|
||||
)
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GeneratePresentationRequirementsRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
|
||||
|
||||
class GeneratePresentationRequirementsHandler:
|
||||
def __init__(
|
||||
self,
|
||||
presentation_id: str,
|
||||
data: GeneratePresentationRequirementsRequest,
|
||||
):
|
||||
self.data = data
|
||||
self.presentation_id = presentation_id
|
||||
self.prompt = data.prompt
|
||||
self.n_slides = data.n_slides
|
||||
self.documents = data.documents or []
|
||||
self.language = data.language
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
all_document_paths = [*self.documents]
|
||||
|
||||
documents_loader = DocumentsLoader(all_document_paths)
|
||||
await documents_loader.load_documents(self.temp_dir)
|
||||
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
presentation = PresentationSqlModel(
|
||||
id=self.presentation_id,
|
||||
prompt=self.prompt,
|
||||
n_slides=self.n_slides,
|
||||
language=self.language,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(presentation.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return presentation
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GenerateResearchReportRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from research_report.generator import get_report
|
||||
|
||||
|
||||
class GenerateResearchReportHandler:
|
||||
def __init__(self, data: GenerateResearchReportRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
report = await get_report(self.data.query, self.data.language)
|
||||
|
||||
file_name = f"{report[:30]}.txt"
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(file_name, self.temp_dir)
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(report)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(file_path), extra=log_metadata.model_dump()
|
||||
)
|
||||
return file_path
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
import json
|
||||
from typing import List
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlmodel import delete
|
||||
|
||||
from api.models import LogMetadata, SSECompleteResponse, SSEResponse, SSEStatusResponse
|
||||
|
||||
from api.routers.presentation.mixins.fetch_assets_on_generation import (
|
||||
FetchAssetsOnPresentationGenerationMixin,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
PresentationAndSlides,
|
||||
PresentationGenerateRequest,
|
||||
)
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from ppt_config_generator.models import (
|
||||
PresentationMarkdownModel,
|
||||
PresentationStructureModel,
|
||||
)
|
||||
from ppt_generator.generator import generate_presentation_stream
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
LLMPresentationModel,
|
||||
LLMSlideModel,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
|
||||
from ppt_generator.slide_generator import get_slide_content_from_type_and_outline
|
||||
|
||||
|
||||
class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
def __init__(self, presentation_id: str, session: str):
|
||||
self.session = session
|
||||
self.presentation_id = presentation_id
|
||||
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
with get_sql_session() as sql_session:
|
||||
key_value_model = sql_session.get(KeyValueSqlModel, self.session)
|
||||
|
||||
if not key_value_model.value:
|
||||
raise HTTPException(400, "Data not found for provided session")
|
||||
|
||||
self.data = PresentationGenerateRequest(**key_value_model.value)
|
||||
|
||||
self.presentation_id = self.data.presentation_id
|
||||
self.theme = self.data.theme
|
||||
self.images = self.data.images
|
||||
self.title = self.data.title or ""
|
||||
self.outlines = self.data.outlines
|
||||
|
||||
return StreamingResponse(
|
||||
self.get_stream(*args, **kwargs), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
async def get_stream(
|
||||
self, logging_service: LoggingService, log_metadata: LogMetadata
|
||||
):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
if not self.outlines:
|
||||
raise HTTPException(400, "Outlines can not be empty")
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
presentation.outlines = [each.model_dump() for each in self.outlines]
|
||||
presentation.title = self.title or presentation.title
|
||||
presentation.theme = self.theme
|
||||
sql_session.exec(
|
||||
delete(SlideSqlModel).where(
|
||||
SlideSqlModel.presentation == self.presentation_id
|
||||
)
|
||||
)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
self.presentation = presentation
|
||||
|
||||
yield SSEResponse(
|
||||
event="response", data=json.dumps({"status": "Analyzing information 📊"})
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = None
|
||||
|
||||
# self.presentation_json will be mutated by the generator
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
async for result in self.generate_presentation_ollama_custom():
|
||||
yield result
|
||||
else:
|
||||
async for result in self.generate_presentation_openai_google():
|
||||
yield result
|
||||
|
||||
slide_models: List[SlideModel] = []
|
||||
for i, slide in enumerate(self.presentation_json["slides"]):
|
||||
slide["index"] = i
|
||||
slide["presentation"] = self.presentation.id
|
||||
slide["content"] = (
|
||||
LLM_CONTENT_TYPE_MAPPING[slide["type"]](**slide["content"])
|
||||
.to_content()
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
slide_model = SlideModel(**slide)
|
||||
slide_models.append(slide_model)
|
||||
|
||||
async for result in self.fetch_slide_assets(slide_models):
|
||||
yield result
|
||||
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models
|
||||
]
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add_all(slide_sql_models)
|
||||
sql_session.commit()
|
||||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
|
||||
yield SSEStatusResponse(status="Packing slide data").to_string()
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=self.presentation, slides=slide_sql_models
|
||||
).to_response_dict()
|
||||
|
||||
yield SSECompleteResponse(key="presentation", value=response).to_string()
|
||||
|
||||
async def generate_presentation_openai_google(self):
|
||||
presentation_text = ""
|
||||
async for event in await generate_presentation_stream(
|
||||
PresentationMarkdownModel(
|
||||
title=self.title,
|
||||
slides=self.outlines,
|
||||
notes=self.presentation.notes,
|
||||
)
|
||||
):
|
||||
chunk = event.choices[0].delta.content
|
||||
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
presentation_text += chunk
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = json.loads(presentation_text)
|
||||
|
||||
async def generate_presentation_ollama_custom(self):
|
||||
presentation_structure = PresentationStructureModel(
|
||||
**self.presentation.structure
|
||||
)
|
||||
slide_models = []
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": '{ "slides": [ '}),
|
||||
).to_string()
|
||||
n_slides = len(presentation_structure.slides)
|
||||
for i, slide_structure in enumerate(presentation_structure.slides):
|
||||
# Informing about the start of the slide
|
||||
# This is to make sure that the client renders slide n
|
||||
# when it receives start chunk of slide n + 1
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": "{"}),
|
||||
).to_string()
|
||||
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
slide_structure.type, self.outlines[i]
|
||||
)
|
||||
slide_model = LLMSlideModel(
|
||||
type=slide_structure.type,
|
||||
content=slide_content.model_dump(mode="json"),
|
||||
)
|
||||
slide_models.append(slide_model)
|
||||
chunk = json.dumps(slide_model.model_dump(mode="json"))
|
||||
|
||||
if i < n_slides - 1:
|
||||
chunk += ","
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk[1:]}),
|
||||
).to_string()
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": " ] }"}),
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = LLMPresentationModel(
|
||||
slides=slide_models,
|
||||
).model_dump(mode="json")
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationAndSlides
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class GetPresentationHandler:
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"presentation": self.id}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.id)
|
||||
slide_models = sql_session.exec(
|
||||
select(SlideSqlModel).where(SlideSqlModel.presentation == self.id)
|
||||
).all()
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=presentation, slides=slide_models
|
||||
).to_response_dict()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return response
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
from sqlmodel import select, exists
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationWithOneSlide
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class GetPresentationsHandler:
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
# Get presentations that have at least one slide
|
||||
presentations = sql_session.exec(
|
||||
select(PresentationSqlModel).where(
|
||||
exists().where(
|
||||
SlideSqlModel.presentation == PresentationSqlModel.id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
presentations.sort(key=lambda x: x.created_at, reverse=True)
|
||||
presentations_with_slide = []
|
||||
for presentation in presentations:
|
||||
slide = sql_session.exec(
|
||||
select(SlideSqlModel)
|
||||
.where(SlideSqlModel.presentation == presentation.id)
|
||||
.where(SlideSqlModel.index == 0)
|
||||
).first()
|
||||
presentations_with_slide.append(
|
||||
PresentationWithOneSlide.from_presentation_and_slide(
|
||||
presentation, slide
|
||||
)
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(
|
||||
[each.model_dump(mode="json") for each in presentations]
|
||||
),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return presentations_with_slide
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from typing import Optional
|
||||
from api.services.logging import LoggingService
|
||||
from api.models import LogMetadata
|
||||
from api.utils.model_utils import list_available_custom_models
|
||||
|
||||
|
||||
class ListAvailableCustomModelsHandler:
|
||||
|
||||
def __init__(self, url: Optional[str] = None, api_key: Optional[str] = None):
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
print("-" * 40)
|
||||
print(self.url, self.api_key)
|
||||
return await list_available_custom_models(self.url, self.api_key)
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.model_utils import list_pulled_ollama_models
|
||||
|
||||
|
||||
class ListPulledOllamaModelsHandler:
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message("Listing Ollama models"),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
pulled_models = await list_pulled_ollama_models()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(pulled_models),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return pulled_models
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from api.models import LogMetadata, OllamaModelMetadata
|
||||
from api.routers.presentation.models import OllamaSupportedModelsResponse
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
|
||||
|
||||
class ListSupportedOllamaModelsHandler:
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message("Listing supported Ollama models"),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return OllamaSupportedModelsResponse(
|
||||
models=SUPPORTED_OLLAMA_MODELS.values(),
|
||||
)
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
import json
|
||||
import traceback
|
||||
import aiohttp
|
||||
from fastapi import BackgroundTasks, HTTPException
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
SUPPORTED_OLLAMA_MODELS,
|
||||
)
|
||||
from api.routers.presentation.models import OllamaModelStatusResponse
|
||||
from api.services.instances import REDIS_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.model_utils import (
|
||||
get_llm_provider_url_or,
|
||||
list_pulled_ollama_models,
|
||||
pull_ollama_model,
|
||||
)
|
||||
|
||||
|
||||
class PullOllamaModelHandler:
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def get(
|
||||
self,
|
||||
logging_service: LoggingService,
|
||||
log_metadata: LogMetadata,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.name),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
if self.name not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {self.name} is not supported",
|
||||
)
|
||||
|
||||
try:
|
||||
pulled_models = await list_pulled_ollama_models()
|
||||
filtered_models = [
|
||||
model for model in pulled_models if model.name == self.name
|
||||
]
|
||||
if filtered_models:
|
||||
return filtered_models[0]
|
||||
except HTTPException as e:
|
||||
logging_service.logger.warning(
|
||||
logging_service.message(e.detail),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logging_service.logger.warning(
|
||||
f"Failed to check pulled models: {e}",
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to check pulled models: {e}",
|
||||
)
|
||||
|
||||
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{self.name}")
|
||||
|
||||
# If the model is being pulled, return the model
|
||||
if saved_model_status:
|
||||
saved_model_status_json = json.loads(saved_model_status)
|
||||
# If the model is being pulled, return the model
|
||||
# ? If the model status is pulled in redis but was not found while listing pulled models,
|
||||
# ? it means the model was deleted and we need to pull it again
|
||||
if (
|
||||
saved_model_status_json["status"] == "error"
|
||||
or saved_model_status_json["status"] == "pulled"
|
||||
):
|
||||
REDIS_SERVICE.delete(f"ollama_models/{self.name}")
|
||||
else:
|
||||
return saved_model_status_json
|
||||
|
||||
# If the model is not being pulled, pull the model
|
||||
background_tasks.add_task(self.pull_model_in_background)
|
||||
|
||||
return OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
|
||||
async def pull_model_in_background(self):
|
||||
await self.pull_model()
|
||||
|
||||
async def pull_model(self):
|
||||
saved_model_status = OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
log_event_count = 0
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(self.name):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
return saved_model_status
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
PresentationAndPaths,
|
||||
SearchIconRequest,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from image_processor.icons_finder import get_icons
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
|
||||
|
||||
class SearchIconHandler:
|
||||
|
||||
def __init__(self, data: SearchIconRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
vector_store = get_icons_vectorstore()
|
||||
|
||||
icon_paths = await get_icons(
|
||||
vector_store,
|
||||
self.data.query or "",
|
||||
self.data.page,
|
||||
self.data.limit,
|
||||
self.data.category,
|
||||
self.temp_dir,
|
||||
)
|
||||
|
||||
response = PresentationAndPaths(
|
||||
presentation_id=self.data.presentation_id, paths=icon_paths
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationAndUrls, SearchImageRequest
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class SearchImageHandler:
|
||||
|
||||
def __init__(self, data: SearchImageRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
response = PresentationAndUrls(
|
||||
presentation_id=self.data.presentation_id, urls=[]
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
from fastapi import UploadFile
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class UpdateParsedDocumentHandler:
|
||||
|
||||
def __init__(self, file_path: str, file: UploadFile):
|
||||
self.file_path = file_path
|
||||
self.file = file
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"path": self.file_path, "file": self.file}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with open(self.file_path, "wb") as f:
|
||||
f.write(await self.file.read())
|
||||
|
||||
return {"message": "File saved successfully"}
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import UpdatePresentationThemeRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PreferencesSqlModel, PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class UpdatePresentationThemeHandler:
|
||||
|
||||
def __init__(self, data: UpdatePresentationThemeRequest):
|
||||
self.data = data
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
preferences = sql_session.get(PreferencesSqlModel, 0)
|
||||
|
||||
if not preferences:
|
||||
preferences = PreferencesSqlModel(id=0, theme=None)
|
||||
sql_session.add(preferences)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(preferences)
|
||||
|
||||
if self.data.theme:
|
||||
theme_name = self.data.theme.get("name", None)
|
||||
if theme_name and theme_name.lower() == "custom":
|
||||
preferences.theme = self.data.theme
|
||||
|
||||
presentation.theme = self.data.theme
|
||||
sql_session.commit()
|
||||
|
||||
return {"message": "Theme updated successfully"}
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import os
|
||||
from typing import List
|
||||
from urllib.parse import unquote, urlparse
|
||||
import uuid
|
||||
|
||||
from sqlmodel import delete
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
PresentationUpdateRequest,
|
||||
PresentationAndSlides,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import download_files, get_presentation_dir, replace_file_name
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
|
||||
|
||||
class UpdateSlideModelsHandler:
|
||||
|
||||
def __init__(self, data: PresentationUpdateRequest):
|
||||
self.data = data
|
||||
self.presentation_id = data.presentation_id
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
presentation_id = self.data.presentation_id
|
||||
new_slides = self.data.slides
|
||||
|
||||
# Handle images
|
||||
images_local_paths = []
|
||||
images_download_links = []
|
||||
for new_slide in new_slides:
|
||||
new_images = new_slide.images or []
|
||||
for i, image in enumerate(new_images):
|
||||
if image.startswith("http"):
|
||||
parsed_url = unquote(urlparse(image).path)
|
||||
image_name = replace_file_name(
|
||||
os.path.basename(parsed_url), str(uuid.uuid4())
|
||||
)
|
||||
image_path = f"{self.presentation_dir}/images/{image_name}"
|
||||
images_local_paths.append(image_path)
|
||||
images_download_links.append(image)
|
||||
new_slide.images[i] = image_path
|
||||
|
||||
if images_download_links:
|
||||
await download_files(images_download_links, images_local_paths)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in new_slides
|
||||
]
|
||||
to_update_slides_ids = [each.id for each in slide_sql_models]
|
||||
sql_session.exec(
|
||||
delete(SlideSqlModel).where(SlideSqlModel.id.in_(to_update_slides_ids))
|
||||
)
|
||||
sql_session.add_all(slide_sql_models)
|
||||
sql_session.commit()
|
||||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
presentation = sql_session.get(PresentationSqlModel, presentation_id)
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=presentation, slides=slide_sql_models
|
||||
)
|
||||
response = response.to_response_dict()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return response
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
from typing import List, Optional
|
||||
import uuid
|
||||
from fastapi import UploadFile
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import DocumentsAndImagesPath
|
||||
from api.services.logging import LoggingService
|
||||
from api.validators import validate_files
|
||||
from document_processor.loader import UPLOAD_ACCEPTED_DOCUMENTS
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
|
||||
|
||||
class UploadFilesHandler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
documents: Optional[List[UploadFile]],
|
||||
images: Optional[List[UploadFile]],
|
||||
):
|
||||
self.documents = documents
|
||||
self.images = images
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
print("Upload Temp Dir: " + self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(
|
||||
{
|
||||
"documents": self.documents,
|
||||
"images": self.images,
|
||||
}
|
||||
),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
validate_files(self.documents, True, True, 50, UPLOAD_ACCEPTED_DOCUMENTS)
|
||||
validate_files(
|
||||
self.images, True, True, 10, ["image/jpeg", "image/png", "image/webp"]
|
||||
)
|
||||
|
||||
self.documents = self.documents or []
|
||||
self.images = self.images or []
|
||||
|
||||
temp_documents: List[str] = []
|
||||
if self.documents or self.images:
|
||||
all_documents = self.documents + self.images
|
||||
for doc in all_documents:
|
||||
temp_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
doc.filename, self.temp_dir
|
||||
)
|
||||
with open(temp_path, "wb") as f:
|
||||
content = await doc.read()
|
||||
f.write(content)
|
||||
|
||||
temp_documents.append(temp_path)
|
||||
|
||||
documents_count = len(temp_documents)
|
||||
response = DocumentsAndImagesPath(
|
||||
documents=temp_documents[:documents_count],
|
||||
images=temp_documents[documents_count:],
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from fastapi import UploadFile
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationAndPath
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.utils import get_presentation_dir
|
||||
|
||||
|
||||
class UploadPresentationThumbnailHandler:
|
||||
|
||||
def __init__(self, presentation_id: str, thumbnail: UploadFile):
|
||||
self.presentation_id = presentation_id
|
||||
self.thumbnail = thumbnail
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(
|
||||
{
|
||||
"presentation_id": self.presentation_id,
|
||||
"thumbnail": self.thumbnail,
|
||||
}
|
||||
),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
|
||||
with open(os.path.join(self.presentation_dir, "thumbnail.jpg"), "wb") as f:
|
||||
f.write(await self.thumbnail.read())
|
||||
|
||||
presentation.thumbnail = os.path.join(
|
||||
self.presentation_dir, "thumbnail.jpg"
|
||||
)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
response = PresentationAndPath(
|
||||
presentation_id=self.presentation_id, path=presentation.thumbnail
|
||||
)
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from api.models import SSEStatusResponse
|
||||
from api.utils.utils import get_presentation_images_dir
|
||||
from image_processor.icons_finder import get_icon
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from ppt_generator.slide_model_utils import SlideModelUtils
|
||||
|
||||
|
||||
class FetchAssetsOnPresentationGenerationMixin:
|
||||
|
||||
async def fetch_slide_assets(self, slide_models: List[SlideModel]):
|
||||
image_prompts = []
|
||||
icon_queries = []
|
||||
|
||||
for each_slide_model in slide_models:
|
||||
slide_model_utils = SlideModelUtils(self.theme, each_slide_model)
|
||||
image_prompts.extend(slide_model_utils.get_image_prompts())
|
||||
icon_queries.extend(slide_model_utils.get_icon_queries())
|
||||
|
||||
if icon_queries:
|
||||
icon_vector_store = get_icons_vectorstore()
|
||||
|
||||
images_directory = get_presentation_images_dir(self.presentation_id)
|
||||
|
||||
coroutines = [
|
||||
generate_image(
|
||||
each,
|
||||
images_directory,
|
||||
)
|
||||
for each in image_prompts
|
||||
] + [get_icon(icon_vector_store, each) for each in icon_queries]
|
||||
|
||||
assets_future = asyncio.gather(*coroutines)
|
||||
|
||||
while not assets_future.done():
|
||||
status = SSEStatusResponse(status="Fetching slide assets").to_string()
|
||||
yield status
|
||||
await asyncio.sleep(5)
|
||||
|
||||
assets = await assets_future
|
||||
|
||||
image_prompts_len = len(image_prompts)
|
||||
|
||||
images = assets[:image_prompts_len]
|
||||
icons = assets[image_prompts_len:]
|
||||
|
||||
for each_slide_model in slide_models:
|
||||
each_slide_model.images = images[: each_slide_model.images_count]
|
||||
images = images[each_slide_model.images_count :]
|
||||
|
||||
each_slide_model.icons = icons[: each_slide_model.icons_count]
|
||||
icons = icons[each_slide_model.icons_count :]
|
||||
|
||||
yield SSEStatusResponse(status="Slide assets fetched").to_string()
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
import os
|
||||
from urllib.parse import unquote, urlparse
|
||||
import uuid
|
||||
from api.utils.utils import download_files, replace_file_name
|
||||
from ppt_generator.models.pptx_models import PptxPictureBoxModel
|
||||
|
||||
|
||||
class FetchPresentationAssetsMixin:
|
||||
|
||||
async def fetch_presentation_assets(self):
|
||||
image_urls = []
|
||||
image_local_paths = []
|
||||
|
||||
for each_slide in self.data.pptx_model.slides:
|
||||
for each_shape in each_slide.shapes:
|
||||
if isinstance(each_shape, PptxPictureBoxModel):
|
||||
image_path = each_shape.picture.path
|
||||
if image_path.startswith("http"):
|
||||
if image_path.startswith("http://localhost:3000/static"):
|
||||
image_path = image_path.replace(
|
||||
"http://localhost:3000/static", ""
|
||||
)
|
||||
image_path = "/app" + image_path
|
||||
elif image_path.startswith("http://localhost/static"):
|
||||
image_path = image_path.replace(
|
||||
"http://localhost/static", ""
|
||||
)
|
||||
image_path = "/app" + image_path
|
||||
else:
|
||||
image_urls.append(image_path)
|
||||
parsed_url = unquote(urlparse(image_path).path)
|
||||
image_name = replace_file_name(
|
||||
os.path.basename(parsed_url), str(uuid.uuid4())
|
||||
)
|
||||
image_path = os.path.join(self.temp_dir, image_name)
|
||||
image_local_paths.append(image_path)
|
||||
elif image_path.startswith("file://"):
|
||||
image_path = image_path.replace("file:///", "")
|
||||
# Check if it's a Windows path (has colon at index 1)
|
||||
if not (len(image_path) > 1 and image_path[1] == ":"):
|
||||
image_path = "/" + image_path
|
||||
|
||||
each_shape.picture.path = image_path
|
||||
each_shape.picture.is_network = False
|
||||
|
||||
await download_files(image_urls, image_local_paths)
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.models import OllamaModelMetadata
|
||||
from ppt_config_generator.models import SlideMarkdownModel
|
||||
from ppt_generator.models.pptx_models import PptxPresentationModel
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
IconCategoryEnum,
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from ollama._types import ModelDetails
|
||||
|
||||
|
||||
class ThemeEnum(Enum):
|
||||
DARK = "dark"
|
||||
LIGHT = "light"
|
||||
ROYAL_BLUE = "royal_blue"
|
||||
CREAM = "cream"
|
||||
LIGHT_RED = "light_red"
|
||||
DARK_PINK = "dark_pink"
|
||||
FAINT_YELLOW = "faint_yellow"
|
||||
|
||||
|
||||
class DocumentsAndImagesPath(BaseModel):
|
||||
documents: Optional[List[str]] = None
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GenerateResearchReportRequest(BaseModel):
|
||||
language: Optional[str] = None
|
||||
query: str
|
||||
|
||||
|
||||
class DecomposeDocumentsRequest(DocumentsAndImagesPath):
|
||||
pass
|
||||
|
||||
|
||||
class GeneratePresentationRequirementsRequest(BaseModel):
|
||||
prompt: Optional[str] = None
|
||||
n_slides: int
|
||||
language: str
|
||||
documents: Optional[List[str]] = None
|
||||
research_reports: Optional[List[str]] = None
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GenerateOutlinesRequest(BaseModel):
|
||||
presentation_id: str
|
||||
|
||||
|
||||
class PresentationGenerateRequest(BaseModel):
|
||||
presentation_id: str
|
||||
theme: Optional[dict] = None
|
||||
images: Optional[List[str]] = None
|
||||
outlines: List[SlideMarkdownModel]
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
presentation_id: str
|
||||
prompt: ImagePromptWithThemeAndAspectRatio
|
||||
|
||||
|
||||
class SearchImageRequest(BaseModel):
|
||||
presentation_id: str
|
||||
query: Optional[str] = None
|
||||
page: int = 1
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class SearchIconRequest(BaseModel):
|
||||
presentation_id: str
|
||||
query: Optional[str] = None
|
||||
category: Optional[IconCategoryEnum] = None
|
||||
page: int = 1
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class SlideEditRequest(BaseModel):
|
||||
index: int
|
||||
prompt: str
|
||||
|
||||
|
||||
class EditPresentationRequest(BaseModel):
|
||||
presentation_id: str
|
||||
watermark: bool = True
|
||||
changes: List[SlideEditRequest]
|
||||
|
||||
|
||||
class EditPresentationSlideRequest(BaseModel):
|
||||
presentation_id: str
|
||||
index: int
|
||||
prompt: str
|
||||
|
||||
|
||||
class UpdatePresentationThemeRequest(BaseModel):
|
||||
presentation_id: str
|
||||
theme: Optional[dict] = None
|
||||
|
||||
|
||||
class ExportAsRequest(BaseModel):
|
||||
presentation_id: str
|
||||
pptx_model: PptxPresentationModel
|
||||
|
||||
|
||||
class DecomposeDocumentsResponse(BaseModel):
|
||||
documents: dict
|
||||
|
||||
|
||||
class PresentationAndSlides(BaseModel):
|
||||
presentation: PresentationSqlModel
|
||||
slides: List[SlideSqlModel]
|
||||
|
||||
def to_response_dict(self):
|
||||
presentation = self.presentation.model_dump(mode="json")
|
||||
return {
|
||||
"presentation": presentation,
|
||||
"slides": [each.model_dump(mode="json") for each in self.slides],
|
||||
}
|
||||
|
||||
|
||||
class PresentationUpdateRequest(BaseModel):
|
||||
presentation_id: str
|
||||
slides: List[SlideModel]
|
||||
|
||||
|
||||
class PresentationAndUrl(BaseModel):
|
||||
presentation_id: str
|
||||
url: str
|
||||
|
||||
|
||||
class PresentationAndUrls(BaseModel):
|
||||
presentation_id: str
|
||||
urls: List[str]
|
||||
|
||||
|
||||
class PresentationAndPath(BaseModel):
|
||||
presentation_id: str
|
||||
path: str
|
||||
|
||||
|
||||
class PresentationAndPaths(BaseModel):
|
||||
presentation_id: str
|
||||
paths: List[str]
|
||||
|
||||
|
||||
class PresentationPathAndEditPath(PresentationAndPath):
|
||||
edit_path: str
|
||||
|
||||
|
||||
class UpdatePresentationTitlesRequest(BaseModel):
|
||||
presentation_id: str
|
||||
titles: List[str]
|
||||
|
||||
|
||||
class GeneratePresentationRequest(BaseModel):
|
||||
prompt: str
|
||||
n_slides: int = Field(default=8, ge=5, le=15)
|
||||
language: str = Field(default="English")
|
||||
theme: ThemeEnum = Field(default=ThemeEnum.LIGHT)
|
||||
documents: Optional[List[UploadFile]] = None
|
||||
export_as: Literal["pptx", "pdf"] = Field(default="pptx")
|
||||
|
||||
|
||||
class OllamaModelStatusResponse(BaseModel):
|
||||
name: str
|
||||
size: Optional[int] = None
|
||||
downloaded: Optional[int] = None
|
||||
status: str
|
||||
done: bool
|
||||
|
||||
|
||||
class OllamaSupportedModelsResponse(BaseModel):
|
||||
models: List[OllamaModelMetadata]
|
||||
|
||||
|
||||
class PresentationWithOneSlide(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
theme: Optional[dict] = None
|
||||
title: Optional[str] = None
|
||||
slide: SlideSqlModel
|
||||
|
||||
@classmethod
|
||||
def from_presentation_and_slide(
|
||||
cls, presentation: PresentationSqlModel, slide: SlideSqlModel
|
||||
):
|
||||
return cls(
|
||||
id=presentation.id,
|
||||
created_at=presentation.created_at,
|
||||
theme=presentation.theme,
|
||||
title=presentation.title,
|
||||
slide=slide,
|
||||
)
|
||||
|
|
@ -1,406 +0,0 @@
|
|||
from typing import Annotated, List, Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, File, Form, UploadFile
|
||||
|
||||
from api.models import SessionModel
|
||||
from api.request_utils import RequestUtils
|
||||
from api.routers.presentation.handlers.decompose_documents import (
|
||||
DecomposeDocumentsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.delete_presentation import (
|
||||
DeletePresentationHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.delete_slide import DeleteSlideHandler
|
||||
from api.routers.presentation.handlers.edit import PresentationEditHandler
|
||||
from api.routers.presentation.handlers.export_as_pptx import ExportAsPptxHandler
|
||||
from api.routers.presentation.handlers.generate_data import (
|
||||
PresentationGenerateDataHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_image import GenerateImageHandler
|
||||
from api.routers.presentation.handlers.generate_presentation import (
|
||||
GeneratePresentationHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_presentation_requirements import (
|
||||
GeneratePresentationRequirementsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_research_report import (
|
||||
GenerateResearchReportHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_stream import (
|
||||
PresentationGenerateStreamHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_outlines import (
|
||||
PresentationOutlinesGenerateHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.get_presentation import GetPresentationHandler
|
||||
from api.routers.presentation.handlers.get_presentations import GetPresentationsHandler
|
||||
from api.routers.presentation.handlers.list_available_custom_models import (
|
||||
ListAvailableCustomModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.list_ollama_pulled_models import (
|
||||
ListPulledOllamaModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
ListSupportedOllamaModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.pull_ollama_model import PullOllamaModelHandler
|
||||
from api.routers.presentation.handlers.search_icon import SearchIconHandler
|
||||
from api.routers.presentation.handlers.search_image import SearchImageHandler
|
||||
from api.routers.presentation.handlers.update_parsed_document import (
|
||||
UpdateParsedDocumentHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.update_presentation_theme import (
|
||||
UpdatePresentationThemeHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.update_slide_models import (
|
||||
UpdateSlideModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.upload_files import UploadFilesHandler
|
||||
from api.routers.presentation.handlers.upload_presentation_thumbnail import (
|
||||
UploadPresentationThumbnailHandler,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
DecomposeDocumentsRequest,
|
||||
DecomposeDocumentsResponse,
|
||||
DocumentsAndImagesPath,
|
||||
EditPresentationSlideRequest,
|
||||
ExportAsRequest,
|
||||
GenerateImageRequest,
|
||||
GeneratePresentationRequest,
|
||||
GeneratePresentationRequirementsRequest,
|
||||
GenerateResearchReportRequest,
|
||||
OllamaModelStatusResponse,
|
||||
OllamaSupportedModelsResponse,
|
||||
PresentationAndPath,
|
||||
PresentationAndPaths,
|
||||
PresentationAndSlides,
|
||||
GenerateOutlinesRequest,
|
||||
PresentationAndUrls,
|
||||
PresentationGenerateRequest,
|
||||
PresentationPathAndEditPath,
|
||||
SearchIconRequest,
|
||||
SearchImageRequest,
|
||||
UpdatePresentationThemeRequest,
|
||||
PresentationUpdateRequest,
|
||||
PresentationWithOneSlide,
|
||||
)
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.utils.utils import handle_errors
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
route_prefix = "/api/v1/ppt"
|
||||
presentation_router = APIRouter(prefix=route_prefix)
|
||||
|
||||
|
||||
@presentation_router.get(
|
||||
"/user_presentations", response_model=List[PresentationWithOneSlide]
|
||||
)
|
||||
async def get_user_presentations():
|
||||
request_utils = RequestUtils(f"{route_prefix}/user_presentations")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
GetPresentationsHandler().get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get("/presentation", response_model=PresentationAndSlides)
|
||||
async def get_presentation_from_id(presentation_id: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GetPresentationHandler(presentation_id).get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/files/upload", response_model=DocumentsAndImagesPath)
|
||||
async def upload_files(
|
||||
documents: Annotated[Optional[List[UploadFile]], File()] = None,
|
||||
images: Annotated[Optional[List[UploadFile]], File()] = None,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/files/upload")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
UploadFilesHandler(documents, images).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/report/generate", response_model=str)
|
||||
async def generate_research_report(
|
||||
data: GenerateResearchReportRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/report/generate")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
GenerateResearchReportHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/files/decompose", response_model=DecomposeDocumentsResponse)
|
||||
async def decompose_documents(data: DecomposeDocumentsRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/files/decompose")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
DecomposeDocumentsHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/document/update")
|
||||
async def update_document(
|
||||
path: Annotated[str, Body()],
|
||||
file: Annotated[UploadFile, File()],
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/document/update")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
UpdateParsedDocumentHandler(path, file).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/create", response_model=PresentationSqlModel)
|
||||
async def create_presentation(
|
||||
data: GeneratePresentationRequirementsRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/create")
|
||||
presentation_id = str(uuid.uuid4())
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GeneratePresentationRequirementsHandler(presentation_id, data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/outlines/generate", response_model=PresentationSqlModel)
|
||||
async def generate_outlines(data: GenerateOutlinesRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/outlines/generate")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationOutlinesGenerateHandler(data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/generate/data", response_model=SessionModel)
|
||||
async def submit_presentation_generation_data(
|
||||
data: PresentationGenerateRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/generate/data")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationGenerateDataHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get("/generate/stream")
|
||||
async def presentation_generation_stream(presentation_id: str, session: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/generate/stream")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationGenerateStreamHandler(presentation_id, session).get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/presentation/thumbnail", response_model=PresentationAndPath)
|
||||
async def update_presentation(
|
||||
presentation_id: Annotated[str, Body()],
|
||||
thumbnail: Annotated[UploadFile, File()],
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation/thumbnail")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
UploadPresentationThumbnailHandler(presentation_id, thumbnail).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/presentation/theme")
|
||||
async def update_presentation(
|
||||
data: UpdatePresentationThemeRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation/theme")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
UpdatePresentationThemeHandler(data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/edit", response_model=SlideModel)
|
||||
async def update_presentation(
|
||||
data: EditPresentationSlideRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/edit")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationEditHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/slides/update", response_model=PresentationAndSlides)
|
||||
async def update_slide_models(data: PresentationUpdateRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/slides/update")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
UpdateSlideModelsHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/image/generate", response_model=PresentationAndPaths)
|
||||
async def generate_image(data: GenerateImageRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/image/generate")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GenerateImageHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/image/search", response_model=PresentationAndUrls)
|
||||
async def search_image(data: SearchImageRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/image/search")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
SearchImageHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/icon/search", response_model=PresentationAndPaths)
|
||||
async def search_icon(data: SearchIconRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/icon/search")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
SearchIconHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post(
|
||||
"/presentation/export_as_pptx", response_model=PresentationAndPath
|
||||
)
|
||||
async def export_as_pptx(data: ExportAsRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation/export_as_pptx")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
ExportAsPptxHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.delete("/delete", status_code=204)
|
||||
async def delete_presentation(presentation_id: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/delete")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
DeletePresentationHandler(presentation_id).delete, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.delete("/slide/delete", status_code=204)
|
||||
async def delete_slide(slide_id: str, presentation_id: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/slide/delete")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
DeleteSlideHandler(slide_id).delete, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post(
|
||||
"/generate/presentation", response_model=PresentationPathAndEditPath
|
||||
)
|
||||
async def generate_presentation(data: Annotated[GeneratePresentationRequest, Form()]):
|
||||
presentation_id = str(uuid.uuid4())
|
||||
|
||||
request_utils = RequestUtils(f"{route_prefix}/generate/presentation")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GeneratePresentationHandler(presentation_id, data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Ollama Support
|
||||
@presentation_router.get(
|
||||
"/ollama/list-supported-models", response_model=OllamaSupportedModelsResponse
|
||||
)
|
||||
async def list_supported_ollama_models():
|
||||
request_utils = RequestUtils(f"{route_prefix}/ollama/list-supported-models")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListSupportedOllamaModelsHandler().get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get(
|
||||
"/ollama/list-pulled-models", response_model=List[OllamaModelStatusResponse]
|
||||
)
|
||||
async def list_pulled_ollama_models():
|
||||
request_utils = RequestUtils(f"{route_prefix}/ollama/list-pulled-models")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListPulledOllamaModelsHandler().get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get("/ollama/pull-model", response_model=OllamaModelStatusResponse)
|
||||
async def pull_ollama_model(name: str, background_tasks: BackgroundTasks):
|
||||
request_utils = RequestUtils(f"{route_prefix}/ollama/pull-model")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
PullOllamaModelHandler(name).get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
background_tasks=background_tasks,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/models/list/custom", response_model=List[str])
|
||||
async def list_custom_models(
|
||||
url: Annotated[Optional[str], Body()] = None,
|
||||
api_key: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/models/list/custom")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListAvailableCustomModelsHandler(url, api_key).get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
from api.services.redis import RedisService
|
||||
from api.services.temp_file import TempFileService
|
||||
|
||||
|
||||
TEMP_FILE_SERVICE = TempFileService()
|
||||
REDIS_SERVICE = RedisService()
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Any
|
||||
from logging import Logger
|
||||
|
||||
|
||||
class LoggingService:
|
||||
|
||||
def __init__(self, stream_name: str):
|
||||
self._logger = Logger(stream_name)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
return self._logger
|
||||
|
||||
def message(self, msg: Any):
|
||||
return {"msg": msg}
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
|
||||
def get_random_uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class PresentationSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
created_at: datetime = Field(default=datetime.now())
|
||||
prompt: Optional[str] = None
|
||||
n_slides: int
|
||||
theme: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
file: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
structure: Optional[dict] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
notes: Optional[List[str]] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
outlines: Optional[List[dict]] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
language: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
thumbnail: Optional[str] = None
|
||||
data: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
|
||||
|
||||
class SlideSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
index: int = Field(index=True)
|
||||
type: int
|
||||
design_index: Optional[int] = None
|
||||
images: Optional[List[str]] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
icons: Optional[List[str]] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
presentation: str
|
||||
content: dict = Field(sa_column=Column(JSON, nullable=False), default=None)
|
||||
properties: Optional[dict] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
|
||||
|
||||
class KeyValueSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
key: str = Field(index=True)
|
||||
value: dict = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
|
||||
|
||||
class PreferencesSqlModel(SQLModel, table=True):
|
||||
id: int = Field(default=0, primary_key=True)
|
||||
theme: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
|
|
@ -1,179 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
import openai
|
||||
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.routers.presentation.models import OllamaModelStatusResponse
|
||||
|
||||
|
||||
def is_ollama_selected() -> bool:
|
||||
return get_selected_llm_provider() == SelectedLLMProvider.OLLAMA
|
||||
|
||||
|
||||
def is_custom_llm_selected() -> bool:
|
||||
return get_selected_llm_provider() == SelectedLLMProvider.CUSTOM
|
||||
|
||||
|
||||
def get_llm_provider_url_or():
|
||||
llm_provider_url = (
|
||||
os.getenv("OLLAMA_URL") if is_ollama_selected() else os.getenv("CUSTOM_LLM_URL")
|
||||
)
|
||||
llm_provider_url = llm_provider_url or "http://localhost:11434"
|
||||
if llm_provider_url.endswith("/"):
|
||||
return llm_provider_url[:-1]
|
||||
return llm_provider_url
|
||||
|
||||
|
||||
def get_selected_llm_provider() -> SelectedLLMProvider:
|
||||
return SelectedLLMProvider(os.getenv("LLM"))
|
||||
|
||||
|
||||
async def list_available_custom_models(
|
||||
url: Optional[str] = None, api_key: Optional[str] = None
|
||||
) -> list[str]:
|
||||
if not url:
|
||||
client = get_llm_client()
|
||||
else:
|
||||
client = openai.AsyncOpenAI(api_key=api_key or "null", base_url=url)
|
||||
models = []
|
||||
async for model in client.models.list():
|
||||
print(model)
|
||||
models.append(model.id)
|
||||
return models
|
||||
|
||||
|
||||
def get_model_base_url():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.path.join(get_llm_provider_url_or(), "v1")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return get_llm_provider_url_or()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM provider")
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return "ollama"
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_LLM_API_KEY") or "null"
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM API key")
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
client = AsyncOpenAI(
|
||||
base_url=get_model_base_url(),
|
||||
api_key=get_llm_api_key(),
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
async def list_pulled_ollama_models() -> list[OllamaModelStatusResponse]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{get_llm_provider_url_or()}/api/tags",
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
pulled_models = await response.json()
|
||||
return [
|
||||
OllamaModelStatusResponse(
|
||||
name=m["model"],
|
||||
size=m["size"],
|
||||
status="pulled",
|
||||
downloaded=m["size"],
|
||||
done=True,
|
||||
)
|
||||
for m in pulled_models["models"]
|
||||
]
|
||||
elif response.status == 403:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Forbidden: Please check your Ollama Configuration",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to list Ollama models: {response.status}",
|
||||
)
|
||||
|
||||
|
||||
async def pull_ollama_model(model: str) -> AsyncGenerator[dict, None]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{get_llm_provider_url_or()}/api/pull",
|
||||
json={"model": model},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to pull model: {await response.text()}",
|
||||
)
|
||||
|
||||
async for line in response.content:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
yield event
|
||||
|
|
@ -1,188 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.models import LogMetadata, UserConfig
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
def get_presentation_dir(presentation_id: str) -> str:
|
||||
presentation_dir = os.path.join(os.getenv("APP_DATA_DIRECTORY"), presentation_id)
|
||||
os.makedirs(presentation_dir, exist_ok=True)
|
||||
return presentation_dir
|
||||
|
||||
|
||||
def get_presentation_images_dir(presentation_id: str) -> str:
|
||||
presentation_images_dir = os.path.join(
|
||||
get_presentation_dir(presentation_id), "images"
|
||||
)
|
||||
os.makedirs(presentation_images_dir, exist_ok=True)
|
||||
return presentation_images_dir
|
||||
|
||||
|
||||
def get_user_config():
|
||||
user_config_path = os.getenv("USER_CONFIG_PATH")
|
||||
|
||||
existing_config = UserConfig()
|
||||
try:
|
||||
if os.path.exists(user_config_path):
|
||||
with open(user_config_path, "r") as f:
|
||||
existing_config = UserConfig(**json.load(f))
|
||||
except Exception as e:
|
||||
print("Error while loading user config")
|
||||
pass
|
||||
|
||||
return UserConfig(
|
||||
LLM=existing_config.LLM or os.getenv("LLM"),
|
||||
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY"),
|
||||
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or os.getenv("GOOGLE_API_KEY"),
|
||||
OLLAMA_URL=existing_config.OLLAMA_URL or os.getenv("OLLAMA_URL"),
|
||||
OLLAMA_MODEL=existing_config.OLLAMA_MODEL or os.getenv("OLLAMA_MODEL"),
|
||||
CUSTOM_LLM_URL=existing_config.CUSTOM_LLM_URL or os.getenv("CUSTOM_LLM_URL"),
|
||||
CUSTOM_LLM_API_KEY=existing_config.CUSTOM_LLM_API_KEY
|
||||
or os.getenv("CUSTOM_LLM_API_KEY"),
|
||||
CUSTOM_MODEL=existing_config.CUSTOM_MODEL or os.getenv("CUSTOM_MODEL"),
|
||||
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or os.getenv("PEXELS_API_KEY"),
|
||||
)
|
||||
|
||||
|
||||
def update_env_with_user_config():
|
||||
user_config = get_user_config()
|
||||
if user_config.LLM:
|
||||
os.environ["LLM"] = user_config.LLM
|
||||
if user_config.OPENAI_API_KEY:
|
||||
os.environ["OPENAI_API_KEY"] = user_config.OPENAI_API_KEY
|
||||
if user_config.GOOGLE_API_KEY:
|
||||
os.environ["GOOGLE_API_KEY"] = user_config.GOOGLE_API_KEY
|
||||
if user_config.OLLAMA_URL:
|
||||
os.environ["OLLAMA_URL"] = user_config.OLLAMA_URL
|
||||
if user_config.OLLAMA_MODEL:
|
||||
os.environ["OLLAMA_MODEL"] = user_config.OLLAMA_MODEL
|
||||
if user_config.CUSTOM_LLM_URL:
|
||||
os.environ["CUSTOM_LLM_URL"] = user_config.CUSTOM_LLM_URL
|
||||
if user_config.CUSTOM_LLM_API_KEY:
|
||||
os.environ["CUSTOM_LLM_API_KEY"] = user_config.CUSTOM_LLM_API_KEY
|
||||
if user_config.CUSTOM_MODEL:
|
||||
os.environ["CUSTOM_MODEL"] = user_config.CUSTOM_MODEL
|
||||
if user_config.PEXELS_API_KEY:
|
||||
os.environ["PEXELS_API_KEY"] = user_config.PEXELS_API_KEY
|
||||
|
||||
|
||||
def get_resource(relative_path):
|
||||
base_path = getattr(
|
||||
sys,
|
||||
"_MEIPASS",
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def replace_file_name(old_name: str, new_name: str) -> str:
|
||||
splitted = old_name.split(".")
|
||||
if len(splitted) < 1:
|
||||
return new_name
|
||||
else:
|
||||
return ".".join([new_name, splitted[-1]])
|
||||
|
||||
|
||||
def save_uploaded_files(
|
||||
TEMP_FILE_SERVICE, files: List[UploadFile], file_paths: List[str], temp_dir: str
|
||||
) -> List:
|
||||
full_file_paths = []
|
||||
for index, each_file in enumerate(files):
|
||||
temp_file_path = TEMP_FILE_SERVICE.create_temp_file(
|
||||
file_paths[index], each_file.file.read(), dir_path=temp_dir
|
||||
)
|
||||
full_file_paths.append(temp_file_path)
|
||||
return full_file_paths
|
||||
|
||||
|
||||
async def download_file(url: str, save_path: str, headers: Optional[dict] = None):
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
with open(save_path, "wb") as file:
|
||||
while True:
|
||||
chunk = await response.content.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
file.write(chunk)
|
||||
print(f"File downloaded successfully to {save_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"Failed to download file. HTTP status: {response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error while downloading file from {url} to {save_path}")
|
||||
return False
|
||||
|
||||
|
||||
async def download_files(urls: List[str], save_paths: List[str]):
|
||||
for url, save_path in zip(urls, save_paths):
|
||||
print(url)
|
||||
print(save_path)
|
||||
print("-" * 10)
|
||||
coroutines = [
|
||||
download_file(url, save_paths[index]) for index, url in enumerate(urls)
|
||||
]
|
||||
await asyncio.gather(*coroutines)
|
||||
|
||||
|
||||
async def handle_errors(
|
||||
func, logging_service: LoggingService, log_metadata: LogMetadata, **kwargs
|
||||
):
|
||||
try:
|
||||
logging_service.logger.info(f"START", extra=log_metadata.model_dump())
|
||||
response = await func(
|
||||
logging_service=logging_service, log_metadata=log_metadata, **kwargs
|
||||
)
|
||||
is_stream = isinstance(response, StreamingResponse)
|
||||
logging_service.logger.info(
|
||||
"STREAMING" if is_stream else "END", extra=log_metadata.model_dump()
|
||||
)
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
log_metadata.status_code = e.status_code
|
||||
logging_service.logger.error(
|
||||
f"Raised HTTPException - {e.detail}", extra=log_metadata.model_dump()
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(traceback.print_stack())
|
||||
print(traceback.print_exc())
|
||||
|
||||
log_metadata.status_code = 400
|
||||
logging_service.logger.critical(
|
||||
"Unhandled Exception",
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
raise HTTPException(400, "Something went wrong while processing your request.")
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
name, ext = os.path.splitext(filename)
|
||||
sanitized = re.sub(r'[\\/:*?"<>|]', "_", name)
|
||||
sanitized = re.sub(r"[\s_]+", "_", sanitized)
|
||||
sanitized = sanitized.strip(" .")
|
||||
|
||||
if not sanitized:
|
||||
sanitized = "untitled"
|
||||
|
||||
if len(sanitized) > 200:
|
||||
sanitized = sanitized[:200]
|
||||
|
||||
return sanitized + ext
|
||||
59
servers/fastapi/api/v1/ppt/background_tasks.py
Normal file
59
servers/fastapi/api/v1/ppt/background_tasks.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from services import REDIS_SERVICE
|
||||
from utils.ollama import pull_ollama_model
|
||||
|
||||
|
||||
async def pull_ollama_model_background_task(model: str):
|
||||
saved_model_status = OllamaModelStatus(
|
||||
name=model,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
log_event_count = 0
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{model}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{model}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{model}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
return saved_model_status
|
||||
14
servers/fastapi/api/v1/ppt/endpoints/custom_llm.py
Normal file
14
servers/fastapi/api/v1/ppt/endpoints/custom_llm.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Annotated, List, Optional
|
||||
from fastapi import APIRouter, Body
|
||||
|
||||
from utils.custom_llm_provider import list_available_custom_models
|
||||
|
||||
CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"])
|
||||
|
||||
|
||||
@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
url: Annotated[Optional[str], Body()] = None,
|
||||
api_key: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
return await list_available_custom_models(url, api_key)
|
||||
87
servers/fastapi/api/v1/ppt/endpoints/files.py
Normal file
87
servers/fastapi/api/v1/ppt/endpoints/files.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from http.client import HTTPException
|
||||
import os
|
||||
from typing import Annotated, List, Optional
|
||||
from fastapi import APIRouter, Body, File, UploadFile
|
||||
|
||||
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
|
||||
from models.decomposed_file_info import DecomposedFileInfo
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from utils.randomizers import get_random_uuid
|
||||
from utils.validators import validate_files
|
||||
|
||||
FILES_ROUTER = APIRouter(prefix="/files", tags=["Files"])
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/upload", response_model=List[str])
|
||||
async def upload_files(files: Optional[List[UploadFile]]):
|
||||
if not files:
|
||||
raise HTTPException(400, "Documents are required")
|
||||
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid())
|
||||
|
||||
validate_files(files, True, True, 50, UPLOAD_ACCEPTED_FILE_TYPES)
|
||||
|
||||
temp_files: List[str] = []
|
||||
if files:
|
||||
for each_file in files:
|
||||
temp_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
each_file.filename, temp_dir
|
||||
)
|
||||
with open(temp_path, "wb") as f:
|
||||
content = await each_file.read()
|
||||
f.write(content)
|
||||
|
||||
temp_files.append(temp_path)
|
||||
|
||||
return temp_files
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/decompose", response_model=List[DecomposedFileInfo])
|
||||
async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]):
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid())
|
||||
|
||||
txt_files = []
|
||||
other_files = []
|
||||
for file_path in file_paths:
|
||||
if file_path.endswith(".txt"):
|
||||
txt_files.append(file_path)
|
||||
else:
|
||||
other_files.append(file_path)
|
||||
|
||||
documents_loader = DocumentsLoader(file_paths=other_files)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
parsed_documents = documents_loader.documents
|
||||
|
||||
response = []
|
||||
for index, parsed_doc in enumerate(parsed_documents):
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{get_random_uuid()}.txt", temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
response.append(
|
||||
DecomposedFileInfo(
|
||||
name=os.path.basename(other_files[index]), file_path=file_path
|
||||
)
|
||||
)
|
||||
|
||||
# Return the txt documents as it is
|
||||
for each_file in txt_files:
|
||||
response.append(
|
||||
DecomposedFileInfo(name=os.path.basename(each_file), file_path=each_file)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/update")
|
||||
async def update_files(
|
||||
file_path: Annotated[str, Body()],
|
||||
file: Annotated[UploadFile, File()],
|
||||
):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
return {"message": "File updated successfully"}
|
||||
11
servers/fastapi/api/v1/ppt/endpoints/icons.py
Normal file
11
servers/fastapi/api/v1/ppt/endpoints/icons.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from typing import List
|
||||
from fastapi import APIRouter
|
||||
from services.icon_finder_service import IconFinderService
|
||||
|
||||
ICONS_ROUTER = APIRouter(prefix="/icons", tags=["Icons"])
|
||||
|
||||
|
||||
@ICONS_ROUTER.get("/search", response_model=List[str])
|
||||
async def search_icons(query: str, limit: int = 20):
|
||||
icon_finder_service = IconFinderService()
|
||||
return await icon_finder_service.search_icons(query, limit)
|
||||
40
servers/fastapi/api/v1/ppt/endpoints/images.py
Normal file
40
servers/fastapi/api/v1/ppt/endpoints/images.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from typing import List
|
||||
from fastapi import APIRouter
|
||||
from sqlmodel import select
|
||||
|
||||
from models.image_prompt import ImagePrompt
|
||||
from models.sql.image_asset import ImageAsset
|
||||
from services.database import get_sql_session
|
||||
from services.image_generation_service import ImageGenerationService
|
||||
from utils.asset_directory_utils import get_images_directory
|
||||
|
||||
IMAGES_ROUTER = APIRouter(prefix="/images", tags=["Images"])
|
||||
|
||||
|
||||
@IMAGES_ROUTER.get("/generate")
|
||||
async def generate_image(prompt: str):
|
||||
images_directory = get_images_directory()
|
||||
image_prompt = ImagePrompt(prompt=prompt)
|
||||
image_generation_service = ImageGenerationService(images_directory)
|
||||
|
||||
image = await image_generation_service.generate_image(image_prompt)
|
||||
if not isinstance(image, ImageAsset):
|
||||
return image
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(image)
|
||||
sql_session.commit()
|
||||
|
||||
return image.path
|
||||
|
||||
|
||||
@IMAGES_ROUTER.get("/generated", response_model=List[ImageAsset])
|
||||
async def get_generated_images():
|
||||
try:
|
||||
with get_sql_session() as sql_session:
|
||||
images = sql_session.exec(
|
||||
select(ImageAsset).order_by(ImageAsset.created_at.desc())
|
||||
).all()
|
||||
return images
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to retrieve generated images: {str(e)}"}
|
||||
72
servers/fastapi/api/v1/ppt/endpoints/ollama.py
Normal file
72
servers/fastapi/api/v1/ppt/endpoints/ollama.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import json
|
||||
from typing import List
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from api.v1.ppt.background_tasks import pull_ollama_model_background_task
|
||||
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from services import REDIS_SERVICE
|
||||
from utils.ollama import list_pulled_ollama_models
|
||||
|
||||
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
||||
|
||||
|
||||
@OLLAMA_ROUTER.get("/models/supported", response_model=List[OllamaModelMetadata])
|
||||
def get_supported_models():
|
||||
return SUPPORTED_OLLAMA_MODELS.values()
|
||||
|
||||
|
||||
@OLLAMA_ROUTER.get("/models/available", response_model=List[OllamaModelStatus])
|
||||
async def get_available_models():
|
||||
return await list_pulled_ollama_models()
|
||||
|
||||
|
||||
@OLLAMA_ROUTER.get("/model/pull", response_model=OllamaModelStatus)
|
||||
async def pull_model(model: str, background_tasks: BackgroundTasks):
|
||||
|
||||
if model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {model} is not supported",
|
||||
)
|
||||
|
||||
try:
|
||||
pulled_models = await list_pulled_ollama_models()
|
||||
filtered_models = [
|
||||
pulled_model for pulled_model in pulled_models if pulled_model.name == model
|
||||
]
|
||||
if filtered_models:
|
||||
return filtered_models[0]
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to check pulled models: {e}",
|
||||
)
|
||||
|
||||
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}")
|
||||
|
||||
# If the model is being pulled, return the model
|
||||
if saved_model_status:
|
||||
saved_model_status_json = json.loads(saved_model_status)
|
||||
# If the model is being pulled, return the model
|
||||
# ? If the model status is pulled in redis but was not found while listing pulled models,
|
||||
# ? it means the model was deleted and we need to pull it again
|
||||
if (
|
||||
saved_model_status_json["status"] == "error"
|
||||
or saved_model_status_json["status"] == "pulled"
|
||||
):
|
||||
REDIS_SERVICE.delete(f"ollama_models/{model}")
|
||||
else:
|
||||
return saved_model_status_json
|
||||
|
||||
# If the model is not being pulled, pull the model
|
||||
background_tasks.add_task(pull_ollama_model_background_task, model)
|
||||
|
||||
return OllamaModelStatus(
|
||||
name=model,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
66
servers/fastapi/api/v1/ppt/endpoints/outlines.py
Normal file
66
servers/fastapi/api/v1/ppt/endpoints/outlines.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import asyncio
|
||||
import json
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sse_response import SSECompleteResponse, SSEResponse, SSEStatusResponse
|
||||
from services.database import get_sql_session
|
||||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
|
||||
OUTLINES_ROUTER = APIRouter(prefix="/outlines", tags=["Outlines"])
|
||||
|
||||
|
||||
@OUTLINES_ROUTER.get("/stream")
|
||||
async def stream_outlines(presentation_id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, presentation_id)
|
||||
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
async def inner():
|
||||
yield SSEStatusResponse(
|
||||
status="Generating presentation outlines..."
|
||||
).to_string()
|
||||
|
||||
presentation_content_text = ""
|
||||
async for chunk in generate_ppt_outline(
|
||||
presentation.prompt,
|
||||
presentation.n_slides,
|
||||
presentation.language,
|
||||
presentation.summary,
|
||||
):
|
||||
# Give control to the event loop
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
presentation_content_text += chunk
|
||||
|
||||
presentation_content_json = json.loads(presentation_content_text)
|
||||
|
||||
presentation_content = PresentationOutlineModel(**presentation_content_json)
|
||||
presentation_content.slides = presentation_content.slides[
|
||||
: presentation.n_slides
|
||||
]
|
||||
|
||||
presentation.title = presentation_content.title
|
||||
presentation.outlines = [
|
||||
each.model_dump() for each in presentation_content.slides
|
||||
]
|
||||
presentation.notes = presentation_content.notes
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
yield SSECompleteResponse(
|
||||
key="presentation", value=presentation.model_dump(mode="json")
|
||||
).to_string()
|
||||
|
||||
return StreamingResponse(inner(), media_type="text/event-stream")
|
||||
496
servers/fastapi/api/v1/ppt/endpoints/presentation.py
Normal file
496
servers/fastapi/api/v1/ppt/endpoints/presentation.py
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Annotated, List, Optional
|
||||
import uuid, aiohttp
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import delete
|
||||
from sqlmodel import select
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.pptx_models import PptxPresentationModel
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.presentation_with_slides import PresentationWithSlides
|
||||
from models.generate_presentation_api import (
|
||||
GeneratePresentationRequest,
|
||||
PresentationAndPath,
|
||||
PresentationPathAndEditPath,
|
||||
)
|
||||
from services.get_layout_by_name import get_layout_by_name
|
||||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sse_response import SSECompleteResponse, SSEResponse
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from services.database import get_sql_session
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from models.sql.presentation import PresentationModel
|
||||
from services.pptx_presentation_creator import PptxPresentationCreator
|
||||
from utils.asset_directory_utils import get_exports_directory
|
||||
from utils.llm_calls.generate_document_summary import generate_document_summary
|
||||
from utils.llm_calls.generate_presentation_structure import (
|
||||
generate_presentation_structure,
|
||||
)
|
||||
from utils.llm_calls.generate_slide_content import (
|
||||
get_slide_content_from_type_and_outline,
|
||||
)
|
||||
from utils.process_slides import process_slide_and_fetch_assets
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
PRESENTATION_ROUTER = APIRouter(prefix="/presentation", tags=["Presentation"])
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.get("/", response_model=PresentationWithSlides)
|
||||
def get_presentation(id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, id)
|
||||
if not presentation:
|
||||
raise HTTPException(404, "Presentation not found")
|
||||
slides = sql_session.exec(
|
||||
select(SlideModel)
|
||||
.where(SlideModel.presentation == id)
|
||||
.order_by(SlideModel.index)
|
||||
)
|
||||
return PresentationWithSlides(
|
||||
**presentation.model_dump(),
|
||||
slides=slides,
|
||||
)
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.delete("/", status_code=204)
|
||||
def delete_presentation(id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, id)
|
||||
if not presentation:
|
||||
raise HTTPException(404, "Presentation not found")
|
||||
slides = sql_session.exec(
|
||||
select(SlideModel).where(SlideModel.presentation == id)
|
||||
).all()
|
||||
for slide in slides:
|
||||
sql_session.delete(slide)
|
||||
sql_session.delete(presentation)
|
||||
sql_session.commit()
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.get("/all", response_model=List[PresentationWithSlides])
|
||||
def get_all_presentations():
|
||||
with get_sql_session() as sql_session:
|
||||
presentations_with_slides = []
|
||||
presentations = sql_session.exec(select(PresentationModel))
|
||||
for presentation in presentations:
|
||||
slides = sql_session.exec(
|
||||
select(SlideModel)
|
||||
.where(SlideModel.presentation == presentation.id)
|
||||
.where(SlideModel.index == 0)
|
||||
).all()
|
||||
if not slides:
|
||||
continue
|
||||
presentations_with_slides.append(
|
||||
PresentationWithSlides(
|
||||
**presentation.model_dump(),
|
||||
slides=slides,
|
||||
)
|
||||
)
|
||||
return presentations_with_slides
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/create", response_model=PresentationModel)
|
||||
async def create_presentation(
|
||||
prompt: Annotated[str, Body()],
|
||||
n_slides: Annotated[int, Body()],
|
||||
language: Annotated[str, Body()],
|
||||
file_paths: Annotated[Optional[List[str]], Body()] = None,
|
||||
):
|
||||
presentation_id = str(uuid.uuid4())
|
||||
|
||||
summary = None
|
||||
if file_paths:
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(presentation_id)
|
||||
documents_loader = DocumentsLoader(file_paths=file_paths)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
presentation = PresentationModel(
|
||||
id=presentation_id,
|
||||
prompt=prompt,
|
||||
n_slides=n_slides,
|
||||
language=language,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
return presentation
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel)
|
||||
async def prepare_presentation(
|
||||
presentation_id: Annotated[str, Body()],
|
||||
outlines: Annotated[List[SlideOutlineModel], Body()],
|
||||
layout: Annotated[PresentationLayoutModel, Body()],
|
||||
title: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
if not outlines:
|
||||
raise HTTPException(status_code=400, detail="Outlines are required")
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, presentation_id)
|
||||
|
||||
total_slide_layouts = len(layout.slides)
|
||||
total_outlines = len(outlines)
|
||||
|
||||
if layout.ordered:
|
||||
presentation_structure = layout.to_presentation_structure()
|
||||
else:
|
||||
presentation_structure: PresentationStructureModel = (
|
||||
await generate_presentation_structure(
|
||||
presentation_outline=presentation.get_presentation_outline(),
|
||||
presentation_layout=layout,
|
||||
)
|
||||
)
|
||||
|
||||
presentation_structure.slides = presentation_structure.slides[: len(outlines)]
|
||||
for index in range(total_outlines):
|
||||
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
||||
if index >= total_outlines:
|
||||
presentation_structure.slides.append(random_slide_index)
|
||||
continue
|
||||
if presentation_structure.slides[index] >= total_slide_layouts:
|
||||
presentation_structure.slides[index] = random_slide_index
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
presentation.outlines = [each.model_dump() for each in outlines]
|
||||
presentation.title = title or presentation.title
|
||||
presentation.set_layout(layout)
|
||||
presentation.set_structure(presentation_structure)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
return presentation
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.get("/stream", response_model=PresentationWithSlides)
|
||||
async def stream_presentation(presentation_id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, presentation_id)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
if not presentation.structure:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Presentation not prepared for stream",
|
||||
)
|
||||
if not presentation.outlines:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Outlines can not be empty",
|
||||
)
|
||||
|
||||
async def inner():
|
||||
structure = presentation.get_structure()
|
||||
layout = presentation.get_layout()
|
||||
outline = presentation.get_presentation_outline()
|
||||
|
||||
# These tasks will be gathered and awaited after all slides are generated
|
||||
async_assets_generation_tasks = []
|
||||
|
||||
slides: List[SlideModel] = []
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": '{ "slides": [ '}),
|
||||
).to_string()
|
||||
for i, slide_layout_index in enumerate(structure.slides):
|
||||
slide_layout = layout.slides[slide_layout_index]
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
slide_layout, outline.slides[i]
|
||||
)
|
||||
slide = SlideModel(
|
||||
presentation=presentation_id,
|
||||
layout_group=layout.name,
|
||||
layout=slide_layout.id,
|
||||
index=i,
|
||||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
|
||||
# This will mutate slide
|
||||
async_assets_generation_tasks.append(process_slide_and_fetch_assets(slide))
|
||||
|
||||
# Give control to the event loop
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
||||
).to_string()
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": " ] }"}),
|
||||
).to_string()
|
||||
|
||||
generated_assets_lists = await asyncio.gather(*async_assets_generation_tasks)
|
||||
generated_assets = []
|
||||
for assets_list in generated_assets_lists:
|
||||
generated_assets.extend(assets_list)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slides)
|
||||
sql_session.add_all(generated_assets)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
for each_slide in slides:
|
||||
sql_session.refresh(each_slide)
|
||||
|
||||
response = PresentationWithSlides(
|
||||
**presentation.model_dump(),
|
||||
slides=slides,
|
||||
)
|
||||
|
||||
yield SSECompleteResponse(
|
||||
key="presentation",
|
||||
value=response.model_dump(mode="json"),
|
||||
).to_string()
|
||||
|
||||
return StreamingResponse(inner(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.put("/update", response_model=PresentationWithSlides)
|
||||
def update_presentation(
|
||||
presentation_with_slides: Annotated[PresentationWithSlides, Body()],
|
||||
):
|
||||
updated_presentation = presentation_with_slides.to_presentation_model()
|
||||
updated_slides = presentation_with_slides.slides
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, updated_presentation.id)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
presentation.sqlmodel_update(updated_presentation)
|
||||
|
||||
sql_session.exec(
|
||||
delete(SlideModel).where(SlideModel.presentation == updated_presentation.id)
|
||||
)
|
||||
sql_session.add_all(updated_slides)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
for slide in updated_slides:
|
||||
sql_session.refresh(slide)
|
||||
|
||||
return PresentationWithSlides(
|
||||
**presentation.model_dump(),
|
||||
slides=updated_slides,
|
||||
)
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/export/pptx", response_model=str)
|
||||
async def create_pptx(pptx_model: Annotated[PptxPresentationModel, Body()]):
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
|
||||
pptx_creator = PptxPresentationCreator(pptx_model, temp_dir)
|
||||
await pptx_creator.create_ppt()
|
||||
|
||||
export_directory = get_exports_directory()
|
||||
pptx_path = os.path.join(
|
||||
export_directory, f"{pptx_model.name or get_random_uuid()}.pptx"
|
||||
)
|
||||
pptx_creator.save(pptx_path)
|
||||
|
||||
return pptx_path
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/generate")
|
||||
async def generate_presentation_api(
|
||||
data: Annotated[GeneratePresentationRequest, Body()],
|
||||
):
|
||||
presentation_id = str(uuid.uuid4())
|
||||
print("**" * 40)
|
||||
print(f"Generating presentation with ID: {presentation_id}")
|
||||
print(f"Received Body as JSON: {data.model_dump_json(indent=2)}")
|
||||
|
||||
# 1. Save uploaded files
|
||||
file_paths = []
|
||||
if data.documents:
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
for upload in data.documents:
|
||||
file_path = os.path.join(temp_dir, upload.filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await upload.read())
|
||||
file_paths.append(file_path)
|
||||
|
||||
# 2. Create Presentation Summary (if documents are provided)
|
||||
summary = None
|
||||
if file_paths:
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(presentation_id)
|
||||
documents_loader = DocumentsLoader(file_paths=file_paths)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
# 3. Generate Outlines
|
||||
presentation_content_text = ""
|
||||
async for chunk in generate_ppt_outline(
|
||||
data.prompt,
|
||||
data.n_slides,
|
||||
data.language,
|
||||
summary,
|
||||
):
|
||||
presentation_content_text += chunk
|
||||
|
||||
presentation_content_json = json.loads(presentation_content_text)
|
||||
presentation_content = PresentationOutlineModel(**presentation_content_json)
|
||||
outlines = presentation_content.slides[: data.n_slides]
|
||||
total_outlines = len(outlines)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generated Presentation Content:", presentation_content_text)
|
||||
print(f"Generated {total_outlines} outlines for the presentation")
|
||||
print(f"Presentation Title: {presentation_content.title}")
|
||||
|
||||
# 4. Parse Layouts
|
||||
layout = await get_layout_by_name(data.layout)
|
||||
total_slide_layouts = len(layout.slides)
|
||||
|
||||
# 5. Generate Structure
|
||||
if layout.ordered:
|
||||
presentation_structure = layout.to_presentation_structure()
|
||||
else:
|
||||
presentation_structure: PresentationStructureModel = (
|
||||
await generate_presentation_structure(
|
||||
presentation_outline=PresentationOutlineModel(
|
||||
title=presentation_content.title,
|
||||
slides=outlines,
|
||||
notes=presentation_content.notes,
|
||||
),
|
||||
presentation_layout=layout,
|
||||
)
|
||||
)
|
||||
|
||||
presentation_structure.slides = presentation_structure.slides[:total_outlines]
|
||||
for index in range(total_outlines):
|
||||
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
||||
if index >= total_outlines:
|
||||
presentation_structure.slides.append(random_slide_index)
|
||||
continue
|
||||
if presentation_structure.slides[index] >= total_slide_layouts:
|
||||
presentation_structure.slides[index] = random_slide_index
|
||||
|
||||
# 6. Create and Save PresentationModel
|
||||
presentation = PresentationModel(
|
||||
id=presentation_id,
|
||||
prompt=data.prompt,
|
||||
n_slides=data.n_slides,
|
||||
language=data.language,
|
||||
title=presentation_content.title,
|
||||
summary=summary,
|
||||
outlines=[each.model_dump() for each in outlines],
|
||||
notes=presentation_content.notes,
|
||||
layout=layout.model_dump(),
|
||||
structure=presentation_structure.model_dump(),
|
||||
)
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
# 7. Generate slide content and save slides
|
||||
slides: List[SlideModel] = []
|
||||
slide_contents: List[dict] = []
|
||||
for i, slide_layout_index in enumerate(presentation_structure.slides):
|
||||
slide_layout = layout.slides[slide_layout_index]
|
||||
print(f"Generating content for slide {i} with layout {slide_layout.id}")
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
slide_layout, outlines[i]
|
||||
)
|
||||
print(f"Generated content for slide {i}: {json.dumps(slide_content, indent=2)}")
|
||||
slide = SlideModel(
|
||||
presentation=presentation_id,
|
||||
layout_group=layout.name,
|
||||
layout=slide_layout.id,
|
||||
index=i,
|
||||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
slide_contents.append(slide_content)
|
||||
|
||||
# Process slides to fetch assets (images, icons, etc.)
|
||||
print("Processing slides to fetch assets")
|
||||
for slide in slides:
|
||||
try:
|
||||
await process_slide_and_fetch_assets(slide)
|
||||
print(f"Processed slide {slide.index} successfully")
|
||||
except Exception as e:
|
||||
print(f"Error processing slide {slide.index}: {e}")
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add_all(slides)
|
||||
sql_session.commit()
|
||||
|
||||
# 8. Export as PPTX
|
||||
if data.export_as == "pptx":
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation as PPTX")
|
||||
|
||||
# Get the converted PPTX model from your existing Next.js service
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost/api/presentation_to_pptx_model?id={presentation_id}"
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
print(f"Failed to get PPTX model: {error_text}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to convert presentation to PPTX model",
|
||||
)
|
||||
pptx_model_data = await response.json()
|
||||
print(f"Received PPTX model data: {json.dumps(pptx_model_data, indent=2)}")
|
||||
|
||||
# Create PPTX file using the converted model
|
||||
pptx_model = PptxPresentationModel(**pptx_model_data)
|
||||
print(f"Creating PPTX with model: {pptx_model.model_dump_json(indent=2)}")
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
pptx_creator = PptxPresentationCreator(pptx_model, temp_dir)
|
||||
await pptx_creator.create_ppt()
|
||||
|
||||
export_directory = get_exports_directory()
|
||||
pptx_path = os.path.join(export_directory, f"{presentation_content.title}.pptx")
|
||||
pptx_creator.save(pptx_path)
|
||||
|
||||
presentation_and_path = PresentationAndPath(
|
||||
presentation_id=presentation_id,
|
||||
path=pptx_path,
|
||||
)
|
||||
else:
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation as PDF")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"http://localhost/api/export-as-pdf",
|
||||
json={
|
||||
"id": presentation_id,
|
||||
"title": presentation_content.title,
|
||||
},
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
print(f"Received PDF export response: {json.dumps(response_json, indent=2)}")
|
||||
|
||||
presentation_and_path = PresentationAndPath(
|
||||
presentation_id=presentation_id,
|
||||
path=response_json["path"],
|
||||
)
|
||||
|
||||
return PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={presentation_id}",
|
||||
)
|
||||
83
servers/fastapi/api/v1/ppt/endpoints/slide.py
Normal file
83
servers/fastapi/api/v1/ppt/endpoints/slide.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
from services.database import get_sql_session
|
||||
from utils.llm_calls.edit_slide import get_edited_slide_content
|
||||
from utils.llm_calls.edit_slide_html import get_edited_slide_html
|
||||
from utils.llm_calls.select_slide_type_on_edit import get_slide_layout_from_prompt
|
||||
from utils.process_slides import process_old_and_new_slides_and_fetch_assets
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
SLIDE_ROUTER = APIRouter(prefix="/slide", tags=["Slide"])
|
||||
|
||||
|
||||
@SLIDE_ROUTER.post("/edit")
|
||||
async def edit_slide(id: Annotated[str, Body()], prompt: Annotated[str, Body()]):
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
slide = sql_session.get(SlideModel, id)
|
||||
if not slide:
|
||||
raise HTTPException(status_code=404, detail="Slide not found")
|
||||
presentation = sql_session.get(PresentationModel, slide.presentation)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
presentation_layout = presentation.get_layout()
|
||||
|
||||
slide_layout = await get_slide_layout_from_prompt(
|
||||
prompt, presentation_layout, slide
|
||||
)
|
||||
edited_slide_content = await get_edited_slide_content(
|
||||
prompt, slide_layout, slide, presentation.language
|
||||
)
|
||||
|
||||
# This will mutate edited_slide_content
|
||||
new_assets = await process_old_and_new_slides_and_fetch_assets(
|
||||
slide.content, edited_slide_content
|
||||
)
|
||||
|
||||
# Always assign a new unique id to the slide
|
||||
slide.id = get_random_uuid()
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(slide)
|
||||
slide.content = edited_slide_content
|
||||
slide.layout = slide_layout.id
|
||||
sql_session.add_all(new_assets)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(slide)
|
||||
|
||||
return slide
|
||||
|
||||
|
||||
@SLIDE_ROUTER.post("/edit-html", response_model=SlideModel)
|
||||
async def edit_slide_html(
|
||||
id: Annotated[str, Body()],
|
||||
prompt: Annotated[str, Body()],
|
||||
html: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
with get_sql_session() as sql_session:
|
||||
slide = sql_session.get(SlideModel, id)
|
||||
if not slide:
|
||||
raise HTTPException(status_code=404, detail="Slide not found")
|
||||
|
||||
html_to_edit = html or slide.html_content
|
||||
if not html_to_edit:
|
||||
raise HTTPException(status_code=400, detail="No HTML to edit")
|
||||
|
||||
edited_slide_html = await get_edited_slide_html(prompt, html_to_edit)
|
||||
|
||||
# Always assign a new unique id to the slide
|
||||
# This is to ensure that the nextjs can track slide updates
|
||||
slide.id = get_random_uuid()
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(slide)
|
||||
slide.html_content = edited_slide_html
|
||||
sql_session.commit()
|
||||
sql_session.refresh(slide)
|
||||
|
||||
return slide
|
||||
22
servers/fastapi/api/v1/ppt/router.py
Normal file
22
servers/fastapi/api/v1/ppt/router.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
from api.v1.ppt.endpoints.icons import ICONS_ROUTER
|
||||
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
||||
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
|
||||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
from api.v1.ppt.endpoints.slide import SLIDE_ROUTER
|
||||
|
||||
|
||||
API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt")
|
||||
|
||||
API_V1_PPT_ROUTER.include_router(FILES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(SLIDE_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER)
|
||||
File diff suppressed because it is too large
Load diff
0
servers/fastapi/constants/__init__.py
Normal file
0
servers/fastapi/constants/__init__.py
Normal file
20
servers/fastapi/constants/documents.py
Normal file
20
servers/fastapi/constants/documents.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
PDF_MIME_TYPES = ["application/pdf"]
|
||||
TEXT_MIME_TYPES = ["text/plain"]
|
||||
POWERPOINT_TYPES = [
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
]
|
||||
WORD_TYPES = [
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
]
|
||||
SPREADSHEET_TYPES = ["text/csv", "application/csv"]
|
||||
|
||||
|
||||
PNG_MIME_TYPES = ["image/png"]
|
||||
JPEG_MIME_TYPES = ["image/jpeg"]
|
||||
WEBP_MIME_TYPES = ["image/webp"]
|
||||
|
||||
|
||||
UPLOAD_ACCEPTED_FILE_TYPES = (
|
||||
PDF_MIME_TYPES + TEXT_MIME_TYPES + POWERPOINT_TYPES + WORD_TYPES
|
||||
)
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
from api.models import OllamaModelMetadata
|
||||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
|
||||
|
||||
SUPPORTED_LLAMA_MODELS = {
|
||||
SUPPORTED_OLLAMA_MODELS = {
|
||||
"llama3:8b": OllamaModelMetadata(
|
||||
label="Llama 3:8b",
|
||||
value="llama3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3:70b": OllamaModelMetadata(
|
||||
label="Llama 3:70b",
|
||||
|
|
@ -16,7 +16,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="40GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:8b": OllamaModelMetadata(
|
||||
label="Llama 3.1:8b",
|
||||
|
|
@ -24,7 +24,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="4.9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:70b": OllamaModelMetadata(
|
||||
label="Llama 3.1:70b",
|
||||
|
|
@ -32,7 +32,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.1:405b": OllamaModelMetadata(
|
||||
label="Llama 3.1:405b",
|
||||
|
|
@ -40,7 +40,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="243GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:1b": OllamaModelMetadata(
|
||||
label="Llama 3.2:1b",
|
||||
|
|
@ -48,7 +48,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="1.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.2:3b": OllamaModelMetadata(
|
||||
label="Llama 3.2:3b",
|
||||
|
|
@ -56,7 +56,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama3.3:70b": OllamaModelMetadata(
|
||||
label="Llama 3.3:70b",
|
||||
|
|
@ -64,7 +64,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:16x17b": OllamaModelMetadata(
|
||||
label="Llama 4:16x17b",
|
||||
|
|
@ -72,7 +72,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="67GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
"llama4:128x17b": OllamaModelMetadata(
|
||||
label="Llama 4:128x17b",
|
||||
|
|
@ -80,7 +80,7 @@ SUPPORTED_LLAMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="245GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
icon="/static/icons/meta.png",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="815MB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:4b": OllamaModelMetadata(
|
||||
label="Gemma 3:4b",
|
||||
|
|
@ -99,7 +99,7 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="3.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:12b": OllamaModelMetadata(
|
||||
label="Gemma 3:12b",
|
||||
|
|
@ -107,7 +107,7 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="8.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
"gemma3:27b": OllamaModelMetadata(
|
||||
label="Gemma 3:27b",
|
||||
|
|
@ -115,7 +115,7 @@ SUPPORTED_GEMMA_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="17GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
icon="/static/icons/gemma.png",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -126,7 +126,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="1.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:7b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:7b",
|
||||
|
|
@ -134,7 +134,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:8b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:8b",
|
||||
|
|
@ -142,7 +142,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:14b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:14b",
|
||||
|
|
@ -150,7 +150,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:32b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:32b",
|
||||
|
|
@ -158,7 +158,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:70b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:70b",
|
||||
|
|
@ -166,7 +166,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:671b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:671b",
|
||||
|
|
@ -174,7 +174,7 @@ SUPPORTED_DEEPSEEK_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="404GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
icon="/static/icons/deepseek.png",
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="523MB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:1.7b": OllamaModelMetadata(
|
||||
label="Qwen 3:1.7b",
|
||||
|
|
@ -193,7 +193,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="1.4GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:4b": OllamaModelMetadata(
|
||||
label="Qwen 3:4b",
|
||||
|
|
@ -201,7 +201,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="2.6GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:8b": OllamaModelMetadata(
|
||||
label="Qwen 3:8b",
|
||||
|
|
@ -209,7 +209,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:14b": OllamaModelMetadata(
|
||||
label="Qwen 3:14b",
|
||||
|
|
@ -217,7 +217,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="❌ Graphs not supported.",
|
||||
size="9.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:30b": OllamaModelMetadata(
|
||||
label="Qwen 3:30b",
|
||||
|
|
@ -225,7 +225,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="19GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:32b": OllamaModelMetadata(
|
||||
label="Qwen 3:32b",
|
||||
|
|
@ -233,7 +233,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
"qwen3:235b": OllamaModelMetadata(
|
||||
label="Qwen 3:235b",
|
||||
|
|
@ -241,12 +241,12 @@ SUPPORTED_QWEN_MODELS = {
|
|||
description="✅ Graphs supported.",
|
||||
size="142GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
icon="/static/icons/qwen.png",
|
||||
),
|
||||
}
|
||||
|
||||
SUPPORTED_OLLAMA_MODELS = {
|
||||
**SUPPORTED_LLAMA_MODELS,
|
||||
**SUPPORTED_OLLAMA_MODELS,
|
||||
**SUPPORTED_GEMMA_MODELS,
|
||||
**SUPPORTED_DEEPSEEK_MODELS,
|
||||
**SUPPORTED_QWEN_MODELS,
|
||||
0
servers/fastapi/enums/__init__.py
Normal file
0
servers/fastapi/enums/__init__.py
Normal file
7
servers/fastapi/enums/image_provider.py
Normal file
7
servers/fastapi/enums/image_provider.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from enum import Enum
|
||||
|
||||
class ImageProvider(Enum):
|
||||
PEXELS = "pexels"
|
||||
PIXABAY = "pixabay"
|
||||
GEMINI_FLASH = "gemini_flash"
|
||||
DALLE3 = "dall-e-3"
|
||||
8
servers/fastapi/enums/llm_provider.py
Normal file
8
servers/fastapi/enums/llm_provider.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class LLMProvider(Enum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
CUSTOM = "custom"
|
||||
419
servers/fastapi/get_test_schema.py
Normal file
419
servers/fastapi/get_test_schema.py
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
import json
|
||||
from typing import List, Literal, Optional
|
||||
from pydantic import BaseModel, Field, HttpUrl, EmailStr
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from utils.dict_utils import get_dict_at_path, get_dict_paths_with_key
|
||||
from utils.schema_utils import remove_fields_from_schema
|
||||
|
||||
|
||||
class ContactInfoModel(BaseModel):
|
||||
email: Optional[EmailStr] = Field(None, description="Contact email")
|
||||
phone: Optional[str] = Field(
|
||||
None, min_length=5, max_length=50, description="Contact phone number"
|
||||
)
|
||||
website: Optional[HttpUrl] = Field(None, description="Website URL")
|
||||
|
||||
|
||||
class ImageModel(BaseModel):
|
||||
__image_url__: str = Field(description="Image URL")
|
||||
__image_prompt__: str = Field(description="Image prompt")
|
||||
|
||||
|
||||
# First Slide Layout
|
||||
class FirstSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Main title of the presentation",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=10, max_length=200, description="Optional subtitle or tagline"
|
||||
)
|
||||
author: Optional[str] = Field(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
description="Author or presenter name",
|
||||
)
|
||||
date: Optional[str] = Field(description="Presentation date")
|
||||
company: Optional[str] = Field(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
description="Company or organization name",
|
||||
)
|
||||
backgroundImage: Optional[ImageModel] = Field(
|
||||
description="Background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Bullet Point Slide Layout
|
||||
class BulletPointSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
icon: Optional[str] = Field(description="Icon to display in the slide")
|
||||
bulletPoints: List[str] = Field(
|
||||
min_length=2,
|
||||
max_length=8,
|
||||
description="List of bullet points (2-8 items)",
|
||||
)
|
||||
|
||||
|
||||
# Image Slide Layout
|
||||
class ImageSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
image: HttpUrl = Field(
|
||||
description="Main image URL",
|
||||
)
|
||||
imageCaption: Optional[str] = Field(
|
||||
min_length=5,
|
||||
max_length=200,
|
||||
description="Optional image caption or description",
|
||||
)
|
||||
content: Optional[str] = Field(
|
||||
min_length=10,
|
||||
max_length=600,
|
||||
description="Optional supporting content text",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Statistics Slide Layout
|
||||
class StatisticItemModel(BaseModel):
|
||||
value: str = Field(
|
||||
min_length=1,
|
||||
max_length=20,
|
||||
description="Statistical value (e.g., '250%', '$1.2M', '99.9%')",
|
||||
)
|
||||
label: str = Field(
|
||||
min_length=3, max_length=100, description="Description of the statistic"
|
||||
)
|
||||
trend: Optional[str] = Field(
|
||||
description="Trend direction indicator", pattern="^(up|down|neutral)$"
|
||||
)
|
||||
context: Optional[str] = Field(
|
||||
min_length=5,
|
||||
max_length=200,
|
||||
description="Additional context or time period",
|
||||
)
|
||||
|
||||
|
||||
class StatisticsSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
statistics: List[StatisticItemModel] = Field(
|
||||
min_length=2,
|
||||
max_length=6,
|
||||
description="List of statistics (2-6 items)",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Quote Slide Layout
|
||||
class QuoteSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
quote: str = Field(
|
||||
min_length=10,
|
||||
max_length=500,
|
||||
description="The main quote or testimonial",
|
||||
)
|
||||
author: str = Field(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
description="Quote author name",
|
||||
)
|
||||
authorTitle: Optional[str] = Field(
|
||||
min_length=2, max_length=100, description="Author job title or position"
|
||||
)
|
||||
company: Optional[str] = Field(
|
||||
min_length=2, max_length=100, description="Author company or organization"
|
||||
)
|
||||
authorImage: Optional[HttpUrl] = Field(description="URL to author photo")
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Timeline Slide Layout
|
||||
class TimelineItemModel(BaseModel):
|
||||
date: str = Field(min_length=2, max_length=50, description="Date or time period")
|
||||
title: str = Field(
|
||||
min_length=3, max_length=100, description="Event or milestone title"
|
||||
)
|
||||
description: str = Field(
|
||||
min_length=10, max_length=300, description="Event description"
|
||||
)
|
||||
status: str = Field(
|
||||
description="Timeline item status",
|
||||
pattern="^(completed|current|upcoming)$",
|
||||
)
|
||||
|
||||
|
||||
class TimelineSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
timelineItems: List[TimelineItemModel] = Field(
|
||||
min_length=2,
|
||||
max_length=6,
|
||||
description="Timeline events (2-6 items)",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Team Slide Layout
|
||||
class TeamMemberModel(BaseModel):
|
||||
name: str = Field(min_length=2, max_length=100, description="Team member name")
|
||||
title: str = Field(min_length=2, max_length=100, description="Job title or role")
|
||||
image: Optional[HttpUrl] = Field(description="URL to team member photo")
|
||||
bio: Optional[str] = Field(
|
||||
min_length=10,
|
||||
max_length=300,
|
||||
description="Brief biography or description",
|
||||
)
|
||||
email: Optional[EmailStr] = Field(description="Contact email")
|
||||
linkedin: Optional[HttpUrl] = Field(description="LinkedIn profile URL")
|
||||
|
||||
|
||||
class TeamSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or team description",
|
||||
)
|
||||
teamMembers: List[TeamMemberModel] = Field(
|
||||
min_length=1,
|
||||
max_length=6,
|
||||
description="Team members (1-6 people)",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Process Slide Layout
|
||||
class ProcessStepModel(BaseModel):
|
||||
step: int = Field(ge=1, le=10, description="Step number")
|
||||
title: str = Field(min_length=3, max_length=100, description="Step title")
|
||||
description: str = Field(
|
||||
min_length=10, max_length=200, description="Step description"
|
||||
)
|
||||
|
||||
|
||||
class ProcessSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
processSteps: List[ProcessStepModel] = Field(
|
||||
min_length=2,
|
||||
max_length=6,
|
||||
description="Process steps (2-6 items)",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Two Column Slide Layout
|
||||
class ColumnContentModel(BaseModel):
|
||||
title: str = Field(min_length=3, max_length=100, description="Column title")
|
||||
content: str = Field(min_length=10, max_length=800, description="Column content")
|
||||
|
||||
|
||||
class TwoColumnSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
leftColumn: ColumnContentModel = Field(
|
||||
description="Left column content",
|
||||
)
|
||||
rightColumn: ColumnContentModel = Field(
|
||||
description="Right column content",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Conclusion Slide Layout
|
||||
class ConclusionSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
keyTakeaways: List[str] = Field(
|
||||
min_length=2,
|
||||
max_length=6,
|
||||
description="Key takeaways or summary points (2-6 items)",
|
||||
)
|
||||
callToAction: Optional[str] = Field(
|
||||
min_length=5,
|
||||
max_length=150,
|
||||
description="Optional call to action or next steps",
|
||||
)
|
||||
contactInfo: Optional[ContactInfoModel] = Field(
|
||||
description="Optional contact information"
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Content Slide Layout
|
||||
class ContentSlideModel(BaseModel):
|
||||
title: str = Field(
|
||||
min_length=3,
|
||||
max_length=100,
|
||||
description="Title of the slide",
|
||||
)
|
||||
subtitle: Optional[str] = Field(
|
||||
min_length=3,
|
||||
max_length=150,
|
||||
description="Optional subtitle or description",
|
||||
)
|
||||
content: str = Field(
|
||||
min_length=10,
|
||||
max_length=1000,
|
||||
description="Main content text",
|
||||
)
|
||||
backgroundImage: Optional[HttpUrl] = Field(
|
||||
description="URL to background image for the slide"
|
||||
)
|
||||
|
||||
|
||||
# Create the presentation layout with all slide types
|
||||
presentation_layout = PresentationLayoutModel(
|
||||
name="Complete Presentation Layout",
|
||||
slides=[
|
||||
SlideLayoutModel(
|
||||
id="first-slide",
|
||||
name="First Slide",
|
||||
json_schema=FirstSlideModel.model_json_schema(),
|
||||
),
|
||||
# SlideLayoutModel(
|
||||
# id="bullet-point-slide",
|
||||
# name="Bullet Point Slide",
|
||||
# json_schema=BulletPointSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="image-slide",
|
||||
# name="Image Slide",
|
||||
# json_schema=ImageSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="statistics-slide",
|
||||
# name="Statistics Slide",
|
||||
# json_schema=StatisticsSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="quote-slide",
|
||||
# name="Quote Slide",
|
||||
# json_schema=QuoteSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="timeline-slide",
|
||||
# name="Timeline Slide",
|
||||
# json_schema=TimelineSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="team-slide",
|
||||
# name="Team Slide",
|
||||
# json_schema=TeamSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="process-slide",
|
||||
# name="Process Slide",
|
||||
# json_schema=ProcessSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="two-column-slide",
|
||||
# name="Two Column Slide",
|
||||
# json_schema=TwoColumnSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="conclusion-slide",
|
||||
# name="Conclusion Slide",
|
||||
# json_schema=ConclusionSlideModel.model_json_schema(),
|
||||
# ),
|
||||
# SlideLayoutModel(
|
||||
# id="content-slide",
|
||||
# name="Content Slide",
|
||||
# json_schema=ContentSlideModel.model_json_schema(),
|
||||
# ),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
print(json.dumps(StatisticsSlideModel.model_json_schema()))
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class PointModel(BaseModel):
|
||||
x: float
|
||||
y: float
|
||||
|
||||
def to_list(self) -> List[float]:
|
||||
return [self.x, self.y]
|
||||
|
||||
|
||||
class PointWithRadius(PointModel):
|
||||
radius: Optional[float] = None
|
||||
|
||||
|
||||
class BarSeriesModel(BaseModel):
|
||||
name: str
|
||||
data: List[float] = Field(
|
||||
description="Only numbers should be given out in data. Don't include text/string in data."
|
||||
)
|
||||
|
||||
|
||||
class ScatterSeriesModel(BaseModel):
|
||||
name: str
|
||||
points: List[PointModel]
|
||||
|
||||
|
||||
class BubbleSeriesModel(BaseModel):
|
||||
name: str
|
||||
points: List[PointWithRadius]
|
||||
|
||||
|
||||
class LineSeriesModel(BaseModel):
|
||||
name: str
|
||||
data: List[float] = Field(
|
||||
description="Only numbers should be given out in data. Don't include text/string in data."
|
||||
)
|
||||
|
||||
|
||||
class PieChartSeriesModel(BaseModel):
|
||||
data: List[float]
|
||||
|
||||
|
||||
class BarGraphDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[BarSeriesModel] = Field(
|
||||
description="There should be no more than 3 series"
|
||||
)
|
||||
|
||||
|
||||
class ScatterChartDataModel(BaseModel):
|
||||
series: List[ScatterSeriesModel]
|
||||
|
||||
|
||||
class BubbleChartDataModel(BaseModel):
|
||||
series: List[BubbleSeriesModel]
|
||||
|
||||
|
||||
class LineChartDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[LineSeriesModel] = Field(
|
||||
description="There should be no more than 3 series"
|
||||
)
|
||||
|
||||
|
||||
class PieChartDataModel(BaseModel):
|
||||
categories: List[str]
|
||||
series: List[PieChartSeriesModel] = Field(
|
||||
description="One series model with list of data",
|
||||
min_length=1,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def limit_series(self):
|
||||
self.series = self.series[:1]
|
||||
return self
|
||||
|
||||
|
||||
class GraphTypeEnum(Enum):
|
||||
pie = "pie"
|
||||
bar = "bar"
|
||||
line = "line"
|
||||
|
||||
|
||||
class LLMGraphModel(BaseModel):
|
||||
name: str
|
||||
type: GraphTypeEnum
|
||||
unit: Optional[str] = Field(
|
||||
description="Unit of the data in the graph. Example: %, kg, million USD, tonnes, etc."
|
||||
)
|
||||
data: PieChartDataModel | LineChartDataModel | BarGraphDataModel
|
||||
|
||||
|
||||
class GraphModel(LLMGraphModel):
|
||||
style: Optional[dict] = {}
|
||||
|
||||
@classmethod
|
||||
def from_llm_graph_model(
|
||||
cls, llm_graph_model: LLMGraphModel, style: Optional[dict] = {}
|
||||
):
|
||||
return cls(
|
||||
name=llm_graph_model.name,
|
||||
type=llm_graph_model.type,
|
||||
unit=llm_graph_model.unit,
|
||||
data=llm_graph_model.data,
|
||||
style=style,
|
||||
)
|
||||
|
||||
|
||||
GRAPH_TYPE_MAPPING = {
|
||||
GraphTypeEnum.pie: PieChartDataModel,
|
||||
GraphTypeEnum.bar: BarGraphDataModel,
|
||||
GraphTypeEnum.line: LineChartDataModel,
|
||||
}
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from api.utils.utils import get_resource
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
IconCategoryEnum,
|
||||
IconQueryCollectionWithData,
|
||||
)
|
||||
from fastembed_vectorstore import FastembedVectorstore
|
||||
|
||||
|
||||
async def get_icon(
|
||||
vector_store: FastembedVectorstore,
|
||||
input: IconQueryCollectionWithData,
|
||||
) -> str:
|
||||
try:
|
||||
query = input.icon_query
|
||||
results = vector_store.search(query, 1)
|
||||
icon_name = results[0][0].split("||")[0]
|
||||
return get_resource(f"assets/icons/bold/{icon_name}.png")
|
||||
except Exception as e:
|
||||
print("Error finding icon: ", e)
|
||||
return get_resource("assets/icons/placeholder.png")
|
||||
|
||||
|
||||
async def get_icons(
|
||||
vector_store: FastembedVectorstore,
|
||||
query: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
category: Optional[IconCategoryEnum],
|
||||
temp_dir: str,
|
||||
) -> List[str]:
|
||||
|
||||
results = vector_store.search(query, limit)
|
||||
icon_names = [result[0].split("||")[0] for result in results]
|
||||
|
||||
return [get_resource(f"assets/icons/bold/{each}.png") for each in icon_names]
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from api.utils.utils import get_resource
|
||||
from fastembed_vectorstore import FastembedVectorstore, FastembedEmbeddingModel
|
||||
|
||||
|
||||
def get_icons_vectorstore():
|
||||
vector_store_path = get_resource("assets/icons_vectorstore.json")
|
||||
embedding_model = FastembedEmbeddingModel.BGESmallENV15
|
||||
|
||||
if os.path.exists(vector_store_path):
|
||||
return FastembedVectorstore.load(embedding_model, vector_store_path)
|
||||
|
||||
vector_store = FastembedVectorstore(embedding_model)
|
||||
with open(get_resource("assets/icons.json"), "r") as f:
|
||||
icons = json.load(f)
|
||||
documents = []
|
||||
for each in icons["icons"]:
|
||||
if each["name"].split("-")[-1] == "bold":
|
||||
documents.append(f"{each['name']}||{each['tags']}")
|
||||
|
||||
vector_store.embed_documents(documents)
|
||||
vector_store.save(vector_store_path)
|
||||
|
||||
return vector_store
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from google import genai
|
||||
from google.genai.types import GenerateContentConfig
|
||||
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from api.utils.utils import download_file, get_resource
|
||||
from api.utils.model_utils import (
|
||||
get_llm_client,
|
||||
is_custom_llm_selected,
|
||||
is_ollama_selected,
|
||||
)
|
||||
|
||||
|
||||
async def generate_image(
|
||||
input: ImagePromptWithThemeAndAspectRatio,
|
||||
output_directory: str,
|
||||
) -> str:
|
||||
is_ollama = is_ollama_selected()
|
||||
is_custom_llm = is_custom_llm_selected()
|
||||
|
||||
image_prompt = (
|
||||
input.image_prompt
|
||||
if is_ollama or is_custom_llm
|
||||
else f"{input.image_prompt}, {input.theme_prompt}"
|
||||
)
|
||||
print(f"Request - Generating Image for {image_prompt}")
|
||||
|
||||
try:
|
||||
image_gen_func = (
|
||||
get_image_from_pexels
|
||||
if is_ollama or is_custom_llm
|
||||
else (
|
||||
generate_image_openai
|
||||
if os.getenv("LLM") == "openai"
|
||||
else generate_image_google
|
||||
)
|
||||
)
|
||||
image_path = await image_gen_func(image_prompt, output_directory)
|
||||
if image_path and os.path.exists(image_path):
|
||||
return image_path
|
||||
raise Exception(f"Image not found at {image_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating image: {e}")
|
||||
return get_resource("assets/images/placeholder.jpg")
|
||||
|
||||
|
||||
async def generate_image_openai(prompt: str, output_directory: str) -> str:
|
||||
client = get_llm_client()
|
||||
result = await client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
quality="standard",
|
||||
size="1024x1024",
|
||||
)
|
||||
image_url = result.data[0].url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
image_bytes = await response.read()
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return image_path
|
||||
|
||||
|
||||
async def generate_image_google(prompt: str, output_directory: str) -> str:
|
||||
client = genai.Client()
|
||||
response = await asyncio.to_thread(
|
||||
client.models.generate_content,
|
||||
model="gemini-2.0-flash-preview-image-generation",
|
||||
contents=[prompt],
|
||||
config=GenerateContentConfig(response_modalities=["TEXT", "IMAGE"]),
|
||||
)
|
||||
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text is not None:
|
||||
print(part.text)
|
||||
elif part.inline_data is not None:
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(part.inline_data.data)
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
async def get_image_from_pexels(prompt: str, output_directory: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response = await session.get(
|
||||
f"https://api.pexels.com/v1/search?query={prompt}&per_page=1",
|
||||
headers={"Authorization": f'{os.getenv("PEXELS_API_KEY")}'},
|
||||
)
|
||||
data = await response.json()
|
||||
image_url = data["photos"][0]["src"]["large"]
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
await download_file(image_url, image_path)
|
||||
return image_path
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
import pdfplumber
|
||||
|
||||
|
||||
def get_page_images_from_pdf(document_path: str, temp_dir: str):
|
||||
images_temp_dir = TEMP_FILE_SERVICE.create_dir_in_dir(temp_dir)
|
||||
|
||||
with pdfplumber.open(document_path) as pdf:
|
||||
for page in pdf.pages:
|
||||
img = page.to_image(resolution=300)
|
||||
img.save(os.path.join(images_temp_dir, f"page_{page.page_number}.png"))
|
||||
|
||||
|
||||
async def get_page_images_from_pdf_async(document_path: str, temp_dir: str):
|
||||
return await asyncio.to_thread(get_page_images_from_pdf, document_path, temp_dir)
|
||||
0
servers/fastapi/models/__init__.py
Normal file
0
servers/fastapi/models/__init__.py
Normal file
6
servers/fastapi/models/decomposed_file_info.py
Normal file
6
servers/fastapi/models/decomposed_file_info.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DecomposedFileInfo(BaseModel):
|
||||
name: str
|
||||
file_path: str
|
||||
19
servers/fastapi/models/generate_presentation_api.py
Normal file
19
servers/fastapi/models/generate_presentation_api.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from typing import List, Optional, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import UploadFile
|
||||
|
||||
class GeneratePresentationRequest(BaseModel):
|
||||
prompt: str
|
||||
n_slides: int = Field(default=8, ge=5, le=15)
|
||||
language: str = Field(default="English")
|
||||
layout: str = Field(default="default")
|
||||
documents: Optional[List[UploadFile]] = None
|
||||
export_as: Literal["pptx", "pdf"] = Field(default="pptx")
|
||||
|
||||
|
||||
class PresentationAndPath(BaseModel):
|
||||
presentation_id: str
|
||||
path: str
|
||||
|
||||
class PresentationPathAndEditPath(PresentationAndPath):
|
||||
edit_path: str
|
||||
10
servers/fastapi/models/image_prompt.py
Normal file
10
servers/fastapi/models/image_prompt.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ImagePrompt(BaseModel):
|
||||
prompt: str
|
||||
theme_prompt: Optional[str] = None
|
||||
|
||||
def get_image_prompt(self, with_theme: bool = False) -> str:
|
||||
return f"{self.prompt}, {self.theme_prompt}" if with_theme else self.prompt
|
||||
14
servers/fastapi/models/json_path_guide.py
Normal file
14
servers/fastapi/models/json_path_guide.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DictGuide(BaseModel):
|
||||
key: str
|
||||
|
||||
|
||||
class ListGuide(BaseModel):
|
||||
index: int
|
||||
|
||||
|
||||
class JsonPathGuide(BaseModel):
|
||||
guides: List[DictGuide | ListGuide]
|
||||
10
servers/fastapi/models/ollama_model_metadata.py
Normal file
10
servers/fastapi/models/ollama_model_metadata.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OllamaModelMetadata(BaseModel):
|
||||
label: str
|
||||
value: str
|
||||
description: str
|
||||
icon: str
|
||||
size: str
|
||||
supports_graph: bool
|
||||
10
servers/fastapi/models/ollama_model_status.py
Normal file
10
servers/fastapi/models/ollama_model_status.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OllamaModelStatus(BaseModel):
|
||||
name: str
|
||||
size: Optional[int] = None
|
||||
downloaded: Optional[int] = None
|
||||
status: str
|
||||
done: bool
|
||||
|
|
@ -54,18 +54,20 @@ class PptxPositionModel(BaseModel):
|
|||
class PptxFontModel(BaseModel):
|
||||
name: str = "Inter"
|
||||
size: int = 16
|
||||
bold: bool = False
|
||||
italic: bool = False
|
||||
color: str = "000000"
|
||||
font_weight: Optional[int] = 400
|
||||
|
||||
|
||||
class PptxFillModel(BaseModel):
|
||||
color: str
|
||||
opacity: float = 1.0
|
||||
|
||||
|
||||
class PptxStrokeModel(BaseModel):
|
||||
color: str
|
||||
thickness: float
|
||||
opacity: float = 1.0
|
||||
|
||||
|
||||
class PptxShadowModel(BaseModel):
|
||||
|
|
@ -85,6 +87,7 @@ class PptxParagraphModel(BaseModel):
|
|||
spacing: Optional[PptxSpacingModel] = None
|
||||
alignment: Optional[PP_ALIGN] = None
|
||||
font: Optional[PptxFontModel] = None
|
||||
line_height: Optional[float] = None
|
||||
text: Optional[str] = None
|
||||
text_runs: Optional[List[PptxTextRunModel]] = None
|
||||
|
||||
|
|
@ -129,6 +132,7 @@ class PptxPictureBoxModel(PptxShapeModel):
|
|||
position: PptxPositionModel
|
||||
margin: Optional[PptxSpacingModel] = None
|
||||
clip: bool = True
|
||||
opacity: Optional[float] = None
|
||||
overlay: Optional[str] = None
|
||||
border_radius: Optional[List[int]] = None
|
||||
shape: Optional[PptxBoxShapeEnum] = None
|
||||
|
|
@ -141,9 +145,11 @@ class PptxConnectorModel(PptxShapeModel):
|
|||
position: PptxPositionModel
|
||||
thickness: float = 0.5
|
||||
color: str = "000000"
|
||||
opacity: float = 1.0
|
||||
|
||||
|
||||
class PptxSlideModel(BaseModel):
|
||||
background: Optional[PptxFillModel] = None
|
||||
shapes: List[
|
||||
PptxTextBoxModel
|
||||
| PptxAutoShapeBoxModel
|
||||
|
|
@ -153,6 +159,6 @@ class PptxSlideModel(BaseModel):
|
|||
|
||||
|
||||
class PptxPresentationModel(BaseModel):
|
||||
background_color: str
|
||||
name: Optional[str] = None
|
||||
shapes: Optional[List[PptxShapeModel]] = None
|
||||
slides: List[PptxSlideModel]
|
||||
30
servers/fastapi/models/presentation_layout.py
Normal file
30
servers/fastapi/models/presentation_layout.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
class SlideLayoutModel(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
json_schema: dict
|
||||
|
||||
|
||||
class PresentationLayoutModel(BaseModel):
|
||||
name: Optional[str] = None
|
||||
ordered: bool = Field(default=False)
|
||||
slides: List[SlideLayoutModel]
|
||||
|
||||
def to_presentation_structure(self):
|
||||
return PresentationStructureModel(
|
||||
slides=[index for index in range(len(self.slides))]
|
||||
)
|
||||
|
||||
def to_string(self):
|
||||
message = f"## Presentation Layout\n\n"
|
||||
for index, slide in enumerate(self.slides):
|
||||
message += f"### Slide Layout: {index}: \n"
|
||||
message += f"- Name: {slide.name or slide.json_schema.get('title')} \n"
|
||||
message += f"- Description: {slide.description} \n\n"
|
||||
return message
|
||||
|
|
@ -2,15 +2,7 @@ from typing import List, Optional
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SlideStructureModel(BaseModel):
|
||||
type: int = Field(description="Type of the slide", gte=1, lte=9)
|
||||
|
||||
|
||||
class PresentationStructureModel(BaseModel):
|
||||
slides: List[SlideStructureModel] = Field(description="List of slide structure")
|
||||
|
||||
|
||||
class SlideMarkdownModel(BaseModel):
|
||||
class SlideOutlineModel(BaseModel):
|
||||
title: str = Field(
|
||||
description="Title of the slide in about 3 to 5 words",
|
||||
)
|
||||
|
|
@ -19,12 +11,12 @@ class SlideMarkdownModel(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class PresentationMarkdownModel(BaseModel):
|
||||
class PresentationOutlineModel(BaseModel):
|
||||
title: str = Field(
|
||||
description="Title of the presentation in about 3 to 8 words",
|
||||
)
|
||||
notes: Optional[List[str]] = Field(description="Notes for the presentation")
|
||||
slides: List[SlideMarkdownModel] = Field(description="List of slides")
|
||||
slides: List[SlideOutlineModel] = Field(description="List of slides")
|
||||
|
||||
def to_string(self):
|
||||
message = f"# Presentation Title: {self.title} \n\n"
|
||||
6
servers/fastapi/models/presentation_structure_model.py
Normal file
6
servers/fastapi/models/presentation_structure_model.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PresentationStructureModel(BaseModel):
|
||||
slides: List[int] = Field(description="List of slide layout indexes")
|
||||
29
servers/fastapi/models/presentation_with_slides.py
Normal file
29
servers/fastapi/models/presentation_with_slides.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sql.slide import SlideModel
|
||||
|
||||
|
||||
class PresentationWithSlides(BaseModel):
|
||||
id: str
|
||||
prompt: str
|
||||
n_slides: int
|
||||
language: str
|
||||
title: Optional[str] = None
|
||||
notes: Optional[List[str]]
|
||||
outlines: Optional[List[SlideOutlineModel]]
|
||||
summary: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
layout: PresentationLayoutModel
|
||||
structure: Optional[PresentationStructureModel]
|
||||
slides: List[SlideModel]
|
||||
|
||||
def to_presentation_model(self) -> PresentationModel:
|
||||
return PresentationModel(**self.model_dump())
|
||||
5
servers/fastapi/models/slide_layout_index.py
Normal file
5
servers/fastapi/models/slide_layout_index.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SlideLayoutIndex(BaseModel):
|
||||
index: int
|
||||
14
servers/fastapi/models/sql/image_asset.py
Normal file
14
servers/fastapi/models/sql/image_asset.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class ImageAsset(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
path: str
|
||||
extras: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
9
servers/fastapi/models/sql/key_value.py
Normal file
9
servers/fastapi/models/sql/key_value.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class KeyValueSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
key: str = Field(index=True)
|
||||
value: dict = Field(sa_column=Column(JSON))
|
||||
49
servers/fastapi/models/sql/presentation.py
Normal file
49
servers/fastapi/models/sql/presentation.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
class PresentationModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True)
|
||||
prompt: str
|
||||
n_slides: int
|
||||
language: str
|
||||
title: Optional[str] = None
|
||||
notes: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
|
||||
outlines: Optional[List[dict]] = Field(sa_column=Column(JSON), default=None)
|
||||
summary: Optional[str] = None
|
||||
created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
updated_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
layout: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
structure: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
|
||||
def get_presentation_outline(self):
|
||||
if not self.outlines:
|
||||
return None
|
||||
return PresentationOutlineModel(
|
||||
title=self.title,
|
||||
slides=[SlideOutlineModel(**each) for each in self.outlines],
|
||||
notes=self.notes,
|
||||
)
|
||||
|
||||
def get_layout(self):
|
||||
return PresentationLayoutModel(**self.layout)
|
||||
|
||||
def set_layout(self, layout: PresentationLayoutModel):
|
||||
self.layout = layout.model_dump()
|
||||
|
||||
def get_structure(self):
|
||||
if not self.structure:
|
||||
return None
|
||||
return PresentationStructureModel(**self.structure)
|
||||
|
||||
def set_structure(self, structure: PresentationStructureModel):
|
||||
self.structure = structure.model_dump()
|
||||
15
servers/fastapi/models/sql/slide.py
Normal file
15
servers/fastapi/models/sql/slide.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class SlideModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True, default_factory=get_random_uuid)
|
||||
presentation: str
|
||||
layout_group: str
|
||||
layout: str
|
||||
index: int
|
||||
content: dict = Field(sa_column=Column(JSON))
|
||||
html_content: Optional[str]
|
||||
properties: Optional[dict] = Field(sa_column=Column(JSON))
|
||||
40
servers/fastapi/models/sse_response.py
Normal file
40
servers/fastapi/models/sse_response.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SSEResponse(BaseModel):
|
||||
event: str
|
||||
data: str
|
||||
|
||||
def to_string(self):
|
||||
return f"event: {self.event}\ndata: {self.data}\n\n"
|
||||
|
||||
|
||||
class SSEStatusResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "status", "status": self.status})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSEErrorResponse(BaseModel):
|
||||
detail: str
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "error", "detail": self.detail})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSECompleteResponse(BaseModel):
|
||||
key: str
|
||||
value: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "complete", self.key: self.value}),
|
||||
).to_string()
|
||||
16
servers/fastapi/models/user_config.py
Normal file
16
servers/fastapi/models/user_config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
LLM: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
OLLAMA_URL: Optional[str] = None
|
||||
OLLAMA_MODEL: Optional[str] = None
|
||||
CUSTOM_LLM_URL: Optional[str] = None
|
||||
CUSTOM_LLM_API_KEY: Optional[str] = None
|
||||
CUSTOM_MODEL: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
IMAGE_PROVIDER: Optional[str] = None
|
||||
PIXABAY_API_KEY: Optional[str] = None
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
from langchain.schema import BaseOutputParser
|
||||
|
||||
|
||||
class StripMarkdownOutputParser(BaseOutputParser):
|
||||
def parse(self, text: str) -> str:
|
||||
# Remove triple backticks and any optional language hint like ```markdown
|
||||
import re
|
||||
|
||||
return re.sub(r"^```[\w]*\n?|```$", "", text.strip(), flags=re.MULTILINE)
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.utils.model_utils import get_large_model, get_llm_client
|
||||
from api.utils.variable_length_models import (
|
||||
get_presentation_markdown_model_with_n_slides,
|
||||
)
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
|
||||
|
||||
def get_prompt_template(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
Create a presentation based on the provided prompt, number of slides, output language, and additional informational details.
|
||||
Format the output in the specified JSON schema with structured markdown content.
|
||||
|
||||
# Steps
|
||||
|
||||
1. Identify key points from the provided prompt, including the topic, number of slides, output language, and additional content directions.
|
||||
2. Create a concise and descriptive title reflecting the main topic, adhering to the specified language.
|
||||
3. Generate a clear title for each slide.
|
||||
4. Develop comprehensive content using markdown structure:
|
||||
* Use bullet points (- or *) for lists.
|
||||
* Use **bold** for emphasis, *italic* for secondary emphasis, and `code` for technical terms.
|
||||
5. Provide important points from prompt as notes.
|
||||
|
||||
# Notes
|
||||
- Content must be generated for every slide.
|
||||
- Images or Icons information provided in **Input** must be included in the **notes**.
|
||||
- Notes should cleary define if it is for specific slide or for the presentation.
|
||||
- Slide **body** should not contain slide **title**.
|
||||
- Slide **title** should not contain "Slide 1", "Slide 2", etc.
|
||||
- Slide **title** should not be in markdown format.
|
||||
- There must be exact **Number of Slides** as specified.
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
**Input:**
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Number of Slides: {n_slides}
|
||||
- Additional Information: {content}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def generate_ppt_content(
|
||||
prompt: Optional[str],
|
||||
n_slides: int,
|
||||
language: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
) -> PresentationMarkdownModel:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
response_model = get_presentation_markdown_model_with_n_slides(n_slides)
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt_template(prompt, n_slides, language, content),
|
||||
response_format=response_model,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
from api.utils.model_utils import get_llm_client, get_small_model
|
||||
from api.utils.variable_length_models import (
|
||||
get_presentation_structure_model_with_n_slides,
|
||||
)
|
||||
from ppt_config_generator.models import (
|
||||
PresentationStructureModel,
|
||||
PresentationMarkdownModel,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt(n_slides: int, data: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
You're a professional presentation designer with years of experience in designing clear and engaging presentations.
|
||||
|
||||
# Slide Types
|
||||
- **1**: contains title, description and image.
|
||||
- **2**: contains title and list of items.
|
||||
- **4**: contains title and list of items with images.
|
||||
- **5**: contains title, description and a graph.
|
||||
- **6**: contains title, description and list of items.
|
||||
- **7**: contains title and list of items with icons.
|
||||
- **8**: contains title, description and list of items with icons.
|
||||
- **9**: contains title, list of items and a graph.
|
||||
|
||||
# Steps
|
||||
1. Analyze provided Number of slides, Presentation title, Slides content and Slide types.
|
||||
2. Select appropriate slide type for each slide.
|
||||
3. Provide output in json format as per given schema.
|
||||
|
||||
# Notes
|
||||
- Slide type should be selected based on provided content for slide and notes.
|
||||
- Feel free to select slide type with images and icons.
|
||||
- Introduction and Conclusion should have type **1**.
|
||||
- Don't fall into patterns like always using type 2 and after type 1.
|
||||
- Each presentation should have its own unique flow and rhythm.
|
||||
- Do not select type **3** for any slide.
|
||||
- Do not select type **5** or **9** if outline does not have table.
|
||||
- Select type for {n_slides} slides.
|
||||
|
||||
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
{data}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def generate_presentation_structure(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> PresentationStructureModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
temperature=0.2,
|
||||
messages=get_prompt(
|
||||
len(presentation_outline.slides), presentation_outline.to_string()
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from api.utils.model_utils import get_large_model
|
||||
|
||||
|
||||
def get_prompt_template():
|
||||
return ChatPromptTemplate(
|
||||
messages=[
|
||||
(
|
||||
"system",
|
||||
"""
|
||||
Analyze the provided [Input] and [Errors] then provide structured output by fixing the errors.
|
||||
|
||||
# Steps
|
||||
1. Go through the provided [Input].
|
||||
2. Find mentioned [Errors] in the [Input].
|
||||
3. Check provided schema and follow every constraints.
|
||||
4. Provide structured output.
|
||||
|
||||
# Notes
|
||||
- Only output fields mentioned in the schema.
|
||||
- Check if fields' key may have been misnamed in the provided **Input**.
|
||||
- Change fields' name to match the schema.
|
||||
""",
|
||||
),
|
||||
(
|
||||
"user",
|
||||
"""
|
||||
- Input: {input}
|
||||
- Errors: {errors}
|
||||
""",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def fix_validation_errors(response_model: BaseModel, response, errors):
|
||||
model = ChatOllama(model=get_large_model(), temperature=0.8)
|
||||
|
||||
chain = get_prompt_template() | model.with_structured_output(
|
||||
response_model.model_json_schema()
|
||||
)
|
||||
return await chain.ainvoke({"input": response, "errors": errors})
|
||||
|
||||
|
||||
async def get_validated_response(
|
||||
chain,
|
||||
input_dict,
|
||||
response_model: BaseModel,
|
||||
validation_model: Optional[BaseModel] = None,
|
||||
retries: int = 1,
|
||||
):
|
||||
response = await chain.ainvoke(input_dict)
|
||||
validation_model = validation_model or response_model
|
||||
|
||||
attempt = 0
|
||||
while retries >= attempt:
|
||||
attempt += 1
|
||||
print("-" * 50)
|
||||
print(f"Validation Retry attempt - {attempt}")
|
||||
try:
|
||||
if response and type(response) is list:
|
||||
response = response[0]["args"]
|
||||
|
||||
validated_response = validation_model(**response)
|
||||
return validated_response
|
||||
except ValidationError as e:
|
||||
if retries < attempt:
|
||||
break
|
||||
|
||||
error_details = []
|
||||
for error in e.errors():
|
||||
error_details.append(
|
||||
{
|
||||
"loc": " -> ".join(str(loc) for loc in error["loc"]),
|
||||
"msg": error["msg"],
|
||||
"type": error["type"],
|
||||
}
|
||||
)
|
||||
|
||||
response = await fix_validation_errors(
|
||||
response_model, response, error_details
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Error while validating response")
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.utils.model_utils import (
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
get_selected_llm_provider,
|
||||
)
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_generator.models.llm_models_with_validations import (
|
||||
LLMPresentationModelWithValidation,
|
||||
)
|
||||
|
||||
|
||||
CREATE_PRESENTATION_PROMPT = """
|
||||
You're a professional presenter with years of experience in creating clear and engaging presentations.
|
||||
|
||||
Create a presentation using the provided title, slide titles and body following specified steps and guidelines.
|
||||
|
||||
Analyze all inputs, to construct each slide with appropriate content and format.
|
||||
|
||||
|
||||
# Slide Types
|
||||
- **1**: contains title, description and image.
|
||||
- **2**: contains title and list of items.
|
||||
- **4**: contains title and list of items with images.
|
||||
- **5**: contains title, description and a graph.
|
||||
- **6**: contains title, description and list of items.
|
||||
- **7**: contains title and list of items with icons.
|
||||
- **8**: contains title, description and list of items with icons.
|
||||
|
||||
# Steps
|
||||
1. Analyze provided presentation title, slide titles and body.
|
||||
2. Select slide type for each slide.
|
||||
3. Output should be in json format as per given schema.
|
||||
4. **Adherence to schema should be beyond all the rules mentioned in notes.**
|
||||
|
||||
# Notes
|
||||
- Generate output in language mentioned in *Input*.
|
||||
- Freely select type with images and icons.
|
||||
- Introduction and Conclusion should have *Type 1* if graph is not assigned.
|
||||
- Try to select **different types for every slides**.
|
||||
- Don't select Type **3** for any slide.
|
||||
- Do not include same graph twice in presentation without any changes to the other.
|
||||
- Every series in a graph should have data in same unit. Example: all series should be in percentage or all series should be in number of items.
|
||||
- Type **9** and **5** should be only picked if graph is available.
|
||||
- **Strictly keep the text under given limit.**
|
||||
- For slide content follow these rules:
|
||||
- Highlighting in markdown format should be used to emphasize numbers and data.
|
||||
- Adhere to length contraints in **body** and **description**. Focus on direct communication within character constrainsts than lengthy explanation.
|
||||
- **body** and **description** in slides should never exceed character limits of 200 characters.
|
||||
- Specify **don't include text in image** in image prompt.
|
||||
- All the numbers should be bolded with **bold** tag in body or description of slide.
|
||||
- Image prompt should cleary define how image should look like.
|
||||
- Image prompt should not ask to generate **numbers, graphs, dashboard and report**.
|
||||
- Examples of image prompts:
|
||||
- a travel agent presenting a detailed itinerary with photos of destinations, showcasing specific experiences, highlighting travel highlights
|
||||
- a person smiling while traveling, with a beautiful background scenery, such as mountains, beach, or city, golden hour lighting
|
||||
- a humanoid robot standing tall, gazing confidently at the horizon, bathed in warm sunlight, the background showing a futuristic cityscape with sleek buildings and flying vehicles
|
||||
- Descriptions should be clear and to the point.
|
||||
- Descriptions should not use words like "This slide", "This presentation".
|
||||
- If **body** contains items, *choose number of items randomly between mentioned constraints.*
|
||||
- **Icon queries** must be a generic **single word noun**.
|
||||
- Provide 3 icon query for each icon where,
|
||||
- First one should be specific like "Led bulb".
|
||||
- Second one should be more generic that first like "bulb".
|
||||
- Third one should be simplest like "light".
|
||||
|
||||
**Follow the all the length constraints provided in the schema and notes.**
|
||||
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
|
||||
"""
|
||||
|
||||
system_prompt_with_schema = f"""
|
||||
{CREATE_PRESENTATION_PROMPT}
|
||||
|
||||
Follow this schema while giving out response: {LLMPresentationModelWithValidation.model_json_schema()}.
|
||||
|
||||
Make description short and obey the character limits. Output should be in JSON format. Give out only JSON, nothing else.
|
||||
"""
|
||||
|
||||
|
||||
def get_system_prompt():
|
||||
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
|
||||
return (
|
||||
system_prompt_with_schema if is_google_selected else CREATE_PRESENTATION_PROMPT
|
||||
)
|
||||
|
||||
|
||||
def get_response_format():
|
||||
is_google_selected = get_selected_llm_provider() == SelectedLLMProvider.GOOGLE
|
||||
return (
|
||||
{
|
||||
"type": "json_object",
|
||||
}
|
||||
if is_google_selected
|
||||
else {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "LLMPresentationModel",
|
||||
"schema": LLMPresentationModelWithValidation.model_json_schema(),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def generate_presentation_stream(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
response_format = get_response_format()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": get_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": presentation_outline.to_string(),
|
||||
},
|
||||
],
|
||||
response_format=response_format,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_presentation(
|
||||
presentation_outline: PresentationMarkdownModel,
|
||||
) -> str:
|
||||
client = get_llm_client()
|
||||
model = get_large_model()
|
||||
|
||||
response_format = get_response_format()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": get_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": presentation_outline.to_string(),
|
||||
},
|
||||
],
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
|
@ -1,232 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graph_processor.models import GraphModel, LLMGraphModel
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
TYPE3,
|
||||
TYPE4,
|
||||
TYPE5,
|
||||
TYPE6,
|
||||
TYPE7,
|
||||
TYPE8,
|
||||
TYPE9,
|
||||
)
|
||||
|
||||
|
||||
class TableType(Enum):
|
||||
TABLE = "table"
|
||||
BAR = "bar"
|
||||
LINE = "line"
|
||||
PIE = "pie"
|
||||
|
||||
|
||||
class TableDataModel(BaseModel):
|
||||
x_labels: List[str]
|
||||
y_labels: List[str]
|
||||
data: List[List[float]]
|
||||
|
||||
|
||||
class TableModel(BaseModel):
|
||||
name: str
|
||||
type: TableType
|
||||
data: TableDataModel
|
||||
|
||||
|
||||
class HeadingModel(BaseModel):
|
||||
heading: str
|
||||
description: str
|
||||
|
||||
def to_llm_content(self, image_prompt: str = None, icon_query: str = None):
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLMHeadingModel,
|
||||
LLMHeadingModelWithImagePrompt,
|
||||
LLMHeadingModelWithIconQuery,
|
||||
)
|
||||
|
||||
if image_prompt:
|
||||
return LLMHeadingModelWithImagePrompt(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
image_prompt=image_prompt,
|
||||
)
|
||||
elif icon_query:
|
||||
return LLMHeadingModelWithIconQuery(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
icon_query=icon_query,
|
||||
)
|
||||
return LLMHeadingModel(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class SlideContentModel(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class Type1Content(SlideContentModel):
|
||||
body: str
|
||||
image_prompts: List[str]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType1Content
|
||||
|
||||
return LLMType1Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
image_prompt=self.image_prompts[0] if self.image_prompts else "",
|
||||
)
|
||||
|
||||
|
||||
class Type2Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType2Content
|
||||
|
||||
return LLMType2Content(
|
||||
title=self.title,
|
||||
body=[item.to_llm_content() for item in self.body],
|
||||
)
|
||||
|
||||
|
||||
class Type3Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
image_prompts: List[str]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType3Content
|
||||
|
||||
return LLMType3Content(
|
||||
title=self.title,
|
||||
body=[item.to_llm_content() for item in self.body],
|
||||
image_prompt=self.image_prompts[0] if self.image_prompts else "",
|
||||
)
|
||||
|
||||
|
||||
class Type4Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
image_prompts: List[str]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType4Content
|
||||
|
||||
llm_body = []
|
||||
for i, item in enumerate(self.body):
|
||||
image_prompt = self.image_prompts[i] if i < len(self.image_prompts) else ""
|
||||
llm_body.append(item.to_llm_content(image_prompt=image_prompt))
|
||||
return LLMType4Content(
|
||||
title=self.title,
|
||||
body=llm_body,
|
||||
)
|
||||
|
||||
|
||||
class Type5Content(SlideContentModel):
|
||||
body: str
|
||||
# table: TableModel
|
||||
graph: GraphModel
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType5Content
|
||||
|
||||
return LLMType5Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
# table=self.table,
|
||||
graph=self.graph,
|
||||
)
|
||||
|
||||
|
||||
class Type6Content(SlideContentModel):
|
||||
description: str
|
||||
body: List[HeadingModel]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType6Content
|
||||
|
||||
return LLMType6Content(
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
body=[item.to_llm_content() for item in self.body],
|
||||
)
|
||||
|
||||
|
||||
class Type7Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
icon_queries: List[str]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType7Content
|
||||
|
||||
llm_body = []
|
||||
for i, item in enumerate(self.body):
|
||||
icon_query = self.icon_queries[i] if i < len(self.icon_queries) else ""
|
||||
llm_body.append(item.to_llm_content(icon_query=icon_query))
|
||||
return LLMType7Content(
|
||||
title=self.title,
|
||||
body=llm_body,
|
||||
)
|
||||
|
||||
|
||||
class Type8Content(SlideContentModel):
|
||||
description: str
|
||||
body: List[HeadingModel]
|
||||
icon_queries: List[str]
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType8Content
|
||||
|
||||
llm_body = []
|
||||
for i, item in enumerate(self.body):
|
||||
icon_query = self.icon_queries[i] if i < len(self.icon_queries) else ""
|
||||
llm_body.append(item.to_llm_content(icon_query=icon_query))
|
||||
return LLMType8Content(
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
body=llm_body,
|
||||
)
|
||||
|
||||
|
||||
class Type9Content(SlideContentModel):
|
||||
body: List[HeadingModel]
|
||||
# table: TableModel
|
||||
graph: GraphModel
|
||||
|
||||
def to_llm_content(self):
|
||||
from ppt_generator.models.llm_models import LLMType9Content
|
||||
|
||||
return LLMType9Content(
|
||||
title=self.title,
|
||||
body=[item.to_llm_content() for item in self.body],
|
||||
# table=self.table,
|
||||
graph=self.graph,
|
||||
)
|
||||
|
||||
|
||||
ContentUnion = Union[
|
||||
Type1Content,
|
||||
Type2Content,
|
||||
Type3Content,
|
||||
Type4Content,
|
||||
Type5Content,
|
||||
Type6Content,
|
||||
Type7Content,
|
||||
Type8Content,
|
||||
Type9Content,
|
||||
]
|
||||
|
||||
CONTENT_TYPE_MAPPING: Mapping[int, ContentUnion] = {
|
||||
TYPE1: Type1Content,
|
||||
TYPE2: Type2Content,
|
||||
TYPE3: Type3Content,
|
||||
TYPE4: Type4Content,
|
||||
TYPE5: Type5Content,
|
||||
TYPE6: Type6Content,
|
||||
TYPE7: Type7Content,
|
||||
TYPE8: Type8Content,
|
||||
TYPE9: Type9Content,
|
||||
}
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
from typing import List, Mapping, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graph_processor.models import GraphModel, LLMGraphModel
|
||||
from ppt_generator.models.content_type_models import (
|
||||
HeadingModel,
|
||||
TableDataModel,
|
||||
TableModel,
|
||||
TableType,
|
||||
Type1Content,
|
||||
Type2Content,
|
||||
Type3Content,
|
||||
Type4Content,
|
||||
Type5Content,
|
||||
Type6Content,
|
||||
Type7Content,
|
||||
Type8Content,
|
||||
Type9Content,
|
||||
)
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
TYPE3,
|
||||
TYPE4,
|
||||
TYPE5,
|
||||
TYPE6,
|
||||
TYPE7,
|
||||
TYPE8,
|
||||
TYPE9,
|
||||
)
|
||||
|
||||
|
||||
class LLMTableDataModel(TableDataModel):
|
||||
x_labels: List[str]
|
||||
y_labels: List[str]
|
||||
data: List[List[float]]
|
||||
|
||||
|
||||
class LLMTableModel(TableModel):
|
||||
name: str
|
||||
type: TableType
|
||||
data: LLMTableDataModel
|
||||
|
||||
|
||||
class LLMHeadingModel(BaseModel):
|
||||
heading: str
|
||||
description: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithImagePrompt(LLMHeadingModel):
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithIconQuery(LLMHeadingModel):
|
||||
icon_query: str
|
||||
|
||||
def to_content(self) -> HeadingModel:
|
||||
return HeadingModel(
|
||||
heading=self.heading,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
class LLMSlideContentModel(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class LLMType1Content(LLMSlideContentModel):
|
||||
body: str
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> Type1Content:
|
||||
return Type1Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
image_prompts=[self.image_prompt],
|
||||
)
|
||||
|
||||
|
||||
class LLMType2Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModel]
|
||||
|
||||
def to_content(self) -> Type2Content:
|
||||
return Type2Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
)
|
||||
|
||||
|
||||
class LLMType3Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModel]
|
||||
image_prompt: str
|
||||
|
||||
def to_content(self) -> Type3Content:
|
||||
return Type3Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
image_prompts=[self.image_prompt],
|
||||
)
|
||||
|
||||
|
||||
class LLMType4Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModelWithImagePrompt]
|
||||
|
||||
def to_content(self) -> Type4Content:
|
||||
return Type4Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
image_prompts=[each.image_prompt for each in self.body],
|
||||
)
|
||||
|
||||
|
||||
class LLMType5Content(LLMSlideContentModel):
|
||||
body: str
|
||||
# table: LLMTableModel
|
||||
graph: LLMGraphModel
|
||||
|
||||
def to_content(self) -> Type5Content:
|
||||
return Type5Content(
|
||||
title=self.title,
|
||||
body=self.body,
|
||||
# table=self.table,
|
||||
graph=GraphModel.from_llm_graph_model(self.graph),
|
||||
)
|
||||
|
||||
|
||||
class LLMType6Content(LLMSlideContentModel):
|
||||
description: str
|
||||
body: List[LLMHeadingModel]
|
||||
|
||||
def to_content(self) -> Type6Content:
|
||||
return Type6Content(
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
body=[each.to_content() for each in self.body],
|
||||
)
|
||||
|
||||
|
||||
class LLMType7Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModelWithIconQuery]
|
||||
|
||||
def to_content(self) -> Type7Content:
|
||||
return Type7Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
icon_queries=[each.icon_query for each in self.body],
|
||||
)
|
||||
|
||||
|
||||
class LLMType8Content(LLMSlideContentModel):
|
||||
description: str
|
||||
body: List[LLMHeadingModelWithImagePrompt]
|
||||
|
||||
def to_content(self) -> Type8Content:
|
||||
return Type8Content(
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
body=[each.to_content() for each in self.body],
|
||||
icon_queries=[each.image_prompt for each in self.body],
|
||||
)
|
||||
|
||||
|
||||
class LLMType9Content(LLMSlideContentModel):
|
||||
body: List[LLMHeadingModel]
|
||||
# table: LLMTableModel
|
||||
graph: LLMGraphModel
|
||||
|
||||
def to_content(self) -> Type9Content:
|
||||
return Type9Content(
|
||||
title=self.title,
|
||||
body=[each.to_content() for each in self.body],
|
||||
# table=self.table,
|
||||
graph=GraphModel.from_llm_graph_model(self.graph),
|
||||
)
|
||||
|
||||
|
||||
LLMContentUnion = Union[
|
||||
LLMType1Content,
|
||||
LLMType2Content,
|
||||
LLMType3Content,
|
||||
LLMType4Content,
|
||||
LLMType5Content,
|
||||
LLMType6Content,
|
||||
LLMType7Content,
|
||||
LLMType8Content,
|
||||
LLMType9Content,
|
||||
]
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING: Mapping[int, LLMContentUnion] = {
|
||||
TYPE1: LLMType1Content,
|
||||
TYPE2: LLMType2Content,
|
||||
TYPE3: LLMType3Content,
|
||||
TYPE4: LLMType4Content,
|
||||
TYPE5: LLMType5Content,
|
||||
TYPE6: LLMType6Content,
|
||||
TYPE7: LLMType7Content,
|
||||
TYPE8: LLMType8Content,
|
||||
TYPE9: LLMType9Content,
|
||||
}
|
||||
|
||||
|
||||
class LLMSlideModel(BaseModel):
|
||||
type: int
|
||||
content: LLMContentUnion
|
||||
|
||||
|
||||
class LLMPresentationModel(BaseModel):
|
||||
slides: List[LLMSlideModel]
|
||||
|
|
@ -1,232 +0,0 @@
|
|||
from typing import List, Mapping, Union
|
||||
from pydantic import Field
|
||||
|
||||
from graph_processor.models import LLMGraphModel
|
||||
from ppt_generator.models.content_type_models import TableType
|
||||
from ppt_generator.models.other_models import (
|
||||
TYPE1,
|
||||
TYPE2,
|
||||
TYPE3,
|
||||
TYPE4,
|
||||
TYPE5,
|
||||
TYPE6,
|
||||
TYPE7,
|
||||
TYPE8,
|
||||
TYPE9,
|
||||
)
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLMTableDataModel,
|
||||
LLMTableModel,
|
||||
LLMHeadingModel,
|
||||
LLMHeadingModelWithImagePrompt,
|
||||
LLMHeadingModelWithIconQuery,
|
||||
LLMSlideContentModel,
|
||||
LLMType1Content,
|
||||
LLMType2Content,
|
||||
LLMType3Content,
|
||||
LLMType4Content,
|
||||
LLMType5Content,
|
||||
LLMType6Content,
|
||||
LLMType7Content,
|
||||
LLMType8Content,
|
||||
LLMType9Content,
|
||||
LLMSlideModel,
|
||||
LLMPresentationModel,
|
||||
)
|
||||
|
||||
|
||||
class LLMTableDataModelWithValidation(LLMTableDataModel):
|
||||
x_labels: List[str] = Field(
|
||||
description="X labels of the table",
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
)
|
||||
y_labels: List[str] = Field(
|
||||
description="Y labels of the table",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
data: List[List[float]] = Field(
|
||||
description="Data of the table",
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
)
|
||||
|
||||
|
||||
class LLMTableModelWithValidation(LLMTableModel):
|
||||
name: str = Field(
|
||||
description="Name of the table in about 8 words",
|
||||
min_length=10,
|
||||
max_length=50,
|
||||
)
|
||||
type: TableType = Field(description="Type of the table")
|
||||
data: LLMTableDataModelWithValidation
|
||||
|
||||
|
||||
class LLMHeadingModelWithValidation(LLMHeadingModel):
|
||||
heading: str = Field(
|
||||
description="Item heading in about 6 words",
|
||||
min_length=10,
|
||||
max_length=40,
|
||||
)
|
||||
description: str = Field(
|
||||
description="Item description in about 12 words.",
|
||||
min_length=50,
|
||||
max_length=120,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithImagePromptWithValidation(LLMHeadingModelWithImagePrompt):
|
||||
image_prompt: str = Field(
|
||||
description="Item image prompt in about 10 words",
|
||||
min_length=10,
|
||||
max_length=100,
|
||||
)
|
||||
|
||||
|
||||
class LLMHeadingModelWithIconQueryWithValidation(LLMHeadingModelWithIconQuery):
|
||||
icon_query: str = Field(
|
||||
description="Item icon query in about 4 words",
|
||||
min_length=10,
|
||||
max_length=40,
|
||||
)
|
||||
|
||||
|
||||
class LLMSlideContentModelWithValidation(LLMSlideContentModel):
|
||||
title: str = Field(
|
||||
description="Slide title in about 8 words",
|
||||
min_length=10,
|
||||
max_length=80,
|
||||
)
|
||||
|
||||
|
||||
class LLMType1ContentWithValidation(LLMType1Content):
|
||||
body: str = Field(
|
||||
description="Slide content summary in about 30 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Slide image prompt in about 5 words",
|
||||
min_length=10,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
|
||||
class LLMType2ContentWithValidation(LLMType2Content):
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
|
||||
class LLMType3ContentWithValidation(LLMType3Content):
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
)
|
||||
image_prompt: str = Field(
|
||||
description="Slide image prompt in about 5 words",
|
||||
min_length=10,
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
|
||||
class LLMType4ContentWithValidation(LLMType4Content):
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
|
||||
class LLMType5ContentWithValidation(LLMType5Content):
|
||||
body: str = Field(
|
||||
description="Slide content summary in about 30 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
graph: LLMGraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
|
||||
class LLMType6ContentWithValidation(LLMType6Content):
|
||||
description: str = Field(
|
||||
description="Slide content summary in about 20 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
|
||||
class LLMType7ContentWithValidation(LLMType7Content):
|
||||
body: List[LLMHeadingModelWithIconQueryWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=4,
|
||||
)
|
||||
|
||||
|
||||
class LLMType8ContentWithValidation(LLMType8Content):
|
||||
description: str = Field(
|
||||
description="Slide content summary in about 20 words.",
|
||||
min_length=50,
|
||||
max_length=300,
|
||||
)
|
||||
body: List[LLMHeadingModelWithImagePromptWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
|
||||
|
||||
class LLMType9ContentWithValidation(LLMType9Content):
|
||||
body: List[LLMHeadingModelWithValidation] = Field(
|
||||
description="Items to show in slide",
|
||||
min_length=1,
|
||||
max_length=3,
|
||||
)
|
||||
# table: LLMTableModelWithValidation = Field(description="Table to show in slide")
|
||||
graph: LLMGraphModel = Field(description="Graph to show in slide")
|
||||
|
||||
|
||||
LLMContentUnionWithValidation = Union[
|
||||
LLMType1ContentWithValidation,
|
||||
LLMType2ContentWithValidation,
|
||||
LLMType3ContentWithValidation,
|
||||
LLMType4ContentWithValidation,
|
||||
LLMType5ContentWithValidation,
|
||||
LLMType6ContentWithValidation,
|
||||
LLMType7ContentWithValidation,
|
||||
LLMType8ContentWithValidation,
|
||||
LLMType9ContentWithValidation,
|
||||
]
|
||||
|
||||
LLM_CONTENT_TYPE_MAPPING_WITH_VALIDATION: Mapping[
|
||||
int, LLMContentUnionWithValidation
|
||||
] = {
|
||||
TYPE1: LLMType1ContentWithValidation,
|
||||
TYPE2: LLMType2ContentWithValidation,
|
||||
TYPE3: LLMType3ContentWithValidation,
|
||||
TYPE4: LLMType4ContentWithValidation,
|
||||
TYPE5: LLMType5ContentWithValidation,
|
||||
TYPE6: LLMType6ContentWithValidation,
|
||||
TYPE7: LLMType7ContentWithValidation,
|
||||
TYPE8: LLMType8ContentWithValidation,
|
||||
TYPE9: LLMType9ContentWithValidation,
|
||||
}
|
||||
|
||||
|
||||
class LLMSlideModelWithValidation(LLMSlideModel):
|
||||
type: int
|
||||
content: LLMContentUnionWithValidation
|
||||
|
||||
|
||||
class LLMPresentationModelWithValidation(LLMPresentationModel):
|
||||
slides: List[LLMSlideModelWithValidation]
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
# 1. contains title, description and an image.
|
||||
TYPE1 = 1
|
||||
# 2. contains title and list of items.
|
||||
TYPE2 = 2
|
||||
# 3. contains title, list of items and an image.
|
||||
TYPE3 = 3
|
||||
# 4. contains title and list of items and multiple images.
|
||||
TYPE4 = 4
|
||||
# 5. contains title, description and a graph.
|
||||
TYPE5 = 5
|
||||
# 6. contains title, description and list of items.
|
||||
TYPE6 = 6
|
||||
# 7. contains title, list of items and icons.
|
||||
TYPE7 = 7
|
||||
# 8. contains title, description, list of items and icons.
|
||||
TYPE8 = 8
|
||||
# 9. contains title, list of items and a graph.
|
||||
TYPE9 = 9
|
||||
|
||||
|
||||
class SlideTypeModel(BaseModel):
|
||||
slide_type: int = Field(gte=1, lte=9, description="Slide type from 1 to 9")
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ImageAspectRatio(Enum):
|
||||
r_1_1 = "1:1"
|
||||
r_3_2 = "3:2"
|
||||
r_2_3 = "2:3"
|
||||
r_5_4 = "5:4"
|
||||
r_4_5 = "4:5"
|
||||
r_16_9 = "16:9"
|
||||
r_9_16 = "9:16"
|
||||
r_21_9 = "21:9"
|
||||
r_9_21 = "9:21"
|
||||
|
||||
|
||||
class ImagePromptWithThemeAndAspectRatio(BaseModel):
|
||||
theme_prompt: str
|
||||
image_prompt: str
|
||||
aspect_ratio: ImageAspectRatio
|
||||
|
||||
|
||||
class IconFrameEnum(Enum):
|
||||
filled_rounded_rectangle = 1
|
||||
filled_circle = 2
|
||||
|
||||
|
||||
class IconCategoryEnum(Enum):
|
||||
solid = "solid"
|
||||
semi_solid = "semi-solid"
|
||||
outline = "outline"
|
||||
|
||||
|
||||
class IconQueryCollectionWithData(BaseModel):
|
||||
category: IconCategoryEnum = IconCategoryEnum.solid
|
||||
index: int
|
||||
theme: Optional[dict] = None
|
||||
icon_query: str
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue