video-accessibility/backend/app/middleware/rate_limiting.py
Vadym Samoilenko 31199f8705 chore: push all session changes — backend hardening, tests, apache config, deploy scripts
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-30 15:52:14 +01:00

263 lines
9.3 KiB
Python

"""Rate limiting middleware for API endpoints."""
import time
import redis.asyncio as aioredis
from fastapi import Request, status
from fastapi.responses import JSONResponse
from app.core.config import get_settings
from app.telemetry.metrics import track_rate_limit_metrics
class RateLimiter:
"""Redis-based rate limiter with sliding window algorithm."""
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
async def is_allowed(
self,
key: str,
limit: int,
window_seconds: int,
identifier: str = ""
) -> tuple[bool, dict[str, int]]:
"""
Check if request is allowed under rate limit.
Returns:
Tuple of (is_allowed, rate_limit_info)
"""
now = time.time()
pipeline = self.redis.pipeline()
# Remove expired entries
pipeline.zremrangebyscore(key, 0, now - window_seconds)
# Count current requests in window
pipeline.zcard(key)
# Add current request
pipeline.zadd(key, {str(now): now})
# Set expiry
pipeline.expire(key, window_seconds)
results = await pipeline.execute()
current_requests = results[1]
rate_limit_info = {
"limit": limit,
"remaining": max(0, limit - current_requests),
"reset_time": int(now + window_seconds),
"retry_after": window_seconds if current_requests >= limit else 0
}
is_allowed = current_requests <= limit
# Track metrics
track_rate_limit_metrics(
identifier=identifier,
is_allowed=is_allowed,
current_requests=current_requests,
limit=limit
)
return is_allowed, rate_limit_info
class RateLimitMiddleware:
"""FastAPI middleware for rate limiting."""
def __init__(self, redis_client: aioredis.Redis):
self.limiter = RateLimiter(redis_client)
self.settings = get_settings()
# Rate limit configurations by endpoint pattern
self.rate_limits = {
# Authentication endpoints
"POST:/api/v1/auth/login": (5, 300), # 5 requests per 5 minutes
"POST:/api/v1/auth/register": (3, 3600), # 3 requests per hour
"POST:/api/v1/auth/refresh": (10, 300), # 10 requests per 5 minutes
"POST:/api/v1/auth/forgot-password": (3, 3600), # 3 requests per hour
# File upload endpoints
"POST:/api/v1/files/upload": (10, 3600), # 10 uploads per hour
"POST:/api/v1/jobs": (20, 3600), # 20 job creations per hour
# Job management endpoints
"GET:/api/v1/jobs": (100, 300), # 100 requests per 5 minutes
"PATCH:/api/v1/jobs/*/approve": (50, 3600), # 50 approvals per hour
"PATCH:/api/v1/jobs/*/reject": (50, 3600), # 50 rejections per hour
# VTT editing endpoints
"PATCH:/api/v1/jobs/*/vtt": (100, 3600), # 100 VTT edits per hour
# Admin endpoints (more restrictive)
"GET:/api/v1/admin/*": (50, 300), # 50 requests per 5 minutes
"POST:/api/v1/admin/*": (20, 3600), # 20 admin actions per hour
"PATCH:/api/v1/admin/*": (20, 3600), # 20 admin updates per hour
"DELETE:/api/v1/admin/*": (10, 3600), # 10 admin deletions per hour
}
# Default rate limits
self.default_limits = {
"authenticated": (1000, 3600), # 1000 requests per hour for authenticated users
"anonymous": (100, 3600), # 100 requests per hour for anonymous users
}
def _get_client_identifier(self, request: Request) -> str:
"""Get client identifier for rate limiting."""
user = getattr(request.state, 'user', None)
if user:
return f"user:{user.id}"
# Only trust X-Forwarded-For when the request arrived via HTTPS (i.e. through
# the Apache/nginx reverse proxy). On plain HTTP (direct connections, local
# dev) the header can be forged, so we fall back to the socket IP.
if request.headers.get("X-Forwarded-Proto") == "https":
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the right-most IP added by the trusted proxy, not client-supplied ones.
return f"ip:{forwarded_for.split(',')[-1].strip()}"
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}"
def _get_endpoint_key(self, request: Request) -> str:
"""Get endpoint pattern for rate limiting."""
method = request.method
path = request.url.path
# Replace job IDs with wildcard for pattern matching
import re
path = re.sub(r'/jobs/[a-f0-9-]+/', '/jobs/*/', path)
path = re.sub(r'/admin/users/[a-f0-9-]+', '/admin/users/*', path)
return f"{method}:{path}"
def _get_rate_limit(self, request: Request) -> tuple[int, int]:
"""Get rate limit for the current request."""
endpoint_key = self._get_endpoint_key(request)
# Check for specific endpoint limits
if endpoint_key in self.rate_limits:
return self.rate_limits[endpoint_key]
# Check for wildcard matches
for pattern, limits in self.rate_limits.items():
if pattern.endswith("*") and endpoint_key.startswith(pattern[:-1]):
return limits
# Use default limits based on authentication
user = getattr(request.state, 'user', None)
if user:
return self.default_limits["authenticated"]
else:
return self.default_limits["anonymous"]
async def __call__(self, request: Request, call_next):
"""Process rate limiting for the request."""
# Skip rate limiting for health checks and metrics only
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
client_id = self._get_client_identifier(request)
endpoint_key = self._get_endpoint_key(request)
limit, window = self._get_rate_limit(request)
# Create rate limit key
rate_limit_key = f"rate_limit:{client_id}:{endpoint_key}"
try:
is_allowed, rate_info = await self.limiter.is_allowed(
key=rate_limit_key,
limit=limit,
window_seconds=window,
identifier=client_id
)
if not is_allowed:
# Return rate limit exceeded response
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded",
"error_code": "RATE_LIMIT_EXCEEDED",
"rate_limit": rate_info
},
headers={
"X-RateLimit-Limit": str(rate_info["limit"]),
"X-RateLimit-Remaining": str(rate_info["remaining"]),
"X-RateLimit-Reset": str(rate_info["reset_time"]),
"Retry-After": str(rate_info["retry_after"])
}
)
# Process the request
response = await call_next(request)
# Add rate limit headers to response
response.headers["X-RateLimit-Limit"] = str(rate_info["limit"])
response.headers["X-RateLimit-Remaining"] = str(rate_info["remaining"])
response.headers["X-RateLimit-Reset"] = str(rate_info["reset_time"])
return response
except Exception as e:
# Log error but don't block request if rate limiting fails
print(f"Rate limiting error: {e}")
return await call_next(request)
class IPWhitelist:
"""IP whitelist for bypassing rate limits."""
def __init__(self, redis_client: aioredis.Redis):
self.redis = redis_client
self.whitelist_key = "ip_whitelist"
# Default whitelisted IPs (health checks, monitoring)
self.default_whitelist = {
"127.0.0.1",
"::1",
"169.254.169.254", # GCP metadata server
}
async def is_whitelisted(self, ip: str) -> bool:
"""Check if IP is whitelisted."""
if ip in self.default_whitelist:
return True
try:
is_member = await self.redis.sismember(self.whitelist_key, ip)
return bool(is_member)
except Exception:
return False
async def add_ip(self, ip: str, ttl_seconds: int | None = None) -> bool:
"""Add IP to whitelist."""
try:
await self.redis.sadd(self.whitelist_key, ip)
if ttl_seconds:
# Create temporary whitelist entry
temp_key = f"{self.whitelist_key}:temp:{ip}"
await self.redis.setex(temp_key, ttl_seconds, "1")
return True
except Exception:
return False
async def remove_ip(self, ip: str) -> bool:
"""Remove IP from whitelist."""
try:
await self.redis.srem(self.whitelist_key, ip)
return True
except Exception:
return False
async def create_rate_limit_middleware(redis_client: aioredis.Redis) -> RateLimitMiddleware:
"""Factory function to create rate limit middleware."""
return RateLimitMiddleware(redis_client)