fix: adds support for sslmode and other params in database url
This commit is contained in:
parent
a6b19aaa34
commit
c63bf35c51
2 changed files with 54 additions and 22 deletions
|
|
@ -15,26 +15,10 @@ from models.sql.presentation import PresentationModel
|
|||
from models.sql.slide import SlideModel
|
||||
from models.sql.presentation_layout_code import PresentationLayoutCodeModel
|
||||
from models.sql.template import TemplateModel
|
||||
from utils.get_env import get_app_data_directory_env, get_database_url_env
|
||||
from utils.db_utils import get_database_url_and_connect_args
|
||||
|
||||
|
||||
|
||||
raw_database_url = get_database_url_env() or "sqlite:///" + os.path.join(
|
||||
get_app_data_directory_env() or "/tmp/presenton", "fastapi.db"
|
||||
)
|
||||
|
||||
if raw_database_url.startswith("sqlite://"):
|
||||
database_url = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://", 1)
|
||||
elif raw_database_url.startswith("postgresql://"):
|
||||
database_url = raw_database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
elif raw_database_url.startswith("mysql://"):
|
||||
database_url = raw_database_url.replace("mysql://", "mysql+aiomysql://", 1)
|
||||
else:
|
||||
database_url = raw_database_url
|
||||
|
||||
connect_args = {}
|
||||
if "sqlite" in database_url:
|
||||
connect_args["check_same_thread"] = False
|
||||
database_url, connect_args = get_database_url_and_connect_args()
|
||||
|
||||
sql_engine: AsyncEngine = create_async_engine(database_url, connect_args=connect_args)
|
||||
async_session_maker = async_sessionmaker(sql_engine, expire_on_commit=False)
|
||||
|
|
@ -71,11 +55,11 @@ async def create_db_and_tables():
|
|||
SlideModel.__table__,
|
||||
KeyValueSqlModel.__table__,
|
||||
ImageAsset.__table__,
|
||||
PresentationLayoutCodeModel.__table__,
|
||||
PresentationLayoutCodeModel.__table__,
|
||||
TemplateModel.__table__,
|
||||
],
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async with container_db_engine.begin() as conn:
|
||||
await conn.run_sync(
|
||||
|
|
|
|||
48
servers/fastapi/utils/db_utils.py
Normal file
48
servers/fastapi/utils/db_utils.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
from utils.get_env import get_app_data_directory_env, get_database_url_env
|
||||
from urllib.parse import urlsplit, urlunsplit, parse_qsl
|
||||
import ssl
|
||||
|
||||
|
||||
def get_database_url_and_connect_args() -> tuple[str, dict]:
|
||||
database_url = get_database_url_env() or "sqlite:///" + os.path.join(
|
||||
get_app_data_directory_env() or "/tmp/presenton", "fastapi.db"
|
||||
)
|
||||
|
||||
if database_url.startswith("sqlite://"):
|
||||
database_url = database_url.replace("sqlite://", "sqlite+aiosqlite://", 1)
|
||||
elif database_url.startswith("postgresql://"):
|
||||
database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
elif database_url.startswith("mysql://"):
|
||||
database_url = database_url.replace("mysql://", "mysql+aiomysql://", 1)
|
||||
else:
|
||||
database_url = database_url
|
||||
|
||||
connect_args = {}
|
||||
if "sqlite" in database_url:
|
||||
connect_args["check_same_thread"] = False
|
||||
|
||||
try:
|
||||
split_result = urlsplit(database_url)
|
||||
if split_result.query:
|
||||
query_params = parse_qsl(split_result.query, keep_blank_values=True)
|
||||
driver_scheme = split_result.scheme
|
||||
for k, v in query_params:
|
||||
key_lower = k.lower()
|
||||
if key_lower == "sslmode" and "postgresql+asyncpg" in driver_scheme:
|
||||
if v.lower() != "disable" and "sqlite" not in database_url:
|
||||
connect_args["ssl"] = ssl.create_default_context()
|
||||
|
||||
database_url = urlunsplit(
|
||||
(
|
||||
split_result.scheme,
|
||||
split_result.netloc,
|
||||
split_result.path,
|
||||
"",
|
||||
split_result.fragment,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return database_url, connect_args
|
||||
Loading…
Add table
Reference in a new issue