diff --git a/.dockerignore b/.dockerignore index 6588610c..47da7fe3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -6,4 +6,6 @@ servers/fastapi/debug servers/fastapi/.venv servers/nextjs/node_modules -servers/nextjs/.next \ No newline at end of file +servers/nextjs/.next + +container.db \ No newline at end of file diff --git a/.gitignore b/.gitignore index 862916f1..ba741405 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ debug my-doc.txt generated_models nltk -chroma \ No newline at end of file +chroma +container.db \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 52f11ff8..9ba7c3bf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,7 @@ FROM python:3.11-slim-bookworm # Install Node.js and npm RUN apt-get update && apt-get install -y \ nginx \ - curl \ - redis-server + curl # Install Node.js 20 using NodeSource repository RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \ @@ -25,7 +24,7 @@ RUN curl -fsSL https://ollama.com/install.sh | sh # Install dependencies for FastAPI RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \ - pathvalidate pdfplumber nltk chromadb sqlmodel redis \ + pathvalidate pdfplumber nltk chromadb sqlmodel \ anthropic google-genai openai fastmcp RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/Dockerfile.dev b/Dockerfile.dev index 3c9cbd74..20b37e52 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -3,8 +3,7 @@ FROM python:3.11-slim-bookworm # Install Node.js and npm RUN apt-get update && apt-get install -y \ nginx \ - curl \ - redis-server + curl # Install Node.js 20 using NodeSource repository @@ -27,7 +26,7 @@ RUN curl -fsSL http://ollama.com/install.sh | sh # Install dependencies for FastAPI RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \ - pathvalidate pdfplumber nltk chromadb sqlmodel redis \ + pathvalidate pdfplumber nltk chromadb sqlmodel \ anthropic google-genai openai fastmcp RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/servers/fastapi/api/v1/ppt/background_tasks.py b/servers/fastapi/api/v1/ppt/background_tasks.py index e9a604f6..14fbab52 100644 --- a/servers/fastapi/api/v1/ppt/background_tasks.py +++ b/servers/fastapi/api/v1/ppt/background_tasks.py @@ -1,9 +1,9 @@ -import json - from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession from models.ollama_model_status import OllamaModelStatus -from services import REDIS_SERVICE +from models.sql.ollama_pull_status import OllamaPullStatus +from services.database import get_container_db_async_session from utils.ollama import pull_ollama_model @@ -15,45 +15,51 @@ async def pull_ollama_model_background_task(model: str): ) log_event_count = 0 - try: - async for event in pull_ollama_model(model): - log_event_count += 1 - if log_event_count != 1 and log_event_count % 20 != 0: - continue + async with get_container_db_async_session() as session: + session: AsyncSession = session + try: + async for event in pull_ollama_model(model): + log_event_count += 1 + if log_event_count != 1 and log_event_count % 20 != 0: + continue - if "completed" in event: - saved_model_status.downloaded = event["completed"] + if "completed" in event: + saved_model_status.downloaded = event["completed"] - if not saved_model_status.size and "total" in event: - saved_model_status.size = event["total"] + if not saved_model_status.size and "total" in event: + saved_model_status.size = event["total"] - if "status" in event: - saved_model_status.status = event["status"] + if "status" in event: + saved_model_status.status = event["status"] - REDIS_SERVICE.set( - f"ollama_models/{model}", - json.dumps(saved_model_status.model_dump(mode="json")), + session.add( + OllamaPullStatus( + id=model, status=saved_model_status.model_dump(mode="json") + ) + ) + await session.commit() + + except Exception as e: + saved_model_status.status = "error" + saved_model_status.done = True + session.add( + OllamaPullStatus( + id=model, status=saved_model_status.model_dump(mode="json") + ) + ) + await session.commit() + raise HTTPException( + status_code=500, + detail=f"Failed to pull model: {e}", ) - except Exception as e: - saved_model_status.status = "error" saved_model_status.done = True - REDIS_SERVICE.set( - f"ollama_models/{model}", - json.dumps(saved_model_status.model_dump(mode="json")), + saved_model_status.status = "pulled" + saved_model_status.downloaded = saved_model_status.size + + session.add( + OllamaPullStatus( + id=model, status=saved_model_status.model_dump(mode="json") + ) ) - raise HTTPException( - status_code=500, - detail=f"Failed to pull model: {e}", - ) - - saved_model_status.done = True - saved_model_status.status = "pulled" - saved_model_status.downloaded = saved_model_status.size - - REDIS_SERVICE.set( - f"ollama_models/{model}", - json.dumps(saved_model_status.model_dump(mode="json")), - ) - - return saved_model_status + await session.commit() diff --git a/servers/fastapi/api/v1/ppt/endpoints/ollama.py b/servers/fastapi/api/v1/ppt/endpoints/ollama.py index 13e334a5..1b8f5fb0 100644 --- a/servers/fastapi/api/v1/ppt/endpoints/ollama.py +++ b/servers/fastapi/api/v1/ppt/endpoints/ollama.py @@ -1,12 +1,14 @@ import json from typing import List -from fastapi import APIRouter, BackgroundTasks, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession from api.v1.ppt.background_tasks import pull_ollama_model_background_task from constants.supported_ollama_models import SUPPORTED_OLLAMA_MODELS from models.ollama_model_metadata import OllamaModelMetadata from models.ollama_model_status import OllamaModelStatus -from services import REDIS_SERVICE +from models.sql.ollama_pull_status import OllamaPullStatus +from services.database import get_container_db_async_session from utils.ollama import list_pulled_ollama_models OLLAMA_ROUTER = APIRouter(prefix="/ollama", tags=["Ollama"]) @@ -23,7 +25,11 @@ async def get_available_models(): @OLLAMA_ROUTER.get("/model/pull", response_model=OllamaModelStatus) -async def pull_model(model: str, background_tasks: BackgroundTasks): +async def pull_model( + model: str, + background_tasks: BackgroundTasks, + session: AsyncSession = Depends(get_container_db_async_session), +): if model not in SUPPORTED_OLLAMA_MODELS: raise HTTPException( @@ -46,21 +52,25 @@ async def pull_model(model: str, background_tasks: BackgroundTasks): detail=f"Failed to check pulled models: {e}", ) - saved_model_status = REDIS_SERVICE.get(f"ollama_models/{model}") + saved_model_status = None + try: + result = await session.get(OllamaPullStatus, model) + saved_model_status = result.status + except Exception as e: + pass # If the model is being pulled, return the model if saved_model_status: - saved_model_status_json = json.loads(saved_model_status) # If the model is being pulled, return the model # ? If the model status is pulled in redis but was not found while listing pulled models, # ? it means the model was deleted and we need to pull it again if ( - saved_model_status_json["status"] == "error" - or saved_model_status_json["status"] == "pulled" + saved_model_status["status"] == "error" + or saved_model_status["status"] == "pulled" ): - REDIS_SERVICE.delete(f"ollama_models/{model}") + await session.delete(OllamaPullStatus, model) else: - return saved_model_status_json + return saved_model_status # If the model is not being pulled, pull the model background_tasks.add_task(pull_ollama_model_background_task, model) diff --git a/servers/fastapi/models/sql/image_asset.py b/servers/fastapi/models/sql/image_asset.py index 2c7b4053..833cbd19 100644 --- a/servers/fastapi/models/sql/image_asset.py +++ b/servers/fastapi/models/sql/image_asset.py @@ -2,12 +2,13 @@ from datetime import datetime from typing import Optional from sqlalchemy import JSON, Column, DateTime -from sqlmodel import SQLModel, Field +from sqlmodel import Field +from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class ImageAsset(SQLModel, table=True): +class ImageAsset(MAIN_DB_BASE, table=True): id: str = Field(default_factory=get_random_uuid, primary_key=True) created_at: datetime = Field(sa_column=Column(DateTime, default=datetime.now)) path: str diff --git a/servers/fastapi/models/sql/key_value.py b/servers/fastapi/models/sql/key_value.py index fadff974..bc1f114c 100644 --- a/servers/fastapi/models/sql/key_value.py +++ b/servers/fastapi/models/sql/key_value.py @@ -1,9 +1,10 @@ -from sqlmodel import SQLModel, Field, Column, JSON +from sqlmodel import Field, Column, JSON +from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class KeyValueSqlModel(SQLModel, table=True): +class KeyValueSqlModel(MAIN_DB_BASE, table=True): id: str = Field(default_factory=get_random_uuid, primary_key=True) key: str = Field(index=True) value: dict = Field(sa_column=Column(JSON)) diff --git a/servers/fastapi/models/sql/ollama_pull_status.py b/servers/fastapi/models/sql/ollama_pull_status.py new file mode 100644 index 00000000..9089e532 --- /dev/null +++ b/servers/fastapi/models/sql/ollama_pull_status.py @@ -0,0 +1,7 @@ +from sqlmodel import Field, Column, JSON +from services.database import CONTAINER_DB_BASE + + +class OllamaPullStatus(CONTAINER_DB_BASE, table=True): + id: str = Field(primary_key=True) + status: dict = Field(sa_column=Column(JSON)) diff --git a/servers/fastapi/models/sql/presentation.py b/servers/fastapi/models/sql/presentation.py index 0b2fcd1b..da2b6380 100644 --- a/servers/fastapi/models/sql/presentation.py +++ b/servers/fastapi/models/sql/presentation.py @@ -1,15 +1,16 @@ from datetime import datetime from typing import List, Optional from sqlalchemy import JSON, Column, DateTime -from sqlmodel import SQLModel, Field +from sqlmodel import Field from models.presentation_layout import PresentationLayoutModel from models.presentation_outline_model import PresentationOutlineModel from models.presentation_structure_model import PresentationStructureModel +from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class PresentationModel(SQLModel, table=True): +class PresentationModel(MAIN_DB_BASE, table=True): id: str = Field(primary_key=True) prompt: str n_slides: int diff --git a/servers/fastapi/models/sql/slide.py b/servers/fastapi/models/sql/slide.py index 268bf1da..c1cd72c4 100644 --- a/servers/fastapi/models/sql/slide.py +++ b/servers/fastapi/models/sql/slide.py @@ -1,10 +1,11 @@ from typing import Optional -from sqlmodel import SQLModel, Field, Column, JSON +from sqlmodel import Field, Column, JSON +from services.database import MAIN_DB_BASE from utils.randomizers import get_random_uuid -class SlideModel(SQLModel, table=True): +class SlideModel(MAIN_DB_BASE, table=True): id: str = Field(primary_key=True, default_factory=get_random_uuid) presentation: str layout_group: str diff --git a/servers/fastapi/services/__init__.py b/servers/fastapi/services/__init__.py index 2c4366c5..a1d47d50 100644 --- a/servers/fastapi/services/__init__.py +++ b/servers/fastapi/services/__init__.py @@ -1,6 +1,4 @@ -from services.redis_service import RedisService from services.temp_file_service import TempFileService TEMP_FILE_SERVICE = TempFileService() -REDIS_SERVICE = RedisService() diff --git a/servers/fastapi/services/database.py b/servers/fastapi/services/database.py index 6b458f73..f26ed23b 100644 --- a/servers/fastapi/services/database.py +++ b/servers/fastapi/services/database.py @@ -6,11 +6,13 @@ from sqlalchemy.ext.asyncio import ( async_sessionmaker, AsyncSession, ) -from sqlmodel import SQLModel +from sqlalchemy.orm import DeclarativeBase from utils.get_env import get_app_data_directory_env, get_database_url_env +MAIN_DB_BASE = DeclarativeBase() + raw_database_url = get_database_url_env() or "sqlite:///" + os.path.join( get_app_data_directory_env() or "/tmp/presenton", "fastapi.db" ) @@ -37,6 +39,27 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: yield session +# Container DB (Lives inside the container) +CONTAINER_DB_BASE = DeclarativeBase() + +container_db_url = "sqlite+aiosqlite:////app/container.db" +container_db_engine: AsyncEngine = create_async_engine( + container_db_url, connect_args={"check_same_thread": False} +) +container_db_async_session_maker = async_sessionmaker( + container_db_engine, expire_on_commit=False +) + + +async def get_container_db_async_session() -> AsyncGenerator[AsyncSession, None]: + async with container_db_async_session_maker() as session: + yield session + + +# Create Database and Tables async def create_db_and_tables(): async with sql_engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) + await conn.run_sync(MAIN_DB_BASE.metadata.create_all) + + async with container_db_engine.begin() as conn: + await conn.run_sync(CONTAINER_DB_BASE.metadata.create_all) diff --git a/servers/fastapi/services/redis_service.py b/servers/fastapi/services/redis_service.py deleted file mode 100644 index f2e3d8c9..00000000 --- a/servers/fastapi/services/redis_service.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Optional -import redis -from redis.exceptions import RedisError - -from utils.get_env import ( - get_redis_db_env, - get_redis_host_env, - get_redis_password_env, - get_redis_port_env, -) - - -class RedisService: - def __init__(self): - self.redis_host = get_redis_host_env() or "localhost" - self.redis_port = int(get_redis_port_env() or "6379") - self.redis_db = int(get_redis_db_env() or "0") - self.redis_password = get_redis_password_env() or None - self.client = self._create_client() - - def _create_client(self) -> redis.Redis: - return redis.Redis( - host=self.redis_host, - port=self.redis_port, - db=self.redis_db, - password=self.redis_password, - decode_responses=True, - ) - - def set(self, key: str, value: Any, expire: Optional[int] = None) -> bool: - try: - return self.client.set(key, value, ex=expire) - except RedisError: - return False - - def get(self, key: str) -> Optional[str]: - try: - return self.client.get(key) - except RedisError: - return None - - def delete(self, key: str) -> bool: - try: - return bool(self.client.delete(key)) - except RedisError: - return False - - def exists(self, key: str) -> bool: - try: - return bool(self.client.exists(key)) - except RedisError: - return False - - def set_hash(self, name: str, mapping: dict) -> bool: - try: - return self.client.hmset(name, mapping) - except RedisError: - return False - - def get_hash(self, name: str) -> Optional[dict]: - try: - return self.client.hgetall(name) - except RedisError: - return None - - def delete_hash(self, name: str, *fields: str) -> int: - try: - return self.client.hdel(name, *fields) - except RedisError: - return 0 - - def set_list(self, name: str, values: list) -> bool: - try: - self.client.delete(name) - if values: - self.client.rpush(name, *values) - return True - except RedisError: - return False - - def get_list(self, name: str, start: int = 0, end: int = -1) -> Optional[list]: - try: - return self.client.lrange(name, start, end) - except RedisError: - return None - - def add_to_set(self, name: str, *values: str) -> int: - try: - return self.client.sadd(name, *values) - except RedisError: - return 0 - - def get_set(self, name: str) -> Optional[set]: - try: - return self.client.smembers(name) - except RedisError: - return None - - def remove_from_set(self, name: str, *values: str) -> int: - try: - return self.client.srem(name, *values) - except RedisError: - return 0 - - def clear(self) -> bool: - try: - return self.client.flushdb() - except RedisError: - return False - - def close(self): - try: - self.client.close() - except RedisError: - pass