refactor(fastapi): removes redis and uses sqlite to store background task status
This commit is contained in:
parent
c982e4b709
commit
ec9104c91c
10 changed files with 188 additions and 66 deletions
|
|
@ -1,4 +1,6 @@
|
|||
from datetime import datetime
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models.ollama_model_status import OllamaModelStatus
|
||||
|
|
@ -15,51 +17,60 @@ async def pull_ollama_model_background_task(model: str):
|
|||
)
|
||||
log_event_count = 0
|
||||
|
||||
async with get_container_db_async_session() as session:
|
||||
session: AsyncSession = session
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
session = await get_container_db_async_session().__anext__()
|
||||
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
try:
|
||||
async for event in pull_ollama_model(model):
|
||||
log_event_count += 1
|
||||
if log_event_count != 1 and log_event_count % 20 != 0:
|
||||
continue
|
||||
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
if "completed" in event:
|
||||
saved_model_status.downloaded = event["completed"]
|
||||
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
if not saved_model_status.size and "total" in event:
|
||||
saved_model_status.size = event["total"]
|
||||
|
||||
session.add(
|
||||
OllamaPullStatus(
|
||||
id=model, status=saved_model_status.model_dump(mode="json")
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
if "status" in event:
|
||||
saved_model_status.status = event["status"]
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
session.add(
|
||||
OllamaPullStatus(
|
||||
id=model, status=saved_model_status.model_dump(mode="json")
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
|
||||
except Exception as e:
|
||||
saved_model_status.status = "error"
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
|
||||
session.add(
|
||||
OllamaPullStatus(
|
||||
id=model, status=saved_model_status.model_dump(mode="json")
|
||||
)
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to pull model: {e}",
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
saved_model_status.done = True
|
||||
saved_model_status.status = "pulled"
|
||||
saved_model_status.downloaded = saved_model_status.size
|
||||
|
||||
await upsert_ollama_pull_status(session, model, saved_model_status)
|
||||
await session.close()
|
||||
|
||||
|
||||
async def upsert_ollama_pull_status(
|
||||
session: AsyncSession, model: str, model_status: OllamaModelStatus
|
||||
):
|
||||
stmt = select(OllamaPullStatus).where(OllamaPullStatus.id == model)
|
||||
result = await session.execute(stmt)
|
||||
existing_record = result.scalar_one_or_none()
|
||||
|
||||
if existing_record:
|
||||
existing_record.status = model_status.model_dump(mode="json")
|
||||
existing_record.last_updated = datetime.now()
|
||||
else:
|
||||
new_record = OllamaPullStatus(
|
||||
id=model,
|
||||
status=model_status.model_dump(mode="json"),
|
||||
last_updated=datetime.now(),
|
||||
)
|
||||
session.add(new_record)
|
||||
|
||||
await session.commit()
|
||||
await session.flush()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import List
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
|
|
@ -52,10 +53,11 @@ async def pull_model(
|
|||
detail=f"Failed to check pulled models: {e}",
|
||||
)
|
||||
|
||||
saved_pull_status = None
|
||||
saved_model_status = None
|
||||
try:
|
||||
result = await session.get(OllamaPullStatus, model)
|
||||
saved_model_status = result.status
|
||||
saved_pull_status = await session.get(OllamaPullStatus, model)
|
||||
saved_model_status = saved_pull_status.status
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
|
@ -67,8 +69,9 @@ async def pull_model(
|
|||
if (
|
||||
saved_model_status["status"] == "error"
|
||||
or saved_model_status["status"] == "pulled"
|
||||
or saved_pull_status.last_updated < (datetime.now() - timedelta(seconds=10))
|
||||
):
|
||||
await session.delete(OllamaPullStatus, model)
|
||||
await session.delete(saved_pull_status)
|
||||
else:
|
||||
return saved_model_status
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@ from datetime import datetime
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import Field
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from services.database import MAIN_DB_BASE
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class ImageAsset(MAIN_DB_BASE, table=True):
|
||||
class ImageAsset(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
path: str
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from sqlmodel import Field, Column, JSON
|
||||
from sqlmodel import Field, Column, JSON, SQLModel
|
||||
|
||||
from services.database import MAIN_DB_BASE
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class KeyValueSqlModel(MAIN_DB_BASE, table=True):
|
||||
class KeyValueSqlModel(SQLModel, table=True):
|
||||
id: str = Field(default_factory=get_random_uuid, primary_key=True)
|
||||
key: str = Field(index=True)
|
||||
value: dict = Field(sa_column=Column(JSON))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from sqlmodel import Field, Column, JSON
|
||||
from services.database import CONTAINER_DB_BASE
|
||||
from datetime import datetime
|
||||
from sqlmodel import Field, Column, JSON, SQLModel, DateTime
|
||||
|
||||
|
||||
class OllamaPullStatus(CONTAINER_DB_BASE, table=True):
|
||||
class OllamaPullStatus(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True)
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
|
||||
status: dict = Field(sa_column=Column(JSON))
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlmodel import Field
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from models.presentation_layout import PresentationLayoutModel
|
||||
from models.presentation_outline_model import PresentationOutlineModel
|
||||
from models.presentation_structure_model import PresentationStructureModel
|
||||
from services.database import MAIN_DB_BASE
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class PresentationModel(MAIN_DB_BASE, table=True):
|
||||
class PresentationModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True)
|
||||
prompt: str
|
||||
n_slides: int
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from typing import Optional
|
||||
from sqlmodel import Field, Column, JSON
|
||||
from sqlmodel import Field, Column, JSON, SQLModel
|
||||
|
||||
from services.database import MAIN_DB_BASE
|
||||
from utils.randomizers import get_random_uuid
|
||||
|
||||
|
||||
class SlideModel(MAIN_DB_BASE, table=True):
|
||||
class SlideModel(SQLModel, table=True):
|
||||
id: str = Field(primary_key=True, default_factory=get_random_uuid)
|
||||
presentation: str
|
||||
layout_group: str
|
||||
|
|
|
|||
|
|
@ -6,13 +6,11 @@ from sqlalchemy.ext.asyncio import (
|
|||
async_sessionmaker,
|
||||
AsyncSession,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from utils.get_env import get_app_data_directory_env, get_database_url_env
|
||||
|
||||
|
||||
MAIN_DB_BASE = DeclarativeBase()
|
||||
|
||||
raw_database_url = get_database_url_env() or "sqlite:///" + os.path.join(
|
||||
get_app_data_directory_env() or "/tmp/presenton", "fastapi.db"
|
||||
)
|
||||
|
|
@ -40,8 +38,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||
|
||||
|
||||
# Container DB (Lives inside the container)
|
||||
CONTAINER_DB_BASE = DeclarativeBase()
|
||||
|
||||
container_db_url = "sqlite+aiosqlite:////app/container.db"
|
||||
container_db_engine: AsyncEngine = create_async_engine(
|
||||
container_db_url, connect_args={"check_same_thread": False}
|
||||
|
|
@ -59,7 +55,7 @@ async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]
|
|||
# Create Database and Tables
|
||||
async def create_db_and_tables():
|
||||
async with sql_engine.begin() as conn:
|
||||
await conn.run_sync(MAIN_DB_BASE.metadata.create_all)
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async with container_db_engine.begin() as conn:
|
||||
await conn.run_sync(CONTAINER_DB_BASE.metadata.create_all)
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
|
|
|||
115
servers/fastapi/services/redis_service.py
Normal file
115
servers/fastapi/services/redis_service.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
from typing import Any, Optional
|
||||
import redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from utils.get_env import (
|
||||
get_redis_db_env,
|
||||
get_redis_host_env,
|
||||
get_redis_password_env,
|
||||
get_redis_port_env,
|
||||
)
|
||||
|
||||
|
||||
class RedisService:
|
||||
def __init__(self):
|
||||
self.redis_host = get_redis_host_env() or "localhost"
|
||||
self.redis_port = int(get_redis_port_env() or "6379")
|
||||
self.redis_db = int(get_redis_db_env() or "0")
|
||||
self.redis_password = get_redis_password_env() or None
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self) -> redis.Redis:
|
||||
return redis.Redis(
|
||||
host=self.redis_host,
|
||||
port=self.redis_port,
|
||||
db=self.redis_db,
|
||||
password=self.redis_password,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
|
||||
try:
|
||||
return self.client.set(key, value, ex=expire)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
try:
|
||||
return self.client.get(key)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.delete(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
try:
|
||||
return bool(self.client.exists(key))
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def set_hash(self, name: str, mapping: dict) -> bool:
|
||||
try:
|
||||
return self.client.hmset(name, mapping)
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_hash(self, name: str) -> Optional[dict]:
|
||||
try:
|
||||
return self.client.hgetall(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def delete_hash(self, name: str, *fields: str) -> int:
|
||||
try:
|
||||
return self.client.hdel(name, *fields)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def set_list(self, name: str, values: list) -> bool:
|
||||
try:
|
||||
self.client.delete(name)
|
||||
if values:
|
||||
self.client.rpush(name, *values)
|
||||
return True
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
|
||||
try:
|
||||
return self.client.lrange(name, start, end)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def add_to_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.sadd(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def get_set(self, name: str) -> Optional[set]:
|
||||
try:
|
||||
return self.client.smembers(name)
|
||||
except RedisError:
|
||||
return None
|
||||
|
||||
def remove_from_set(self, name: str, *values: str) -> int:
|
||||
try:
|
||||
return self.client.srem(name, *values)
|
||||
except RedisError:
|
||||
return 0
|
||||
|
||||
def clear(self) -> bool:
|
||||
try:
|
||||
return self.client.flushdb()
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.client.close()
|
||||
except RedisError:
|
||||
pass
|
||||
|
|
@ -5,7 +5,7 @@ import nltk
|
|||
from models.document_chunk import DocumentChunk
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
nltk.data.find("tokenizers/punkt", paths=["./nltk"])
|
||||
except LookupError:
|
||||
nltk.download("punkt", download_dir="./nltk")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue