brief-extractor/backend/venv/lib/python3.10/site-packages/quart/testing/connections.py
2026-03-06 18:42:46 +00:00

200 lines
6.9 KiB
Python
Executable file

from __future__ import annotations
import asyncio
from collections.abc import Awaitable
from types import TracebackType
from typing import Any
from typing import AnyStr
from typing import TYPE_CHECKING
from hypercorn.typing import ASGIReceiveEvent
from hypercorn.typing import ASGISendEvent
from hypercorn.typing import HTTPScope
from hypercorn.typing import WebsocketScope
from werkzeug.datastructures import Headers
from ..json import dumps
from ..json import loads
from ..utils import decode_headers
from ..wrappers import Response
if TYPE_CHECKING:
from ..app import Quart # noqa
class HTTPDisconnectError(Exception):
pass
class WebsocketDisconnectError(Exception):
pass
class WebsocketResponseError(Exception):
def __init__(self, response: Response) -> None:
super().__init__(response)
self.response = response
class TestHTTPConnection:
def __init__(
self, app: Quart, scope: HTTPScope, _preserve_context: bool = False
) -> None:
self.app = app
self.headers: Headers | None = None
self.push_promises: list[tuple[str, Headers]] = []
self.response_data = bytearray()
self.scope = scope
self.status_code: int | None = None
self._preserve_context = _preserve_context
self._send_queue: asyncio.Queue = asyncio.Queue()
self._receive_queue: asyncio.Queue = asyncio.Queue()
self._task: Awaitable[None] = None
async def send(self, data: bytes) -> None:
await self._send_queue.put(
{"type": "http.request", "body": data, "more_body": True}
)
async def send_complete(self) -> None:
await self._send_queue.put(
{"type": "http.request", "body": b"", "more_body": False}
)
async def receive(self) -> bytes:
data = await self._receive_queue.get()
if isinstance(data, Exception):
raise data
else:
return data
async def disconnect(self) -> None:
await self._send_queue.put({"type": "http.disconnect"})
async def __aenter__(self) -> TestHTTPConnection:
self._task = asyncio.ensure_future(
self.app(self.scope, self._asgi_receive, self._asgi_send)
)
return self
async def __aexit__(
self, exc_type: type, exc_value: BaseException, tb: TracebackType
) -> None:
if exc_type is not None:
await self.disconnect()
await self._task
while not self._receive_queue.empty():
data = await self._receive_queue.get()
if isinstance(data, bytes):
self.response_data.extend(data)
elif not isinstance(data, HTTPDisconnectError):
raise data
async def as_response(self) -> Response:
while not self._receive_queue.empty():
data = await self._receive_queue.get()
if isinstance(data, bytes):
self.response_data.extend(data)
return self.app.response_class(
bytes(self.response_data), self.status_code, self.headers
)
async def _asgi_receive(self) -> ASGIReceiveEvent:
return await self._send_queue.get()
async def _asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "http.response.start":
self.headers = decode_headers(message["headers"])
self.status_code = message["status"]
elif message["type"] == "http.response.body":
await self._receive_queue.put(message["body"])
elif message["type"] == "http.response.push":
self.push_promises.append(
(message["path"], decode_headers(message["headers"]))
)
elif message["type"] == "http.disconnect":
await self._receive_queue.put(HTTPDisconnectError())
class TestWebsocketConnection:
def __init__(self, app: Quart, scope: WebsocketScope) -> None:
self.accepted = False
self.app = app
self.headers: Headers | None = None
self.response_data = bytearray()
self.scope = scope
self.status_code: int | None = None
self._send_queue: asyncio.Queue = asyncio.Queue()
self._receive_queue: asyncio.Queue = asyncio.Queue()
self._task: Awaitable[None] = None
async def __aenter__(self) -> TestWebsocketConnection:
self._task = asyncio.ensure_future(
self.app(self.scope, self._asgi_receive, self._asgi_send)
)
return self
async def __aexit__(
self, exc_type: type, exc_value: BaseException, tb: TracebackType
) -> None:
await self.disconnect()
await self._task
while not self._receive_queue.empty():
data = await self._receive_queue.get()
if isinstance(data, Exception) and not isinstance(
data, WebsocketDisconnectError
):
raise data
async def receive(self) -> AnyStr:
data = await self._receive_queue.get()
if isinstance(data, Exception):
raise data
else:
return data
async def send(self, data: AnyStr) -> None:
if isinstance(data, str):
await self._send_queue.put({"type": "websocket.receive", "text": data})
else:
await self._send_queue.put({"type": "websocket.receive", "bytes": data})
async def receive_json(self) -> Any:
data = await self.receive()
return loads(data)
async def send_json(self, data: Any) -> None:
raw = dumps(data)
await self.send(raw)
async def close(self, code: int) -> None:
await self._send_queue.put({"type": "websocket.close", "code": code})
async def disconnect(self) -> None:
await self._send_queue.put({"type": "websocket.disconnect"})
async def _asgi_receive(self) -> ASGIReceiveEvent:
return await self._send_queue.get()
async def _asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "websocket.accept":
self.accepted = True
elif message["type"] == "websocket.send":
await self._receive_queue.put(message.get("bytes") or message.get("text"))
elif message["type"] == "websocket.http.response.start":
self.headers = decode_headers(message["headers"])
self.status_code = message["status"]
elif message["type"] == "websocket.http.response.body":
self.response_data.extend(message["body"])
if not message.get("more_body", False):
await self._receive_queue.put(
WebsocketResponseError(
self.app.response_class(
bytes(self.response_data), self.status_code, self.headers
)
)
)
elif message["type"] == "websocket.close":
await self._receive_queue.put(
WebsocketDisconnectError(message.get("code", 1000))
)