video-accessibility/backend/app/middleware/validation.py
2025-08-24 16:28:33 -05:00

324 lines
No EOL
12 KiB
Python

"""Enhanced request validation middleware."""
import json
import re
import time
from typing import Any, Dict, List, Optional, Set
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ValidationError as PydanticValidationError
import magic
from urllib.parse import unquote
from app.telemetry.metrics import track_validation_metrics
class ValidationError(Exception):
"""Custom validation error."""
pass
class SecurityValidationError(Exception):
"""Raised when security validation fails."""
pass
class RequestValidator:
"""Enhanced request validation with security checks."""
def __init__(self):
# File type restrictions
self.allowed_video_types = {
"video/mp4",
"video/quicktime",
"video/x-msvideo" # AVI
}
self.allowed_subtitle_types = {
"text/vtt",
"text/plain"
}
# Security patterns to block
self.malicious_patterns = [
# SQL injection patterns
r"(union|select|insert|update|delete|drop|create|alter)\s+",
r"(script|javascript|vbscript|onload|onerror|onclick)",
r"<\s*script[^>]*>",
r"javascript:",
r"data:.*base64",
# Path traversal
r"\.\./",
r"\.\.\\",
r"%2e%2e%2f",
r"%2e%2e\\",
# Command injection
r"[;&|`$]",
r"(rm|wget|curl|nc|bash|sh|cmd|powershell)\s+",
# MongoDB injection
r"\$where|\$ne|\$gt|\$lt|\$regex",
]
self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.malicious_patterns]
# Max file sizes (in bytes)
self.max_video_size = 2 * 1024 * 1024 * 1024 # 2GB
self.max_subtitle_size = 10 * 1024 * 1024 # 10MB
# Request size limits
self.max_json_size = 1024 * 1024 # 1MB
self.max_form_fields = 50
def validate_string_content(self, content: str, field_name: str = "input") -> None:
"""Validate string content for malicious patterns."""
if not isinstance(content, str):
return
for pattern in self.compiled_patterns:
if pattern.search(content):
raise SecurityValidationError(
f"Potentially malicious content detected in {field_name}"
)
def validate_filename(self, filename: str) -> str:
"""Validate and sanitize filename."""
if not filename:
raise ValidationError("Filename cannot be empty")
# Decode URL encoding
filename = unquote(filename)
# Check for malicious patterns
self.validate_string_content(filename, "filename")
# Remove dangerous characters
safe_filename = re.sub(r'[^\w\-_\.]', '_', filename)
# Prevent hidden files
if safe_filename.startswith('.'):
safe_filename = 'file_' + safe_filename[1:]
# Limit length
if len(safe_filename) > 255:
name, ext = safe_filename.rsplit('.', 1) if '.' in safe_filename else (safe_filename, '')
safe_filename = name[:250] + ('.' + ext if ext else '')
return safe_filename
def validate_file_type(self, content: bytes, expected_type: str, filename: str) -> None:
"""Validate file type using magic numbers."""
try:
detected_type = magic.from_buffer(content, mime=True)
except Exception:
# Fallback to extension-based validation
ext = filename.lower().split('.')[-1] if '.' in filename else ''
video_extensions = {'mp4', 'mov', 'avi', 'mkv'}
subtitle_extensions = {'vtt', 'srt', 'txt'}
if expected_type == "video" and ext not in video_extensions:
raise ValidationError(f"Invalid video file extension: {ext}")
elif expected_type == "subtitle" and ext not in subtitle_extensions:
raise ValidationError(f"Invalid subtitle file extension: {ext}")
return
if expected_type == "video" and detected_type not in self.allowed_video_types:
raise ValidationError(
f"Invalid video file type: {detected_type}. "
f"Allowed types: {', '.join(self.allowed_video_types)}"
)
elif expected_type == "subtitle" and detected_type not in self.allowed_subtitle_types:
raise ValidationError(
f"Invalid subtitle file type: {detected_type}. "
f"Allowed types: {', '.join(self.allowed_subtitle_types)}"
)
def validate_file_size(self, size: int, file_type: str) -> None:
"""Validate file size limits."""
if file_type == "video" and size > self.max_video_size:
raise ValidationError(
f"Video file too large: {size} bytes. "
f"Maximum allowed: {self.max_video_size} bytes"
)
elif file_type == "subtitle" and size > self.max_subtitle_size:
raise ValidationError(
f"Subtitle file too large: {size} bytes. "
f"Maximum allowed: {self.max_subtitle_size} bytes"
)
async def validate_json_payload(self, request: Request) -> Optional[Dict[str, Any]]:
"""Validate JSON request payload."""
if not request.headers.get("content-type", "").startswith("application/json"):
return None
content_length = request.headers.get("content-length")
if content_length and int(content_length) > self.max_json_size:
raise ValidationError(f"JSON payload too large: {content_length} bytes")
try:
# Check if body has already been read
if hasattr(request, '_cached_body'):
body = request._cached_body
else:
body = await request.body()
# Cache the body so FastAPI can read it later
request._cached_body = body
if len(body) > self.max_json_size:
raise ValidationError(f"JSON payload too large: {len(body)} bytes")
if not body:
return {}
payload = json.loads(body)
# Recursively validate all string values
self._validate_json_values(payload)
return payload
except json.JSONDecodeError as e:
raise ValidationError(f"Invalid JSON: {e}")
def _validate_json_values(self, obj: Any, path: str = "root") -> None:
"""Recursively validate JSON values."""
if isinstance(obj, dict):
if len(obj) > self.max_form_fields:
raise ValidationError(f"Too many fields in object at {path}")
for key, value in obj.items():
if isinstance(key, str):
self.validate_string_content(key, f"{path}.{key}")
self._validate_json_values(value, f"{path}.{key}")
elif isinstance(obj, list):
if len(obj) > 1000: # Prevent large arrays
raise ValidationError(f"Array too large at {path}")
for i, item in enumerate(obj):
self._validate_json_values(item, f"{path}[{i}]")
elif isinstance(obj, str):
self.validate_string_content(obj, path)
def validate_query_params(self, request: Request) -> None:
"""Validate query parameters."""
for key, value in request.query_params.items():
self.validate_string_content(key, f"query.{key}")
self.validate_string_content(str(value), f"query.{key}")
def validate_headers(self, request: Request) -> None:
"""Validate request headers."""
suspicious_headers = {
"x-forwarded-host",
"x-original-host",
"x-rewrite-url"
}
for header_name, header_value in request.headers.items():
# Check for suspicious headers
if header_name.lower() in suspicious_headers:
self.validate_string_content(header_value, f"header.{header_name}")
# Validate user-agent length
if header_name.lower() == "user-agent" and len(header_value) > 500:
raise SecurityValidationError("User-Agent header too long")
class ValidationMiddleware:
"""FastAPI middleware for enhanced request validation."""
def __init__(self):
self.validator = RequestValidator()
async def __call__(self, request: Request, call_next):
"""Process validation for the request."""
start_time = time.time()
validation_errors = []
# Skip validation for timing adjustment endpoint temporarily
if "/vtt/adjust-timing" in request.url.path:
return await call_next(request)
try:
# Validate headers
self.validator.validate_headers(request)
# Validate query parameters
self.validator.validate_query_params(request)
# Validate JSON payload if present
if request.method in ["POST", "PUT", "PATCH"]:
await self.validator.validate_json_payload(request)
# Process the request
response = await call_next(request)
# Track successful validation
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=True,
validation_time=time.time() - start_time,
error_types=[]
)
return response
except SecurityValidationError as e:
validation_errors.append("security")
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=False,
validation_time=time.time() - start_time,
error_types=validation_errors
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"detail": "Security validation failed",
"error_code": "SECURITY_VALIDATION_ERROR"
}
)
except ValidationError as e:
validation_errors.append("format")
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=False,
validation_time=time.time() - start_time,
error_types=validation_errors
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": str(e),
"error_code": "VALIDATION_ERROR"
}
)
except Exception as e:
validation_errors.append("unknown")
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=False,
validation_time=time.time() - start_time,
error_types=validation_errors
)
# Log unexpected error but continue processing
print(f"Validation middleware error: {e}")
return await call_next(request)
async def create_validation_middleware() -> ValidationMiddleware:
"""Factory function to create validation middleware."""
return ValidationMiddleware()