refactor: removes redis and uses sqlite to store background tasks status

This commit is contained in:
sauravniraula 2025-08-04 17:32:25 +05:45
parent ee271b80dc
commit c982e4b709
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
14 changed files with 115 additions and 181 deletions

View file

@ -6,4 +6,6 @@ servers/fastapi/debug
servers/fastapi/.venv
servers/nextjs/node_modules
servers/nextjs/.next
servers/nextjs/.next
container.db

3
.gitignore vendored
View file

@ -14,4 +14,5 @@ debug
my-doc.txt
generated_models
nltk
chroma
chroma
container.db

View file

@ -3,8 +3,7 @@ FROM python:3.11-slim-bookworm
# Install Node.js and npm
RUN apt-get update && apt-get install -y \
nginx \
curl \
redis-server
curl
# Install Node.js 20 using NodeSource repository
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
@ -25,7 +24,7 @@ RUN curl -fsSL https://ollama.com/install.sh | sh
# Install dependencies for FastAPI
RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \
pathvalidate pdfplumber nltk chromadb sqlmodel redis \
pathvalidate pdfplumber nltk chromadb sqlmodel \
anthropic google-genai openai fastmcp
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu

View file

@ -3,8 +3,7 @@ FROM python:3.11-slim-bookworm
# Install Node.js and npm
RUN apt-get update && apt-get install -y \
nginx \
curl \
redis-server
curl
# Install Node.js 20 using NodeSource repository
@ -27,7 +26,7 @@ RUN curl -fsSL http://ollama.com/install.sh | sh
# Install dependencies for FastAPI
RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \
pathvalidate pdfplumber nltk chromadb sqlmodel redis \
pathvalidate pdfplumber nltk chromadb sqlmodel \
anthropic google-genai openai fastmcp
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu

View file

@ -1,9 +1,9 @@
import json
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from models.ollama_model_status import OllamaModelStatus
from services import REDIS_SERVICE
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from utils.ollama import pull_ollama_model
@ -15,45 +15,51 @@ async def pull_ollama_model_background_task(model: str):
)
log_event_count = 0
try:
async for event in pull_ollama_model(model):
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
async with get_container_db_async_session() as session:
session: AsyncSession = session
try:
async for event in pull_ollama_model(model):
log_event_count += 1
if log_event_count != 1 and log_event_count % 20 != 0:
continue
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if "completed" in event:
saved_model_status.downloaded = event["completed"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if not saved_model_status.size and "total" in event:
saved_model_status.size = event["total"]
if "status" in event:
saved_model_status.status = event["status"]
if "status" in event:
saved_model_status.status = event["status"]
REDIS_SERVICE.set(
f"ollama_models/{model}",
json.dumps(saved_model_status.model_dump(mode="json")),
session.add(
OllamaPullStatus(
id=model, status=saved_model_status.model_dump(mode="json")
)
)
await session.commit()
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
session.add(
OllamaPullStatus(
id=model, status=saved_model_status.model_dump(mode="json")
)
)
await session.commit()
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
except Exception as e:
saved_model_status.status = "error"
saved_model_status.done = True
REDIS_SERVICE.set(
f"ollama_models/{model}",
json.dumps(saved_model_status.model_dump(mode="json")),
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
session.add(
OllamaPullStatus(
id=model, status=saved_model_status.model_dump(mode="json")
)
)
raise HTTPException(
status_code=500,
detail=f"Failed to pull model: {e}",
)
saved_model_status.done = True
saved_model_status.status = "pulled"
saved_model_status.downloaded = saved_model_status.size
REDIS_SERVICE.set(
f"ollama_models/{model}",
json.dumps(saved_model_status.model_dump(mode="json")),
)
return saved_model_status
await session.commit()

View file

@ -1,12 +1,14 @@
import json
from typing import List
from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from api.v1.ppt.background_tasks import pull_ollama_model_background_task
from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS
from models.ollama_model_metadata import OllamaModelMetadata
from models.ollama_model_status import OllamaModelStatus
from services import REDIS_SERVICE
from models.sql.ollama_pull_status import OllamaPullStatus
from services.database import get_container_db_async_session
from utils.ollama import list_pulled_ollama_models
OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"])
@ -23,7 +25,11 @@ async def get_available_models():
@OLLAMA_ROUTER.get("/model/pull", response_model=OllamaModelStatus)
async def pull_model(model: str, background_tasks: BackgroundTasks):
async def pull_model(
model: str,
background_tasks: BackgroundTasks,
session: AsyncSession = Depends(get_container_db_async_session),
):
if model not in SUPPORTED_OLLAMA_MODELS:
raise HTTPException(
@ -46,21 +52,25 @@ async def pull_model(model: str, background_tasks: BackgroundTasks):
detail=f"Failed to check pulled models: {e}",
)
saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}")
saved_model_status = None
try:
result = await session.get(OllamaPullStatus, model)
saved_model_status = result.status
except Exception as e:
pass
# If the model is being pulled, return the model
if saved_model_status:
saved_model_status_json = json.loads(saved_model_status)
# If the model is being pulled, return the model
# ? If the model status is pulled in redis but was not found while listing pulled models,
# ? it means the model was deleted and we need to pull it again
if (
saved_model_status_json["status"] == "error"
or saved_model_status_json["status"] == "pulled"
saved_model_status["status"] == "error"
or saved_model_status["status"] == "pulled"
):
REDIS_SERVICE.delete(f"ollama_models/{model}")
await session.delete(OllamaPullStatus, model)
else:
return saved_model_status_json
return saved_model_status
# If the model is not being pulled, pull the model
background_tasks.add_task(pull_ollama_model_background_task, model)

View file

@ -2,12 +2,13 @@ from datetime import datetime
from typing import Optional
from sqlalchemy import JSON, Column, DateTime
from sqlmodel import SQLModel, Field
from sqlmodel import Field
from services.database import MAIN_DB_BASE
from utils.randomizers import get_random_uuid
class ImageAsset(SQLModel, table=True):
class ImageAsset(MAIN_DB_BASE, table=True):
id: str = Field(default_factory=get_random_uuid, primary_key=True)
created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now))
path: str

View file

@ -1,9 +1,10 @@
from sqlmodel import SQLModel, Field, Column, JSON
from sqlmodel import Field, Column, JSON
from services.database import MAIN_DB_BASE
from utils.randomizers import get_random_uuid
class KeyValueSqlModel(SQLModel, table=True):
class KeyValueSqlModel(MAIN_DB_BASE, table=True):
id: str = Field(default_factory=get_random_uuid, primary_key=True)
key: str = Field(index=True)
value: dict = Field(sa_column=Column(JSON))

View file

@ -0,0 +1,7 @@
from sqlmodel import Field, Column, JSON
from services.database import CONTAINER_DB_BASE
class OllamaPullStatus(CONTAINER_DB_BASE, table=True):
id: str = Field(primary_key=True)
status: dict = Field(sa_column=Column(JSON))

View file

@ -1,15 +1,16 @@
from datetime import datetime
from typing import List, Optional
from sqlalchemy import JSON, Column, DateTime
from sqlmodel import SQLModel, Field
from sqlmodel import Field
from models.presentation_layout import PresentationLayoutModel
from models.presentation_outline_model import PresentationOutlineModel
from models.presentation_structure_model import PresentationStructureModel
from services.database import MAIN_DB_BASE
from utils.randomizers import get_random_uuid
class PresentationModel(SQLModel, table=True):
class PresentationModel(MAIN_DB_BASE, table=True):
id: str = Field(primary_key=True)
prompt: str
n_slides: int

View file

@ -1,10 +1,11 @@
from typing import Optional
from sqlmodel import SQLModel, Field, Column, JSON
from sqlmodel import Field, Column, JSON
from services.database import MAIN_DB_BASE
from utils.randomizers import get_random_uuid
class SlideModel(SQLModel, table=True):
class SlideModel(MAIN_DB_BASE, table=True):
id: str = Field(primary_key=True, default_factory=get_random_uuid)
presentation: str
layout_group: str

View file

@ -1,6 +1,4 @@
from services.redis_service import RedisService
from services.temp_file_service import TempFileService
TEMP_FILE_SERVICE = TempFileService()
REDIS_SERVICE = RedisService()

View file

@ -6,11 +6,13 @@ from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncSession,
)
from sqlmodel import SQLModel
from sqlalchemy.orm import DeclarativeBase
from utils.get_env import get_app_data_directory_env, get_database_url_env
MAIN_DB_BASE = DeclarativeBase()
raw_database_url = get_database_url_env() or "sqlite:///" + os.path.join(
get_app_data_directory_env() or "/tmp/presenton", "fastapi.db"
)
@ -37,6 +39,27 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
yield session
# Container DB (Lives inside the container)
CONTAINER_DB_BASE = DeclarativeBase()
container_db_url = "sqlite+aiosqlite:////app/container.db"
container_db_engine: AsyncEngine = create_async_engine(
container_db_url, connect_args={"check_same_thread": False}
)
container_db_async_session_maker = async_sessionmaker(
container_db_engine, expire_on_commit=False
)
async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]:
async with container_db_async_session_maker() as session:
yield session
# Create Database and Tables
async def create_db_and_tables():
async with sql_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.run_sync(MAIN_DB_BASE.metadata.create_all)
async with container_db_engine.begin() as conn:
await conn.run_sync(CONTAINER_DB_BASE.metadata.create_all)

View file

@ -1,115 +0,0 @@
from typing import Any, Optional
import redis
from redis.exceptions import RedisError
from utils.get_env import (
get_redis_db_env,
get_redis_host_env,
get_redis_password_env,
get_redis_port_env,
)
class RedisService:
def __init__(self):
self.redis_host = get_redis_host_env() or "localhost"
self.redis_port = int(get_redis_port_env() or "6379")
self.redis_db = int(get_redis_db_env() or "0")
self.redis_password = get_redis_password_env() or None
self.client = self._create_client()
def _create_client(self) -> redis.Redis:
return redis.Redis(
host=self.redis_host,
port=self.redis_port,
db=self.redis_db,
password=self.redis_password,
decode_responses=True,
)
def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool:
try:
return self.client.set(key, value, ex=expire)
except RedisError:
return False
def get(self, key: str) -> Optional[str]:
try:
return self.client.get(key)
except RedisError:
return None
def delete(self, key: str) -> bool:
try:
return bool(self.client.delete(key))
except RedisError:
return False
def exists(self, key: str) -> bool:
try:
return bool(self.client.exists(key))
except RedisError:
return False
def set_hash(self, name: str, mapping: dict) -> bool:
try:
return self.client.hmset(name, mapping)
except RedisError:
return False
def get_hash(self, name: str) -> Optional[dict]:
try:
return self.client.hgetall(name)
except RedisError:
return None
def delete_hash(self, name: str, *fields: str) -> int:
try:
return self.client.hdel(name, *fields)
except RedisError:
return 0
def set_list(self, name: str, values: list) -> bool:
try:
self.client.delete(name)
if values:
self.client.rpush(name, *values)
return True
except RedisError:
return False
def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]:
try:
return self.client.lrange(name, start, end)
except RedisError:
return None
def add_to_set(self, name: str, *values: str) -> int:
try:
return self.client.sadd(name, *values)
except RedisError:
return 0
def get_set(self, name: str) -> Optional[set]:
try:
return self.client.smembers(name)
except RedisError:
return None
def remove_from_set(self, name: str, *values: str) -> int:
try:
return self.client.srem(name, *values)
except RedisError:
return 0
def clear(self) -> bool:
try:
return self.client.flushdb()
except RedisError:
return False
def close(self):
try:
self.client.close()
except RedisError:
pass