diff --git a/servers/fastapi/api/v1/ppt/background_tasks.py b/servers/fastapi/api/v1/ppt/background_tasks.py new file mode 100644 index 00000000..e9a604f6 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/background_tasks.py @@ -0,0 +1,59 @@ +import json + +from fastapi import HTTPException + +from models.ollama_model_status import OllamaModelStatus +from services import REDIS_SERVICE +from utils.ollama import pull_ollama_model + + +async def pull_ollama_model_background_task(model: str): + saved_model_status = OllamaModelStatus( + name=model, + status="pulling", + done=False, + ) + log_event_count = 0 + + try: + async for event in pull_ollama_model(model): + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue + + if "completed" in event: + saved_model_status.downloaded = event["completed"] + + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] + + if "status" in event: + saved_model_status.status = event["status"] + + REDIS_SERVICE.set( + f"ollama_models/{model}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + + except Exception as e: + saved_model_status.status = "error" + saved_model_status.done = True + REDIS_SERVICE.set( + f"ollama_models/{model}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + raise HTTPException( + status_code=500, + detail=f"Failed to pull model: {e}", + ) + + saved_model_status.done = True + saved_model_status.status = "pulled" + saved_model_status.downloaded = saved_model_status.size + + REDIS_SERVICE.set( + f"ollama_models/{model}", + json.dumps(saved_model_status.model_dump(mode="json")), + ) + + return saved_model_status diff --git a/servers/fastapi/api/v1/ppt/endpoints/custom_llm.py b/servers/fastapi/api/v1/ppt/endpoints/custom_llm.py new file mode 100644 index 00000000..8a44cb22 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/custom_llm.py @@ -0,0 +1,14 @@ +from typing import Annotated, List, Optional +from fastapi import APIRouter, Body + +from utils.custom_llm_provider import list_available_custom_models + +CUSTOM_LLM_ROUTER = APIRouter(prefix="/custom_llm", tags=["Custom LLM"]) + + +@CUSTOM_LLM_ROUTER.post("/models/available", response_model=List[str]) +async def get_available_models( + url: Annotated[Optional[str], Body()] = None, + api_key: Annotated[Optional[str], Body()] = None, +): + return await list_available_custom_models(url, api_key) diff --git a/servers/fastapi/api/v1/ppt/endpoints/files.py b/servers/fastapi/api/v1/ppt/endpoints/files.py index e2f43329..b19e31d0 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/files.py +++ b/servers/fastapi/api/v1/ppt/endpoints/files.py @@ -1,13 +1,13 @@ from http.client import HTTPException import os from typing import Annotated, List, Optional -import uuid from fastapi import APIRouter, Body, File, UploadFile from constants.documents import UPLOAD_ACCEPTED_FILE_TYPES from models.decomposed_file_info import DecomposedFileInfo from services import TEMP_FILE_SERVICE from services.documents_loader import DocumentsLoader +from utils.randomizers import get_random_uuid from utils.validators import validate_files FILES_ROUTER = APIRouter(prefix="/files", tags=["Files"]) @@ -18,7 +18,7 @@ async def upload_files(files: Optional[List[UploadFile]]): if not files: raise HTTPException(400, "Documents are required") - temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4())) + temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid()) validate_files(files, True, True, 50, UPLOAD_ACCEPTED_FILE_TYPES) @@ -39,7 +39,7 @@ async def upload_files(files: Optional[List[UploadFile]]): @FILES_ROUTER.post("/decompose", response_model=List[DecomposedFileInfo]) async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]): - temp_dir = TEMP_FILE_SERVICE.create_temp_dir(str(uuid.uuid4())) + temp_dir = TEMP_FILE_SERVICE.create_temp_dir(get_random_uuid()) txt_files = [] other_files = [] @@ -56,7 +56,7 @@ async def decompose_files(file_paths: Annotated[List[str], Body(embed=True)]): response = [] for index, parsed_doc in enumerate(parsed_documents): file_path = TEMP_FILE_SERVICE.create_temp_file_path( - f"{str(uuid.uuid4())}.txt", temp_dir + f"{get_random_uuid()}.txt", temp_dir ) parsed_doc = parsed_doc.replace("
", "\n") with open(file_path, "w") as text_file: diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py new file mode 100644 index 00000000..04b1a505 --- /dev/null +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -0,0 +1,72 @@ +import json +from typing import List +from fastapi import APIRouter, BackgroundTasks, HTTPException + +from api.v1.ppt.background_tasks import pull_ollama_model_background_task +from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS +from models.ollama_model_metadata import OllamaModelMetadata +from models.ollama_model_status import OllamaModelStatus +from services import REDIS_SERVICE +from utils.ollama import list_pulled_ollama_models + +OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"]) + + +@OLLAMA_ROUTER.get("/models/supported", response_model=List[OllamaModelMetadata]) +def get_supported_models(): + return SUPPORTED_OLLAMA_MODELS.values() + + +@OLLAMA_ROUTER.get("/models/available", response_model=List[OllamaModelStatus]) +async def get_available_models(): + return await list_pulled_ollama_models() + + +@OLLAMA_ROUTER.get("/models/pull", response_model=OllamaModelStatus) +async def pull_model(model: str, background_tasks: BackgroundTasks): + + if model not in SUPPORTED_OLLAMA_MODELS: + raise HTTPException( + status_code=400, + detail=f"Model {model} is not supported", + ) + + try: + pulled_models = await list_pulled_ollama_models() + filtered_models = [ + pulled_model for pulled_model in pulled_models if pulled_model.name == model + ] + if filtered_models: + return filtered_models[0] + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to check pulled models: {e}", + ) + + saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}") + + # If the model is being pulled, return the model + if saved_model_status: + saved_model_status_json = json.loads(saved_model_status) + # If the model is being pulled, return the model + # ? If the model status is pulled in redis but was not found while listing pulled models, + # ? it means the model was deleted and we need to pull it again + if ( + saved_model_status_json["status"] == "error" + or saved_model_status_json["status"] == "pulled" + ): + REDIS_SERVICE.delete(f"ollama_models/{model}") + else: + return saved_model_status_json + + # If the model is not being pulled, pull the model + background_tasks.add_task(pull_ollama_model_background_task, model) + + return OllamaModelStatus( + name=model, + status="pulling", + done=False, + ) diff --git a/servers/fastapi/api/v1/ppt/router.py b/servers/fastapi/api/v1/ppt/router.py index ddff5676..15fb1b52 100644 --- a/servers/fastapi/api/v1/ppt/router.py +++ b/servers/fastapi/api/v1/ppt/router.py @@ -1,8 +1,10 @@ from fastapi import APIRouter +from api.v1.ppt.endpoints.custom_llm import CUSTOM_LLM_ROUTER from api.v1.ppt.endpoints.files import FILES_ROUTER from api.v1.ppt.endpoints.icons import ICONS_ROUTER from api.v1.ppt.endpoints.images import IMAGES_ROUTER +from api.v1.ppt.endpoints.ollama import OLLAMA_ROUTER from api.v1.ppt.endpoints.outlines import OUTLINES_ROUTER from api.v1.ppt.endpoints.presentation import PRESENTATION_ROUTER @@ -14,3 +16,5 @@ API_V1_PPT_ROUTER.include_router(OUTLINES_ROUTER) API_V1_PPT_ROUTER.include_router(PRESENTATION_ROUTER) API_V1_PPT_ROUTER.include_router(IMAGES_ROUTER) API_V1_PPT_ROUTER.include_router(ICONS_ROUTER) +API_V1_PPT_ROUTER.include_router(OLLAMA_ROUTER) +API_V1_PPT_ROUTER.include_router(CUSTOM_LLM_ROUTER) diff --git a/servers/fastapi/constants/supported_ollama_models.py b/servers/fastapi/constants/supported_ollama_models.py index 2589d54e..57a9a257 100644 --- a/servers/fastapi/constants/supported_ollama_models.py +++ b/servers/fastapi/constants/supported_ollama_models.py @@ -1,7 +1,7 @@ from models.ollama_model_metadata import OllamaModelMetadata -SUPPORTED_LLAMA_MODELS = { +SUPPORTED_OLLAMA_MODELS = { "llama3:8b": OllamaModelMetadata( label="Llama 3:8b", value="llama3:8b", @@ -246,7 +246,7 @@ SUPPORTED_QWEN_MODELS = { } SUPPORTED_OLLAMA_MODELS = { - **SUPPORTED_LLAMA_MODELS, + **SUPPORTED_OLLAMA_MODELS, **SUPPORTED_GEMMA_MODELS, **SUPPORTED_DEEPSEEK_MODELS, **SUPPORTED_QWEN_MODELS, diff --git a/servers/fastapi/requirements.txt b/servers/fastapi/requirements.txt index 3edb9b66..afeb38a0 100644 --- a/servers/fastapi/requirements.txt +++ b/servers/fastapi/requirements.txt @@ -3,6 +3,7 @@ aiohttp==3.12.14 aiosignal==1.4.0 annotated-types==0.7.0 anyio==4.9.0 +async-timeout==5.0.1 attrs==25.3.0 cachetools==5.5.2 certifi==2025.7.14 @@ -55,6 +56,7 @@ python-dotenv==1.1.1 python-multipart==0.0.20 python-pptx==1.0.2 PyYAML==6.0.2 +redis==6.2.0 requests==2.32.4 rich==14.0.0 rich-toolkit==0.14.8 diff --git a/servers/fastapi/services/__init__.py b/servers/fastapi/services/__init__.py index 89bac591..56843e2b 100644 --- a/servers/fastapi/services/__init__.py +++ b/servers/fastapi/services/__init__.py @@ -1,6 +1,8 @@ +from services.redis_service import RedisService from services.temp_file_service import TempFileService from services.database import sql_engine TEMP_FILE_SERVICE = TempFileService() SQL_ENGINE = sql_engine +REDIS_SERVICE = RedisService() diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py new file mode 100644 index 00000000..5d9dff22 --- /dev/null +++ b/servers/fastapi/services/redis_service.py @@ -0,0 +1,109 @@ +import os +from typing import Any, Optional +import redis +from redis.exceptions import RedisError + + +class RedisService: + def __init__(self): + self.redis_host = os.getenv("REDIS_HOST", "localhost") + self.redis_port = int(os.getenv("REDIS_PORT", "6379")) + self.redis_db = int(os.getenv("REDIS_DB", "0")) + self.redis_password = os.getenv("REDIS_PASSWORD") + self.client = self._create_client() + + def _create_client(self) -> redis.Redis: + return redis.Redis( + host=self.redis_host, + port=self.redis_port, + db=self.redis_db, + password=self.redis_password, + decode_responses=True, + ) + + def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: + try: + return self.client.set(key, value, ex=expire) + except RedisError: + return False + + def get(self, key: str) -> Optional[str]: + try: + return self.client.get(key) + except RedisError: + return None + + def delete(self, key: str) -> bool: + try: + return bool(self.client.delete(key)) + except RedisError: + return False + + def exists(self, key: str) -> bool: + try: + return bool(self.client.exists(key)) + except RedisError: + return False + + def set_hash(self, name: str, mapping: dict) -> bool: + try: + return self.client.hmset(name, mapping) + except RedisError: + return False + + def get_hash(self, name: str) -> Optional[dict]: + try: + return self.client.hgetall(name) + except RedisError: + return None + + def delete_hash(self, name: str, *fields: str) -> int: + try: + return self.client.hdel(name, *fields) + except RedisError: + return 0 + + def set_list(self, name: str, values: list) -> bool: + try: + self.client.delete(name) + if values: + self.client.rpush(name, *values) + return True + except RedisError: + return False + + def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: + try: + return self.client.lrange(name, start, end) + except RedisError: + return None + + def add_to_set(self, name: str, *values: str) -> int: + try: + return self.client.sadd(name, *values) + except RedisError: + return 0 + + def get_set(self, name: str) -> Optional[set]: + try: + return self.client.smembers(name) + except RedisError: + return None + + def remove_from_set(self, name: str, *values: str) -> int: + try: + return self.client.srem(name, *values) + except RedisError: + return 0 + + def clear(self) -> bool: + try: + return self.client.flushdb() + except RedisError: + return False + + def close(self): + try: + self.client.close() + except RedisError: + pass