72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
import json
|
|
from typing import List
|
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
|
|
|
from api.v1.ppt.background_tasks import pull_ollama_model_background_task
|
|
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
|
|
from models.ollama_model_metadata import OllamaModelMetadata
|
|
from models.ollama_model_status import OllamaModelStatus
|
|
from services import REDIS_SERVICE
|
|
from utils.ollama import list_pulled_ollama_models
|
|
|
|
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
|
|
|
|
|
|
@OLLAMA_ROUTER.get("/models/supported", response_model=List[OllamaModelMetadata])
|
|
def get_supported_models():
|
|
return SUPPORTED_OLLAMA_MODELS.values()
|
|
|
|
|
|
@OLLAMA_ROUTER.get("/models/available", response_model=List[OllamaModelStatus])
|
|
async def get_available_models():
|
|
return await list_pulled_ollama_models()
|
|
|
|
|
|
@OLLAMA_ROUTER.get("/model/pull", response_model=OllamaModelStatus)
|
|
async def pull_model(model: str, background_tasks: BackgroundTasks):
|
|
|
|
if model not in SUPPORTED_OLLAMA_MODELS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Model {model} is not supported",
|
|
)
|
|
|
|
try:
|
|
pulled_models = await list_pulled_ollama_models()
|
|
filtered_models = [
|
|
pulled_model for pulled_model in pulled_models if pulled_model.name == model
|
|
]
|
|
if filtered_models:
|
|
return filtered_models[0]
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to check pulled models: {e}",
|
|
)
|
|
|
|
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}")
|
|
|
|
# If the model is being pulled, return the model
|
|
if saved_model_status:
|
|
saved_model_status_json = json.loads(saved_model_status)
|
|
# If the model is being pulled, return the model
|
|
# ? If the model status is pulled in redis but was not found while listing pulled models,
|
|
# ? it means the model was deleted and we need to pull it again
|
|
if (
|
|
saved_model_status_json["status"] == "error"
|
|
or saved_model_status_json["status"] == "pulled"
|
|
):
|
|
REDIS_SERVICE.delete(f"ollama_models/{model}")
|
|
else:
|
|
return saved_model_status_json
|
|
|
|
# If the model is not being pulled, pull the model
|
|
background_tasks.add_task(pull_ollama_model_background_task, model)
|
|
|
|
return OllamaModelStatus(
|
|
name=model,
|
|
status="pulling",
|
|
done=False,
|
|
)
|