Complete Flask → FastAPI migration with: - FastAPI app with session auth, Azure AD SSO, rate limiting - SQLite-backed session store (survives restarts) - Bulk AI metadata generation with SSE progress - Admin panel (user management, audit log, AI usage) - Subpath deployment support (ROOT_PATH config) - Docker + deploy.sh for production deployment - Test suite (auth, upload, templates, imports, admin, sessions) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
298 lines
11 KiB
Python
298 lines
11 KiB
Python
"""SQLite-backed session store for file processing and import sessions."""
|
|
|
|
import json
|
|
import sqlite3
|
|
import secrets
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, List, Any
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SessionStore:
|
|
"""Persistent session store replacing in-memory dicts.
|
|
|
|
Stores file processing sessions and imported metadata maps in SQLite,
|
|
surviving server restarts and supporting multi-worker deployments.
|
|
"""
|
|
|
|
def __init__(self, db_path: str):
|
|
self.db_path = db_path
|
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_tables()
|
|
|
|
def _get_conn(self) -> sqlite3.Connection:
|
|
"""Create a new connection per call (thread-safe)."""
|
|
conn = sqlite3.connect(self.db_path, timeout=10)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
return conn
|
|
|
|
def _init_tables(self):
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS file_sessions (
|
|
session_id TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
metadata_source TEXT DEFAULT 'manual',
|
|
import_session_id TEXT DEFAULT '',
|
|
files_json TEXT DEFAULT '[]',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS import_sessions (
|
|
session_id TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
session_type TEXT DEFAULT 'import',
|
|
metadata_json TEXT DEFAULT '{}',
|
|
file_info_json TEXT DEFAULT '{}',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_fs_user ON file_sessions(user_id)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_fs_expires ON file_sessions(expires_at)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_is_user ON import_sessions(user_id)")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_is_expires ON import_sessions(expires_at)")
|
|
conn.commit()
|
|
logger.info(f"Session store initialized at {self.db_path}")
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- File Sessions ---
|
|
|
|
def create_file_session(
|
|
self,
|
|
user_id: int,
|
|
metadata_source: str = "manual",
|
|
import_session_id: str = "",
|
|
expires_hours: int = 24,
|
|
) -> str:
|
|
"""Create a new file processing session with a secure random ID."""
|
|
session_id = secrets.token_urlsafe(32)
|
|
expires_at = datetime.now() + timedelta(hours=expires_hours)
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute(
|
|
"INSERT INTO file_sessions (session_id, user_id, metadata_source, import_session_id, expires_at) VALUES (?,?,?,?,?)",
|
|
(session_id, user_id, metadata_source, import_session_id, expires_at),
|
|
)
|
|
conn.commit()
|
|
logger.info(f"Created file session {session_id[:8]}... for user {user_id}")
|
|
return session_id
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_file_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get file session by ID. Returns None if expired or not found."""
|
|
conn = self._get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT * FROM file_sessions WHERE session_id = ? AND expires_at > datetime('now')",
|
|
(session_id,),
|
|
).fetchone()
|
|
if row:
|
|
result = dict(row)
|
|
result["files"] = json.loads(result.pop("files_json"))
|
|
return result
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def add_file_to_session(self, session_id: str, file_entry: Dict[str, Any]):
|
|
"""Add a processed file entry to a session."""
|
|
conn = self._get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT files_json FROM file_sessions WHERE session_id = ?",
|
|
(session_id,),
|
|
).fetchone()
|
|
if row:
|
|
files = json.loads(row["files_json"])
|
|
files.append(file_entry)
|
|
conn.execute(
|
|
"UPDATE file_sessions SET files_json = ? WHERE session_id = ?",
|
|
(json.dumps(files, ensure_ascii=False), session_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def update_file_in_session(
|
|
self, session_id: str, file_index: int, updates: Dict[str, Any]
|
|
):
|
|
"""Update specific fields of a file entry within a session."""
|
|
conn = self._get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT files_json FROM file_sessions WHERE session_id = ?",
|
|
(session_id,),
|
|
).fetchone()
|
|
if row:
|
|
files = json.loads(row["files_json"])
|
|
if 0 <= file_index < len(files):
|
|
files[file_index].update(updates)
|
|
conn.execute(
|
|
"UPDATE file_sessions SET files_json = ? WHERE session_id = ?",
|
|
(json.dumps(files, ensure_ascii=False), session_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_file_session_files(self, session_id: str) -> List[Dict[str, Any]]:
|
|
"""Get just the files list from a session."""
|
|
session = self.get_file_session(session_id)
|
|
if session:
|
|
return session["files"]
|
|
return []
|
|
|
|
def delete_file_session(self, session_id: str):
|
|
"""Delete a file session."""
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute("DELETE FROM file_sessions WHERE session_id = ?", (session_id,))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_user_file_sessions(self, user_id: int) -> List[str]:
|
|
"""Get all active session IDs for a user."""
|
|
conn = self._get_conn()
|
|
try:
|
|
rows = conn.execute(
|
|
"SELECT session_id FROM file_sessions WHERE user_id = ? AND expires_at > datetime('now')",
|
|
(user_id,),
|
|
).fetchall()
|
|
return [row["session_id"] for row in rows]
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- Import Sessions ---
|
|
|
|
def create_import_session(
|
|
self,
|
|
user_id: int,
|
|
session_type: str = "import",
|
|
metadata_map: Optional[Dict] = None,
|
|
file_info: Optional[Dict] = None,
|
|
expires_hours: int = 24,
|
|
) -> str:
|
|
"""Create an import/excel session."""
|
|
session_id = f"{session_type}_{secrets.token_urlsafe(8)}"
|
|
expires_at = datetime.now() + timedelta(hours=expires_hours)
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute(
|
|
"INSERT INTO import_sessions (session_id, user_id, session_type, metadata_json, file_info_json, expires_at) VALUES (?,?,?,?,?,?)",
|
|
(
|
|
session_id,
|
|
user_id,
|
|
session_type,
|
|
json.dumps(metadata_map or {}, ensure_ascii=False),
|
|
json.dumps(file_info or {}, ensure_ascii=False),
|
|
expires_at,
|
|
),
|
|
)
|
|
conn.commit()
|
|
logger.info(f"Created {session_type} session {session_id} for user {user_id}")
|
|
return session_id
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_import_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get import session by ID."""
|
|
conn = self._get_conn()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT * FROM import_sessions WHERE session_id = ? AND expires_at > datetime('now')",
|
|
(session_id,),
|
|
).fetchone()
|
|
if row:
|
|
result = dict(row)
|
|
result["metadata_map"] = json.loads(result.pop("metadata_json"))
|
|
result["file_info"] = json.loads(result.pop("file_info_json"))
|
|
return result
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def update_import_session(
|
|
self,
|
|
session_id: str,
|
|
metadata_map: Optional[Dict] = None,
|
|
file_info: Optional[Dict] = None,
|
|
):
|
|
"""Update an import session's metadata map or file info."""
|
|
conn = self._get_conn()
|
|
try:
|
|
updates = []
|
|
params = []
|
|
if metadata_map is not None:
|
|
updates.append("metadata_json = ?")
|
|
params.append(json.dumps(metadata_map, ensure_ascii=False))
|
|
if file_info is not None:
|
|
updates.append("file_info_json = ?")
|
|
params.append(json.dumps(file_info, ensure_ascii=False))
|
|
if updates:
|
|
params.append(session_id)
|
|
conn.execute(
|
|
f"UPDATE import_sessions SET {', '.join(updates)} WHERE session_id = ?",
|
|
params,
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def delete_import_session(self, session_id: str):
|
|
"""Delete an import session."""
|
|
conn = self._get_conn()
|
|
try:
|
|
conn.execute("DELETE FROM import_sessions WHERE session_id = ?", (session_id,))
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
# --- Cleanup ---
|
|
|
|
def cleanup_expired(self) -> int:
|
|
"""Remove all expired sessions. Returns count of deleted rows."""
|
|
conn = self._get_conn()
|
|
try:
|
|
c1 = conn.execute("DELETE FROM file_sessions WHERE expires_at < datetime('now')")
|
|
c2 = conn.execute("DELETE FROM import_sessions WHERE expires_at < datetime('now')")
|
|
conn.commit()
|
|
total = c1.rowcount + c2.rowcount
|
|
if total > 0:
|
|
logger.info(f"Cleaned up {total} expired sessions")
|
|
return total
|
|
finally:
|
|
conn.close()
|
|
|
|
def cleanup_user_sessions(self, user_id: int) -> List[str]:
|
|
"""Delete all sessions for a user. Returns file paths for cleanup."""
|
|
conn = self._get_conn()
|
|
try:
|
|
# Collect file paths before deleting
|
|
rows = conn.execute(
|
|
"SELECT files_json FROM file_sessions WHERE user_id = ?",
|
|
(user_id,),
|
|
).fetchall()
|
|
file_paths = []
|
|
for row in rows:
|
|
files = json.loads(row["files_json"])
|
|
for f in files:
|
|
if f.get("filepath"):
|
|
file_paths.append(f["filepath"])
|
|
|
|
conn.execute("DELETE FROM file_sessions WHERE user_id = ?", (user_id,))
|
|
conn.execute("DELETE FROM import_sessions WHERE user_id = ?", (user_id,))
|
|
conn.commit()
|
|
return file_paths
|
|
finally:
|
|
conn.close()
|