"""Database management for user authentication and sessions.""" import sqlite3 import os from datetime import datetime, timedelta from typing import Optional, Dict, List from pathlib import Path from .utils import get_logger logger = get_logger(__name__) class Database: """SQLite database manager for Oliver Metadata Tool. Uses connection-per-operation pattern for thread safety with multiple uvicorn workers. """ def __init__(self, db_path: str = None): # Auto-detect database path based on environment if db_path is None: DOCKER_MODE = os.getenv('DOCKER_MODE', 'false').lower() == 'true' if DOCKER_MODE: db_dir = Path('/app/data') db_dir.mkdir(parents=True, exist_ok=True) db_path = str(db_dir / 'oliver_metadata.db') else: db_path = 'oliver_metadata.db' self.db_path = db_path Path(db_path).parent.mkdir(parents=True, exist_ok=True) self._create_tables() logger.info(f"Database initialized at {db_path}") def _get_conn(self) -> sqlite3.Connection: """Create a new connection per call (thread-safe).""" conn = sqlite3.connect(self.db_path, timeout=10) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") return conn def _create_tables(self): """Create database tables if they don't exist.""" conn = self._get_conn() try: # Users table (with role column) conn.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password_hash TEXT, email TEXT, full_name TEXT, role TEXT DEFAULT 'user', auth_method TEXT DEFAULT 'local', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, last_login TIMESTAMP, is_active BOOLEAN DEFAULT 1 ) ''') # Sessions table conn.execute(''' CREATE TABLE IF NOT EXISTS sessions ( session_id TEXT PRIMARY KEY, user_id INTEGER NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, expires_at TIMESTAMP NOT NULL, ip_address TEXT, user_agent TEXT, FOREIGN KEY (user_id) REFERENCES users (id) ) ''') # Audit log table conn.execute(''' CREATE TABLE IF NOT EXISTS audit_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, action TEXT NOT NULL, details TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users (id) ) ''') # AI usage table conn.execute(''' CREATE TABLE IF NOT EXISTS ai_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, filename TEXT, tokens_total INTEGER DEFAULT 0, model TEXT DEFAULT '', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users (id) ) ''') # Indexes conn.execute('CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)') conn.execute('CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at)') conn.execute('CREATE INDEX IF NOT EXISTS idx_audit_user_id ON audit_log(user_id)') conn.execute('CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)') conn.execute('CREATE INDEX IF NOT EXISTS idx_ai_usage_user_id ON ai_usage(user_id)') conn.execute('CREATE INDEX IF NOT EXISTS idx_ai_usage_created ON ai_usage(created_at)') conn.commit() logger.info("Database tables created/verified") # Add role column to existing databases (migration) self._migrate_add_role_column(conn) # Create test user if enabled enable_test = os.getenv('ENABLE_TEST_USER', 'false').lower() == 'true' if enable_test: self._create_test_user(conn) # Create superadmin if configured superadmin_email = os.getenv('SUPERADMIN_EMAIL', '') if superadmin_email: self._create_superadmin(conn, superadmin_email) finally: conn.close() def _migrate_add_role_column(self, conn: sqlite3.Connection): """Add role column if it doesn't exist (for existing databases).""" try: cursor = conn.execute("PRAGMA table_info(users)") columns = [row['name'] for row in cursor.fetchall()] if 'role' not in columns: conn.execute("ALTER TABLE users ADD COLUMN role TEXT DEFAULT 'user'") conn.commit() logger.info("Added 'role' column to users table") except Exception as e: logger.error(f"Error migrating role column: {e}") def _create_test_user(self, conn: sqlite3.Connection): """Create test user (tester/oliveradmin) if doesn't exist.""" try: cursor = conn.execute('SELECT id FROM users WHERE username = ?', ('tester',)) if not cursor.fetchone(): try: from werkzeug.security import generate_password_hash password_hash = generate_password_hash('oliveradmin') conn.execute( 'INSERT INTO users (username, password_hash, email, full_name, role, auth_method) VALUES (?, ?, ?, ?, ?, ?)', ('tester', password_hash, 'tester@oliver.local', 'Test User', 'user', 'local'), ) conn.commit() logger.info("Test user 'tester' created") except ImportError: logger.warning("werkzeug not available - test user not created") except Exception as e: logger.error(f"Error creating test user: {e}") def _create_superadmin(self, conn: sqlite3.Connection, email: str): """Create or promote superadmin user.""" try: username = email.split('@')[0] cursor = conn.execute('SELECT id, role FROM users WHERE username = ? OR email = ?', (username, email)) row = cursor.fetchone() if row: if row['role'] != 'admin': conn.execute('UPDATE users SET role = ? WHERE id = ?', ('admin', row['id'])) conn.commit() logger.info(f"Promoted user '{username}' to admin") else: conn.execute( 'INSERT INTO users (username, email, full_name, role, auth_method) VALUES (?, ?, ?, ?, ?)', (username, email, username, 'admin', 'sso'), ) conn.commit() logger.info(f"Created superadmin user '{username}' ({email})") except Exception as e: logger.error(f"Error creating superadmin: {e}") # --- User Operations --- def get_user_by_username(self, username: str) -> Optional[Dict]: """Get user by username.""" conn = self._get_conn() try: cursor = conn.execute('SELECT * FROM users WHERE username = ? AND is_active = 1', (username,)) row = cursor.fetchone() return dict(row) if row else None except Exception as e: logger.error(f"Error fetching user '{username}': {e}") return None finally: conn.close() def get_user_by_id(self, user_id: int) -> Optional[Dict]: """Get user by ID.""" conn = self._get_conn() try: cursor = conn.execute('SELECT * FROM users WHERE id = ? AND is_active = 1', (user_id,)) row = cursor.fetchone() return dict(row) if row else None except Exception as e: logger.error(f"Error fetching user ID {user_id}: {e}") return None finally: conn.close() def create_user( self, username: str, password_hash: Optional[str] = None, email: Optional[str] = None, full_name: Optional[str] = None, auth_method: str = 'local', role: str = 'user', ) -> Optional[int]: """Create a new user. Returns user ID if successful.""" conn = self._get_conn() try: cursor = conn.execute( 'INSERT INTO users (username, password_hash, email, full_name, role, auth_method) VALUES (?, ?, ?, ?, ?, ?)', (username, password_hash, email, full_name, role, auth_method), ) conn.commit() user_id = cursor.lastrowid logger.info(f"Created user '{username}' (ID: {user_id})") return user_id except sqlite3.IntegrityError: logger.warning(f"User '{username}' already exists") return None except Exception as e: logger.error(f"Error creating user '{username}': {e}") return None finally: conn.close() def update_last_login(self, user_id: int): """Update user's last login timestamp.""" conn = self._get_conn() try: conn.execute('UPDATE users SET last_login = CURRENT_TIMESTAMP WHERE id = ?', (user_id,)) conn.commit() except Exception as e: logger.error(f"Error updating last login for user {user_id}: {e}") finally: conn.close() # --- Session Operations --- def create_session( self, user_id: int, session_id: str, expires_in_hours: int = 24, ip_address: Optional[str] = None, user_agent: Optional[str] = None, ) -> bool: """Create new session for user.""" conn = self._get_conn() try: expires_at = datetime.now() + timedelta(hours=expires_in_hours) conn.execute( 'INSERT INTO sessions (session_id, user_id, expires_at, ip_address, user_agent) VALUES (?, ?, ?, ?, ?)', (session_id, user_id, expires_at, ip_address, user_agent), ) conn.commit() return True except Exception as e: logger.error(f"Error creating session: {e}") return False finally: conn.close() def get_session(self, session_id: str) -> Optional[Dict]: """Get session by ID. Returns None if expired or not found.""" conn = self._get_conn() try: cursor = conn.execute(''' SELECT s.*, u.username, u.email, u.full_name FROM sessions s JOIN users u ON s.user_id = u.id WHERE s.session_id = ? AND s.expires_at > CURRENT_TIMESTAMP ''', (session_id,)) row = cursor.fetchone() return dict(row) if row else None except Exception as e: logger.error(f"Error fetching session: {e}") return None finally: conn.close() def delete_session(self, session_id: str) -> bool: """Delete session (logout).""" conn = self._get_conn() try: conn.execute('DELETE FROM sessions WHERE session_id = ?', (session_id,)) conn.commit() return True except Exception as e: logger.error(f"Error deleting session: {e}") return False finally: conn.close() def cleanup_expired_sessions(self): """Remove expired sessions from database.""" conn = self._get_conn() try: cursor = conn.execute('DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP') conn.commit() deleted_count = cursor.rowcount if deleted_count > 0: logger.info(f"Cleaned up {deleted_count} expired sessions") except Exception as e: logger.error(f"Error cleaning up sessions: {e}") finally: conn.close() # --- Audit Log --- def log_action(self, user_id: int, action: str, details: Optional[str] = None): """Log user action to audit trail.""" conn = self._get_conn() try: conn.execute( 'INSERT INTO audit_log (user_id, action, details) VALUES (?, ?, ?)', (user_id, action, details), ) conn.commit() except Exception as e: logger.error(f"Error logging action: {e}") finally: conn.close() def get_user_activity(self, user_id: int, limit: int = 100, offset: int = 0) -> List[Dict]: """Get user activity log.""" conn = self._get_conn() try: cursor = conn.execute( 'SELECT * FROM audit_log WHERE user_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?', (user_id, limit, offset), ) return [dict(row) for row in cursor.fetchall()] except Exception as e: logger.error(f"Error fetching user activity: {e}") return [] finally: conn.close() def get_all_users(self, include_inactive: bool = False) -> List[Dict]: """Get all users.""" conn = self._get_conn() try: query = 'SELECT * FROM users' if not include_inactive: query += ' WHERE is_active = 1' query += ' ORDER BY created_at DESC' cursor = conn.execute(query) return [dict(row) for row in cursor.fetchall()] except Exception as e: logger.error(f"Error fetching users: {e}") return [] finally: conn.close() def get_stats(self) -> Dict: """Get database statistics.""" conn = self._get_conn() try: stats = {} cursor = conn.execute('SELECT COUNT(*) as count FROM users WHERE is_active = 1') stats['active_users'] = cursor.fetchone()['count'] cursor = conn.execute('SELECT COUNT(*) as count FROM sessions WHERE expires_at > CURRENT_TIMESTAMP') stats['active_sessions'] = cursor.fetchone()['count'] cursor = conn.execute('SELECT COUNT(*) as count FROM audit_log') stats['audit_entries'] = cursor.fetchone()['count'] cursor = conn.execute("SELECT COUNT(*) as count FROM audit_log WHERE timestamp > datetime('now', '-24 hours')") stats['recent_activity'] = cursor.fetchone()['count'] return stats except Exception as e: logger.error(f"Error fetching stats: {e}") return {} finally: conn.close() # --- User Update --- def update_user(self, user_id: int, updates: Dict) -> bool: """Update user fields. Returns True on success.""" allowed = {'role', 'is_active', 'full_name', 'email'} filtered = {k: v for k, v in updates.items() if k in allowed} if not filtered: return False conn = self._get_conn() try: set_clause = ', '.join(f'{k} = ?' for k in filtered) values = list(filtered.values()) + [user_id] conn.execute(f'UPDATE users SET {set_clause} WHERE id = ?', values) conn.commit() return conn.total_changes > 0 except Exception as e: logger.error(f"Error updating user {user_id}: {e}") return False finally: conn.close() # --- Audit Log (extended) --- def get_audit_log( self, user_id: Optional[int] = None, action: Optional[str] = None, limit: int = 100, offset: int = 0, ) -> List[Dict]: """Get audit log with optional filters.""" conn = self._get_conn() try: query = ''' SELECT a.*, u.username FROM audit_log a LEFT JOIN users u ON a.user_id = u.id ''' conditions = [] params = [] if user_id is not None: conditions.append('a.user_id = ?') params.append(user_id) if action: conditions.append('a.action = ?') params.append(action) if conditions: query += ' WHERE ' + ' AND '.join(conditions) query += ' ORDER BY a.timestamp DESC LIMIT ? OFFSET ?' params.extend([limit, offset]) cursor = conn.execute(query, params) return [dict(row) for row in cursor.fetchall()] except Exception as e: logger.error(f"Error fetching audit log: {e}") return [] finally: conn.close() # --- AI Usage --- def log_ai_usage( self, user_id: int, filename: str = "", tokens_total: int = 0, model: str = "", ): """Log AI token usage for a file.""" conn = self._get_conn() try: conn.execute( 'INSERT INTO ai_usage (user_id, filename, tokens_total, model) VALUES (?, ?, ?, ?)', (user_id, filename, tokens_total, model), ) conn.commit() except Exception as e: logger.error(f"Error logging AI usage: {e}") finally: conn.close() def get_ai_usage_stats(self) -> Dict: """Get aggregate AI usage statistics.""" conn = self._get_conn() try: stats = {} cursor = conn.execute('SELECT COUNT(*) as count, COALESCE(SUM(tokens_total), 0) as total_tokens FROM ai_usage') row = cursor.fetchone() stats['total_requests'] = row['count'] stats['total_tokens'] = row['total_tokens'] cursor = conn.execute( "SELECT COUNT(*) as count, COALESCE(SUM(tokens_total), 0) as tokens FROM ai_usage WHERE created_at > datetime('now', '-24 hours')" ) row = cursor.fetchone() stats['requests_24h'] = row['count'] stats['tokens_24h'] = row['tokens'] cursor = conn.execute( "SELECT COUNT(*) as count, COALESCE(SUM(tokens_total), 0) as tokens FROM ai_usage WHERE created_at > datetime('now', '-7 days')" ) row = cursor.fetchone() stats['requests_7d'] = row['count'] stats['tokens_7d'] = row['tokens'] return stats except Exception as e: logger.error(f"Error fetching AI usage stats: {e}") return {} finally: conn.close() def get_ai_usage_by_user(self, limit: int = 50) -> List[Dict]: """Get AI usage broken down by user.""" conn = self._get_conn() try: cursor = conn.execute(''' SELECT u.username, u.id as user_id, COUNT(*) as request_count, COALESCE(SUM(a.tokens_total), 0) as total_tokens, MAX(a.created_at) as last_used FROM ai_usage a JOIN users u ON a.user_id = u.id GROUP BY u.id ORDER BY total_tokens DESC LIMIT ? ''', (limit,)) return [dict(row) for row in cursor.fetchall()] except Exception as e: logger.error(f"Error fetching AI usage by user: {e}") return [] finally: conn.close() def close(self): """No-op for connection-per-operation pattern.""" pass