import io import socket import struct from collections import namedtuple from collections import OrderedDict from .compat import iteritems from .exception import DeserializationException class Transport(object): HEADER_LENGTH = 4 MAX_SEGMENT = 512 * 1024 def __init__(self, sock): self.socket = sock def send(self, packet): self.socket.sendall(struct.pack("!I", len(packet)) + packet) def receive(self): raw_length = self.socket.recv(self.HEADER_LENGTH) length, = struct.unpack("!I", raw_length) payload = self.socket.recv(length) return payload def close(self): self.socket.shutdown(socket.SHUT_RDWR) self.socket.close() class Packet(object): CMD_REQUEST = 0 # Named request message CMD_RESPONSE = 1 # Unnamed response message for a request CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown EVENT_REGISTER = 3 # Named event registration request EVENT_UNREGISTER = 4 # Named event de-registration request EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed EVENT = 7 # Named event message ParsedPacket = namedtuple( "ParsedPacket", ["response_type", "payload"] ) ParsedEventPacket = namedtuple( "ParsedEventPacket", ["response_type", "event_type", "payload"] ) @classmethod def _named_request(cls, request_type, request, message=None): request = request.encode() payload = struct.pack("!BB", request_type, len(request)) + request if message is not None: return payload + message else: return payload @classmethod def request(cls, command, message=None): return cls._named_request(cls.CMD_REQUEST, command, message) @classmethod def register_event(cls, event_type): return cls._named_request(cls.EVENT_REGISTER, event_type) @classmethod def unregister_event(cls, event_type): return cls._named_request(cls.EVENT_UNREGISTER, event_type) @classmethod def parse(cls, packet): stream = FiniteStream(packet) response_type, = struct.unpack("!B", stream.read(1)) if response_type == cls.EVENT: length, = struct.unpack("!B", stream.read(1)) event_type = stream.read(length) return cls.ParsedEventPacket(response_type, event_type, stream) else: return cls.ParsedPacket(response_type, stream) class Message(object): SECTION_START = 1 # Begin a new section having a name SECTION_END = 2 # End a previously started section KEY_VALUE = 3 # Define a value for a named key in the section LIST_START = 4 # Begin a named list for list items LIST_ITEM = 5 # Define an unnamed item value in the current list LIST_END = 6 # End a previously started list @classmethod def serialize(cls, message): def encode_named_type(marker, name): name = name.encode() return struct.pack("!BB", marker, len(name)) + name def encode_blob(value): if not isinstance(value, bytes): value = str(value).encode() return struct.pack("!H", len(value)) + value def serialize_list(lst): segment = bytes() for item in lst: segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item) return segment def serialize_dict(d): segment = bytes() for key, value in iteritems(d): if isinstance(value, dict): segment += ( encode_named_type(cls.SECTION_START, key) + serialize_dict(value) + struct.pack("!B", cls.SECTION_END) ) elif isinstance(value, list): segment += ( encode_named_type(cls.LIST_START, key) + serialize_list(value) + struct.pack("!B", cls.LIST_END) ) else: segment += ( encode_named_type(cls.KEY_VALUE, key) + encode_blob(value) ) return segment return serialize_dict(message) @classmethod def deserialize(cls, stream): def decode_named_type(stream): length, = struct.unpack("!B", stream.read(1)) return stream.read(length).decode() def decode_blob(stream): length, = struct.unpack("!H", stream.read(2)) return stream.read(length) def decode_list_item(stream): marker, = struct.unpack("!B", stream.read(1)) while marker == cls.LIST_ITEM: yield decode_blob(stream) marker, = struct.unpack("!B", stream.read(1)) if marker != cls.LIST_END: raise DeserializationException( "Expected end of list at {pos}".format(pos=stream.tell()) ) section = OrderedDict() section_stack = [] while stream.has_more(): element_type, = struct.unpack("!B", stream.read(1)) if element_type == cls.SECTION_START: section_name = decode_named_type(stream) new_section = OrderedDict() section[section_name] = new_section section_stack.append(section) section = new_section elif element_type == cls.LIST_START: list_name = decode_named_type(stream) section[list_name] = [item for item in decode_list_item(stream)] elif element_type == cls.KEY_VALUE: key = decode_named_type(stream) section[key] = decode_blob(stream) elif element_type == cls.SECTION_END: if len(section_stack): section = section_stack.pop() else: raise DeserializationException( "Unexpected end of section at {pos}".format( pos=stream.tell() ) ) if len(section_stack): raise DeserializationException("Expected end of section") return section class FiniteStream(io.BytesIO): def __len__(self): return len(self.getvalue()) def has_more(self): return self.tell() < len(self)