Complete Flask → FastAPI migration with: - FastAPI app with session auth, Azure AD SSO, rate limiting - SQLite-backed session store (survives restarts) - Bulk AI metadata generation with SSE progress - Admin panel (user management, audit log, AI usage) - Subpath deployment support (ROOT_PATH config) - Docker + deploy.sh for production deployment - Test suite (auth, upload, templates, imports, admin, sessions) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
525 lines
20 KiB
Python
525 lines
20 KiB
Python
"""Database management for user authentication and sessions."""
|
|
|
|
import sqlite3
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, List
|
|
from pathlib import Path
|
|
from .utils import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class Database:
|
|
"""SQLite database manager for Oliver Metadata Tool.
|
|
|
|
Uses connection-per-operation pattern for thread safety with
|
|
multiple uvicorn workers.
|
|
"""
|
|
|
|
def __init__(self, db_path: str = None):
|
|
# Auto-detect database path based on environment
|
|
if db_path is None:
|
|
DOCKER_MODE = os.getenv('DOCKER_MODE', 'false').lower() == 'true'
|
|
if DOCKER_MODE:
|
|
db_dir = Path('/app/data')
|
|
db_dir.mkdir(parents=True, exist_ok=True)
|
|
db_path = str(db_dir / 'oliver_metadata.db')
|
|
else:
|
|
db_path = 'oliver_metadata.db'
|
|
|
|
self.db_path = db_path
|
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
self._create_tables()
|
|
logger.info(f"Database initialized at {db_path}")
|
|
|
|
def _get_conn(self) -> sqlite3.Connection:
|
|
"""Create a new connection per call (thread-safe)."""
|
|
conn = sqlite3.connect(self.db_path, timeout=10)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
return conn
|
|
|
|
def _create_tables(self):
|
|
"""Create database tables if they don't exist."""
|
|
conn = self._get_conn()
|
|
try:
|
|
# Users table (with role column)
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT,
|
|
email TEXT,
|
|
full_name TEXT,
|
|
role TEXT DEFAULT 'user',
|
|
auth_method TEXT DEFAULT 'local',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
last_login TIMESTAMP,
|
|
is_active BOOLEAN DEFAULT 1
|
|
)
|
|
''')
|
|
|
|
# Sessions table
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
session_id TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at TIMESTAMP NOT NULL,
|
|
ip_address TEXT,
|
|
user_agent TEXT,
|
|
FOREIGN KEY (user_id) REFERENCES users (id)
|
|
)
|
|
''')
|
|
|
|
# Audit log table
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS audit_log (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
action TEXT NOT NULL,
|
|
details TEXT,
|
|
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (user_id) REFERENCES users (id)
|
|
)
|
|
''')
|
|
|
|
# AI usage table
|
|
conn.execute('''
|
|
CREATE TABLE IF NOT EXISTS ai_usage (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
filename TEXT,
|
|
tokens_total INTEGER DEFAULT 0,
|
|
model TEXT DEFAULT '',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY (user_id) REFERENCES users (id)
|
|
)
|
|
''')
|
|
|
|
# Indexes
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON sessions(expires_at)')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_audit_user_id ON audit_log(user_id)')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp)')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_ai_usage_user_id ON ai_usage(user_id)')
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_ai_usage_created ON ai_usage(created_at)')
|
|
|
|
conn.commit()
|
|
logger.info("Database tables created/verified")
|
|
|
|
# Add role column to existing databases (migration)
|
|
self._migrate_add_role_column(conn)
|
|
|
|
# Create test user if enabled
|
|
enable_test = os.getenv('ENABLE_TEST_USER', 'false').lower() == 'true'
|
|
if enable_test:
|
|
self._create_test_user(conn)
|
|
|
|
# Create superadmin if configured
|
|
superadmin_email = os.getenv('SUPERADMIN_EMAIL', '')
|
|
if superadmin_email:
|
|
self._create_superadmin(conn, superadmin_email)
|
|
|
|
finally:
|
|
conn.close()
|
|
|
|
def _migrate_add_role_column(self, conn: sqlite3.Connection):
|
|
"""Add role column if it doesn't exist (for existing databases)."""
|
|
try:
|
|
cursor = conn.execute("PRAGMA table_info(users)")
|
|
columns = [row['name'] for row in cursor.fetchall()]
|
|
if 'role' not in columns:
|
|
conn.execute("ALTER TABLE users ADD COLUMN role TEXT DEFAULT 'user'")
|
|
conn.commit()
|
|
logger.info("Added 'role' column to users table")
|
|
except Exception as e:
|
|
logger.error(f"Error migrating role column: {e}")
|
|
|
|
def _create_test_user(self, conn: sqlite3.Connection):
|
|
"""Create test user (tester/oliveradmin) if doesn't exist."""
|
|
try:
|
|
cursor = conn.execute('SELECT id FROM users WHERE username = ?', ('tester',))
|
|
if not cursor.fetchone():
|
|
try:
|
|
from werkzeug.security import generate_password_hash
|
|
password_hash = generate_password_hash('oliveradmin')
|
|
conn.execute(
|
|
'INSERT INTO users (username, password_hash, email, full_name, role, auth_method) VALUES (?, ?, ?, ?, ?, ?)',
|
|
('tester', password_hash, 'tester@oliver.local', 'Test User', 'user', 'local'),
|
|
)
|
|
conn.commit()
|
|
logger.info("Test user 'tester' created")
|
|
except ImportError:
|
|
logger.warning("werkzeug not available - test user not created")
|
|
except Exception as e:
|
|
logger.error(f"Error creating test user: {e}")
|
|
|
|
def _create_superadmin(self, conn: sqlite3.Connection, email: str):
|
|
"""Create or promote superadmin user."""
|
|
try:
|
|
username = email.split('@')[0]
|
|
cursor = conn.execute('SELECT id, role FROM users WHERE username = ? OR email = ?', (username, email))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
if row['role'] != 'admin':
|
|
conn.execute('UPDATE users SET role = ? WHERE id = ?', ('admin', row['id']))
|
|
conn.commit()
|
|
logger.info(f"Promoted user '{username}' to admin")
|
|
else:
|
|
conn.execute(
|
|
'INSERT INTO users (username, email, full_name, role, auth_method) VALUES (?, ?, ?, ?, ?)',
|
|
(username, email, username, 'admin', 'sso'),
|
|
)
|
|
conn.commit()
|
|
logger.info(f"Created superadmin user '{username}' ({email})")
|
|
except Exception as e:
|
|
logger.error(f"Error creating superadmin: {e}")
|
|
|
|
# --- User Operations ---
|
|
|
|
def get_user_by_username(self, username: str) -> Optional[Dict]:
|
|
"""Get user by username."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute('SELECT * FROM users WHERE username = ? AND is_active = 1', (username,))
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
except Exception as e:
|
|
logger.error(f"Error fetching user '{username}': {e}")
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
|
"""Get user by ID."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute('SELECT * FROM users WHERE id = ? AND is_active = 1', (user_id,))
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
except Exception as e:
|
|
logger.error(f"Error fetching user ID {user_id}: {e}")
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def create_user(
|
|
self,
|
|
username: str,
|
|
password_hash: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
full_name: Optional[str] = None,
|
|
auth_method: str = 'local',
|
|
role: str = 'user',
|
|
) -> Optional[int]:
|
|
"""Create a new user. Returns user ID if successful."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute(
|
|
'INSERT INTO users (username, password_hash, email, full_name, role, auth_method) VALUES (?, ?, ?, ?, ?, ?)',
|
|
(username, password_hash, email, full_name, role, auth_method),
|
|
)
|
|
conn.commit()
|
|
user_id = cursor.lastrowid
|
|
logger.info(f"Created user '{username}' (ID: {user_id})")
|
|
return user_id
|
|
except sqlite3.IntegrityError:
|
|
logger.warning(f"User '{username}' already exists")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error creating user '{username}': {e}")
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def update_last_login(self, user_id: int):
|
|
"""Update user's last login timestamp."""
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute('UPDATE users SET last_login = CURRENT_TIMESTAMP WHERE id = ?', (user_id,))
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error updating last login for user {user_id}: {e}")
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- Session Operations ---
|
|
|
|
def create_session(
|
|
self,
|
|
user_id: int,
|
|
session_id: str,
|
|
expires_in_hours: int = 24,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None,
|
|
) -> bool:
|
|
"""Create new session for user."""
|
|
conn = self._get_conn()
|
|
try:
|
|
expires_at = datetime.now() + timedelta(hours=expires_in_hours)
|
|
conn.execute(
|
|
'INSERT INTO sessions (session_id, user_id, expires_at, ip_address, user_agent) VALUES (?, ?, ?, ?, ?)',
|
|
(session_id, user_id, expires_at, ip_address, user_agent),
|
|
)
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error creating session: {e}")
|
|
return False
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_session(self, session_id: str) -> Optional[Dict]:
|
|
"""Get session by ID. Returns None if expired or not found."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute('''
|
|
SELECT s.*, u.username, u.email, u.full_name
|
|
FROM sessions s
|
|
JOIN users u ON s.user_id = u.id
|
|
WHERE s.session_id = ? AND s.expires_at > CURRENT_TIMESTAMP
|
|
''', (session_id,))
|
|
row = cursor.fetchone()
|
|
return dict(row) if row else None
|
|
except Exception as e:
|
|
logger.error(f"Error fetching session: {e}")
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def delete_session(self, session_id: str) -> bool:
|
|
"""Delete session (logout)."""
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute('DELETE FROM sessions WHERE session_id = ?', (session_id,))
|
|
conn.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting session: {e}")
|
|
return False
|
|
finally:
|
|
conn.close()
|
|
|
|
def cleanup_expired_sessions(self):
|
|
"""Remove expired sessions from database."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute('DELETE FROM sessions WHERE expires_at < CURRENT_TIMESTAMP')
|
|
conn.commit()
|
|
deleted_count = cursor.rowcount
|
|
if deleted_count > 0:
|
|
logger.info(f"Cleaned up {deleted_count} expired sessions")
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up sessions: {e}")
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- Audit Log ---
|
|
|
|
def log_action(self, user_id: int, action: str, details: Optional[str] = None):
|
|
"""Log user action to audit trail."""
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute(
|
|
'INSERT INTO audit_log (user_id, action, details) VALUES (?, ?, ?)',
|
|
(user_id, action, details),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error logging action: {e}")
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_user_activity(self, user_id: int, limit: int = 100, offset: int = 0) -> List[Dict]:
|
|
"""Get user activity log."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute(
|
|
'SELECT * FROM audit_log WHERE user_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?',
|
|
(user_id, limit, offset),
|
|
)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching user activity: {e}")
|
|
return []
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_all_users(self, include_inactive: bool = False) -> List[Dict]:
|
|
"""Get all users."""
|
|
conn = self._get_conn()
|
|
try:
|
|
query = 'SELECT * FROM users'
|
|
if not include_inactive:
|
|
query += ' WHERE is_active = 1'
|
|
query += ' ORDER BY created_at DESC'
|
|
cursor = conn.execute(query)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching users: {e}")
|
|
return []
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get database statistics."""
|
|
conn = self._get_conn()
|
|
try:
|
|
stats = {}
|
|
cursor = conn.execute('SELECT COUNT(*) as count FROM users WHERE is_active = 1')
|
|
stats['active_users'] = cursor.fetchone()['count']
|
|
|
|
cursor = conn.execute('SELECT COUNT(*) as count FROM sessions WHERE expires_at > CURRENT_TIMESTAMP')
|
|
stats['active_sessions'] = cursor.fetchone()['count']
|
|
|
|
cursor = conn.execute('SELECT COUNT(*) as count FROM audit_log')
|
|
stats['audit_entries'] = cursor.fetchone()['count']
|
|
|
|
cursor = conn.execute("SELECT COUNT(*) as count FROM audit_log WHERE timestamp > datetime('now', '-24 hours')")
|
|
stats['recent_activity'] = cursor.fetchone()['count']
|
|
|
|
return stats
|
|
except Exception as e:
|
|
logger.error(f"Error fetching stats: {e}")
|
|
return {}
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- User Update ---
|
|
|
|
def update_user(self, user_id: int, updates: Dict) -> bool:
|
|
"""Update user fields. Returns True on success."""
|
|
allowed = {'role', 'is_active', 'full_name', 'email'}
|
|
filtered = {k: v for k, v in updates.items() if k in allowed}
|
|
if not filtered:
|
|
return False
|
|
conn = self._get_conn()
|
|
try:
|
|
set_clause = ', '.join(f'{k} = ?' for k in filtered)
|
|
values = list(filtered.values()) + [user_id]
|
|
conn.execute(f'UPDATE users SET {set_clause} WHERE id = ?', values)
|
|
conn.commit()
|
|
return conn.total_changes > 0
|
|
except Exception as e:
|
|
logger.error(f"Error updating user {user_id}: {e}")
|
|
return False
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- Audit Log (extended) ---
|
|
|
|
def get_audit_log(
|
|
self,
|
|
user_id: Optional[int] = None,
|
|
action: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> List[Dict]:
|
|
"""Get audit log with optional filters."""
|
|
conn = self._get_conn()
|
|
try:
|
|
query = '''
|
|
SELECT a.*, u.username
|
|
FROM audit_log a
|
|
LEFT JOIN users u ON a.user_id = u.id
|
|
'''
|
|
conditions = []
|
|
params = []
|
|
if user_id is not None:
|
|
conditions.append('a.user_id = ?')
|
|
params.append(user_id)
|
|
if action:
|
|
conditions.append('a.action = ?')
|
|
params.append(action)
|
|
if conditions:
|
|
query += ' WHERE ' + ' AND '.join(conditions)
|
|
query += ' ORDER BY a.timestamp DESC LIMIT ? OFFSET ?'
|
|
params.extend([limit, offset])
|
|
cursor = conn.execute(query, params)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching audit log: {e}")
|
|
return []
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- AI Usage ---
|
|
|
|
def log_ai_usage(
|
|
self,
|
|
user_id: int,
|
|
filename: str = "",
|
|
tokens_total: int = 0,
|
|
model: str = "",
|
|
):
|
|
"""Log AI token usage for a file."""
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute(
|
|
'INSERT INTO ai_usage (user_id, filename, tokens_total, model) VALUES (?, ?, ?, ?)',
|
|
(user_id, filename, tokens_total, model),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.error(f"Error logging AI usage: {e}")
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_ai_usage_stats(self) -> Dict:
|
|
"""Get aggregate AI usage statistics."""
|
|
conn = self._get_conn()
|
|
try:
|
|
stats = {}
|
|
cursor = conn.execute('SELECT COUNT(*) as count, COALESCE(SUM(tokens_total), 0) as total_tokens FROM ai_usage')
|
|
row = cursor.fetchone()
|
|
stats['total_requests'] = row['count']
|
|
stats['total_tokens'] = row['total_tokens']
|
|
|
|
cursor = conn.execute(
|
|
"SELECT COUNT(*) as count, COALESCE(SUM(tokens_total), 0) as tokens FROM ai_usage WHERE created_at > datetime('now', '-24 hours')"
|
|
)
|
|
row = cursor.fetchone()
|
|
stats['requests_24h'] = row['count']
|
|
stats['tokens_24h'] = row['tokens']
|
|
|
|
cursor = conn.execute(
|
|
"SELECT COUNT(*) as count, COALESCE(SUM(tokens_total), 0) as tokens FROM ai_usage WHERE created_at > datetime('now', '-7 days')"
|
|
)
|
|
row = cursor.fetchone()
|
|
stats['requests_7d'] = row['count']
|
|
stats['tokens_7d'] = row['tokens']
|
|
|
|
return stats
|
|
except Exception as e:
|
|
logger.error(f"Error fetching AI usage stats: {e}")
|
|
return {}
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_ai_usage_by_user(self, limit: int = 50) -> List[Dict]:
|
|
"""Get AI usage broken down by user."""
|
|
conn = self._get_conn()
|
|
try:
|
|
cursor = conn.execute('''
|
|
SELECT u.username, u.id as user_id,
|
|
COUNT(*) as request_count,
|
|
COALESCE(SUM(a.tokens_total), 0) as total_tokens,
|
|
MAX(a.created_at) as last_used
|
|
FROM ai_usage a
|
|
JOIN users u ON a.user_id = u.id
|
|
GROUP BY u.id
|
|
ORDER BY total_tokens DESC
|
|
LIMIT ?
|
|
''', (limit,))
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
except Exception as e:
|
|
logger.error(f"Error fetching AI usage by user: {e}")
|
|
return []
|
|
finally:
|
|
conn.close()
|
|
|
|
def close(self):
|
|
"""No-op for connection-per-operation pattern."""
|
|
pass
|