video-accessibility/backend/app/middleware/validation.py
Vadym Samoilenko f22d568fc5 fix(security): fix false-positive injection blocks on French/multilingual VTT content
- Remove ';' from command-injection pattern — semicolons are common in French
  and other European languages, not a shell injection risk in JSON context
- Skip security pattern scanning for free-text fields (captions_vtt,
  audio_description_vtt, notes, etc.) — natural language always generates
  false positives against injection regexes
- Add GET/HEAD to GCS CORS config so browsers can load signed VTT URLs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-13 19:11:01 +01:00

334 lines
12 KiB
Python

"""Enhanced request validation middleware."""
import json
import re
import time
from typing import Any
from urllib.parse import unquote
import magic
from fastapi import Request, status
from fastapi.responses import JSONResponse
from app.telemetry.metrics import track_validation_metrics
from ..core.config import settings
class ValidationError(Exception):
"""Custom validation error."""
pass
class SecurityValidationError(Exception):
"""Raised when security validation fails."""
pass
class RequestValidator:
"""Enhanced request validation with security checks."""
def __init__(self):
# File type restrictions
self.allowed_video_types = {
"video/mp4",
"video/quicktime",
"video/x-msvideo" # AVI
}
self.allowed_subtitle_types = {
"text/vtt",
"text/plain"
}
# Security patterns to block
self.malicious_patterns = [
# SQL injection patterns
r"\b(union|select|insert|update|delete|drop|create|alter)\b\s+",
r"vbscript:", # vbscript protocol injection
r"\b(onload|onerror|onclick)\s*=", # HTML event handler attribute injection
r"<\s*script[^>]*>",
r"javascript:",
r"data:.*base64",
# Path traversal
r"\.\./",
r"\.\.\\",
r"%2e%2e%2f",
r"%2e%2e\\",
# Command injection (removed $ and ; — semicolons are common in natural language)
r"[&|`](?!\s*$)",
r"\b(rm|wget|curl|nc|bash|sh|cmd|powershell)\b\s+",
# MongoDB injection — NoSQL operator abuse
r"\$where|\$expr|\$function|\$accumulator"
r"|\$ne|\$nin|\$not"
r"|\$gt|\$gte|\$lt|\$lte"
r"|\$regex|\$jsonSchema|\$mod",
]
self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.malicious_patterns]
# Max file sizes (in bytes) — driven by central config (T-14)
self.max_video_size = settings.upload_max_video_bytes
self.max_subtitle_size = 10 * 1024 * 1024 # 10MB
# Request size limits
self.max_json_size = 1024 * 1024 # 1MB
self.max_form_fields = 50
def validate_string_content(self, content: str, field_name: str = "input") -> None:
"""Validate string content for malicious patterns."""
if not isinstance(content, str):
return
for pattern in self.compiled_patterns:
if pattern.search(content):
raise SecurityValidationError(
f"Potentially malicious content detected in {field_name}"
)
def validate_filename(self, filename: str) -> str:
"""Validate and sanitize filename."""
if not filename:
raise ValidationError("Filename cannot be empty")
# Decode URL encoding
filename = unquote(filename)
# Check for malicious patterns
self.validate_string_content(filename, "filename")
# Remove dangerous characters
safe_filename = re.sub(r'[^\w\-_\.]', '_', filename)
# Prevent hidden files
if safe_filename.startswith('.'):
safe_filename = 'file_' + safe_filename[1:]
# Limit length
if len(safe_filename) > 255:
name, ext = safe_filename.rsplit('.', 1) if '.' in safe_filename else (safe_filename, '')
safe_filename = name[:250] + ('.' + ext if ext else '')
return safe_filename
def validate_file_type(self, content: bytes, expected_type: str, filename: str) -> None:
"""Validate file type using magic numbers."""
try:
detected_type = magic.from_buffer(content, mime=True)
except Exception:
# Fallback to extension-based validation
ext = filename.lower().split('.')[-1] if '.' in filename else ''
video_extensions = {'mp4', 'mov', 'avi', 'mkv'}
subtitle_extensions = {'vtt', 'srt', 'txt'}
if expected_type == "video" and ext not in video_extensions:
raise ValidationError(f"Invalid video file extension: {ext}") from None
elif expected_type == "subtitle" and ext not in subtitle_extensions:
raise ValidationError(f"Invalid subtitle file extension: {ext}") from None
return
if expected_type == "video" and detected_type not in self.allowed_video_types:
raise ValidationError(
f"Invalid video file type: {detected_type}. "
f"Allowed types: {', '.join(self.allowed_video_types)}"
)
elif expected_type == "subtitle" and detected_type not in self.allowed_subtitle_types:
raise ValidationError(
f"Invalid subtitle file type: {detected_type}. "
f"Allowed types: {', '.join(self.allowed_subtitle_types)}"
)
def validate_file_size(self, size: int, file_type: str) -> None:
"""Validate file size limits."""
if file_type == "video" and size > self.max_video_size:
raise ValidationError(
f"Video file too large: {size} bytes. "
f"Maximum allowed: {self.max_video_size} bytes"
)
elif file_type == "subtitle" and size > self.max_subtitle_size:
raise ValidationError(
f"Subtitle file too large: {size} bytes. "
f"Maximum allowed: {self.max_subtitle_size} bytes"
)
async def validate_json_payload(self, request: Request) -> dict[str, Any] | None:
"""Validate JSON request payload."""
if not request.headers.get("content-type", "").startswith("application/json"):
return None
content_length = request.headers.get("content-length")
if content_length and int(content_length) > self.max_json_size:
raise ValidationError(f"JSON payload too large: {content_length} bytes")
try:
# Check if body has already been read
if hasattr(request, '_cached_body'):
body = request._cached_body
else:
body = await request.body()
# Cache the body so FastAPI can read it later
request._cached_body = body
if len(body) > self.max_json_size:
raise ValidationError(f"JSON payload too large: {len(body)} bytes")
if not body:
return {}
payload = json.loads(body)
# Recursively validate all string values
self._validate_json_values(payload)
return payload
except json.JSONDecodeError as e:
raise ValidationError(f"Invalid JSON: {e}") from e
# Fields that contain free-form natural language — skip injection pattern checks
_FREETEXT_FIELDS = {"captions_vtt", "audio_description_vtt", "text", "notes", "change_note", "description"}
def _validate_json_values(self, obj: Any, path: str = "root") -> None:
"""Recursively validate JSON values."""
if isinstance(obj, dict):
if len(obj) > self.max_form_fields:
raise ValidationError(f"Too many fields in object at {path}")
for key, value in obj.items():
self.validate_string_content(key, f"{path}.key")
# Skip pattern scanning for free-text fields (VTT content, notes, etc.)
if key not in self._FREETEXT_FIELDS:
self._validate_json_values(value, f"{path}.{key}")
elif isinstance(obj, list):
if len(obj) > 1000: # Prevent large arrays
raise ValidationError(f"Array too large at {path}")
for i, item in enumerate(obj):
self._validate_json_values(item, f"{path}[{i}]")
elif isinstance(obj, str):
self.validate_string_content(obj, path)
def validate_query_params(self, request: Request) -> None:
"""Validate query parameters."""
for key, value in request.query_params.items():
self.validate_string_content(key, f"query.{key}")
self.validate_string_content(str(value), f"query.{key}")
def validate_headers(self, request: Request) -> None:
"""Validate request headers."""
suspicious_headers = {
"x-forwarded-host",
"x-original-host",
"x-rewrite-url"
}
for header_name, header_value in request.headers.items():
# Check for suspicious headers
if header_name.lower() in suspicious_headers:
self.validate_string_content(header_value, f"header.{header_name}")
# Validate user-agent length
if header_name.lower() == "user-agent" and len(header_value) > 500:
raise SecurityValidationError("User-Agent header too long")
class ValidationMiddleware:
"""FastAPI middleware for enhanced request validation."""
def __init__(self):
self.validator = RequestValidator()
async def __call__(self, request: Request, call_next):
"""Process validation for the request."""
start_time = time.time()
validation_errors = []
# Skip validation for timing adjustment endpoint temporarily
if "/vtt/adjust-timing" in request.url.path:
return await call_next(request)
try:
# Validate headers
self.validator.validate_headers(request)
# Validate query parameters
self.validator.validate_query_params(request)
# Validate JSON payload if present
if request.method in ["POST", "PUT", "PATCH"]:
await self.validator.validate_json_payload(request)
# Process the request
response = await call_next(request)
# Track successful validation
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=True,
validation_time=time.time() - start_time,
error_types=[]
)
return response
except SecurityValidationError:
validation_errors.append("security")
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=False,
validation_time=time.time() - start_time,
error_types=validation_errors
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"detail": "Security validation failed",
"error_code": "SECURITY_VALIDATION_ERROR"
}
)
except ValidationError as e:
validation_errors.append("format")
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=False,
validation_time=time.time() - start_time,
error_types=validation_errors
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"detail": str(e),
"error_code": "VALIDATION_ERROR"
}
)
except Exception as e:
validation_errors.append("unknown")
track_validation_metrics(
endpoint=request.url.path,
method=request.method,
is_valid=False,
validation_time=time.time() - start_time,
error_types=validation_errors
)
# Log unexpected error but continue processing
print(f"Validation middleware error: {e}")
return await call_next(request)
async def create_validation_middleware() -> ValidationMiddleware:
"""Factory function to create validation middleware."""
return ValidationMiddleware()