brief-extractor/backend/venv/lib/python3.10/site-packages/quart/wrappers/request.py
2026-03-06 18:42:46 +00:00

414 lines
13 KiB
Python
Executable file

from __future__ import annotations
import asyncio
from collections.abc import Awaitable
from collections.abc import Generator
from typing import Any
from typing import Callable
from typing import Literal
from typing import NoReturn
from typing import overload
from hypercorn.typing import HTTPScope
from werkzeug.datastructures import CombinedMultiDict
from werkzeug.datastructures import Headers
from werkzeug.datastructures import iter_multi_items
from werkzeug.datastructures import MultiDict
from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import RequestEntityTooLarge
from werkzeug.exceptions import RequestTimeout
from ..formparser import FormDataParser
from ..globals import current_app
from .base import BaseRequestWebsocket
SERVER_PUSH_HEADERS_TO_COPY = {
"accept",
"accept-encoding",
"accept-language",
"cache-control",
"user-agent",
}
class Body:
"""A request body container.
The request body can either be iterated over and consumed in parts
(without building up memory usage) or awaited.
.. code-block:: python
async for data in body:
...
# or simply
complete = await body
Note: It is not possible to iterate over the data and then await
it.
"""
def __init__(
self, expected_content_length: int | None, max_content_length: int | None
) -> None:
self._data = bytearray()
self._complete: asyncio.Event = asyncio.Event()
self._has_data: asyncio.Event = asyncio.Event()
self._max_content_length = max_content_length
# Exceptions must be raised within application (not ASGI)
# calls, this is achieved by having the ASGI methods set this
# to an exception on error.
self._must_raise: Exception | None = None
if (
expected_content_length is not None
and max_content_length is not None
and expected_content_length > max_content_length
):
self._must_raise = RequestEntityTooLarge()
def __aiter__(self) -> Body:
return self
async def __anext__(self) -> bytes:
if self._must_raise is not None:
raise self._must_raise
# if we got all of the data in the first shot, then self._complete is
# set and self._has_data will not get set again, so skip the await
# if we already have completed everything
if not self._complete.is_set():
await self._has_data.wait()
if self._complete.is_set() and len(self._data) == 0:
raise StopAsyncIteration()
data = bytes(self._data)
self._data.clear()
self._has_data.clear()
return data
def __await__(self) -> Generator[Any, None, Any]:
# Must check the _must_raise before and after waiting on the
# completion event as it may change whilst waiting and the
# event may not be set if there is already an issue.
if self._must_raise is not None:
raise self._must_raise
yield from self._complete.wait().__await__()
if self._must_raise is not None:
raise self._must_raise
return bytes(self._data)
def append(self, data: bytes) -> None:
if data == b"" or self._must_raise is not None:
return
self._data.extend(data)
self._has_data.set()
if (
self._max_content_length is not None
and len(self._data) > self._max_content_length
):
self._must_raise = RequestEntityTooLarge()
self.set_complete()
def set_complete(self) -> None:
self._complete.set()
self._has_data.set()
def set_result(self, data: bytes) -> None:
"""Convenience method, mainly for testing."""
self.append(data)
self.set_complete()
def clear(self) -> None:
self._data.clear()
class Request(BaseRequestWebsocket):
"""This class represents a request.
It can be subclassed and the subclassed used in preference by
replacing the :attr:`~quart.Quart.request_class` with your
subclass.
Attributes:
body_class: The class to store the body data within.
form_data_parser_class: Can be overridden to implement a
different form data parsing.
"""
body_class = Body
form_data_parser_class = FormDataParser
lock_class = asyncio.Lock
_max_content_length: int | None = None
_max_form_memory_size: int | None = None
_max_form_parts: int | None = None
def __init__(
self,
method: str,
scheme: str,
path: str,
query_string: bytes,
headers: Headers,
root_path: str,
http_version: str,
scope: HTTPScope,
*,
max_content_length: int | None = None,
body_timeout: int | None = None,
send_push_promise: Callable[[str, Headers], Awaitable[None]],
) -> None:
"""Create a request object.
Arguments:
method: The HTTP verb.
scheme: The scheme used for the request.
path: The full unquoted path of the request.
query_string: The raw bytes for the query string part.
headers: The request headers.
root_path: The root path that should be prepended to all
routes.
http_version: The HTTP version of the request.
body: An awaitable future for the body data i.e.
``data = await body``
max_content_length: The maximum length in bytes of the
body (None implies no limit in Quart).
body_timeout: The maximum time (seconds) to wait for the
body before timing out.
send_push_promise: An awaitable to send a push promise based
off of this request (HTTP/2 feature).
scope: Underlying ASGI scope dictionary.
"""
super().__init__(
method, scheme, path, query_string, headers, root_path, http_version, scope
)
self.body_timeout = body_timeout
self.body = self.body_class(self.content_length, max_content_length)
self._cached_json: dict[bool, Any] = {False: Ellipsis, True: Ellipsis}
self._form: MultiDict | None = None
self._files: MultiDict | None = None
self._parsing_lock = self.lock_class()
self._send_push_promise = send_push_promise
@property
def max_content_length(self) -> int | None:
if self._max_content_length is not None:
return self._max_content_length
if current_app:
return current_app.config["MAX_CONTENT_LENGTH"]
return None
@max_content_length.setter
def max_content_length(self, value: int | None) -> None:
self._max_content_length = value
@property
def max_form_memory_size(self) -> int | None:
if self._max_form_memory_size is not None:
return self._max_form_memory_size
if current_app:
return current_app.config["MAX_FORM_MEMORY_SIZE"]
return None
@max_form_memory_size.setter
def max_form_memory_size(self, value: int | None) -> None:
self._max_form_memory_size = value
@property
def max_form_parts(self) -> int | None:
if self._max_form_parts is not None:
return self._max_form_parts
if current_app:
return current_app.config["MAX_FORM_PARTS"]
return None
@max_form_parts.setter
def max_form_parts(self, value: int | None) -> None:
self._max_form_parts = value
@property
async def stream(self) -> NoReturn:
raise NotImplementedError("Use body instead")
@property
async def data(self) -> bytes:
return await self.get_data(as_text=False, parse_form_data=True) # type: ignore
@overload
async def get_data(
self, cache: bool, as_text: Literal[False], parse_form_data: bool
) -> bytes: ...
@overload
async def get_data(
self, cache: bool, as_text: Literal[True], parse_form_data: bool
) -> str: ...
@overload
async def get_data(
self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False
) -> str | bytes: ...
async def get_data(
self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False
) -> str | bytes:
"""Get the request body data.
Arguments:
cache: If False the body data will be cleared, resulting in any
subsequent calls returning an empty str | bytes and reducing
memory usage.
as_text: If True the data is returned as a decoded string,
otherwise raw bytes are returned.
parse_form_data: Parse the data as form data first, return any
remaining data.
"""
if parse_form_data:
await self._load_form_data()
try:
raw_data = await asyncio.wait_for(self.body, timeout=self.body_timeout)
except asyncio.TimeoutError as e:
raise RequestTimeout() from e
else:
if not cache:
self.body.clear()
if as_text:
return raw_data.decode()
else:
return raw_data
@property
async def values(self) -> CombinedMultiDict:
sources = [self.args]
if self.method != "GET":
# Whilst GET requests are allowed to have a body, most
# implementations do not allow this hence this
# inconsistency may result in confusing values.
form = await self.form
sources.append(form)
multidict_sources: list[MultiDict] = []
for source in sources:
if not isinstance(source, MultiDict):
multidict_sources.append(MultiDict(source))
else:
multidict_sources.append(source)
return CombinedMultiDict(multidict_sources)
@property
async def form(self) -> MultiDict:
"""The parsed form encoded data.
Note file data is present in the :attr:`files`.
"""
await self._load_form_data()
return self._form
@property
async def files(self) -> MultiDict:
"""The parsed files.
This will return an empty multidict unless the request
mimetype was ``enctype="multipart/form-data"`` and the method
POST, PUT, or PATCH.
"""
await self._load_form_data()
return self._files
def make_form_data_parser(self) -> FormDataParser:
return self.form_data_parser_class(
max_content_length=self.max_content_length,
max_form_memory_size=self.max_form_memory_size,
max_form_parts=self.max_form_parts,
cls=self.parameter_storage_class,
)
async def _load_form_data(self) -> None:
async with self._parsing_lock:
if self._form is None:
parser = self.make_form_data_parser()
try:
self._form, self._files = await asyncio.wait_for(
parser.parse(
self.body,
self.mimetype,
self.content_length,
self.mimetype_params,
),
timeout=self.body_timeout,
)
except asyncio.TimeoutError as e:
raise RequestTimeout() from e
@property
async def json(self) -> Any:
return await self.get_json()
async def get_json(
self, force: bool = False, silent: bool = False, cache: bool = True
) -> Any:
"""Parses the body data as JSON and returns it.
Arguments:
force: Force JSON parsing even if the mimetype is not JSON.
silent: Do not trigger error handling if parsing fails, without
this the :meth:`on_json_loading_failed` will be called on
error.
cache: Cache the parsed JSON on this request object.
"""
if cache and self._cached_json[silent] is not Ellipsis:
return self._cached_json[silent]
if not (force or self.is_json):
return None
data = await self.get_data(cache=cache, as_text=True)
try:
result = self.json_module.loads(data)
except ValueError as error:
if silent:
result = None
else:
result = self.on_json_loading_failed(error)
if cache:
self._cached_json[silent] = result
return result
def on_json_loading_failed(self, error: Exception) -> Any:
"""Handle a JSON parsing error.
Arguments:
error: The exception raised during parsing.
Returns:
Any value returned (if overridden) will be used as the
default for any get_json calls.
"""
if current_app and current_app.debug:
raise BadRequest(f"Failed to decode JSON: {error}")
raise BadRequest()
async def send_push_promise(self, path: str) -> None:
headers = Headers()
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
headers.add(name, value)
await self._send_push_promise(path, headers)
async def close(self) -> None:
for _key, value in iter_multi_items(self._files or ()):
value.close()