Source code for dotbot.protocol

# SPDX-FileCopyrightText: 2022-present Inria
# SPDX-FileCopyrightText: 2022-present Alexandre Abadie <alexandre.abadie@inria.fr>
# SPDX-FileCopyrightText: 2023-present Filip Maksimovic <filip.maksimovic@inria.fr>
# SPDX-FileCopyrightText: 2024-present Diego Badillo <diego.badillo@sansano.usm.cl>
#
# SPDX-License-Identifier: BSD-3-Clause

"""Module for the Dotbot protocol API."""

import dataclasses
from abc import ABC, abstractmethod
from binascii import hexlify
from dataclasses import dataclass
from enum import Enum, IntEnum
from itertools import chain
from typing import List

PROTOCOL_VERSION = 8


[docs] class PayloadType(Enum): """Types of DotBot payload types.""" CMD_MOVE_RAW = 0 CMD_RGB_LED = 1 LH2_RAW_DATA = 2 LH2_LOCATION = 3 ADVERTISEMENT = 4 GPS_POSITION = 5 DOTBOT_DATA = 6 CONTROL_MODE = 7 LH2_WAYPOINTS = 8 GPS_WAYPOINTS = 9 SAILBOT_DATA = 10 CMD_XGO_ACTION = 11 INVALID_PAYLOAD = 12 # Increase each time a new payload type is added DOTBOT_SIMULATOR_DATA = 250
[docs] class ApplicationType(IntEnum): """Types of DotBot applications.""" DotBot = 0 # pylint: disable=invalid-name SailBot = 1 # pylint: disable=invalid-name Freebot = 2 # pylint: disable=invalid-name XGO = 3
[docs] class ControlModeType(IntEnum): """Types of DotBot control modes.""" MANUAL = 0 AUTO = 1
[docs] class ProtocolPayloadParserException(Exception): """Exception raised on invalid or unsupported payload."""
[docs] @dataclass class ProtocolField: """Data class that describes a payload field.""" value: int = 0 name: str = "" length: int = 1 signed: bool = False
[docs] @dataclass class ProtocolData(ABC): """Base class for protocol payload data classes.""" @property @abstractmethod def fields(self) -> List[ProtocolField]: """Returns the list of fields in this data."""
[docs] @staticmethod @abstractmethod def from_bytes(bytes_): """Returns a ProtocolData instance from a bytearray."""
[docs] @dataclass class ProtocolHeader(ProtocolData): """Dataclass that holds header fields.""" destination: int = 0xFFFFFFFFFFFFFFFF source: int = 0x0000000000000000 swarm_id: int = 0x0000 application: ApplicationType = ApplicationType.DotBot version: int = PROTOCOL_VERSION msg_id: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.destination, name="dst", length=8), ProtocolField(self.source, name="src", length=8), ProtocolField(self.swarm_id, name="swarm id", length=2), ProtocolField(self.application, name="app."), ProtocolField(self.version, name="ver."), ProtocolField(self.msg_id, name="msg id", length=4), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return ProtocolHeader( int.from_bytes(bytes_[0:8], "little"), # destination int.from_bytes(bytes_[8:16], "little"), # source int.from_bytes(bytes_[16:18], "little"), # swarm_id ApplicationType(int.from_bytes(bytes_[18:19], "little")), # application int.from_bytes(bytes_[19:20], "little"), # version int.from_bytes(bytes_[20:24], "little"), # msg id )
[docs] @dataclass class CommandMoveRaw(ProtocolData): """Dataclass that holds move raw command data fields.""" left_x: int = 0 left_y: int = 0 right_x: int = 0 right_y: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.left_x, name="lx", length=1, signed=True), ProtocolField(self.left_y, name="ly", length=1, signed=True), ProtocolField(self.right_x, name="rx", length=1, signed=True), ProtocolField(self.right_y, name="ry", length=1, signed=True), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return CommandMoveRaw(*bytes_[0:4])
[docs] @dataclass class CommandRgbLed(ProtocolData): """Dataclass that holds a complete rgb led command fields.""" red: int = 0 green: int = 0 blue: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.red, name="red"), ProtocolField(self.green, name="green"), ProtocolField(self.blue, name="blue"), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return CommandRgbLed(*bytes_[0:3])
[docs] @dataclass class CommandXgoAction(ProtocolData): """Dataclass that holds an XGO action.""" action: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.action, name="action"), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return CommandXgoAction(bytes_[0])
[docs] @dataclass class Lh2RawLocation(ProtocolData): """Dataclass that holds LH2 raw location data.""" bits: int = 0x0000000000000000 polynomial_index: int = 0x00 offset: int = 0x00 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.bits, name="bits", length=8), ProtocolField(self.polynomial_index, name="poly", length=1), ProtocolField(self.offset, name="off.", length=1, signed=True), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return Lh2RawLocation( int.from_bytes(bytes_[0:8], "little"), int.from_bytes(bytes_[8:9], "little"), int.from_bytes(bytes_[9:10], "little", signed=True), )
[docs] @dataclass class Lh2RawData(ProtocolData): """Dataclass that holds LH2 raw data.""" locations: List[Lh2RawLocation] = dataclasses.field(default_factory=lambda: []) @property def fields(self) -> List[ProtocolField]: return list(chain(*[location.fields for location in self.locations]))
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return Lh2RawData( [ Lh2RawLocation.from_bytes(bytes_[0:10]), Lh2RawLocation.from_bytes(bytes_[10:20]), ] )
[docs] @dataclass class LH2Location(ProtocolData): """Dataclass that holds LH2 computed location data.""" pos_x: int = 0 pos_y: int = 0 pos_z: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.pos_x, name="x", length=4), ProtocolField(self.pos_y, name="y", length=4), ProtocolField(self.pos_z, name="z", length=4), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return LH2Location( int.from_bytes(bytes_[0:4], "little"), int.from_bytes(bytes_[4:8], "little"), int.from_bytes(bytes_[8:12], "little"), )
[docs] @dataclass class DotBotData(ProtocolData): """Dataclass that holds direction and LH2 raw data from DotBot application.""" direction: int = 0xFFFF locations: List[Lh2RawLocation] = dataclasses.field(default_factory=lambda: []) @property def fields(self) -> List[ProtocolField]: _fields = [ProtocolField(self.direction, name="dir.", length=2, signed=True)] _fields += list(chain(*[location.fields for location in self.locations])) return _fields
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return DotBotData( direction=int.from_bytes(bytes_[0:2], "little", signed=True), locations=[ Lh2RawLocation.from_bytes(bytes_[2:12]), Lh2RawLocation.from_bytes(bytes_[12:22]), ], )
[docs] @dataclass class GPSPosition(ProtocolData): """Dataclass that holds GPS positions.""" latitude: int = 0 longitude: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.latitude, name="latitude", length=4, signed=True), ProtocolField(self.longitude, name="longitude", length=4, signed=True), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return GPSPosition( latitude=int.from_bytes(bytes_[0:4], "little", signed=True), longitude=int.from_bytes(bytes_[4:8], "little", signed=True), )
[docs] @dataclass class SailBotData(ProtocolData): """Dataclass that holds SailBot data from SailBot application.""" direction: int = 0xFFFF latitude: int = 0 longitude: int = 0 wind_angle: int = 0xFFFF # uint angles from 0 to 359 rudder_angle: int = 0 sail_angle: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.direction, name="dir.", length=2, signed=False), ProtocolField(self.latitude, name="latitude", length=4, signed=True), ProtocolField(self.longitude, name="longitude", length=4, signed=True), ProtocolField(self.wind_angle, name="wind ang", length=2, signed=False), ProtocolField(self.rudder_angle, name="rud.", length=1, signed=True), ProtocolField(self.sail_angle, name="sail.", length=1, signed=True), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return SailBotData( direction=int.from_bytes(bytes_[0:2], "little", signed=False), latitude=int.from_bytes(bytes_[2:6], "little", signed=True), longitude=int.from_bytes(bytes_[6:10], "little", signed=True), wind_angle=int.from_bytes(bytes_[10:12], "little", signed=False), rudder_angle=int.from_bytes(bytes_[12:13], "little", signed=True), sail_angle=int.from_bytes(bytes_[13:14], "little", signed=True), )
[docs] @dataclass class DotBotSimulatorData(ProtocolData): """Dataclass that holds direction and GPS data and heading from SailBot application.""" theta: int = 0xFFFF pos_x: int = 0 pos_y: int = 0 @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.theta, name="theta", length=2), ProtocolField(self.pos_x, name="pos_x", length=4), ProtocolField(self.pos_y, name="pos_y", length=4), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return DotBotSimulatorData( theta=int.from_bytes(bytes_[0:2], "little"), pos_x=int.from_bytes(bytes_[2:6], "little"), pos_y=int.from_bytes(bytes_[6:10], "little"), )
[docs] @dataclass class ControlMode(ProtocolData): """Dataclass that holds a control mode message.""" mode: ControlModeType = ControlModeType.MANUAL @property def fields(self) -> List[ProtocolField]: return [ ProtocolField(self.mode, "mode"), ]
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: return ControlMode(bytes_[0])
[docs] @dataclass class LH2Waypoints(ProtocolData): """Dataclass that holds a list of LH2 waypoints.""" threshold: int waypoints: List[LH2Location] = dataclasses.field(default_factory=lambda: []) @property def fields(self) -> List[ProtocolField]: _fields = [ProtocolField(len(self.waypoints), name="len.")] _fields += [ProtocolField(value=self.threshold, name="thr.")] _fields += list(chain(*[waypoint.fields for waypoint in self.waypoints])) return _fields
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: waypoints_count = int(bytes_[0]) threshold = int(bytes_[1]) waypoints_bytes = bytes_[2:] waypoints = [] for idx in range(waypoints_count): for i in range(3): waypoints.append( int.from_bytes( waypoints_bytes[12 * idx + i * 4 : 12 * idx + (i + 1) * 4], byteorder="little", ) ) waypoints = [ (waypoints[i], waypoints[i + 1], waypoints[i + 2]) for i in range(0, 3 * waypoints_count, 3) ] return LH2Waypoints(threshold=threshold, waypoints=waypoints)
[docs] @dataclass class GPSWaypoints(ProtocolData): """Dataclass that holds a list of GPS waypoints.""" threshold: int waypoints: List[GPSPosition] = dataclasses.field(default_factory=lambda: []) @property def fields(self) -> List[ProtocolField]: _fields = [ProtocolField(len(self.waypoints), name="len.")] _fields += [ProtocolField(value=self.threshold, name="thr.")] _fields += list(chain(*[waypoint.fields for waypoint in self.waypoints])) return _fields
[docs] @staticmethod def from_bytes(bytes_) -> ProtocolData: waypoints_count = int(bytes_[0]) threshold = int(bytes_[1]) waypoints_bytes = bytes_[2:] waypoints = [] for idx in range(2 * waypoints_count): waypoints.append( float( int.from_bytes( waypoints_bytes[4 * idx : 4 * (idx + 1)], byteorder="little" ) / 1e6 ) ) waypoints = [ (waypoints[i], waypoints[i + 1]) for i in range(0, 2 * waypoints_count, 2) ] return GPSWaypoints(threshold=threshold, waypoints=waypoints)
[docs] @dataclass class ProtocolPayload: """Manage a protocol complete payload (header + type + values).""" header: ProtocolHeader payload_type: PayloadType values: ProtocolData
[docs] def to_bytes(self, endian="little") -> bytes: """Converts a payload to a bytearray.""" buffer = bytearray() for field in self.header.fields: buffer += int(field.value).to_bytes( length=field.length, byteorder=endian, signed=field.signed ) buffer += int(self.payload_type.value).to_bytes(length=1, byteorder=endian) for field in self.values.fields: buffer += int(field.value).to_bytes( length=field.length, byteorder=endian, signed=field.signed ) return buffer
[docs] @staticmethod def from_bytes(bytes_: bytes): """Parse a bytearray to return a protocol payload instance.""" try: header = ProtocolHeader.from_bytes(bytes_[0:24]) except ValueError as exc: raise ProtocolPayloadParserException(f"Invalid header: {exc}") from exc if header.version != PROTOCOL_VERSION: raise ProtocolPayloadParserException( f"Invalid header: Unsupported payload version '{header.version}' (expected: {PROTOCOL_VERSION})" ) payload_type = PayloadType(int.from_bytes(bytes_[24:25], "little")) if payload_type == PayloadType.CMD_MOVE_RAW: values = CommandMoveRaw.from_bytes(bytes_[25:30]) elif payload_type == PayloadType.CMD_RGB_LED: values = CommandRgbLed.from_bytes(bytes_[25:29]) elif payload_type == PayloadType.LH2_RAW_DATA: values = Lh2RawData.from_bytes(bytes_[25:45]) elif payload_type == PayloadType.LH2_LOCATION: values = LH2Location.from_bytes(bytes_[25:37]) elif payload_type == PayloadType.ADVERTISEMENT: values = Advertisement.from_bytes(None) elif payload_type == PayloadType.GPS_POSITION: values = GPSPosition.from_bytes(bytes_[25:33]) elif payload_type == PayloadType.DOTBOT_DATA: values = DotBotData.from_bytes(bytes_[25:47]) elif payload_type == PayloadType.SAILBOT_DATA: values = SailBotData.from_bytes(bytes_[25:39]) elif payload_type == PayloadType.DOTBOT_SIMULATOR_DATA: values = DotBotSimulatorData.from_bytes(bytes_[25:35]) elif payload_type == PayloadType.CONTROL_MODE: values = ControlMode.from_bytes(bytes_[25:26]) elif payload_type == PayloadType.LH2_WAYPOINTS: values = LH2Waypoints.from_bytes(bytes_[25:]) elif payload_type == PayloadType.GPS_WAYPOINTS: values = GPSWaypoints.from_bytes(bytes_[25:]) else: raise ProtocolPayloadParserException( f"Unsupported payload type '{payload_type.value}'" ) return ProtocolPayload(header, payload_type, values)
def __repr__(self): header_separators = [ "-" * (4 * field.length + 2) for field in self.header.fields ] type_separators = ["-" * 6] # type values_separators = [ "-" * (4 * field.length + 2) for field in self.values.fields ] header_names = [ f" {field.name:<{4 * field.length + 1}}" for field in self.header.fields ] type_name = [" type "] values_names = [ f" {field.name:<{4 * field.length + 1}}" for field in self.values.fields ] header_values = [ f" 0x{hexlify(int(field.value).to_bytes(field.length, 'big', signed=field.signed)).decode():<{4 * field.length - 1}}" for field in self.header.fields ] type_value = [ f" 0x{hexlify(int(PayloadType(self.payload_type).value).to_bytes(1, 'big')).decode():<3}" ] values_values = [ f" 0x{hexlify(int(field.value).to_bytes(field.length, 'big', signed=field.signed)).decode():<{4 * field.length - 1}}" for field in self.values.fields ] num_bytes = ( sum(field.length for field in self.header.fields) + 1 + sum(field.length for field in self.values.fields) ) if num_bytes > 32: # put values on a separate row separators = header_separators + type_separators names = header_names + type_name values = header_values + type_value return ( f" {' ' * 16}+{'+'.join(separators)}+\n" f" {PayloadType(self.payload_type).name:<16}|{'|'.join(names)}|\n" f" {f'({num_bytes} Bytes)':<16}|{'|'.join(values)}|\n" f" {' ' * 16}+{'+'.join(separators)}+\n" f" {' ' * 16}+{'+'.join(values_separators)}+\n" f" {' ' * 16}|{'|'.join(values_names)}|\n" f" {' ' * 16}|{'|'.join(values_values)}|\n" f" {' ' * 16}+{'+'.join(values_separators)}+\n" ) # all in a row by default separators = header_separators + type_separators + values_separators names = header_names + type_name + values_names values = header_values + type_value + values_values return ( f" {' ' * 16}+{'+'.join(separators)}+\n" f" {PayloadType(self.payload_type).name:<16}|{'|'.join(names)}|\n" f" {f'({num_bytes} Bytes)':<16}|{'|'.join(values)}|\n" f" {' ' * 16}+{'+'.join(separators)}+\n" )