feat(fastapi): adds presentation generation using provided slide schemas, refactor(fastapi): converts class based endpoint handlers to organized functional handlers
|
|
@ -1,23 +1,26 @@
|
|||
from http.client import HTTPException
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import Annotated, List, Optional
|
||||
import uuid
|
||||
from fastapi import UploadFile
|
||||
from fastapi import APIRouter, Body, File, UploadFile
|
||||
|
||||
from api.v1.ppt.router import API_V1_PPT_ROUTER
|
||||
from constants.documents import UPLOAD_ACCEPTED_DOCUMENTS, UPLOAD_ACCEPTED_IMAGES
|
||||
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
|
||||
from models.decomposed_file_info import DecomposedFileInfo
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from utils.validators import validate_files
|
||||
|
||||
FILES_ROUTER = APIRouter(prefix="/files")
|
||||
|
||||
@API_V1_PPT_ROUTER.post("/files/upload")
|
||||
|
||||
@FILES_ROUTER.post("/upload", response_model=List[str])
|
||||
async def upload_files(files: Optional[List[UploadFile]]):
|
||||
if not files:
|
||||
raise HTTPException(400, "Files are required")
|
||||
raise HTTPException(400, "Documents are required")
|
||||
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
|
||||
|
||||
validate_files(files, True, True, 50, UPLOAD_ACCEPTED_DOCUMENTS)
|
||||
validate_files(files, True, True, 10, UPLOAD_ACCEPTED_IMAGES)
|
||||
validate_files(files, True, True, 50, UPLOAD_ACCEPTED_FILE_TYPES)
|
||||
|
||||
temp_files: List[str] = []
|
||||
if files:
|
||||
|
|
@ -32,3 +35,53 @@ async def upload_files(files: Optional[List[UploadFile]]):
|
|||
temp_files.append(temp_path)
|
||||
|
||||
return temp_files
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/decompose")
|
||||
async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]):
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
|
||||
|
||||
txt_files = []
|
||||
other_files = []
|
||||
for file_path in file_paths:
|
||||
if file_path.endswith(".txt"):
|
||||
txt_files.append(file_path)
|
||||
else:
|
||||
other_files.append(file_path)
|
||||
|
||||
documents_loader = DocumentsLoader(file_paths=other_files)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
parsed_documents = documents_loader.documents
|
||||
|
||||
response = []
|
||||
for index, parsed_doc in enumerate(parsed_documents):
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{str(uuid.uuid4())}.txt", temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
response.append(
|
||||
DecomposedFileInfo(
|
||||
name=os.path.basename(other_files[index]), file_path=file_path
|
||||
)
|
||||
)
|
||||
|
||||
# Return the txt documents as it is
|
||||
for each_file in txt_files:
|
||||
response.append(
|
||||
DecomposedFileInfo(name=os.path.basename(each_file), file_path=each_file)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@FILES_ROUTER.post("/update")
|
||||
async def update_files(
|
||||
file_path: Annotated[str, Body()],
|
||||
file: Annotated[UploadFile, File()],
|
||||
):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
return {"message": "File updated successfully"}
|
||||
|
|
|
|||
62
servers/fastapi/api/v1/ppt/endpoints/outlines.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import json
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.sql.presentation import PresentationModel
|
||||
from models.sse_response import SSECompleteResponse, SSEResponse, SSEStatusResponse
|
||||
from services.database import get_sql_session
|
||||
from utils.llm_calls.generate_presentation_outlines import generate_ppt_outline
|
||||
|
||||
OUTLINES_ROUTER = APIRouter(prefix="/outlines")
|
||||
|
||||
|
||||
@OUTLINES_ROUTER.get("/stream")
|
||||
async def stream_outlines(presentation_id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, presentation_id)
|
||||
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
|
||||
async def inner():
|
||||
yield SSEStatusResponse(
|
||||
status="Generating presentation outlines..."
|
||||
).to_string()
|
||||
|
||||
presentation_content_text = ""
|
||||
async for chunk in generate_ppt_outline(
|
||||
presentation.prompt,
|
||||
presentation.n_slides,
|
||||
presentation.language,
|
||||
presentation.summary,
|
||||
):
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
presentation_content_text += chunk
|
||||
|
||||
presentation_content = PresentationOutlineModel.model_validate_json(
|
||||
presentation_content_text
|
||||
)
|
||||
presentation_content.slides = presentation_content.slides[
|
||||
: presentation.n_slides
|
||||
]
|
||||
|
||||
presentation.title = presentation_content.title
|
||||
presentation.outlines = [
|
||||
each.model_dump() for each in presentation_content.slides
|
||||
]
|
||||
presentation.notes = presentation_content.notes
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
yield SSECompleteResponse(
|
||||
key="presentation", value=presentation.model_dump_json()
|
||||
).to_string()
|
||||
|
||||
return StreamingResponse(inner())
|
||||
174
servers/fastapi/api/v1/ppt/endpoints/presentation.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Annotated, List, Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.presentation_with_slides import PresentationWithSlides
|
||||
from models.sql.slide import SlideModel
|
||||
from models.sse_response import SSECompleteResponse, SSEResponse
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from services.database import get_sql_session
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from models.sql.presentation import PresentationModel
|
||||
from utils.llm_calls.generate_document_summary import generate_document_summary
|
||||
from utils.llm_calls.generate_presentation_structure import (
|
||||
generate_presentation_structure,
|
||||
)
|
||||
from utils.llm_calls.generate_slide import get_slide_content_from_type_and_outline
|
||||
|
||||
PRESENTATION_ROUTER = APIRouter(prefix="/presentation")
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/create", response_model=PresentationModel)
|
||||
async def create_presentation(
|
||||
prompt: Annotated[str, Body()],
|
||||
n_slides: Annotated[int, Body()],
|
||||
language: Annotated[str, Body()],
|
||||
layout: Annotated[PresentationLayoutModel, Body()],
|
||||
file_paths: Annotated[Optional[List[str]], Body()] = None,
|
||||
):
|
||||
presentation_id = str(uuid.uuid4())
|
||||
|
||||
summary = None
|
||||
if file_paths:
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(presentation_id)
|
||||
documents_loader = DocumentsLoader(file_paths=file_paths)
|
||||
await documents_loader.load_documents(temp_dir)
|
||||
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
presentation = PresentationModel(
|
||||
id=presentation_id,
|
||||
prompt=prompt,
|
||||
n_slides=n_slides,
|
||||
language=language,
|
||||
layout=layout.model_dump(),
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
return presentation
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.post("/prepare", response_model=PresentationModel)
|
||||
async def prepare_presentation(
|
||||
presentation_id: Annotated[str, Body()],
|
||||
outlines: Annotated[List[SlideOutlineModel], Body()],
|
||||
title: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
if not outlines:
|
||||
raise HTTPException(status_code=400, detail="Outlines are required")
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, presentation_id)
|
||||
presentation.outlines = [each.model_dump() for each in outlines]
|
||||
presentation.title = title or presentation.title
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
layout = presentation.get_layout()
|
||||
total_slide_layouts = len(layout.slides)
|
||||
total_outlines = len(outlines)
|
||||
|
||||
if layout.ordered:
|
||||
presentation_structure = layout.to_presentation_structure()
|
||||
else:
|
||||
presentation_structure: PresentationStructureModel = (
|
||||
await generate_presentation_structure(
|
||||
presentation_outline=presentation.get_presentation_outline(),
|
||||
presentation_layout=layout,
|
||||
)
|
||||
)
|
||||
presentation_structure.slides = presentation_structure.slides[: len(outlines)]
|
||||
for index in range(total_outlines):
|
||||
random_slide_index = random.randint(0, total_slide_layouts - 1)
|
||||
if index >= total_outlines:
|
||||
presentation_structure.slides.append(random_slide_index)
|
||||
continue
|
||||
if presentation_structure.slides[index] >= total_slide_layouts:
|
||||
presentation_structure.slides[index] = random_slide_index
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
presentation.set_structure(presentation_structure)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
return presentation
|
||||
|
||||
|
||||
@PRESENTATION_ROUTER.get("/stream", response_model=PresentationWithSlides)
|
||||
async def stream_presentation(presentation_id: str):
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationModel, presentation_id)
|
||||
if not presentation:
|
||||
raise HTTPException(status_code=404, detail="Presentation not found")
|
||||
if not presentation.structure:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Presentation not prepared for stream",
|
||||
)
|
||||
if not presentation.outlines:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Outlines can not be empty",
|
||||
)
|
||||
|
||||
async def inner():
|
||||
structure = presentation.get_structure()
|
||||
layout = presentation.get_layout()
|
||||
outline = presentation.get_presentation_outline()
|
||||
|
||||
slides: List[SlideModel] = []
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": '{ "slides": [ '}),
|
||||
).to_string()
|
||||
for i, slide_layout_index in enumerate(structure.slides):
|
||||
slide_layout = layout.slides[slide_layout_index]
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
slide_layout, outline.slides[i]
|
||||
)
|
||||
slide = SlideModel(
|
||||
presentation=presentation_id,
|
||||
layout=slide_layout.id,
|
||||
content=slide_content,
|
||||
)
|
||||
slides.append(slide)
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": slide.model_dump_json()}),
|
||||
).to_string()
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": " ] }"}),
|
||||
).to_string()
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slides)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
for each_slide in slides:
|
||||
sql_session.refresh(each_slide)
|
||||
|
||||
response = PresentationWithSlides(
|
||||
**presentation.model_dump(),
|
||||
slides=slides,
|
||||
)
|
||||
|
||||
yield SSECompleteResponse(
|
||||
key="presentation",
|
||||
value=response.model_dump(mode="json"),
|
||||
).to_string()
|
||||
|
||||
return StreamingResponse(inner())
|
||||
|
|
@ -1,5 +1,12 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
|
||||
API_V1_PPT = "/api/v1/ppt"
|
||||
API_V1_PPT_ROUTER = APIRouter(prefix=API_V1_PPT)
|
||||
|
||||
API_V1_PPT_ROUTER = APIRouter(prefix="/api/v1/ppt")
|
||||
|
||||
API_V1_PPT_ROUTER.include_router(FILES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ JPEG_MIME_TYPES = ["image/jpeg"]
|
|||
WEBP_MIME_TYPES = ["image/webp"]
|
||||
|
||||
|
||||
UPLOAD_ACCEPTED_DOCUMENTS = (
|
||||
UPLOAD_ACCEPTED_FILE_TYPES = (
|
||||
PDF_MIME_TYPES + TEXT_MIME_TYPES + POWERPOINT_TYPES + WORD_TYPES
|
||||
)
|
||||
UPLOAD_ACCEPTED_IMAGES = PNG_MIME_TYPES + JPEG_MIME_TYPES + WEBP_MIME_TYPES
|
||||
|
|
|
|||
33
servers/fastapi/get_test_schema.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel, SlideLayoutModel
|
||||
|
||||
|
||||
class TitleDescriptionSlide(BaseModel):
|
||||
title: str = Field(min_length=10, max_length=100)
|
||||
description: str = Field(min_length=50, max_length=200)
|
||||
|
||||
|
||||
class ContentSlide(BaseModel):
|
||||
title: str = Field(min_length=10, max_length=100)
|
||||
content: List[str] = Field(min_length=1, max_length=5)
|
||||
|
||||
|
||||
presentation_layout = PresentationLayoutModel(
|
||||
name="Basic Presentation",
|
||||
slides=[
|
||||
SlideLayoutModel(
|
||||
id="title_description",
|
||||
name="Title Description",
|
||||
json_schema=TitleDescriptionSlide.model_json_schema(),
|
||||
),
|
||||
SlideLayoutModel(
|
||||
id="content",
|
||||
name="Content",
|
||||
json_schema=ContentSlide.model_json_schema(),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
print(presentation_layout.model_dump_json())
|
||||
6
servers/fastapi/models/decomposed_file_info.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DecomposedFileInfo(BaseModel):
|
||||
name: str
|
||||
file_path: str
|
||||
|
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OllamaModelStatusResponse(BaseModel):
|
||||
class OllamaModelStatus(BaseModel):
|
||||
name: str
|
||||
size: Optional[int] = None
|
||||
downloaded: Optional[int] = None
|
||||
30
servers/fastapi/models/presentation_layout.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
class SlideLayoutModel(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
json_schema: dict
|
||||
|
||||
|
||||
class PresentationLayoutModel(BaseModel):
|
||||
name: Optional[str] = None
|
||||
ordered: bool = Field(default=False)
|
||||
slides: List[SlideLayoutModel]
|
||||
|
||||
def to_presentation_structure(self):
|
||||
return PresentationStructureModel(
|
||||
slides=[index for index in range(len(self.slides))]
|
||||
)
|
||||
|
||||
def to_string(self):
|
||||
message = f"## Presentation Layout\n\n"
|
||||
for index, slide in enumerate(self.slides):
|
||||
message += f"### Slide Layout: {index}: \n"
|
||||
message += f"- Name: {slide.name or slide.json_schema.get('title')} \n"
|
||||
message += f"- Description: {slide.description} \n\n"
|
||||
return message
|
||||
|
|
@ -2,15 +2,7 @@ from typing import List, Optional
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SlideStructureModel(BaseModel):
|
||||
type: int = Field(description="Type of the slide", gte=1, lte=9)
|
||||
|
||||
|
||||
class PresentationStructureModel(BaseModel):
|
||||
slides: List[SlideStructureModel] = Field(description="List of slide structure")
|
||||
|
||||
|
||||
class SlideMarkdownModel(BaseModel):
|
||||
class SlideOutlineModel(BaseModel):
|
||||
title: str = Field(
|
||||
description="Title of the slide in about 3 to 5 words",
|
||||
)
|
||||
|
|
@ -19,12 +11,12 @@ class SlideMarkdownModel(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class PresentationMarkdownModel(BaseModel):
|
||||
class PresentationOutlineModel(BaseModel):
|
||||
title: str = Field(
|
||||
description="Title of the presentation in about 3 to 8 words",
|
||||
)
|
||||
notes: Optional[List[str]] = Field(description="Notes for the presentation")
|
||||
slides: List[SlideMarkdownModel] = Field(description="List of slides")
|
||||
slides: List[SlideOutlineModel] = Field(description="List of slides")
|
||||
|
||||
def to_string(self):
|
||||
message = f"# Presentation Title: {self.title} \n\n"
|
||||
6
servers/fastapi/models/presentation_structure_model.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PresentationStructureModel(BaseModel):
|
||||
slides: List[int] = Field(description="List of slide layout indexes")
|
||||
25
servers/fastapi/models/presentation_with_slides.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from models.sql.slide import SlideModel
|
||||
|
||||
|
||||
class PresentationWithSlides(BaseModel):
|
||||
id: str
|
||||
prompt: str
|
||||
n_slides: int
|
||||
language: str
|
||||
title: Optional[str] = None
|
||||
notes: Optional[List[str]]
|
||||
outlines: Optional[List[SlideOutlineModel]]
|
||||
summary: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
layout: PresentationLayoutModel
|
||||
structure: Optional[PresentationStructureModel]
|
||||
slides: List[SlideModel]
|
||||
9
servers/fastapi/models/sql/key_value.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class KeyValueSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
key: str = Field(index=True)
|
||||
value: dict = Field(sa_column=Column(JSON))
|
||||
49
servers/fastapi/models/sql/presentation.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
class PresentationModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True)
|
||||
prompt: str
|
||||
n_slides: int
|
||||
language: str
|
||||
title: Optional[str] = None
|
||||
notes: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
|
||||
outlines: Optional[List[dict]] = Field(sa_column=Column(JSON), default=None)
|
||||
summary: Optional[str] = None
|
||||
created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
updated_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
layout: dict = Field(sa_column=Column(JSON))
|
||||
structure: Optional[dict] = Field(sa_column=Column(JSON), default=None)
|
||||
|
||||
def get_presentation_outline(self):
|
||||
if not self.outlines:
|
||||
return None
|
||||
return PresentationOutlineModel(
|
||||
title=self.title,
|
||||
slides=[SlideOutlineModel(**each) for each in self.outlines],
|
||||
notes=self.notes,
|
||||
)
|
||||
|
||||
def get_layout(self):
|
||||
return PresentationLayoutModel(**self.layout)
|
||||
|
||||
def set_layout(self, layout: PresentationLayoutModel):
|
||||
self.layout = layout.model_dump()
|
||||
|
||||
def get_structure(self):
|
||||
if not self.structure:
|
||||
return None
|
||||
return PresentationStructureModel(**self.structure)
|
||||
|
||||
def set_structure(self, structure: PresentationStructureModel):
|
||||
self.structure = structure.model_dump()
|
||||
10
servers/fastapi/models/sql/slide.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class SlideModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True, default_factory=get_random_uuid)
|
||||
presentation: str
|
||||
layout: str
|
||||
content: dict = Field(sa_column=Column(JSON))
|
||||
31
servers/fastapi/models/sse_response.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SSEResponse(BaseModel):
|
||||
event: str
|
||||
data: str
|
||||
|
||||
def to_string(self):
|
||||
return f"event: {self.event}\ndata: {self.data}\n\n"
|
||||
|
||||
|
||||
class SSEStatusResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "status", "status": self.status})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSECompleteResponse(BaseModel):
|
||||
key: str
|
||||
value: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "complete", self.key: self.value}),
|
||||
).to_string()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from http.client import HTTPException
|
||||
import mimetypes
|
||||
from fastapi import HTTPException
|
||||
import os, pdfplumber, asyncio
|
||||
from typing import List, Tuple
|
||||
from docx import Document
|
||||
|
|
@ -15,8 +15,8 @@ from constants.documents import (
|
|||
|
||||
class DocumentsLoader:
|
||||
|
||||
def __init__(self, documents: List[str]):
|
||||
self._document_paths = documents
|
||||
def __init__(self, file_paths: List[str]):
|
||||
self._file_paths = file_paths
|
||||
|
||||
self._documents: List[str] = []
|
||||
self._images: List[List[str]] = []
|
||||
|
|
@ -38,7 +38,7 @@ class DocumentsLoader:
|
|||
documents: List[str] = []
|
||||
images: List[str] = []
|
||||
|
||||
for file_path in self._document_paths:
|
||||
for file_path in self._file_paths:
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"File {file_path} not found"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ class TempFileService:
|
|||
|
||||
def __init__(self):
|
||||
self.base_dir = os.getenv("TEMP_DIRECTORY")
|
||||
self.cleanup_base_dir()
|
||||
# TODO: Uncomment this when we want to cleanup the base dir on startup
|
||||
# self.cleanup_base_dir()
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
def create_dir_in_dir(self, base_dir: str, dir_name: Optional[str] = None) -> str:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from typing import List, Optional
|
||||
from pydantic import Field
|
||||
from ppt_config_generator.models import (
|
||||
PresentationMarkdownModel,
|
||||
PresentationStructureModel,
|
||||
SlideMarkdownModel,
|
||||
SlideStructureModel,
|
||||
from models.presentation_outline_model import (
|
||||
PresentationOutlineModel,
|
||||
SlideOutlineModel,
|
||||
)
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
|
||||
|
||||
class SlideMarkdownModelWithValidation(SlideMarkdownModel):
|
||||
class SlideOutlineModelWithValidation(SlideOutlineModel):
|
||||
title: str = Field(
|
||||
description="Title of the slide in about 3 to 5 words",
|
||||
min_length=10,
|
||||
|
|
@ -16,8 +15,8 @@ class SlideMarkdownModelWithValidation(SlideMarkdownModel):
|
|||
)
|
||||
|
||||
|
||||
def get_presentation_markdown_model_with_n_slides(n_slides: int):
|
||||
class PresentationMarkdownModelWithNSlides(PresentationMarkdownModel):
|
||||
def get_presentation_outline_model_with_n_slides(n_slides: int):
|
||||
class PresentationOutlineModelWithNSlides(PresentationOutlineModel):
|
||||
title: str = Field(
|
||||
description="Title of the presentation in about 3 to 8 words",
|
||||
min_length=10,
|
||||
|
|
@ -28,17 +27,17 @@ def get_presentation_markdown_model_with_n_slides(n_slides: int):
|
|||
min_length=0,
|
||||
max_length=10,
|
||||
)
|
||||
slides: List[SlideMarkdownModelWithValidation] = Field(
|
||||
slides: List[SlideOutlineModelWithValidation] = Field(
|
||||
description="List of slides", min_items=n_slides, max_items=n_slides
|
||||
)
|
||||
|
||||
return PresentationMarkdownModelWithNSlides
|
||||
return PresentationOutlineModelWithNSlides
|
||||
|
||||
|
||||
def get_presentation_structure_model_with_n_slides(n_slides: int):
|
||||
class PresentationStructureModelWithNSlides(PresentationStructureModel):
|
||||
slides: List[SlideStructureModel] = Field(
|
||||
description="List of slide structure",
|
||||
slides: List[int] = Field(
|
||||
description="List of slide layouts",
|
||||
min_items=n_slides,
|
||||
max_items=n_slides,
|
||||
)
|
||||
|
|
@ -2,7 +2,8 @@ import asyncio
|
|||
from typing import List
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from api.utils.model_utils import get_llm_client, get_nano_model
|
||||
from utils.llm_provider import get_llm_client, get_nano_model
|
||||
|
||||
|
||||
sysmte_prompt = """
|
||||
Generate a blog-style summary of the provided document in **more than 2000 words**.
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Optional
|
||||
from openai.lib.streaming.chat._events import ContentDeltaEvent
|
||||
|
||||
from utils.get_dynamic_models import get_presentation_outline_model_with_n_slides
|
||||
from utils.llm_provider import (
|
||||
get_large_model,
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
system_prompt = """
|
||||
Create a presentation based on the provided prompt, number of slides, output language, and additional informational details.
|
||||
Format the output in the specified JSON schema with structured markdown content.
|
||||
|
||||
# Steps
|
||||
|
||||
1. Identify key points from the provided prompt, including the topic, number of slides, output language, and additional content directions.
|
||||
2. Create a concise and descriptive title reflecting the main topic, adhering to the specified language.
|
||||
3. Generate a clear title for each slide.
|
||||
4. Develop comprehensive content using markdown structure:
|
||||
* Use bullet points (- or *) for lists.
|
||||
* Use **bold** for emphasis, *italic* for secondary emphasis, and `code` for technical terms.
|
||||
5. Provide important points from prompt as notes.
|
||||
|
||||
# Notes
|
||||
- Content must be generated for every slide.
|
||||
- Images or Icons information provided in **Input** must be included in the **notes**.
|
||||
- Notes should cleary define if it is for specific slide or for the presentation.
|
||||
- Slide **body** should not contain slide **title**.
|
||||
- Slide **title** should not contain "Slide 1", "Slide 2", etc.
|
||||
- Slide **title** should not be in markdown format.
|
||||
- There must be exact **Number of Slides** as specified.
|
||||
"""
|
||||
|
||||
|
||||
def get_user_prompt(prompt: str, n_slides: int, language: str, content: str):
|
||||
return f"""
|
||||
**Input:**
|
||||
- Prompt: {prompt}
|
||||
- Output Language: {language}
|
||||
- Number of Slides: {n_slides}
|
||||
- Additional Information: {content}
|
||||
"""
|
||||
|
||||
|
||||
def get_prompt_template(prompt: str, n_slides: int, language: str, content: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_user_prompt(prompt, n_slides, language, content),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def generate_ppt_outline(
|
||||
prompt: Optional[str],
|
||||
n_slides: int,
|
||||
language: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
):
|
||||
model = get_large_model()
|
||||
response_model = get_presentation_outline_model_with_n_slides(n_slides)
|
||||
|
||||
client = get_llm_client()
|
||||
async with client.beta.chat.completions.stream(
|
||||
model=model,
|
||||
messages=get_prompt_template(prompt, n_slides, language, content),
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "PresentationOutline",
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
},
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if isinstance(event, ContentDeltaEvent):
|
||||
yield event.delta
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from utils.llm_provider import get_llm_client, get_small_model
|
||||
from utils.get_dynamic_models import (
|
||||
get_presentation_structure_model_with_n_slides,
|
||||
)
|
||||
from models.presentation_structure_model import (
|
||||
PresentationStructureModel,
|
||||
)
|
||||
|
||||
|
||||
def get_prompt(presentation_layout: PresentationLayoutModel, n_slides: int, data: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
You're a professional presentation designer with years of experience in designing clear and engaging presentations.
|
||||
|
||||
{presentation_layout.to_string()}
|
||||
|
||||
# Steps
|
||||
1. Analyze provided Number of slides, Presentation title, Slides content and Presentation Layout.
|
||||
2. Select appropriate slide layout **index** for each slide.
|
||||
|
||||
# Notes
|
||||
- Slide layout should be selected based on provided content for slide and notes.
|
||||
- Don't fall into patterns like always using layout 2 and after layout 1.
|
||||
- Each presentation should have its own unique flow and rhythm.
|
||||
- Select type for {n_slides} slides.
|
||||
|
||||
**Go through notes and steps and make sure they are all followed. Rule breaks are strictly not allowed.**
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
{data}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def generate_presentation_structure(
|
||||
presentation_outline: PresentationOutlineModel,
|
||||
presentation_layout: PresentationLayoutModel,
|
||||
) -> PresentationStructureModel:
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
response_model = get_presentation_structure_model_with_n_slides(
|
||||
len(presentation_outline.slides)
|
||||
)
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt(
|
||||
presentation_layout,
|
||||
len(presentation_outline.slides),
|
||||
presentation_outline.to_string(),
|
||||
),
|
||||
response_format=response_model,
|
||||
)
|
||||
return response.choices[0].message.parsed
|
||||
61
servers/fastapi/utils/llm_calls/generate_slide.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from models.presentation_layout import SlideLayoutModel
|
||||
from models.presentation_outline_model import SlideOutlineModel
|
||||
from utils.llm_provider import get_llm_client, get_small_model
|
||||
|
||||
|
||||
def get_prompt_to_generate_slide_content(title: str, outline: str):
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
Generate structured slide based on provided title and outline, follow mentioned steps and notes and provide structured output.
|
||||
|
||||
# Steps
|
||||
1. Analyze the outline and title.
|
||||
2. Generate structured slide based on the outline and title.
|
||||
|
||||
# Notes
|
||||
- Slide body should not use words like "This slide", "This presentation".
|
||||
- Rephrase the slide body to make it flow naturally.
|
||||
- Do not use markdown formatting in slide body.
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
## Slide Title
|
||||
{title}
|
||||
|
||||
## Slide Outline
|
||||
{outline}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def get_slide_content_from_type_and_outline(
|
||||
slide_layout: SlideLayoutModel, outline: SlideOutlineModel
|
||||
):
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": slide_layout.name or slide_layout.id,
|
||||
"schema": slide_layout.json_schema,
|
||||
},
|
||||
}
|
||||
|
||||
client = get_llm_client()
|
||||
model = get_small_model()
|
||||
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=get_prompt_to_generate_slide_content(
|
||||
outline.title,
|
||||
outline.body,
|
||||
),
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return json.loads(response.choices[0].message.content)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from http.client import HTTPException
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
from google import genai
|
||||
|
||||
from enums.llm_provider import LLMProvider
|
||||
from utils.get_env import (
|
||||
|
|
@ -21,6 +22,10 @@ def get_ollama_url():
|
|||
return get_ollama_url_env() or "http://localhost:11434"
|
||||
|
||||
|
||||
def is_google_selected():
|
||||
return get_llm_provider() == LLMProvider.GOOGLE
|
||||
|
||||
|
||||
def is_ollama_selected():
|
||||
return get_llm_provider() == LLMProvider.OLLAMA
|
||||
|
||||
|
|
@ -64,3 +69,50 @@ def get_llm_client():
|
|||
api_key=get_llm_api_key(),
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_google_llm_client():
|
||||
client = genai.Client(api_key=get_llm_api_key())
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = get_llm_provider()
|
||||
if selected_llm == LLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == LLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == LLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == LLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
from typing import AsyncGenerator
|
||||
import aiohttp
|
||||
|
||||
from models.ollama_model_status_response import OllamaModelStatusResponse
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from utils.get_env import get_ollama_url_env
|
||||
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ async def pull_ollama_model(model: str) -> AsyncGenerator[dict, None]:
|
|||
yield event
|
||||
|
||||
|
||||
async def list_pulled_ollama_models() -> list[OllamaModelStatusResponse]:
|
||||
async def list_pulled_ollama_models() -> list[OllamaModelStatus]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{get_ollama_url_env()}/api/tags",
|
||||
|
|
@ -39,7 +39,7 @@ async def list_pulled_ollama_models() -> list[OllamaModelStatusResponse]:
|
|||
if response.status == 200:
|
||||
pulled_models = await response.json()
|
||||
return [
|
||||
OllamaModelStatusResponse(
|
||||
OllamaModelStatus(
|
||||
name=m["model"],
|
||||
size=m["size"],
|
||||
status="pulled",
|
||||
|
|
|
|||
5
servers/fastapi/utils/randomizers.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import uuid
|
||||
|
||||
|
||||
def get_random_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
|
@ -1 +0,0 @@
|
|||
{}
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
import os
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlmodel import SQLModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.routers.presentation.router import presentation_router
|
||||
from api.services.database import sql_engine
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from api.utils.utils import update_env_with_user_config
|
||||
from api.utils.model_utils import (
|
||||
get_selected_llm_provider,
|
||||
is_custom_llm_selected,
|
||||
is_ollama_selected,
|
||||
list_available_custom_models,
|
||||
pull_ollama_model,
|
||||
)
|
||||
|
||||
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
|
||||
|
||||
|
||||
async def check_llm_model_availability():
|
||||
if not can_change_keys:
|
||||
if get_selected_llm_provider() == SelectedLLMProvider.OPENAI:
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
|
||||
elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE:
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
|
||||
elif is_ollama_selected():
|
||||
ollama_model = os.getenv("OLLAMA_MODEL")
|
||||
if not ollama_model:
|
||||
raise Exception("OLLAMA_MODEL must be provided")
|
||||
|
||||
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise Exception(f"Model {ollama_model} is not supported")
|
||||
|
||||
print("-" * 50)
|
||||
print("Pulling model: ", ollama_model)
|
||||
async for event in pull_ollama_model(ollama_model):
|
||||
print(event)
|
||||
print("Pulled model: ", ollama_model)
|
||||
print("-" * 50)
|
||||
|
||||
elif is_custom_llm_selected():
|
||||
custom_model = os.getenv("CUSTOM_MODEL")
|
||||
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
|
||||
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
if not custom_model:
|
||||
raise Exception("CUSTOM_MODEL must be provided")
|
||||
if not custom_llm_url:
|
||||
raise Exception("CUSTOM_LLM_URL must be provided")
|
||||
if not custom_llm_api_key:
|
||||
raise Exception("CUSTOM_LLM_API_KEY must be provided")
|
||||
print("-" * 50)
|
||||
print("Selecting model: ", custom_model)
|
||||
models = await list_available_custom_models(
|
||||
custom_llm_url, custom_llm_api_key
|
||||
)
|
||||
print("Available models: ", models)
|
||||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
|
||||
SQLModel.metadata.create_all(sql_engine)
|
||||
await check_llm_model_availability()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
origins = ["*"]
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def update_env_middleware(request: Request, call_next):
|
||||
if can_change_keys:
|
||||
update_env_with_user_config()
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
app.include_router(presentation_router)
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
from enum import Enum
|
||||
import json
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.sql_models import PresentationSqlModel
|
||||
|
||||
|
||||
class LogMetadata(BaseModel):
|
||||
presentation: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
status_code: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_presentation(
|
||||
cls, presentation: PresentationSqlModel, endpoint: Optional[str] = None
|
||||
):
|
||||
return cls(
|
||||
presentation=presentation.id,
|
||||
title=presentation.title,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
@property
|
||||
def stream_name(self):
|
||||
return f"Endpoint - {self.endpoint}, Presentation - {self.presentation}"
|
||||
|
||||
|
||||
class SessionModel(BaseModel):
|
||||
session: str
|
||||
|
||||
|
||||
class SSEResponse(BaseModel):
|
||||
event: str
|
||||
data: str
|
||||
|
||||
def to_string(self):
|
||||
return f"event: {self.event}\ndata: {self.data}\n\n"
|
||||
|
||||
|
||||
class SSEStatusResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response", data=json.dumps({"type": "status", "status": self.status})
|
||||
).to_string()
|
||||
|
||||
|
||||
class SSECompleteResponse(BaseModel):
|
||||
key: str
|
||||
value: object
|
||||
|
||||
def to_string(self):
|
||||
return SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "complete", self.key: self.value}),
|
||||
).to_string()
|
||||
|
||||
|
||||
class UserConfig(BaseModel):
|
||||
LLM: Optional[str] = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
OLLAMA_URL: Optional[str] = None
|
||||
OLLAMA_MODEL: Optional[str] = None
|
||||
CUSTOM_LLM_URL: Optional[str] = None
|
||||
CUSTOM_LLM_API_KEY: Optional[str] = None
|
||||
CUSTOM_MODEL: Optional[str] = None
|
||||
PEXELS_API_KEY: Optional[str] = None
|
||||
|
||||
|
||||
class OllamaModelMetadata(BaseModel):
|
||||
label: str
|
||||
value: str
|
||||
description: str
|
||||
icon: str
|
||||
size: str
|
||||
supports_graph: bool
|
||||
|
||||
|
||||
class SelectedLLMProvider(Enum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
CUSTOM = "custom"
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class RequestUtils:
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
async def initialize_logger(
|
||||
self,
|
||||
presentation_id: Optional[str] = None,
|
||||
):
|
||||
metadata = LogMetadata(presentation=presentation_id, endpoint=self.endpoint)
|
||||
logging_service = LoggingService(metadata.stream_name)
|
||||
|
||||
return logging_service, metadata
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
DecomposeDocumentsRequest,
|
||||
DecomposeDocumentsResponse,
|
||||
)
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from document_processor.loader import DocumentsLoader
|
||||
|
||||
|
||||
class DecomposeDocumentsHandler:
|
||||
|
||||
def __init__(self, data: DecomposeDocumentsRequest):
|
||||
self.data = data
|
||||
self.documents = list(
|
||||
filter(lambda doc: not doc.endswith(".csv"), self.data.documents or [])
|
||||
)
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
documents_loader = DocumentsLoader(self.documents)
|
||||
await documents_loader.load_documents(self.temp_dir)
|
||||
parsed_documents = documents_loader.documents
|
||||
|
||||
document_paths = []
|
||||
for parsed_doc in parsed_documents:
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{str(uuid.uuid4())}.txt", self.temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(parsed_doc)
|
||||
document_paths.append(file_path)
|
||||
|
||||
documents = {}
|
||||
for index, each in enumerate(self.documents):
|
||||
documents[each] = document_paths[index]
|
||||
|
||||
response = DecomposeDocumentsResponse(
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.utils import get_presentation_dir
|
||||
|
||||
|
||||
class DeletePresentationHandler:
|
||||
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.id)
|
||||
|
||||
async def delete(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"presentation": self.id}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.id)
|
||||
sql_session.delete(presentation)
|
||||
sql_session.commit()
|
||||
|
||||
if os.path.exists(self.presentation_dir):
|
||||
shutil.rmtree(self.presentation_dir)
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.database import get_sql_session
|
||||
from api.sql_models import SlideSqlModel
|
||||
|
||||
|
||||
class DeleteSlideHandler:
|
||||
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
async def delete(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"slide": self.id}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
slide = sql_session.get(SlideSqlModel, self.id)
|
||||
sql_session.delete(slide)
|
||||
sql_session.commit()
|
||||
|
|
@ -1,204 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
EditPresentationSlideRequest,
|
||||
)
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from api.utils.utils import (
|
||||
get_presentation_dir,
|
||||
get_presentation_images_dir,
|
||||
)
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from image_processor.icons_finder import get_icon
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
IconQueryCollectionWithData,
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from ppt_generator.slide_generator import (
|
||||
get_edited_slide_content_model,
|
||||
get_slide_type_from_prompt,
|
||||
)
|
||||
from ppt_generator.slide_model_utils import SlideModelUtils
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class PresentationEditHandler:
|
||||
def __init__(self, data: EditPresentationSlideRequest):
|
||||
self.data = data
|
||||
self.presentation_id = data.presentation_id
|
||||
|
||||
self.slide_index = data.index
|
||||
self.prompt = data.prompt
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
slide_to_edit_sql = sql_session.exec(
|
||||
select(SlideSqlModel).where(
|
||||
SlideSqlModel.index == self.slide_index,
|
||||
SlideSqlModel.presentation == self.presentation_id,
|
||||
)
|
||||
).first()
|
||||
|
||||
slide_to_edit = SlideSqlModel(**slide_to_edit_sql.model_dump(mode="json"))
|
||||
new_slide_type = await get_slide_type_from_prompt(self.prompt, slide_to_edit)
|
||||
new_slide_type = new_slide_type.slide_type
|
||||
|
||||
supports_graph = not is_custom_llm_selected()
|
||||
if is_ollama_selected():
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
supports_graph = model.supports_graph
|
||||
|
||||
if not supports_graph:
|
||||
if new_slide_type == 5:
|
||||
new_slide_type = 1
|
||||
elif new_slide_type == 9:
|
||||
new_slide_type = 6
|
||||
|
||||
edited_content = await get_edited_slide_content_model(
|
||||
self.prompt,
|
||||
new_slide_type,
|
||||
slide_to_edit,
|
||||
presentation.theme,
|
||||
presentation.language,
|
||||
)
|
||||
|
||||
new_slide_model = SlideSqlModel(
|
||||
id=slide_to_edit.id,
|
||||
index=slide_to_edit.index,
|
||||
type=new_slide_type,
|
||||
design_index=slide_to_edit.design_index,
|
||||
images=None,
|
||||
icons=None,
|
||||
presentation=slide_to_edit.presentation,
|
||||
properties=slide_to_edit.properties,
|
||||
content=edited_content.to_content(),
|
||||
)
|
||||
|
||||
new_slide_images_count = new_slide_model.images_count
|
||||
new_slide_icons_count = new_slide_model.icons_count
|
||||
|
||||
slide_model_utils = SlideModelUtils(presentation.theme, new_slide_model)
|
||||
|
||||
new_slide_images: dict[int, str | ImagePromptWithThemeAndAspectRatio] = {}
|
||||
new_slide_icons: dict[int, str | IconQueryCollectionWithData] = {}
|
||||
|
||||
# ? Checks if image prompts have changed
|
||||
# ? If they have, it will search if it is same as the old one but used at a different index
|
||||
# ? If it is, it will use the old image
|
||||
# ? If it is not, it will generate a new image
|
||||
if new_slide_images_count:
|
||||
new_image_prompts = slide_model_utils.get_image_prompts()
|
||||
old_image_prompts = (
|
||||
slide_to_edit.content.image_prompts
|
||||
if slide_to_edit.images_count
|
||||
else []
|
||||
)
|
||||
for index in range(new_slide_images_count):
|
||||
new_prompt = new_slide_model.content.image_prompts[index]
|
||||
for old_prompt in old_image_prompts:
|
||||
if old_prompt != new_prompt:
|
||||
continue
|
||||
if index < len(slide_to_edit.images or []):
|
||||
new_slide_images[index] = slide_to_edit.images[index]
|
||||
break
|
||||
if not new_slide_images.get(index):
|
||||
new_slide_images[index] = new_image_prompts[index]
|
||||
|
||||
# ? Checks if icon queries have changed
|
||||
# ? If they have, it will search if it is same as the old one but used at a different index
|
||||
# ? If it is, it will use the old icon
|
||||
# ? If it is not, it will generate a new icon
|
||||
if new_slide_icons_count:
|
||||
new_icon_queries = slide_model_utils.get_icon_queries()
|
||||
old_icon_queries = (
|
||||
slide_to_edit.content.icon_queries if slide_to_edit.icons_count else []
|
||||
)
|
||||
for index in range(new_slide_icons_count):
|
||||
new_query = new_slide_model.content.icon_queries[index]
|
||||
for old_query in old_icon_queries:
|
||||
if old_query != new_query:
|
||||
continue
|
||||
if index < len(slide_to_edit.icons or []):
|
||||
new_slide_icons[index] = slide_to_edit.icons[index]
|
||||
break
|
||||
if not new_slide_icons.get(index):
|
||||
new_slide_icons[index] = new_icon_queries[index]
|
||||
|
||||
images_to_generate = []
|
||||
for each in new_slide_images.values():
|
||||
if isinstance(each, ImagePromptWithThemeAndAspectRatio):
|
||||
images_to_generate.append(each)
|
||||
|
||||
icons_to_generate = []
|
||||
for each in new_slide_icons.values():
|
||||
if isinstance(each, IconQueryCollectionWithData):
|
||||
icons_to_generate.append(each)
|
||||
|
||||
images_directory = get_presentation_images_dir(self.presentation_id)
|
||||
if icons_to_generate:
|
||||
icons_vectorstore = get_icons_vectorstore()
|
||||
|
||||
coroutines = [
|
||||
generate_image(each_prompt, images_directory)
|
||||
for each_prompt in images_to_generate
|
||||
] + [
|
||||
get_icon(icons_vectorstore, each_query) for each_query in icons_to_generate
|
||||
]
|
||||
generated_assets = await asyncio.gather(*coroutines)
|
||||
generated_image_count = len(images_to_generate)
|
||||
generate_images = generated_assets[:generated_image_count]
|
||||
generate_icons = generated_assets[generated_image_count:]
|
||||
|
||||
for each in new_slide_images:
|
||||
if isinstance(new_slide_images[each], ImagePromptWithThemeAndAspectRatio):
|
||||
new_slide_images[each] = generate_images.pop(0)
|
||||
|
||||
for each in new_slide_icons:
|
||||
if isinstance(new_slide_icons[each], IconQueryCollectionWithData):
|
||||
new_slide_icons[each] = generate_icons.pop(0)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.exec(
|
||||
update(SlideSqlModel)
|
||||
.where(SlideSqlModel.id == slide_to_edit.id)
|
||||
.values(
|
||||
type=new_slide_type,
|
||||
images=list(new_slide_images.values()),
|
||||
icons=list(new_slide_icons.values()),
|
||||
content=new_slide_model.content.model_dump(mode="json"),
|
||||
)
|
||||
)
|
||||
sql_session.commit()
|
||||
slide_to_edit_sql = sql_session.exec(
|
||||
select(SlideSqlModel).where(SlideSqlModel.id == slide_to_edit.id)
|
||||
).first()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(slide_to_edit_sql.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return slide_to_edit_sql
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.mixins.fetch_presentation_assets import (
|
||||
FetchPresentationAssetsMixin,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
ExportAsRequest,
|
||||
PresentationAndPath,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.utils.utils import get_presentation_dir, sanitize_filename
|
||||
from ppt_generator.pptx_presentation_creator import PptxPresentationCreator
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class ExportAsPptxHandler(FetchPresentationAssetsMixin):
|
||||
|
||||
def __init__(self, data: ExportAsRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.data.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
await self.fetch_presentation_assets()
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
|
||||
ppt_path = os.path.join(
|
||||
self.presentation_dir,
|
||||
sanitize_filename(f"{presentation.title}.pptx")
|
||||
)
|
||||
ppt_creator = PptxPresentationCreator(self.data.pptx_model, self.temp_dir)
|
||||
ppt_creator.create_ppt()
|
||||
ppt_creator.save(ppt_path)
|
||||
|
||||
response = PresentationAndPath(
|
||||
presentation_id=self.data.presentation_id, path=ppt_path
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
presentation.file = ppt_path
|
||||
sql_session.commit()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException
|
||||
from api.models import LogMetadata, SessionModel
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
SUPPORTED_OLLAMA_MODELS,
|
||||
)
|
||||
from api.routers.presentation.models import PresentationGenerateRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from ppt_config_generator.models import PresentationMarkdownModel, SlideStructureModel
|
||||
from ppt_config_generator.structure_generator import generate_presentation_structure
|
||||
|
||||
SLIDES_WITHOUT_GRAPH = [2, 4, 6, 7, 8]
|
||||
|
||||
|
||||
class PresentationGenerateDataHandler:
|
||||
|
||||
def __init__(self, data: PresentationGenerateRequest):
|
||||
self.data = data
|
||||
self.session = str(uuid.uuid4())
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump()),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
if not self.data.outlines:
|
||||
raise HTTPException(400, "Outlines can not be empty")
|
||||
|
||||
key_value_model = KeyValueSqlModel(
|
||||
id=self.session,
|
||||
key=self.session,
|
||||
value=self.data.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
presentation_structure = await generate_presentation_structure(
|
||||
PresentationMarkdownModel(
|
||||
**{
|
||||
"title": presentation.title,
|
||||
"slides": presentation.outlines,
|
||||
"notes": presentation.notes,
|
||||
}
|
||||
)
|
||||
)
|
||||
supports_graph = not is_custom_llm_selected()
|
||||
if is_ollama_selected():
|
||||
model = SUPPORTED_OLLAMA_MODELS[os.getenv("OLLAMA_MODEL")]
|
||||
supports_graph = model.supports_graph
|
||||
|
||||
for each in presentation_structure.slides:
|
||||
if each.type > 9:
|
||||
each.type = random.choice(SLIDES_WITHOUT_GRAPH)
|
||||
if each.type == 3:
|
||||
each.type = 6
|
||||
if not supports_graph:
|
||||
if each.type == 5:
|
||||
each.type = 1
|
||||
elif each.type == 9:
|
||||
each.type = 6
|
||||
|
||||
presentation_outlines_len = len(presentation.outlines)
|
||||
missing_slides_len = presentation_outlines_len - len(
|
||||
presentation_structure.slides
|
||||
)
|
||||
if missing_slides_len > 0:
|
||||
for index in range(missing_slides_len):
|
||||
selected_type = (
|
||||
random.choice(SLIDES_WITHOUT_GRAPH)
|
||||
if index != missing_slides_len - 1
|
||||
else 1
|
||||
)
|
||||
presentation_structure.slides.append(
|
||||
SlideStructureModel(type=selected_type)
|
||||
)
|
||||
elif missing_slides_len < 0:
|
||||
presentation_structure.slides = presentation_structure.slides[
|
||||
:presentation_outlines_len
|
||||
]
|
||||
|
||||
presentation.structure = presentation_structure.model_dump(mode="json")
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(key_value_model)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(key_value_model)
|
||||
|
||||
response = SessionModel(session=self.session)
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
GenerateImageRequest,
|
||||
PresentationAndPaths,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.utils.utils import get_presentation_dir, get_presentation_images_dir
|
||||
from image_processor.images_finder import generate_image
|
||||
|
||||
|
||||
class GenerateImageHandler:
|
||||
|
||||
def __init__(self, data: GenerateImageRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.data.presentation_id)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
images_directory = get_presentation_images_dir(self.data.presentation_id)
|
||||
image_path = await generate_image(self.data.prompt, images_directory)
|
||||
|
||||
response = PresentationAndPaths(
|
||||
presentation_id=self.data.presentation_id, paths=[image_path]
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
import uuid
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GenerateOutlinesRequest
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from ppt_config_generator.ppt_outlines_generator import generate_ppt_content
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class PresentationOutlinesGenerateHandler:
|
||||
def __init__(self, data: GenerateOutlinesRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
|
||||
presentation_content = await generate_ppt_content(
|
||||
presentation.prompt,
|
||||
presentation.n_slides,
|
||||
presentation.language,
|
||||
presentation.summary,
|
||||
)
|
||||
presentation_content.slides = presentation_content.slides[
|
||||
: presentation.n_slides
|
||||
]
|
||||
|
||||
presentation.title = presentation_content.title
|
||||
presentation.outlines = [
|
||||
each.model_dump() for each in presentation_content.slides
|
||||
]
|
||||
presentation.notes = presentation_content.notes
|
||||
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(presentation.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return presentation
|
||||
|
|
@ -1,185 +0,0 @@
|
|||
import json
|
||||
from typing import List
|
||||
import uuid, aiohttp
|
||||
from fastapi import HTTPException
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.handlers.export_as_pptx import ExportAsPptxHandler
|
||||
from api.routers.presentation.handlers.upload_files import UploadFilesHandler
|
||||
from api.routers.presentation.mixins.fetch_assets_on_generation import (
|
||||
FetchAssetsOnPresentationGenerationMixin,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
ExportAsRequest,
|
||||
GeneratePresentationRequest,
|
||||
PresentationAndPath,
|
||||
PresentationPathAndEditPath,
|
||||
)
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
from ppt_config_generator.models import PresentationMarkdownModel
|
||||
from ppt_config_generator.ppt_outlines_generator import generate_ppt_content
|
||||
from ppt_generator.generator import generate_presentation
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
|
||||
|
||||
class GeneratePresentationHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
def __init__(self, presentation_id: str, data: GeneratePresentationRequest):
|
||||
self.session = str(uuid.uuid4())
|
||||
self.presentation_id = presentation_id
|
||||
self.data = data
|
||||
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Ollama is not currently supported for this endpoint",
|
||||
)
|
||||
|
||||
documents_and_images_path = await UploadFilesHandler(
|
||||
documents=self.data.documents,
|
||||
images=None,
|
||||
).post(logging_service, log_metadata)
|
||||
|
||||
summary = None
|
||||
if documents_and_images_path.documents:
|
||||
documents_loader = DocumentsLoader(documents_and_images_path.documents)
|
||||
await documents_loader.load_documents(self.temp_dir)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generating Document Summary")
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generating PPT Outline")
|
||||
presentation_content = await generate_ppt_content(
|
||||
self.data.prompt,
|
||||
self.data.n_slides,
|
||||
self.data.language,
|
||||
summary,
|
||||
)
|
||||
|
||||
print("-" * 40)
|
||||
print("Generating Presentation")
|
||||
presentation_text = await generate_presentation(
|
||||
PresentationMarkdownModel(
|
||||
title=presentation_content.title,
|
||||
slides=presentation_content.slides,
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
)
|
||||
|
||||
print("-" * 40)
|
||||
print("Parsing Presentation")
|
||||
presentation_json = json.loads(presentation_text)
|
||||
|
||||
slide_models: List[SlideModel] = []
|
||||
for i, slide in enumerate(presentation_json["slides"]):
|
||||
slide["index"] = i
|
||||
slide["presentation"] = self.presentation_id
|
||||
slide["content"] = (
|
||||
LLM_CONTENT_TYPE_MAPPING[slide["type"]](**slide["content"])
|
||||
.to_content()
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
slide_model = SlideModel(**slide)
|
||||
slide_models.append(slide_model)
|
||||
|
||||
print("-" * 40)
|
||||
print("Fetching Theme Colors")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost/api/get-theme-from-name?theme={self.data.theme.value}",
|
||||
) as response:
|
||||
self.theme = await response.json()
|
||||
|
||||
print("-" * 40)
|
||||
print("Fetching Slide Assets")
|
||||
async for result in self.fetch_slide_assets(slide_models):
|
||||
print(result)
|
||||
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models
|
||||
]
|
||||
|
||||
presentation = PresentationSqlModel(
|
||||
id=self.presentation_id,
|
||||
prompt=self.data.prompt,
|
||||
n_slides=self.data.n_slides,
|
||||
language=self.data.language,
|
||||
summary=summary,
|
||||
theme=self.theme,
|
||||
title=presentation_content.title,
|
||||
outlines=[each.model_dump() for each in presentation_content.slides],
|
||||
notes=presentation_content.notes,
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.add_all(slide_sql_models)
|
||||
sql_session.commit()
|
||||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
|
||||
if self.data.export_as == "pptx":
|
||||
print("-" * 40)
|
||||
print("Fetching Slide Metadata for Export")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost/api/slide-metadata",
|
||||
json={
|
||||
"id": self.presentation_id,
|
||||
|
||||
},
|
||||
) as response:
|
||||
export_request_body = await response.json()
|
||||
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation")
|
||||
export_request_body["presentation_id"] = self.presentation_id
|
||||
print(export_request_body)
|
||||
export_request = ExportAsRequest(**export_request_body)
|
||||
|
||||
presentation_and_path = await ExportAsPptxHandler(export_request).post(
|
||||
logging_service, log_metadata
|
||||
)
|
||||
|
||||
else:
|
||||
print("-" * 40)
|
||||
print("Exporting Presentation as PDF")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost/api/export-as-pdf",
|
||||
json={
|
||||
"id": self.presentation_id,
|
||||
"title": presentation_content.title,
|
||||
},
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
presentation_and_path = PresentationAndPath(
|
||||
presentation_id=self.presentation_id,
|
||||
path=response_json["path"].replace("app", "static"),
|
||||
)
|
||||
|
||||
presentation_and_path.path = presentation_and_path.path.replace("app", "static")
|
||||
return PresentationPathAndEditPath(
|
||||
**presentation_and_path.model_dump(),
|
||||
edit_path=f"/presentation?id={self.presentation_id}",
|
||||
)
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GeneratePresentationRequirementsRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from document_processor.loader import DocumentsLoader
|
||||
from ppt_config_generator.document_summary_generator import generate_document_summary
|
||||
|
||||
|
||||
class GeneratePresentationRequirementsHandler:
|
||||
def __init__(
|
||||
self,
|
||||
presentation_id: str,
|
||||
data: GeneratePresentationRequirementsRequest,
|
||||
):
|
||||
self.data = data
|
||||
self.presentation_id = presentation_id
|
||||
self.prompt = data.prompt
|
||||
self.n_slides = data.n_slides
|
||||
self.documents = data.documents or []
|
||||
self.language = data.language
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
all_document_paths = [*self.documents]
|
||||
|
||||
documents_loader = DocumentsLoader(all_document_paths)
|
||||
await documents_loader.load_documents(self.temp_dir)
|
||||
|
||||
summary = await generate_document_summary(documents_loader.documents)
|
||||
|
||||
presentation = PresentationSqlModel(
|
||||
id=self.presentation_id,
|
||||
prompt=self.prompt,
|
||||
n_slides=self.n_slides,
|
||||
language=self.language,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add(presentation)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(presentation.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return presentation
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import GenerateResearchReportRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from research_report.generator import get_report
|
||||
|
||||
|
||||
class GenerateResearchReportHandler:
|
||||
def __init__(self, data: GenerateResearchReportRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
report = await get_report(self.data.query, self.data.language)
|
||||
|
||||
file_name = f"{report[:30]}.txt"
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(file_name, self.temp_dir)
|
||||
with open(file_path, "w") as text_file:
|
||||
text_file.write(report)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(file_path), extra=log_metadata.model_dump()
|
||||
)
|
||||
return file_path
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
import json
|
||||
from typing import List
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlmodel import delete
|
||||
|
||||
from api.models import LogMetadata, SSECompleteResponse, SSEResponse, SSEStatusResponse
|
||||
|
||||
from api.routers.presentation.mixins.fetch_assets_on_generation import (
|
||||
FetchAssetsOnPresentationGenerationMixin,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
PresentationAndSlides,
|
||||
PresentationGenerateRequest,
|
||||
)
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import KeyValueSqlModel, PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import get_presentation_dir
|
||||
from api.utils.model_utils import is_custom_llm_selected, is_ollama_selected
|
||||
from ppt_config_generator.models import (
|
||||
PresentationMarkdownModel,
|
||||
PresentationStructureModel,
|
||||
)
|
||||
from ppt_generator.generator import generate_presentation_stream
|
||||
from ppt_generator.models.llm_models import (
|
||||
LLM_CONTENT_TYPE_MAPPING,
|
||||
LLMPresentationModel,
|
||||
LLMSlideModel,
|
||||
)
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
|
||||
from ppt_generator.slide_generator import get_slide_content_from_type_and_outline
|
||||
|
||||
|
||||
class PresentationGenerateStreamHandler(FetchAssetsOnPresentationGenerationMixin):
|
||||
|
||||
def __init__(self, presentation_id: str, session: str):
|
||||
self.session = session
|
||||
self.presentation_id = presentation_id
|
||||
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
with get_sql_session() as sql_session:
|
||||
key_value_model = sql_session.get(KeyValueSqlModel, self.session)
|
||||
|
||||
if not key_value_model.value:
|
||||
raise HTTPException(400, "Data not found for provided session")
|
||||
|
||||
self.data = PresentationGenerateRequest(**key_value_model.value)
|
||||
|
||||
self.presentation_id = self.data.presentation_id
|
||||
self.theme = self.data.theme
|
||||
self.images = self.data.images
|
||||
self.title = self.data.title or ""
|
||||
self.outlines = self.data.outlines
|
||||
|
||||
return StreamingResponse(
|
||||
self.get_stream(*args, **kwargs), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
async def get_stream(
|
||||
self, logging_service: LoggingService, log_metadata: LogMetadata
|
||||
):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
if not self.outlines:
|
||||
raise HTTPException(400, "Outlines can not be empty")
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
presentation.outlines = [each.model_dump() for each in self.outlines]
|
||||
presentation.title = self.title or presentation.title
|
||||
presentation.theme = self.theme
|
||||
sql_session.exec(
|
||||
delete(SlideSqlModel).where(
|
||||
SlideSqlModel.presentation == self.presentation_id
|
||||
)
|
||||
)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
self.presentation = presentation
|
||||
|
||||
yield SSEResponse(
|
||||
event="response", data=json.dumps({"status": "Analyzing information 📊"})
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = None
|
||||
|
||||
# self.presentation_json will be mutated by the generator
|
||||
if is_ollama_selected() or is_custom_llm_selected():
|
||||
async for result in self.generate_presentation_ollama_custom():
|
||||
yield result
|
||||
else:
|
||||
async for result in self.generate_presentation_openai_google():
|
||||
yield result
|
||||
|
||||
slide_models: List[SlideModel] = []
|
||||
for i, slide in enumerate(self.presentation_json["slides"]):
|
||||
slide["index"] = i
|
||||
slide["presentation"] = self.presentation.id
|
||||
slide["content"] = (
|
||||
LLM_CONTENT_TYPE_MAPPING[slide["type"]](**slide["content"])
|
||||
.to_content()
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
slide_model = SlideModel(**slide)
|
||||
slide_models.append(slide_model)
|
||||
|
||||
async for result in self.fetch_slide_assets(slide_models):
|
||||
yield result
|
||||
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in slide_models
|
||||
]
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
sql_session.add_all(slide_sql_models)
|
||||
sql_session.commit()
|
||||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
|
||||
yield SSEStatusResponse(status="Packing slide data").to_string()
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=self.presentation, slides=slide_sql_models
|
||||
).to_response_dict()
|
||||
|
||||
yield SSECompleteResponse(key="presentation", value=response).to_string()
|
||||
|
||||
async def generate_presentation_openai_google(self):
|
||||
presentation_text = ""
|
||||
async for event in await generate_presentation_stream(
|
||||
PresentationMarkdownModel(
|
||||
title=self.title,
|
||||
slides=self.outlines,
|
||||
notes=self.presentation.notes,
|
||||
)
|
||||
):
|
||||
chunk = event.choices[0].delta.content
|
||||
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
presentation_text += chunk
|
||||
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk}),
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = json.loads(presentation_text)
|
||||
|
||||
async def generate_presentation_ollama_custom(self):
|
||||
presentation_structure = PresentationStructureModel(
|
||||
**self.presentation.structure
|
||||
)
|
||||
slide_models = []
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": '{ "slides": [ '}),
|
||||
).to_string()
|
||||
n_slides = len(presentation_structure.slides)
|
||||
for i, slide_structure in enumerate(presentation_structure.slides):
|
||||
# Informing about the start of the slide
|
||||
# This is to make sure that the client renders slide n
|
||||
# when it receives start chunk of slide n + 1
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": "{"}),
|
||||
).to_string()
|
||||
|
||||
slide_content = await get_slide_content_from_type_and_outline(
|
||||
slide_structure.type, self.outlines[i]
|
||||
)
|
||||
slide_model = LLMSlideModel(
|
||||
type=slide_structure.type,
|
||||
content=slide_content.model_dump(mode="json"),
|
||||
)
|
||||
slide_models.append(slide_model)
|
||||
chunk = json.dumps(slide_model.model_dump(mode="json"))
|
||||
|
||||
if i < n_slides - 1:
|
||||
chunk += ","
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": chunk[1:]}),
|
||||
).to_string()
|
||||
yield SSEResponse(
|
||||
event="response",
|
||||
data=json.dumps({"type": "chunk", "chunk": " ] }"}),
|
||||
).to_string()
|
||||
|
||||
self.presentation_json = LLMPresentationModel(
|
||||
slides=slide_models,
|
||||
).model_dump(mode="json")
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationAndSlides
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class GetPresentationHandler:
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"presentation": self.id}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.id)
|
||||
slide_models = sql_session.exec(
|
||||
select(SlideSqlModel).where(SlideSqlModel.presentation == self.id)
|
||||
).all()
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=presentation, slides=slide_models
|
||||
).to_response_dict()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return response
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
from sqlmodel import select, exists
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationWithOneSlide
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class GetPresentationsHandler:
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
# Get presentations that have at least one slide
|
||||
presentations = sql_session.exec(
|
||||
select(PresentationSqlModel).where(
|
||||
exists().where(
|
||||
SlideSqlModel.presentation == PresentationSqlModel.id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
presentations.sort(key=lambda x: x.created_at, reverse=True)
|
||||
presentations_with_slide = []
|
||||
for presentation in presentations:
|
||||
slide = sql_session.exec(
|
||||
select(SlideSqlModel)
|
||||
.where(SlideSqlModel.presentation == presentation.id)
|
||||
.where(SlideSqlModel.index == 0)
|
||||
).first()
|
||||
presentations_with_slide.append(
|
||||
PresentationWithOneSlide.from_presentation_and_slide(
|
||||
presentation, slide
|
||||
)
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(
|
||||
[each.model_dump(mode="json") for each in presentations]
|
||||
),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return presentations_with_slide
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from typing import Optional
|
||||
from api.services.logging import LoggingService
|
||||
from api.models import LogMetadata
|
||||
from api.utils.model_utils import list_available_custom_models
|
||||
|
||||
|
||||
class ListAvailableCustomModelsHandler:
|
||||
|
||||
def __init__(self, url: Optional[str] = None, api_key: Optional[str] = None):
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
print("-" * 40)
|
||||
print(self.url, self.api_key)
|
||||
return await list_available_custom_models(self.url, self.api_key)
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.model_utils import list_pulled_ollama_models
|
||||
|
||||
|
||||
class ListPulledOllamaModelsHandler:
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message("Listing Ollama models"),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
pulled_models = await list_pulled_ollama_models()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(pulled_models),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return pulled_models
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationLayoutSqlModel
|
||||
|
||||
|
||||
class ListPresentationLayoutsHandler:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
with get_sql_session() as sql_session:
|
||||
layouts = sql_session.exec(select(PresentationLayoutSqlModel)).all()
|
||||
return layouts
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
from sqlmodel import select
|
||||
from api.models import LogMetadata
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import SlideLayoutSqlModel
|
||||
|
||||
|
||||
class ListSlideLayoutsHandler:
|
||||
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
with get_sql_session() as sql_session:
|
||||
layouts = sql_session.exec(select(SlideLayoutSqlModel)).all()
|
||||
return layouts
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from api.models import LogMetadata, OllamaModelMetadata
|
||||
from api.routers.presentation.models import OllamaSupportedModelsResponse
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
|
||||
|
||||
class ListSupportedOllamaModelsHandler:
|
||||
async def get(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message("Listing supported Ollama models"),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return OllamaSupportedModelsResponse(
|
||||
models=SUPPORTED_OLLAMA_MODELS.values(),
|
||||
)
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
import json
|
||||
import traceback
|
||||
import aiohttp
|
||||
from fastapi import BackgroundTasks, HTTPException
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
SUPPORTED_OLLAMA_MODELS,
|
||||
)
|
||||
from api.routers.presentation.models import OllamaModelStatusResponse
|
||||
from api.services.instances import REDIS_SERVICE
|
||||
from api.services.logging import LoggingService
|
||||
from api.utils.model_utils import (
|
||||
get_llm_provider_url_or,
|
||||
list_pulled_ollama_models,
|
||||
pull_ollama_model,
|
||||
)
|
||||
|
||||
|
||||
class PullOllamaModelHandler:
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
async def get(
|
||||
self,
|
||||
logging_service: LoggingService,
|
||||
log_metadata: LogMetadata,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.name),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
if self.name not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {self.name} is not supported",
|
||||
)
|
||||
|
||||
try:
|
||||
pulled_models = await list_pulled_ollama_models()
|
||||
filtered_models = [
|
||||
model for model in pulled_models if model.name == self.name
|
||||
]
|
||||
if filtered_models:
|
||||
return filtered_models[0]
|
||||
except HTTPException as e:
|
||||
logging_service.logger.warning(
|
||||
logging_service.message(e.detail),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logging_service.logger.warning(
|
||||
f"Failed to check pulled models: {e}",
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to check pulled models: {e}",
|
||||
)
|
||||
|
||||
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{self.name}")
|
||||
|
||||
# If the model is being pulled, return the model
|
||||
if saved_model_status:
|
||||
saved_model_status_json = json.loads(saved_model_status)
|
||||
# If the model is being pulled, return the model
|
||||
# ? If the model status is pulled in redis but was not found while listing pulled models,
|
||||
# ? it means the model was deleted and we need to pull it again
|
||||
if (
|
||||
saved_model_status_json["status"] == "error"
|
||||
or saved_model_status_json["status"] == "pulled"
|
||||
):
|
||||
REDIS_SERVICE.delete(f"ollama_models/{self.name}")
|
||||
else:
|
||||
return saved_model_status_json
|
||||
|
||||
# If the model is not being pulled, pull the model
|
||||
background_tasks.add_task(self.pull_model_in_background)
|
||||
|
||||
return OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
|
||||
async def pull_model_in_background(self):
|
||||
await self.pull_model()
|
||||
|
||||
async def pull_model(self):
|
||||
saved_model_status = OllamaModelStatusResponse(
|
||||
name=self.name,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
log_event_count = 0
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(self.name):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{self.name}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
return saved_model_status
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import SavePresentationLayoutsRequest
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class SavePresentationLayoutsHandler:
|
||||
|
||||
def __init__(self, data: SavePresentationLayoutsRequest):
|
||||
self.data = data
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
for layout in self.data.layouts:
|
||||
sql_session.merge(layout)
|
||||
sql_session.commit()
|
||||
|
||||
return self.data
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import SaveSlideLayoutsRequest
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class SaveSlideLayoutsHandler:
|
||||
|
||||
def __init__(self, data: SaveSlideLayoutsRequest):
|
||||
self.data = data
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
for layout in self.data.layouts:
|
||||
sql_session.merge(layout)
|
||||
sql_session.commit()
|
||||
|
||||
return self.data
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
PresentationAndPaths,
|
||||
SearchIconRequest,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from image_processor.icons_finder import get_icons
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
|
||||
|
||||
class SearchIconHandler:
|
||||
|
||||
def __init__(self, data: SearchIconRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
vector_store = get_icons_vectorstore()
|
||||
|
||||
icon_paths = await get_icons(
|
||||
vector_store,
|
||||
self.data.query or "",
|
||||
self.data.page,
|
||||
self.data.limit,
|
||||
self.data.category,
|
||||
self.temp_dir,
|
||||
)
|
||||
|
||||
response = PresentationAndPaths(
|
||||
presentation_id=self.data.presentation_id, paths=icon_paths
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
import uuid
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationAndUrls, SearchImageRequest
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class SearchImageHandler:
|
||||
|
||||
def __init__(self, data: SearchImageRequest):
|
||||
self.data = data
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
response = PresentationAndUrls(
|
||||
presentation_id=self.data.presentation_id, urls=[]
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
from fastapi import UploadFile
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
class UpdateParsedDocumentHandler:
|
||||
|
||||
def __init__(self, file_path: str, file: UploadFile):
|
||||
self.file_path = file_path
|
||||
self.file = file
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message({"path": self.file_path, "file": self.file}),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with open(self.file_path, "wb") as f:
|
||||
f.write(await self.file.read())
|
||||
|
||||
return {"message": "File saved successfully"}
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import UpdatePresentationThemeRequest
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PreferencesSqlModel, PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
|
||||
|
||||
class UpdatePresentationThemeHandler:
|
||||
|
||||
def __init__(self, data: UpdatePresentationThemeRequest):
|
||||
self.data = data
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(
|
||||
PresentationSqlModel, self.data.presentation_id
|
||||
)
|
||||
preferences = sql_session.get(PreferencesSqlModel, 0)
|
||||
|
||||
if not preferences:
|
||||
preferences = PreferencesSqlModel(id=0, theme=None)
|
||||
sql_session.add(preferences)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(preferences)
|
||||
|
||||
if self.data.theme:
|
||||
theme_name = self.data.theme.get("name", None)
|
||||
if theme_name and theme_name.lower() == "custom":
|
||||
preferences.theme = self.data.theme
|
||||
|
||||
presentation.theme = self.data.theme
|
||||
sql_session.commit()
|
||||
|
||||
return {"message": "Theme updated successfully"}
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import os
|
||||
from typing import List
|
||||
from urllib.parse import unquote, urlparse
|
||||
import uuid
|
||||
|
||||
from sqlmodel import delete
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import (
|
||||
PresentationUpdateRequest,
|
||||
PresentationAndSlides,
|
||||
)
|
||||
from api.services.logging import LoggingService
|
||||
from api.sql_models import PresentationSqlModel, SlideSqlModel
|
||||
from api.utils.utils import download_files, get_presentation_dir, replace_file_name
|
||||
from api.services.database import get_sql_session
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
|
||||
|
||||
class UpdateSlideModelsHandler:
|
||||
|
||||
def __init__(self, data: PresentationUpdateRequest):
|
||||
self.data = data
|
||||
self.presentation_id = data.presentation_id
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir()
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(self.data.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
presentation_id = self.data.presentation_id
|
||||
new_slides = self.data.slides
|
||||
|
||||
# Handle images
|
||||
images_local_paths = []
|
||||
images_download_links = []
|
||||
for new_slide in new_slides:
|
||||
new_images = new_slide.images or []
|
||||
for i, image in enumerate(new_images):
|
||||
if image.startswith("http"):
|
||||
parsed_url = unquote(urlparse(image).path)
|
||||
image_name = replace_file_name(
|
||||
os.path.basename(parsed_url), str(uuid.uuid4())
|
||||
)
|
||||
image_path = f"{self.presentation_dir}/images/{image_name}"
|
||||
images_local_paths.append(image_path)
|
||||
images_download_links.append(image)
|
||||
new_slide.images[i] = image_path
|
||||
|
||||
if images_download_links:
|
||||
await download_files(images_download_links, images_local_paths)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
slide_sql_models = [
|
||||
SlideSqlModel(**each.model_dump(mode="json")) for each in new_slides
|
||||
]
|
||||
to_update_slides_ids = [each.id for each in slide_sql_models]
|
||||
sql_session.exec(
|
||||
delete(SlideSqlModel).where(SlideSqlModel.id.in_(to_update_slides_ids))
|
||||
)
|
||||
sql_session.add_all(slide_sql_models)
|
||||
sql_session.commit()
|
||||
for each in slide_sql_models:
|
||||
sql_session.refresh(each)
|
||||
presentation = sql_session.get(PresentationSqlModel, presentation_id)
|
||||
|
||||
response = PresentationAndSlides(
|
||||
presentation=presentation, slides=slide_sql_models
|
||||
)
|
||||
response = response.to_response_dict()
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
return response
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
from typing import List, Optional
|
||||
import uuid
|
||||
from fastapi import UploadFile
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import DocumentsAndImagesPath
|
||||
from api.services.logging import LoggingService
|
||||
from api.validators import validate_files
|
||||
from document_processor.loader import UPLOAD_ACCEPTED_DOCUMENTS
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
|
||||
|
||||
class UploadFilesHandler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
documents: Optional[List[UploadFile]],
|
||||
images: Optional[List[UploadFile]],
|
||||
):
|
||||
self.documents = documents
|
||||
self.images = images
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
print("Upload Temp Dir: " + self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(
|
||||
{
|
||||
"documents": self.documents,
|
||||
"images": self.images,
|
||||
}
|
||||
),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
validate_files(self.documents, True, True, 50, UPLOAD_ACCEPTED_DOCUMENTS)
|
||||
validate_files(
|
||||
self.images, True, True, 10, ["image/jpeg", "image/png", "image/webp"]
|
||||
)
|
||||
|
||||
self.documents = self.documents or []
|
||||
self.images = self.images or []
|
||||
|
||||
temp_documents: List[str] = []
|
||||
if self.documents or self.images:
|
||||
all_documents = self.documents + self.images
|
||||
for doc in all_documents:
|
||||
temp_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
doc.filename, self.temp_dir
|
||||
)
|
||||
with open(temp_path, "wb") as f:
|
||||
content = await doc.read()
|
||||
f.write(content)
|
||||
|
||||
temp_documents.append(temp_path)
|
||||
|
||||
documents_count = len(temp_documents)
|
||||
response = DocumentsAndImagesPath(
|
||||
documents=temp_documents[:documents_count],
|
||||
images=temp_documents[documents_count:],
|
||||
)
|
||||
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from fastapi import UploadFile
|
||||
|
||||
from api.models import LogMetadata
|
||||
from api.routers.presentation.models import PresentationAndPath
|
||||
from api.services.logging import LoggingService
|
||||
from api.services.instances import TEMP_FILE_SERVICE
|
||||
from api.sql_models import PresentationSqlModel
|
||||
from api.services.database import get_sql_session
|
||||
from api.utils.utils import get_presentation_dir
|
||||
|
||||
|
||||
class UploadPresentationThumbnailHandler:
|
||||
|
||||
def __init__(self, presentation_id: str, thumbnail: UploadFile):
|
||||
self.presentation_id = presentation_id
|
||||
self.thumbnail = thumbnail
|
||||
|
||||
self.session = str(uuid.uuid4())
|
||||
self.temp_dir = TEMP_FILE_SERVICE.create_temp_dir(self.session)
|
||||
|
||||
self.presentation_dir = get_presentation_dir(self.presentation_id)
|
||||
|
||||
def __del__(self):
|
||||
TEMP_FILE_SERVICE.cleanup_temp_dir(self.temp_dir)
|
||||
|
||||
async def post(self, logging_service: LoggingService, log_metadata: LogMetadata):
|
||||
logging_service.logger.info(
|
||||
logging_service.message(
|
||||
{
|
||||
"presentation_id": self.presentation_id,
|
||||
"thumbnail": self.thumbnail,
|
||||
}
|
||||
),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
with get_sql_session() as sql_session:
|
||||
presentation = sql_session.get(PresentationSqlModel, self.presentation_id)
|
||||
|
||||
with open(os.path.join(self.presentation_dir, "thumbnail.jpg"), "wb") as f:
|
||||
f.write(await self.thumbnail.read())
|
||||
|
||||
presentation.thumbnail = os.path.join(
|
||||
self.presentation_dir, "thumbnail.jpg"
|
||||
)
|
||||
sql_session.commit()
|
||||
sql_session.refresh(presentation)
|
||||
|
||||
response = PresentationAndPath(
|
||||
presentation_id=self.presentation_id, path=presentation.thumbnail
|
||||
)
|
||||
logging_service.logger.info(
|
||||
logging_service.message(response.model_dump(mode="json")),
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from api.models import SSEStatusResponse
|
||||
from api.utils.utils import get_presentation_images_dir
|
||||
from image_processor.icons_finder import get_icon
|
||||
from image_processor.icons_vectorstore_utils import get_icons_vectorstore
|
||||
from image_processor.images_finder import generate_image
|
||||
from ppt_generator.models.slide_model import SlideModel
|
||||
from ppt_generator.slide_model_utils import SlideModelUtils
|
||||
|
||||
|
||||
class FetchAssetsOnPresentationGenerationMixin:
|
||||
|
||||
async def fetch_slide_assets(self, slide_models: List[SlideModel]):
|
||||
image_prompts = []
|
||||
icon_queries = []
|
||||
|
||||
for each_slide_model in slide_models:
|
||||
slide_model_utils = SlideModelUtils(self.theme, each_slide_model)
|
||||
image_prompts.extend(slide_model_utils.get_image_prompts())
|
||||
icon_queries.extend(slide_model_utils.get_icon_queries())
|
||||
|
||||
if icon_queries:
|
||||
icon_vector_store = get_icons_vectorstore()
|
||||
|
||||
images_directory = get_presentation_images_dir(self.presentation_id)
|
||||
|
||||
coroutines = [
|
||||
generate_image(
|
||||
each,
|
||||
images_directory,
|
||||
)
|
||||
for each in image_prompts
|
||||
] + [get_icon(icon_vector_store, each) for each in icon_queries]
|
||||
|
||||
assets_future = asyncio.gather(*coroutines)
|
||||
|
||||
while not assets_future.done():
|
||||
status = SSEStatusResponse(status="Fetching slide assets").to_string()
|
||||
yield status
|
||||
await asyncio.sleep(5)
|
||||
|
||||
assets = await assets_future
|
||||
|
||||
image_prompts_len = len(image_prompts)
|
||||
|
||||
images = assets[:image_prompts_len]
|
||||
icons = assets[image_prompts_len:]
|
||||
|
||||
for each_slide_model in slide_models:
|
||||
each_slide_model.images = images[: each_slide_model.images_count]
|
||||
images = images[each_slide_model.images_count :]
|
||||
|
||||
each_slide_model.icons = icons[: each_slide_model.icons_count]
|
||||
icons = icons[each_slide_model.icons_count :]
|
||||
|
||||
yield SSEStatusResponse(status="Slide assets fetched").to_string()
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
import os
|
||||
from urllib.parse import unquote, urlparse
|
||||
import uuid
|
||||
from api.utils.utils import download_files, replace_file_name
|
||||
from ppt_generator.models.pptx_models import PptxPictureBoxModel
|
||||
|
||||
|
||||
class FetchPresentationAssetsMixin:
|
||||
|
||||
async def fetch_presentation_assets(self):
|
||||
image_urls = []
|
||||
image_local_paths = []
|
||||
|
||||
for each_slide in self.data.pptx_model.slides:
|
||||
for each_shape in each_slide.shapes:
|
||||
if isinstance(each_shape, PptxPictureBoxModel):
|
||||
image_path = each_shape.picture.path
|
||||
if image_path.startswith("http"):
|
||||
if image_path.startswith("http://localhost:3000/static"):
|
||||
image_path = image_path.replace(
|
||||
"http://localhost:3000/static", ""
|
||||
)
|
||||
image_path = "/app" + image_path
|
||||
elif image_path.startswith("http://localhost/static"):
|
||||
image_path = image_path.replace(
|
||||
"http://localhost/static", ""
|
||||
)
|
||||
image_path = "/app" + image_path
|
||||
else:
|
||||
image_urls.append(image_path)
|
||||
parsed_url = unquote(urlparse(image_path).path)
|
||||
image_name = replace_file_name(
|
||||
os.path.basename(parsed_url), str(uuid.uuid4())
|
||||
)
|
||||
image_path = os.path.join(self.temp_dir, image_name)
|
||||
image_local_paths.append(image_path)
|
||||
elif image_path.startswith("file://"):
|
||||
image_path = image_path.replace("file:///", "")
|
||||
# Check if it's a Windows path (has colon at index 1)
|
||||
if not (len(image_path) > 1 and image_path[1] == ":"):
|
||||
image_path = "/" + image_path
|
||||
|
||||
each_shape.picture.path = image_path
|
||||
each_shape.picture.is_network = False
|
||||
|
||||
await download_files(image_urls, image_local_paths)
|
||||
|
|
@ -1,211 +0,0 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.models import OllamaModelMetadata
|
||||
from ppt_config_generator.models import SlideMarkdownModel
|
||||
from ppt_generator.models.pptx_models import PptxPresentationModel
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
IconCategoryEnum,
|
||||
ImagePromptWithThemeAndAspectRatio,
|
||||
)
|
||||
from api.sql_models import (
|
||||
PresentationLayoutSqlModel,
|
||||
PresentationSqlModel,
|
||||
SlideLayoutSqlModel,
|
||||
SlideSqlModel,
|
||||
)
|
||||
from ollama._types import ModelDetails
|
||||
|
||||
|
||||
class ThemeEnum(Enum):
|
||||
DARK = "dark"
|
||||
LIGHT = "light"
|
||||
ROYAL_BLUE = "royal_blue"
|
||||
CREAM = "cream"
|
||||
LIGHT_RED = "light_red"
|
||||
DARK_PINK = "dark_pink"
|
||||
FAINT_YELLOW = "faint_yellow"
|
||||
|
||||
|
||||
class DocumentsAndImagesPath(BaseModel):
|
||||
documents: Optional[List[str]] = None
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GenerateResearchReportRequest(BaseModel):
|
||||
language: Optional[str] = None
|
||||
query: str
|
||||
|
||||
|
||||
class DecomposeDocumentsRequest(DocumentsAndImagesPath):
|
||||
pass
|
||||
|
||||
|
||||
class GeneratePresentationRequirementsRequest(BaseModel):
|
||||
prompt: Optional[str] = None
|
||||
n_slides: int
|
||||
language: str
|
||||
documents: Optional[List[str]] = None
|
||||
research_reports: Optional[List[str]] = None
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GenerateOutlinesRequest(BaseModel):
|
||||
presentation_id: str
|
||||
|
||||
|
||||
class PresentationGenerateRequest(BaseModel):
|
||||
presentation_id: str
|
||||
theme: Optional[dict] = None
|
||||
images: Optional[List[str]] = None
|
||||
outlines: List[SlideMarkdownModel]
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
presentation_id: str
|
||||
prompt: ImagePromptWithThemeAndAspectRatio
|
||||
|
||||
|
||||
class SearchImageRequest(BaseModel):
|
||||
presentation_id: str
|
||||
query: Optional[str] = None
|
||||
page: int = 1
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class SearchIconRequest(BaseModel):
|
||||
presentation_id: str
|
||||
query: Optional[str] = None
|
||||
category: Optional[IconCategoryEnum] = None
|
||||
page: int = 1
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class SlideEditRequest(BaseModel):
|
||||
index: int
|
||||
prompt: str
|
||||
|
||||
|
||||
class EditPresentationRequest(BaseModel):
|
||||
presentation_id: str
|
||||
watermark: bool = True
|
||||
changes: List[SlideEditRequest]
|
||||
|
||||
|
||||
class EditPresentationSlideRequest(BaseModel):
|
||||
presentation_id: str
|
||||
index: int
|
||||
prompt: str
|
||||
|
||||
|
||||
class UpdatePresentationThemeRequest(BaseModel):
|
||||
presentation_id: str
|
||||
theme: Optional[dict] = None
|
||||
|
||||
|
||||
class ExportAsRequest(BaseModel):
|
||||
presentation_id: str
|
||||
pptx_model: PptxPresentationModel
|
||||
|
||||
|
||||
class DecomposeDocumentsResponse(BaseModel):
|
||||
documents: dict
|
||||
|
||||
|
||||
class PresentationAndSlides(BaseModel):
|
||||
presentation: PresentationSqlModel
|
||||
slides: List[SlideSqlModel]
|
||||
|
||||
def to_response_dict(self):
|
||||
presentation = self.presentation.model_dump(mode="json")
|
||||
return {
|
||||
"presentation": presentation,
|
||||
"slides": [each.model_dump(mode="json") for each in self.slides],
|
||||
}
|
||||
|
||||
|
||||
class PresentationUpdateRequest(BaseModel):
|
||||
presentation_id: str
|
||||
slides: List[SlideSqlModel]
|
||||
|
||||
|
||||
class PresentationAndUrl(BaseModel):
|
||||
presentation_id: str
|
||||
url: str
|
||||
|
||||
|
||||
class PresentationAndUrls(BaseModel):
|
||||
presentation_id: str
|
||||
urls: List[str]
|
||||
|
||||
|
||||
class PresentationAndPath(BaseModel):
|
||||
presentation_id: str
|
||||
path: str
|
||||
|
||||
|
||||
class PresentationAndPaths(BaseModel):
|
||||
presentation_id: str
|
||||
paths: List[str]
|
||||
|
||||
|
||||
class PresentationPathAndEditPath(PresentationAndPath):
|
||||
edit_path: str
|
||||
|
||||
|
||||
class UpdatePresentationTitlesRequest(BaseModel):
|
||||
presentation_id: str
|
||||
titles: List[str]
|
||||
|
||||
|
||||
class GeneratePresentationRequest(BaseModel):
|
||||
prompt: str
|
||||
n_slides: int = Field(default=8, ge=5, le=15)
|
||||
language: str = Field(default="English")
|
||||
theme: ThemeEnum = Field(default=ThemeEnum.LIGHT)
|
||||
documents: Optional[List[UploadFile]] = None
|
||||
export_as: Literal["pptx", "pdf"] = Field(default="pptx")
|
||||
|
||||
|
||||
class OllamaModelStatusResponse(BaseModel):
|
||||
name: str
|
||||
size: Optional[int] = None
|
||||
downloaded: Optional[int] = None
|
||||
status: str
|
||||
done: bool
|
||||
|
||||
|
||||
class OllamaSupportedModelsResponse(BaseModel):
|
||||
models: List[OllamaModelMetadata]
|
||||
|
||||
|
||||
class PresentationWithOneSlide(BaseModel):
|
||||
id: str
|
||||
created_at: datetime
|
||||
theme: Optional[dict] = None
|
||||
title: Optional[str] = None
|
||||
slide: SlideSqlModel
|
||||
|
||||
@classmethod
|
||||
def from_presentation_and_slide(
|
||||
cls, presentation: PresentationSqlModel, slide: SlideSqlModel
|
||||
):
|
||||
return cls(
|
||||
id=presentation.id,
|
||||
created_at=presentation.created_at,
|
||||
theme=presentation.theme,
|
||||
title=presentation.title,
|
||||
slide=slide,
|
||||
)
|
||||
|
||||
|
||||
class SaveSlideLayoutsRequest(BaseModel):
|
||||
layouts: List[SlideLayoutSqlModel]
|
||||
|
||||
|
||||
class SavePresentationLayoutsRequest(BaseModel):
|
||||
layouts: List[PresentationLayoutSqlModel]
|
||||
|
|
@ -1,470 +0,0 @@
|
|||
from typing import Annotated, List, Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, File, Form, UploadFile
|
||||
|
||||
from api.models import SessionModel
|
||||
from api.request_utils import RequestUtils
|
||||
from api.routers.presentation.handlers.decompose_documents import (
|
||||
DecomposeDocumentsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.delete_presentation import (
|
||||
DeletePresentationHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.delete_slide import DeleteSlideHandler
|
||||
from api.routers.presentation.handlers.edit import PresentationEditHandler
|
||||
from api.routers.presentation.handlers.export_as_pptx import ExportAsPptxHandler
|
||||
from api.routers.presentation.handlers.generate_data import (
|
||||
PresentationGenerateDataHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_image import GenerateImageHandler
|
||||
from api.routers.presentation.handlers.generate_presentation import (
|
||||
GeneratePresentationHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_presentation_requirements import (
|
||||
GeneratePresentationRequirementsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_research_report import (
|
||||
GenerateResearchReportHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_stream import (
|
||||
PresentationGenerateStreamHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.generate_outlines import (
|
||||
PresentationOutlinesGenerateHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.get_presentation import GetPresentationHandler
|
||||
from api.routers.presentation.handlers.get_presentations import GetPresentationsHandler
|
||||
from api.routers.presentation.handlers.list_available_custom_models import (
|
||||
ListAvailableCustomModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.list_ollama_pulled_models import (
|
||||
ListPulledOllamaModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.list_presentation_layouts import (
|
||||
ListPresentationLayoutsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.list_slide_layouts import ListSlideLayoutsHandler
|
||||
from api.routers.presentation.handlers.list_supported_ollama_models import (
|
||||
ListSupportedOllamaModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.pull_ollama_model import PullOllamaModelHandler
|
||||
from api.routers.presentation.handlers.save_presentation_layouts_handler import (
|
||||
SavePresentationLayoutsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.save_slide_layouts_handler import (
|
||||
SaveSlideLayoutsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.search_icon import SearchIconHandler
|
||||
from api.routers.presentation.handlers.search_image import SearchImageHandler
|
||||
from api.routers.presentation.handlers.update_parsed_document import (
|
||||
UpdateParsedDocumentHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.update_presentation_theme import (
|
||||
UpdatePresentationThemeHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.update_slide_models import (
|
||||
UpdateSlideModelsHandler,
|
||||
)
|
||||
from api.routers.presentation.handlers.upload_files import UploadFilesHandler
|
||||
from api.routers.presentation.handlers.upload_presentation_thumbnail import (
|
||||
UploadPresentationThumbnailHandler,
|
||||
)
|
||||
from api.routers.presentation.models import (
|
||||
DecomposeDocumentsRequest,
|
||||
DecomposeDocumentsResponse,
|
||||
DocumentsAndImagesPath,
|
||||
EditPresentationSlideRequest,
|
||||
ExportAsRequest,
|
||||
GenerateImageRequest,
|
||||
GeneratePresentationRequest,
|
||||
GeneratePresentationRequirementsRequest,
|
||||
GenerateResearchReportRequest,
|
||||
OllamaModelStatusResponse,
|
||||
OllamaSupportedModelsResponse,
|
||||
PresentationAndPath,
|
||||
PresentationAndPaths,
|
||||
PresentationAndSlides,
|
||||
GenerateOutlinesRequest,
|
||||
PresentationAndUrls,
|
||||
PresentationGenerateRequest,
|
||||
PresentationPathAndEditPath,
|
||||
SavePresentationLayoutsRequest,
|
||||
SaveSlideLayoutsRequest,
|
||||
SearchIconRequest,
|
||||
SearchImageRequest,
|
||||
UpdatePresentationThemeRequest,
|
||||
PresentationUpdateRequest,
|
||||
PresentationWithOneSlide,
|
||||
)
|
||||
from api.sql_models import (
|
||||
PresentationLayoutSqlModel,
|
||||
PresentationSqlModel,
|
||||
SlideLayoutSqlModel,
|
||||
SlideSqlModel,
|
||||
)
|
||||
from api.utils.utils import handle_errors
|
||||
|
||||
route_prefix = "/api/v1/ppt"
|
||||
presentation_router = APIRouter(prefix=route_prefix)
|
||||
|
||||
|
||||
@presentation_router.get(
|
||||
"/user_presentations", response_model=List[PresentationWithOneSlide]
|
||||
)
|
||||
async def get_user_presentations():
|
||||
request_utils = RequestUtils(f"{route_prefix}/user_presentations")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
GetPresentationsHandler().get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get("/presentation", response_model=PresentationAndSlides)
|
||||
async def get_presentation_from_id(presentation_id: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GetPresentationHandler(presentation_id).get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/files/upload", response_model=DocumentsAndImagesPath)
|
||||
async def upload_files(
|
||||
documents: Annotated[Optional[List[UploadFile]], File()] = None,
|
||||
images: Annotated[Optional[List[UploadFile]], File()] = None,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/files/upload")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
UploadFilesHandler(documents, images).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/report/generate", response_model=str)
|
||||
async def generate_research_report(
|
||||
data: GenerateResearchReportRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/report/generate")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
GenerateResearchReportHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/files/decompose", response_model=DecomposeDocumentsResponse)
|
||||
async def decompose_documents(data: DecomposeDocumentsRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/files/decompose")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
DecomposeDocumentsHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/document/update")
|
||||
async def update_document(
|
||||
path: Annotated[str, Body()],
|
||||
file: Annotated[UploadFile, File()],
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/document/update")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
UpdateParsedDocumentHandler(path, file).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/create", response_model=PresentationSqlModel)
|
||||
async def create_presentation(
|
||||
data: GeneratePresentationRequirementsRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/create")
|
||||
presentation_id = str(uuid.uuid4())
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GeneratePresentationRequirementsHandler(presentation_id, data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/outlines/generate", response_model=PresentationSqlModel)
|
||||
async def generate_outlines(data: GenerateOutlinesRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/outlines/generate")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationOutlinesGenerateHandler(data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/generate/data", response_model=SessionModel)
|
||||
async def submit_presentation_generation_data(
|
||||
data: PresentationGenerateRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/generate/data")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationGenerateDataHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get("/generate/stream")
|
||||
async def presentation_generation_stream(presentation_id: str, session: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/generate/stream")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationGenerateStreamHandler(presentation_id, session).get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/presentation/thumbnail", response_model=PresentationAndPath)
|
||||
async def update_presentation(
|
||||
presentation_id: Annotated[str, Body()],
|
||||
thumbnail: Annotated[UploadFile, File()],
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation/thumbnail")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
UploadPresentationThumbnailHandler(presentation_id, thumbnail).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/presentation/theme")
|
||||
async def update_presentation(
|
||||
data: UpdatePresentationThemeRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation/theme")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
UpdatePresentationThemeHandler(data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/edit", response_model=SlideSqlModel)
|
||||
async def update_presentation(
|
||||
data: EditPresentationSlideRequest,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/edit")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id
|
||||
)
|
||||
return await handle_errors(
|
||||
PresentationEditHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/slides/update", response_model=PresentationAndSlides)
|
||||
async def update_slide_models(data: PresentationUpdateRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/slides/update")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
UpdateSlideModelsHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/image/generate", response_model=PresentationAndPaths)
|
||||
async def generate_image(data: GenerateImageRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/image/generate")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GenerateImageHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/image/search", response_model=PresentationAndUrls)
|
||||
async def search_image(data: SearchImageRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/image/search")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
SearchImageHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/icon/search", response_model=PresentationAndPaths)
|
||||
async def search_icon(data: SearchIconRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/icon/search")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
SearchIconHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post(
|
||||
"/presentation/export_as_pptx", response_model=PresentationAndPath
|
||||
)
|
||||
async def export_as_pptx(data: ExportAsRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/presentation/export_as_pptx")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=data.presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
ExportAsPptxHandler(data).post, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.delete("/delete", status_code=204)
|
||||
async def delete_presentation(presentation_id: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/delete")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
DeletePresentationHandler(presentation_id).delete, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.delete("/slide/delete", status_code=204)
|
||||
async def delete_slide(slide_id: str, presentation_id: str):
|
||||
request_utils = RequestUtils(f"{route_prefix}/slide/delete")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
DeleteSlideHandler(slide_id).delete, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post(
|
||||
"/generate/presentation", response_model=PresentationPathAndEditPath
|
||||
)
|
||||
async def generate_presentation(data: Annotated[GeneratePresentationRequest, Form()]):
|
||||
presentation_id = str(uuid.uuid4())
|
||||
|
||||
request_utils = RequestUtils(f"{route_prefix}/generate/presentation")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger(
|
||||
presentation_id=presentation_id,
|
||||
)
|
||||
return await handle_errors(
|
||||
GeneratePresentationHandler(presentation_id, data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Ollama Support
|
||||
@presentation_router.get(
|
||||
"/ollama/list-supported-models", response_model=OllamaSupportedModelsResponse
|
||||
)
|
||||
async def list_supported_ollama_models():
|
||||
request_utils = RequestUtils(f"{route_prefix}/ollama/list-supported-models")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListSupportedOllamaModelsHandler().get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get(
|
||||
"/ollama/list-pulled-models", response_model=List[OllamaModelStatusResponse]
|
||||
)
|
||||
async def list_pulled_ollama_models():
|
||||
request_utils = RequestUtils(f"{route_prefix}/ollama/list-pulled-models")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListPulledOllamaModelsHandler().get, logging_service, log_metadata
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get("/ollama/pull-model", response_model=OllamaModelStatusResponse)
|
||||
async def pull_ollama_model(name: str, background_tasks: BackgroundTasks):
|
||||
request_utils = RequestUtils(f"{route_prefix}/ollama/pull-model")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
PullOllamaModelHandler(name).get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
background_tasks=background_tasks,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/models/list/custom", response_model=List[str])
|
||||
async def list_custom_models(
|
||||
url: Annotated[Optional[str], Body()] = None,
|
||||
api_key: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
request_utils = RequestUtils(f"{route_prefix}/models/list/custom")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListAvailableCustomModelsHandler(url, api_key).get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get(
|
||||
"/layout/slides/list", response_model=List[SlideLayoutSqlModel]
|
||||
)
|
||||
async def list_slide_layouts():
|
||||
request_utils = RequestUtils(f"{route_prefix}/layout/slides/list")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListSlideLayoutsHandler().get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/layout/slides/save")
|
||||
async def save_slide_layouts(data: SaveSlideLayoutsRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/layout/slides/save")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
SaveSlideLayoutsHandler(data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.get(
|
||||
"/layout/presentations/list", response_model=List[PresentationLayoutSqlModel]
|
||||
)
|
||||
async def list_presentation_layouts():
|
||||
request_utils = RequestUtils(f"{route_prefix}/layout/presentations/list")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
ListPresentationLayoutsHandler().get,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
|
||||
@presentation_router.post("/layout/presentations/save")
|
||||
async def save_presentation_layouts(data: SavePresentationLayoutsRequest):
|
||||
request_utils = RequestUtils(f"{route_prefix}/layout/presentations/save")
|
||||
logging_service, log_metadata = await request_utils.initialize_logger()
|
||||
return await handle_errors(
|
||||
SavePresentationLayoutsHandler(data).post,
|
||||
logging_service,
|
||||
log_metadata,
|
||||
)
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
from contextlib import contextmanager
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
database_url = os.getenv("DATABASE_URL") or "sqlite:///" + os.path.join(
|
||||
os.getenv("APP_DATA_DIRECTORY"), "fastapi.db"
|
||||
)
|
||||
connect_args = {}
|
||||
if "sqlite" in database_url:
|
||||
connect_args["check_same_thread"] = False
|
||||
|
||||
sql_engine = create_engine(database_url, connect_args=connect_args)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_sql_session():
|
||||
session = Session(sql_engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
from api.services.redis import RedisService
|
||||
from api.services.temp_file import TempFileService
|
||||
|
||||
|
||||
TEMP_FILE_SERVICE = TempFileService()
|
||||
REDIS_SERVICE = RedisService()
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
from typing import Any
|
||||
from logging import Logger
|
||||
|
||||
|
||||
class LoggingService:
|
||||
|
||||
def __init__(self, stream_name: str):
|
||||
self._logger = Logger(stream_name)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
return self._logger
|
||||
|
||||
def message(self, msg: Any):
|
||||
return {"msg": msg}
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
import os
|
||||
from typing import Any, Optional
|
||||
import redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
|
||||
class RedisService:
|
||||
def __init__(self):
|
||||
self.redis_host = os.getenv("REDIS_HOST", "localhost")
|
||||
self.redis_port = int(os.getenv("REDIS_PORT", "6379"))
|
||||
self.redis_db = int(os.getenv("REDIS_DB", "0"))
|
||||
self.redis_password = os.getenv("REDIS_PASSWORD")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
db=self.redis_db,
|
||||
password=self.redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
|
||||
try:
|
||||
return self.client.set(key, value, ex=expire)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
try:
|
||||
return self.client.get(key)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.delete(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.exists(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def set_hash(self, name: str, mapping: dict) -> bool:
|
||||
try:
|
||||
return self.client.hmset(name, mapping)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_hash(self, name: str) -> Optional[dict]:
|
||||
try:
|
||||
return self.client.hgetall(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete_hash(self, name: str, *fields: str) -> int:
|
||||
try:
|
||||
return self.client.hdel(name, *fields)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def set_list(self, name: str, values: list) -> bool:
|
||||
try:
|
||||
self.client.delete(name)
|
||||
if values:
|
||||
self.client.rpush(name, *values)
|
||||
return True
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
|
||||
try:
|
||||
return self.client.lrange(name, start, end)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def add_to_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.sadd(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def get_set(self, name: str) -> Optional[set]:
|
||||
try:
|
||||
return self.client.smembers(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def remove_from_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.srem(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def clear(self) -> bool:
|
||||
try:
|
||||
return self.client.flushdb()
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.client.close()
|
||||
except RedisError:
|
||||
pass
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
import os
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class TempFileService:
|
||||
base_dir = os.getenv("TEMP_DIRECTORY")
|
||||
|
||||
def __init__(self):
|
||||
self.cleanup_base_dir()
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
def create_dir_in_dir(self, base_dir: str, dir_name: Optional[str] = None) -> str:
|
||||
temp_dir = os.path.join(base_dir, dir_name if dir_name else str(uuid.uuid4()))
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
return temp_dir
|
||||
|
||||
def create_temp_dir(self, dir_name: Optional[str] = None) -> str:
|
||||
return self.create_dir_in_dir(self.base_dir, dir_name)
|
||||
|
||||
def create_temp_file_path(
|
||||
self, file_path: str, dir_path: Optional[str] = None
|
||||
) -> str:
|
||||
if dir_path is None:
|
||||
dir_path = self.base_dir
|
||||
|
||||
full_path = os.path.join(dir_path, file_path)
|
||||
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
return full_path
|
||||
|
||||
def create_temp_file(
|
||||
self, file_path: str, content: Union[bytes, str], dir_path: Optional[str] = None
|
||||
) -> str:
|
||||
file_path = self.create_temp_file_path(file_path, dir_path)
|
||||
mode = "wb" if isinstance(content, bytes) else "w"
|
||||
with open(file_path, mode) as f:
|
||||
f.write(content)
|
||||
|
||||
return file_path
|
||||
|
||||
def read_temp_file(self, file_path: str, binary: bool = True) -> Union[bytes, str]:
|
||||
mode = "rb" if binary else "r"
|
||||
with open(file_path, mode) as f:
|
||||
return f.read()
|
||||
|
||||
def cleanup_temp_file(self, file_path: str):
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
def delete_dir_files(self, dir_path: str):
|
||||
if os.path.exists(dir_path):
|
||||
for root, dirs, files in os.walk(dir_path, topdown=False):
|
||||
for name in files:
|
||||
os.remove(os.path.join(root, name))
|
||||
for name in dirs:
|
||||
os.rmdir(os.path.join(root, name))
|
||||
|
||||
def cleanup_temp_dir(self, dir_path: str):
|
||||
if os.path.exists(dir_path):
|
||||
self.delete_dir_files(dir_path)
|
||||
os.rmdir(dir_path)
|
||||
|
||||
def cleanup_base_dir(self):
|
||||
self.cleanup_temp_dir(self.base_dir)
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from sqlmodel import SQLModel, Field, Column, JSON
|
||||
|
||||
|
||||
def get_random_uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class SlideLayoutSqlModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True)
|
||||
description: Optional[str] = None
|
||||
json_schema: dict = Field(sa_column=Column(JSON, nullable=False))
|
||||
|
||||
|
||||
class PresentationLayoutSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
slide_layouts: List[str] = Field(sa_column=Column(JSON, nullable=False))
|
||||
|
||||
|
||||
class PresentationSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
created_at: datetime = Field(default=datetime.now())
|
||||
prompt: Optional[str] = None
|
||||
n_slides: int
|
||||
theme: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
title: Optional[str] = None
|
||||
structure: Optional[dict] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
notes: Optional[List[str]] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
outlines: Optional[List[dict]] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
language: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
data: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
|
||||
|
||||
class SlideSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
index: int = Field(index=True)
|
||||
layout: str
|
||||
presentation: str
|
||||
content: dict = Field(sa_column=Column(JSON, nullable=False), default=None)
|
||||
properties: Optional[dict] = Field(
|
||||
sa_column=Column(JSON, nullable=True), default=None
|
||||
)
|
||||
|
||||
|
||||
class KeyValueSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
key: str = Field(index=True)
|
||||
value: dict = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
|
||||
|
||||
class PreferencesSqlModel(SQLModel, table=True):
|
||||
id: int = Field(default=0, primary_key=True)
|
||||
theme: Optional[dict] = Field(sa_column=Column(JSON, nullable=True), default=None)
|
||||
|
|
@ -1,179 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
import openai
|
||||
|
||||
from api.models import SelectedLLMProvider
|
||||
from api.routers.presentation.models import OllamaModelStatusResponse
|
||||
|
||||
|
||||
def is_ollama_selected() -> bool:
|
||||
return get_selected_llm_provider() == SelectedLLMProvider.OLLAMA
|
||||
|
||||
|
||||
def is_custom_llm_selected() -> bool:
|
||||
return get_selected_llm_provider() == SelectedLLMProvider.CUSTOM
|
||||
|
||||
|
||||
def get_llm_provider_url_or():
|
||||
llm_provider_url = (
|
||||
os.getenv("OLLAMA_URL") if is_ollama_selected() else os.getenv("CUSTOM_LLM_URL")
|
||||
)
|
||||
llm_provider_url = llm_provider_url or "http://localhost:11434"
|
||||
if llm_provider_url.endswith("/"):
|
||||
return llm_provider_url[:-1]
|
||||
return llm_provider_url
|
||||
|
||||
|
||||
def get_selected_llm_provider() -> SelectedLLMProvider:
|
||||
return SelectedLLMProvider(os.getenv("LLM"))
|
||||
|
||||
|
||||
async def list_available_custom_models(
|
||||
url: Optional[str] = None, api_key: Optional[str] = None
|
||||
) -> list[str]:
|
||||
if not url:
|
||||
client = get_llm_client()
|
||||
else:
|
||||
client = openai.AsyncOpenAI(api_key=api_key or "null", base_url=url)
|
||||
models = []
|
||||
async for model in client.models.list():
|
||||
print(model)
|
||||
models.append(model.id)
|
||||
return models
|
||||
|
||||
|
||||
def get_model_base_url():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "https://api.openai.com/v1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.path.join(get_llm_provider_url_or(), "v1")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return get_llm_provider_url_or()
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM provider")
|
||||
|
||||
|
||||
def get_llm_api_key():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return os.getenv("OPENAI_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return os.getenv("GOOGLE_API_KEY")
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return "ollama"
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_LLM_API_KEY") or "null"
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM API key")
|
||||
|
||||
|
||||
def get_llm_client():
|
||||
client = AsyncOpenAI(
|
||||
base_url=get_model_base_url(),
|
||||
api_key=get_llm_api_key(),
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def get_large_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_small_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-mini"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
def get_nano_model():
|
||||
selected_llm = get_selected_llm_provider()
|
||||
if selected_llm == SelectedLLMProvider.OPENAI:
|
||||
return "gpt-4.1-nano"
|
||||
elif selected_llm == SelectedLLMProvider.GOOGLE:
|
||||
return "gemini-2.0-flash"
|
||||
elif selected_llm == SelectedLLMProvider.OLLAMA:
|
||||
return os.getenv("OLLAMA_MODEL")
|
||||
elif selected_llm == SelectedLLMProvider.CUSTOM:
|
||||
return os.getenv("CUSTOM_MODEL")
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM model")
|
||||
|
||||
|
||||
async def list_pulled_ollama_models() -> list[OllamaModelStatusResponse]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{get_llm_provider_url_or()}/api/tags",
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
pulled_models = await response.json()
|
||||
return [
|
||||
OllamaModelStatusResponse(
|
||||
name=m["model"],
|
||||
size=m["size"],
|
||||
status="pulled",
|
||||
downloaded=m["size"],
|
||||
done=True,
|
||||
)
|
||||
for m in pulled_models["models"]
|
||||
]
|
||||
elif response.status == 403:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Forbidden: Please check your Ollama Configuration",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to list Ollama models: {response.status}",
|
||||
)
|
||||
|
||||
|
||||
async def pull_ollama_model(model: str) -> AsyncGenerator[dict, None]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{get_llm_provider_url_or()}/api/pull",
|
||||
json={"model": model},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Failed to pull model: {await response.text()}",
|
||||
)
|
||||
|
||||
async for line in response.content:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
yield event
|
||||
|
|
@ -1,253 +0,0 @@
|
|||
from api.models import OllamaModelMetadata
|
||||
|
||||
|
||||
SUPPORTED_LLAMA_MODELS = {
|
||||
"llama3:8b": OllamaModelMetadata(
|
||||
label="Llama 3:8b",
|
||||
value="llama3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3:70b": OllamaModelMetadata(
|
||||
label="Llama 3:70b",
|
||||
value="llama3:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="40GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3.1:8b": OllamaModelMetadata(
|
||||
label="Llama 3.1:8b",
|
||||
value="llama3.1:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3.1:70b": OllamaModelMetadata(
|
||||
label="Llama 3.1:70b",
|
||||
value="llama3.1:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3.1:405b": OllamaModelMetadata(
|
||||
label="Llama 3.1:405b",
|
||||
value="llama3.1:405b",
|
||||
description="✅ Graphs supported.",
|
||||
size="243GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3.2:1b": OllamaModelMetadata(
|
||||
label="Llama 3.2:1b",
|
||||
value="llama3.2:1b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3.2:3b": OllamaModelMetadata(
|
||||
label="Llama 3.2:3b",
|
||||
value="llama3.2:3b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama3.3:70b": OllamaModelMetadata(
|
||||
label="Llama 3.3:70b",
|
||||
value="llama3.3:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama4:16x17b": OllamaModelMetadata(
|
||||
label="Llama 4:16x17b",
|
||||
value="llama4:16x17b",
|
||||
description="✅ Graphs supported.",
|
||||
size="67GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
"llama4:128x17b": OllamaModelMetadata(
|
||||
label="Llama 4:128x17b",
|
||||
value="llama4:128x17b",
|
||||
description="✅ Graphs supported.",
|
||||
size="245GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/meta.png",
|
||||
),
|
||||
}
|
||||
|
||||
SUPPORTED_GEMMA_MODELS = {
|
||||
"gemma3:1b": OllamaModelMetadata(
|
||||
label="Gemma 3:1b",
|
||||
value="gemma3:1b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="815MB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
),
|
||||
"gemma3:4b": OllamaModelMetadata(
|
||||
label="Gemma 3:4b",
|
||||
value="gemma3:4b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="3.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
),
|
||||
"gemma3:12b": OllamaModelMetadata(
|
||||
label="Gemma 3:12b",
|
||||
value="gemma3:12b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="8.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
),
|
||||
"gemma3:27b": OllamaModelMetadata(
|
||||
label="Gemma 3:27b",
|
||||
value="gemma3:27b",
|
||||
description="✅ Graphs supported.",
|
||||
size="17GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/gemma.png",
|
||||
),
|
||||
}
|
||||
|
||||
SUPPORTED_DEEPSEEK_MODELS = {
|
||||
"deepseek-r1:1.5b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:1.5b",
|
||||
value="deepseek-r1:1.5b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.1GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:7b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:7b",
|
||||
value="deepseek-r1:7b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="4.7GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:8b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:8b",
|
||||
value="deepseek-r1:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:14b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:14b",
|
||||
value="deepseek-r1:14b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="9GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:32b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:32b",
|
||||
value="deepseek-r1:32b",
|
||||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:70b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:70b",
|
||||
value="deepseek-r1:70b",
|
||||
description="✅ Graphs supported.",
|
||||
size="43GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
"deepseek-r1:671b": OllamaModelMetadata(
|
||||
label="DeepSeek R1:671b",
|
||||
value="deepseek-r1:671b",
|
||||
description="✅ Graphs supported.",
|
||||
size="404GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/deepseek.png",
|
||||
),
|
||||
}
|
||||
|
||||
SUPPORTED_QWEN_MODELS = {
|
||||
"qwen3:0.6b": OllamaModelMetadata(
|
||||
label="Qwen 3:0.6b",
|
||||
value="qwen3:0.6b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="523MB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:1.7b": OllamaModelMetadata(
|
||||
label="Qwen 3:1.7b",
|
||||
value="qwen3:1.7b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="1.4GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:4b": OllamaModelMetadata(
|
||||
label="Qwen 3:4b",
|
||||
value="qwen3:4b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="2.6GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:8b": OllamaModelMetadata(
|
||||
label="Qwen 3:8b",
|
||||
value="qwen3:8b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="5.2GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:14b": OllamaModelMetadata(
|
||||
label="Qwen 3:14b",
|
||||
value="qwen3:14b",
|
||||
description="❌ Graphs not supported.",
|
||||
size="9.3GB",
|
||||
supports_graph=False,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:30b": OllamaModelMetadata(
|
||||
label="Qwen 3:30b",
|
||||
value="qwen3:30b",
|
||||
description="✅ Graphs supported.",
|
||||
size="19GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:32b": OllamaModelMetadata(
|
||||
label="Qwen 3:32b",
|
||||
value="qwen3:32b",
|
||||
description="✅ Graphs supported.",
|
||||
size="20GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
"qwen3:235b": OllamaModelMetadata(
|
||||
label="Qwen 3:235b",
|
||||
value="qwen3:235b",
|
||||
description="✅ Graphs supported.",
|
||||
size="142GB",
|
||||
supports_graph=True,
|
||||
icon="/static/servers/fastapi/assets/icons/qwen.png",
|
||||
),
|
||||
}
|
||||
|
||||
SUPPORTED_OLLAMA_MODELS = {
|
||||
**SUPPORTED_LLAMA_MODELS,
|
||||
**SUPPORTED_GEMMA_MODELS,
|
||||
**SUPPORTED_DEEPSEEK_MODELS,
|
||||
**SUPPORTED_QWEN_MODELS,
|
||||
}
|
||||
|
|
@ -1,188 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.models import LogMetadata, UserConfig
|
||||
from api.services.logging import LoggingService
|
||||
|
||||
|
||||
def get_presentation_dir(presentation_id: str) -> str:
|
||||
presentation_dir = os.path.join(os.getenv("APP_DATA_DIRECTORY"), presentation_id)
|
||||
os.makedirs(presentation_dir, exist_ok=True)
|
||||
return presentation_dir
|
||||
|
||||
|
||||
def get_presentation_images_dir(presentation_id: str) -> str:
|
||||
presentation_images_dir = os.path.join(
|
||||
get_presentation_dir(presentation_id), "images"
|
||||
)
|
||||
os.makedirs(presentation_images_dir, exist_ok=True)
|
||||
return presentation_images_dir
|
||||
|
||||
|
||||
def get_user_config():
|
||||
user_config_path = os.getenv("USER_CONFIG_PATH")
|
||||
|
||||
existing_config = UserConfig()
|
||||
try:
|
||||
if os.path.exists(user_config_path):
|
||||
with open(user_config_path, "r") as f:
|
||||
existing_config = UserConfig(**json.load(f))
|
||||
except Exception as e:
|
||||
print("Error while loading user config")
|
||||
pass
|
||||
|
||||
return UserConfig(
|
||||
LLM=existing_config.LLM or os.getenv("LLM"),
|
||||
OPENAI_API_KEY=existing_config.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY"),
|
||||
GOOGLE_API_KEY=existing_config.GOOGLE_API_KEY or os.getenv("GOOGLE_API_KEY"),
|
||||
OLLAMA_URL=existing_config.OLLAMA_URL or os.getenv("OLLAMA_URL"),
|
||||
OLLAMA_MODEL=existing_config.OLLAMA_MODEL or os.getenv("OLLAMA_MODEL"),
|
||||
CUSTOM_LLM_URL=existing_config.CUSTOM_LLM_URL or os.getenv("CUSTOM_LLM_URL"),
|
||||
CUSTOM_LLM_API_KEY=existing_config.CUSTOM_LLM_API_KEY
|
||||
or os.getenv("CUSTOM_LLM_API_KEY"),
|
||||
CUSTOM_MODEL=existing_config.CUSTOM_MODEL or os.getenv("CUSTOM_MODEL"),
|
||||
PEXELS_API_KEY=existing_config.PEXELS_API_KEY or os.getenv("PEXELS_API_KEY"),
|
||||
)
|
||||
|
||||
|
||||
def update_env_with_user_config():
|
||||
user_config = get_user_config()
|
||||
if user_config.LLM:
|
||||
os.environ["LLM"] = user_config.LLM
|
||||
if user_config.OPENAI_API_KEY:
|
||||
os.environ["OPENAI_API_KEY"] = user_config.OPENAI_API_KEY
|
||||
if user_config.GOOGLE_API_KEY:
|
||||
os.environ["GOOGLE_API_KEY"] = user_config.GOOGLE_API_KEY
|
||||
if user_config.OLLAMA_URL:
|
||||
os.environ["OLLAMA_URL"] = user_config.OLLAMA_URL
|
||||
if user_config.OLLAMA_MODEL:
|
||||
os.environ["OLLAMA_MODEL"] = user_config.OLLAMA_MODEL
|
||||
if user_config.CUSTOM_LLM_URL:
|
||||
os.environ["CUSTOM_LLM_URL"] = user_config.CUSTOM_LLM_URL
|
||||
if user_config.CUSTOM_LLM_API_KEY:
|
||||
os.environ["CUSTOM_LLM_API_KEY"] = user_config.CUSTOM_LLM_API_KEY
|
||||
if user_config.CUSTOM_MODEL:
|
||||
os.environ["CUSTOM_MODEL"] = user_config.CUSTOM_MODEL
|
||||
if user_config.PEXELS_API_KEY:
|
||||
os.environ["PEXELS_API_KEY"] = user_config.PEXELS_API_KEY
|
||||
|
||||
|
||||
def get_resource(relative_path):
|
||||
base_path = getattr(
|
||||
sys,
|
||||
"_MEIPASS",
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def replace_file_name(old_name: str, new_name: str) -> str:
|
||||
splitted = old_name.split(".")
|
||||
if len(splitted) < 1:
|
||||
return new_name
|
||||
else:
|
||||
return ".".join([new_name, splitted[-1]])
|
||||
|
||||
|
||||
def save_uploaded_files(
|
||||
TEMP_FILE_SERVICE, files: List[UploadFile], file_paths: List[str], temp_dir: str
|
||||
) -> List:
|
||||
full_file_paths = []
|
||||
for index, each_file in enumerate(files):
|
||||
temp_file_path = TEMP_FILE_SERVICE.create_temp_file(
|
||||
file_paths[index], each_file.file.read(), dir_path=temp_dir
|
||||
)
|
||||
full_file_paths.append(temp_file_path)
|
||||
return full_file_paths
|
||||
|
||||
|
||||
async def download_file(url: str, save_path: str, headers: Optional[dict] = None):
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
with open(save_path, "wb") as file:
|
||||
while True:
|
||||
chunk = await response.content.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
file.write(chunk)
|
||||
print(f"File downloaded successfully to {save_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"Failed to download file. HTTP status: {response.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"Error while downloading file from {url} to {save_path}")
|
||||
return False
|
||||
|
||||
|
||||
async def download_files(urls: List[str], save_paths: List[str]):
|
||||
for url, save_path in zip(urls, save_paths):
|
||||
print(url)
|
||||
print(save_path)
|
||||
print("-" * 10)
|
||||
coroutines = [
|
||||
download_file(url, save_paths[index]) for index, url in enumerate(urls)
|
||||
]
|
||||
await asyncio.gather(*coroutines)
|
||||
|
||||
|
||||
async def handle_errors(
|
||||
func, logging_service: LoggingService, log_metadata: LogMetadata, **kwargs
|
||||
):
|
||||
try:
|
||||
logging_service.logger.info(f"START", extra=log_metadata.model_dump())
|
||||
response = await func(
|
||||
logging_service=logging_service, log_metadata=log_metadata, **kwargs
|
||||
)
|
||||
is_stream = isinstance(response, StreamingResponse)
|
||||
logging_service.logger.info(
|
||||
"STREAMING" if is_stream else "END", extra=log_metadata.model_dump()
|
||||
)
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
log_metadata.status_code = e.status_code
|
||||
logging_service.logger.error(
|
||||
f"Raised HTTPException - {e.detail}", extra=log_metadata.model_dump()
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_stack()
|
||||
traceback.print_exc()
|
||||
|
||||
log_metadata.status_code = 400
|
||||
logging_service.logger.critical(
|
||||
"Unhandled Exception",
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
extra=log_metadata.model_dump(),
|
||||
)
|
||||
raise HTTPException(400, "Something went wrong while processing your request.")
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
name, ext = os.path.splitext(filename)
|
||||
sanitized = re.sub(r'[\\/:*?"<>|]', "_", name)
|
||||
sanitized = re.sub(r"[\s_]+", "_", sanitized)
|
||||
sanitized = sanitized.strip(" .")
|
||||
|
||||
if not sanitized:
|
||||
sanitized = "untitled"
|
||||
|
||||
if len(sanitized) > 200:
|
||||
sanitized = sanitized[:200]
|
||||
|
||||
return sanitized + ext
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
from fastapi import HTTPException, UploadFile
|
||||
|
||||
|
||||
def validate_files(
|
||||
field,
|
||||
nullable: bool,
|
||||
multiple: bool,
|
||||
max_size: int,
|
||||
accepted_types: List[str],
|
||||
):
|
||||
|
||||
if field:
|
||||
files: List[UploadFile] = field if multiple else [field]
|
||||
for each_file in files:
|
||||
if (max_size * 1024 * 1024) < each_file.size:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"File '{each_file.filename}' exceeded max upload size of {max_size} MB",
|
||||
)
|
||||
elif each_file.content_type not in accepted_types:
|
||||
raise HTTPException(400, f"File '{each_file.filename}' not accepted.")
|
||||
|
||||
elif not (field or nullable):
|
||||
raise HTTPException(400, "File must be provided.")
|
||||
|
Before Width: | Height: | Size: 4.6 KiB |
|
Before Width: | Height: | Size: 3.9 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 3.4 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 3.6 KiB |
|
Before Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 5.1 KiB |
|
Before Width: | Height: | Size: 2.9 KiB |
|
Before Width: | Height: | Size: 5.3 KiB |
|
Before Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 1.7 KiB |