872 lines
32 KiB
Python
Executable file
872 lines
32 KiB
Python
Executable file
# SPDX-License-Identifier: BSD-3-Clause
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import functools
|
|
import logging
|
|
import socket
|
|
import ssl
|
|
import sys
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from enum import IntEnum
|
|
from types import TracebackType
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
Awaitable,
|
|
Callable,
|
|
Coroutine,
|
|
Generator,
|
|
Iterable,
|
|
Iterator,
|
|
cast,
|
|
)
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from typing import Concatenate, ParamSpec, TypeAlias
|
|
else:
|
|
from typing_extensions import Concatenate, ParamSpec, TypeAlias
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
import paho.mqtt.client as mqtt
|
|
from paho.mqtt.properties import Properties
|
|
|
|
from .error import MqttCodeError, MqttConnectError, MqttError
|
|
from .types import PayloadType, T
|
|
|
|
MQTT_LOGGER = logging.getLogger("mqtt")
|
|
MQTT_LOGGER.setLevel(logging.WARNING)
|
|
|
|
MQTT_LOGGER.warning(
|
|
"We renamed asyncio-mqtt to aiomqtt and released a version 1.0.0 in the process. "
|
|
"This is the last release under the asyncio-mqtt name. You can find the new "
|
|
"repository at https://github.com/sbtinstruments/aiomqtt"
|
|
)
|
|
|
|
_PahoSocket: TypeAlias = "socket.socket | ssl.SSLSocket | mqtt.WebsocketWrapper | Any"
|
|
|
|
WebSocketHeaders: TypeAlias = (
|
|
"dict[str, str] | Callable[[dict[str, str]], dict[str, str]]"
|
|
)
|
|
|
|
|
|
class ProtocolVersion(IntEnum):
|
|
"""A mapping of paho-mqtt protocol versions to an Enum for use in type hints."""
|
|
|
|
V31 = mqtt.MQTTv31
|
|
V311 = mqtt.MQTTv311
|
|
V5 = mqtt.MQTTv5
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Will:
|
|
topic: str
|
|
payload: PayloadType | None = None
|
|
qos: int = 0
|
|
retain: bool = False
|
|
properties: mqtt.Properties | None = None
|
|
|
|
|
|
# TLS set parameter class
|
|
@dataclass(frozen=True)
|
|
class TLSParameters:
|
|
ca_certs: str | None = None
|
|
certfile: str | None = None
|
|
keyfile: str | None = None
|
|
cert_reqs: ssl.VerifyMode | None = None
|
|
tls_version: Any | None = None
|
|
ciphers: str | None = None
|
|
keyfile_password: str | None = None
|
|
|
|
|
|
# Proxy parameters class
|
|
class ProxySettings:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
proxy_type: int,
|
|
proxy_addr: str,
|
|
proxy_rdns: bool | None = True,
|
|
proxy_username: str | None = None,
|
|
proxy_password: str | None = None,
|
|
):
|
|
self.proxy_args = {
|
|
"proxy_type": proxy_type,
|
|
"proxy_addr": proxy_addr,
|
|
"proxy_rdns": proxy_rdns,
|
|
"proxy_username": proxy_username,
|
|
"proxy_password": proxy_password,
|
|
}
|
|
|
|
|
|
# See the overloads of `socket.setsockopt` for details.
|
|
SocketOption: TypeAlias = "tuple[int, int, int | bytes] | tuple[int, int, None, int]"
|
|
|
|
SubscribeTopic: TypeAlias = "str | tuple[str, mqtt.SubscribeOptions] | list[tuple[str, mqtt.SubscribeOptions]] | list[tuple[str, int]]"
|
|
|
|
P = ParamSpec("P")
|
|
|
|
# TODO: Simplify the logic that surrounds `self._outgoing_calls_sem` with
|
|
# `nullcontext` when we support Python 3.10 (`nullcontext` becomes async-aware in
|
|
# 3.10). See: https://docs.python.org/3/library/contextlib.html#contextlib.nullcontext
|
|
def _outgoing_call(
|
|
method: Callable[Concatenate[Client, P], Coroutine[Any, Any, T]]
|
|
) -> Callable[Concatenate[Client, P], Coroutine[Any, Any, T]]:
|
|
@functools.wraps(method)
|
|
async def decorated(self: Client, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
if not self._outgoing_calls_sem:
|
|
return await method(self, *args, **kwargs)
|
|
|
|
async with self._outgoing_calls_sem:
|
|
return await method(self, *args, **kwargs)
|
|
|
|
return decorated
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Wildcard:
|
|
"""A topic, optionally with wildcards (+ and #). Can only be subscribed to."""
|
|
|
|
value: str
|
|
|
|
def __str__(self) -> str:
|
|
return self.value
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Validate the wildcard."""
|
|
if not isinstance(self.value, str):
|
|
raise TypeError("wildcard must be a string")
|
|
if (
|
|
len(self.value) == 0
|
|
or len(self.value) > 65535
|
|
or "#/" in self.value
|
|
or any(
|
|
"+" in level or "#" in level
|
|
for level in self.value.split("/")
|
|
if len(level) > 1
|
|
)
|
|
):
|
|
raise ValueError(f"Invalid wildcard: {self.value}")
|
|
|
|
|
|
WildcardLike: TypeAlias = "str | Wildcard"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Topic(Wildcard):
|
|
"""A topic that can be published and subscribed to."""
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Validate the topic."""
|
|
if not isinstance(self.value, str):
|
|
raise TypeError("topic must be a string")
|
|
if (
|
|
len(self.value) == 0
|
|
or len(self.value) > 65535
|
|
or "+" in self.value
|
|
or "#" in self.value
|
|
):
|
|
raise ValueError(f"Invalid topic: {self.value}")
|
|
|
|
def matches(self, wildcard: WildcardLike) -> bool:
|
|
"""Check if the topic is matched by a given wildcard."""
|
|
if not isinstance(wildcard, Wildcard):
|
|
wildcard = Wildcard(wildcard)
|
|
# Split topics into levels to compare them one by one
|
|
topic_levels = self.value.split("/")
|
|
wildcard_levels = str(wildcard).split("/")
|
|
if wildcard_levels[0] == "$share":
|
|
# Shared subscriptions use the topic structure: $share/<group_id>/<topic>
|
|
wildcard_levels = wildcard_levels[2:]
|
|
|
|
def recurse(x: list[str], y: list[str]) -> bool:
|
|
if not x:
|
|
if not y:
|
|
return True
|
|
return False
|
|
if not y:
|
|
return False
|
|
if y[0] == "#":
|
|
return True
|
|
if x[0] == y[0] or y[0] == "+":
|
|
return recurse(x[1:], y[1:])
|
|
return False
|
|
|
|
return recurse(topic_levels, wildcard_levels)
|
|
|
|
|
|
TopicLike: TypeAlias = "str | Topic"
|
|
|
|
|
|
class Message:
|
|
"""Custom message class that allows us to use our own Topic class."""
|
|
|
|
def __init__(
|
|
self,
|
|
topic: TopicLike,
|
|
payload: PayloadType,
|
|
qos: int,
|
|
retain: bool,
|
|
mid: int,
|
|
properties: Properties | None,
|
|
):
|
|
self.topic = Topic(topic) if not isinstance(topic, Topic) else topic
|
|
self.payload = payload
|
|
self.qos = qos
|
|
self.retain = retain
|
|
self.mid = mid
|
|
self.properties = properties
|
|
|
|
@classmethod
|
|
def _from_paho_message(cls, message: mqtt.MQTTMessage) -> Message:
|
|
return cls(
|
|
topic=message.topic,
|
|
payload=message.payload,
|
|
qos=message.qos,
|
|
retain=message.retain,
|
|
mid=message.mid,
|
|
properties=message.properties if hasattr(message, "properties") else None,
|
|
)
|
|
|
|
|
|
class Client:
|
|
def __init__( # noqa: C901
|
|
self,
|
|
hostname: str,
|
|
port: int = 1883,
|
|
*,
|
|
username: str | None = None,
|
|
password: str | None = None,
|
|
logger: logging.Logger | None = None,
|
|
client_id: str | None = None,
|
|
tls_context: ssl.SSLContext | None = None,
|
|
tls_params: TLSParameters | None = None,
|
|
proxy: ProxySettings | None = None,
|
|
protocol: ProtocolVersion | None = None,
|
|
will: Will | None = None,
|
|
clean_session: bool | None = None,
|
|
transport: str = "tcp",
|
|
keepalive: int = 60,
|
|
bind_address: str = "",
|
|
bind_port: int = 0,
|
|
clean_start: int = mqtt.MQTT_CLEAN_START_FIRST_ONLY,
|
|
properties: Properties | None = None,
|
|
message_retry_set: int = 20,
|
|
socket_options: Iterable[SocketOption] | None = None,
|
|
max_concurrent_outgoing_calls: int | None = None,
|
|
websocket_path: str | None = None,
|
|
websocket_headers: WebSocketHeaders | None = None,
|
|
):
|
|
self._hostname = hostname
|
|
self._port = port
|
|
self._keepalive = keepalive
|
|
self._bind_address = bind_address
|
|
self._bind_port = bind_port
|
|
self._clean_start = clean_start
|
|
self._properties = properties
|
|
self._loop = asyncio.get_event_loop()
|
|
self._connected: asyncio.Future[int | mqtt.ReasonCodes] = asyncio.Future()
|
|
self._disconnected: asyncio.Future[
|
|
int | mqtt.ReasonCodes | None
|
|
] = asyncio.Future()
|
|
# Pending subscribe, unsubscribe, and publish calls
|
|
self._pending_subscribes: dict[
|
|
int, asyncio.Future[tuple[int] | list[mqtt.ReasonCodes]]
|
|
] = {}
|
|
self._pending_unsubscribes: dict[int, asyncio.Event] = {}
|
|
self._pending_publishes: dict[int, asyncio.Event] = {}
|
|
self._pending_calls_threshold: int = 10
|
|
self._misc_task: asyncio.Task[None] | None = None
|
|
|
|
# List of all callbacks to call when a message is received
|
|
self._on_message_callbacks: list[Callable[[Message], None]] = []
|
|
self._unfiltered_messages_callback: Callable[
|
|
[mqtt.Client, Any, mqtt.MQTTMessage], None
|
|
] | None = None
|
|
|
|
self._outgoing_calls_sem: asyncio.Semaphore | None
|
|
if max_concurrent_outgoing_calls is not None:
|
|
self._outgoing_calls_sem = asyncio.Semaphore(max_concurrent_outgoing_calls)
|
|
else:
|
|
self._outgoing_calls_sem = None
|
|
|
|
if protocol is None:
|
|
protocol = ProtocolVersion.V311
|
|
|
|
self._client: mqtt.Client = mqtt.Client(
|
|
client_id=client_id,
|
|
protocol=protocol,
|
|
clean_session=clean_session,
|
|
transport=transport,
|
|
reconnect_on_failure=False,
|
|
)
|
|
self._client.on_connect = self._on_connect
|
|
self._client.on_disconnect = self._on_disconnect
|
|
self._client.on_subscribe = self._on_subscribe
|
|
self._client.on_unsubscribe = self._on_unsubscribe
|
|
self._client.on_message = self._on_message
|
|
self._client.on_publish = self._on_publish
|
|
# Callbacks for custom event loop
|
|
self._client.on_socket_open = self._on_socket_open
|
|
self._client.on_socket_close = self._on_socket_close
|
|
self._client.on_socket_register_write = self._on_socket_register_write
|
|
self._client.on_socket_unregister_write = self._on_socket_unregister_write
|
|
|
|
if logger is None:
|
|
logger = MQTT_LOGGER
|
|
self._client.enable_logger(logger)
|
|
|
|
if username is not None:
|
|
self._client.username_pw_set(username=username, password=password)
|
|
|
|
if tls_context is not None:
|
|
self._client.tls_set_context(tls_context)
|
|
|
|
if tls_params is not None:
|
|
self._client.tls_set(
|
|
ca_certs=tls_params.ca_certs,
|
|
certfile=tls_params.certfile,
|
|
keyfile=tls_params.keyfile,
|
|
cert_reqs=tls_params.cert_reqs,
|
|
tls_version=tls_params.tls_version,
|
|
ciphers=tls_params.ciphers,
|
|
keyfile_password=tls_params.keyfile_password,
|
|
)
|
|
|
|
if proxy is not None:
|
|
self._client.proxy_set(**proxy.proxy_args)
|
|
|
|
if websocket_path is not None:
|
|
self._client.ws_set_options(path=websocket_path, headers=websocket_headers)
|
|
|
|
if will is not None:
|
|
self._client.will_set(
|
|
will.topic, will.payload, will.qos, will.retain, will.properties
|
|
)
|
|
|
|
self._client.message_retry_set(message_retry_set)
|
|
if socket_options is None:
|
|
socket_options = ()
|
|
self._socket_options = tuple(socket_options)
|
|
|
|
@property
|
|
def id( # noqa: A003 # TODO: When doing BREAKING CHANGES rename to avoid shadowing builtin id
|
|
self,
|
|
) -> str:
|
|
"""Return the client ID.
|
|
|
|
Note that paho-mqtt stores the client ID as `bytes` internally.
|
|
We assume that the client ID is a UTF8-encoded string and decode
|
|
it first.
|
|
"""
|
|
return cast(bytes, self._client._client_id).decode() # type: ignore[attr-defined]
|
|
|
|
@property
|
|
def _pending_calls(self) -> Generator[int, None, None]:
|
|
"""Yield all message IDs with pending calls."""
|
|
yield from self._pending_subscribes.keys()
|
|
yield from self._pending_unsubscribes.keys()
|
|
yield from self._pending_publishes.keys()
|
|
|
|
async def connect(self, *, timeout: int = 10) -> None:
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
|
|
# [3] Run connect() within an executor thread, since it blocks on socket
|
|
# connection for up to `keepalive` seconds: https://git.io/Jt5Yc
|
|
await loop.run_in_executor(
|
|
None,
|
|
self._client.connect,
|
|
self._hostname,
|
|
self._port,
|
|
self._keepalive,
|
|
self._bind_address,
|
|
self._bind_port,
|
|
self._clean_start,
|
|
self._properties,
|
|
)
|
|
client_socket = self._client.socket()
|
|
_set_client_socket_defaults(client_socket, self._socket_options)
|
|
# paho.mqtt.Client.connect may raise one of several exceptions.
|
|
# We convert all of them to the common MqttError for user convenience.
|
|
# See: https://github.com/eclipse/paho.mqtt.python/blob/v1.5.0/src/paho/mqtt/client.py#L1770
|
|
except (OSError, mqtt.WebsocketConnectionError) as error:
|
|
raise MqttError(str(error)) from None
|
|
await self._wait_for(self._connected, timeout=timeout)
|
|
|
|
async def disconnect(self, *, timeout: int = 10) -> None:
|
|
rc = self._client.disconnect()
|
|
# Early out on error
|
|
if rc != mqtt.MQTT_ERR_SUCCESS:
|
|
raise MqttCodeError(rc, "Could not disconnect")
|
|
# Wait for acknowledgement
|
|
await self._wait_for(self._disconnected, timeout=timeout)
|
|
|
|
async def force_disconnect(self) -> None:
|
|
if not self._disconnected.done():
|
|
self._disconnected.set_result(None)
|
|
|
|
@_outgoing_call
|
|
async def subscribe(
|
|
self,
|
|
topic: SubscribeTopic,
|
|
qos: int = 0,
|
|
options: mqtt.SubscribeOptions | None = None,
|
|
properties: Properties | None = None,
|
|
*args: Any,
|
|
timeout: int = 10,
|
|
**kwargs: Any,
|
|
) -> tuple[int] | list[mqtt.ReasonCodes]:
|
|
result, mid = self._client.subscribe(
|
|
topic, qos, options, properties, *args, **kwargs
|
|
)
|
|
# Early out on error
|
|
if result != mqtt.MQTT_ERR_SUCCESS:
|
|
raise MqttCodeError(result, "Could not subscribe to topic")
|
|
# Create future for when the on_subscribe callback is called
|
|
callback_result: asyncio.Future[
|
|
tuple[int] | list[mqtt.ReasonCodes]
|
|
] = asyncio.Future()
|
|
with self._pending_call(mid, callback_result, self._pending_subscribes):
|
|
# Wait for callback_result
|
|
return await self._wait_for(callback_result, timeout=timeout)
|
|
|
|
@_outgoing_call
|
|
async def unsubscribe(
|
|
self,
|
|
topic: str | list[str],
|
|
properties: Properties | None = None,
|
|
*args: Any,
|
|
timeout: int = 10,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
result, mid = self._client.unsubscribe(topic, properties, *args, **kwargs)
|
|
# Early out on error
|
|
if result != mqtt.MQTT_ERR_SUCCESS:
|
|
raise MqttCodeError(result, "Could not unsubscribe from topic")
|
|
# Create event for when the on_unsubscribe callback is called
|
|
confirmation = asyncio.Event()
|
|
with self._pending_call(mid, confirmation, self._pending_unsubscribes):
|
|
# Wait for confirmation
|
|
await self._wait_for(confirmation.wait(), timeout=timeout)
|
|
|
|
@_outgoing_call
|
|
async def publish(
|
|
self,
|
|
topic: str,
|
|
payload: PayloadType = None,
|
|
qos: int = 0,
|
|
retain: bool = False,
|
|
properties: Properties | None = None,
|
|
*args: Any,
|
|
timeout: int = 10,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
info = self._client.publish(
|
|
topic, payload, qos, retain, properties, *args, **kwargs
|
|
) # [2]
|
|
# Early out on error
|
|
if info.rc != mqtt.MQTT_ERR_SUCCESS:
|
|
raise MqttCodeError(info.rc, "Could not publish message")
|
|
# Early out on immediate success
|
|
if info.is_published():
|
|
return
|
|
# Create event for when the on_publish callback is called
|
|
confirmation = asyncio.Event()
|
|
with self._pending_call(info.mid, confirmation, self._pending_publishes):
|
|
# Wait for confirmation
|
|
await self._wait_for(confirmation.wait(), timeout=timeout)
|
|
|
|
@asynccontextmanager
|
|
async def filtered_messages(
|
|
self, topic_filter: str, *, queue_maxsize: int = 0
|
|
) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]:
|
|
"""Return async generator of messages that match the given filter."""
|
|
MQTT_LOGGER.warning(
|
|
"filtered_messages() is deprecated and will be removed in a future version."
|
|
" Use messages() together with Topic.matches() instead."
|
|
)
|
|
callback, generator = self._deprecated_callback_and_generator(
|
|
log_context=f'topic_filter="{topic_filter}"', queue_maxsize=queue_maxsize
|
|
)
|
|
try:
|
|
self._client.message_callback_add(topic_filter, callback)
|
|
# Back to the caller (run whatever is inside the with statement)
|
|
yield generator
|
|
finally:
|
|
# We are exiting the with statement. Remove the topic filter.
|
|
self._client.message_callback_remove(topic_filter)
|
|
|
|
@asynccontextmanager
|
|
async def unfiltered_messages(
|
|
self, *, queue_maxsize: int = 0
|
|
) -> AsyncGenerator[AsyncGenerator[mqtt.MQTTMessage, None], None]:
|
|
"""Return async generator of all messages that are not caught in filters."""
|
|
MQTT_LOGGER.warning(
|
|
"unfiltered_messages() is deprecated and will be removed in a future"
|
|
" version. Use messages() instead."
|
|
)
|
|
# Early out
|
|
if self._unfiltered_messages_callback is not None:
|
|
raise RuntimeError(
|
|
"Only a single unfiltered_messages generator can be used at a time."
|
|
)
|
|
callback, generator = self._deprecated_callback_and_generator(
|
|
log_context="unfiltered", queue_maxsize=queue_maxsize
|
|
)
|
|
try:
|
|
self._unfiltered_messages_callback = callback
|
|
# Back to the caller (run whatever is inside the with statement)
|
|
yield generator
|
|
finally:
|
|
# We are exiting the with statement. Unset the callback.
|
|
self._unfiltered_messages_callback = None
|
|
|
|
@asynccontextmanager
|
|
async def messages(
|
|
self, *, queue_maxsize: int = 0
|
|
) -> AsyncGenerator[AsyncGenerator[Message, None], None]:
|
|
"""Return async generator of incoming messages.
|
|
|
|
Use queue_maxsize to restrict the queue size. If the queue is full,
|
|
incoming messages will be discarded (and a warning is logged).
|
|
If queue_maxsize is less than or equal to zero, the queue size is infinite.
|
|
"""
|
|
callback, generator = self._callback_and_generator(queue_maxsize=queue_maxsize)
|
|
try:
|
|
# Add to the list of callbacks to call when a message is received
|
|
self._on_message_callbacks.append(callback)
|
|
# Back to the caller (run whatever is inside the with statement)
|
|
yield generator
|
|
finally:
|
|
# We are exiting the with statement. Remove the callback from the list.
|
|
self._on_message_callbacks.remove(callback)
|
|
|
|
def _deprecated_callback_and_generator(
|
|
self, *, log_context: str, queue_maxsize: int = 0
|
|
) -> tuple[
|
|
Callable[[mqtt.Client, Any, mqtt.MQTTMessage], None],
|
|
AsyncGenerator[mqtt.MQTTMessage, None],
|
|
]:
|
|
# Queue to hold the incoming messages
|
|
messages: asyncio.Queue[mqtt.MQTTMessage] = asyncio.Queue(maxsize=queue_maxsize)
|
|
# Callback for the underlying API
|
|
def _put_in_queue(
|
|
client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage
|
|
) -> None:
|
|
try:
|
|
messages.put_nowait(message)
|
|
except asyncio.QueueFull:
|
|
MQTT_LOGGER.warning(
|
|
f"[{log_context}] Message queue is full. Discarding message."
|
|
)
|
|
|
|
# The generator that we give to the caller
|
|
async def _message_generator() -> AsyncGenerator[mqtt.MQTTMessage, None]:
|
|
# Forward all messages from the queue
|
|
while True:
|
|
# Wait until we either:
|
|
# 1. Receive a message
|
|
# 2. Disconnect from the broker
|
|
get: asyncio.Task[mqtt.MQTTMessage] = self._loop.create_task(
|
|
messages.get()
|
|
)
|
|
try:
|
|
done, _ = await asyncio.wait(
|
|
(get, self._disconnected), return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
except asyncio.CancelledError:
|
|
# If the asyncio.wait is cancelled, we must make sure
|
|
# to also cancel the underlying tasks.
|
|
get.cancel()
|
|
raise
|
|
if get in done:
|
|
# We received a message. Return the result.
|
|
yield get.result()
|
|
else:
|
|
# We got disconnected from the broker. Cancel the "get" task.
|
|
get.cancel()
|
|
# Stop the generator with the following exception
|
|
raise MqttError("Disconnected during message iteration")
|
|
|
|
return _put_in_queue, _message_generator()
|
|
|
|
def _callback_and_generator(
|
|
self, *, queue_maxsize: int = 0
|
|
) -> tuple[Callable[[Message], None], AsyncGenerator[Message, None]]:
|
|
# Queue to hold the incoming messages
|
|
messages: asyncio.Queue[Message] = asyncio.Queue(maxsize=queue_maxsize)
|
|
|
|
def _callback(message: Message) -> None:
|
|
"""Put the new message in the queue."""
|
|
try:
|
|
messages.put_nowait(message)
|
|
except asyncio.QueueFull:
|
|
MQTT_LOGGER.warning("Message queue is full. Discarding message.")
|
|
|
|
async def _generator() -> AsyncGenerator[Message, None]:
|
|
"""Forward all messages from the message queue."""
|
|
while True:
|
|
# Wait until we either:
|
|
# 1. Receive a message
|
|
# 2. Disconnect from the broker
|
|
get: asyncio.Task[Message] = self._loop.create_task(messages.get())
|
|
try:
|
|
done, _ = await asyncio.wait(
|
|
(get, self._disconnected), return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
except asyncio.CancelledError:
|
|
# If the asyncio.wait is cancelled, we must make sure
|
|
# to also cancel the underlying tasks.
|
|
get.cancel()
|
|
raise
|
|
if get in done:
|
|
# We received a message. Return the result.
|
|
yield get.result()
|
|
else:
|
|
# We got disconnected from the broker. Cancel the "get" task.
|
|
get.cancel()
|
|
# Stop the generator with the following exception
|
|
raise MqttError("Disconnected during message iteration")
|
|
|
|
return _callback, _generator()
|
|
|
|
async def _wait_for(
|
|
self, fut: Awaitable[T], timeout: float | None, **kwargs: Any
|
|
) -> T:
|
|
try:
|
|
return await asyncio.wait_for(fut, timeout=timeout, **kwargs)
|
|
except asyncio.TimeoutError:
|
|
raise MqttError("Operation timed out") from None
|
|
|
|
@contextmanager
|
|
def _pending_call(
|
|
self, mid: int, value: T, pending_dict: dict[int, T]
|
|
) -> Iterator[None]:
|
|
if mid in self._pending_calls:
|
|
raise RuntimeError(
|
|
f'There already exists a pending call for message ID "{mid}"'
|
|
)
|
|
pending_dict[mid] = value # [1]
|
|
try:
|
|
# Log a warning if there is a concerning number of pending calls
|
|
pending = len(list(self._pending_calls))
|
|
if pending > self._pending_calls_threshold:
|
|
MQTT_LOGGER.warning(f"There are {pending} pending publish calls.")
|
|
# Back to the caller (run whatever is inside the with statement)
|
|
yield
|
|
finally:
|
|
# The normal procedure is:
|
|
# * We add the item at [1]
|
|
# * A callback will remove the item
|
|
#
|
|
# However, if the callback doesn't get called (e.g., due to a
|
|
# network error) we still need to remove the item from the dict.
|
|
try:
|
|
del pending_dict[mid]
|
|
except KeyError:
|
|
pass
|
|
|
|
def _on_connect(
|
|
self,
|
|
client: mqtt.Client,
|
|
userdata: Any,
|
|
flags: dict[str, int],
|
|
rc: int | mqtt.ReasonCodes,
|
|
properties: mqtt.Properties | None = None,
|
|
) -> None:
|
|
# Return early if already connected. Sometimes, paho-mqtt calls _on_connect
|
|
# multiple times. Maybe because we receive multiple CONNACK messages
|
|
# from the server. In any case, we return early so that we don't set
|
|
# self._connected twice (as it raises an asyncio.InvalidStateError).
|
|
if self._connected.done():
|
|
return
|
|
|
|
if rc == mqtt.CONNACK_ACCEPTED:
|
|
self._connected.set_result(rc)
|
|
else:
|
|
self._connected.set_exception(MqttConnectError(rc))
|
|
|
|
def _on_disconnect(
|
|
self,
|
|
client: mqtt.Client,
|
|
userdata: Any,
|
|
rc: int | mqtt.ReasonCodes | None,
|
|
properties: mqtt.Properties | None = None,
|
|
) -> None:
|
|
# Return early if the disconnect is already acknowledged.
|
|
# Sometimes (e.g., due to timeouts), paho-mqtt calls _on_disconnect
|
|
# twice. We return early to avoid setting self._disconnected twice
|
|
# (as it raises an asyncio.InvalidStateError).
|
|
if self._disconnected.done():
|
|
return
|
|
# Return early if we are not connected yet. This avoids calling
|
|
# `_disconnected.set_exception` with an exception that will never
|
|
# be retrieved (since `__aexit__` won't get called if `__aenter__`
|
|
# fails). In turn, this avoids asyncio debug messages like the
|
|
# following:
|
|
#
|
|
# "[asyncio] Future exception was never retrieved"
|
|
#
|
|
# See also: https://docs.python.org/3/library/asyncio-dev.html#detect-never-retrieved-exceptions
|
|
if not self._connected.done() or self._connected.exception() is not None:
|
|
return
|
|
if rc == mqtt.MQTT_ERR_SUCCESS:
|
|
self._disconnected.set_result(rc)
|
|
else:
|
|
self._disconnected.set_exception(MqttCodeError(rc, "Unexpected disconnect"))
|
|
|
|
def _on_subscribe(
|
|
self,
|
|
client: mqtt.Client,
|
|
userdata: Any,
|
|
mid: int,
|
|
granted_qos: tuple[int] | list[mqtt.ReasonCodes],
|
|
properties: mqtt.Properties | None = None,
|
|
) -> None:
|
|
try:
|
|
fut = self._pending_subscribes.pop(mid)
|
|
if not fut.done():
|
|
fut.set_result(granted_qos)
|
|
except KeyError:
|
|
MQTT_LOGGER.error(f'Unexpected message ID "{mid}" in on_subscribe callback')
|
|
|
|
def _on_unsubscribe(
|
|
self,
|
|
client: mqtt.Client,
|
|
userdata: Any,
|
|
mid: int,
|
|
properties: mqtt.Properties | None = None,
|
|
reason_codes: list[mqtt.ReasonCodes] | mqtt.ReasonCodes | None = None,
|
|
) -> None:
|
|
try:
|
|
self._pending_unsubscribes.pop(mid).set()
|
|
except KeyError:
|
|
MQTT_LOGGER.error(
|
|
f'Unexpected message ID "{mid}" in on_unsubscribe callback'
|
|
)
|
|
|
|
def _on_message(
|
|
self, client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage
|
|
) -> None:
|
|
# Call the deprecated unfiltered_messages callback
|
|
if self._unfiltered_messages_callback is not None:
|
|
self._unfiltered_messages_callback(client, userdata, message)
|
|
# Convert the paho.mqtt message into our own Message type
|
|
m = Message._from_paho_message(message)
|
|
for callback in self._on_message_callbacks:
|
|
callback(m)
|
|
|
|
def _on_publish(self, client: mqtt.Client, userdata: Any, mid: int) -> None:
|
|
try:
|
|
self._pending_publishes.pop(mid).set()
|
|
except KeyError:
|
|
# Do nothing since [2] may call on_publish before it even returns.
|
|
# That is, the message may already be published before we even get a
|
|
# chance to set up the 'pending_call' logic.
|
|
pass
|
|
|
|
def _on_socket_open(
|
|
self, client: mqtt.Client, userdata: Any, sock: _PahoSocket
|
|
) -> None:
|
|
def callback() -> None:
|
|
# client.loop_read() may raise an exception, such as BadPipe. It's
|
|
# usually a sign that the underlaying connection broke, therefore we
|
|
# disconnect straight away
|
|
try:
|
|
client.loop_read()
|
|
except Exception as exc:
|
|
if not self._disconnected.done():
|
|
self._disconnected.set_exception(exc)
|
|
|
|
self._loop.add_reader(sock.fileno(), callback)
|
|
# paho-mqtt calls this function from the executor thread on which we've called
|
|
# `self._client.connect()` (see [3]), so we create a callback function to
|
|
# schedule `_misc_loop()` and run it on the loop thread-safely.
|
|
def create_task_callback() -> None:
|
|
self._misc_task = self._loop.create_task(self._misc_loop())
|
|
|
|
self._loop.call_soon_threadsafe(create_task_callback)
|
|
|
|
def _on_socket_close(
|
|
self, client: mqtt.Client, userdata: Any, sock: _PahoSocket
|
|
) -> None:
|
|
fileno = sock.fileno()
|
|
if fileno > -1:
|
|
self._loop.remove_reader(fileno)
|
|
if self._misc_task is not None and not self._misc_task.done():
|
|
self._loop.call_soon_threadsafe(self._misc_task.cancel)
|
|
|
|
def _on_socket_register_write(
|
|
self, client: mqtt.Client, userdata: Any, sock: _PahoSocket
|
|
) -> None:
|
|
def callback() -> None:
|
|
# client.loop_write() may raise an exception, such as BadPipe. It's
|
|
# usually a sign that the underlaying connection broke, therefore we
|
|
# disconnect straight away
|
|
try:
|
|
client.loop_write()
|
|
except Exception as exc:
|
|
if not self._disconnected.done():
|
|
self._disconnected.set_exception(exc)
|
|
|
|
self._loop.add_writer(sock, callback)
|
|
|
|
def _on_socket_unregister_write(
|
|
self, client: mqtt.Client, userdata: Any, sock: _PahoSocket
|
|
) -> None:
|
|
self._loop.remove_writer(sock)
|
|
|
|
async def _misc_loop(self) -> None:
|
|
while self._client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
|
await asyncio.sleep(1)
|
|
|
|
async def __aenter__(self) -> "Client":
|
|
"""Connect to the broker."""
|
|
await self.connect()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc: BaseException | None,
|
|
tb: TracebackType | None,
|
|
) -> None:
|
|
"""Disconnect from the broker."""
|
|
# Early out if already disconnected...
|
|
if self._disconnected.done():
|
|
disc_exc = self._disconnected.exception()
|
|
if disc_exc is not None:
|
|
# ...by raising the error that caused the disconnect
|
|
raise disc_exc
|
|
# ...by returning since the disconnect was intentional
|
|
return
|
|
# Try to gracefully disconnect from the broker
|
|
try:
|
|
await self.disconnect()
|
|
except MqttError as error:
|
|
# We tried to be graceful. Now there is no mercy.
|
|
MQTT_LOGGER.warning(
|
|
f'Could not gracefully disconnect due to "{error}". Forcing'
|
|
" disconnection."
|
|
)
|
|
await self.force_disconnect()
|
|
|
|
|
|
def _set_client_socket_defaults(
|
|
client_socket: _PahoSocket | None, socket_options: Iterable[SocketOption]
|
|
) -> None:
|
|
# Note that socket may be None if, e.g., the username and
|
|
# password combination didn't work. In this case, we return early.
|
|
if client_socket is None:
|
|
return
|
|
# Furthermore, paho sometimes gives us a socket wrapper instead of
|
|
# the raw socket. E.g., for WebSocket-based connections.
|
|
if not isinstance(client_socket, socket.socket):
|
|
return
|
|
# At this point, we know that we got an actual socket. We change
|
|
# some of the default options.
|
|
for socket_option in socket_options:
|
|
client_socket.setsockopt(*socket_option)
|