"""MongoDB migration framework.""" import os import importlib.util from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path from typing import List, Optional from motor.motor_asyncio import AsyncIOMotorDatabase from app.core.database import get_database from app.core.logging import get_logger from app.telemetry.tracing import trace_async_operation logger = get_logger(__name__) class Migration(ABC): """Base class for database migrations.""" def __init__(self): self.version: str = "0000-00-00-000000" # Format: YYYY-MM-DD-HHMMSS self.description: str = "" self.db: Optional[AsyncIOMotorDatabase] = None @abstractmethod async def up(self) -> None: """Apply the migration.""" pass @abstractmethod async def down(self) -> None: """Rollback the migration.""" pass async def set_database(self, db: AsyncIOMotorDatabase) -> None: """Set the database instance.""" self.db = db class MigrationRecord: """Represents a migration record in the database.""" def __init__(self, version: str, description: str, applied_at: datetime): self.version = version self.description = description self.applied_at = applied_at class MigrationManager: """Manages database migrations.""" def __init__(self): self.db: Optional[AsyncIOMotorDatabase] = None self.migrations_dir = Path(__file__).parent / "scripts" self.collection_name = "migration_history" async def initialize(self) -> None: """Initialize the migration manager.""" self.db = await get_database() await self._ensure_migration_collection() async def _ensure_migration_collection(self) -> None: """Ensure the migration history collection exists with proper indexes.""" collection = self.db[self.collection_name] # Create indexes for migration history await collection.create_index([("version", 1)], unique=True) await collection.create_index([("applied_at", -1)]) logger.info("Migration history collection initialized") def discover_migrations(self) -> List[str]: """Discover all migration files in the migrations directory.""" if not self.migrations_dir.exists(): logger.warning(f"Migrations directory not found: {self.migrations_dir}") return [] migration_files = [] for file_path in self.migrations_dir.glob("*.py"): if file_path.name.startswith("migration_") and not file_path.name.startswith("__"): migration_files.append(file_path.stem) # Sort by version (filename should start with version) migration_files.sort() return migration_files async def load_migration(self, migration_name: str) -> Migration: """Dynamically load a migration class.""" migration_path = self.migrations_dir / f"{migration_name}.py" if not migration_path.exists(): raise FileNotFoundError(f"Migration file not found: {migration_path}") # Load the module spec = importlib.util.spec_from_file_location(migration_name, migration_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Get the migration class (assume it's named Migration) if not hasattr(module, 'Migration'): raise AttributeError(f"Migration class not found in {migration_name}") migration_class = getattr(module, 'Migration') migration = migration_class() await migration.set_database(self.db) return migration async def get_applied_migrations(self) -> List[str]: """Get list of applied migration versions.""" collection = self.db[self.collection_name] cursor = collection.find({}, {"version": 1}).sort("version", 1) applied = [] async for doc in cursor: applied.append(doc["version"]) return applied async def record_migration(self, migration: Migration) -> None: """Record a successful migration in the database.""" collection = self.db[self.collection_name] record = { "version": migration.version, "description": migration.description, "applied_at": datetime.utcnow() } await collection.insert_one(record) logger.info(f"Recorded migration: {migration.version} - {migration.description}") async def remove_migration_record(self, version: str) -> None: """Remove a migration record (for rollback).""" collection = self.db[self.collection_name] await collection.delete_one({"version": version}) logger.info(f"Removed migration record: {version}") @trace_async_operation("migration_manager.migrate_up") async def migrate_up(self, target_version: Optional[str] = None) -> List[str]: """ Apply migrations up to the target version. Args: target_version: Version to migrate to. If None, applies all pending migrations. Returns: List of applied migration versions. """ await self.initialize() # Discover all migrations all_migrations = self.discover_migrations() applied_migrations = await self.get_applied_migrations() # Find pending migrations pending_migrations = [] for migration_name in all_migrations: # Extract version from filename (assumes format: migration_YYYY-MM-DD-HHMMSS_description.py) version = migration_name.replace("migration_", "").split("_")[0] if version not in applied_migrations: if target_version is None or version <= target_version: pending_migrations.append((migration_name, version)) # Sort by version pending_migrations.sort(key=lambda x: x[1]) applied = [] for migration_name, version in pending_migrations: try: logger.info(f"Applying migration: {migration_name}") migration = await self.load_migration(migration_name) await migration.up() await self.record_migration(migration) applied.append(version) logger.info(f"Successfully applied migration: {version}") except Exception as e: logger.error(f"Failed to apply migration {migration_name}: {e}") raise return applied @trace_async_operation("migration_manager.migrate_down") async def migrate_down(self, target_version: str) -> List[str]: """ Rollback migrations down to the target version. Args: target_version: Version to rollback to. Returns: List of rolled back migration versions. """ await self.initialize() applied_migrations = await self.get_applied_migrations() # Find migrations to rollback (newer than target) to_rollback = [] for version in reversed(applied_migrations): if version > target_version: to_rollback.append(version) rolled_back = [] for version in to_rollback: try: # Find migration file for this version migration_name = None for migration_file in self.discover_migrations(): if version in migration_file: migration_name = migration_file break if not migration_name: logger.warning(f"Migration file not found for version {version}") continue logger.info(f"Rolling back migration: {migration_name}") migration = await self.load_migration(migration_name) await migration.down() await self.remove_migration_record(version) rolled_back.append(version) logger.info(f"Successfully rolled back migration: {version}") except Exception as e: logger.error(f"Failed to rollback migration {version}: {e}") raise return rolled_back async def get_migration_status(self) -> dict: """Get current migration status.""" await self.initialize() all_migrations = self.discover_migrations() applied_migrations = await self.get_applied_migrations() pending_count = len(all_migrations) - len(applied_migrations) return { "total_migrations": len(all_migrations), "applied_migrations": len(applied_migrations), "pending_migrations": pending_count, "latest_applied": applied_migrations[-1] if applied_migrations else None, "all_applied": applied_migrations }