- Create FastAPI application with async I/O - Implement Redis session storage (fixes session loss on restart) - Add JWT authentication with refresh tokens - Add Microsoft SSO support via MSAL - Copy all processors from src/ (100% reused, no changes) - Create file upload/download endpoints - Create metadata update endpoints - Create template CRUD endpoints - Add SQLAlchemy async database models - Add Docker Compose configuration with Redis Solves critical issues: - Session management: Redis replaces in-memory dicts - Scalability: Async FastAPI + microservices architecture - File handling: Persistent storage with auto-cleanup Key files: - backend/app/main.py - FastAPI entry point - backend/app/core/redis_client.py - Session store - backend/app/core/auth.py - JWT authentication - backend/app/api/* - All REST endpoints - backend/app/processors/ - Reused from src/ Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
341 lines
9.3 KiB
Python
341 lines
9.3 KiB
Python
"""
|
|
Redis Session Store
|
|
Replaces in-memory session dictionaries with persistent Redis storage.
|
|
Solves the main problem: sessions lost on restart.
|
|
"""
|
|
|
|
from redis.asyncio import Redis
|
|
from typing import Optional, Dict, Any
|
|
import json
|
|
import secrets
|
|
|
|
|
|
class RedisSessionStore:
|
|
"""
|
|
Redis-based session storage for:
|
|
1. User login sessions (JWT refresh tokens)
|
|
2. File processing sessions (uploaded files + metadata)
|
|
3. Import sessions (Excel/CSV metadata lookups)
|
|
"""
|
|
|
|
def __init__(self, redis_url: str):
|
|
"""
|
|
Initialize Redis connection.
|
|
|
|
Args:
|
|
redis_url: Redis connection string (e.g., "redis://localhost:6379/0")
|
|
"""
|
|
self.redis = Redis.from_url(redis_url, decode_responses=True)
|
|
|
|
async def close(self):
|
|
"""Close Redis connection"""
|
|
await self.redis.close()
|
|
|
|
# ===== User Session Methods =====
|
|
|
|
async def create_user_session(
|
|
self,
|
|
user_id: int,
|
|
refresh_token: str,
|
|
ip_address: str,
|
|
user_agent: str,
|
|
ttl: int = 7 * 86400 # 7 days
|
|
) -> str:
|
|
"""
|
|
Create a new user login session.
|
|
|
|
Args:
|
|
user_id: User ID from database
|
|
refresh_token: JWT refresh token
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent string
|
|
ttl: Time to live in seconds (default: 7 days)
|
|
|
|
Returns:
|
|
session_id: Unique session identifier
|
|
"""
|
|
session_id = secrets.token_urlsafe(32)
|
|
|
|
session_data = {
|
|
"user_id": user_id,
|
|
"refresh_token": refresh_token,
|
|
"ip_address": ip_address,
|
|
"user_agent": user_agent
|
|
}
|
|
|
|
await self.redis.setex(
|
|
f"user_session:{session_id}",
|
|
ttl,
|
|
json.dumps(session_data)
|
|
)
|
|
|
|
return session_id
|
|
|
|
async def get_user_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve user session data.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
Session data dict or None if not found/expired
|
|
"""
|
|
data = await self.redis.get(f"user_session:{session_id}")
|
|
return json.loads(data) if data else None
|
|
|
|
async def delete_user_session(self, session_id: str) -> bool:
|
|
"""
|
|
Delete user session (logout).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await self.redis.delete(f"user_session:{session_id}")
|
|
return result > 0
|
|
|
|
# ===== File Processing Session Methods =====
|
|
|
|
async def create_file_session(
|
|
self,
|
|
user_id: int,
|
|
files_data: list[Dict[str, Any]],
|
|
metadata_source: str,
|
|
ttl: int = 3600 # 1 hour
|
|
) -> str:
|
|
"""
|
|
Create file processing session (replaces in-memory sessions dict).
|
|
|
|
Args:
|
|
user_id: User ID who uploaded files
|
|
files_data: List of file info dicts (filename, filepath, metadata, etc.)
|
|
metadata_source: Source of metadata ('excel', 'ai', 'manual', 'import', 'template')
|
|
ttl: Time to live in seconds (default: 1 hour)
|
|
|
|
Returns:
|
|
session_id: Unique session identifier
|
|
"""
|
|
session_id = secrets.token_urlsafe(16)
|
|
|
|
session_data = {
|
|
"user_id": user_id,
|
|
"files": files_data,
|
|
"metadata_source": metadata_source
|
|
}
|
|
|
|
await self.redis.setex(
|
|
f"file_session:{session_id}",
|
|
ttl,
|
|
json.dumps(session_data)
|
|
)
|
|
|
|
return session_id
|
|
|
|
async def get_file_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve file processing session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
Session data dict or None if not found/expired
|
|
"""
|
|
data = await self.redis.get(f"file_session:{session_id}")
|
|
return json.loads(data) if data else None
|
|
|
|
async def update_file_session(
|
|
self,
|
|
session_id: str,
|
|
files_data: list[Dict[str, Any]]
|
|
) -> bool:
|
|
"""
|
|
Update file session with new metadata (after user edits).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
files_data: Updated file data list
|
|
|
|
Returns:
|
|
True if updated, False if session not found
|
|
"""
|
|
# Get current session to preserve TTL
|
|
current_data = await self.get_file_session(session_id)
|
|
if not current_data:
|
|
return False
|
|
|
|
# Update files data
|
|
current_data["files"] = files_data
|
|
|
|
# Get remaining TTL
|
|
ttl = await self.redis.ttl(f"file_session:{session_id}")
|
|
if ttl <= 0:
|
|
ttl = 3600 # Default 1 hour if expired
|
|
|
|
# Save with preserved TTL
|
|
await self.redis.setex(
|
|
f"file_session:{session_id}",
|
|
ttl,
|
|
json.dumps(current_data)
|
|
)
|
|
|
|
return True
|
|
|
|
async def delete_file_session(self, session_id: str) -> bool:
|
|
"""
|
|
Delete file processing session (cleanup after download).
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await self.redis.delete(f"file_session:{session_id}")
|
|
return result > 0
|
|
|
|
# ===== Import Session Methods =====
|
|
|
|
async def create_import_session(
|
|
self,
|
|
user_id: int,
|
|
import_type: str, # 'excel' or 'csv' or 'json'
|
|
filename: str,
|
|
filepath: str,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
ttl: int = 3600 # 1 hour
|
|
) -> str:
|
|
"""
|
|
Create import session for Excel/CSV/JSON metadata lookup.
|
|
|
|
Args:
|
|
user_id: User ID who uploaded import file
|
|
import_type: Type of import file
|
|
filename: Original filename
|
|
filepath: Path to uploaded file
|
|
metadata: Optional metadata map (after configuration)
|
|
ttl: Time to live in seconds (default: 1 hour)
|
|
|
|
Returns:
|
|
session_id: Unique session identifier
|
|
"""
|
|
session_id = secrets.token_urlsafe(16)
|
|
|
|
session_data = {
|
|
"user_id": user_id,
|
|
"import_type": import_type,
|
|
"filename": filename,
|
|
"filepath": filepath,
|
|
"metadata": metadata or {}
|
|
}
|
|
|
|
await self.redis.setex(
|
|
f"import_session:{session_id}",
|
|
ttl,
|
|
json.dumps(session_data)
|
|
)
|
|
|
|
return session_id
|
|
|
|
async def get_import_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve import session.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
Session data dict or None if not found/expired
|
|
"""
|
|
data = await self.redis.get(f"import_session:{session_id}")
|
|
return json.loads(data) if data else None
|
|
|
|
async def update_import_metadata(
|
|
self,
|
|
session_id: str,
|
|
metadata: Dict[str, Any]
|
|
) -> bool:
|
|
"""
|
|
Update import session with configured metadata mappings.
|
|
|
|
Args:
|
|
session_id: Session identifier
|
|
metadata: Metadata lookup map (filename -> metadata dict)
|
|
|
|
Returns:
|
|
True if updated, False if session not found
|
|
"""
|
|
current_data = await self.get_import_session(session_id)
|
|
if not current_data:
|
|
return False
|
|
|
|
current_data["metadata"] = metadata
|
|
|
|
ttl = await self.redis.ttl(f"import_session:{session_id}")
|
|
if ttl <= 0:
|
|
ttl = 3600
|
|
|
|
await self.redis.setex(
|
|
f"import_session:{session_id}",
|
|
ttl,
|
|
json.dumps(current_data)
|
|
)
|
|
|
|
return True
|
|
|
|
# ===== Utility Methods =====
|
|
|
|
async def ping(self) -> bool:
|
|
"""
|
|
Check if Redis is connected.
|
|
|
|
Returns:
|
|
True if connected, False otherwise
|
|
"""
|
|
try:
|
|
await self.redis.ping()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
async def get_all_sessions(self, pattern: str = "*") -> list[str]:
|
|
"""
|
|
Get all session keys matching pattern (for debugging).
|
|
|
|
Args:
|
|
pattern: Redis key pattern (e.g., "file_session:*")
|
|
|
|
Returns:
|
|
List of session keys
|
|
"""
|
|
cursor = 0
|
|
keys = []
|
|
while True:
|
|
cursor, batch = await self.redis.scan(cursor, match=pattern, count=100)
|
|
keys.extend(batch)
|
|
if cursor == 0:
|
|
break
|
|
return keys
|
|
|
|
async def cleanup_expired_sessions(self):
|
|
"""
|
|
Cleanup expired sessions (Redis does this automatically with TTL,
|
|
but this can be called for manual cleanup if needed).
|
|
"""
|
|
# Redis automatically removes expired keys, but we can force cleanup
|
|
# This is mainly for monitoring/logging purposes
|
|
patterns = ["user_session:*", "file_session:*", "import_session:*"]
|
|
total_cleaned = 0
|
|
|
|
for pattern in patterns:
|
|
keys = await self.get_all_sessions(pattern)
|
|
for key in keys:
|
|
ttl = await self.redis.ttl(key)
|
|
if ttl <= 0:
|
|
await self.redis.delete(key)
|
|
total_cleaned += 1
|
|
|
|
return total_cleaned
|