# This file is part of CloudControl.
#
# CloudControl is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# CloudControl is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with CloudControl.  If not, see <http://www.gnu.org/licenses/>.


import os
import gc
import sys
import types
import errno
import signal
import pickle
import logging
import resource
import threading
import traceback
import subprocess
from collections import deque
from functools import wraps
from subprocess import _eintr_retry_call

import pyev

from cloudcontrol.common.client.utils import main_thread


logger = logging.getLogger(__name__)


def and_(iter_):
    """Do an and logic condition over the iterable element.

    :param iterable iter: meat for condition
    """
    for i in iter_:
        if not i:
            return False

    return True


def _main_thread(func):
    """EvPopen constructor decorator."""
    @wraps(func)
    def decorated(self, main_loop, *args, **kwargs):
        return main_loop.call_in_main_thread(func, self, main_loop, *args, **kwargs)
    return decorated


class EvPopen(subprocess.Popen):
    @_main_thread
    def __init__(self, main_loop, *args, **kwargs):
        """Class that acts as `subprocess.Popen` but uses libev child handling.
        :param main_loop: `NodeLoop` instance
        :param \*args: arguments for :py:class:`subprocess.Popen`
        :param \*\*kwargs: keyword arguments for :py:class:`subprocess.Popen`
        """
        self.main = main_loop

        # this could raise but don't worry about zombies, they will be collected
        # by libev
        subprocess.Popen.__init__(self, *args, **kwargs)

        # check stdout, stderr fileno and create watchers if needed
        self.stdout_watcher = self.stderr_watcher = None
        self.child_watcher = self.main.evloop.child(self.pid, False,
                                                    self.child_cb)
        self.child_watcher.start()

        self._stdout_output = list()
        self._stderr_output = list()
        # take an optional event for other threads to wait for process
        # termination
        self.stdout_done = threading.Event()
        self.stderr_done = threading.Event()
        self.process_done = threading.Event()

    @main_thread
    def create_std_watchers(self):
        if self.stdout is not None:
            self.stdout_watcher = self.main.evloop.io(self.stdout,
                                                      pyev.EV_READ,
                                                      self.stdout_cb)
            self.stdout_watcher.start()
        else:
            self.stdout_done.set()
        if self.stderr is not None and self.stderr.fileno() != self.stdout.fileno():
            self.stderr_watcher = self.main.evloop.io(self.stderr,
                                                      pyev.EV_READ,
                                                      self.stderr_cb)
            self.stderr_watcher.start()
        else:
            self.stderr_done.set()

    def stdout_cb(self, watcher, revents):
        data = os.read(watcher.fd, 1024)
        if data:
            self._stdout_output.append(data)
        else:
            self.stdout_watcher.stop()
            self.stdout_watcher = None
            self.stdout.close()
            self.stdout = None
            self.stdout_done.set()

    def stderr_cb(self, watcher, revents):
        data = os.read(watcher.fd, 1024)
        if data:
            self._stderr_output.append(data)
        else:
            self.stderr_watcher.stop()
            self.stderr_watcher = None
            self.stderr.close()
            self.stderr = None
            self.stderr_done.set()

    def child_cb(self, watcher, revents):
        self._handle_exitstatus(self.child_watcher.rstatus)
        self.child_watcher.stop()
        self.child_watcher = None
        self.process_done.set()

    # overiding parent methods
    def _internal_poll(self, *args, **kwargs):
        # ignore all parameters
        return self.returncode

    def _communicate(self, stdin=None):
        self.create_std_watchers()

        if stdin:
            if self.stdin is None:
                logger.warning('Ignoring stdin input for %s', self)
            else:
                fd = self.stdin.fileno()
                while True:
                    count = os.write(fd, stdin)
                    if count == len(stdin):
                        self.stdin.close()
                        self.stdin = None
                        break
                    else:
                        stdin = stdin[:count]

        self.stdout_done.wait()
        self.stderr_done.wait()
        # FIXME handle universal newlines
        self.process_done.wait()

        return tuple(map(u''.join, (self._stdout_output, self._stderr_output)))

    # This is basically a copy-paste from stdlib subprocess module to
    # prevent calling waitpid which would race with libev and would raise
    # ECHILD
    def _execute_child(self, args, executable, preexec_fn, close_fds,
                       cwd, env, universal_newlines,
                       startupinfo, creationflags, shell,
                       p2cread, p2cwrite,
                       c2pread, c2pwrite,
                       errread, errwrite):
        """Execute program (POSIX version)"""

        if isinstance(args, types.StringTypes):
            args = [args]
        else:
            args = list(args)

        if shell:
            args = ["/bin/sh", "-c"] + args

        if executable is None:
            executable = args[0]

        # For transferring possible exec failure from child to parent
        # The first char specifies the exception type: 0 means
        # OSError, 1 means some other error.
        errpipe_read, errpipe_write = os.pipe()
        try:
            try:
                self._set_cloexec_flag(errpipe_write)

                gc_was_enabled = gc.isenabled()
                # Disable gc to avoid bug where gc -> file_dealloc ->
                # write to stderr -> hang.  http://bugs.python.org/issue1336
                gc.disable()
                try:
                    self.pid = os.fork()
                except:
                    if gc_was_enabled:
                        gc.enable()
                    raise
                self._child_created = True
                if self.pid == 0:
                    # Child
                    try:
                        # Close parent's pipe ends
                        if p2cwrite is not None:
                            os.close(p2cwrite)
                        if c2pread is not None:
                            os.close(c2pread)
                        if errread is not None:
                            os.close(errread)
                        os.close(errpipe_read)

                        # Dup fds for child
                        if p2cread is not None:
                            os.dup2(p2cread, 0)
                        if c2pwrite is not None:
                            os.dup2(c2pwrite, 1)
                        if errwrite is not None:
                            os.dup2(errwrite, 2)

                        # Close pipe fds.  Make sure we don't close the same
                        # fd more than once, or standard fds.
                        if p2cread is not None and p2cread not in (0,):
                            os.close(p2cread)
                        if c2pwrite is not None and c2pwrite not in (p2cread, 1):
                            os.close(c2pwrite)
                        if errwrite is not None and errwrite not in (p2cread, c2pwrite, 2):
                            os.close(errwrite)

                        # Close all other fds, if asked for
                        if close_fds:
                            self._close_fds(but=errpipe_write)

                        if cwd is not None:
                            os.chdir(cwd)

                        if preexec_fn:
                            preexec_fn()

                        if env is None:
                            os.execvp(executable, args)
                        else:
                            os.execvpe(executable, args, env)

                    except:
                        exc_type, exc_value, tb = sys.exc_info()
                        # Save the traceback and attach it to the exception object
                        exc_lines = traceback.format_exception(exc_type,
                                                               exc_value,
                                                               tb)
                        exc_value.child_traceback = ''.join(exc_lines)
                        os.write(errpipe_write, pickle.dumps(exc_value))

                    # This exitcode won't be reported to applications, so it
                    # really doesn't matter what we return.
                    os._exit(255)

                # Parent
                if gc_was_enabled:
                    gc.enable()
            finally:
                # be sure the FD is closed no matter what
                os.close(errpipe_write)

            if p2cread is not None and p2cwrite is not None:
                os.close(p2cread)
            if c2pwrite is not None and c2pread is not None:
                os.close(c2pwrite)
            if errwrite is not None and errread is not None:
                os.close(errwrite)

            # Wait for exec to fail or succeed; possibly raising exception
            # Exception limited to 1M
            data = _eintr_retry_call(os.read, errpipe_read, 1048576)
        finally:
            # be sure the FD is closed no matter what
            os.close(errpipe_read)

        if data != "":
            # _eintr_retry_call(os.waitpid, self.pid, 0)
            child_exception = pickle.loads(data)
            for fd in (p2cwrite, c2pread, errread):
                if fd is not None:
                    os.close(fd)
            raise child_exception

    def wait(self):
        self.process_done.wait()
        return self.returncode
    # end overiding

    def close(self):
        # stop std* watchers
        if self.stdout_watcher is not None:
            self.stdout_watcher.stop()
            self.stdout_watcher = None
        if self.stderr_watcher is not None:
            self.stderr_watcher.stop()
            self.stderr_watcher = None
        # close std* file objects if needed
        if self.stdin is not None:
            self.stdin.close()
            self.stdin = None
        if self.stdout is not None:
            self.stdout.close()
            self.stdout = None
        if self.stderr is not None:
            self.stderr.close()
            self.stderr = None
        if self.child_watcher is not None:
            self.child_watcher.stop()
            self.child_watcher = None
        if self.returncode is None:
            # we must kill the child
            self.kill()
            # libev handles zombies


def subproc_call(main_loop, args, stdin=None):
    """
    :param args: arguments for subprocess call
    :param stdin: stdin data as string
    """
    proc = EvPopen(main_loop, args, bufsize=4096, stdin=subprocess.PIPE,
                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                   close_fds=True)
    result, _ = proc.communicate(stdin)
    if proc.returncode != 0:
        raise subprocess.CalledProcessError(proc.returncode,
                                            'Error while executing command')
    return result


class SocketBuffer(deque):
    """Holds bytes in a list.

    This class don't handle maximum size but instead give help like handling
    count automatically.
    """
    def __init__(self, max_len=8 * 64 * 1024):
        deque.__init__(self)
        self.max_len = max_len
        self.current_len = 0

    def append(self, x):
        deque.append(self, x)
        self.current_len += len(x)

    def appendleft(self, x):
        deque.appendleft(self, x)
        self.current_len += len(x)

    def clear(self):
        deque.clear(self)
        self.current_len = 0

    def extend(self, iterable):
        raise NotImplementedError

    def extendleft(self, iterable):
        raise NotImplementedError

    def pop(self):
        elt = deque.pop(self)
        self.current_len -= len(elt)
        return elt

    def popleft(self):
        elt = deque.popleft(self)
        self.current_len -= len(elt)
        return elt

    def remove(value):
        raise NotImplementedError

    def reverse(self):
        raise NotImplementedError

    def rotate(self, n):
        raise NotImplementedError

    def is_full(self):
        return self.current_len >= self.max_len

    def is_empty(self):
        return self.current_len == 0


class Singleton(type):
    """Singleton metaclass."""
    def __init__(cls, name, bases, dict):
        super(Singleton, cls).__init__(cls, bases, dict)
        cls._instance = None

    def __call__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(Singleton, cls).__call__(*args, **kwargs)

        return cls._instance


def close_fds(exclude_fds=None, debug=False):
    """Close all fds uneeded fds in child when using fork.

    :param exclude_fds: list of file descriptors that should not be closed (0,
        1, 2 must not be set here, see debug)
    :param bool debug: indicates if std in/out should be left open (usually for
        debuging purpose)
    """
    # get max fd
    limit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
    if limit == resource.RLIM_INFINITY:
        max_fd = 2048
    else:
        max_fd = limit

    if exclude_fds is None:
        exclude_fds = []
    if debug:
        exclude_fds += [0, 1, 2]  # debug

    for fd in xrange(max_fd, -1, -1):
        if fd in exclude_fds:
            continue
        try:
            os.close(fd)
        except OSError as exc:
            if exc.errno != errno.EBADF:
                raise
            # wasn't open

    if not debug:
        sys.stdin = open(os.devnull)
        sys.stdout = open(os.devnull, 'w')
        sys.stderr = open(os.devnull, 'w')
        assert sys.stdin.fileno() == 0
        assert sys.stdout.fileno() == 1
        assert sys.stderr.fileno() == 2


def set_signal_map(map_):
    """Set signal map in fork children.

    :param mapping map_: (signal code, handler)...
    :returns: old handlers as dict
    """
    previous_handlers = dict()

    for sig, handler in map_.iteritems():
        previous_handlers[sig] = signal.signal(sig, handler)

    return previous_handlers


sig_names = dict((k, v) for v, k in signal.__dict__.iteritems() if
                 v.startswith('SIG'))


def num_to_sig(num):
    """Returns signal name.

    :param num: signal number
    """
    return sig_names.get(num, 'Unknown signal')