# -*- coding: utf-8 -*-

import libvirt
import psutil
import xml.dom.minidom
from logging import error, warning, info, debug
from time import sleep
from common import Hypervisor, VM, Storage, StoragePool, StorageVolume
from utils import RWLock
from errors import (HypervisorError, VMError, StorageError, StoragePoolError,
                    StorageVolumeError)


KVM_LIBVIRT_SESSION = 'qemu:///system'
XEN_LIBVIRT_SESSION = 'xen:///'

MEGABYTE_DIV = 1024 * 1024
GIGABYTE_DIV = 1024 * 1024 * 1024
KILOBYTE_DIV = 1024


#### hypervisor

class LibvirtHypervisor(Hypervisor):
    '''
    '''
    
    def __init__(self, hv_type):
        '''
        '''
        super(LibvirtHypervisor, self).__init__()
        try:
            if hv_type == 'kvm':
                warning("LibvirtHypervisor: initialized as KVM")
                self._lvcon_handle = libvirt.open(KVM_LIBVIRT_SESSION)
            elif hv_type == 'xen':
                warning("LibvirtHypervisor: initialized as Xen")
                self._lvcon_handle = libvirt.open(XEN_LIBVIRT_SESSION)
            else:
                raise NotImplemented('Unknown hypervisor type')
        except libvirt.libvirtError as error:
            raise HypervisorError('libvirt cannot connect to hypervisor')
        
        self._hv_type = hv_type
        self._sto_handle = LibvirtStorage(self)
        self._vm_cache_running = {}
        self._vm_cache_defined = {}
        self._vm_cache = {}
        self._vm_cache_lock = RWLock()
    
    def scheduler_run(self):
        self._cache_vm_rebuild()
    
    def _cache_vm_rebuild(self):
        '''
        '''
        running = {}
        defined = {}
        
        try:
            for dom_id in self._lvcon_handle.listDomainsID():
                vm = LibvirtVm(self, self._lvcon_handle.lookupByID(dom_id))
                running[vm.get_name()] = vm
            for dom_name in self._lvcon_handle.listDefinedDomains():
                vm = LibvirtVm(self, self._lvcon_handle.lookupByName(dom_name))
                defined[vm.get_name()] = vm
        except Exception as err:
            debug("_cache_vm_rebuild: abort, caught exception `%r`:`%s`", err,
                                                                        err)
        else:
            with self._vm_cache_lock.write:
                #debug("_cache_vm_rebuild: running: %s" % running)
                #debug("_cache_vm_rebuild: defined: %s" % defined)
                #debug("_cache_vm_rebuild: old-running: %s" % self._vm_cache_running)
                #debug("_cache_vm_rebuild: old-defined: %s" % self._vm_cache_defined)
                self._vm_cache_running = running
                self._vm_cache_defined = defined
                self._vm_cache = self._vm_cache_running
                self._vm_cache.update(self._vm_cache_defined)
    
    def get_hv_type(self):
        '''
        '''
        return self._hv_type
    
    def get_hv_version(self):
        '''
        '''
        version = None
        try:
            data = self._lvcon_handle.getVersion()
            if data:
                version = data  
        except:
            pass
        return version
    
    def get_libvirt_version(self):
        '''
        '''
        version = None
        try:
            data = self._lvcon_handle.getLibVersion()
            if data:
                version = data
        except:
            pass
        return version
    
    def get_cpu_threads(self):
        '''
        '''
        return self._lvcon_handle.getInfo()[7] * self.get_cpu()
    
    def get_cpu_frequency(self):
        '''
        '''
        return self._lvcon_handle.getInfo()[3]
    
    def storage(self):
        '''
        '''
        return self._sto_handle
    
    def vm_define(self, data):
        '''
        '''
        vm = self._lvcon_handle.defineXML(data)
        self._cache_vm_rebuild()
        if hasattr(vm, 'name'):
            return vm.name()
        else:
            raise HypervisorError('VM not defined properly')
    
    def vm_list(self):
        '''
        '''
        with self._vm_cache_lock.read:
            return self._vm_cache.keys()
    
    def vm_list_running(self):
        '''
        '''
        running = []
        with self._vm_cache_lock.read:
            for vm_name, vm in self._vm_cache_running.iteritems():
                if vm.is_active():
                    running.append(vm_name)
        return running
    
    def vm_list_stopped(self):
        '''
        '''
        with self._vm_cache_lock.read:
            return self._vm_cache_defined.keys()
    
    def vm_list_paused(self):
        '''
        '''
        paused = []
        with self._vm_cache_lock.read:
            for vm_name, vm in self._vm_cache_running.iteritems():
                if vm.is_paused():
                    paused.append(vm_name)
        return paused
    
    def vm_get(self, name):
        '''
        '''
        if name in self.vm_list():
            try:
                with self._vm_cache_lock.read:
                    return self._vm_cache[name]
            except:
                raise HypervisorError('VM `%s` has vanished' % name)
        else:
            raise HypervisorError('host has no VM named `%s`' % name)


#### storage

class LibvirtStorage(Storage):
    '''
    '''
    
    def __init__(self, hypervisor):
        '''
        '''
        if isinstance(hypervisor, LibvirtHypervisor):
            self._hv_handle = hypervisor
        else:
            raise TypeError('Expected `%s` given `%s`' % (LibvirtHypervisor,
                                                        hypervisor))
        
        self._pool_cache_running = {}
        self._pool_cache_defined = {}
        self._pool_cache = {}
        self._pool_cache_lock = RWLock()
    
    def _pool_cache_rebuild(self):
        '''
        '''
        with self._pool_cache_lock.write:
            self._pool_cache_running = {}
            self._pool_cache_defined = {}
            self._pool_cache = {}
            
            for name in self._hv_handle._lvcon_handle.listStoragePools():
                pool = LibvirtStoragePool(self,
                        self._hv_handle._lvcon_handle.storagePoolLookupByName(name))
                self._pool_cache_running[pool.get_name()] = pool
            
            for name in self._hv_handle._lvcon_handle.listDefinedStoragePools():
                pool = LibvirtStoragePool(self,
                        self._hv_handle._lvcon_handle.storagePoolLookupByName(name))
                self._pool_cache_defined[pool.get_name()] = pool
            
            self._pool_cache = self._pool_cache_running
            self._pool_cache.update(self._pool_cache_defined)
    
    def pool_list(self):
        '''
        '''
        if not self._pool_cache:
            self._pool_cache_rebuild()
        with self._pool_cache_lock.read:
            return self._pool_cache.keys()
    
    def pool_get(self, name):
        '''
        '''
        if name in self.pool_list():
            try:
                with self._pool_cache_lock.read:
                    return self._pool_cache[name]
            except:
                raise StorageError('storage pool `%s` vanished' % name)
        else:
            raise StorageError('no storage pool with name `%s`' % name)
    
    def capacity(self):
        '''
        '''
        capacity = 0
        for pool_name in self.pool_list():
            pool = self.pool_get(pool_name)
            capacity += pool.get_space_capacity()
        return capacity
    
    def find_volumes(self, path=None, name=None):
        '''
        '''
        volumes = []
        if path is not None or name is not None:
            for pool_name in self.pool_list():
                pool = self.pool_get(pool_name)
                for vol_name in pool.volume_list():
                    vol = pool.volume_get(vol_name)
                    if (path is not None and vol.get_path() == path) \
                            or (name is not None and vol.get_name() == name):
                        volumes.append(vol)
        return volumes


class LibvirtStoragePool(StoragePool):
    '''
    '''
    
    def __init__(self, storage, libvirt_pool):
        '''
        '''
        if isinstance(storage, LibvirtStorage):
            self._sto_handle = storage
        else:
            raise TypeError('Expected `%s` given `%s`' % (LibvirtStorage,
                                                        storage))
        if isinstance(libvirt_pool, libvirt.virStoragePool):
            self._lvpool_handle = libvirt_pool
        else:
            raise TypeError('Expected `%s` given `%s`' % (libvirt.virStoragePool
                                                        , libvirt_pool))
        
        self._vol_cache = {}
        self._vol_cache_lock = RWLock()
    
    def _vol_cache_rebuild(self):
        '''
        '''
        with self._vol_cache_lock.write:
            self._vol_cache = {}
            if self._lvpool_handle.isActive():
                for name in self._lvpool_handle.listVolumes():
                    vol = LibvirtStorageVolume(self,
                            self._lvpool_handle.storageVolLookupByName(name))
                    self._vol_cache[vol.get_name()] = vol
    
    def volume_list(self):
        '''
        '''
        if not self._vol_cache:
            self._vol_cache_rebuild()
        with self._vol_cache_lock.read:
            return self._vol_cache.keys()
    
    def volume_get(self, name):
        '''
        '''
        if name in self.volume_list():
            try:
                with self._vol_cache_lock.read:
                    return self._vol_cache[name]
            except:
                raise StoragePoolError('volume `%s` has vanished from pool `%s`'
                                                    %(name, self.get_name()))
        else:
            raise StoragePoolError('pool `%s` has no volume `%s`' % (
                                                        self.get_name(), name))
    
    def volume_create(self, name, size):
        '''
        '''
        xml = '''
            <volume>
                <name>%(name)s</name>
                <capacity>%(capacity)i</capacity>
            </volume>
        ''' % {
                'name' : name,
                'capacity' : size
            }
        try:
            vol = self._lvpool_handle.createXML(xml, 0)
            if isinstance(vol, libvirt.virStorageVol):
                self._vol_cache_rebuild()
                return vol
            else:
                raise StoragePoolError('volume creation failed for an unknown reason')
        except libvirt.libvirtError as err:
            raise StoragePoolError('volume creation failed : `%r` : `%s`' %
                                                                    (err, err))
    
    def get_name(self):
        '''
        '''
        name = None
        try:
            data = self._lvpool_handle.name()
            if data:
                name = data
        except libvirt.libvirtError:
            pass
        return name
    
    def get_source_name(self):
        '''
        '''
        name = None
        try:
            xroot = xml.dom.minidom.parseString(self._lvpool_handle.XMLDesc(0))
            xpool = xroot.getElementsByTagName('pool').pop()
            xsource = xpool.getElementsByTagName('source').pop()
            xname = xpool.getElementsByTagName('name').pop()
            name = xname.childNodes[0].nodeValue
        except libvirt.libvirtError:
            pass
        return name
    
    def get_source_format(self):
        '''
        '''
        format = None
        try:
            xroot = xml.dom.minidom.parseString(self._lvpool_handle.XMLDesc(0))
            xpool = xroot.getElementsByTagName('pool').pop()
            xsource = xpool.getElementsByTagName('source').pop()
            xformat = xpool.getElementsByTagName('format').pop()
            format = xformat.getAttribute('type')
        except libvirt.libvirtError:
            pass
        return format
    
    def get_type(self):
        '''
        '''
        typ = None
        try:
            xroot = xml.dom.minidom.parseString(
                self._lvpool_handle.XMLDesc(libvirt.VIR_DOMAIN_XML_INACTIVE))
            xpool = xroot.getElementsByTagName('pool').pop()
            typ = xpool.getAttribute('type')
        except libvirt.libvirtError:
            pass
        return typ
    
    def get_space_capacity(self):
        '''
        '''
        try:
            return self._lvpool_handle.info()[1]
        except libvirt.libvirtError as e:
            raise StoragePoolError("can't get pool information (%s)" % e)
    
    def get_space_free(self):
        '''
        '''
        try:
            return self._lvpool_handle.info()[3]
        except libvirt.libvirtError as e:
            raise StoragePoolError("can't get pool information (%s)" % e)
    
    def get_space_used(self):
        '''
        '''
        try:
            return self._lvpool_handle.info()[2]
        except libvirt.libvirtError as e:
            raise StoragePoolError("can't get pool information (%s)" % e)


class LibvirtStorageVolume(StorageVolume):
    '''
    '''
    
    def __init__(self, pool, libvirt_vol):
        '''
        '''
        if isinstance(pool, LibvirtStoragePool):
            self._pool_handle = pool
        else:
            raise TypeError('Expected `%s` given `%s`' % (LibvirtStoragePool,
                                                        pool))
        if isinstance(libvirt_vol, libvirt.virStorageVol):
            self._lvvol_handle = libvirt_vol
        else:
            raise TypeError('Expected `%s` given `%s`' % (libvirt.virStorageVol,
                                                                libvirt_vol))
    
    def wipe(self):
        '''
        '''
        try:
            if self._lvvol_handle.wipe(0):
                raise StorageVolumeError('volume wipe failed for an unknown reason')
        except libvirt.libvirtError as err:
            raise StorageVolumeError('volume wipe failed : `%r` : `%s`' % (err,
                                                                        err))
    
    def delete(self):
        '''
        '''
        try:
            if self._lvvol_handle.delete(0):
                raise StorageVolumeError('volume deletion failed for an unknown reason')
            else:
                self._pool_handle._vol_cache_rebuild()
        except libvirt.libvirtError as err:
            raise StorageVolumeError('volume deletion failed : `%r` : `%s`' %
                                                                    (err, err))
    
    def get_pool(self):
        '''
        '''
        pool = None
        try:
            data = self._lvvol_handle.storagePoolLookupByVolume()
            if data:
                pool = data
        except libvirt.libvirtError:
            pass
        return pool
    
    def get_name(self):
        '''
        '''
        name = None
        try:
            data = self._lvvol_handle.name()
            if data:
                name = data
        except libvirt.libvirtError:
            pass
        return name
    
    def get_space_capacity(self):
        '''
        '''
        capacity = None
        try:
            capacity = self._lvvol_handle.info()[1]
        except libvirt.libvirtError:
            pass
        return capacity
    
    def get_space_allocation(self):
        '''
        '''
        allocated = None
        try:
            allocated = self._lvvol_handle.info()[2]
        except libvirt.libvirtError:
            pass
        return allocated 
    
    def get_path(self):
        '''
        '''
        path = None
        try:
            path = self._lvvol_handle.path()
        except libvirt.libvirtError:
            pass
        return path

#### vm

class LibvirtVm(VM):
    
    ARCH = {
        'i686' : 'x86',
        'x86_64' : 'x64',
    }
    
    STATUS = (
        'No state',
        'Running',
        'Blocked on resource',
        'Paused',
        'Shutting down ...',
        'Shutdown',
        'Crashed',
    )
    STATUS_STOPPED = [0, 5, 6]
    STATUS_RUNNING = [1, 2 , 3, 4]
    STATUS_PAUSED = [3]
    
    def __init__(self, hypervisor, domain):
        '''
        '''
        super(LibvirtVm, self).__init__()
        if isinstance(domain, libvirt.virDomain):
            self._domain = domain
        else:
            raise TypeError('Need virDomain object given %s' % type(domain))
            
        self._hv_handle = hypervisor
        self._find_pid()
    
    def _find_pid(self):
        '''
        '''
        result = find_process_id(self.get_uuid())
        if result:
            self._pid = int(result.pop())
        else:
            self._pid = None
    
    def hypervisor(self):
        '''
        '''
        return self._hv_handle
    
    def undefine(self):
        '''
        '''
        if self._domain.undefine():
            raise VMError('deletion of VM `%s` failed' % self.get_name())
        self._hv_handle._cache_vm_rebuild()
    
    def migrate(self, host, port):
        '''
        '''
        if self.hypervisor().get_hv_type() == 'xen':
            flags = (libvirt.VIR_MIGRATE_LIVE ^
                    libvirt.VIR_MIGRATE_PERSIST_DEST ^
                    libvirt.VIR_MIGRATE_UNDEFINE_SOURCE)
            uri = 'xenmigr://%s:%d' % (host, port)
        else:
            flags = (libvirt.VIR_MIGRATE_LIVE ^
                    libvirt.VIR_MIGRATE_PERSIST_DEST ^
                    libvirt.VIR_MIGRATE_UNDEFINE_SOURCE ^
                    libvirt.VIR_MIGRATE_PEER2PEER ^
                    libvirt.VIR_MIGRATE_TUNNELLED)
            uri = 'qemu+tcp://%s:%d/system' % (host, port)
        
        try:
            self._domain.migrate(self._hv_handle._lvcon_handle, flags, None, uri, 0)
        except libvirt.libvirtError as err:
            # FIXME ignore bogus exception properly
            if not ('no domain with matching name' in err.message
                                    or 'Domain not found' in err.message):
                raise err
    
    def power_on(self):
        '''
        '''
        try:
            self._domain.create()
        except libvirt.libvirtError:
            raise VMError('`%s` is already running' % self.get_name())
    
    def power_off(self):
        '''
        '''
        try:
            self._domain.destroy()
        except libvirt.libvirtError:
            raise VMError('`%s` is not running' % self.get_name())
    
    def power_shutdown(self):
        '''
        '''
        try:
            self._domain.shutdown()
        except libvirt.libvirtError:
            raise VMError('`%s` is not running' % self.get_name())
    
    def power_suspend(self):
        '''
        '''
        try:
            self._domain.suspend()
        except libvirt.libvirtError:
            raise VMError('`%s` is not running, or already paused'
                                                            % self.get_name())
    
    def power_resume(self):
        '''
        '''
        try:
            self._domain.resume()
        except libvirt.libvirtError:
            raise VMError('`%s` is not paused, or not running'
                                                            % self.get_name())
    
    def is_active(self):
        '''
        '''
        active = None
        try:
            active = self._domain.info()[0] in self.STATUS_RUNNING
        except libvirt.libvirtError:
            pass
        return active
    
    def is_paused(self):
        '''
        '''
        paused = None
        try:
            paused = self._domain.info()[0] in self.STATUS_PAUSED
        except libvirt.libvirtError:
            pass
        return paused
    
    def get_config(self):
        '''
        '''
        return self._domain.XMLDesc(0)
    
    def get_uuid(self):
        '''
        '''
        return self._domain.UUIDString()
    
    def get_name(self):
        '''
        '''
        return self._domain.name()
    
    def get_pid(self):
        '''
        '''
        return self._pid
    
    def get_arch(self):
        '''
        '''
        arch = None
        try:
            # bug #4020
            if self.hypervisor().get_hv_type() == 'xen':
                arch = self.hypervisor().get_arch()
            else:
                xroot = xml.dom.minidom.parseString(
                        self._domain.XMLDesc(libvirt.VIR_DOMAIN_XML_INACTIVE))
                xdomain = xroot.getElementsByTagName('domain').pop()
                xos = xdomain.getElementsByTagName('os').pop()
                xtype = xos.getElementsByTagName('type').pop()
                xarch = xtype.getAttribute('arch')
                if xarch in self.ARCH:
                    arch = self.ARCH[xarch]
        except:
            pass
        return arch
    
    def get_cpu_core(self):
        '''
        '''
        return self._domain.info()[3]
    
    def get_cpu_usage(self):
        '''
        '''
        usage = None
        if self._pid is not None:
            try:
                p = psutil.Process(self._pid)
                sleep(0.2)
                usage = p.get_cpu_percent()
            except:
                pass
        return usage
    
    def get_mem(self):
        '''
        '''
        return self._domain.info()[2] * KILOBYTE_DIV
    
    def get_mem_max(self):
        '''
        '''
        return self._domain.info()[1] * KILOBYTE_DIV
    
    def get_vnc_port(self):
        '''
        '''
        port = None
        try:
            xroot = xml.dom.minidom.parseString(
                        self._domain.XMLDesc(libvirt.VIR_DOMAIN_XML_SECURE))
            xdomain = xroot.getElementsByTagName('domain').pop()
            xgraphics = xdomain.getElementsByTagName('graphics').pop()
            data = int(xgraphics.getAttribute('port'))
            if data > 0 and data <= 65535:
                port = data
        except:
            pass
        return port
    
    def get_volumes(self):
        '''
        '''
        volumes = []
        try:
            xroot = xml.dom.minidom.parseString(
                        self._domain.XMLDesc(libvirt.VIR_DOMAIN_XML_INACTIVE))
            xdomain = xroot.getElementsByTagName('domain').pop()
            xdevices = xdomain.getElementsByTagName('devices').pop()
            # iter on "disk" devices
            for xdisk in xdevices.getElementsByTagName('disk'):
                try:
                    # disks we can handle
                    if xdisk.getAttribute('device') == 'disk':
                        # get type
                        d_type = xdisk.getAttribute('type')
                        # get backend path
                        if d_type == 'file':
                            d_path = xdisk.getElementsByTagName('source').pop()\
                                                        .getAttribute('file')
                        elif d_type == 'block':
                            d_path = xdisk.getElementsByTagName('source').pop()\
                                                            .getAttribute('dev')
                            # FIXME sometimes xen do not report '/dev/' at the
                            # beginning of block devices, and relative paths
                            # are non-sense
                            # Example: vg/myvm instead of /dev/vg/myvm
                            if d_path[0] != '/':
                                d_path = '/dev/' + d_path
                        # search the volume object
                        if d_type in ['file', 'block']:
                            volumes.append(self._hv_handle._sto_handle
                                            .find_volumes(path=d_path).pop())
                except Exception as e:
                    print e
        except:
            pass
        return volumes
    
    def get_nics(self):
        '''
        '''
        nics = []
        try:
            xroot = xml.dom.minidom.parseString(
                        self._domain.XMLDesc(libvirt.VIR_DOMAIN_XML_INACTIVE))
            xdomain = xroot.getElementsByTagName('domain').pop()
            xdevices = xdomain.getElementsByTagName('devices').pop()
            # iter on "interface" devices
            for xint in xdevices.getElementsByTagName('interface'):
                nic = {}
                try:
                    # search for network interfaces
                    if xint.getAttribute('type') in ['bridge']:
                        # mac
                        nic['mac'] = xint.getElementsByTagName('mac').pop()\
                                                        .getAttribute('address')
                        # model
                        nic['model'] = xint.getElementsByTagName('model').pop()\
                                                        .getAttribute('type')
                        # source
                        nic['source'] = xint.getElementsByTagName('source')\
                                                .pop().getAttribute('bridge')
                except:
                    pass
                else:
                    nics.append(nic)
        except:
            pass
        return nics


#### helpers

def find_process_id(cmd_subchain):
    '''
    '''
    return [p.pid for p in psutil.get_process_list()
                                        if p.cmdline.__contains__(cmd_subchain)]