brief-extractor/backend/venv/lib/python3.10/site-packages/quart/asgi.py
2026-03-06 18:42:46 +00:00

408 lines
15 KiB
Python
Executable file

from __future__ import annotations
import asyncio
import warnings
from functools import partial
from typing import AnyStr
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from hypercorn.typing import ASGIReceiveCallable
from hypercorn.typing import ASGISendCallable
from hypercorn.typing import HTTPResponseBodyEvent
from hypercorn.typing import HTTPResponseStartEvent
from hypercorn.typing import HTTPScope
from hypercorn.typing import LifespanScope
from hypercorn.typing import LifespanShutdownCompleteEvent
from hypercorn.typing import LifespanShutdownFailedEvent
from hypercorn.typing import LifespanStartupCompleteEvent
from hypercorn.typing import LifespanStartupFailedEvent
from hypercorn.typing import WebsocketAcceptEvent
from hypercorn.typing import WebsocketCloseEvent
from hypercorn.typing import WebsocketResponseBodyEvent
from hypercorn.typing import WebsocketResponseStartEvent
from hypercorn.typing import WebsocketScope
from werkzeug.datastructures import Headers
from werkzeug.wrappers import Response as WerkzeugResponse
from .debug import traceback_response
from .signals import websocket_received
from .signals import websocket_sent
from .typing import ResponseTypes
from .utils import cancel_tasks
from .utils import encode_headers
from .utils import raise_task_exceptions
from .wrappers import Request # noqa: F401
from .wrappers import Response # noqa: F401
from .wrappers import Websocket # noqa: F401
if TYPE_CHECKING:
from .app import Quart # noqa: F401
class ASGIHTTPConnection:
def __init__(self, app: Quart, scope: HTTPScope) -> None:
self.app = app
self.scope = scope
async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
request = self._create_request_from_scope(send)
receiver_task = asyncio.ensure_future(self.handle_messages(request, receive))
handler_task = asyncio.ensure_future(self.handle_request(request, send))
done, pending = await asyncio.wait(
[handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED
)
await cancel_tasks(pending)
raise_task_exceptions(done)
async def handle_messages(
self, request: Request, receive: ASGIReceiveCallable
) -> None:
while True:
message = await receive()
if message["type"] == "http.request":
request.body.append(message.get("body", b""))
if not message.get("more_body", False):
request.body.set_complete()
elif message["type"] == "http.disconnect":
return
def _create_request_from_scope(self, send: ASGISendCallable) -> Request:
headers = Headers()
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
for name, value in self.scope["headers"]:
headers.add(name.decode("latin1").title(), value.decode("latin1"))
if self.scope["http_version"] < "1.1":
headers.setdefault("Host", self.app.config["SERVER_NAME"] or "")
path = self.scope["path"]
path = path if path[0] == "/" else urlparse(path).path
root_path = self.scope.get("root_path", "")
if root_path != "":
try:
path = path.split(root_path, 1)[1]
path = " " if path == "" else path
except IndexError:
path = " " # Invalid in paths, hence will result in 404
return self.app.request_class(
self.scope["method"],
self.scope["scheme"],
path,
self.scope["query_string"],
headers,
self.scope.get("root_path", ""),
self.scope["http_version"],
max_content_length=self.app.config["MAX_CONTENT_LENGTH"],
body_timeout=self.app.config["BODY_TIMEOUT"],
send_push_promise=partial(self._send_push_promise, send),
scope=self.scope,
)
async def handle_request(self, request: Request, send: ASGISendCallable) -> None:
try:
response = await self.app.handle_request(request)
except Exception as error:
response = await _handle_exception(self.app, error)
if isinstance(response, Response) and response.timeout != Ellipsis:
timeout = cast(Optional[float], response.timeout)
else:
timeout = self.app.config["RESPONSE_TIMEOUT"]
try:
await asyncio.wait_for(self._send_response(send, response), timeout=timeout)
except asyncio.TimeoutError:
pass
async def _send_response(
self, send: ASGISendCallable, response: ResponseTypes
) -> None:
await send(
cast(
HTTPResponseStartEvent,
{
"type": "http.response.start",
"status": response.status_code,
"headers": encode_headers(response.headers),
},
)
)
if isinstance(response, WerkzeugResponse):
for data in response.response:
body = data.encode() if isinstance(data, str) else data
await send(
cast(
HTTPResponseBodyEvent,
{"type": "http.response.body", "body": body, "more_body": True},
)
)
else:
async with response.response as response_body:
async for data in response_body:
body = data.encode() if isinstance(data, str) else data
await send(
cast(
HTTPResponseBodyEvent,
{
"type": "http.response.body",
"body": body,
"more_body": True,
},
)
)
await send(
cast(
HTTPResponseBodyEvent,
{"type": "http.response.body", "body": b"", "more_body": False},
)
)
async def _send_push_promise(
self, send: ASGISendCallable, path: str, headers: Headers
) -> None:
extensions = self.scope.get("extensions", {}) or {}
if "http.response.push" in extensions:
await send(
{
"type": "http.response.push",
"path": path,
"headers": encode_headers(headers),
}
)
class ASGIWebsocketConnection:
def __init__(self, app: Quart, scope: WebsocketScope) -> None:
self.app = app
self.scope = scope
self.queue: asyncio.Queue = asyncio.Queue()
self._accepted = False
self._closed = False
async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
websocket = self._create_websocket_from_scope(send)
receiver_task = asyncio.ensure_future(self.handle_messages(receive))
handler_task = asyncio.ensure_future(self.handle_websocket(websocket, send))
done, pending = await asyncio.wait(
[handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED
)
await cancel_tasks(pending)
raise_task_exceptions(done)
async def handle_messages(self, receive: ASGIReceiveCallable) -> None:
while True:
event = await receive()
if event["type"] == "websocket.receive":
message = event.get("bytes") or event["text"]
await websocket_received.send_async(message)
await self.queue.put(message)
elif event["type"] == "websocket.disconnect":
return
def _create_websocket_from_scope(self, send: ASGISendCallable) -> Websocket:
headers = Headers()
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
for name, value in self.scope["headers"]:
headers.add(name.decode("latin1").title(), value.decode("latin1"))
path = self.scope["path"]
path = path if path[0] == "/" else urlparse(path).path
root_path = self.scope.get("root_path", "")
if root_path != "":
try:
path = path.split(root_path, 1)[1]
path = " " if path == "" else path
except IndexError:
path = " " # Invalid in paths, hence will result in 404
return self.app.websocket_class(
path,
self.scope["query_string"],
self.scope["scheme"],
headers,
self.scope.get("root_path", ""),
self.scope.get("http_version", "1.1"),
list(self.scope.get("subprotocols", [])),
self.queue.get,
partial(self.send_data, send),
partial(self.accept_connection, send),
partial(self.close_connection, send),
scope=self.scope,
)
async def handle_websocket(
self, websocket: Websocket, send: ASGISendCallable
) -> None:
try:
response = await self.app.handle_websocket(websocket)
except Exception as error:
response = await _handle_exception(self.app, error)
if response is not None and not self._accepted:
extensions = self.scope.get("extensions", {}) or {}
if "websocket.http.response" in extensions:
headers = [
(key.lower().encode(), value.encode())
for key, value in response.headers.items()
]
await send(
cast(
WebsocketResponseStartEvent,
{
"type": "websocket.http.response.start",
"status": response.status_code,
"headers": headers,
},
)
)
if isinstance(response, WerkzeugResponse):
for data in response.response:
await send(
cast(
WebsocketResponseBodyEvent,
{
"type": "websocket.http.response.body",
"body": data,
"more_body": True,
},
)
)
elif isinstance(response, Response):
async with response.response as body:
async for data in body:
await send(
cast(
WebsocketResponseBodyEvent,
{
"type": "websocket.http.response.body",
"body": data,
"more_body": True,
},
)
)
await send(
cast(
WebsocketResponseBodyEvent,
{
"type": "websocket.http.response.body",
"body": b"",
"more_body": False,
},
)
)
elif not self._closed:
await send(
cast(WebsocketCloseEvent, {"type": "websocket.close", "code": 1000})
)
elif self._accepted and not self._closed:
await send(
cast(WebsocketCloseEvent, {"type": "websocket.close", "code": 1000})
)
async def send_data(self, send: ASGISendCallable, data: AnyStr) -> None:
if isinstance(data, str):
await send({"type": "websocket.send", "bytes": None, "text": data})
else:
await send({"type": "websocket.send", "bytes": data, "text": None})
await websocket_sent.send_async(data)
async def accept_connection(
self, send: ASGISendCallable, headers: Headers, subprotocol: str | None
) -> None:
if not self._accepted:
message: WebsocketAcceptEvent = {
"headers": [],
"subprotocol": subprotocol,
"type": "websocket.accept",
}
spec_version = _convert_version(
self.scope.get("asgi", {}).get("spec_version", "2.0")
)
if spec_version > [2, 0]:
message["headers"] = encode_headers(headers)
elif headers:
warnings.warn(
"The ASGI Server does not support accept headers, headers not sent",
stacklevel=1,
)
self._accepted = True
await send(message)
async def close_connection(
self, send: ASGISendCallable, code: int, reason: str
) -> None:
if self._closed:
raise RuntimeError("Cannot close websocket multiple times")
spec_version = _convert_version(
self.scope.get("asgi", {}).get("spec_version", "2.0")
)
if spec_version >= [2, 3]:
await send({"type": "websocket.close", "code": code, "reason": reason})
else:
await send({"type": "websocket.close", "code": code}) # type: ignore
self._closed = True
class ASGILifespan:
def __init__(self, app: Quart, scope: LifespanScope) -> None:
self.app = app
async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
while True:
event = await receive()
if event["type"] == "lifespan.startup":
try:
await self.app.startup()
except Exception as error:
await send(
cast(
LifespanStartupFailedEvent,
{"type": "lifespan.startup.failed", "message": str(error)},
),
)
else:
await send(
cast(
LifespanStartupCompleteEvent,
{"type": "lifespan.startup.complete"},
)
)
elif event["type"] == "lifespan.shutdown":
try:
await self.app.shutdown()
except Exception as error:
await send(
cast(
LifespanShutdownFailedEvent,
{"type": "lifespan.shutdown.failed", "message": str(error)},
),
)
else:
await send(
cast(
LifespanShutdownCompleteEvent,
{"type": "lifespan.shutdown.complete"},
),
)
break
def _convert_version(raw: str) -> list[int]:
return list(map(int, raw.split(".")))
async def _handle_exception(app: Quart, error: Exception) -> Response:
if not app.testing and app.config["PROPAGATE_EXCEPTIONS"]:
return await traceback_response(error)
else:
raise error