protocol.py 6.57 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
import io
import socket
import struct

from collections import namedtuple
from collections import OrderedDict

from .compat import iteritems
from .exception import DeserializationException


class Transport(object):
    HEADER_LENGTH = 4
    MAX_SEGMENT = 512 * 1024

    def __init__(self, sock):
        self.socket = sock

    def send(self, packet):
        self.socket.sendall(struct.pack("!I", len(packet)) + packet)

    def receive(self):
        raw_length = self.socket.recv(self.HEADER_LENGTH)
        length, = struct.unpack("!I", raw_length)
        payload = self.socket.recv(length)
        return payload

    def close(self):
        self.socket.shutdown(socket.SHUT_RDWR)
        self.socket.close()


class Packet(object):
    CMD_REQUEST = 0         # Named request message
    CMD_RESPONSE = 1        # Unnamed response message for a request
    CMD_UNKNOWN = 2         # Unnamed response if requested command is unknown
    EVENT_REGISTER = 3      # Named event registration request
    EVENT_UNREGISTER = 4    # Named event de-registration request
    EVENT_CONFIRM = 5       # Unnamed confirmation for event (de-)registration
    EVENT_UNKNOWN = 6       # Unnamed response if event (de-)registration failed
    EVENT = 7               # Named event message

    ParsedPacket = namedtuple(
        "ParsedPacket",
        ["response_type", "payload"]
    )

    ParsedEventPacket = namedtuple(
        "ParsedEventPacket",
        ["response_type", "event_type", "payload"]
    )

    @classmethod
    def _named_request(cls, request_type, request, message=None):
        request = request.encode()
        payload = struct.pack("!BB", request_type, len(request)) + request
        if message is not None:
            return payload + message
        else:
            return payload

    @classmethod
    def request(cls, command, message=None):
        return cls._named_request(cls.CMD_REQUEST, command, message)

    @classmethod
    def register_event(cls, event_type):
        return cls._named_request(cls.EVENT_REGISTER, event_type)

    @classmethod
    def unregister_event(cls, event_type):
        return cls._named_request(cls.EVENT_UNREGISTER, event_type)

    @classmethod
    def parse(cls, packet):
        stream = FiniteStream(packet)
        response_type, = struct.unpack("!B", stream.read(1))

        if response_type == cls.EVENT:
            length, = struct.unpack("!B", stream.read(1))
            event_type = stream.read(length)
            return cls.ParsedEventPacket(response_type, event_type, stream)
        else:
            return cls.ParsedPacket(response_type, stream)


class Message(object):
    SECTION_START = 1       # Begin a new section having a name
    SECTION_END = 2         # End a previously started section
    KEY_VALUE = 3           # Define a value for a named key in the section
    LIST_START = 4          # Begin a named list for list items
    LIST_ITEM = 5           # Define an unnamed item value in the current list
    LIST_END = 6            # End a previously started list

    @classmethod
    def serialize(cls, message):
        def encode_named_type(marker, name):
            name = name.encode()
            return struct.pack("!BB", marker, len(name)) + name

        def encode_blob(value):
            if not isinstance(value, bytes):
                value = str(value).encode()
            return struct.pack("!H", len(value)) + value

        def serialize_list(lst):
            segment = bytes()
            for item in lst:
                segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
            return segment

        def serialize_dict(d):
            segment = bytes()
            for key, value in iteritems(d):
                if isinstance(value, dict):
                    segment += (
                        encode_named_type(cls.SECTION_START, key)
                        + serialize_dict(value)
                        + struct.pack("!B", cls.SECTION_END)
                    )
                elif isinstance(value, list):
                    segment += (
                        encode_named_type(cls.LIST_START, key)
                        + serialize_list(value)
                        + struct.pack("!B", cls.LIST_END)
                    )
                else:
                    segment += (
                        encode_named_type(cls.KEY_VALUE, key)
                        + encode_blob(value)
                    )
            return segment

        return serialize_dict(message)

    @classmethod
    def deserialize(cls, stream):
        def decode_named_type(stream):
            length, = struct.unpack("!B", stream.read(1))
            return stream.read(length).decode()

        def decode_blob(stream):
            length, = struct.unpack("!H", stream.read(2))
            return stream.read(length)

        def decode_list_item(stream):
            marker, = struct.unpack("!B", stream.read(1))
            while marker == cls.LIST_ITEM:
                yield decode_blob(stream)
                marker, = struct.unpack("!B", stream.read(1))

            if marker != cls.LIST_END:
                raise DeserializationException(
                    "Expected end of list at {pos}".format(pos=stream.tell())
                )

        section = OrderedDict()
        section_stack = []
        while stream.has_more():
            element_type, = struct.unpack("!B", stream.read(1))
            if element_type == cls.SECTION_START:
                section_name = decode_named_type(stream)
                new_section = OrderedDict()
                section[section_name] = new_section
                section_stack.append(section)
                section = new_section

            elif element_type == cls.LIST_START:
                list_name = decode_named_type(stream)
                section[list_name] = [item for item in decode_list_item(stream)]

            elif element_type == cls.KEY_VALUE:
                key = decode_named_type(stream)
                section[key] = decode_blob(stream)

            elif element_type == cls.SECTION_END:
                if len(section_stack):
                    section = section_stack.pop()
                else:
                    raise DeserializationException(
                        "Unexpected end of section at {pos}".format(
                            pos=stream.tell()
                        )
                    )

        if len(section_stack):
            raise DeserializationException("Expected end of section")
        return section


class FiniteStream(io.BytesIO):
    def __len__(self):
        return len(self.getvalue())

    def has_more(self):
        return self.tell() < len(self)