import io
import socket
import struct

from collections import namedtuple
from collections import OrderedDict

from .compat import iteritems
from .exception import DeserializationException


class Transport(object):
    HEADER_LENGTH = 4
    MAX_SEGMENT = 512 * 1024

    def __init__(self, sock):
        self.socket = sock

    def send(self, packet):
        self.socket.sendall(struct.pack("!I", len(packet)) + packet)

    def receive(self):
        raw_length = self.socket.recv(self.HEADER_LENGTH)
        length, = struct.unpack("!I", raw_length)
        payload = self.socket.recv(length)
        return payload

    def close(self):
        self.socket.shutdown(socket.SHUT_RDWR)
        self.socket.close()


class Packet(object):
    CMD_REQUEST = 0         # Named request message
    CMD_RESPONSE = 1        # Unnamed response message for a request
    CMD_UNKNOWN = 2         # Unnamed response if requested command is unknown
    EVENT_REGISTER = 3      # Named event registration request
    EVENT_UNREGISTER = 4    # Named event de-registration request
    EVENT_CONFIRM = 5       # Unnamed confirmation for event (de-)registration
    EVENT_UNKNOWN = 6       # Unnamed response if event (de-)registration failed
    EVENT = 7               # Named event message

    ParsedPacket = namedtuple(
        "ParsedPacket",
        ["response_type", "payload"]
    )

    ParsedEventPacket = namedtuple(
        "ParsedEventPacket",
        ["response_type", "event_type", "payload"]
    )

    @classmethod
    def _named_request(cls, request_type, request, message=None):
        request = request.encode()
        payload = struct.pack("!BB", request_type, len(request)) + request
        if message is not None:
            return payload + message
        else:
            return payload

    @classmethod
    def request(cls, command, message=None):
        return cls._named_request(cls.CMD_REQUEST, command, message)

    @classmethod
    def register_event(cls, event_type):
        return cls._named_request(cls.EVENT_REGISTER, event_type)

    @classmethod
    def unregister_event(cls, event_type):
        return cls._named_request(cls.EVENT_UNREGISTER, event_type)

    @classmethod
    def parse(cls, packet):
        stream = FiniteStream(packet)
        response_type, = struct.unpack("!B", stream.read(1))

        if response_type == cls.EVENT:
            length, = struct.unpack("!B", stream.read(1))
            event_type = stream.read(length)
            return cls.ParsedEventPacket(response_type, event_type, stream)
        else:
            return cls.ParsedPacket(response_type, stream)


class Message(object):
    SECTION_START = 1       # Begin a new section having a name
    SECTION_END = 2         # End a previously started section
    KEY_VALUE = 3           # Define a value for a named key in the section
    LIST_START = 4          # Begin a named list for list items
    LIST_ITEM = 5           # Define an unnamed item value in the current list
    LIST_END = 6            # End a previously started list

    @classmethod
    def serialize(cls, message):
        def encode_named_type(marker, name):
            name = name.encode()
            return struct.pack("!BB", marker, len(name)) + name

        def encode_blob(value):
            if not isinstance(value, bytes):
                value = str(value).encode()
            return struct.pack("!H", len(value)) + value

        def serialize_list(lst):
            segment = bytes()
            for item in lst:
                segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
            return segment

        def serialize_dict(d):
            segment = bytes()
            for key, value in iteritems(d):
                if isinstance(value, dict):
                    segment += (
                        encode_named_type(cls.SECTION_START, key)
                        + serialize_dict(value)
                        + struct.pack("!B", cls.SECTION_END)
                    )
                elif isinstance(value, list):
                    segment += (
                        encode_named_type(cls.LIST_START, key)
                        + serialize_list(value)
                        + struct.pack("!B", cls.LIST_END)
                    )
                else:
                    segment += (
                        encode_named_type(cls.KEY_VALUE, key)
                        + encode_blob(value)
                    )
            return segment

        return serialize_dict(message)

    @classmethod
    def deserialize(cls, stream):
        def decode_named_type(stream):
            length, = struct.unpack("!B", stream.read(1))
            return stream.read(length).decode()

        def decode_blob(stream):
            length, = struct.unpack("!H", stream.read(2))
            return stream.read(length)

        def decode_list_item(stream):
            marker, = struct.unpack("!B", stream.read(1))
            while marker == cls.LIST_ITEM:
                yield decode_blob(stream)
                marker, = struct.unpack("!B", stream.read(1))

            if marker != cls.LIST_END:
                raise DeserializationException(
                    "Expected end of list at {pos}".format(pos=stream.tell())
                )

        section = OrderedDict()
        section_stack = []
        while stream.has_more():
            element_type, = struct.unpack("!B", stream.read(1))
            if element_type == cls.SECTION_START:
                section_name = decode_named_type(stream)
                new_section = OrderedDict()
                section[section_name] = new_section
                section_stack.append(section)
                section = new_section

            elif element_type == cls.LIST_START:
                list_name = decode_named_type(stream)
                section[list_name] = [item for item in decode_list_item(stream)]

            elif element_type == cls.KEY_VALUE:
                key = decode_named_type(stream)
                section[key] = decode_blob(stream)

            elif element_type == cls.SECTION_END:
                if len(section_stack):
                    section = section_stack.pop()
                else:
                    raise DeserializationException(
                        "Unexpected end of section at {pos}".format(
                            pos=stream.tell()
                        )
                    )

        if len(section_stack):
            raise DeserializationException("Expected end of section")
        return section


class FiniteStream(io.BytesIO):
    def __len__(self):
        return len(self.getvalue())

    def has_more(self):
        return self.tell() < len(self)