import logging
import socket
from StringIO import StringIO
from xml.etree import cElementTree as et

import libvirt
from sjrpc.utils import threadless, pass_connection
from cloudcontrol.common.client.tags import Tag, tag_inspector

from cloudcontrol.node.host import Handler as HostHandler
from cloudcontrol.node.hypervisor import tags
from cloudcontrol.node.hypervisor.kvm import KVM, LiveMigration
from cloudcontrol.node.exc import (
    UndefinedDomain, DRBDError, PoolStorageError
)
from cloudcontrol.node.hypervisor.jobs import (
    ImportVolume, ExportVolume, TCPTunnel, DRBD,
)


logger = logging.getLogger(__name__)


# FIXME find a way to refactor Handler and Hypervisor class
class Handler(HostHandler):
    def __init__(self, *args, **kwargs):
        """
        :param loop: MainLoop instance
        :param hypervisor_name: hypervisor name
        """
        self.hypervisor_name = kwargs.pop('hypervisor_name')
        HostHandler.__init__(self, *args, **kwargs)

        #: keep index of asynchronous calls
        self.async_calls = dict()

        self.timer = self.main.evloop.timer(.0, 5., self.virt_connect_cb)
        self.hypervisor = None
        self._virt_connected = False

        # register tags
        self.tag_db.add_tags(tag_inspector(tags, self))

    @property
    def virt_connected(self):
        return self._virt_connected

    @virt_connected.setter
    def virt_connected(self, value):
        self._virt_connected = value
        # update tags
        for tag in ('vir_status', 'sto', 'nvm', 'vmpaused', 'vmstarted',
                    'vmstopped', 'hvver', 'libvirtver', 'hv'):
            self.tag_db['__main__'][tag].update_value()

    def start(self):
        self.timer.start()
        HostHandler.start(self)

    def stop(self):
        self.timer.stop()
        if self.hypervisor is not None:
            self.hypervisor.stop()
        HostHandler.stop(self)

    def virt_connect_cb(self, *args):
        # initialize hypervisor instance
        try:
            self.hypervisor = KVM(
                name=self.hypervisor_name,
                handler=self,
            )
        except libvirt.libvirtError:
            logger.exception('Error while connecting to libvirt')
            return

        self.virt_connected = True

        # register hypervisor storage tags
        for name, storage in self.hypervisor.storage.storages.iteritems():
            self.tag_db.add_tags((
                Tag('sto%s_state' % name, lambda sto: sto.state, 5, 5, storage),
                Tag('sto%s_size' % name,
                    lambda sto: sto.capacity, 5, 5, storage),
                Tag('sto%s_free' % name,
                    lambda sto: sto.available, 5, 5, storage),
                Tag('sto%s_used' % name,
                    lambda sto: sto.capacity - sto.available, 5, 5, storage),
                Tag('sto%s_type' % name, lambda sto: sto.type, 5, 5, storage),
            ))

        # register domains
        for dom in self.hypervisor.domains.itervalues():
            self.tag_db.add_sub_object(dom.name, dom.tags.itervalues(), 'vm')

        # we must refresh those tags only when domains tags are registered to
        # have the calculated values
        for tag in ('cpualloc', 'cpurunning', 'memalloc', 'memrunning'):
            self.tag_db['__main__'][tag].update_value()

        self.rpc_handler.update(dict(
            vm_define=self.vm_define,
            vm_undefine=self.vm_undefine,
            vm_export=self.vm_export,
            vm_stop=self.vm_stop,
            vm_start=self.vm_start,
            vm_suspend=self.vm_suspend,
            vm_resume=self.vm_resume,
        ))
        self.main.reset_handler('vm_define', self.vm_define)
        self.main.reset_handler('vm_undefine', self.vm_undefine)
        self.main.reset_handler('vm_export', self.vm_export)
        self.main.reset_handler('vm_stop', self.vm_stop)
        self.main.reset_handler('vm_destroy', self.vm_destroy)
        self.main.reset_handler('vm_start', self.vm_start)
        self.main.reset_handler('vm_suspend', self.vm_suspend)
        self.main.reset_handler('vm_resume', self.vm_resume)
        self.main.reset_handler('vm_migrate_tunneled', self.vm_migrate_tunneled)
        self.main.reset_handler('vol_create', self.vol_create)
        self.main.reset_handler('vol_delete', self.vol_delete)
        self.main.reset_handler('vol_import', self.vol_import)
        self.main.reset_handler('vol_import_wait', self.vol_import_wait)
        self.main.reset_handler('vol_export', self.vol_export)
        self.main.reset_handler('tun_setup', self.tun_setup)
        self.main.reset_handler('tun_connect', self.tun_connect)
        self.main.reset_handler('tun_connect_hv', self.tun_connect_hv)
        self.main.reset_handler('tun_destroy', self.tun_destroy)
        self.main.reset_handler('drbd_setup', self.drbd_setup)
        self.main.reset_handler('drbd_connect', self.drbd_connect)
        self.main.reset_handler('drbd_role', self.drbd_role)
        self.main.reset_handler('drbd_takeover', self.drbd_takeover)
        self.main.reset_handler('drbd_sync_status', self.drbd_sync_status)
        self.main.reset_handler('drbd_shutdown', self.drbd_shutdown)
        self.main.reset_handler('vm_open_console', self.vm_open_console)
        self.main.reset_handler('vm_disable_virtio_cache',
                                self.vm_disable_virtio_cache)
        self.main.reset_handler('vm_set_autostart', self.vm_set_autostart)

        # if everything went fine, unregister the timer
        self.timer.stop()

    def virt_connect_restart(self):
        """Restart libvirt connection.

        This method might be called when libvirt connection is lost.
        """
        if not self.virt_connected:
            return

        logger.error('Connection to libvirt lost, trying to restart')
        # update connection state
        self.virt_connected = False
        # refresh those tags
        for tag in ('cpualloc', 'cpurunning', 'memalloc', 'memrunning'):
            self.tag_db['__main__'][tag].update_value()

        # unregister tags that will be re registered later
        for storage in self.hypervisor.storage.storages:
            self.tag_db.remove_tags((
                'sto%s_state' % storage,
                'sto%s_size' % storage,
                'sto%s_free' % storage,
                'sto%s_used' % storage,
                'sto%s_type' % storage,
            ))
        # unregister sub objects (for the same reason)
        for sub_id in self.tag_db.keys():
            if sub_id == '__main__':
                continue
            self.tag_db.remove_sub_object(sub_id)
        # stop and delete hypervisor instance
        self.hypervisor.stop()
        self.hypervisor = None

        # remove handlers related to libvirt
        self.main.remove_handler('vm_define')
        self.main.remove_handler('vm_undefine')
        self.main.remove_handler('vm_export')
        self.main.remove_handler('vm_stop')
        self.main.remove_handler('vm_destroy')
        self.main.remove_handler('vm_start')
        self.main.remove_handler('vm_suspend')
        self.main.remove_handler('vm_resume')
        self.main.remove_handler('vm_migrate_tunneled')
        self.main.remove_handler('vol_create')
        self.main.remove_handler('vol_delete')
        self.main.remove_handler('vol_import')
        self.main.remove_handler('vol_import_wait')
        self.main.remove_handler('vol_export')
        self.main.remove_handler('tun_setup')
        self.main.remove_handler('tun_connect')
        self.main.remove_handler('tun_connect_hv')
        self.main.remove_handler('tun_destroy')
        self.main.remove_handler('drbd_setup')
        self.main.remove_handler('drbd_connect')
        self.main.remove_handler('drbd_role')
        self.main.remove_handler('drbd_takeover')
        self.main.remove_handler('drbd_sync_status')
        self.main.remove_handler('drbd_shutdown')
        self.main.remove_handler('vm_open_console')
        self.main.remove_handler('vm_disable_virtio_cache')
        self.main.remove_handler('vm_set_autostart')
        # launch connection timer
        self.timer.start()

    def iter_vms(self, vm_names):
        """Utility function to iterate over VM objects using their names."""
        if vm_names is None:
            return
        get_domain = self.hypervisor.domains.get
        for name in vm_names:
            dom = get_domain(name)
            if dom is not None:
                yield dom

    def vm_define(self, data, format='xml'):
        logger.debug('VM define')
        if format != 'xml':
            raise NotImplementedError('Format not supported')

        return self.hypervisor.vm_define(data)

    def vm_undefine(self, name):
        logger.debug('VM undefine %s', name)
        vm = self.hypervisor.domains.get(name)
        if vm is not None:
            vm.undefine()

    def vm_export(self, name, format='xml'):
        logger.debug('VM export %s', name)
        if format != 'xml':
            raise NotImplementedError('Format not supported')

        vm = self.hypervisor.domains.get(name)

        if vm is None:
            return

        return vm.lv_dom.XMLDesc(0)

    def vm_stop(self, name):
        logger.debug('VM stop %s', name)
        try:
            self.hypervisor.domains[name].stop()
        except libvirt.libvirtError:
            logger.exception('Error while stopping VM %s', name)
            raise
        except KeyError:
            msg = 'Cannot stop VM %s because it is not defined' % name
            logger.error(msg)
            raise UndefinedDomain(msg)

    def vm_destroy(self, name):
        logger.debug('VM destroy %s', name)
        try:
            self.hypervisor.domains[name].destroy()
        except libvirt.libvirtError as exc:
            # Libvirt raises exception 'domain is not running' even if domain
            # is running, might be a bug in libvirt
            if 'domain is not running' not in str(exc) or (
                self.hypervisor.domains[name].state != 'running'):
                logger.exception('Error while destroying VM %s', name)
                raise
        except KeyError:
            msg = 'Cannot destroy VM %s because it is not defined' % name
            logger.error(msg)
            raise UndefinedDomain(msg)

    def vm_start(self, name):
        logger.debug('VM start %s', name)
        try:
            self.hypervisor.domains[name].start()
        except libvirt.libvirtError:
            logger.exception('Error while starting VM %s', name)
            raise
        except KeyError:
            msg = 'Cannot start VM %s because it is not defined' % name
            logger.error(msg)
            raise UndefinedDomain(msg)

    def vm_suspend(self, name):
        logger.debug('VM suspend %s', name)
        try:
            self.hypervisor.domains[name].suspend()
        except libvirt.libvirtError:
            logger.exception('Error while suspending VM %s', name)
            raise
        except KeyError:
            msg = 'Cannot suspend VM %s because it is not defined' % name
            logger.error(msg)
            raise UndefinedDomain(msg)

    def vm_resume(self, name):
        logger.debug('VM resume %s', name)
        try:
            self.hypervisor.domains[name].resume()
        except libvirt.libvirtError:
            logger.exception('Error while resuming VM %s', name)
            raise
        except KeyError:
            msg = 'Cannot resume VM %s because it is not defined' % name
            logger.error(msg)
            raise UndefinedDomain(msg)

    def vm_migrate_tunneled(self, name, tun_res, migtun_res, unsafe=False,
                            timeout=60.):
        """Live migrate VM through TCP tunnel.

        :param name: VM name to migrate
        :param tun_res: result of tunnel_setup handler
        :param migtun_res: result of tunnel setup handler
        :param bool unsafe: unsafe migration
        :param float timeout: timeout for libvirt migration (prevents libvirt
            from trying to acquire domain lock forever)
        :param float timeout: migration timeout in seconds
        """
        logger.debug('VM live migrate %s', name)

        try:
            # this is the port used by our libvirt in the cc-node (client
            # libvirt) to connect to the remote libvirtd
            remote_virt_port = tun_res['port']
        except KeyError:
            logger.error('Invalid formatted argument tun_res for live'
                         ' migration')
            raise
        try:
            # this is the port used by local libvirtd to connect to the remote
            # libvirtd (see http://libvirt.org/migration.html)
            remote_virt_port2 = migtun_res['port']
        except KeyError:
            logger.error('Invalid formatted argument migtun_res for live'
                         ' migration')
            raise
        try:
            vm = self.hypervisor.domains[name]
        except KeyError:
            logger.exception('Cannot find domain %s on hypervisor for live'
                             ' migration', name)
            raise

        migration = LiveMigration(self.main, vm, remote_virt_port,
                                  remote_virt_port2, timeout, unsafe)
        try:
            migration.wait()
        except Exception:
            logger.exception('Error during live migration for vm %s', name)
            logger.debug('Exit status %d', migration.return_status)
            raise

        logger.info('Sucessfuly migrated vm %s', name)

    @threadless
    @pass_connection
    def vm_open_console(self, conn, name):
        """
        :param conn: sjRPC connection instance
        :param name: VM name
        """
        vm = self.hypervisor.domains[name]

        # create connection to the VM console
        try:
            endpoint = vm.open_console()
        except socket.error:
            # cannot create socketpair
            logger.error('Cannot create connection to VM console')
            raise

        def on_shutdown(tun):
            """Method of Tunnel protocol close callback."""
            vm.close_console()

        # connect as tunnel endpoint
        proto = conn.create_tunnel(endpoint=endpoint, on_shutdown=on_shutdown)
        return proto.label

    def vm_disable_virtio_cache(self, name):
        """Set virtio cache to none on VM disks.

        :param name: VM name
        """
        vm = self.hypervisor.domains[name]

        # get VM XML
        try:
            xml = vm.lv_dom.XMLDesc(0)
        except libvirt.libvirtError:
            logger.exception('Error while getting domain XML from libvirt, %s',
                             vm.name)
            raise

        xml_tree = et.ElementTree()
        xml_tree.parse(StringIO(xml))
        for disk in xml_tree.findall('devices/disk'):
            # check that disk is virtio
            target = disk.find('target')
            if target is None or target.get('bus') != 'virtio':
                continue
            # modify cache attr
            driver = disk.find('driver')
            assert driver is not None
            driver.set('cache', 'none')
            logger.debug('Set cache attribute for disk %s of VM %s',
                         target.get('dev'), name)

        # write back the XML tree
        out = StringIO()
        xml_tree.write(out)  # check encoding is fine
        try:
            self.hypervisor.vir_con.defineXML(out.getvalue())
        except libvirt.libvirtError:
            logger.exception('Cannot update XML file for domain %s', name)
            raise

    def vm_set_autostart(self, name, autostart=True):
        """Set autostart on VM.

        :param name: VM name
        :param bool autostart: autostart value to set
        """
        vm = self.hypervisor.domains[name]
        vm.lv_dom.setAutostart(int(bool(autostart)))
        # update autostart value now instead of 10 seconds lag
        vm.tags['autostart'].update_value()

    def vol_create(self, pool, name, size):
        logger.debug('Volume create %s, pool %s, size %s', name, pool, size)
        try:
            self.hypervisor.storage.create_volume(pool, name, size)
        except Exception:
            logger.exception('Error while creating volume')
            raise

    def vol_delete(self, pool, name):
        logger.debug('Volume delete %s, pool %s', name, pool)
        try:
            self.hypervisor.storage.delete_volume(pool, name)
        except Exception:
            logger.exception('Error while deleting volume')
            raise

    def vol_import(self, pool, name):
        """
        :param pool: pool name where the volume is
        :param name: name of the volume
        """
        logger.debug('Volume import pool = %s, volume = %s', pool, name)
        try:
            pool = self.hypervisor.storage.get_storage(pool)
            if pool is None:
                raise PoolStorageError('Pool storage does not exist')

            volume = pool.volumes.get(name)
            if volume is None:
                raise PoolStorageError('Volume does not exist')

            # create the job
            job = self.main.job_manager.create(ImportVolume, volume)
            job.start()
        except Exception:
            logger.exception('Error while starting import job')
            raise

        return dict(id=job.id, port=job.port)

    def vol_import_wait(self, job_id):
        """Block until completion of the given job id."""
        job = self.main.job_manager.get(job_id)
        logger.debug('Waiting for import job to terminate')
        job.wait()
        logger.debug('Import job terminated')

        return dict(id=job.id, log='', checksum=job.checksum)

    def vol_import_cancel(self, job_id):
        """Cancel import job."""
        logger.debug('Cancel import job')
        job = self.main.job_manager.get(job_id)
        self.main.job_manager.cancel(job_id)
        # wait for job to end
        job.join()  # we don't call wait as it is already called in
                    # vol_import_wait handler

    def vol_export(self, pool, name, raddr, rport):
        """
        :param pool: pool name where the volume is
        :param name: name of the volume
        :param raddr: IP address of the destination to send the volume to
        :param rport: TCP port of the destination
        """
        pool = self.hypervisor.storage.get_storage(pool)

        if pool is None:
            raise PoolStorageError('Pool storage does not exist')

        volume = pool.volumes.get(name)

        if volume is None:
            raise PoolStorageError('Volume does not exist')

        try:
            job = self.main.job_manager.create(ExportVolume, volume, raddr, rport)
            job.start()
            job.wait()
        except Exception:
            logger.exception('Error while exporting volume')
            raise

        logger.debug('Export volume successfull')
        return dict(id=job.id, log='', checksum=job.checksum)

    @threadless
    def tun_setup(self, local=True):
        """Set up local tunnel and listen on a random port.

        :param local: indicate if we should listen on localhost or all
            interfaces
        """
        logger.debug('Tunnel setup: local = %s', local)
        # create job
        job = self.main.job_manager.create(TCPTunnel)
        job.setup_listen('127.0.0.1' if local else '0.0.0.0')
        return dict(
            jid=job.id,
            key='FIXME',
            port=job.port,
        )

    @threadless
    def tun_connect(self, res, remote_res, remote_ip):
        """Connect tunnel to the other end.

        :param res: previous result of `tun_setup` handler
        :param remote_res: other end result of `tun_setup` handler
        :param remote_ip: where to connect
        """
        logger.debug('Tunnel connect %s %s', res['jid'], remote_ip)
        job = self.main.job_manager.get(res['jid'])
        job.setup_connect((remote_ip, remote_res['port']))
        job.start()

    @threadless
    def tun_connect_hv(self, res, migration=False):
        """Connect tunnel to local libvirt Unix socket.

        :param res: previous result of `tun_setup` handler
        """
        logger.debug('Tunnel connect hypervisor %s', res['jid'])
        job = self.main.job_manager.get(res['jid'])
        job.setup_connect('/var/run/libvirt/libvirt-sock')
        job.start()

    @threadless
    def tun_destroy(self, res):
        """Close given tunnel.

        :param res: previous result as givent by `tun_setup` handler
        """
        logger.debug('Tunnel destroy %s', res['jid'])
        job = self.main.job_manager.get(res['jid'])
        self.main.job_manager.cancel(job.id)
        job.wait()

    def drbd_setup(self, pool, name):
        """Create DRBD volumes.

        :param pool: storage pool
        :param name: storage volume name
        """
        pool = self.hypervisor.storage.get_storage(pool)
        if pool is None:
            raise DRBDError('Cannot setup DRBD: pool storage does not exist')
        elif pool.type != 'logical':
            raise DRBDError('Cannot setup DRBD: pool storage is not LVM')

        volume = pool.volumes.get(name)
        if volume is None:
            raise DRBDError('Cannot setup DRBD: volume does not exist')

        try:
            job = self.main.job_manager.create(DRBD, self.hypervisor.storage,
                                               pool, volume)
        except Exception:
            logger.exception('Error while creating DRBD job')
            raise

        job.setup()

        logger.debug('DRBD setup successfull')
        return dict(
            jid=job.id,
            port=job.drbd_port,
        )

    def drbd_connect(self, res, remote_res, remote_ip):
        """Set up DRBD in connect mode. (Wait for connection and try to connect
        to the remote peer.

        :param res: previous result of `drbd_setup` handler
        :param remote_res: result of remote `drbd_setup` handler
        :param remote_ip: IP of remote peer
        """
        job = self.main.job_manager.get(res['jid'])
        job.connect(remote_ip, remote_res['port'])
        job.wait_connection()

    def drbd_role(self, res, primary):
        """Set up DRBD role.

        :param res: previous result of `drbd_setup` handler
        :param bool primary: if True, set up in primary mode else secondary
        """
        job = self.main.job_manager.get(res['jid'])
        if primary:
            job.switch_primary()
        else:
            job.switch_secondary()

    def drbd_takeover(self, res, state):
        """Set up DRBD device as the VM disk. FIXME

        :param res: previous result of `drbd_setup` handler
        :param state: FIXME
        """
        job = self.main.job_manager.get(res['jid'])
        job.takeover()

    def drbd_sync_status(self, res):
        """Return synchronization status of a current DRBD job.

        :param res: previous result of `drbd_setup` handler
        """
        status = self.main.job_manager.get(res['jid']).status()
        result = dict(
            done=status['disk'] == 'UpToDate',
            completion=status['percent'],
        )
        logger.debug('DRBD status %s', result)
        return result

    def drbd_shutdown(self, res):
        """Destroy DRBD related block devices.

        :param res: previous result of `drbd_setup` handler
        """
        logger.debug('DRBD shutdown')
        job = self.main.job_manager.get(res['jid'])
        job.cleanup()

        # remove job from job_manager list
        self.main.job_manager.notify(job)
