From c63bf35c51e2897abbd525e63ac2c44e6e418fab Mon Sep 17 00:00:00 2001 From: sauravniraula Date: Sat, 16 Aug 2025 20:48:42 +0545 Subject: [PATCH] fix: adds support for sslmode and other params in database url --- servers/fastapi/services/database.py | 28 ++++------------ servers/fastapi/utils/db_utils.py | 48 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 22 deletions(-) create mode 100644 servers/fastapi/utils/db_utils.py diff --git a/servers/fastapi/services/database.py b/servers/fastapi/services/database.py index 64c9d48d..9aaf40f9 100644 --- a/servers/fastapi/services/database.py +++ b/servers/fastapi/services/database.py @@ -15,26 +15,10 @@ from models.sql.presentation import PresentationModel from models.sql.slide import SlideModel from models.sql.presentation_layout_code import PresentationLayoutCodeModel from models.sql.template import TemplateModel -from utils.get_env import get_app_data_directory_env, get_database_url_env +from utils.db_utils import get_database_url_and_connect_args - -raw_database_url = get_database_url_env() or "sqlite:///" + os.path.join( - get_app_data_directory_env() or "/tmp/presenton", "fastapi.db" -) - -if raw_database_url.startswith("sqlite://"): - database_url = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://", 1) -elif raw_database_url.startswith("postgresql://"): - database_url = raw_database_url.replace("postgresql://", "postgresql+asyncpg://", 1) -elif raw_database_url.startswith("mysql://"): - database_url = raw_database_url.replace("mysql://", "mysql+aiomysql://", 1) -else: - database_url = raw_database_url - -connect_args = {} -if "sqlite" in database_url: - connect_args["check_same_thread"] = False +database_url, connect_args = get_database_url_and_connect_args() sql_engine: AsyncEngine = create_async_engine(database_url, connect_args=connect_args) async_session_maker = async_sessionmaker(sql_engine, expire_on_commit=False) @@ -71,11 +55,11 @@ async def create_db_and_tables(): SlideModel.__table__, KeyValueSqlModel.__table__, ImageAsset.__table__, - PresentationLayoutCodeModel.__table__, + PresentationLayoutCodeModel.__table__, TemplateModel.__table__, - ], - ) - ) + ], + ) + ) async with container_db_engine.begin() as conn: await conn.run_sync( diff --git a/servers/fastapi/utils/db_utils.py b/servers/fastapi/utils/db_utils.py new file mode 100644 index 00000000..368740f5 --- /dev/null +++ b/servers/fastapi/utils/db_utils.py @@ -0,0 +1,48 @@ +import os +from utils.get_env import get_app_data_directory_env, get_database_url_env +from urllib.parse import urlsplit, urlunsplit, parse_qsl +import ssl + + +def get_database_url_and_connect_args() -> tuple[str, dict]: + database_url = get_database_url_env() or "sqlite:///" + os.path.join( + get_app_data_directory_env() or "/tmp/presenton", "fastapi.db" + ) + + if database_url.startswith("sqlite://"): + database_url = database_url.replace("sqlite://", "sqlite+aiosqlite://", 1) + elif database_url.startswith("postgresql://"): + database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1) + elif database_url.startswith("mysql://"): + database_url = database_url.replace("mysql://", "mysql+aiomysql://", 1) + else: + database_url = database_url + + connect_args = {} + if "sqlite" in database_url: + connect_args["check_same_thread"] = False + + try: + split_result = urlsplit(database_url) + if split_result.query: + query_params = parse_qsl(split_result.query, keep_blank_values=True) + driver_scheme = split_result.scheme + for k, v in query_params: + key_lower = k.lower() + if key_lower == "sslmode" and "postgresql+asyncpg" in driver_scheme: + if v.lower() != "disable" and "sqlite" not in database_url: + connect_args["ssl"] = ssl.create_default_context() + + database_url = urlunsplit( + ( + split_result.scheme, + split_result.netloc, + split_result.path, + "", + split_result.fragment, + ) + ) + except Exception: + pass + + return database_url, connect_args