solventum-image-metadata/backend/app/core/redis_client.py
SamoilenkoVadym 563d476a94 feat(backend): migrate from Flask to FastAPI with Redis sessions
- Create FastAPI application with async I/O
- Implement Redis session storage (fixes session loss on restart)
- Add JWT authentication with refresh tokens
- Add Microsoft SSO support via MSAL
- Copy all processors from src/ (100% reused, no changes)
- Create file upload/download endpoints
- Create metadata update endpoints
- Create template CRUD endpoints
- Add SQLAlchemy async database models
- Add Docker Compose configuration with Redis

Solves critical issues:
- Session management: Redis replaces in-memory dicts
- Scalability: Async FastAPI + microservices architecture
- File handling: Persistent storage with auto-cleanup

Key files:
- backend/app/main.py - FastAPI entry point
- backend/app/core/redis_client.py - Session store
- backend/app/core/auth.py - JWT authentication
- backend/app/api/* - All REST endpoints
- backend/app/processors/ - Reused from src/

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
2026-02-09 13:14:37 +00:00

341 lines
9.3 KiB
Python

"""
Redis Session Store
Replaces in-memory session dictionaries with persistent Redis storage.
Solves the main problem: sessions lost on restart.
"""
from redis.asyncio import Redis
from typing import Optional, Dict, Any
import json
import secrets
class RedisSessionStore:
"""
Redis-based session storage for:
1. User login sessions (JWT refresh tokens)
2. File processing sessions (uploaded files + metadata)
3. Import sessions (Excel/CSV metadata lookups)
"""
def __init__(self, redis_url: str):
"""
Initialize Redis connection.
Args:
redis_url: Redis connection string (e.g., "redis://localhost:6379/0")
"""
self.redis = Redis.from_url(redis_url, decode_responses=True)
async def close(self):
"""Close Redis connection"""
await self.redis.close()
# ===== User Session Methods =====
async def create_user_session(
self,
user_id: int,
refresh_token: str,
ip_address: str,
user_agent: str,
ttl: int = 7 * 86400 # 7 days
) -> str:
"""
Create a new user login session.
Args:
user_id: User ID from database
refresh_token: JWT refresh token
ip_address: Client IP address
user_agent: Client user agent string
ttl: Time to live in seconds (default: 7 days)
Returns:
session_id: Unique session identifier
"""
session_id = secrets.token_urlsafe(32)
session_data = {
"user_id": user_id,
"refresh_token": refresh_token,
"ip_address": ip_address,
"user_agent": user_agent
}
await self.redis.setex(
f"user_session:{session_id}",
ttl,
json.dumps(session_data)
)
return session_id
async def get_user_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve user session data.
Args:
session_id: Session identifier
Returns:
Session data dict or None if not found/expired
"""
data = await self.redis.get(f"user_session:{session_id}")
return json.loads(data) if data else None
async def delete_user_session(self, session_id: str) -> bool:
"""
Delete user session (logout).
Args:
session_id: Session identifier
Returns:
True if deleted, False if not found
"""
result = await self.redis.delete(f"user_session:{session_id}")
return result > 0
# ===== File Processing Session Methods =====
async def create_file_session(
self,
user_id: int,
files_data: list[Dict[str, Any]],
metadata_source: str,
ttl: int = 3600 # 1 hour
) -> str:
"""
Create file processing session (replaces in-memory sessions dict).
Args:
user_id: User ID who uploaded files
files_data: List of file info dicts (filename, filepath, metadata, etc.)
metadata_source: Source of metadata ('excel', 'ai', 'manual', 'import', 'template')
ttl: Time to live in seconds (default: 1 hour)
Returns:
session_id: Unique session identifier
"""
session_id = secrets.token_urlsafe(16)
session_data = {
"user_id": user_id,
"files": files_data,
"metadata_source": metadata_source
}
await self.redis.setex(
f"file_session:{session_id}",
ttl,
json.dumps(session_data)
)
return session_id
async def get_file_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve file processing session.
Args:
session_id: Session identifier
Returns:
Session data dict or None if not found/expired
"""
data = await self.redis.get(f"file_session:{session_id}")
return json.loads(data) if data else None
async def update_file_session(
self,
session_id: str,
files_data: list[Dict[str, Any]]
) -> bool:
"""
Update file session with new metadata (after user edits).
Args:
session_id: Session identifier
files_data: Updated file data list
Returns:
True if updated, False if session not found
"""
# Get current session to preserve TTL
current_data = await self.get_file_session(session_id)
if not current_data:
return False
# Update files data
current_data["files"] = files_data
# Get remaining TTL
ttl = await self.redis.ttl(f"file_session:{session_id}")
if ttl <= 0:
ttl = 3600 # Default 1 hour if expired
# Save with preserved TTL
await self.redis.setex(
f"file_session:{session_id}",
ttl,
json.dumps(current_data)
)
return True
async def delete_file_session(self, session_id: str) -> bool:
"""
Delete file processing session (cleanup after download).
Args:
session_id: Session identifier
Returns:
True if deleted, False if not found
"""
result = await self.redis.delete(f"file_session:{session_id}")
return result > 0
# ===== Import Session Methods =====
async def create_import_session(
self,
user_id: int,
import_type: str, # 'excel' or 'csv' or 'json'
filename: str,
filepath: str,
metadata: Optional[Dict[str, Any]] = None,
ttl: int = 3600 # 1 hour
) -> str:
"""
Create import session for Excel/CSV/JSON metadata lookup.
Args:
user_id: User ID who uploaded import file
import_type: Type of import file
filename: Original filename
filepath: Path to uploaded file
metadata: Optional metadata map (after configuration)
ttl: Time to live in seconds (default: 1 hour)
Returns:
session_id: Unique session identifier
"""
session_id = secrets.token_urlsafe(16)
session_data = {
"user_id": user_id,
"import_type": import_type,
"filename": filename,
"filepath": filepath,
"metadata": metadata or {}
}
await self.redis.setex(
f"import_session:{session_id}",
ttl,
json.dumps(session_data)
)
return session_id
async def get_import_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve import session.
Args:
session_id: Session identifier
Returns:
Session data dict or None if not found/expired
"""
data = await self.redis.get(f"import_session:{session_id}")
return json.loads(data) if data else None
async def update_import_metadata(
self,
session_id: str,
metadata: Dict[str, Any]
) -> bool:
"""
Update import session with configured metadata mappings.
Args:
session_id: Session identifier
metadata: Metadata lookup map (filename -> metadata dict)
Returns:
True if updated, False if session not found
"""
current_data = await self.get_import_session(session_id)
if not current_data:
return False
current_data["metadata"] = metadata
ttl = await self.redis.ttl(f"import_session:{session_id}")
if ttl <= 0:
ttl = 3600
await self.redis.setex(
f"import_session:{session_id}",
ttl,
json.dumps(current_data)
)
return True
# ===== Utility Methods =====
async def ping(self) -> bool:
"""
Check if Redis is connected.
Returns:
True if connected, False otherwise
"""
try:
await self.redis.ping()
return True
except Exception:
return False
async def get_all_sessions(self, pattern: str = "*") -> list[str]:
"""
Get all session keys matching pattern (for debugging).
Args:
pattern: Redis key pattern (e.g., "file_session:*")
Returns:
List of session keys
"""
cursor = 0
keys = []
while True:
cursor, batch = await self.redis.scan(cursor, match=pattern, count=100)
keys.extend(batch)
if cursor == 0:
break
return keys
async def cleanup_expired_sessions(self):
"""
Cleanup expired sessions (Redis does this automatically with TTL,
but this can be called for manual cleanup if needed).
"""
# Redis automatically removes expired keys, but we can force cleanup
# This is mainly for monitoring/logging purposes
patterns = ["user_session:*", "file_session:*", "import_session:*"]
total_cleaned = 0
for pattern in patterns:
keys = await self.get_all_sessions(pattern)
for key in keys:
ttl = await self.redis.ttl(key)
if ttl <= 0:
await self.redis.delete(key)
total_cleaned += 1
return total_cleaned