video-accessibility/backend/app/migrations/migrator.py
2025-08-24 16:28:33 -05:00

253 lines
No EOL
9.1 KiB
Python

"""MongoDB migration framework."""
import os
import importlib.util
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from motor.motor_asyncio import AsyncIOMotorDatabase
from app.core.database import get_database
from app.core.logging import get_logger
from app.telemetry.tracing import trace_async_operation
logger = get_logger(__name__)
class Migration(ABC):
"""Base class for database migrations."""
def __init__(self):
self.version: str = "0000-00-00-000000" # Format: YYYY-MM-DD-HHMMSS
self.description: str = ""
self.db: Optional[AsyncIOMotorDatabase] = None
@abstractmethod
async def up(self) -> None:
"""Apply the migration."""
pass
@abstractmethod
async def down(self) -> None:
"""Rollback the migration."""
pass
async def set_database(self, db: AsyncIOMotorDatabase) -> None:
"""Set the database instance."""
self.db = db
class MigrationRecord:
"""Represents a migration record in the database."""
def __init__(self, version: str, description: str, applied_at: datetime):
self.version = version
self.description = description
self.applied_at = applied_at
class MigrationManager:
"""Manages database migrations."""
def __init__(self):
self.db: Optional[AsyncIOMotorDatabase] = None
self.migrations_dir = Path(__file__).parent / "scripts"
self.collection_name = "migration_history"
async def initialize(self) -> None:
"""Initialize the migration manager."""
self.db = await get_database()
await self._ensure_migration_collection()
async def _ensure_migration_collection(self) -> None:
"""Ensure the migration history collection exists with proper indexes."""
collection = self.db[self.collection_name]
# Create indexes for migration history
await collection.create_index([("version", 1)], unique=True)
await collection.create_index([("applied_at", -1)])
logger.info("Migration history collection initialized")
def discover_migrations(self) -> List[str]:
"""Discover all migration files in the migrations directory."""
if not self.migrations_dir.exists():
logger.warning(f"Migrations directory not found: {self.migrations_dir}")
return []
migration_files = []
for file_path in self.migrations_dir.glob("*.py"):
if file_path.name.startswith("migration_") and not file_path.name.startswith("__"):
migration_files.append(file_path.stem)
# Sort by version (filename should start with version)
migration_files.sort()
return migration_files
async def load_migration(self, migration_name: str) -> Migration:
"""Dynamically load a migration class."""
migration_path = self.migrations_dir / f"{migration_name}.py"
if not migration_path.exists():
raise FileNotFoundError(f"Migration file not found: {migration_path}")
# Load the module
spec = importlib.util.spec_from_file_location(migration_name, migration_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Get the migration class (assume it's named Migration)
if not hasattr(module, 'Migration'):
raise AttributeError(f"Migration class not found in {migration_name}")
migration_class = getattr(module, 'Migration')
migration = migration_class()
await migration.set_database(self.db)
return migration
async def get_applied_migrations(self) -> List[str]:
"""Get list of applied migration versions."""
collection = self.db[self.collection_name]
cursor = collection.find({}, {"version": 1}).sort("version", 1)
applied = []
async for doc in cursor:
applied.append(doc["version"])
return applied
async def record_migration(self, migration: Migration) -> None:
"""Record a successful migration in the database."""
collection = self.db[self.collection_name]
record = {
"version": migration.version,
"description": migration.description,
"applied_at": datetime.utcnow()
}
await collection.insert_one(record)
logger.info(f"Recorded migration: {migration.version} - {migration.description}")
async def remove_migration_record(self, version: str) -> None:
"""Remove a migration record (for rollback)."""
collection = self.db[self.collection_name]
await collection.delete_one({"version": version})
logger.info(f"Removed migration record: {version}")
@trace_async_operation("migration_manager.migrate_up")
async def migrate_up(self, target_version: Optional[str] = None) -> List[str]:
"""
Apply migrations up to the target version.
Args:
target_version: Version to migrate to. If None, applies all pending migrations.
Returns:
List of applied migration versions.
"""
await self.initialize()
# Discover all migrations
all_migrations = self.discover_migrations()
applied_migrations = await self.get_applied_migrations()
# Find pending migrations
pending_migrations = []
for migration_name in all_migrations:
# Extract version from filename (assumes format: migration_YYYY-MM-DD-HHMMSS_description.py)
version = migration_name.replace("migration_", "").split("_")[0]
if version not in applied_migrations:
if target_version is None or version <= target_version:
pending_migrations.append((migration_name, version))
# Sort by version
pending_migrations.sort(key=lambda x: x[1])
applied = []
for migration_name, version in pending_migrations:
try:
logger.info(f"Applying migration: {migration_name}")
migration = await self.load_migration(migration_name)
await migration.up()
await self.record_migration(migration)
applied.append(version)
logger.info(f"Successfully applied migration: {version}")
except Exception as e:
logger.error(f"Failed to apply migration {migration_name}: {e}")
raise
return applied
@trace_async_operation("migration_manager.migrate_down")
async def migrate_down(self, target_version: str) -> List[str]:
"""
Rollback migrations down to the target version.
Args:
target_version: Version to rollback to.
Returns:
List of rolled back migration versions.
"""
await self.initialize()
applied_migrations = await self.get_applied_migrations()
# Find migrations to rollback (newer than target)
to_rollback = []
for version in reversed(applied_migrations):
if version > target_version:
to_rollback.append(version)
rolled_back = []
for version in to_rollback:
try:
# Find migration file for this version
migration_name = None
for migration_file in self.discover_migrations():
if version in migration_file:
migration_name = migration_file
break
if not migration_name:
logger.warning(f"Migration file not found for version {version}")
continue
logger.info(f"Rolling back migration: {migration_name}")
migration = await self.load_migration(migration_name)
await migration.down()
await self.remove_migration_record(version)
rolled_back.append(version)
logger.info(f"Successfully rolled back migration: {version}")
except Exception as e:
logger.error(f"Failed to rollback migration {version}: {e}")
raise
return rolled_back
async def get_migration_status(self) -> dict:
"""Get current migration status."""
await self.initialize()
all_migrations = self.discover_migrations()
applied_migrations = await self.get_applied_migrations()
pending_count = len(all_migrations) - len(applied_migrations)
return {
"total_migrations": len(all_migrations),
"applied_migrations": len(applied_migrations),
"pending_migrations": pending_count,
"latest_applied": applied_migrations[-1] if applied_migrations else None,
"all_applied": applied_migrations
}