253 lines
8.7 KiB
Python
253 lines
8.7 KiB
Python
"""MongoDB migration framework."""
|
|
|
|
import importlib.util
|
|
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from motor.motor_asyncio import AsyncIOMotorDatabase
|
|
|
|
from app.core.database import get_database
|
|
from app.core.logging import get_logger
|
|
from app.telemetry.tracing import trace_async_operation
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Migration(ABC):
|
|
"""Base class for database migrations."""
|
|
|
|
version: str = "0000-00-00-000000" # overridden by subclass as class variable
|
|
description: str = ""
|
|
|
|
def __init__(self):
|
|
self.db: AsyncIOMotorDatabase | None = None
|
|
|
|
@abstractmethod
|
|
async def up(self) -> None:
|
|
"""Apply the migration."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def down(self) -> None:
|
|
"""Rollback the migration."""
|
|
pass
|
|
|
|
async def set_database(self, db: AsyncIOMotorDatabase) -> None:
|
|
"""Set the database instance."""
|
|
self.db = db
|
|
|
|
|
|
class MigrationRecord:
|
|
"""Represents a migration record in the database."""
|
|
|
|
def __init__(self, version: str, description: str, applied_at: datetime):
|
|
self.version = version
|
|
self.description = description
|
|
self.applied_at = applied_at
|
|
|
|
|
|
class MigrationManager:
|
|
"""Manages database migrations."""
|
|
|
|
def __init__(self):
|
|
self.db: AsyncIOMotorDatabase | None = None
|
|
self.migrations_dir = Path(__file__).parent / "scripts"
|
|
self.collection_name = "migration_history"
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the migration manager."""
|
|
self.db = await get_database()
|
|
await self._ensure_migration_collection()
|
|
|
|
async def _ensure_migration_collection(self) -> None:
|
|
"""Ensure the migration history collection exists with proper indexes."""
|
|
collection = self.db[self.collection_name]
|
|
|
|
# Create indexes for migration history
|
|
await collection.create_index([("version", 1)], unique=True)
|
|
await collection.create_index([("applied_at", -1)])
|
|
|
|
logger.info("Migration history collection initialized")
|
|
|
|
def discover_migrations(self) -> list[str]:
|
|
"""Discover all migration files in the migrations directory."""
|
|
if not self.migrations_dir.exists():
|
|
logger.warning(f"Migrations directory not found: {self.migrations_dir}")
|
|
return []
|
|
|
|
migration_files = []
|
|
for file_path in self.migrations_dir.glob("*.py"):
|
|
if file_path.name.startswith("migration_") and not file_path.name.startswith("__"):
|
|
migration_files.append(file_path.stem)
|
|
|
|
# Sort by version (filename should start with version)
|
|
migration_files.sort()
|
|
return migration_files
|
|
|
|
async def load_migration(self, migration_name: str) -> Migration:
|
|
"""Dynamically load a migration class."""
|
|
migration_path = self.migrations_dir / f"{migration_name}.py"
|
|
|
|
if not migration_path.exists():
|
|
raise FileNotFoundError(f"Migration file not found: {migration_path}")
|
|
|
|
# Load the module
|
|
spec = importlib.util.spec_from_file_location(migration_name, migration_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
|
|
# Get the migration class (assume it's named Migration)
|
|
if not hasattr(module, 'Migration'):
|
|
raise AttributeError(f"Migration class not found in {migration_name}")
|
|
|
|
migration_class = module.Migration
|
|
migration = migration_class()
|
|
await migration.set_database(self.db)
|
|
|
|
return migration
|
|
|
|
async def get_applied_migrations(self) -> list[str]:
|
|
"""Get list of applied migration versions."""
|
|
collection = self.db[self.collection_name]
|
|
cursor = collection.find({}, {"version": 1}).sort("version", 1)
|
|
|
|
applied = []
|
|
async for doc in cursor:
|
|
applied.append(doc["version"])
|
|
|
|
return applied
|
|
|
|
async def record_migration(self, migration: Migration) -> None:
|
|
"""Record a successful migration in the database."""
|
|
collection = self.db[self.collection_name]
|
|
|
|
record = {
|
|
"version": migration.version,
|
|
"description": migration.description,
|
|
"applied_at": datetime.utcnow()
|
|
}
|
|
|
|
await collection.insert_one(record)
|
|
logger.info(f"Recorded migration: {migration.version} - {migration.description}")
|
|
|
|
async def remove_migration_record(self, version: str) -> None:
|
|
"""Remove a migration record (for rollback)."""
|
|
collection = self.db[self.collection_name]
|
|
await collection.delete_one({"version": version})
|
|
logger.info(f"Removed migration record: {version}")
|
|
|
|
@trace_async_operation("migration_manager.migrate_up")
|
|
async def migrate_up(self, target_version: str | None = None) -> list[str]:
|
|
"""
|
|
Apply migrations up to the target version.
|
|
|
|
Args:
|
|
target_version: Version to migrate to. If None, applies all pending migrations.
|
|
|
|
Returns:
|
|
List of applied migration versions.
|
|
"""
|
|
await self.initialize()
|
|
|
|
# Discover all migrations
|
|
all_migrations = self.discover_migrations()
|
|
applied_migrations = await self.get_applied_migrations()
|
|
|
|
# Find pending migrations
|
|
pending_migrations = []
|
|
for migration_name in all_migrations:
|
|
# Extract version from filename (assumes format: migration_YYYY-MM-DD-HHMMSS_description.py)
|
|
version = migration_name.replace("migration_", "").split("_")[0]
|
|
|
|
if version not in applied_migrations:
|
|
if target_version is None or version <= target_version:
|
|
pending_migrations.append((migration_name, version))
|
|
|
|
# Sort by version
|
|
pending_migrations.sort(key=lambda x: x[1])
|
|
|
|
applied = []
|
|
for migration_name, version in pending_migrations:
|
|
try:
|
|
logger.info(f"Applying migration: {migration_name}")
|
|
|
|
migration = await self.load_migration(migration_name)
|
|
await migration.up()
|
|
await self.record_migration(migration)
|
|
|
|
applied.append(version)
|
|
logger.info(f"Successfully applied migration: {version}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to apply migration {migration_name}: {e}")
|
|
raise
|
|
|
|
return applied
|
|
|
|
@trace_async_operation("migration_manager.migrate_down")
|
|
async def migrate_down(self, target_version: str) -> list[str]:
|
|
"""
|
|
Rollback migrations down to the target version.
|
|
|
|
Args:
|
|
target_version: Version to rollback to.
|
|
|
|
Returns:
|
|
List of rolled back migration versions.
|
|
"""
|
|
await self.initialize()
|
|
|
|
applied_migrations = await self.get_applied_migrations()
|
|
|
|
# Find migrations to rollback (newer than target)
|
|
to_rollback = []
|
|
for version in reversed(applied_migrations):
|
|
if version > target_version:
|
|
to_rollback.append(version)
|
|
|
|
rolled_back = []
|
|
for version in to_rollback:
|
|
try:
|
|
# Find migration file for this version
|
|
migration_name = None
|
|
for migration_file in self.discover_migrations():
|
|
if version in migration_file:
|
|
migration_name = migration_file
|
|
break
|
|
|
|
if not migration_name:
|
|
logger.warning(f"Migration file not found for version {version}")
|
|
continue
|
|
|
|
logger.info(f"Rolling back migration: {migration_name}")
|
|
|
|
migration = await self.load_migration(migration_name)
|
|
await migration.down()
|
|
await self.remove_migration_record(version)
|
|
|
|
rolled_back.append(version)
|
|
logger.info(f"Successfully rolled back migration: {version}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to rollback migration {version}: {e}")
|
|
raise
|
|
|
|
return rolled_back
|
|
|
|
async def get_migration_status(self) -> dict:
|
|
"""Get current migration status."""
|
|
await self.initialize()
|
|
|
|
all_migrations = self.discover_migrations()
|
|
applied_migrations = await self.get_applied_migrations()
|
|
|
|
pending_count = len(all_migrations) - len(applied_migrations)
|
|
|
|
return {
|
|
"total_migrations": len(all_migrations),
|
|
"applied_migrations": len(applied_migrations),
|
|
"pending_migrations": pending_count,
|
|
"latest_applied": applied_migrations[-1] if applied_migrations else None,
|
|
"all_applied": applied_migrations
|
|
}
|