457 lines
14 KiB
Python
Executable file
457 lines
14 KiB
Python
Executable file
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from http.cookiejar import CookieJar
|
|
from types import TracebackType
|
|
from typing import Any
|
|
from typing import AnyStr
|
|
from typing import TYPE_CHECKING
|
|
from urllib.request import Request as U2Request
|
|
|
|
from werkzeug.datastructures import Authorization
|
|
from werkzeug.datastructures import Headers
|
|
from werkzeug.http import dump_cookie
|
|
|
|
from ..datastructures import FileStorage
|
|
from ..globals import _cv_request
|
|
from ..sessions import SessionMixin
|
|
from ..typing import TestHTTPConnectionProtocol
|
|
from ..typing import TestWebsocketConnectionProtocol
|
|
from ..wrappers import Response
|
|
from .connections import TestHTTPConnection
|
|
from .connections import TestWebsocketConnection
|
|
from .utils import make_test_body_with_headers
|
|
from .utils import make_test_headers_path_and_query_string
|
|
from .utils import make_test_scope
|
|
from .utils import sentinel
|
|
|
|
if TYPE_CHECKING:
|
|
from ..app import Quart # noqa
|
|
|
|
|
|
class _TestWrapper:
|
|
def __init__(self, headers: Headers) -> None:
|
|
self.headers = headers
|
|
|
|
def get_all(self, name: str, default: Any | None = None) -> list[str]:
|
|
name = name.lower()
|
|
result = []
|
|
for key, value in self.headers:
|
|
if key.lower() == name:
|
|
result.append(value)
|
|
return result or default or []
|
|
|
|
|
|
class _TestCookieJarResponse:
|
|
def __init__(self, headers: Headers) -> None:
|
|
self.headers = headers
|
|
|
|
def info(self) -> _TestWrapper:
|
|
return _TestWrapper(self.headers)
|
|
|
|
|
|
class QuartClient:
|
|
http_connection_class: type[TestHTTPConnectionProtocol]
|
|
websocket_connection_class: type[TestWebsocketConnectionProtocol]
|
|
|
|
http_connection_class = TestHTTPConnection
|
|
websocket_connection_class = TestWebsocketConnection
|
|
|
|
def __init__(self, app: Quart, use_cookies: bool = True) -> None:
|
|
self.app = app
|
|
self.cookie_jar: CookieJar | None
|
|
if use_cookies:
|
|
self.cookie_jar = CookieJar()
|
|
else:
|
|
self.cookie_jar = None
|
|
self.preserve_context = False
|
|
self.push_promises: list[tuple[str, Headers]] = []
|
|
|
|
async def open(
|
|
self,
|
|
path: str,
|
|
*,
|
|
method: str = "GET",
|
|
headers: dict | Headers | None = None,
|
|
data: AnyStr | None = None,
|
|
form: dict | None = None,
|
|
files: dict[str, FileStorage] | None = None,
|
|
query_string: dict | None = None,
|
|
json: Any = sentinel,
|
|
scheme: str = "http",
|
|
follow_redirects: bool = False,
|
|
root_path: str = "",
|
|
http_version: str = "1.1",
|
|
scope_base: dict | None = None,
|
|
auth: Authorization | tuple[str, str] | None = None,
|
|
subdomain: str | None = None,
|
|
) -> Response:
|
|
self.push_promises = []
|
|
response = await self._make_request(
|
|
path,
|
|
method,
|
|
headers,
|
|
data,
|
|
form,
|
|
files,
|
|
query_string,
|
|
json,
|
|
scheme,
|
|
root_path,
|
|
http_version,
|
|
scope_base,
|
|
auth,
|
|
subdomain,
|
|
)
|
|
if follow_redirects:
|
|
while response.status_code >= 300 and response.status_code <= 399:
|
|
# Most browsers respond to an HTTP 302 with a GET request to the
|
|
# new location, despite what the HTTP spec says. HTTP 303 should
|
|
# always be responded to with a GET request.
|
|
if response.status_code == 302 or response.status_code == 303:
|
|
method = "GET"
|
|
response = await self._make_request(
|
|
response.location,
|
|
method,
|
|
headers,
|
|
data,
|
|
form,
|
|
files,
|
|
query_string,
|
|
json,
|
|
scheme,
|
|
root_path,
|
|
http_version,
|
|
scope_base,
|
|
auth,
|
|
subdomain,
|
|
)
|
|
if self.preserve_context:
|
|
_cv_request.set(self.app._preserved_context) # type: ignore
|
|
return response
|
|
|
|
def request(
|
|
self,
|
|
path: str,
|
|
*,
|
|
method: str = "GET",
|
|
headers: dict | Headers | None = None,
|
|
query_string: dict | None = None,
|
|
scheme: str = "http",
|
|
root_path: str = "",
|
|
http_version: str = "1.1",
|
|
scope_base: dict | None = None,
|
|
auth: Authorization | tuple[str, str] | None = None,
|
|
subdomain: str | None = None,
|
|
) -> TestHTTPConnectionProtocol:
|
|
headers, path, query_string_bytes = make_test_headers_path_and_query_string(
|
|
self.app,
|
|
path,
|
|
headers,
|
|
query_string,
|
|
auth,
|
|
subdomain,
|
|
)
|
|
if self.cookie_jar is not None:
|
|
for cookie in self.cookie_jar:
|
|
headers.add("cookie", f"{cookie.name}={cookie.value}")
|
|
scope = make_test_scope(
|
|
"http",
|
|
path,
|
|
method,
|
|
headers,
|
|
query_string_bytes,
|
|
scheme,
|
|
root_path,
|
|
http_version,
|
|
scope_base,
|
|
_preserve_context=self.preserve_context,
|
|
)
|
|
return self.http_connection_class(
|
|
self.app, scope, _preserve_context=self.preserve_context
|
|
)
|
|
|
|
def websocket(
|
|
self,
|
|
path: str,
|
|
*,
|
|
headers: dict | Headers | None = None,
|
|
query_string: dict | None = None,
|
|
scheme: str = "ws",
|
|
subprotocols: list[str] | None = None,
|
|
root_path: str = "",
|
|
http_version: str = "1.1",
|
|
scope_base: dict | None = None,
|
|
auth: Authorization | tuple[str, str] | None = None,
|
|
subdomain: str | None = None,
|
|
) -> TestWebsocketConnectionProtocol:
|
|
headers, path, query_string_bytes = make_test_headers_path_and_query_string(
|
|
self.app,
|
|
path,
|
|
headers,
|
|
query_string,
|
|
auth,
|
|
subdomain,
|
|
)
|
|
if self.cookie_jar is not None:
|
|
for cookie in self.cookie_jar:
|
|
headers.add("cookie", f"{cookie.name}={cookie.value}")
|
|
scope = make_test_scope(
|
|
"websocket",
|
|
path,
|
|
"GET",
|
|
headers,
|
|
query_string_bytes,
|
|
scheme,
|
|
root_path,
|
|
http_version,
|
|
scope_base,
|
|
_preserve_context=self.preserve_context,
|
|
)
|
|
return self.websocket_connection_class(self.app, scope)
|
|
|
|
async def delete(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a DELETE request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="DELETE", **kwargs)
|
|
|
|
async def get(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a GET request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="GET", **kwargs)
|
|
|
|
async def head(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a HEAD request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="HEAD", **kwargs)
|
|
|
|
async def options(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a OPTIONS request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="OPTIONS", **kwargs)
|
|
|
|
async def patch(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a PATCH request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="PATCH", **kwargs)
|
|
|
|
async def post(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a POST request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="POST", **kwargs)
|
|
|
|
async def put(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a PUT request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="PUT", **kwargs)
|
|
|
|
async def trace(self, *args: Any, **kwargs: Any) -> Response:
|
|
"""Make a TRACE request.
|
|
|
|
See :meth:`~quart.testing.QuartClient.open` for argument
|
|
details.
|
|
"""
|
|
return await self.open(*args, method="TRACE", **kwargs)
|
|
|
|
def set_cookie(
|
|
self,
|
|
server_name: str,
|
|
key: str,
|
|
value: str = "",
|
|
max_age: int | timedelta | None = None,
|
|
expires: int | float | datetime | None = None,
|
|
path: str = "/",
|
|
domain: str | None = None,
|
|
secure: bool = False,
|
|
httponly: bool = False,
|
|
samesite: str = None,
|
|
) -> None:
|
|
"""Set a cookie in the cookie jar.
|
|
|
|
The arguments are the standard cookie morsels and this is a
|
|
wrapper around the stdlib SimpleCookie code.
|
|
"""
|
|
cookie = dump_cookie(
|
|
key,
|
|
value=value,
|
|
max_age=max_age,
|
|
expires=expires,
|
|
path=path,
|
|
domain=domain,
|
|
secure=secure,
|
|
httponly=httponly,
|
|
samesite=samesite,
|
|
)
|
|
self.cookie_jar.extract_cookies(
|
|
_TestCookieJarResponse(Headers([("set-cookie", cookie)])), # type: ignore
|
|
U2Request(f"http://{server_name}{path}"),
|
|
)
|
|
|
|
def delete_cookie(
|
|
self, server_name: str, key: str, path: str = "/", domain: str | None = None
|
|
) -> None:
|
|
"""Delete a cookie (set to expire immediately)."""
|
|
self.set_cookie(
|
|
server_name, key, expires=0, max_age=0, path=path, domain=domain
|
|
)
|
|
|
|
@asynccontextmanager
|
|
async def session_transaction(
|
|
self,
|
|
path: str = "/",
|
|
*,
|
|
method: str = "GET",
|
|
headers: dict | Headers | None = None,
|
|
query_string: dict | None = None,
|
|
scheme: str = "http",
|
|
data: AnyStr | None = None,
|
|
form: dict | None = None,
|
|
json: Any = sentinel,
|
|
root_path: str = "",
|
|
http_version: str = "1.1",
|
|
auth: Authorization | tuple[str, str] | None = None,
|
|
) -> AsyncGenerator[SessionMixin, None]:
|
|
if self.cookie_jar is None:
|
|
raise RuntimeError(
|
|
"Session transactions only make sense with cookies enabled."
|
|
)
|
|
|
|
if headers is None:
|
|
headers = Headers()
|
|
elif isinstance(headers, Headers):
|
|
headers = headers
|
|
elif headers is not None:
|
|
headers = Headers(headers)
|
|
for cookie in self.cookie_jar:
|
|
headers.add("cookie", f"{cookie.name}={cookie.value}")
|
|
|
|
original_request_ctx = _cv_request.get(None)
|
|
async with self.app.test_request_context(
|
|
path,
|
|
method=method,
|
|
headers=headers,
|
|
query_string=query_string,
|
|
scheme=scheme,
|
|
data=data,
|
|
form=form,
|
|
json=json,
|
|
root_path=root_path,
|
|
http_version=http_version,
|
|
auth=auth,
|
|
) as ctx:
|
|
session_interface = self.app.session_interface
|
|
session = await session_interface.open_session(self.app, ctx.request)
|
|
if session is None:
|
|
raise RuntimeError("Error opening the session. Check the secret_key?")
|
|
|
|
token = _cv_request.set(original_request_ctx)
|
|
try:
|
|
yield session
|
|
finally:
|
|
_cv_request.reset(token)
|
|
|
|
response = self.app.response_class(b"")
|
|
if not session_interface.is_null_session(session):
|
|
await session_interface.save_session(self.app, session, response)
|
|
self.cookie_jar.extract_cookies(
|
|
_TestCookieJarResponse(response.headers), # type: ignore
|
|
U2Request(ctx.request.url),
|
|
)
|
|
|
|
async def __aenter__(self) -> QuartClient:
|
|
if self.preserve_context:
|
|
raise RuntimeError("Cannot nest client invocations")
|
|
self.preserve_context = True
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self, exc_type: type, exc_value: BaseException, tb: TracebackType
|
|
) -> None:
|
|
self.preserve_context = False
|
|
|
|
while True:
|
|
top = _cv_request.get(None)
|
|
|
|
if top is not None and top.preserved:
|
|
await top.pop(None)
|
|
else:
|
|
break
|
|
|
|
async def _make_request(
|
|
self,
|
|
path: str,
|
|
method: str,
|
|
headers: dict | Headers | None,
|
|
data: AnyStr | None,
|
|
form: dict | None,
|
|
files: dict[str, FileStorage] | None,
|
|
query_string: dict | None,
|
|
json: Any,
|
|
scheme: str,
|
|
root_path: str,
|
|
http_version: str,
|
|
scope_base: dict | None,
|
|
auth: Authorization | tuple[str, str] | None = None,
|
|
subdomain: str | None = None,
|
|
) -> Response:
|
|
headers, path, query_string_bytes = make_test_headers_path_and_query_string(
|
|
self.app, path, headers, query_string, auth, subdomain
|
|
)
|
|
request_data, body_headers = make_test_body_with_headers(
|
|
data=data, form=form, files=files, json=json, app=self.app
|
|
)
|
|
headers.update(**body_headers)
|
|
|
|
if self.cookie_jar is not None:
|
|
for cookie in self.cookie_jar:
|
|
headers.add("cookie", f"{cookie.name}={cookie.value}")
|
|
|
|
scope = make_test_scope(
|
|
"http",
|
|
path,
|
|
method,
|
|
headers,
|
|
query_string_bytes,
|
|
scheme,
|
|
root_path,
|
|
http_version,
|
|
scope_base,
|
|
_preserve_context=self.preserve_context,
|
|
)
|
|
async with self.http_connection_class(
|
|
self.app, scope, _preserve_context=self.preserve_context
|
|
) as connection:
|
|
await connection.send(request_data)
|
|
await connection.send_complete()
|
|
response = await connection.as_response()
|
|
if self.cookie_jar is not None:
|
|
self.cookie_jar.extract_cookies(
|
|
_TestCookieJarResponse(response.headers), # type: ignore
|
|
U2Request(f"{scheme}://{headers['host']}{path}"),
|
|
)
|
|
self.push_promises.extend(connection.push_promises)
|
|
return response
|