"""Rate limiting middleware for API endpoints.""" import time from collections import defaultdict from typing import Dict, Optional, Tuple import redis.asyncio as aioredis from fastapi import HTTPException, Request, status from fastapi.responses import JSONResponse import json import asyncio from datetime import datetime, timedelta from app.core.config import get_settings from app.telemetry.metrics import track_rate_limit_metrics class RateLimiter: """Redis-based rate limiter with sliding window algorithm.""" def __init__(self, redis_client: aioredis.Redis): self.redis = redis_client async def is_allowed( self, key: str, limit: int, window_seconds: int, identifier: str = "" ) -> Tuple[bool, Dict[str, int]]: """ Check if request is allowed under rate limit. Returns: Tuple of (is_allowed, rate_limit_info) """ now = time.time() pipeline = self.redis.pipeline() # Remove expired entries pipeline.zremrangebyscore(key, 0, now - window_seconds) # Count current requests in window pipeline.zcard(key) # Add current request pipeline.zadd(key, {str(now): now}) # Set expiry pipeline.expire(key, window_seconds) results = await pipeline.execute() current_requests = results[1] rate_limit_info = { "limit": limit, "remaining": max(0, limit - current_requests), "reset_time": int(now + window_seconds), "retry_after": window_seconds if current_requests >= limit else 0 } is_allowed = current_requests <= limit # Track metrics track_rate_limit_metrics( identifier=identifier, is_allowed=is_allowed, current_requests=current_requests, limit=limit ) return is_allowed, rate_limit_info class RateLimitMiddleware: """FastAPI middleware for rate limiting.""" def __init__(self, redis_client: aioredis.Redis): self.limiter = RateLimiter(redis_client) self.settings = get_settings() # Rate limit configurations by endpoint pattern self.rate_limits = { # Authentication endpoints "POST:/api/v1/auth/login": (5, 300), # 5 requests per 5 minutes "POST:/api/v1/auth/register": (3, 3600), # 3 requests per hour "POST:/api/v1/auth/refresh": (10, 300), # 10 requests per 5 minutes "POST:/api/v1/auth/forgot-password": (3, 3600), # 3 requests per hour # File upload endpoints "POST:/api/v1/files/upload": (10, 3600), # 10 uploads per hour "POST:/api/v1/jobs": (20, 3600), # 20 job creations per hour # Job management endpoints "GET:/api/v1/jobs": (100, 300), # 100 requests per 5 minutes "PATCH:/api/v1/jobs/*/approve": (50, 3600), # 50 approvals per hour "PATCH:/api/v1/jobs/*/reject": (50, 3600), # 50 rejections per hour # VTT editing endpoints "PATCH:/api/v1/jobs/*/vtt": (100, 3600), # 100 VTT edits per hour # Admin endpoints (more restrictive) "GET:/api/v1/admin/*": (50, 300), # 50 requests per 5 minutes "POST:/api/v1/admin/*": (20, 3600), # 20 admin actions per hour "PATCH:/api/v1/admin/*": (20, 3600), # 20 admin updates per hour "DELETE:/api/v1/admin/*": (10, 3600), # 10 admin deletions per hour } # Default rate limits self.default_limits = { "authenticated": (1000, 3600), # 1000 requests per hour for authenticated users "anonymous": (100, 3600), # 100 requests per hour for anonymous users } def _get_client_identifier(self, request: Request) -> str: """Get client identifier for rate limiting.""" # Try to get user ID from JWT token user = getattr(request.state, 'user', None) if user: return f"user:{user.id}" # Fall back to IP address forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return f"ip:{forwarded_for.split(',')[0].strip()}" client_ip = request.client.host if request.client else "unknown" return f"ip:{client_ip}" def _get_endpoint_key(self, request: Request) -> str: """Get endpoint pattern for rate limiting.""" method = request.method path = request.url.path # Replace job IDs with wildcard for pattern matching import re path = re.sub(r'/jobs/[a-f0-9-]+/', '/jobs/*/', path) path = re.sub(r'/admin/users/[a-f0-9-]+', '/admin/users/*', path) return f"{method}:{path}" def _get_rate_limit(self, request: Request) -> Tuple[int, int]: """Get rate limit for the current request.""" endpoint_key = self._get_endpoint_key(request) # Check for specific endpoint limits if endpoint_key in self.rate_limits: return self.rate_limits[endpoint_key] # Check for wildcard matches for pattern, limits in self.rate_limits.items(): if pattern.endswith("*") and endpoint_key.startswith(pattern[:-1]): return limits # Use default limits based on authentication user = getattr(request.state, 'user', None) if user: return self.default_limits["authenticated"] else: return self.default_limits["anonymous"] async def __call__(self, request: Request, call_next): """Process rate limiting for the request.""" # Skip rate limiting for health checks and login (temporary for debugging) if request.url.path in ["/health", "/metrics", "/api/v1/auth/login"]: return await call_next(request) client_id = self._get_client_identifier(request) endpoint_key = self._get_endpoint_key(request) limit, window = self._get_rate_limit(request) # Create rate limit key rate_limit_key = f"rate_limit:{client_id}:{endpoint_key}" try: is_allowed, rate_info = await self.limiter.is_allowed( key=rate_limit_key, limit=limit, window_seconds=window, identifier=client_id ) if not is_allowed: # Return rate limit exceeded response return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "detail": "Rate limit exceeded", "error_code": "RATE_LIMIT_EXCEEDED", "rate_limit": rate_info }, headers={ "X-RateLimit-Limit": str(rate_info["limit"]), "X-RateLimit-Remaining": str(rate_info["remaining"]), "X-RateLimit-Reset": str(rate_info["reset_time"]), "Retry-After": str(rate_info["retry_after"]) } ) # Process the request response = await call_next(request) # Add rate limit headers to response response.headers["X-RateLimit-Limit"] = str(rate_info["limit"]) response.headers["X-RateLimit-Remaining"] = str(rate_info["remaining"]) response.headers["X-RateLimit-Reset"] = str(rate_info["reset_time"]) return response except Exception as e: # Log error but don't block request if rate limiting fails print(f"Rate limiting error: {e}") return await call_next(request) class IPWhitelist: """IP whitelist for bypassing rate limits.""" def __init__(self, redis_client: aioredis.Redis): self.redis = redis_client self.whitelist_key = "ip_whitelist" # Default whitelisted IPs (health checks, monitoring) self.default_whitelist = { "127.0.0.1", "::1", "169.254.169.254", # GCP metadata server } async def is_whitelisted(self, ip: str) -> bool: """Check if IP is whitelisted.""" if ip in self.default_whitelist: return True try: is_member = await self.redis.sismember(self.whitelist_key, ip) return bool(is_member) except Exception: return False async def add_ip(self, ip: str, ttl_seconds: Optional[int] = None) -> bool: """Add IP to whitelist.""" try: await self.redis.sadd(self.whitelist_key, ip) if ttl_seconds: # Create temporary whitelist entry temp_key = f"{self.whitelist_key}:temp:{ip}" await self.redis.setex(temp_key, ttl_seconds, "1") return True except Exception: return False async def remove_ip(self, ip: str) -> bool: """Remove IP from whitelist.""" try: await self.redis.srem(self.whitelist_key, ip) return True except Exception: return False async def create_rate_limit_middleware(redis_client: aioredis.Redis) -> RateLimitMiddleware: """Factory function to create rate limit middleware.""" return RateLimitMiddleware(redis_client)