246 lines
8.2 KiB
Python
Executable file
246 lines
8.2 KiB
Python
Executable file
from __future__ import annotations
|
|
|
|
from collections.abc import Awaitable
|
|
from typing import Any
|
|
from typing import Callable
|
|
from typing import cast
|
|
from typing import IO
|
|
from typing import NoReturn
|
|
from typing import Optional
|
|
from typing import TYPE_CHECKING
|
|
from urllib.parse import parse_qsl
|
|
|
|
from werkzeug.datastructures import Headers
|
|
from werkzeug.datastructures import MultiDict
|
|
from werkzeug.exceptions import RequestEntityTooLarge
|
|
from werkzeug.formparser import default_stream_factory
|
|
from werkzeug.http import parse_options_header
|
|
from werkzeug.sansio.multipart import Data
|
|
from werkzeug.sansio.multipart import Epilogue
|
|
from werkzeug.sansio.multipart import Field
|
|
from werkzeug.sansio.multipart import File
|
|
from werkzeug.sansio.multipart import MultipartDecoder
|
|
from werkzeug.sansio.multipart import NeedData
|
|
|
|
from .datastructures import FileStorage
|
|
|
|
if TYPE_CHECKING:
|
|
from .wrappers.request import Body
|
|
|
|
StreamFactory = Callable[
|
|
[Optional[int], Optional[str], Optional[str], Optional[int]],
|
|
IO[bytes],
|
|
]
|
|
|
|
ParserFunc = Callable[
|
|
["FormDataParser", "Body", str, Optional[int], dict[str, str]],
|
|
Awaitable[tuple[MultiDict, MultiDict]],
|
|
]
|
|
|
|
|
|
class FormDataParser:
|
|
file_storage_class = FileStorage
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
cls: type[MultiDict] | None = MultiDict,
|
|
max_content_length: int | None = None,
|
|
max_form_memory_size: int | None = None,
|
|
max_form_parts: int | None = None,
|
|
silent: bool = True,
|
|
stream_factory: StreamFactory = default_stream_factory,
|
|
) -> None:
|
|
self.cls = cls
|
|
self.max_content_length = max_content_length
|
|
self.max_form_memory_size = max_form_memory_size
|
|
self.max_form_parts = max_form_parts
|
|
self.silent = silent
|
|
self.stream_factory = stream_factory
|
|
|
|
def get_parse_func(
|
|
self, mimetype: str, options: dict[str, str]
|
|
) -> ParserFunc | None:
|
|
return self.parse_functions.get(mimetype)
|
|
|
|
async def parse(
|
|
self,
|
|
body: Body,
|
|
mimetype: str,
|
|
content_length: int | None,
|
|
options: dict[str, str] | None = None,
|
|
) -> tuple[MultiDict, MultiDict]:
|
|
if options is None:
|
|
options = {}
|
|
|
|
parse_func = self.get_parse_func(mimetype, options)
|
|
|
|
if parse_func is not None:
|
|
try:
|
|
return await parse_func(self, body, mimetype, content_length, options)
|
|
except ValueError:
|
|
if not self.silent:
|
|
raise
|
|
|
|
return self.cls(), self.cls()
|
|
|
|
async def _parse_multipart(
|
|
self,
|
|
body: Body,
|
|
mimetype: str,
|
|
content_length: int | None,
|
|
options: dict[str, str],
|
|
) -> tuple[MultiDict, MultiDict]:
|
|
parser = MultiPartParser(
|
|
cls=self.cls,
|
|
file_storage_cls=self.file_storage_class,
|
|
max_content_length=self.max_content_length,
|
|
max_form_memory_size=self.max_form_memory_size,
|
|
max_form_parts=self.max_form_parts,
|
|
stream_factory=self.stream_factory,
|
|
)
|
|
boundary = options.get("boundary", "").encode("ascii")
|
|
|
|
if not boundary:
|
|
raise ValueError("Missing boundary")
|
|
|
|
return await parser.parse(body, boundary, content_length)
|
|
|
|
async def _parse_urlencoded(
|
|
self,
|
|
body: Body,
|
|
mimetype: str,
|
|
content_length: int | None,
|
|
options: dict[str, str],
|
|
) -> tuple[MultiDict, MultiDict]:
|
|
try:
|
|
form = parse_qsl(
|
|
(await body).decode(),
|
|
keep_blank_values=True,
|
|
max_num_fields=self.max_form_parts,
|
|
)
|
|
except ValueError:
|
|
raise RequestEntityTooLarge() from None
|
|
return self.cls(form), self.cls()
|
|
|
|
parse_functions: dict[str, ParserFunc] = {
|
|
"multipart/form-data": _parse_multipart,
|
|
"application/x-www-form-urlencoded": _parse_urlencoded,
|
|
"application/x-url-encoded": _parse_urlencoded,
|
|
}
|
|
|
|
|
|
class MultiPartParser:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
buffer_size: int = 64 * 1024,
|
|
cls: type[MultiDict] = MultiDict,
|
|
file_storage_cls: type[FileStorage] = FileStorage,
|
|
max_content_length: int | None = None,
|
|
max_form_memory_size: int | None = None,
|
|
max_form_parts: int | None = None,
|
|
stream_factory: StreamFactory = default_stream_factory,
|
|
) -> None:
|
|
self.buffer_size = buffer_size
|
|
self.cls = cls
|
|
self.file_storage_cls = file_storage_cls
|
|
self.max_content_length = max_content_length
|
|
self.max_form_memory_size = max_form_memory_size
|
|
self.max_form_parts = max_form_parts
|
|
self.stream_factory = stream_factory
|
|
|
|
def fail(self, message: str) -> NoReturn:
|
|
raise ValueError(message)
|
|
|
|
def get_part_charset(self, headers: Headers) -> str:
|
|
content_type = headers.get("content-type")
|
|
|
|
if content_type:
|
|
parameters = parse_options_header(content_type)[1]
|
|
ct_charset = parameters.get("charset", "").lower()
|
|
|
|
# A safe list of encodings. Modern clients should only send ASCII or UTF-8.
|
|
# This list will not be extended further.
|
|
if ct_charset in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}:
|
|
return ct_charset
|
|
|
|
return "utf-8"
|
|
|
|
def start_file_streaming(self, event: File, total_content_length: int) -> IO[bytes]:
|
|
content_type = event.headers.get("content-type")
|
|
|
|
try:
|
|
content_length = int(event.headers["content-length"])
|
|
except (KeyError, ValueError):
|
|
content_length = 0
|
|
|
|
container = self.stream_factory(
|
|
total_content_length,
|
|
content_type,
|
|
event.filename,
|
|
content_length,
|
|
)
|
|
return container
|
|
|
|
async def parse(
|
|
self, body: Body, boundary: bytes, content_length: int
|
|
) -> tuple[MultiDict, MultiDict]:
|
|
container: IO[bytes] | list[bytes]
|
|
_write: Callable[[bytes], Any]
|
|
|
|
parser = MultipartDecoder(
|
|
boundary, self.max_content_length, max_parts=self.max_form_parts
|
|
)
|
|
|
|
fields = []
|
|
files = []
|
|
|
|
current_part: Field | File
|
|
field_size: int | None = None
|
|
async for data in body:
|
|
parser.receive_data(data)
|
|
event = parser.next_event()
|
|
while not isinstance(event, (Epilogue, NeedData)):
|
|
if isinstance(event, Field):
|
|
current_part = event
|
|
field_size = 0
|
|
container = []
|
|
_write = container.append
|
|
elif isinstance(event, File):
|
|
current_part = event
|
|
field_size = None
|
|
container = self.start_file_streaming(event, content_length)
|
|
_write = container.write
|
|
elif isinstance(event, Data):
|
|
if self.max_form_memory_size is not None and field_size is not None:
|
|
field_size += len(event.data)
|
|
|
|
if field_size > self.max_form_memory_size:
|
|
raise RequestEntityTooLarge()
|
|
|
|
_write(event.data)
|
|
if not event.more_data:
|
|
if isinstance(current_part, Field):
|
|
value = b"".join(container).decode(
|
|
self.get_part_charset(current_part.headers), "replace"
|
|
)
|
|
fields.append((current_part.name, value))
|
|
else:
|
|
container = cast(IO[bytes], container)
|
|
container.seek(0)
|
|
files.append(
|
|
(
|
|
current_part.name,
|
|
self.file_storage_cls(
|
|
container,
|
|
current_part.filename,
|
|
current_part.name,
|
|
headers=current_part.headers,
|
|
),
|
|
)
|
|
)
|
|
|
|
event = parser.next_event()
|
|
|
|
return self.cls(fields), self.cls(files)
|