presenton/servers/fastapi/migrations.py
sudipnext f050064771 feat: add legacy database migration support and new database schema
- Introduced functions to handle legacy database stamping and migration.
- Added a new Alembic migration script for initializing the database schema.
- Enhanced the migration process to check for unversioned databases and apply necessary stamps before upgrades.
- Created new migration files for adding a theme column to presentations.
2026-03-26 15:33:49 +05:45

114 lines
4 KiB
Python

import asyncio
from pathlib import Path
from alembic import command
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, inspect, text
from utils.db_utils import get_database_url_and_connect_args
from utils.get_env import get_migrate_database_on_startup_env
LEGACY_BASELINE_REVISION = "00b3c27a13bc"
async def migrate_database_on_startup() -> None:
if get_migrate_database_on_startup_env() not in ["true", "True"]:
return
try:
await asyncio.to_thread(_run_migrations)
print("Migrations run successfully", flush=True)
except Exception as exc:
print(f"Error running migrations: {exc}", flush=True)
raise
def _to_sync_database_url(database_url: str) -> str:
# Preserve slash counts for sqlite URLs so Windows paths stay valid.
if database_url.startswith("sqlite+aiosqlite:///"):
return "sqlite:///" + database_url[len("sqlite+aiosqlite:///") :]
if database_url.startswith("postgresql+asyncpg://"):
return "postgresql://" + database_url[len("postgresql+asyncpg://") :]
if database_url.startswith("mysql+aiomysql://"):
return "mysql://" + database_url[len("mysql+aiomysql://") :]
return database_url
def _run_migrations() -> None:
# migrations.py lives at servers/fastapi/migrations.py
# so parents[0] = servers/fastapi/, where alembic/ lives alongside it.
base_dir = Path(__file__).resolve().parents[0]
config = Config()
config.set_main_option("script_location", str(base_dir / "alembic"))
database_url, _ = get_database_url_and_connect_args()
# Alembic uses synchronous engines; strip async driver prefixes.
database_url = _to_sync_database_url(database_url)
config.set_main_option("sqlalchemy.url", database_url)
_stamp_legacy_database_if_needed(config, database_url)
try:
command.upgrade(config, "head")
except Exception:
# Safety net for edge cases; legacy DBs are stamped proactively above.
if _is_unversioned_populated_database(database_url):
_stamp_legacy_database_if_needed(config, database_url)
command.upgrade(config, "head")
return
raise
def _stamp_legacy_database_if_needed(config: Config, database_url: str) -> None:
"""
If the DB has app tables but no migration reference in alembic_version,
treat it as a legacy DB and stamp baseline before upgrading.
"""
if not _is_unversioned_populated_database(database_url):
return
script = ScriptDirectory.from_config(config)
known_revisions = {rev.revision for rev in script.walk_revisions()}
baseline_revision = (
LEGACY_BASELINE_REVISION
if LEGACY_BASELINE_REVISION in known_revisions
else script.get_base()
)
print(
"Detected legacy database without migration reference. "
f"Stamping revision to {baseline_revision} before upgrading.",
flush=True,
)
command.stamp(config, baseline_revision)
def _is_unversioned_populated_database(database_url: str) -> bool:
known_app_tables = {
"presentations",
"slides",
"templates",
"keyvaluesqlmodel",
"imageasset",
"presentation_layout_codes",
"async_presentation_generation_tasks",
"webhook_subscriptions",
}
engine = create_engine(database_url)
try:
with engine.connect() as connection:
inspector = inspect(connection)
table_names = set(inspector.get_table_names())
has_alembic_version_table = "alembic_version" in table_names
has_applied_revision = False
if has_alembic_version_table:
revision_count = connection.execute(
text("SELECT COUNT(*) FROM alembic_version")
).scalar_one()
has_applied_revision = revision_count > 0
has_known_app_tables = len(table_names.intersection(known_app_tables)) > 0
return has_known_app_tables and not has_applied_revision
finally:
engine.dispose()