200 lines
6.9 KiB
Python
Executable file
200 lines
6.9 KiB
Python
Executable file
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Awaitable
|
|
from types import TracebackType
|
|
from typing import Any
|
|
from typing import AnyStr
|
|
from typing import TYPE_CHECKING
|
|
|
|
from hypercorn.typing import ASGIReceiveEvent
|
|
from hypercorn.typing import ASGISendEvent
|
|
from hypercorn.typing import HTTPScope
|
|
from hypercorn.typing import WebsocketScope
|
|
from werkzeug.datastructures import Headers
|
|
|
|
from ..json import dumps
|
|
from ..json import loads
|
|
from ..utils import decode_headers
|
|
from ..wrappers import Response
|
|
|
|
if TYPE_CHECKING:
|
|
from ..app import Quart # noqa
|
|
|
|
|
|
class HTTPDisconnectError(Exception):
|
|
pass
|
|
|
|
|
|
class WebsocketDisconnectError(Exception):
|
|
pass
|
|
|
|
|
|
class WebsocketResponseError(Exception):
|
|
def __init__(self, response: Response) -> None:
|
|
super().__init__(response)
|
|
self.response = response
|
|
|
|
|
|
class TestHTTPConnection:
|
|
def __init__(
|
|
self, app: Quart, scope: HTTPScope, _preserve_context: bool = False
|
|
) -> None:
|
|
self.app = app
|
|
self.headers: Headers | None = None
|
|
self.push_promises: list[tuple[str, Headers]] = []
|
|
self.response_data = bytearray()
|
|
self.scope = scope
|
|
self.status_code: int | None = None
|
|
self._preserve_context = _preserve_context
|
|
self._send_queue: asyncio.Queue = asyncio.Queue()
|
|
self._receive_queue: asyncio.Queue = asyncio.Queue()
|
|
self._task: Awaitable[None] = None
|
|
|
|
async def send(self, data: bytes) -> None:
|
|
await self._send_queue.put(
|
|
{"type": "http.request", "body": data, "more_body": True}
|
|
)
|
|
|
|
async def send_complete(self) -> None:
|
|
await self._send_queue.put(
|
|
{"type": "http.request", "body": b"", "more_body": False}
|
|
)
|
|
|
|
async def receive(self) -> bytes:
|
|
data = await self._receive_queue.get()
|
|
if isinstance(data, Exception):
|
|
raise data
|
|
else:
|
|
return data
|
|
|
|
async def disconnect(self) -> None:
|
|
await self._send_queue.put({"type": "http.disconnect"})
|
|
|
|
async def __aenter__(self) -> TestHTTPConnection:
|
|
self._task = asyncio.ensure_future(
|
|
self.app(self.scope, self._asgi_receive, self._asgi_send)
|
|
)
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self, exc_type: type, exc_value: BaseException, tb: TracebackType
|
|
) -> None:
|
|
if exc_type is not None:
|
|
await self.disconnect()
|
|
await self._task
|
|
while not self._receive_queue.empty():
|
|
data = await self._receive_queue.get()
|
|
if isinstance(data, bytes):
|
|
self.response_data.extend(data)
|
|
elif not isinstance(data, HTTPDisconnectError):
|
|
raise data
|
|
|
|
async def as_response(self) -> Response:
|
|
while not self._receive_queue.empty():
|
|
data = await self._receive_queue.get()
|
|
if isinstance(data, bytes):
|
|
self.response_data.extend(data)
|
|
return self.app.response_class(
|
|
bytes(self.response_data), self.status_code, self.headers
|
|
)
|
|
|
|
async def _asgi_receive(self) -> ASGIReceiveEvent:
|
|
return await self._send_queue.get()
|
|
|
|
async def _asgi_send(self, message: ASGISendEvent) -> None:
|
|
if message["type"] == "http.response.start":
|
|
self.headers = decode_headers(message["headers"])
|
|
self.status_code = message["status"]
|
|
elif message["type"] == "http.response.body":
|
|
await self._receive_queue.put(message["body"])
|
|
elif message["type"] == "http.response.push":
|
|
self.push_promises.append(
|
|
(message["path"], decode_headers(message["headers"]))
|
|
)
|
|
elif message["type"] == "http.disconnect":
|
|
await self._receive_queue.put(HTTPDisconnectError())
|
|
|
|
|
|
class TestWebsocketConnection:
|
|
def __init__(self, app: Quart, scope: WebsocketScope) -> None:
|
|
self.accepted = False
|
|
self.app = app
|
|
self.headers: Headers | None = None
|
|
self.response_data = bytearray()
|
|
self.scope = scope
|
|
self.status_code: int | None = None
|
|
self._send_queue: asyncio.Queue = asyncio.Queue()
|
|
self._receive_queue: asyncio.Queue = asyncio.Queue()
|
|
self._task: Awaitable[None] = None
|
|
|
|
async def __aenter__(self) -> TestWebsocketConnection:
|
|
self._task = asyncio.ensure_future(
|
|
self.app(self.scope, self._asgi_receive, self._asgi_send)
|
|
)
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self, exc_type: type, exc_value: BaseException, tb: TracebackType
|
|
) -> None:
|
|
await self.disconnect()
|
|
await self._task
|
|
while not self._receive_queue.empty():
|
|
data = await self._receive_queue.get()
|
|
if isinstance(data, Exception) and not isinstance(
|
|
data, WebsocketDisconnectError
|
|
):
|
|
raise data
|
|
|
|
async def receive(self) -> AnyStr:
|
|
data = await self._receive_queue.get()
|
|
if isinstance(data, Exception):
|
|
raise data
|
|
else:
|
|
return data
|
|
|
|
async def send(self, data: AnyStr) -> None:
|
|
if isinstance(data, str):
|
|
await self._send_queue.put({"type": "websocket.receive", "text": data})
|
|
else:
|
|
await self._send_queue.put({"type": "websocket.receive", "bytes": data})
|
|
|
|
async def receive_json(self) -> Any:
|
|
data = await self.receive()
|
|
return loads(data)
|
|
|
|
async def send_json(self, data: Any) -> None:
|
|
raw = dumps(data)
|
|
await self.send(raw)
|
|
|
|
async def close(self, code: int) -> None:
|
|
await self._send_queue.put({"type": "websocket.close", "code": code})
|
|
|
|
async def disconnect(self) -> None:
|
|
await self._send_queue.put({"type": "websocket.disconnect"})
|
|
|
|
async def _asgi_receive(self) -> ASGIReceiveEvent:
|
|
return await self._send_queue.get()
|
|
|
|
async def _asgi_send(self, message: ASGISendEvent) -> None:
|
|
if message["type"] == "websocket.accept":
|
|
self.accepted = True
|
|
elif message["type"] == "websocket.send":
|
|
await self._receive_queue.put(message.get("bytes") or message.get("text"))
|
|
elif message["type"] == "websocket.http.response.start":
|
|
self.headers = decode_headers(message["headers"])
|
|
self.status_code = message["status"]
|
|
elif message["type"] == "websocket.http.response.body":
|
|
self.response_data.extend(message["body"])
|
|
if not message.get("more_body", False):
|
|
await self._receive_queue.put(
|
|
WebsocketResponseError(
|
|
self.app.response_class(
|
|
bytes(self.response_data), self.status_code, self.headers
|
|
)
|
|
)
|
|
)
|
|
elif message["type"] == "websocket.close":
|
|
await self._receive_queue.put(
|
|
WebsocketDisconnectError(message.get("code", 1000))
|
|
)
|