From 653e35bb3dc5d68e330e5b3cd27befbaa91dc759 Mon Sep 17 00:00:00 2001 From: sudipnext Date: Thu, 26 Mar 2026 16:29:52 +0545 Subject: [PATCH] refactor: reorganize migration functions and enhance legacy database handling - Moved the legacy database URL conversion function to improve clarity. - Introduced a new function to handle legacy database stamping before migrations. - Updated error handling during migration to ensure legacy databases are properly managed. - Cleaned up redundant code and improved comments for better maintainability. --- electron/servers/fastapi/migrations.py | 75 +++++++++++++------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/electron/servers/fastapi/migrations.py b/electron/servers/fastapi/migrations.py index c2b75984..4cb75de0 100644 --- a/electron/servers/fastapi/migrations.py +++ b/electron/servers/fastapi/migrations.py @@ -5,26 +5,14 @@ from alembic import command from alembic.config import Config from alembic.script import ScriptDirectory from sqlalchemy import create_engine, inspect, text -from alembic.script import ScriptDirectory -from sqlalchemy import create_engine, inspect, text from utils.db_utils import get_database_url_and_connect_args from utils.get_env import get_migrate_database_on_startup_env + LEGACY_BASELINE_REVISION = "00b3c27a13bc" -def _to_sync_database_url(database_url: str) -> str: - # Preserve slash counts for sqlite URLs so Windows paths stay valid. - if database_url.startswith("sqlite+aiosqlite:///"): - return "sqlite:///" + database_url[len("sqlite+aiosqlite:///") :] - if database_url.startswith("postgresql+asyncpg://"): - return "postgresql://" + database_url[len("postgresql+asyncpg://") :] - if database_url.startswith("mysql+aiomysql://"): - return "mysql://" + database_url[len("mysql+aiomysql://") :] - return database_url - - async def migrate_database_on_startup() -> None: if get_migrate_database_on_startup_env() not in ["true", "True"]: return @@ -37,15 +25,15 @@ async def migrate_database_on_startup() -> None: raise -def run_migrations_sync() -> None: - """Apply Alembic migrations to head (for CLI/scripts; no env gate).""" - _run_migrations() - raise - - -def run_migrations_sync() -> None: - """Apply Alembic migrations to head (for CLI/scripts; no env gate).""" - _run_migrations() +def _to_sync_database_url(database_url: str) -> str: + # Preserve slash counts for sqlite URLs so Windows paths stay valid. + if database_url.startswith("sqlite+aiosqlite:///"): + return "sqlite:///" + database_url[len("sqlite+aiosqlite:///") :] + if database_url.startswith("postgresql+asyncpg://"): + return "postgresql://" + database_url[len("postgresql+asyncpg://") :] + if database_url.startswith("mysql+aiomysql://"): + return "mysql://" + database_url[len("mysql+aiomysql://") :] + return database_url def _run_migrations() -> None: @@ -61,31 +49,42 @@ def _run_migrations() -> None: database_url = _to_sync_database_url(database_url) config.set_main_option("sqlalchemy.url", database_url) + _stamp_legacy_database_if_needed(config, database_url) try: command.upgrade(config, "head") - except Exception as exc: - # Recovery path for historical DBs that were created via create_all() - # without an alembic_version table. + except Exception: + # Safety net for edge cases; legacy DBs are stamped proactively above. if _is_unversioned_populated_database(database_url): - script = ScriptDirectory.from_config(config) - known_revisions = {rev.revision for rev in script.walk_revisions()} - baseline_revision = ( - LEGACY_BASELINE_REVISION - if LEGACY_BASELINE_REVISION in known_revisions - else script.get_base() - ) - print( - "Detected existing unversioned database schema. " - f"Stamping revision to {baseline_revision} before upgrading.", - flush=True, - ) - command.stamp(config, baseline_revision) + _stamp_legacy_database_if_needed(config, database_url) command.upgrade(config, "head") return raise +def _stamp_legacy_database_if_needed(config: Config, database_url: str) -> None: + """ + If the DB has app tables but no migration reference in alembic_version, + treat it as a legacy DB and stamp baseline before upgrading. + """ + if not _is_unversioned_populated_database(database_url): + return + + script = ScriptDirectory.from_config(config) + known_revisions = {rev.revision for rev in script.walk_revisions()} + baseline_revision = ( + LEGACY_BASELINE_REVISION + if LEGACY_BASELINE_REVISION in known_revisions + else script.get_base() + ) + print( + "Detected legacy database without migration reference. " + f"Stamping revision to {baseline_revision} before upgrading.", + flush=True, + ) + command.stamp(config, baseline_revision) + + def _is_unversioned_populated_database(database_url: str) -> bool: known_app_tables = { "presentations",