263 lines
9.3 KiB
Python
263 lines
9.3 KiB
Python
"""Rate limiting middleware for API endpoints."""
|
|
|
|
import time
|
|
|
|
import redis.asyncio as aioredis
|
|
from fastapi import Request, status
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from app.core.config import get_settings
|
|
from app.telemetry.metrics import track_rate_limit_metrics
|
|
|
|
|
|
class RateLimiter:
|
|
"""Redis-based rate limiter with sliding window algorithm."""
|
|
|
|
def __init__(self, redis_client: aioredis.Redis):
|
|
self.redis = redis_client
|
|
|
|
async def is_allowed(
|
|
self,
|
|
key: str,
|
|
limit: int,
|
|
window_seconds: int,
|
|
identifier: str = ""
|
|
) -> tuple[bool, dict[str, int]]:
|
|
"""
|
|
Check if request is allowed under rate limit.
|
|
|
|
Returns:
|
|
Tuple of (is_allowed, rate_limit_info)
|
|
"""
|
|
now = time.time()
|
|
pipeline = self.redis.pipeline()
|
|
|
|
# Remove expired entries
|
|
pipeline.zremrangebyscore(key, 0, now - window_seconds)
|
|
|
|
# Count current requests in window
|
|
pipeline.zcard(key)
|
|
|
|
# Add current request
|
|
pipeline.zadd(key, {str(now): now})
|
|
|
|
# Set expiry
|
|
pipeline.expire(key, window_seconds)
|
|
|
|
results = await pipeline.execute()
|
|
current_requests = results[1]
|
|
|
|
rate_limit_info = {
|
|
"limit": limit,
|
|
"remaining": max(0, limit - current_requests),
|
|
"reset_time": int(now + window_seconds),
|
|
"retry_after": window_seconds if current_requests >= limit else 0
|
|
}
|
|
|
|
is_allowed = current_requests <= limit
|
|
|
|
# Track metrics
|
|
track_rate_limit_metrics(
|
|
identifier=identifier,
|
|
is_allowed=is_allowed,
|
|
current_requests=current_requests,
|
|
limit=limit
|
|
)
|
|
|
|
return is_allowed, rate_limit_info
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
"""FastAPI middleware for rate limiting."""
|
|
|
|
def __init__(self, redis_client: aioredis.Redis):
|
|
self.limiter = RateLimiter(redis_client)
|
|
self.settings = get_settings()
|
|
|
|
# Rate limit configurations by endpoint pattern
|
|
self.rate_limits = {
|
|
# Authentication endpoints
|
|
"POST:/api/v1/auth/login": (5, 300), # 5 requests per 5 minutes
|
|
"POST:/api/v1/auth/register": (3, 3600), # 3 requests per hour
|
|
"POST:/api/v1/auth/refresh": (10, 300), # 10 requests per 5 minutes
|
|
"POST:/api/v1/auth/forgot-password": (3, 3600), # 3 requests per hour
|
|
|
|
# File upload endpoints
|
|
"POST:/api/v1/files/upload": (10, 3600), # 10 uploads per hour
|
|
"POST:/api/v1/jobs": (20, 3600), # 20 job creations per hour
|
|
|
|
# Job management endpoints
|
|
"GET:/api/v1/jobs": (100, 300), # 100 requests per 5 minutes
|
|
"PATCH:/api/v1/jobs/*/approve": (50, 3600), # 50 approvals per hour
|
|
"PATCH:/api/v1/jobs/*/reject": (50, 3600), # 50 rejections per hour
|
|
|
|
# VTT editing endpoints
|
|
"PATCH:/api/v1/jobs/*/vtt": (100, 3600), # 100 VTT edits per hour
|
|
|
|
# Admin endpoints (more restrictive)
|
|
"GET:/api/v1/admin/*": (50, 300), # 50 requests per 5 minutes
|
|
"POST:/api/v1/admin/*": (20, 3600), # 20 admin actions per hour
|
|
"PATCH:/api/v1/admin/*": (20, 3600), # 20 admin updates per hour
|
|
"DELETE:/api/v1/admin/*": (10, 3600), # 10 admin deletions per hour
|
|
}
|
|
|
|
# Default rate limits
|
|
self.default_limits = {
|
|
"authenticated": (1000, 3600), # 1000 requests per hour for authenticated users
|
|
"anonymous": (100, 3600), # 100 requests per hour for anonymous users
|
|
}
|
|
|
|
def _get_client_identifier(self, request: Request) -> str:
|
|
"""Get client identifier for rate limiting."""
|
|
user = getattr(request.state, 'user', None)
|
|
if user:
|
|
return f"user:{user.id}"
|
|
|
|
# Only trust X-Forwarded-For when the request arrived via HTTPS (i.e. through
|
|
# the Apache/nginx reverse proxy). On plain HTTP (direct connections, local
|
|
# dev) the header can be forged, so we fall back to the socket IP.
|
|
if request.headers.get("X-Forwarded-Proto") == "https":
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
# Take the right-most IP added by the trusted proxy, not client-supplied ones.
|
|
return f"ip:{forwarded_for.split(',')[-1].strip()}"
|
|
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
return f"ip:{client_ip}"
|
|
|
|
def _get_endpoint_key(self, request: Request) -> str:
|
|
"""Get endpoint pattern for rate limiting."""
|
|
method = request.method
|
|
path = request.url.path
|
|
|
|
# Replace job IDs with wildcard for pattern matching
|
|
import re
|
|
path = re.sub(r'/jobs/[a-f0-9-]+/', '/jobs/*/', path)
|
|
path = re.sub(r'/admin/users/[a-f0-9-]+', '/admin/users/*', path)
|
|
|
|
return f"{method}:{path}"
|
|
|
|
def _get_rate_limit(self, request: Request) -> tuple[int, int]:
|
|
"""Get rate limit for the current request."""
|
|
endpoint_key = self._get_endpoint_key(request)
|
|
|
|
# Check for specific endpoint limits
|
|
if endpoint_key in self.rate_limits:
|
|
return self.rate_limits[endpoint_key]
|
|
|
|
# Check for wildcard matches
|
|
for pattern, limits in self.rate_limits.items():
|
|
if pattern.endswith("*") and endpoint_key.startswith(pattern[:-1]):
|
|
return limits
|
|
|
|
# Use default limits based on authentication
|
|
user = getattr(request.state, 'user', None)
|
|
if user:
|
|
return self.default_limits["authenticated"]
|
|
else:
|
|
return self.default_limits["anonymous"]
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process rate limiting for the request."""
|
|
|
|
# Skip rate limiting for health checks and metrics only
|
|
if request.url.path in ["/health", "/metrics"]:
|
|
return await call_next(request)
|
|
|
|
client_id = self._get_client_identifier(request)
|
|
endpoint_key = self._get_endpoint_key(request)
|
|
limit, window = self._get_rate_limit(request)
|
|
|
|
# Create rate limit key
|
|
rate_limit_key = f"rate_limit:{client_id}:{endpoint_key}"
|
|
|
|
try:
|
|
is_allowed, rate_info = await self.limiter.is_allowed(
|
|
key=rate_limit_key,
|
|
limit=limit,
|
|
window_seconds=window,
|
|
identifier=client_id
|
|
)
|
|
|
|
if not is_allowed:
|
|
# Return rate limit exceeded response
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"detail": "Rate limit exceeded",
|
|
"error_code": "RATE_LIMIT_EXCEEDED",
|
|
"rate_limit": rate_info
|
|
},
|
|
headers={
|
|
"X-RateLimit-Limit": str(rate_info["limit"]),
|
|
"X-RateLimit-Remaining": str(rate_info["remaining"]),
|
|
"X-RateLimit-Reset": str(rate_info["reset_time"]),
|
|
"Retry-After": str(rate_info["retry_after"])
|
|
}
|
|
)
|
|
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers to response
|
|
response.headers["X-RateLimit-Limit"] = str(rate_info["limit"])
|
|
response.headers["X-RateLimit-Remaining"] = str(rate_info["remaining"])
|
|
response.headers["X-RateLimit-Reset"] = str(rate_info["reset_time"])
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
# Log error but don't block request if rate limiting fails
|
|
print(f"Rate limiting error: {e}")
|
|
return await call_next(request)
|
|
|
|
|
|
class IPWhitelist:
|
|
"""IP whitelist for bypassing rate limits."""
|
|
|
|
def __init__(self, redis_client: aioredis.Redis):
|
|
self.redis = redis_client
|
|
self.whitelist_key = "ip_whitelist"
|
|
|
|
# Default whitelisted IPs (health checks, monitoring)
|
|
self.default_whitelist = {
|
|
"127.0.0.1",
|
|
"::1",
|
|
"169.254.169.254", # GCP metadata server
|
|
}
|
|
|
|
async def is_whitelisted(self, ip: str) -> bool:
|
|
"""Check if IP is whitelisted."""
|
|
if ip in self.default_whitelist:
|
|
return True
|
|
|
|
try:
|
|
is_member = await self.redis.sismember(self.whitelist_key, ip)
|
|
return bool(is_member)
|
|
except Exception:
|
|
return False
|
|
|
|
async def add_ip(self, ip: str, ttl_seconds: int | None = None) -> bool:
|
|
"""Add IP to whitelist."""
|
|
try:
|
|
await self.redis.sadd(self.whitelist_key, ip)
|
|
if ttl_seconds:
|
|
# Create temporary whitelist entry
|
|
temp_key = f"{self.whitelist_key}:temp:{ip}"
|
|
await self.redis.setex(temp_key, ttl_seconds, "1")
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
async def remove_ip(self, ip: str) -> bool:
|
|
"""Remove IP from whitelist."""
|
|
try:
|
|
await self.redis.srem(self.whitelist_key, ip)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
async def create_rate_limit_middleware(redis_client: aioredis.Redis) -> RateLimitMiddleware:
|
|
"""Factory function to create rate limit middleware."""
|
|
return RateLimitMiddleware(redis_client)
|