From ec9104c91c41d8a6247ccb0b1f4801cdaa7560ca Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Mon, 4 Aug 2025 21:56:48 +0545 Subject: [PATCH] refactor(fastapi): removes redis and uses sqlite to store background task status --- .../fastapi/api/v1/ppt/background_tasks.py | 91 ++++++++------ .../fastapi/api/v1/ppt/endpoints/ollama.py | 9 +- servers/fastapi/models/sql/image_asset.py | 5 +- servers/fastapi/models/sql/key_value.py | 5 +- .../fastapi/models/sql/ollama_pull_status.py | 7 +- servers/fastapi/models/sql/presentation.py | 5 +- servers/fastapi/models/sql/slide.py | 5 +- servers/fastapi/services/database.py | 10 +- servers/fastapi/services/redis_service.py | 115 ++++++++++++++++++ .../fastapi/services/score_based_chunker.py | 2 +- 10 files changed, 188 insertions(+), 66 deletions(-) create mode 100644 servers/fastapi/services/redis_service.py diff --git a/servers/fastapi/api/v1/ppt/background_tasks.py b/servers/fastapi/api/v1/ppt/background_tasks.py index 14fbab52..dddba98a 100644 --- a/servers/fastapi/api/v1/ppt/background_tasks.py +++ b/servers/fastapi/api/v1/ppt/background_tasks.py @@ -1,4 +1,6 @@ +from datetime import datetime from fastapi import HTTPException +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from models.ollama_model_status import OllamaModelStatus @@ -15,51 +17,60 @@ async def pull_ollama_model_background_task(model: str): ) log_event_count = 0 - async with get_container_db_async_session() as session: - session: AsyncSession = session - try: - async for event in pull_ollama_model(model): - log_event_count += 1 - if log_event_count != 1 and log_event_count % 20 != 0: - continue + session = await get_container_db_async_session().__anext__() - if "completed" in event: - saved_model_status.downloaded = event["completed"] + try: + async for event in pull_ollama_model(model): + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue - if not saved_model_status.size and "total" in event: - saved_model_status.size = event["total"] + if "completed" in event: + saved_model_status.downloaded = event["completed"] - if "status" in event: - saved_model_status.status = event["status"] + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] - session.add( - OllamaPullStatus( - id=model, status=saved_model_status.model_dump(mode="json") - ) - ) - await session.commit() + if "status" in event: + saved_model_status.status = event["status"] - except Exception as e: - saved_model_status.status = "error" - saved_model_status.done = True - session.add( - OllamaPullStatus( - id=model, status=saved_model_status.model_dump(mode="json") - ) - ) - await session.commit() - raise HTTPException( - status_code=500, - detail=f"Failed to pull model: {e}", - ) + await upsert_ollama_pull_status(session, model, saved_model_status) + except Exception as e: + saved_model_status.status = "error" saved_model_status.done = True - saved_model_status.status = "pulled" - saved_model_status.downloaded = saved_model_status.size - - session.add( - OllamaPullStatus( - id=model, status=saved_model_status.model_dump(mode="json") - ) + await upsert_ollama_pull_status(session, model, saved_model_status) + await session.close() + raise HTTPException( + status_code=500, + detail=f"Failed to pull model: {e}", ) - await session.commit() + + saved_model_status.done = True + saved_model_status.status = "pulled" + saved_model_status.downloaded = saved_model_status.size + + await upsert_ollama_pull_status(session, model, saved_model_status) + await session.close() + + +async def upsert_ollama_pull_status( + session: AsyncSession, model: str, model_status: OllamaModelStatus +): + stmt = select(OllamaPullStatus).where(OllamaPullStatus.id == model) + result = await session.execute(stmt) + existing_record = result.scalar_one_or_none() + + if existing_record: + existing_record.status = model_status.model_dump(mode="json") + existing_record.last_updated = datetime.now() + else: + new_record = OllamaPullStatus( + id=model, + status=model_status.model_dump(mode="json"), + last_updated=datetime.now(), + ) + session.add(new_record) + + await session.commit() + await session.flush() diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index 1b8f5fb0..adde8669 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta import json from typing import List from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException @@ -52,10 +53,11 @@ async def pull_model( detail=f"Failed to check pulled models: {e}", ) + saved_pull_status = None saved_model_status = None try: - result = await session.get(OllamaPullStatus, model) - saved_model_status = result.status + saved_pull_status = await session.get(OllamaPullStatus, model) + saved_model_status = saved_pull_status.status except Exception as e: pass @@ -67,8 +69,9 @@ async def pull_model( if ( saved_model_status["status"] == "error" or saved_model_status["status"] == "pulled" + or saved_pull_status.last_updated < (datetime.now() - timedelta(seconds=10)) ): - await session.delete(OllamaPullStatus, model) + await session.delete(saved_pull_status) else: return saved_model_status diff --git a/servers/fastapi/models/sql/image_asset.py b/servers/fastapi/models/sql/image_asset.py index 833cbd19..00939c87 100644 --- a/servers/fastapi/models/sql/image_asset.py +++ b/servers/fastapi/models/sql/image_asset.py @@ -2,13 +2,12 @@ from datetime import datetime from typing import Optional from sqlalchemy import JSON, Column, DateTime -from sqlmodel import Field +from sqlmodel import Field, SQLModel -from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class ImageAsset(MAIN_DB_BASE, table=True): +class ImageAsset(SQLModel, table=True): id: str = Field(default_factory=get_random_uuid, primary_key=True) created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now)) path: str diff --git a/servers/fastapi/models/sql/key_value.py b/servers/fastapi/models/sql/key_value.py index bc1f114c..3ecabf39 100644 --- a/servers/fastapi/models/sql/key_value.py +++ b/servers/fastapi/models/sql/key_value.py @@ -1,10 +1,9 @@ -from sqlmodel import Field, Column, JSON +from sqlmodel import Field, Column, JSON, SQLModel -from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class KeyValueSqlModel(MAIN_DB_BASE, table=True): +class KeyValueSqlModel(SQLModel, table=True): id: str = Field(default_factory=get_random_uuid, primary_key=True) key: str = Field(index=True) value: dict = Field(sa_column=Column(JSON)) diff --git a/servers/fastapi/models/sql/ollama_pull_status.py b/servers/fastapi/models/sql/ollama_pull_status.py index 9089e532..59599cae 100644 --- a/servers/fastapi/models/sql/ollama_pull_status.py +++ b/servers/fastapi/models/sql/ollama_pull_status.py @@ -1,7 +1,8 @@ -from sqlmodel import Field, Column, JSON -from services.database import CONTAINER_DB_BASE +from datetime import datetime +from sqlmodel import Field, Column, JSON, SQLModel, DateTime -class OllamaPullStatus(CONTAINER_DB_BASE, table=True): +class OllamaPullStatus(SQLModel, table=True): id: str = Field(primary_key=True) + last_updated: datetime = Field(sa_column=Column(DateTime, default=datetime.now)) status: dict = Field(sa_column=Column(JSON)) diff --git a/servers/fastapi/models/sql/presentation.py b/servers/fastapi/models/sql/presentation.py index da2b6380..c0b65c64 100644 --- a/servers/fastapi/models/sql/presentation.py +++ b/servers/fastapi/models/sql/presentation.py @@ -1,16 +1,15 @@ from datetime import datetime from typing import List, Optional from sqlalchemy import JSON, Column, DateTime -from sqlmodel import Field +from sqlmodel import Field, SQLModel from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from models.presentation_structure_model import PresentationStructureModel -from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class PresentationModel(MAIN_DB_BASE, table=True): +class PresentationModel(SQLModel, table=True): id: str = Field(primary_key=True) prompt: str n_slides: int diff --git a/servers/fastapi/models/sql/slide.py b/servers/fastapi/models/sql/slide.py index c1cd72c4..7c0cb7e3 100644 --- a/servers/fastapi/models/sql/slide.py +++ b/servers/fastapi/models/sql/slide.py @@ -1,11 +1,10 @@ from typing import Optional -from sqlmodel import Field, Column, JSON +from sqlmodel import Field, Column, JSON, SQLModel -from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class SlideModel(MAIN_DB_BASE, table=True): +class SlideModel(SQLModel, table=True): id: str = Field(primary_key=True, default_factory=get_random_uuid) presentation: str layout_group: str diff --git a/servers/fastapi/services/database.py b/servers/fastapi/services/database.py index f26ed23b..3f419bcf 100644 --- a/servers/fastapi/services/database.py +++ b/servers/fastapi/services/database.py @@ -6,13 +6,11 @@ from sqlalchemy.ext.asyncio import ( async_sessionmaker, AsyncSession, ) -from sqlalchemy.orm import DeclarativeBase +from sqlmodel import SQLModel from utils.get_env import get_app_data_directory_env, get_database_url_env -MAIN_DB_BASE = DeclarativeBase() - raw_database_url = get_database_url_env() or "sqlite:///" + os.path.join( get_app_data_directory_env() or "/tmp/presenton", "fastapi.db" ) @@ -40,8 +38,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: # Container DB (Lives inside the container) -CONTAINER_DB_BASE = DeclarativeBase() - container_db_url = "sqlite+aiosqlite:////app/container.db" container_db_engine: AsyncEngine = create_async_engine( container_db_url, connect_args={"check_same_thread": False} @@ -59,7 +55,7 @@ async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None] # Create Database and Tables async def create_db_and_tables(): async with sql_engine.begin() as conn: - await conn.run_sync(MAIN_DB_BASE.metadata.create_all) + await conn.run_sync(SQLModel.metadata.create_all) async with container_db_engine.begin() as conn: - await conn.run_sync(CONTAINER_DB_BASE.metadata.create_all) + await conn.run_sync(SQLModel.metadata.create_all) diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py new file mode 100644 index 00000000..f2e3d8c9 --- /dev/null +++ b/servers/fastapi/services/redis_service.py @@ -0,0 +1,115 @@ +from typing import Any, Optional +import redis +from redis.exceptions import RedisError + +from utils.get_env import ( + get_redis_db_env, + get_redis_host_env, + get_redis_password_env, + get_redis_port_env, +) + + +class RedisService: + def __init__(self): + self.redis_host = get_redis_host_env() or "localhost" + self.redis_port = int(get_redis_port_env() or "6379") + self.redis_db = int(get_redis_db_env() or "0") + self.redis_password = get_redis_password_env() or None + self.client = self._create_client() + + def _create_client(self) -> redis.Redis: + return redis.Redis( + host=self.redis_host, + port=self.redis_port, + db=self.redis_db, + password=self.redis_password, + decode_responses=True, + ) + + def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: + try: + return self.client.set(key, value, ex=expire) + except RedisError: + return False + + def get(self, key: str) -> Optional[str]: + try: + return self.client.get(key) + except RedisError: + return None + + def delete(self, key: str) -> bool: + try: + return bool(self.client.delete(key)) + except RedisError: + return False + + def exists(self, key: str) -> bool: + try: + return bool(self.client.exists(key)) + except RedisError: + return False + + def set_hash(self, name: str, mapping: dict) -> bool: + try: + return self.client.hmset(name, mapping) + except RedisError: + return False + + def get_hash(self, name: str) -> Optional[dict]: + try: + return self.client.hgetall(name) + except RedisError: + return None + + def delete_hash(self, name: str, *fields: str) -> int: + try: + return self.client.hdel(name, *fields) + except RedisError: + return 0 + + def set_list(self, name: str, values: list) -> bool: + try: + self.client.delete(name) + if values: + self.client.rpush(name, *values) + return True + except RedisError: + return False + + def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: + try: + return self.client.lrange(name, start, end) + except RedisError: + return None + + def add_to_set(self, name: str, *values: str) -> int: + try: + return self.client.sadd(name, *values) + except RedisError: + return 0 + + def get_set(self, name: str) -> Optional[set]: + try: + return self.client.smembers(name) + except RedisError: + return None + + def remove_from_set(self, name: str, *values: str) -> int: + try: + return self.client.srem(name, *values) + except RedisError: + return 0 + + def clear(self) -> bool: + try: + return self.client.flushdb() + except RedisError: + return False + + def close(self): + try: + self.client.close() + except RedisError: + pass diff --git a/servers/fastapi/services/score_based_chunker.py b/servers/fastapi/services/score_based_chunker.py index 45a468f0..0af245a2 100644 --- a/servers/fastapi/services/score_based_chunker.py +++ b/servers/fastapi/services/score_based_chunker.py @@ -5,7 +5,7 @@ import nltk from models.document_chunk import DocumentChunk try: - nltk.data.find("tokenizers/punkt") + nltk.data.find("tokenizers/punkt", paths=["./nltk"]) except LookupError: nltk.download("punkt", download_dir="./nltk")