"""Enhanced request validation middleware.""" import json import re import time from typing import Any, Dict, List, Optional, Set from fastapi import HTTPException, Request, status from fastapi.responses import JSONResponse from pydantic import BaseModel, ValidationError as PydanticValidationError import magic from urllib.parse import unquote from app.telemetry.metrics import track_validation_metrics class ValidationError(Exception): """Custom validation error.""" pass class SecurityValidationError(Exception): """Raised when security validation fails.""" pass class RequestValidator: """Enhanced request validation with security checks.""" def __init__(self): # File type restrictions self.allowed_video_types = { "video/mp4", "video/quicktime", "video/x-msvideo" # AVI } self.allowed_subtitle_types = { "text/vtt", "text/plain" } # Security patterns to block self.malicious_patterns = [ # SQL injection patterns r"(union|select|insert|update|delete|drop|create|alter)\s+", r"(script|javascript|vbscript|onload|onerror|onclick)", r"<\s*script[^>]*>", r"javascript:", r"data:.*base64", # Path traversal r"\.\./", r"\.\.\\", r"%2e%2e%2f", r"%2e%2e\\", # Command injection (removed $ to allow MongoDB operators in controlled contexts) r"[;&|`](?!\s*$)", # Allow $ but not as command separator r"(rm|wget|curl|nc|bash|sh|cmd|powershell)\s+", # MongoDB injection r"\$where|\$ne|\$gt|\$lt|\$regex", ] self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.malicious_patterns] # Max file sizes (in bytes) self.max_video_size = 2 * 1024 * 1024 * 1024 # 2GB self.max_subtitle_size = 10 * 1024 * 1024 # 10MB # Request size limits self.max_json_size = 1024 * 1024 # 1MB self.max_form_fields = 50 def validate_string_content(self, content: str, field_name: str = "input") -> None: """Validate string content for malicious patterns.""" if not isinstance(content, str): return for pattern in self.compiled_patterns: if pattern.search(content): raise SecurityValidationError( f"Potentially malicious content detected in {field_name}" ) def validate_filename(self, filename: str) -> str: """Validate and sanitize filename.""" if not filename: raise ValidationError("Filename cannot be empty") # Decode URL encoding filename = unquote(filename) # Check for malicious patterns self.validate_string_content(filename, "filename") # Remove dangerous characters safe_filename = re.sub(r'[^\w\-_\.]', '_', filename) # Prevent hidden files if safe_filename.startswith('.'): safe_filename = 'file_' + safe_filename[1:] # Limit length if len(safe_filename) > 255: name, ext = safe_filename.rsplit('.', 1) if '.' in safe_filename else (safe_filename, '') safe_filename = name[:250] + ('.' + ext if ext else '') return safe_filename def validate_file_type(self, content: bytes, expected_type: str, filename: str) -> None: """Validate file type using magic numbers.""" try: detected_type = magic.from_buffer(content, mime=True) except Exception: # Fallback to extension-based validation ext = filename.lower().split('.')[-1] if '.' in filename else '' video_extensions = {'mp4', 'mov', 'avi', 'mkv'} subtitle_extensions = {'vtt', 'srt', 'txt'} if expected_type == "video" and ext not in video_extensions: raise ValidationError(f"Invalid video file extension: {ext}") elif expected_type == "subtitle" and ext not in subtitle_extensions: raise ValidationError(f"Invalid subtitle file extension: {ext}") return if expected_type == "video" and detected_type not in self.allowed_video_types: raise ValidationError( f"Invalid video file type: {detected_type}. " f"Allowed types: {', '.join(self.allowed_video_types)}" ) elif expected_type == "subtitle" and detected_type not in self.allowed_subtitle_types: raise ValidationError( f"Invalid subtitle file type: {detected_type}. " f"Allowed types: {', '.join(self.allowed_subtitle_types)}" ) def validate_file_size(self, size: int, file_type: str) -> None: """Validate file size limits.""" if file_type == "video" and size > self.max_video_size: raise ValidationError( f"Video file too large: {size} bytes. " f"Maximum allowed: {self.max_video_size} bytes" ) elif file_type == "subtitle" and size > self.max_subtitle_size: raise ValidationError( f"Subtitle file too large: {size} bytes. " f"Maximum allowed: {self.max_subtitle_size} bytes" ) async def validate_json_payload(self, request: Request) -> Optional[Dict[str, Any]]: """Validate JSON request payload.""" if not request.headers.get("content-type", "").startswith("application/json"): return None content_length = request.headers.get("content-length") if content_length and int(content_length) > self.max_json_size: raise ValidationError(f"JSON payload too large: {content_length} bytes") try: # Check if body has already been read if hasattr(request, '_cached_body'): body = request._cached_body else: body = await request.body() # Cache the body so FastAPI can read it later request._cached_body = body if len(body) > self.max_json_size: raise ValidationError(f"JSON payload too large: {len(body)} bytes") if not body: return {} payload = json.loads(body) # Recursively validate all string values self._validate_json_values(payload) return payload except json.JSONDecodeError as e: raise ValidationError(f"Invalid JSON: {e}") def _validate_json_values(self, obj: Any, path: str = "root") -> None: """Recursively validate JSON values.""" if isinstance(obj, dict): if len(obj) > self.max_form_fields: raise ValidationError(f"Too many fields in object at {path}") for key, value in obj.items(): if isinstance(key, str): self.validate_string_content(key, f"{path}.{key}") self._validate_json_values(value, f"{path}.{key}") elif isinstance(obj, list): if len(obj) > 1000: # Prevent large arrays raise ValidationError(f"Array too large at {path}") for i, item in enumerate(obj): self._validate_json_values(item, f"{path}[{i}]") elif isinstance(obj, str): self.validate_string_content(obj, path) def validate_query_params(self, request: Request) -> None: """Validate query parameters.""" for key, value in request.query_params.items(): self.validate_string_content(key, f"query.{key}") self.validate_string_content(str(value), f"query.{key}") def validate_headers(self, request: Request) -> None: """Validate request headers.""" suspicious_headers = { "x-forwarded-host", "x-original-host", "x-rewrite-url" } for header_name, header_value in request.headers.items(): # Check for suspicious headers if header_name.lower() in suspicious_headers: self.validate_string_content(header_value, f"header.{header_name}") # Validate user-agent length if header_name.lower() == "user-agent" and len(header_value) > 500: raise SecurityValidationError("User-Agent header too long") class ValidationMiddleware: """FastAPI middleware for enhanced request validation.""" def __init__(self): self.validator = RequestValidator() async def __call__(self, request: Request, call_next): """Process validation for the request.""" start_time = time.time() validation_errors = [] # Skip validation for timing adjustment endpoint temporarily if "/vtt/adjust-timing" in request.url.path: return await call_next(request) try: # Validate headers self.validator.validate_headers(request) # Validate query parameters self.validator.validate_query_params(request) # Validate JSON payload if present if request.method in ["POST", "PUT", "PATCH"]: await self.validator.validate_json_payload(request) # Process the request response = await call_next(request) # Track successful validation track_validation_metrics( endpoint=request.url.path, method=request.method, is_valid=True, validation_time=time.time() - start_time, error_types=[] ) return response except SecurityValidationError as e: validation_errors.append("security") track_validation_metrics( endpoint=request.url.path, method=request.method, is_valid=False, validation_time=time.time() - start_time, error_types=validation_errors ) return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content={ "detail": "Security validation failed", "error_code": "SECURITY_VALIDATION_ERROR" } ) except ValidationError as e: validation_errors.append("format") track_validation_metrics( endpoint=request.url.path, method=request.method, is_valid=False, validation_time=time.time() - start_time, error_types=validation_errors ) return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={ "detail": str(e), "error_code": "VALIDATION_ERROR" } ) except Exception as e: validation_errors.append("unknown") track_validation_metrics( endpoint=request.url.path, method=request.method, is_valid=False, validation_time=time.time() - start_time, error_types=validation_errors ) # Log unexpected error but continue processing print(f"Validation middleware error: {e}") return await call_next(request) async def create_validation_middleware() -> ValidationMiddleware: """Factory function to create validation middleware.""" return ValidationMiddleware()