fix(fastapi): proper model availability checks
This commit is contained in:
parent
eb25416219
commit
d597464717
1 changed files with 43 additions and 41 deletions
|
|
@ -20,58 +20,60 @@ from api.utils.model_utils import (
|
|||
|
||||
can_change_keys = os.getenv("CAN_CHANGE_KEYS") != "false"
|
||||
|
||||
# Ollama model download
|
||||
if not can_change_keys:
|
||||
if get_selected_llm_provider() == SelectedLLMProvider.OPENAI:
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
|
||||
elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE:
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
async def check_llm_model_availability():
|
||||
if not can_change_keys:
|
||||
if get_selected_llm_provider() == SelectedLLMProvider.OPENAI:
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise Exception("OPENAI_API_KEY must be provided")
|
||||
|
||||
elif is_ollama_selected():
|
||||
ollama_model = os.getenv("OLLAMA_MODEL")
|
||||
pexels_api_key = os.getenv("PEXELS_API_KEY")
|
||||
if not (ollama_model or pexels_api_key):
|
||||
raise Exception("OLLAMA_MODEL and PEXELS_API_KEY must be provided")
|
||||
elif get_selected_llm_provider() == SelectedLLMProvider.GOOGLE:
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not google_api_key:
|
||||
raise Exception("GOOGLE_API_KEY must be provided")
|
||||
|
||||
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise Exception(f"Model {ollama_model} is not supported")
|
||||
elif is_ollama_selected():
|
||||
ollama_model = os.getenv("OLLAMA_MODEL")
|
||||
if not ollama_model:
|
||||
raise Exception("OLLAMA_MODEL must be provided")
|
||||
|
||||
print("-" * 50)
|
||||
print("Pulling model: ", ollama_model)
|
||||
asyncio.run(pull_ollama_model(ollama_model))
|
||||
print("Pulled model: ", ollama_model)
|
||||
print("-" * 50)
|
||||
if ollama_model not in SUPPORTED_OLLAMA_MODELS:
|
||||
raise Exception(f"Model {ollama_model} is not supported")
|
||||
|
||||
elif is_custom_llm_selected():
|
||||
custom_model = os.getenv("CUSTOM_MODEL")
|
||||
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
|
||||
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
if not custom_model:
|
||||
raise Exception("CUSTOM_MODEL must be provided")
|
||||
if not custom_llm_url:
|
||||
raise Exception("CUSTOM_LLM_URL must be provided")
|
||||
if not custom_llm_api_key:
|
||||
raise Exception("CUSTOM_LLM_API_KEY must be provided")
|
||||
print("-" * 50)
|
||||
print("Pulling model: ", custom_model)
|
||||
models = asyncio.run(
|
||||
list_available_custom_models(custom_llm_url, custom_llm_api_key)
|
||||
)
|
||||
print("Available models: ", models)
|
||||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
print("-" * 50)
|
||||
print("Pulling model: ", ollama_model)
|
||||
async for event in pull_ollama_model(ollama_model):
|
||||
print(event)
|
||||
print("Pulled model: ", ollama_model)
|
||||
print("-" * 50)
|
||||
|
||||
elif is_custom_llm_selected():
|
||||
custom_model = os.getenv("CUSTOM_MODEL")
|
||||
custom_llm_url = os.getenv("CUSTOM_LLM_URL")
|
||||
custom_llm_api_key = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
if not custom_model:
|
||||
raise Exception("CUSTOM_MODEL must be provided")
|
||||
if not custom_llm_url:
|
||||
raise Exception("CUSTOM_LLM_URL must be provided")
|
||||
if not custom_llm_api_key:
|
||||
raise Exception("CUSTOM_LLM_API_KEY must be provided")
|
||||
print("-" * 50)
|
||||
print("Selecting model: ", custom_model)
|
||||
models = await list_available_custom_models(
|
||||
custom_llm_url, custom_llm_api_key
|
||||
)
|
||||
print("Available models: ", models)
|
||||
print("-" * 50)
|
||||
if custom_model not in models:
|
||||
raise Exception(f"Model {custom_model} is not available")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
os.makedirs(os.getenv("APP_DATA_DIRECTORY"), exist_ok=True)
|
||||
SQLModel.metadata.create_all(sql_engine)
|
||||
await check_llm_model_availability()
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue