refactor: reorganize migration functions and enhance legacy database handling

- Moved the legacy database URL conversion function to improve clarity.
- Introduced a new function to handle legacy database stamping before migrations.
- Updated error handling during migration to ensure legacy databases are properly managed.
- Cleaned up redundant code and improved comments for better maintainability.
This commit is contained in:
sudipnext 2026-03-26 16:29:52 +05:45
parent f050064771
commit 653e35bb3d

View file

@ -5,26 +5,14 @@ from alembic import command
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, inspect, text
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, inspect, text
from utils.db_utils import get_database_url_and_connect_args
from utils.get_env import get_migrate_database_on_startup_env
LEGACY_BASELINE_REVISION = "00b3c27a13bc"
def _to_sync_database_url(database_url: str) -> str:
# Preserve slash counts for sqlite URLs so Windows paths stay valid.
if database_url.startswith("sqlite+aiosqlite:///"):
return "sqlite:///" + database_url[len("sqlite+aiosqlite:///") :]
if database_url.startswith("postgresql+asyncpg://"):
return "postgresql://" + database_url[len("postgresql+asyncpg://") :]
if database_url.startswith("mysql+aiomysql://"):
return "mysql://" + database_url[len("mysql+aiomysql://") :]
return database_url
async def migrate_database_on_startup() -> None:
if get_migrate_database_on_startup_env() not in ["true", "True"]:
return
@ -37,15 +25,15 @@ async def migrate_database_on_startup() -> None:
raise
def run_migrations_sync() -> None:
"""Apply Alembic migrations to head (for CLI/scripts; no env gate)."""
_run_migrations()
raise
def run_migrations_sync() -> None:
"""Apply Alembic migrations to head (for CLI/scripts; no env gate)."""
_run_migrations()
def _to_sync_database_url(database_url: str) -> str:
# Preserve slash counts for sqlite URLs so Windows paths stay valid.
if database_url.startswith("sqlite+aiosqlite:///"):
return "sqlite:///" + database_url[len("sqlite+aiosqlite:///") :]
if database_url.startswith("postgresql+asyncpg://"):
return "postgresql://" + database_url[len("postgresql+asyncpg://") :]
if database_url.startswith("mysql+aiomysql://"):
return "mysql://" + database_url[len("mysql+aiomysql://") :]
return database_url
def _run_migrations() -> None:
@ -61,31 +49,42 @@ def _run_migrations() -> None:
database_url = _to_sync_database_url(database_url)
config.set_main_option("sqlalchemy.url", database_url)
_stamp_legacy_database_if_needed(config, database_url)
try:
command.upgrade(config, "head")
except Exception as exc:
# Recovery path for historical DBs that were created via create_all()
# without an alembic_version table.
except Exception:
# Safety net for edge cases; legacy DBs are stamped proactively above.
if _is_unversioned_populated_database(database_url):
script = ScriptDirectory.from_config(config)
known_revisions = {rev.revision for rev in script.walk_revisions()}
baseline_revision = (
LEGACY_BASELINE_REVISION
if LEGACY_BASELINE_REVISION in known_revisions
else script.get_base()
)
print(
"Detected existing unversioned database schema. "
f"Stamping revision to {baseline_revision} before upgrading.",
flush=True,
)
command.stamp(config, baseline_revision)
_stamp_legacy_database_if_needed(config, database_url)
command.upgrade(config, "head")
return
raise
def _stamp_legacy_database_if_needed(config: Config, database_url: str) -> None:
"""
If the DB has app tables but no migration reference in alembic_version,
treat it as a legacy DB and stamp baseline before upgrading.
"""
if not _is_unversioned_populated_database(database_url):
return
script = ScriptDirectory.from_config(config)
known_revisions = {rev.revision for rev in script.walk_revisions()}
baseline_revision = (
LEGACY_BASELINE_REVISION
if LEGACY_BASELINE_REVISION in known_revisions
else script.get_base()
)
print(
"Detected legacy database without migration reference. "
f"Stamping revision to {baseline_revision} before upgrading.",
flush=True,
)
command.stamp(config, baseline_revision)
def _is_unversioned_populated_database(database_url: str) -> bool:
known_app_tables = {
"presentations",