100 lines
3.6 KiB
Python
100 lines
3.6 KiB
Python
import os
|
||
from utils.get_env import get_app_data_directory_env, get_database_url_env
|
||
from urllib.parse import urlsplit, urlunsplit, parse_qsl
|
||
import ssl
|
||
|
||
|
||
def _ensure_sqlite_parent_dir(database_url: str) -> None:
|
||
if not database_url.startswith("sqlite://"):
|
||
return
|
||
|
||
split_result = urlsplit(database_url)
|
||
db_path = split_result.path
|
||
if not db_path:
|
||
return
|
||
|
||
# sqlite URLs on Windows can start with /C:/..., normalize that for os.path.
|
||
if os.name == "nt" and len(db_path) >= 3 and db_path[0] == "/" and db_path[2] == ":":
|
||
db_path = db_path[1:]
|
||
|
||
parent = os.path.dirname(db_path)
|
||
if parent:
|
||
os.makedirs(parent, exist_ok=True)
|
||
def _int_env(name: str, default: int) -> int:
|
||
"""Read an integer from an environment variable, falling back to *default*."""
|
||
raw = os.getenv(name)
|
||
if raw is None:
|
||
return default
|
||
try:
|
||
return int(raw)
|
||
except ValueError:
|
||
return default
|
||
|
||
|
||
def get_pool_kwargs() -> dict:
|
||
"""Build SQLAlchemy engine pool keyword arguments from environment variables.
|
||
|
||
Supported variables (all optional):
|
||
DB_POOL_SIZE – max persistent connections (default 5)
|
||
DB_MAX_OVERFLOW – extra connections above pool_size (default 10)
|
||
DB_POOL_TIMEOUT – seconds to wait for a connection (default 30)
|
||
DB_POOL_RECYCLE – seconds before a connection is recycled (default 1800)
|
||
DB_POOL_PRE_PING – enable connection liveness check (default true)
|
||
|
||
For SQLite the pool settings are not applicable and an empty dict is
|
||
returned, since SQLite uses ``StaticPool`` / ``NullPool`` by default.
|
||
"""
|
||
return {
|
||
"pool_size": _int_env("DB_POOL_SIZE", 5),
|
||
"max_overflow": _int_env("DB_MAX_OVERFLOW", 10),
|
||
"pool_timeout": _int_env("DB_POOL_TIMEOUT", 30),
|
||
"pool_recycle": _int_env("DB_POOL_RECYCLE", 1800),
|
||
"pool_pre_ping": os.getenv("DB_POOL_PRE_PING", "true").lower()
|
||
not in ("false", "0", "no"),
|
||
}
|
||
|
||
|
||
def get_database_url_and_connect_args() -> tuple[str, dict]:
|
||
database_url = get_database_url_env() or "sqlite:///" + os.path.join(
|
||
get_app_data_directory_env() or "/tmp/presenton", "fastapi.db"
|
||
)
|
||
|
||
_ensure_sqlite_parent_dir(database_url)
|
||
|
||
if database_url.startswith("sqlite://"):
|
||
database_url = database_url.replace("sqlite://", "sqlite+aiosqlite://", 1)
|
||
elif database_url.startswith("postgresql://"):
|
||
database_url = database_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||
elif database_url.startswith("mysql://"):
|
||
database_url = database_url.replace("mysql://", "mysql+aiomysql://", 1)
|
||
else:
|
||
database_url = database_url
|
||
|
||
connect_args = {}
|
||
if "sqlite" in database_url:
|
||
connect_args["check_same_thread"] = False
|
||
|
||
try:
|
||
split_result = urlsplit(database_url)
|
||
if split_result.query:
|
||
query_params = parse_qsl(split_result.query, keep_blank_values=True)
|
||
driver_scheme = split_result.scheme
|
||
for k, v in query_params:
|
||
key_lower = k.lower()
|
||
if key_lower == "sslmode" and "postgresql+asyncpg" in driver_scheme:
|
||
if v.lower() != "disable" and "sqlite" not in database_url:
|
||
connect_args["ssl"] = ssl.create_default_context()
|
||
|
||
database_url = urlunsplit(
|
||
(
|
||
split_result.scheme,
|
||
split_result.netloc,
|
||
split_result.path,
|
||
"",
|
||
split_result.fragment,
|
||
)
|
||
)
|
||
except Exception:
|
||
pass
|
||
|
||
return database_url, connect_args
|