From f582a36e4df16d5709943f7df17a900c8bcc12ab Mon Sep 17 00:00:00 2001 From: Kevin O'Connor Date: Wed, 25 May 2016 11:37:40 -0400 Subject: Initial commit of source code. Signed-off-by: Kevin O'Connor --- klippy/msgproto.py | 313 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 klippy/msgproto.py (limited to 'klippy/msgproto.py') diff --git a/klippy/msgproto.py b/klippy/msgproto.py new file mode 100644 index 00000000..73c8d2d5 --- /dev/null +++ b/klippy/msgproto.py @@ -0,0 +1,313 @@ +# Protocol definitions for firmware communication +# +# Copyright (C) 2016 Kevin O'Connor +# +# This file may be distributed under the terms of the GNU GPLv3 license. +import json, zlib, logging + +DefaultMessages = { + 0: "identify_response offset=%u data=%.*s", + 1: "identify offset=%u count=%c", +} + +MESSAGE_MIN = 5 +MESSAGE_MAX = 64 +MESSAGE_HEADER_SIZE = 2 +MESSAGE_TRAILER_SIZE = 3 +MESSAGE_POS_LEN = 0 +MESSAGE_POS_SEQ = 1 +MESSAGE_TRAILER_CRC = 3 +MESSAGE_TRAILER_SYNC = 1 +MESSAGE_PAYLOAD_MAX = MESSAGE_MAX - MESSAGE_MIN +MESSAGE_SEQ_MASK = 0x0f +MESSAGE_DEST = 0x10 +MESSAGE_SYNC = '\x7E' + +class error(Exception): + def __init__(self, msg): + self.msg = msg + def __str__(self): + return self.msg + +def crc16_ccitt(buf): + crc = 0xffff + for data in buf: + data = ord(data) + data ^= crc & 0xff + data ^= (data & 0x0f) << 4 + crc = ((data << 8) | (crc >> 8)) ^ (data >> 4) ^ (data << 3) + crc = chr(crc >> 8) + chr(crc & 0xff) + return crc + +class PT_uint32: + is_int = 1 + max_length = 5 + signed = 0 + def encode(self, out, v): + if v >= 0xc000000 or v < -0x4000000: out.append((v>>28) & 0x7f | 0x80) + if v >= 0x180000 or v < -0x80000: out.append((v>>21) & 0x7f | 0x80) + if v >= 0x3000 or v < -0x1000: out.append((v>>14) & 0x7f | 0x80) + if v >= 0x60 or v < -0x20: out.append((v>>7) & 0x7f | 0x80) + out.append(v & 0x7f) + def parse(self, s, pos): + c = s[pos] + pos += 1 + v = c & 0x7f + if (c & 0x60) == 0x60: + v |= -0x20 + while c & 0x80: + c = s[pos] + pos += 1 + v = (v<<7) | (c & 0x7f) + if not self.signed: + v = int(v & 0xffffffff) + return v, pos + +class PT_int32(PT_uint32): + signed = 1 +class PT_uint16(PT_uint32): + max_length = 3 +class PT_int16(PT_int32): + signed = 1 + max_length = 3 +class PT_byte(PT_uint32): + max_length = 2 + +class PT_string: + is_int = 0 + max_length = 64 + def encode(self, out, v): + out.append(len(v)) + out.extend(bytearray(v)) + def parse(self, s, pos): + l = s[pos] + return str(bytearray(s[pos+1:pos+l+1])), pos+l+1 +class PT_progmem_buffer(PT_string): + pass +class PT_buffer(PT_string): + pass + +MessageTypes = { + '%u': PT_uint32(), '%i': PT_int32(), + '%hu': PT_uint16(), '%hi': PT_int16(), + '%c': PT_byte(), + '%s': PT_string(), '%.*s': PT_progmem_buffer(), '%*s': PT_buffer(), +} + +# Update the message format to be compatible with python's % operator +def convert_msg_format(msgformat): + mf = msgformat.replace('%c', '%u') + mf = mf.replace('%.*s', '%s').replace('%*s', '%s') + return mf + +class MessageFormat: + def __init__(self, msgid, msgformat): + self.msgid = msgid + self.msgformat = msgformat + self.debugformat = convert_msg_format(msgformat) + parts = msgformat.split() + self.name = parts[0] + argparts = [arg.split('=') for arg in parts[1:]] + self.param_types = [MessageTypes[fmt] for name, fmt in argparts] + self.param_names = [name for name, fmt in argparts] + self.name_to_type = dict(zip(self.param_names, self.param_types)) + def encode(self, *params): + out = [] + out.append(self.msgid) + for i, t in enumerate(self.param_types): + t.encode(out, params[i]) + return out + def encode_by_name(self, **params): + out = [] + out.append(self.msgid) + for name, t in zip(self.param_names, self.param_types): + t.encode(out, params[name]) + return out + def parse(self, s, pos): + pos += 1 + out = {} + for t, name in zip(self.param_types, self.param_names): + v, pos = t.parse(s, pos) + out[name] = v + return out, pos + def dump(self, s, pos): + pos += 1 + out = [] + for t in self.param_types: + v, pos = t.parse(s, pos) + if not t.is_int: + v = repr(v) + out.append(v) + outmsg = self.debugformat % tuple(out) + return outmsg, pos + +class OutputFormat: + name = '#output' + def __init__(self, msgid, msgformat): + self.msgid = msgid + self.msgformat = msgformat + self.debugformat = convert_msg_format(msgformat) + self.param_types = [] + args = msgformat + while 1: + pos = args.find('%') + if pos < 0: + break + if pos+1 >= len(args) or args[pos+1] != '%': + for i in range(4): + t = MessageTypes.get(args[pos:pos+1+i]) + if t is not None: + self.param_types.append(t) + break + else: + raise error("Invalid output format for '%s'" % (msg,)) + args = args[pos+1:] + def parse(self, s, pos): + pos += 1 + out = [] + for t in self.param_types: + v, pos = t.parse(s, pos) + out.append(v) + outmsg = self.debugformat % tuple(out) + return {'#msg': outmsg}, pos + def dump(self, s, pos): + pos += 1 + out = [] + for t in self.param_types: + v, pos = t.parse(s, pos) + out.append(v) + outmsg = self.debugformat % tuple(out) + return outmsg, pos + +class UnknownFormat: + name = '#unknown' + def parse(self, s, pos): + msgid = s[pos] + msg = str(bytearray(s)) + return {'#msgid': msgid, '#msg': msg}, len(s)-MESSAGE_TRAILER_SIZE + +class MessageParser: + def __init__(self): + self.unknown = UnknownFormat() + self.messages_by_id = {} + self.messages_by_name = {} + self.static_strings = [] + self.config = {} + self.version = "" + self.raw_identify_data = "" + self._init_messages(DefaultMessages, DefaultMessages.keys()) + def check_packet(self, s): + if len(s) < MESSAGE_MIN: + return 0 + msglen = ord(s[MESSAGE_POS_LEN]) + if msglen < MESSAGE_MIN or msglen > MESSAGE_MAX: + return -1 + msgseq = ord(s[MESSAGE_POS_SEQ]) + if (msgseq & ~MESSAGE_SEQ_MASK) != MESSAGE_DEST: + return -1 + if len(s) < msglen: + # Need more data + return 0 + if s[msglen-MESSAGE_TRAILER_SYNC] != MESSAGE_SYNC: + return -1 + msgcrc = s[msglen-MESSAGE_TRAILER_CRC:msglen-MESSAGE_TRAILER_CRC+2] + crc = crc16_ccitt(s[:msglen-MESSAGE_TRAILER_SIZE]) + if crc != msgcrc: + #logging.debug("got crc %s vs %s" % (repr(crc), repr(msgcrc))) + return -1 + return msglen + def dump(self, s): + msgseq = s[MESSAGE_POS_SEQ] + out = ["seq: %02x" % (msgseq,)] + pos = MESSAGE_HEADER_SIZE + while 1: + msgid = s[pos] + mid = self.messages_by_id.get(msgid, self.unknown) + params, pos = mid.dump(s, pos) + out.append("%s" % (params,)) + if pos >= len(s)-MESSAGE_TRAILER_SIZE: + break + return out + def parse(self, s): + msgid = s[MESSAGE_HEADER_SIZE] + mid = self.messages_by_id.get(msgid, self.unknown) + params, pos = mid.parse(s, MESSAGE_HEADER_SIZE) + if pos != len(s)-MESSAGE_TRAILER_SIZE: + raise error("Extra data at end of message") + params['#name'] = mid.name + static_string_id = params.get('static_string_id') + if static_string_id is not None: + params['#msg'] = self.static_strings[static_string_id] + return params + def encode(self, seq, cmd): + msglen = MESSAGE_MIN + len(cmd) + seq = (seq & MESSAGE_SEQ_MASK) | MESSAGE_DEST + out = [chr(msglen), chr(seq), cmd] + out.append(crc16_ccitt(''.join(out))) + out.append(MESSAGE_SYNC) + return ''.join(out) + def _parse_buffer(self, value): + tval = int(value, 16) + out = [] + for i in range(len(value)/2): + out.append(tval & 0xff) + tval >>= 8 + out.reverse() + return ''.join([chr(i) for i in out]) + def lookup_command(self, msgformat): + parts = msgformat.strip().split() + msgname = parts[0] + mp = self.messages_by_name.get(msgname) + if mp is None: + raise error("Unknown command: %s" % (msgname,)) + if msgformat != mp.msgformat: + raise error("Command format mismatch: %s vs %s" % ( + msgformat, mp.msgformat)) + return mp + def create_command(self, msg): + parts = msg.strip().split() + if not parts: + return "" + msgname = parts[0] + mp = self.messages_by_name.get(msgname) + if mp is None: + raise error("Unknown command: %s" % (msgname,)) + try: + argparts = dict(arg.split('=', 1) for arg in parts[1:]) + for name, value in argparts.items(): + t = mp.name_to_type[name] + if t.is_int: + tval = int(value, 0) + else: + tval = self._parse_buffer(value) + argparts[name] = tval + except: + #traceback.print_exc() + raise error("Unable to extract params from: %s" % (msgname,)) + try: + cmd = mp.encode_by_name(**argparts) + except: + #traceback.print_exc() + raise error("Unable to encode: %s" % (msgname,)) + return cmd + def _init_messages(self, messages, parsers): + for msgid, msgformat in messages.items(): + msgid = int(msgid) + if msgid not in parsers: + self.messages_by_id[msgid] = OutputFormat(msgid, msgformat) + continue + msg = MessageFormat(msgid, msgformat) + self.messages_by_id[msgid] = msg + self.messages_by_name[msg.name] = msg + def process_identify(self, data, decompress=True): + if decompress: + data = zlib.decompress(data) + self.raw_identify_data = data + data = json.loads(data) + messages = data.get('messages') + commands = data.get('commands') + responses = data.get('responses') + self._init_messages(messages, commands+responses) + self.static_strings = data.get('static_strings', []) + self.config.update(data.get('config', {})) + self.version = data.get('version', '') -- cgit v1.2.3-70-g09d2