#####################################################################################
#
# Copyright (c) typedef int GmbH
# SPDX-License-Identifier: EUPL-1.2
#
#####################################################################################
import warnings
import attr
from attr.validators import instance_of, optional
from bitstring import pack
from ._utils import ParseFailure, SerialisationFailure, build_header, build_string, read_prefixed_data, read_string
@attr.s
[docs]
class Failure(object):
[docs]
reason = attr.ib(default=None)
@attr.s
[docs]
class Disconnect(object):
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
return build_header(14, (False, False, False, False), 0)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
return cls()
@attr.s
[docs]
class PingRESP(object):
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
return build_header(13, (False, False, False, False), 0)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
return cls()
@attr.s
[docs]
class PingREQ(object):
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
return build_header(12, (False, False, False, False), 0)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
return cls()
@attr.s
[docs]
class UnsubACK(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(11, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Session identifier
b.append(pack("uint:16", self.packet_identifier).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
packet_identifier = data.read("uint:16")
return cls(packet_identifier=packet_identifier)
@attr.s
[docs]
class Unsubscribe(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
topics = attr.ib(validator=instance_of(list))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(10, (False, False, True, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Session identifier
b.append(pack("uint:16", self.packet_identifier).bytes)
for topic in self.topics:
if not isinstance(topic, unicode):
raise SerialisationFailure(self, "Topics must be Unicode")
b.append(build_string(topic))
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, True, False):
raise ParseFailure(cls, "Bad flags")
topics = []
packet_identifier = data.read("uint:16")
while not data.bitpos == len(data):
topics.append(read_string(data))
if len(topics) == 0:
raise ParseFailure(cls, "Must contain a payload.")
return cls(packet_identifier=packet_identifier, topics=topics)
@attr.s
[docs]
class PubCOMP(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(7, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
b.append(pack("uint:16", self.packet_identifier).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
"""
Disassemble from an on-wire message.
"""
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
packet_identifier = data.read("uint:16")
return cls(packet_identifier)
@attr.s
[docs]
class PubREL(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(6, (False, False, True, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
b.append(pack("uint:16", self.packet_identifier).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
"""
Disassemble from an on-wire message.
"""
if flags != (False, False, True, False):
raise ParseFailure(cls, "Bad flags")
packet_identifier = data.read("uint:16")
return cls(packet_identifier)
@attr.s
[docs]
class PubREC(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(5, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
b.append(pack("uint:16", self.packet_identifier).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
"""
Disassemble from an on-wire message.
"""
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
packet_identifier = data.read("uint:16")
return cls(packet_identifier)
@attr.s
[docs]
class PubACK(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(4, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
b.append(pack("uint:16", self.packet_identifier).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
"""
Disassemble from an on-wire message.
"""
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
packet_identifier = data.read("uint:16")
return cls(packet_identifier)
@attr.s
[docs]
class Publish(object):
[docs]
duplicate = attr.ib(validator=instance_of(bool))
[docs]
qos_level = attr.ib(validator=instance_of(int))
[docs]
retain = attr.ib(validator=instance_of(bool))
[docs]
topic_name = attr.ib(validator=instance_of(unicode))
[docs]
payload = attr.ib(validator=instance_of(bytes))
[docs]
packet_identifier = attr.ib(validator=optional(instance_of(int)), default=None)
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
flags = [self.duplicate]
if self.qos_level == 0:
flags.extend([False, False])
elif self.qos_level == 1:
flags.extend([False, True])
elif self.qos_level == 2:
flags.extend([True, False])
else:
raise SerialisationFailure(self, "QoS must be 0, 1, or 2")
flags.append(self.retain)
payload = self._make_payload()
header = build_header(3, flags, len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Topic Name
b.append(build_string(self.topic_name))
if self.packet_identifier:
if self.qos_level > 0:
# Session identifier
b.append(pack("uint:16", self.packet_identifier).bytes)
else:
raise SerialisationFailure(self, "Packet Identifier on non-QoS 1/2 packet")
else:
if self.qos_level > 0:
raise SerialisationFailure(self, "QoS level > 0 but no Packet Identifier")
# Payload (bytes)
b.append(self.payload)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
total_length = len(data)
duplicate = flags[0]
if flags[1:3] == (False, False):
qos_level = 0
elif flags[1:3] == (False, True):
qos_level = 1
elif flags[1:3] == (True, False):
qos_level = 2
elif flags[1:3] == (True, True):
raise ParseFailure(cls, "Invalid QoS value")
retain = flags[3]
topic_name = read_string(data)
if qos_level in [1, 2]:
packet_identifier = data.read("uint:16")
else:
packet_identifier = None
payload = data.read(total_length - data.bitpos).bytes
return cls(
duplicate=duplicate,
qos_level=qos_level,
retain=retain,
topic_name=topic_name,
packet_identifier=packet_identifier,
payload=payload,
)
@attr.s
[docs]
class SubACK(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
return_codes = attr.ib(validator=instance_of(list))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(9, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Session identifier
b.append(pack("uint:16", self.packet_identifier).bytes)
for code in self.return_codes:
b.append(pack("uint:8", code).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
return_codes = []
packet_identifier = data.read("uint:16")
while not data.bitpos == len(data):
return_code = data.read("uint:8")
return_codes.append(return_code)
return cls(packet_identifier=packet_identifier, return_codes=return_codes)
@attr.s
[docs]
class SubscriptionTopicRequest(object):
[docs]
topic_filter = attr.ib(validator=instance_of(unicode))
[docs]
max_qos = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message part.
"""
b = []
# Topic filter, as UTF-8
b.append(build_string(self.topic_filter))
# Reserved section + max QoS
b.append(pack("uint:6, uint:2", 0, self.max_qos).bytes)
return b"".join(b)
@attr.s
[docs]
class Subscribe(object):
[docs]
packet_identifier = attr.ib(validator=instance_of(int))
[docs]
topic_requests = attr.ib(validator=instance_of(list))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(8, (False, False, True, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Session identifier
b.append(pack("uint:16", self.packet_identifier).bytes)
for request in self.topic_requests:
b.append(request.serialise())
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
if flags != (False, False, True, False):
raise ParseFailure(cls, "Bad flags")
pairs = []
packet_identifier = data.read("uint:16")
def parse_pair():
topic_filter = read_string(data)
reserved = data.read("uint:6")
max_qos = data.read("uint:2")
if reserved:
raise ParseFailure(cls, "Data in QoS Reserved area")
if max_qos not in [0, 1, 2]:
raise ParseFailure(cls, "Invalid QoS")
pairs.append(SubscriptionTopicRequest(topic_filter=topic_filter, max_qos=max_qos))
parse_pair()
while not data.bitpos == len(data):
parse_pair()
return cls(packet_identifier=packet_identifier, topic_requests=pairs)
@attr.s
[docs]
class ConnACK(object):
[docs]
session_present = attr.ib(validator=instance_of(bool))
[docs]
return_code = attr.ib(validator=instance_of(int))
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(2, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Flags -- 7 bit reserved + Session Present flag
b.append(pack("uint:7, bool", 0, self.session_present).bytes)
# Return code
b.append(pack("uint:8", self.return_code).bytes)
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
"""
Take an on-wire message and turn it into an instance of this class.
"""
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
reserved = data.read(7).uint
if reserved:
raise ParseFailure(cls, "Reserved flag used.")
built = cls(session_present=data.read(1).bool, return_code=data.read(8).uint)
# XXX: Do some more verification, re conn flags
if not data.bitpos == len(data):
# There's some wacky stuff going on here -- data they included, but
# didn't put flags for, maybe?
warnings.warn(
("Quirky server CONNACK -- packet length was %d bytes but only had %d bytes of useful data")
% (data.bitpos, len(data))
)
return built
@attr.s
[docs]
class ConnectFlags(object):
[docs]
username = attr.ib(validator=instance_of(bool), default=False)
[docs]
password = attr.ib(validator=instance_of(bool), default=False)
[docs]
will = attr.ib(validator=instance_of(bool), default=False)
[docs]
will_retain = attr.ib(validator=instance_of(bool), default=False)
[docs]
will_qos = attr.ib(validator=instance_of(int), default=0)
[docs]
clean_session = attr.ib(validator=instance_of(bool), default=False)
[docs]
reserved = attr.ib(validator=instance_of(bool), default=False)
[docs]
def serialise(self):
"""
Assemble this into an on-wire message portion.
"""
return pack(
"bool, bool, bool, uint:2, bool, bool, bool",
self.username,
self.password,
self.will_retain,
self.will_qos,
self.will,
self.clean_session,
self.reserved,
).bytes
@classmethod
[docs]
def deserialise(cls, data):
built = cls(
username=data.read(1).bool,
password=data.read(1).bool,
will_retain=data.read(1).bool,
will_qos=data.read(2).uint,
will=data.read(1).bool,
clean_session=data.read(1).bool,
reserved=data.read(1).bool,
)
# XXX: Do some more conformance checking here
# Need to worry about invalid flag combinations
if built.reserved:
# MQTT-3.1.2-3, reserved flag must not be used
raise ParseFailure(cls, "Reserved flag in CONNECT used")
return built
@attr.s
[docs]
class Connect(object):
[docs]
client_id = attr.ib(validator=instance_of(unicode))
[docs]
flags = attr.ib(validator=instance_of(ConnectFlags))
[docs]
keep_alive = attr.ib(validator=instance_of(int), default=0)
[docs]
will_topic = attr.ib(validator=optional(instance_of(unicode)), default=None)
[docs]
will_message = attr.ib(validator=optional(instance_of(bytes)), default=None)
[docs]
username = attr.ib(validator=optional(instance_of(unicode)), default=None)
[docs]
password = attr.ib(validator=optional(instance_of(unicode)), default=None)
[docs]
def serialise(self):
"""
Assemble this into an on-wire message.
"""
payload = self._make_payload()
header = build_header(1, (False, False, False, False), len(payload))
return header + payload
[docs]
def _make_payload(self):
"""
Build the payload from its constituent parts.
"""
b = []
# Protocol name (MQTT)
b.append(build_string("MQTT"))
# Protocol Level (4 == 3.1.1)
b.append(pack("uint:8", 4).bytes)
# CONNECT flags
b.append(self.flags.serialise())
# Keep Alive time
b.append(pack("uint:16", self.keep_alive).bytes)
# Client ID
b.append(build_string(self.client_id))
if self.flags.will:
b.append(build_string(self.will_topic))
# Will message is a uint16 prefixed bytestring
b.append(pack("uint:16", len(self.will_message)).bytes)
b.append(self.will_message)
if self.flags.username:
b.append(build_string(self.username))
# Technically this should be binary data but we will only accept UTF-8
if self.flags.password:
b.append(build_string(self.password))
return b"".join(b)
@classmethod
[docs]
def deserialise(cls, flags, data):
"""
Disassemble from an on-wire message.
"""
if flags != (False, False, False, False):
raise ParseFailure(cls, "Bad flags")
protocol = read_string(data)
if protocol != "MQTT":
print(protocol)
raise ParseFailure(cls, "Bad protocol name")
protocol_level = data.read("uint:8")
if protocol_level != 4:
raise ParseFailure(cls, "Bad protocol level")
flags = ConnectFlags.deserialise(data.read(8))
# Keep alive, in seconds
keep_alive = data.read("uint:16")
# The client ID
client_id = read_string(data)
if flags.will:
# MQTT-3.1.3-10, topic must be UTF-8
will_topic = read_string(data)
will_message = read_prefixed_data(data)
else:
will_topic = None
will_message = None
# Username
if flags.username:
username = read_string(data)
else:
username = None
# Password
if flags.password:
password = read_string(data)
else:
password = None
if not data.bitpos == len(data):
# There's some wacky stuff going on here -- data they included, but
# didn't put flags for, maybe?
warnings.warn(
("Quirky client CONNECT -- packet length was %d bytes but only had %d bytes of useful data")
% (data.bitpos, len(data))
)
# The event
return cls(
flags=flags,
keep_alive=keep_alive,
client_id=client_id,
will_topic=will_topic,
will_message=will_message,
username=username,
password=password,
)