import io
import os
import errno
import socket
import logging
from hashlib import md5
from collections import deque

import pyev

from ccnode.exc import TunnelError
from ccnode.jobs import BaseThreadedJob


logger = logging.getLogger(__name__)


class ImportVolume(BaseThreadedJob):
    """Import volume job.

    """
    BUFFER_LEN = 8192 * 16
    HASH = md5

    def __init__(self, job_manager, ev_loop, volume):
        BaseThreadedJob.__init__(self, job_manager, ev_loop)

        self.checksum = None
        self.volume = volume
        # where the other node will connect
        self.port = None

        # fds
        self.sock = None
        self.client_sock = None
        self.disk = None

    def clean_fds(self):
        if self.sock is not None:
            self.sock.close()
            self.sock = None
        if self.client_sock is not None:
            self.client_sock.close()
            self.client_sock = None
        if self.disk is not None:
            self.disk.close()
            self.disk = None

    def pre_job(self):
        """
        :returns: port number the socket is listening on
        """
        # create socket
        try:
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        except socket.error:
            logger.exception('Error while creating socket for volume export')
            self.clean_fds()
            raise
        try:
            self.sock.settimeout(10.)
        except socket.error:
            logger.exception('Cannot set timeout on socket for volume export')
            self.clean_fds()
            raise
        try:
            self.sock.bind(('0.0.0.0', 0))
        except socket.error:
            logger.exception('Error while binding socket for volume export')
            self.clean_fds()
            raise
        try:
            self.sock.listen(1)
        except socket.error:
            logger.exception('Error while listening on socket')
            self.clean_fds()
            raise

        # open local disk
        try:
            self.disk = io.open(self.volume.path, 'wb', 0)
        except IOError:
            logger.exception('Error while trying to open local disk')
            self.clean_fds()
            raise

        self.port = self.sock.getsockname()[1]
        return self.port

    def run_job(self):
        # FIXME raised exceptions in this functions will be in the context of a
        # thread that is not running in the sjRPC, therefore these won't be
        # caught
        try:
            self.client_sock, _ = self.sock.accept()
        except socket.timeout:
            logger.exception('Error for importing job: client did not connect')
            self.clean_fds()
            raise
        except socket.error:
            logger.exception('Error while accepting socket')
            self.clean_fds()
            raise

        # close the listening socket
        self.sock.close()
        self.sock = None

        checksum = self.HASH()

        # start downloading disk image
        while self.running:
            try:
                received = []  # keep a list of received buffers in order to do
                               # only one concatenation in the end
                total_received = 0
                while True:
                    recv_buf = self.client_sock.recv(self.BUFFER_LEN - total_received)
                    # logger.debug('Received %d', len(recv_buf))
                    if not recv_buf:  # EOF
                        # in case received in not empty, we will come back here
                        # once again and it returns EOF one more time
                        break
                    total_received += len(recv_buf)
                    received.append(recv_buf)
                    if total_received == self.BUFFER_LEN:
                        break
            except socket.error:
                logger.exception('Error while receiving disk image')
                self.clean_fds()
                raise
            buffer_ = b''.join(received)
            if not buffer_:
                logger.debug('Received EOF import job')
                break
            checksum.update(buffer_)
            try:
                written = 0
                # FIXME never write small chuncks
                # in which case does disk.write would not write all the buffer ?
                to_send = buffer_
                while True:
                    written += self.disk.write(to_send)
                    # logger.debug('Written %s to disk', written)
                    to_send = buffer(buffer_, written)
                    if not to_send:
                        break
            except IOError:
                logger.exception('Error while writing image to disk')
                self.clean_fds()
                raise

        # here we could not have received the full disk but we don't consider
        # this as an error in the import part
        self.checksum = checksum.hexdigest()
        # clean the fds
        self.clean_fds()
        logger.debug('Volume import done')


class ExportVolume(BaseThreadedJob):
    """Export volume job.

    """
    BUFFER_LEN = 8192 * 16
    HASH = md5

    def __init__(self, job_manager, ev_loop, volume, raddr, rport):
        """
        :param volume: :class:`Volume` instance
        :param raddr: remote IP address
        :param rport: remote TCP port
        """
        BaseThreadedJob.__init__(self, job_manager, ev_loop)

        # where to connect to send the volume
        self.raddr = raddr
        self.rport = rport

        self.volume = volume
        self.checksum = None

        # fds
        self.sock = None
        self.disk = None

    def clean_fds(self):
        if self.sock is not None:
            self.sock.close()
            self.sock = None
        if self.disk is not None:
            self.disk.close()
            self.disk = None

    def pre_job(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        # connect to the remote host
        try:
            self.sock.connect((self.raddr, self.rport))
        except socket.error as exc:
            logger.exception('Error while trying to connect to remote host %s',
                            os.strerror(exc.errno))
            self.clean_fds()
            raise

        # open local volume
        try:
            self.disk = io.open(self.volume.path, 'rb', 0)
        except IOError:
            logger.exception('Error while opening disk for export job')
            self.clean_fds()
            raise

    def run_job(self):
        checksum = self.HASH()
        # sent_count = 0

        # do copy
        while self.running:
            try:
                read = self.disk.read(self.BUFFER_LEN)
            except IOError:
                logger.exception('Error while reading from disk')
                self.clean_fds()
                break
            # read length may be less than BUFFER_LEN but we don't care as it
            # will go over TCP
            if not read:  # end of file
                # logger.debug('EOF, exported %d bytes', sent_count)
                break
            # sent_count += len(read)
            # logger.debug('Read %d from disk', len(read))
            checksum.update(read)
            try:
                self.sock.sendall(read)
            except socket.error:
                logger.exception('Error while sending through socket')
                self.clean_fds()
                break


        self.checksum = checksum.hexdigest()
        self.clean_fds()


class SocketBuffer(deque):
    """Holds bytes in a list.

    This class don't handle maximum size but instead give help like handling
    count automatically.
    """
    def __init__(self, max_len=8 * 64 * 1024):
        deque.__init__(self)
        self.max_len = max_len
        self.current_len = 0

    def append(self, x):
        deque.append(self, x)
        self.current_len += len(x)

    def appendleft(self, x):
        deque.appendleft(self, x)
        self.current_len += len(x)

    def clear(self):
        deque.clear(self)
        self.current_len = 0

    def extend(self, iterable):
        raise NotImplementedError

    def extendleft(self, iterable):
        raise NotImplementedError

    def pop(self):
        elt = deque.pop(self)
        self.current_len -= len(elt)
        return elt

    def popleft(self):
        elt = deque.popleft(self)
        self.current_len -= len(elt)
        return elt

    def remove(value):
        raise NotImplementedError

    def reverse(self):
        raise NotImplementedError

    def rotate(self, n):
        raise NotImplementedError

    def is_full(self):
        return self.current_len >= self.max_len

    def is_empty(self):
        return self.current_len == 0


class TCPTunnel(object):
    """Handles a TCP tunnel."""

    BUFFER_LEN = 8096

    def __init__(self, job_manager, ev_loop, connect=None, listen='0.0.0.0'):
        """
        :param job_manager: :class:`JobManager` instance
        :param ev_loop: pyev loop instance (to create watchers from)
        :param connect: where to connect one end of the tunnel (a tuple, as
            given to socket.connect)
        :param listen: which interface to listen to for the other end of the
            tunnel
        """
        #: job id
        self.id = job_manager.job_id.next()

        self.ev_loop = ev_loop
        self.connect = connect
        self.listen = listen
        #: port is assigned by the kernel
        self.port = None

        # keep state information for both ends
        self.listen_state = 'CLOSED'
        self.connect_state = 'CLOSED'
        #: very basic error report
        self.error = None

        # these are the watchers
        self.source_reader = None
        self.source_writer = None
        self.dest_reader = None
        self.dest_writer = None

        #: source_sock is the socket that will listen for remote|local to happen
        self.source_sock = None
        #: dest sock connects to an other setuped tunnel
        self.dest_sock = None

        # input buffer is used for data that is coming from source_sock and goes
        # to dest_sock
        self.input_buffer = SocketBuffer()
        # output_buffer is usde for data that is coming from dest_sock and goes
        # to source_sock
        self.output_buffer = SocketBuffer()

    def close(self):
        logger.debug('Closing job %d', self.id)
        # stop watchers
        if self.source_reader is not None:
            self.source_reader.stop()
            self.source_reader = None
        if self.source_writer is not None:
            self.source_writer.stop()
            self.source_writer = None
        if self.dest_reader is not None:
            self.dest_reader.stop()
            self.dest_reader = None
        if self.dest_writer is not None:
            self.dest_writer.stop()
            self.dest_writer = None
        # close sockets
        if self.source_sock is not None:
            self.source_sock.close()
            self.source_sock = None
        if self.dest_sock is not None:
            self.dest_sock.close()
            self.dest_sock = None
        # clear buffers (this memory won't be needed anyway)
        self.input_buffer = None
        self.output_buffer = None
        # reset states
        self.listen_state = 'CLOSED'
        self.connect_state = 'CLOSED'

    def stop(self):
        self.close()

    def setup_listen(self, interface=None):
        """Setup source socket.

        :param interface: specify which interface to listen onto
        """
        if interface is not None:
            self.listening = interface
        logger.debug('Setup listening %s %d', self.listen, self.id)
        try:
            self.source_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        except socket.error:
            logger.exception('Error while creating source_sock for tunnel job'
                             ' %d', self.id)
            self.close()
            raise
        try:
            self.source_sock.setblocking(0)
        except socket.error:
            logger.exception('Cannot set source_sock in blocking mode for'
                             ' tunnel job %d', self.id)
            self.close()
            raise
        try:
            self.source_sock.bind((self.listen, 0))
        except socket.error:
            logger.exception('Error while binding source_sock for tunnel job'
                             ' %d', self.id)
            self.close()
            raise
        self.port = self.source_sock.getsockname()[1]
        try:
            self.source_sock.listen(1)
        except socket.error:
            logger.exception('Error while listening on source_sock for tunnel'
                             ' job %d', self.id)
            self.close()
            raise

        self.listen_state = 'LISTENING'
        # ready to accept
        self.source_reader = self.ev_loop.io(self.source_sock,
                                             pyev.EV_READ, self.accept_cb)
        self.source_reader.start()

    def setup_connect(self, endpoint=None):
        """Start connection to remote end.

        :param endpoint: specify where to connect (same as connect argument in
        constructor), can be specified in both places
        """
        if endpoint is not None:
            self.connect = endpoint
        if self.connect is None:
            raise TunnelError('Remote endpoint to connect to was not specified')
        logger.debug('Connect to endpoint %s %d', self.connect, self.id)
        try:
            if isinstance(self.connect, tuple):
                addr_family = socket.AF_INET
            else:
                addr_family = socket.AF_UNIX
            self.dest_sock = socket.socket(addr_family, socket.SOCK_STREAM)
        except socket.error:
            logger.exception('Error while creating dest_sock for tunnel job'
                             ' %d', self.id)
            self.close()
            raise
        try:
            self.dest_sock.setblocking(0)
        except socket.error:
            logger.exception('Error while sitting non block mode on dest_sock'
                             ' for tunnel job %d', self.id)
            raise

        error = self.dest_sock.connect_ex(self.connect)
        if error and error != errno.EINPROGRESS:
            raise socket.error('Error during connect for tunnel job, %s' %
                               os.strerror(error))
        self.dest_writer = self.ev_loop.io(self.dest_sock,
                                           pyev.EV_WRITE, self.connect_cb)
        self.dest_writer.start()

        self.connect_state = 'CONNECTING'

    def accept_cb(self, watcher, revents):
        try:
            new_source, remote = self.source_sock.accept()
        except socket.error as exc:
            if exc.errno == errno.EAGAIN or errno.EWOULDBLOCK:
                # we will come back
                return

            # else
            logger.exception('Error while accepting new connection on'
                             ' sock_source for tunnel job')
            self.close()
            self.error = exc.errno
            return

        # everything went fine
        self.source_sock.close()  # we won't accept connections
        self.source_sock = new_source
        # set new socket non blocking
        try:
            self.source_sock.setblocking(0)
        except socket.error as exc:
            logger.exception('Cannot set source socket in non blocking for'
                             ' tunnel job: %s', os.strerror(exc.errno))
            self.close()
            self.error = exc.errno
            return
        self.source_reader.stop()
        self.source_reader = self.ev_loop.io(new_source, pyev.EV_READ,
                                             self.read_cb)
        self.source_writer = self.ev_loop.io(new_source, pyev.EV_WRITE,
                                             self.write_cb)
        logger.debug('Successfully accepted remote client %s for tunnel job %d',
                     remote, self.id)
        self.listen_state = 'CONNECTED'
        if self.connect_state == 'CONNECTED':
            # start the watchers only if both ends are ready to accept data
            self.source_reader.start()
            self.dest_reader.start()

    def connect_cb(self, watcher, revents):
        # check that connection was a success
        error = self.dest_sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
        if error:
            logger.error('Error during connect for tunnel job, %s' %
                         os.strerror(error))
            self.close()
            return

        # else we setup watcher with proper events
        self.dest_reader = self.ev_loop.io(self.dest_sock, pyev.EV_READ,
                                           self.read_cb)
        self.dest_writer.stop()
        self.dest_writer = self.ev_loop.io(self.dest_sock, pyev.EV_WRITE,
                                           self.write_cb)
        logger.debug('Successfully connected to remote endpoint %s %d',
                     self.connect, self.id)
        self.connect_state = 'CONNECTED'
        if self.listen_state == 'CONNECTED':
            # start the watchers only if both ends are ready to accept data
            self.source_reader.start()
            self.dest_reader.start()

    def read_cb(self, watcher, revents):
        if watcher == self.dest_reader:
            # logger.debug('Read event on dest %s', self.id)
            sock = self.dest_sock
            buffer_ = self.output_buffer
            other_watcher = self.source_writer
        else:
            # logger.debug('Read event on source %s', self.id)
            sock = self.source_sock
            buffer_ = self.input_buffer
            other_watcher = self.dest_writer

        # logger.debug('Will loop into event')
        while True:
            try:
                incoming = sock.recv(self.BUFFER_LEN)
            except socket.error as exc:
                if exc.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
                    # logger.debug('EAGAIN')
                    break
                # else: unexpected error
                logger.exception('Unexpected error while reading on socket'
                                 ' for tunnel job, %s', os.strerror(exc.errno))
                self.close()
                self.error = exc.errno
                return

            if not incoming:
                # EOF
                # logger.debug('EOF')
                self.close()
                return
            # logger.debug('Read %d bytes', len(incoming))
            buffer_.append(incoming)
            if buffer_.is_full():
                # logger.debug('Buffer is full')
                watcher.stop()
                break

        # we did read some bytes that we could write to the other end
        if not buffer_.is_empty():
            # logger.debug('Starting other watcher')
            other_watcher.start()

        # logger.debug('Read event done')

    def write_cb(self, watcher, revents):
        if watcher == self.dest_writer:
            # logger.debug('Write event on dest %s', self.id)
            sock = self.dest_sock
            buffer_ = self.input_buffer
            other_watcher = self.source_reader
        else:
            # logger.debug('Write event on source %s', self.id)
            sock = self.source_sock
            buffer_ = self.output_buffer
            other_watcher = self.dest_reader

        while True:
            try:
                to_send = buffer_.popleft()
            except IndexError:
                # buffer is empty, we should stop write event
                # logger.debug('Buffer is empty')
                watcher.stop()
                break
            send_buffer = to_send
            total_sent = 0
            while True:
                try:
                    written = sock.send(send_buffer)
                except socket.error as exc:
                    if exc.errno in (errno.EAGAIN, errno.EWOULDBLOCK):
                        buffer_.appendleft(to_send[total_sent:])
                        # logger.debug('EAGAIN')
                        break
                    # else: unexpected error
                    logger.exception('Unexpected error while writting on socket'
                                     ' for tunnel job, %s',
                                     os.strerror(exc.errno))
                    self.close()
                    self.error = exc.errno
                    return

                # logger.debug('Written %d bytes', written)
                if written == len(send_buffer):
                    break

                # else
                total_sent += written
                send_buffer = buffer(to_send, total_sent)

        # if we can read on the other end
        if not buffer_.is_full():
            # logger.debug('Starting other watcher')
            other_watcher.start()

        # logger.debug('Proccessed write event')