import logging
import weakref
from itertools import chain, imap

import libvirt

from ccnode.host import Handler as HostHandler
from ccnode.tags import Tag, tag_inspector, get_tags
from ccnode.hypervisor import tags
from ccnode.hypervisor import lib as _libvirt
from ccnode.hypervisor.lib import (
    DOMAIN_STATES, EVENTS, STORAGE_STATES,
    EventLoop as VirEventLoop,
)
from ccnode.hypervisor.domains import VirtualMachine


logger = logging.getLogger(__name__)


class Handler(HostHandler):
    def __init__(self, *args, **kwargs):
        """
        :param loop: MainLoop instance
        :param hypervisor_name: hypervisor name
        """
        hypervisor_name = kwargs.pop('hypervisor_name')
        HostHandler.__init__(self, *args, **kwargs)

        #: keep index of asynchronous calls
        self.async_calls = dict()

        # initialize hypervisor instance
        # FIXME this may block
        self.hypervisor = Hypervisor(
            name=hypervisor_name,
            loop=self.main,
        )

        # FIXME this may block
        # register hypervisor storage tags
        for name, storage in self.hypervisor.storage.storages.iteritems():
            for t in (
                Tag('sto%s_state' % name, lambda: storage.state, 5),
                Tag('sto%s_size' % name, lambda: storage.capacity, 5),
                Tag('sto%s_free' % name, lambda: storage.available, 5),
                Tag('sto%s_used' % name,
                    lambda: storage.capacity - storage.available, 5),
            ):
                self.tag_db['__main__'][t.name] = t

        # register domains
        for dom in self.hypervisor.domains.itervalues():
            name = dom.name
            # proxy.register(name, 'vm')
            self.async_calls[self.main.rpc_con.rpc.async_call_cb(
                self.register_domain_cb,
                'register',
                name,
                'vm',
            )] = name

        self.tag_db['__main__'].update(dict(
            (t.name, t) for t in tag_inspector(tags, self),
        ))

        self.rpc_handler.update(dict(
            vm_define=self.vm_define,
            vm_undefine=self.vm_undefine,
            vm_export=self.vm_export,
            vm_stop=self.vm_stop,
            vm_start=self.vm_start,
            vm_suspend=self.vm_suspend,
            vm_resume=self.vm_resume,
        ))

    def register_domain_cb(self, call_id, response=None, error=None):
        name = self.async_calls.pop(call_id)
        if error is not None:
            logger.error('Error while registering domain, %s', error)
            return

        logger.debug('Registered domain %s', name)
        domain = self.hypervisor.domains[name]
        for tag in domain.tags.itervalues():
            self.main.reset_sub_tag(domain.name, tag)

    def iter_vms(self, vm_names):
        """Utility function to iterate over VM objects using their names."""
        if vm_names is None:
            return
        get_domain = self.hypervisor.domains.get
        for name in vm_names:
            dom = get_domain(name)
            if dom is not None:
                yield dom

    def vm_define(self, data, format='xml'):
        logger.debug('VM define')
        if format != 'xml':
            raise NotImplementedError('Format not supported')

        try:
            return _libvirt.connection.defineXML(data).name()
        except libvirt.libvirtError:
            logger.exception('Error while creating domain')

    def vm_undefine(self, name):
        logger.debug('VM undefin')
        vm = self.hypervisor.domains.get(name)
        if vm is not None:
            vm.undefine()

    def vm_export(self, name, format='xml'):
        if format != 'xml':
            raise NotImplementedError('Format not supported')

        vm = self.hypervisor.domains.get(name)

        if vm is None:
            return

        return vm.lv_dom.XMLDesc(0)

    def vm_stop(self, vm_names=None, force=False):
        logger.debug('VM stop')
        for vm in self.iter_vms(vm_names):
            try:
                if force:
                    vm.destroy()
                else:
                    vm.stop()
            except libvirt.libvirtError:
                # should we return errors ?
                pass

    def vm_start(self, vm_names=None):
        logger.debug('VM start')
        for vm in self.iter_vms(vm_names):
            vm.start()

    def vm_suspend(self, vm_names=None):
        logger.debug('VM suspend')
        for vm in self.iter_vms(vm_names):
            vm.suspend()

    def vm_resume(self, vm_names=None):
        logger.debug('VM resume')
        for vm in self.iter_vms(vm_names):
            vm.resume()


class Hypervisor(object):
    """Container for all hypervisor related state."""
    def __init__(self, name, loop):
        """
        :param str name: name of hypervisor instance
        :param loop: MainLoop instance
        """
        #: parent MainLoop
        self.main = weakref.proxy(loop)
        self.rpc_con = weakref.proxy(loop.rpc_con)  # FIXME do we need weakref ?
        self.async_calls = dict()

        #: hv attributes
        self.name = name
        self.type = u'kvm'

        # libvirt event loop abstraction
        self.vir_event_loop = VirEventLoop(self.main.loop)
        # This tells libvirt what event loop implementation it
        # should use
        libvirt.virEventRegisterImpl(
            self.vir_event_loop.add_handle,
            self.vir_event_loop.update_handle,
            self.vir_event_loop.remove_handle,
            self.vir_event_loop.add_timer,
            self.vir_event_loop.update_timer,
            self.vir_event_loop.remove_timer,
        )

        # TODO cleanup connection on stop
        _libvirt.connection = libvirt.open('qemu:///system')  # currently only support KVM

        # findout storage
        self.storage = StorageIndex(_libvirt.connection)

        logger.debug('Storages: %s', self.storage.paths)

        #: domains: vms, containers...
        self.domains = dict()
        # find defined domains
        for dom_name in _libvirt.connection.listDefinedDomains():
            dom = _libvirt.connection.lookupByName(dom_name)
            self.domains[dom.name()] = VirtualMachine(dom, self)
        # find started domains
        for dom_id in _libvirt.connection.listDomainsID():
            dom = _libvirt.connection.lookupByID(dom_id)
            self.domains[dom.name()] = VirtualMachine(dom, self)

        logger.debug('Domains: %s', self.domains)

        self.vir_event_loop.register_callbacks(self.vir_cb)

    def stop(self):
        self.vir_event_loop.stop()
        # TODO delet objects

    def vir_cb(self, conn, dom, event, detail, opaque):
        """Callback for libvirt event loop."""
        logger.debug('Received event %s on domain %s, detail %s', event,
                     dom.name(), detail)

        event = EVENTS[event]

        if event == 'Added':
            vm = VirtualMachine(dom, self)
            self.domains[vm.name] = vm
            # self.sjproxy.register(vm.name, 'vm')
            self.async_calls[self.rpc_con.rpc.async_call_cb(
                self.register_cb,
                'register',
                vm.name,
                'vm',
            )] = vm.name
        elif event == 'Removed':
            logger.debug('About to remove domain')
            vm = self.domains.pop(dom.name())
            # self.sjproxy.unregister(vm.name)
            self.async_calls[self.rpc_con.rpc.async_call_cb(
                self.register_cb,
                'register',
                vm.name,
                'vm',
            )] = vm.name
        elif event in ('Started', 'Suspended', 'Resumed', 'Stopped', 'Saved',
                       'Restored'):
            vm = self.domains.get(dom.name())
            # sometimes libvirt sent a start event before a created event so be
            # careful
            if vm is not None:
                state = DOMAIN_STATES[dom.info()[0]]
                logger.info('Domain change state from %s to %s', vm.state,
                             state)
                vm.state = state

    def register_cb(self, call_id, response=None, error=None):
        vm = self.domains[self.async_calls.pop(call_id)]
        if error is not None:
            logger.error('Error while registering domain to server, %s', error)
        logger.info('Add domain: %s (%s)', vm.name, vm.uuid)
        # add tags
        for tag in vm.tags.itervalues():
            self.main.reset_sub_tag(vm.name, tag)

    def unregister_cb(self, call_id, response=None, error=None):
        vm = self.domains[self.async_calls.pop(call_id)]
        if error is not None:
            logger.error('Error while unregistering domain to server, %s', error)
        logger.info('Delete domain: %s (%s)', vm.name, vm.uuid)
        self.main.remove_sub_object(vm)

    def _count_domain(self, filter=lambda d: True):
        count = 0

        for dom in self.domains.itervalues():
            if filter(dom):
                count += 1

        return count

    @property
    def vm_started(self):
        """Number of VMs started."""
        return self._count_domain(lambda d: d.state == 'running')

    @property
    def vm_stopped(self):
        """Number of VMs stopped."""
        return self._count_domain(lambda d: d.state == 'stopped')

    @property
    def vm_paused(self):
        """Number of VMs paused."""
        return self._count_domain(lambda d: d.state == 'paused')

    @property
    def vm_total(self):
        """Total number of VMs on the hypervisor."""
        return self._count_domain()


class StorageIndex(object):
    """Keep an index of all storage volume paths."""
    def __init__(self, lv_con):
        """
        :param lv_con: Libvirt connection
        """
        self.storages = dict(
            (s.name, s) for s in imap(
                Storage,
                imap(
                    lv_con.storagePoolLookupByName,
                    chain(
                        lv_con.listDefinedStoragePools(),
                        lv_con.listStoragePools(),
                    ),
                ),
            ),
        )

        self.paths = dict(
            (v.path, v) for v in chain.from_iterable(map(
                lambda s: s.volumes,
                self.storages.itervalues(),
            )),
        )

    def get_volume(self, path):
        return self.paths.get(path)

    def get_storage(self, name):
        return self.Storage.get(name)


class Storage(object):
    """Storage abstraction."""
    def __init__(self, lv_storage):
        """
        :param lv_storage: Libvirt pool storage instance
        """
        self.uuid = lv_storage.UUID()
        self.name = lv_storage.name()

        self.state, self.capacity, self.allocation, self.available = lv_storage.info()
        self.state = STORAGE_STATES[self.state]

        self.volumes = map(
            Volume,
            (lv_storage.storageVolLookupByName(n) for n in
            lv_storage.listVolumes()),
        )


class Volume(object):
    """Volume abstraction."""
    def __init__(self, lv_volume):
        """
        :param lv_volume: Libvirt volume instance
        """
        self.storage = lv_volume.storagePoolLookupByVolume().name()
        self.path = lv_volume.path()
        self.name = lv_volume.name()
        self.capacity, self.allocation = lv_volume.info()[1:]
