- Remove ';' from command-injection pattern — semicolons are common in French and other European languages, not a shell injection risk in JSON context - Skip security pattern scanning for free-text fields (captions_vtt, audio_description_vtt, notes, etc.) — natural language always generates false positives against injection regexes - Add GET/HEAD to GCS CORS config so browsers can load signed VTT URLs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
334 lines
12 KiB
Python
334 lines
12 KiB
Python
"""Enhanced request validation middleware."""
|
|
|
|
import json
|
|
import re
|
|
import time
|
|
from typing import Any
|
|
from urllib.parse import unquote
|
|
|
|
import magic
|
|
from fastapi import Request, status
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from app.telemetry.metrics import track_validation_metrics
|
|
|
|
from ..core.config import settings
|
|
|
|
|
|
class ValidationError(Exception):
|
|
"""Custom validation error."""
|
|
pass
|
|
|
|
|
|
class SecurityValidationError(Exception):
|
|
"""Raised when security validation fails."""
|
|
pass
|
|
|
|
|
|
class RequestValidator:
|
|
"""Enhanced request validation with security checks."""
|
|
|
|
def __init__(self):
|
|
# File type restrictions
|
|
self.allowed_video_types = {
|
|
"video/mp4",
|
|
"video/quicktime",
|
|
"video/x-msvideo" # AVI
|
|
}
|
|
|
|
self.allowed_subtitle_types = {
|
|
"text/vtt",
|
|
"text/plain"
|
|
}
|
|
|
|
# Security patterns to block
|
|
self.malicious_patterns = [
|
|
# SQL injection patterns
|
|
r"\b(union|select|insert|update|delete|drop|create|alter)\b\s+",
|
|
r"vbscript:", # vbscript protocol injection
|
|
r"\b(onload|onerror|onclick)\s*=", # HTML event handler attribute injection
|
|
r"<\s*script[^>]*>",
|
|
r"javascript:",
|
|
r"data:.*base64",
|
|
|
|
# Path traversal
|
|
r"\.\./",
|
|
r"\.\.\\",
|
|
r"%2e%2e%2f",
|
|
r"%2e%2e\\",
|
|
|
|
# Command injection (removed $ and ; — semicolons are common in natural language)
|
|
r"[&|`](?!\s*$)",
|
|
r"\b(rm|wget|curl|nc|bash|sh|cmd|powershell)\b\s+",
|
|
|
|
# MongoDB injection — NoSQL operator abuse
|
|
r"\$where|\$expr|\$function|\$accumulator"
|
|
r"|\$ne|\$nin|\$not"
|
|
r"|\$gt|\$gte|\$lt|\$lte"
|
|
r"|\$regex|\$jsonSchema|\$mod",
|
|
]
|
|
|
|
self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.malicious_patterns]
|
|
|
|
# Max file sizes (in bytes) — driven by central config (T-14)
|
|
self.max_video_size = settings.upload_max_video_bytes
|
|
self.max_subtitle_size = 10 * 1024 * 1024 # 10MB
|
|
|
|
# Request size limits
|
|
self.max_json_size = 1024 * 1024 # 1MB
|
|
self.max_form_fields = 50
|
|
|
|
def validate_string_content(self, content: str, field_name: str = "input") -> None:
|
|
"""Validate string content for malicious patterns."""
|
|
if not isinstance(content, str):
|
|
return
|
|
|
|
for pattern in self.compiled_patterns:
|
|
if pattern.search(content):
|
|
raise SecurityValidationError(
|
|
f"Potentially malicious content detected in {field_name}"
|
|
)
|
|
|
|
def validate_filename(self, filename: str) -> str:
|
|
"""Validate and sanitize filename."""
|
|
if not filename:
|
|
raise ValidationError("Filename cannot be empty")
|
|
|
|
# Decode URL encoding
|
|
filename = unquote(filename)
|
|
|
|
# Check for malicious patterns
|
|
self.validate_string_content(filename, "filename")
|
|
|
|
# Remove dangerous characters
|
|
safe_filename = re.sub(r'[^\w\-_\.]', '_', filename)
|
|
|
|
# Prevent hidden files
|
|
if safe_filename.startswith('.'):
|
|
safe_filename = 'file_' + safe_filename[1:]
|
|
|
|
# Limit length
|
|
if len(safe_filename) > 255:
|
|
name, ext = safe_filename.rsplit('.', 1) if '.' in safe_filename else (safe_filename, '')
|
|
safe_filename = name[:250] + ('.' + ext if ext else '')
|
|
|
|
return safe_filename
|
|
|
|
def validate_file_type(self, content: bytes, expected_type: str, filename: str) -> None:
|
|
"""Validate file type using magic numbers."""
|
|
try:
|
|
detected_type = magic.from_buffer(content, mime=True)
|
|
except Exception:
|
|
# Fallback to extension-based validation
|
|
ext = filename.lower().split('.')[-1] if '.' in filename else ''
|
|
video_extensions = {'mp4', 'mov', 'avi', 'mkv'}
|
|
subtitle_extensions = {'vtt', 'srt', 'txt'}
|
|
|
|
if expected_type == "video" and ext not in video_extensions:
|
|
raise ValidationError(f"Invalid video file extension: {ext}") from None
|
|
elif expected_type == "subtitle" and ext not in subtitle_extensions:
|
|
raise ValidationError(f"Invalid subtitle file extension: {ext}") from None
|
|
return
|
|
|
|
if expected_type == "video" and detected_type not in self.allowed_video_types:
|
|
raise ValidationError(
|
|
f"Invalid video file type: {detected_type}. "
|
|
f"Allowed types: {', '.join(self.allowed_video_types)}"
|
|
)
|
|
elif expected_type == "subtitle" and detected_type not in self.allowed_subtitle_types:
|
|
raise ValidationError(
|
|
f"Invalid subtitle file type: {detected_type}. "
|
|
f"Allowed types: {', '.join(self.allowed_subtitle_types)}"
|
|
)
|
|
|
|
def validate_file_size(self, size: int, file_type: str) -> None:
|
|
"""Validate file size limits."""
|
|
if file_type == "video" and size > self.max_video_size:
|
|
raise ValidationError(
|
|
f"Video file too large: {size} bytes. "
|
|
f"Maximum allowed: {self.max_video_size} bytes"
|
|
)
|
|
elif file_type == "subtitle" and size > self.max_subtitle_size:
|
|
raise ValidationError(
|
|
f"Subtitle file too large: {size} bytes. "
|
|
f"Maximum allowed: {self.max_subtitle_size} bytes"
|
|
)
|
|
|
|
async def validate_json_payload(self, request: Request) -> dict[str, Any] | None:
|
|
"""Validate JSON request payload."""
|
|
if not request.headers.get("content-type", "").startswith("application/json"):
|
|
return None
|
|
|
|
content_length = request.headers.get("content-length")
|
|
if content_length and int(content_length) > self.max_json_size:
|
|
raise ValidationError(f"JSON payload too large: {content_length} bytes")
|
|
|
|
try:
|
|
# Check if body has already been read
|
|
if hasattr(request, '_cached_body'):
|
|
body = request._cached_body
|
|
else:
|
|
body = await request.body()
|
|
# Cache the body so FastAPI can read it later
|
|
request._cached_body = body
|
|
|
|
if len(body) > self.max_json_size:
|
|
raise ValidationError(f"JSON payload too large: {len(body)} bytes")
|
|
|
|
if not body:
|
|
return {}
|
|
|
|
payload = json.loads(body)
|
|
|
|
# Recursively validate all string values
|
|
self._validate_json_values(payload)
|
|
|
|
return payload
|
|
|
|
except json.JSONDecodeError as e:
|
|
raise ValidationError(f"Invalid JSON: {e}") from e
|
|
|
|
# Fields that contain free-form natural language — skip injection pattern checks
|
|
_FREETEXT_FIELDS = {"captions_vtt", "audio_description_vtt", "text", "notes", "change_note", "description"}
|
|
|
|
def _validate_json_values(self, obj: Any, path: str = "root") -> None:
|
|
"""Recursively validate JSON values."""
|
|
if isinstance(obj, dict):
|
|
if len(obj) > self.max_form_fields:
|
|
raise ValidationError(f"Too many fields in object at {path}")
|
|
|
|
for key, value in obj.items():
|
|
self.validate_string_content(key, f"{path}.key")
|
|
# Skip pattern scanning for free-text fields (VTT content, notes, etc.)
|
|
if key not in self._FREETEXT_FIELDS:
|
|
self._validate_json_values(value, f"{path}.{key}")
|
|
|
|
elif isinstance(obj, list):
|
|
if len(obj) > 1000: # Prevent large arrays
|
|
raise ValidationError(f"Array too large at {path}")
|
|
|
|
for i, item in enumerate(obj):
|
|
self._validate_json_values(item, f"{path}[{i}]")
|
|
|
|
elif isinstance(obj, str):
|
|
self.validate_string_content(obj, path)
|
|
|
|
def validate_query_params(self, request: Request) -> None:
|
|
"""Validate query parameters."""
|
|
for key, value in request.query_params.items():
|
|
self.validate_string_content(key, f"query.{key}")
|
|
self.validate_string_content(str(value), f"query.{key}")
|
|
|
|
def validate_headers(self, request: Request) -> None:
|
|
"""Validate request headers."""
|
|
suspicious_headers = {
|
|
"x-forwarded-host",
|
|
"x-original-host",
|
|
"x-rewrite-url"
|
|
}
|
|
|
|
for header_name, header_value in request.headers.items():
|
|
# Check for suspicious headers
|
|
if header_name.lower() in suspicious_headers:
|
|
self.validate_string_content(header_value, f"header.{header_name}")
|
|
|
|
# Validate user-agent length
|
|
if header_name.lower() == "user-agent" and len(header_value) > 500:
|
|
raise SecurityValidationError("User-Agent header too long")
|
|
|
|
|
|
class ValidationMiddleware:
|
|
"""FastAPI middleware for enhanced request validation."""
|
|
|
|
def __init__(self):
|
|
self.validator = RequestValidator()
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process validation for the request."""
|
|
|
|
start_time = time.time()
|
|
validation_errors = []
|
|
|
|
# Skip validation for timing adjustment endpoint temporarily
|
|
if "/vtt/adjust-timing" in request.url.path:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
# Validate headers
|
|
self.validator.validate_headers(request)
|
|
|
|
# Validate query parameters
|
|
self.validator.validate_query_params(request)
|
|
|
|
# Validate JSON payload if present
|
|
if request.method in ["POST", "PUT", "PATCH"]:
|
|
await self.validator.validate_json_payload(request)
|
|
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Track successful validation
|
|
track_validation_metrics(
|
|
endpoint=request.url.path,
|
|
method=request.method,
|
|
is_valid=True,
|
|
validation_time=time.time() - start_time,
|
|
error_types=[]
|
|
)
|
|
|
|
return response
|
|
|
|
except SecurityValidationError:
|
|
validation_errors.append("security")
|
|
track_validation_metrics(
|
|
endpoint=request.url.path,
|
|
method=request.method,
|
|
is_valid=False,
|
|
validation_time=time.time() - start_time,
|
|
error_types=validation_errors
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={
|
|
"detail": "Security validation failed",
|
|
"error_code": "SECURITY_VALIDATION_ERROR"
|
|
}
|
|
)
|
|
|
|
except ValidationError as e:
|
|
validation_errors.append("format")
|
|
track_validation_metrics(
|
|
endpoint=request.url.path,
|
|
method=request.method,
|
|
is_valid=False,
|
|
validation_time=time.time() - start_time,
|
|
error_types=validation_errors
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
content={
|
|
"detail": str(e),
|
|
"error_code": "VALIDATION_ERROR"
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
validation_errors.append("unknown")
|
|
track_validation_metrics(
|
|
endpoint=request.url.path,
|
|
method=request.method,
|
|
is_valid=False,
|
|
validation_time=time.time() - start_time,
|
|
error_types=validation_errors
|
|
)
|
|
|
|
# Log unexpected error but continue processing
|
|
print(f"Validation middleware error: {e}")
|
|
return await call_next(request)
|
|
|
|
|
|
async def create_validation_middleware() -> ValidationMiddleware:
|
|
"""Factory function to create validation middleware."""
|
|
return ValidationMiddleware()
|