feat(fastapi): adds missing ollama and custom llm endpoints and redis service
This commit is contained in:
parent
ffeded7045
commit
578792de76
9 changed files with 268 additions and 6 deletions
59
servers/fastapi/api/v1/ppt/background_tasks.py
Normal file
59
servers/fastapi/api/v1/ppt/background_tasks.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from services import REDIS_SERVICE
|
||||
from utils.ollama import pull_ollama_model
|
||||
|
||||
|
||||
async def pull_ollama_model_background_task(model: str):
|
||||
saved_model_status = OllamaModelStatus(
|
||||
name=model,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
log_event_count = 0
|
||||
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{model}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{model}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
|
||||
REDIS_SERVICE.set(
|
||||
f"ollama_models/{model}",
|
||||
json.dumps(saved_model_status.model_dump(mode="json")),
|
||||
)
|
||||
|
||||
return saved_model_status
|
||||
14
servers/fastapi/api/v1/ppt/endpoints/custom_llm.py
Normal file
14
servers/fastapi/api/v1/ppt/endpoints/custom_llm.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Annotated, List, Optional
|
||||
from fastapi import APIRouter, Body
|
||||
|
||||
from utils.custom_llm_provider import list_available_custom_models
|
||||
|
||||
CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"])
|
||||
|
||||
|
||||
@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str])
|
||||
async def get_available_models(
|
||||
url: Annotated[Optional[str], Body()] = None,
|
||||
api_key: Annotated[Optional[str], Body()] = None,
|
||||
):
|
||||
return await list_available_custom_models(url, api_key)
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
from http.client import HTTPException
|
||||
import os
|
||||
from typing import Annotated, List, Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, Body, File, UploadFile
|
||||
|
||||
from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES
|
||||
from models.decomposed_file_info import DecomposedFileInfo
|
||||
from services import TEMP_FILE_SERVICE
|
||||
from services.documents_loader import DocumentsLoader
|
||||
from utils.randomizers import get_random_uuid
|
||||
from utils.validators import validate_files
|
||||
|
||||
FILES_ROUTER = APIRouter(prefix="/files", tags=["Files"])
|
||||
|
|
@ -18,7 +18,7 @@ async def upload_files(files: Optional[List[UploadFile]]):
|
|||
if not files:
|
||||
raise HTTPException(400, "Documents are required")
|
||||
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid())
|
||||
|
||||
validate_files(files, True, True, 50, UPLOAD_ACCEPTED_FILE_TYPES)
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ async def upload_files(files: Optional[List[UploadFile]]):
|
|||
|
||||
@FILES_ROUTER.post("/decompose", response_model=List[DecomposedFileInfo])
|
||||
async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]):
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4()))
|
||||
temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid())
|
||||
|
||||
txt_files = []
|
||||
other_files = []
|
||||
|
|
@ -56,7 +56,7 @@ async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]):
|
|||
response = []
|
||||
for index, parsed_doc in enumerate(parsed_documents):
|
||||
file_path = TEMP_FILE_SERVICE.create_temp_file_path(
|
||||
f"{str(uuid.uuid4())}.txt", temp_dir
|
||||
f"{get_random_uuid()}.txt", temp_dir
|
||||
)
|
||||
parsed_doc = parsed_doc.replace("<br>", "\n")
|
||||
with open(file_path, "w") as text_file:
|
||||
|
|
|
|||
72
servers/fastapi/api/v1/ppt/endpoints/ollama.py
Normal file
72
servers/fastapi/api/v1/ppt/endpoints/ollama.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import json
|
||||
from typing import List
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from api.v1.ppt.background_tasks import pull_ollama_model_background_task
|
||||
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
||||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
from services import REDIS_SERVICE
|
||||
from utils.ollama import list_pulled_ollama_models
|
||||
|
||||
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
||||
|
||||
|
||||
@OLLAMA_ROUTER.get("/models/supported", response_model=List[OllamaModelMetadata])
|
||||
def get_supported_models():
|
||||
return SUPPORTED_OLLAMA_MODELS.values()
|
||||
|
||||
|
||||
@OLLAMA_ROUTER.get("/models/available", response_model=List[OllamaModelStatus])
|
||||
async def get_available_models():
|
||||
return await list_pulled_ollama_models()
|
||||
|
||||
|
||||
@OLLAMA_ROUTER.get("/models/pull", response_model=OllamaModelStatus)
|
||||
async def pull_model(model: str, background_tasks: BackgroundTasks):
|
||||
|
||||
if model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {model} is not supported",
|
||||
)
|
||||
|
||||
try:
|
||||
pulled_models = await list_pulled_ollama_models()
|
||||
filtered_models = [
|
||||
pulled_model for pulled_model in pulled_models if pulled_model.name == model
|
||||
]
|
||||
if filtered_models:
|
||||
return filtered_models[0]
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to check pulled models: {e}",
|
||||
)
|
||||
|
||||
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}")
|
||||
|
||||
# If the model is being pulled, return the model
|
||||
if saved_model_status:
|
||||
saved_model_status_json = json.loads(saved_model_status)
|
||||
# If the model is being pulled, return the model
|
||||
# ? If the model status is pulled in redis but was not found while listing pulled models,
|
||||
# ? it means the model was deleted and we need to pull it again
|
||||
if (
|
||||
saved_model_status_json["status"] == "error"
|
||||
or saved_model_status_json["status"] == "pulled"
|
||||
):
|
||||
REDIS_SERVICE.delete(f"ollama_models/{model}")
|
||||
else:
|
||||
return saved_model_status_json
|
||||
|
||||
# If the model is not being pulled, pull the model
|
||||
background_tasks.add_task(pull_ollama_model_background_task, model)
|
||||
|
||||
return OllamaModelStatus(
|
||||
name=model,
|
||||
status="pulling",
|
||||
done=False,
|
||||
)
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER
|
||||
from api.v1.ppt.endpoints.files import FILES_ROUTER
|
||||
from api.v1.ppt.endpoints.icons import ICONS_ROUTER
|
||||
from api.v1.ppt.endpoints.images import IMAGES_ROUTER
|
||||
from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER
|
||||
from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER
|
||||
from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER
|
||||
|
||||
|
|
@ -14,3 +16,5 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER)
|
|||
API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(ICONS_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER)
|
||||
API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from models.ollama_model_metadata import OllamaModelMetadata
|
||||
|
||||
|
||||
SUPPORTED_LLAMA_MODELS = {
|
||||
SUPPORTED_OLLAMA_MODELS = {
|
||||
"llama3:8b": OllamaModelMetadata(
|
||||
label="Llama 3:8b",
|
||||
value="llama3:8b",
|
||||
|
|
@ -246,7 +246,7 @@ SUPPORTED_QWEN_MODELS = {
|
|||
}
|
||||
|
||||
SUPPORTED_OLLAMA_MODELS = {
|
||||
**SUPPORTED_LLAMA_MODELS,
|
||||
**SUPPORTED_OLLAMA_MODELS,
|
||||
**SUPPORTED_GEMMA_MODELS,
|
||||
**SUPPORTED_DEEPSEEK_MODELS,
|
||||
**SUPPORTED_QWEN_MODELS,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ aiohttp==3.12.14
|
|||
aiosignal==1.4.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
async-timeout==5.0.1
|
||||
attrs==25.3.0
|
||||
cachetools==5.5.2
|
||||
certifi==2025.7.14
|
||||
|
|
@ -55,6 +56,7 @@ python-dotenv==1.1.1
|
|||
python-multipart==0.0.20
|
||||
python-pptx==1.0.2
|
||||
PyYAML==6.0.2
|
||||
redis==6.2.0
|
||||
requests==2.32.4
|
||||
rich==14.0.0
|
||||
rich-toolkit==0.14.8
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from services.redis_service import RedisService
|
||||
from services.temp_file_service import TempFileService
|
||||
from services.database import sql_engine
|
||||
|
||||
|
||||
TEMP_FILE_SERVICE = TempFileService()
|
||||
SQL_ENGINE = sql_engine
|
||||
REDIS_SERVICE = RedisService()
|
||||
|
|
|
|||
109
servers/fastapi/services/redis_service.py
Normal file
109
servers/fastapi/services/redis_service.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import os
|
||||
from typing import Any, Optional
|
||||
import redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
|
||||
class RedisService:
|
||||
def __init__(self):
|
||||
self.redis_host = os.getenv("REDIS_HOST", "localhost")
|
||||
self.redis_port = int(os.getenv("REDIS_PORT", "6379"))
|
||||
self.redis_db = int(os.getenv("REDIS_DB", "0"))
|
||||
self.redis_password = os.getenv("REDIS_PASSWORD")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
db=self.redis_db,
|
||||
password=self.redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
|
||||
try:
|
||||
return self.client.set(key, value, ex=expire)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
try:
|
||||
return self.client.get(key)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.delete(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.exists(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def set_hash(self, name: str, mapping: dict) -> bool:
|
||||
try:
|
||||
return self.client.hmset(name, mapping)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_hash(self, name: str) -> Optional[dict]:
|
||||
try:
|
||||
return self.client.hgetall(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete_hash(self, name: str, *fields: str) -> int:
|
||||
try:
|
||||
return self.client.hdel(name, *fields)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def set_list(self, name: str, values: list) -> bool:
|
||||
try:
|
||||
self.client.delete(name)
|
||||
if values:
|
||||
self.client.rpush(name, *values)
|
||||
return True
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
|
||||
try:
|
||||
return self.client.lrange(name, start, end)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def add_to_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.sadd(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def get_set(self, name: str) -> Optional[set]:
|
||||
try:
|
||||
return self.client.smembers(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def remove_from_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.srem(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def clear(self) -> bool:
|
||||
try:
|
||||
return self.client.flushdb()
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.client.close()
|
||||
except RedisError:
|
||||
pass
|
||||
Loading…
Add table
Reference in a new issue