# SPDX-FileCopyrightText: 2022-present Inria
# SPDX-FileCopyrightText: 2022-present Alexandre Abadie <alexandre.abadie@inria.fr>
# SPDX-FileCopyrightText: 2023-present Filip Maksimovic <filip.maksimovic@inria.fr>
# SPDX-FileCopyrightText: 2024-present Diego Badillo <diego.badillo@sansano.usm.cl>
#
# SPDX-License-Identifier: BSD-3-Clause
"""Module for the Dotbot protocol API."""
import dataclasses
from abc import ABC, abstractmethod
from binascii import hexlify
from dataclasses import dataclass
from enum import Enum, IntEnum
from itertools import chain
from typing import List
PROTOCOL_VERSION = 9
[docs]
class PayloadType(Enum):
"""Types of DotBot payload types."""
CMD_MOVE_RAW = 0
CMD_RGB_LED = 1
LH2_RAW_DATA = 2
LH2_LOCATION = 3
ADVERTISEMENT = 4
GPS_POSITION = 5
DOTBOT_DATA = 6
CONTROL_MODE = 7
LH2_WAYPOINTS = 8
GPS_WAYPOINTS = 9
SAILBOT_DATA = 10
CMD_XGO_ACTION = 11
LH2_PROCESSED_DATA = 12
INVALID_PAYLOAD = 13 # Increase each time a new payload type is added
DOTBOT_SIMULATOR_DATA = 250
[docs]
class ApplicationType(IntEnum):
"""Types of DotBot applications."""
DotBot = 0 # pylint: disable=invalid-name
SailBot = 1 # pylint: disable=invalid-name
Freebot = 2 # pylint: disable=invalid-name
XGO = 3
LH2_mini_mote = 4
[docs]
class ControlModeType(IntEnum):
"""Types of DotBot control modes."""
MANUAL = 0
AUTO = 1
[docs]
class ProtocolPayloadParserException(Exception):
"""Exception raised on invalid or unsupported payload."""
[docs]
@dataclass
class ProtocolField:
"""Data class that describes a payload field."""
value: int = 0
name: str = ""
length: int = 1
signed: bool = False
[docs]
@dataclass
class ProtocolData(ABC):
"""Base class for protocol payload data classes."""
@property
@abstractmethod
def fields(self) -> List[ProtocolField]:
"""Returns the list of fields in this data."""
[docs]
@staticmethod
@abstractmethod
def from_bytes(bytes_):
"""Returns a ProtocolData instance from a bytearray."""
[docs]
@dataclass
class CommandMoveRaw(ProtocolData):
"""Dataclass that holds move raw command data fields."""
left_x: int = 0
left_y: int = 0
right_x: int = 0
right_y: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.left_x, name="lx", length=1, signed=True),
ProtocolField(self.left_y, name="ly", length=1, signed=True),
ProtocolField(self.right_x, name="rx", length=1, signed=True),
ProtocolField(self.right_y, name="ry", length=1, signed=True),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return CommandMoveRaw(*bytes_[0:4])
[docs]
@dataclass
class CommandRgbLed(ProtocolData):
"""Dataclass that holds a complete rgb led command fields."""
red: int = 0
green: int = 0
blue: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.red, name="red"),
ProtocolField(self.green, name="green"),
ProtocolField(self.blue, name="blue"),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return CommandRgbLed(*bytes_[0:3])
[docs]
@dataclass
class CommandXgoAction(ProtocolData):
"""Dataclass that holds an XGO action."""
action: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.action, name="action"),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return CommandXgoAction(bytes_[0])
[docs]
@dataclass
class Lh2RawLocation(ProtocolData):
"""Dataclass that holds LH2 raw location data."""
bits: int = 0x0000000000000000
polynomial_index: int = 0x00
offset: int = 0x00
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.bits, name="bits", length=8),
ProtocolField(self.polynomial_index, name="poly", length=1),
ProtocolField(self.offset, name="off.", length=1, signed=True),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return Lh2RawLocation(
int.from_bytes(bytes_[0:8], "little"),
int.from_bytes(bytes_[8:9], "little"),
int.from_bytes(bytes_[9:10], "little", signed=True),
)
[docs]
@dataclass
class Lh2ProcessedLocation(ProtocolData):
"""Dataclass that holds LH2 processed location data."""
polynomial_index: int = 0x00
lfsr_index: int = 0x00000000
timestamp_us: int = 0x00000000
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.polynomial_index, name="poly", length=1),
ProtocolField(self.lfsr_index, name="lfsr_index", length=4),
ProtocolField(self.timestamp_us, name="timestamp_us", length=4),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return Lh2ProcessedLocation(
int.from_bytes(bytes_[0:1], "little"),
int.from_bytes(bytes_[1:5], "little"),
int.from_bytes(bytes_[5:9], "little"),
)
[docs]
@dataclass
class Lh2RawData(ProtocolData):
"""Dataclass that holds LH2 raw data."""
locations: List[Lh2RawLocation] = dataclasses.field(default_factory=lambda: [])
@property
def fields(self) -> List[ProtocolField]:
return list(chain(*[location.fields for location in self.locations]))
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return Lh2RawData(
[
Lh2RawLocation.from_bytes(bytes_[0:10]),
Lh2RawLocation.from_bytes(bytes_[10:20]),
]
)
[docs]
@dataclass
class LH2Location(ProtocolData):
"""Dataclass that holds LH2 computed location data."""
pos_x: int = 0
pos_y: int = 0
pos_z: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.pos_x, name="x", length=4),
ProtocolField(self.pos_y, name="y", length=4),
ProtocolField(self.pos_z, name="z", length=4),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return LH2Location(
int.from_bytes(bytes_[0:4], "little"),
int.from_bytes(bytes_[4:8], "little"),
int.from_bytes(bytes_[8:12], "little"),
)
[docs]
@dataclass
class DotBotData(ProtocolData):
"""Dataclass that holds direction and LH2 raw data from DotBot application."""
direction: int = 0xFFFF
locations: List[Lh2RawLocation] = dataclasses.field(default_factory=lambda: [])
@property
def fields(self) -> List[ProtocolField]:
_fields = [ProtocolField(self.direction, name="dir.", length=2, signed=True)]
_fields += list(chain(*[location.fields for location in self.locations]))
return _fields
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return DotBotData(
direction=int.from_bytes(bytes_[0:2], "little", signed=True),
locations=[
Lh2RawLocation.from_bytes(bytes_[2:12]),
Lh2RawLocation.from_bytes(bytes_[12:22]),
],
)
[docs]
@dataclass
class GPSPosition(ProtocolData):
"""Dataclass that holds GPS positions."""
latitude: int = 0
longitude: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.latitude, name="latitude", length=4, signed=True),
ProtocolField(self.longitude, name="longitude", length=4, signed=True),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return GPSPosition(
latitude=int.from_bytes(bytes_[0:4], "little", signed=True),
longitude=int.from_bytes(bytes_[4:8], "little", signed=True),
)
[docs]
@dataclass
class SailBotData(ProtocolData):
"""Dataclass that holds SailBot data from SailBot application."""
direction: int = 0xFFFF
latitude: int = 0
longitude: int = 0
wind_angle: int = 0xFFFF # uint angles from 0 to 359
rudder_angle: int = 0
sail_angle: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.direction, name="dir.", length=2, signed=False),
ProtocolField(self.latitude, name="latitude", length=4, signed=True),
ProtocolField(self.longitude, name="longitude", length=4, signed=True),
ProtocolField(self.wind_angle, name="wind ang", length=2, signed=False),
ProtocolField(self.rudder_angle, name="rud.", length=1, signed=True),
ProtocolField(self.sail_angle, name="sail.", length=1, signed=True),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return SailBotData(
direction=int.from_bytes(bytes_[0:2], "little", signed=False),
latitude=int.from_bytes(bytes_[2:6], "little", signed=True),
longitude=int.from_bytes(bytes_[6:10], "little", signed=True),
wind_angle=int.from_bytes(bytes_[10:12], "little", signed=False),
rudder_angle=int.from_bytes(bytes_[12:13], "little", signed=True),
sail_angle=int.from_bytes(bytes_[13:14], "little", signed=True),
)
[docs]
@dataclass
class DotBotSimulatorData(ProtocolData):
"""Dataclass that holds direction and GPS data and heading from SailBot application."""
theta: int = 0xFFFF
pos_x: int = 0
pos_y: int = 0
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.theta, name="theta", length=2),
ProtocolField(self.pos_x, name="pos_x", length=4),
ProtocolField(self.pos_y, name="pos_y", length=4),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return DotBotSimulatorData(
theta=int.from_bytes(bytes_[0:2], "little"),
pos_x=int.from_bytes(bytes_[2:6], "little"),
pos_y=int.from_bytes(bytes_[6:10], "little"),
)
[docs]
@dataclass
class Advertisement(ProtocolData):
"""Dataclass that holds an advertisement (emtpy)."""
@property
def fields(self) -> List[ProtocolField]:
return []
[docs]
@staticmethod
def from_bytes(_: bytes) -> ProtocolData:
return Advertisement()
[docs]
@dataclass
class ControlMode(ProtocolData):
"""Dataclass that holds a control mode message."""
mode: ControlModeType = ControlModeType.MANUAL
@property
def fields(self) -> List[ProtocolField]:
return [
ProtocolField(self.mode, "mode"),
]
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
return ControlMode(bytes_[0])
[docs]
@dataclass
class LH2Waypoints(ProtocolData):
"""Dataclass that holds a list of LH2 waypoints."""
threshold: int
waypoints: List[LH2Location] = dataclasses.field(default_factory=lambda: [])
@property
def fields(self) -> List[ProtocolField]:
_fields = [ProtocolField(len(self.waypoints), name="len.")]
_fields += [ProtocolField(value=self.threshold, name="thr.")]
_fields += list(chain(*[waypoint.fields for waypoint in self.waypoints]))
return _fields
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
waypoints_count = int(bytes_[0])
threshold = int(bytes_[1])
waypoints_bytes = bytes_[2:]
waypoints = []
for idx in range(waypoints_count):
for i in range(3):
waypoints.append(
int.from_bytes(
waypoints_bytes[12 * idx + i * 4 : 12 * idx + (i + 1) * 4],
byteorder="little",
)
)
waypoints = [
(waypoints[i], waypoints[i + 1], waypoints[i + 2])
for i in range(0, 3 * waypoints_count, 3)
]
return LH2Waypoints(threshold=threshold, waypoints=waypoints)
[docs]
@dataclass
class GPSWaypoints(ProtocolData):
"""Dataclass that holds a list of GPS waypoints."""
threshold: int
waypoints: List[GPSPosition] = dataclasses.field(default_factory=lambda: [])
@property
def fields(self) -> List[ProtocolField]:
_fields = [ProtocolField(len(self.waypoints), name="len.")]
_fields += [ProtocolField(value=self.threshold, name="thr.")]
_fields += list(chain(*[waypoint.fields for waypoint in self.waypoints]))
return _fields
[docs]
@staticmethod
def from_bytes(bytes_) -> ProtocolData:
waypoints_count = int(bytes_[0])
threshold = int(bytes_[1])
waypoints_bytes = bytes_[2:]
waypoints = []
for idx in range(2 * waypoints_count):
waypoints.append(
float(
int.from_bytes(
waypoints_bytes[4 * idx : 4 * (idx + 1)], byteorder="little"
)
/ 1e6
)
)
waypoints = [
(waypoints[i], waypoints[i + 1]) for i in range(0, 2 * waypoints_count, 2)
]
return GPSWaypoints(threshold=threshold, waypoints=waypoints)
[docs]
@dataclass
class ProtocolPayload:
"""Manage a protocol complete payload (header + type + values)."""
header: ProtocolHeader
payload_type: PayloadType
values: ProtocolData
[docs]
def to_bytes(self, endian="little") -> bytes:
"""Converts a payload to a bytearray."""
buffer = bytearray()
for field in self.header.fields:
buffer += int(field.value).to_bytes(
length=field.length, byteorder=endian, signed=field.signed
)
buffer += int(self.payload_type.value).to_bytes(length=1, byteorder=endian)
for field in self.values.fields:
buffer += int(field.value).to_bytes(
length=field.length, byteorder=endian, signed=field.signed
)
return buffer
[docs]
@staticmethod
def from_bytes(bytes_: bytes):
"""Parse a bytearray to return a protocol payload instance."""
try:
header = ProtocolHeader.from_bytes(bytes_[0:24])
except ValueError as exc:
raise ProtocolPayloadParserException(f"Invalid header: {exc}") from exc
if header.version != PROTOCOL_VERSION:
raise ProtocolPayloadParserException(
f"Invalid header: Unsupported payload version '{header.version}' (expected: {PROTOCOL_VERSION})"
)
payload_type = PayloadType(int.from_bytes(bytes_[24:25], "little"))
if payload_type == PayloadType.CMD_MOVE_RAW:
values = CommandMoveRaw.from_bytes(bytes_[25:30])
elif payload_type == PayloadType.CMD_RGB_LED:
values = CommandRgbLed.from_bytes(bytes_[25:29])
elif payload_type == PayloadType.LH2_RAW_DATA:
values = Lh2RawData.from_bytes(bytes_[25:45])
elif payload_type == PayloadType.LH2_LOCATION:
values = LH2Location.from_bytes(bytes_[25:37])
elif payload_type == PayloadType.ADVERTISEMENT:
values = Advertisement.from_bytes(None)
elif payload_type == PayloadType.GPS_POSITION:
values = GPSPosition.from_bytes(bytes_[25:33])
elif payload_type == PayloadType.DOTBOT_DATA:
values = DotBotData.from_bytes(bytes_[25:47])
elif payload_type == PayloadType.SAILBOT_DATA:
values = SailBotData.from_bytes(bytes_[25:39])
elif payload_type == PayloadType.DOTBOT_SIMULATOR_DATA:
values = DotBotSimulatorData.from_bytes(bytes_[25:35])
elif payload_type == PayloadType.CONTROL_MODE:
values = ControlMode.from_bytes(bytes_[25:26])
elif payload_type == PayloadType.LH2_WAYPOINTS:
values = LH2Waypoints.from_bytes(bytes_[25:])
elif payload_type == PayloadType.GPS_WAYPOINTS:
values = GPSWaypoints.from_bytes(bytes_[25:])
elif payload_type == PayloadType.LH2_PROCESSED_DATA:
values = Lh2ProcessedLocation.from_bytes(bytes_[25:34])
else:
raise ProtocolPayloadParserException(
f"Unsupported payload type '{payload_type.value}'"
)
return ProtocolPayload(header, payload_type, values)
def __repr__(self):
header_separators = [
"-" * (4 * field.length + 2) for field in self.header.fields
]
type_separators = ["-" * 6] # type
values_separators = [
"-" * (4 * field.length + 2) for field in self.values.fields
]
header_names = [
f" {field.name:<{4 * field.length + 1}}" for field in self.header.fields
]
type_name = [" type "]
values_names = [
f" {field.name:<{4 * field.length + 1}}" for field in self.values.fields
]
header_values = [
f" 0x{hexlify(int(field.value).to_bytes(field.length, 'big', signed=field.signed)).decode():<{4 * field.length - 1}}"
for field in self.header.fields
]
type_value = [
f" 0x{hexlify(int(PayloadType(self.payload_type).value).to_bytes(1, 'big')).decode():<3}"
]
values_values = [
f" 0x{hexlify(int(field.value).to_bytes(field.length, 'big', signed=field.signed)).decode():<{4 * field.length - 1}}"
for field in self.values.fields
]
num_bytes = (
sum(field.length for field in self.header.fields)
+ 1
+ sum(field.length for field in self.values.fields)
)
if num_bytes > 32:
# put values on a separate row
separators = header_separators + type_separators
names = header_names + type_name
values = header_values + type_value
return (
f" {' ' * 16}+{'+'.join(separators)}+\n"
f" {PayloadType(self.payload_type).name:<16}|{'|'.join(names)}|\n"
f" {f'({num_bytes} Bytes)':<16}|{'|'.join(values)}|\n"
f" {' ' * 16}+{'+'.join(separators)}+\n"
f" {' ' * 16}+{'+'.join(values_separators)}+\n"
f" {' ' * 16}|{'|'.join(values_names)}|\n"
f" {' ' * 16}|{'|'.join(values_values)}|\n"
f" {' ' * 16}+{'+'.join(values_separators)}+\n"
)
# all in a row by default
separators = header_separators + type_separators + values_separators
names = header_names + type_name + values_names
values = header_values + type_value + values_values
return (
f" {' ' * 16}+{'+'.join(separators)}+\n"
f" {PayloadType(self.payload_type).name:<16}|{'|'.join(names)}|\n"
f" {f'({num_bytes} Bytes)':<16}|{'|'.join(values)}|\n"
f" {' ' * 16}+{'+'.join(separators)}+\n"
)