# -*- coding: utf-8 -*-

from fnmatch import fnmatchcase
from threading import Timer, Lock, RLock
from sjrpc.core import RpcError
from itertools import chain
from __init__ import __version__
from sjrpc.utils import RpcHandler
from sjrpc.utils import pure
from logging import debug, critical, warning, info
from exceptions import CCException, FeatureNotImplemented
from common import LocalHost


_MOD_KVM = True
try:
    import kvm
except ImportError:
    _MOD_KVM = False

_MOD_XEN = True
try:
    import xen
except ImportError:
    _MOD_XEN = False


class NodeHandler(RpcHandler):
    '''
    Main node handler that exports the host capabilities to the server.
    '''   
    
    def __init__(self, connection, detect_hv=True, safe_mode=True,
                                                            force_xen=False):
        '''
        '''
        super(RpcHandler, self).__init__()
        self._connection = connection
        self._safe_mode = safe_mode
        
        # create interface with host
        self._host_handle = None
        if detect_hv:
            debug('Hypervisor detection in progress')
            if _MOD_KVM and not force_xen:
                debug('Initializing connection to the local KVM hypervisor')
                self._host_handle = kvm.KvmHypervisor()
            elif _MOD_XEN:
                debug('Initializing connection to the local Xen hypervisor')
                self._host_handle = xen.XenHypervisor()
            
            if self._host_handle is None:
                debug('Hypervisor detection failed')
                
        if not detect_hv or self._host_handle is None:
            debug('Hypervisor detection disabled, running as regular node')
            self._host_handle = LocalHost()
        
        # methods that execute administrative commands, to be banned when
        # running in safe mode
        self.UNSAFE_METHODS = ['execute_command', 'shutdown']
        
        # hypervisor tags
        self.HV_TAG_MANDATORY = ['h']
        self.HV_TAG_MAP = {
            # infinite TTL
            'version'   : ( lambda o: True,
                            lambda o,t: str(__version__),
                            -1),
            'libvirtver': self._tag_map_direct('get_libvirt_version', -1),
            'htype'     : self._tag_map_direct('get_hv_type', -1),
            'hserial'   : self._tag_map_direct('get_hw_serial', -1),
            'hvendor'   : self._tag_map_direct('get_hw_vendor', -1),
            'hmachine'  : self._tag_map_direct('get_hw_product', -1),
            'arch'      : self._tag_map_direct('get_arch', -1),
            'hvm'       : self._tag_map_direct('get_hvm_available', -1),
            'cpu'       : self._tag_map_direct('get_cpu', -1),
            'cpucore'   : self._tag_map_direct('get_cpu_core', -1),
            'cputhread' : self._tag_map_direct('get_cpu_threads', -1),
            # one day
            'hbios'     : self._tag_map_direct('get_hw_bios', 24*3600),
            'hvver'     : self._tag_map_direct('get_hv_version', 24*3600),
            'platform'  : self._tag_map_direct('get_platform', 24*3600),
            'uname'     : self._tag_map_direct('get_uname', 24*3600),
            'cpufreq'   : self._tag_map_direct('get_cpu_frequency', 24*3600),
            'mem'       : self._tag_map_direct('get_mem', 24*3600),
            'disk'      : self._tag_map_keys('get_disks', 24*3600),
            'h'         : self._tag_map_direct('get_name', 24*3600),
            # one hour
            # one minute
            'uptime'    : self._tag_map_direct('get_uptime', 60),
            'memfree'   : self._tag_map_direct('get_mem_free', 60),
            'memused'   : self._tag_map_direct('get_mem_used', 60),
            'sto'       : ( lambda o: hasattr(o, 'storage'),
                            lambda o,t: ' '.join(
                                        getattr(o, 'storage')().pool_list()),
                            60),
            # 5 seconds
            'cpuuse'    : self._tag_map_direct('get_cpu_usage', 5),
            'load'      : self._tag_map_direct('get_loadavg', 5),
            'nvm'       : self._tag_map_counter('vm_list', 5),
            'vmstarted' : self._tag_map_counter('vm_list_running', 5),
            'vmstopped' : self._tag_map_counter('vm_list_stopped', 5),
            'vmpaused'  : self._tag_map_counter('vm_list_paused', 5),
        }
        self.HV_TAG_GLOB = {
            'disk*'     : self._tag_map_helper(self._helper_hv_disk, 24*3600),
            'sto*'      : self._tag_map_helper(self._helper_hv_sto, 60),
        }
        
        # guest VM tags
        self.VM_TAG_MANDATORY = ['hv', 'h']
        self.VM_TAG_MAP = {
            # infinite TTL
            'version'   : ( lambda o: True,
                            lambda o,t: str(__version__),
                            -1),
            'hv'        : ( lambda o: hasattr(o, 'hypervisor'),
                            lambda o,t: o.hypervisor().get_name(),
                            -1),
            'htype'     : ( lambda o: hasattr(o, 'hypervisor'),
                            lambda o,t: o.hypervisor().get_hv_type(),
                            -1),
            'arch'      : self._tag_map_direct('get_arch', -1),
            'h'         : self._tag_map_direct('get_name', -1),
            # one day
            # one hour
            'cpu'       : self._tag_map_direct('get_cpu_core', 3600),
            'mem'       : self._tag_map_direct('get_mem', 3600),
            'memused'   : self._tag_map_direct('get_mem_used', 3600),
            'memfree'   : self._tag_map_direct('get_mem_free', 3600),
            # one minute
            # 5 seconds
            'status'    : ( lambda o: True,
                            lambda o,t: 'running' if o.is_active()
                                        else 'paused' if o.is_paused()
                                        else 'stopped',
                            5), # FIXME crappy tag implementation
            #'cpuuse'    : self._tag_map_direct('get_cpu_usage'),
        }
        self.VM_TAG_GLOB = {
            'disk*'     : self._tag_map_helper(self._helper_vm_disk, 3600),
        }
        
        # FIXME
        self._register_vm = []
    
    def __getitem__(self, name):
        '''
        '''
        # filter the private members access
        if name.startswith('_'):
            raise KeyError('Remote name `%s` is private' % repr(name))
        # filter command execution methods
        elif not self._safe_mode and name in self.UNSAFE_METHODS:
            raise KeyError('Remote name `%s` is disabled by configuration'
                                                                % repr(name))
        else:
            debug('Called %s.%s' % (self.__class__.__name__, name))
            return super(NodeHandler, self).__getitem__(name)
    
    def _tag_map_direct(self, method, ttl):
        '''
        '''
        return ( lambda o: hasattr(o, method),
                 lambda o,t: getattr(o, method)(),
                 ttl)
    
    def _tag_map_counter(self, method, ttl):
        '''
        '''
        return ( lambda o: hasattr(o, method),
                 lambda o,t: len(getattr(o, method)()),
                 ttl)
    
    def _tag_map_keys(self, method, ttl):
        '''
        '''
        return ( lambda o: hasattr(o, method),
                 lambda o,t: ' '.join(getattr(o, method)().keys()),
                 ttl)
    
    def _tag_map_helper(self, helper, ttl):
        '''
        '''
        return ( lambda o, resolve=False: helper(o, resolve=resolve),
                 lambda o, tag_name=None, resolve=False:
                                helper(o, tag_name=tag_name, resolve=resolve),
                 ttl)
    
    def _helper_hv_disk(self, hv, tag_name=None, resolve=True):
        '''
        '''
        result = {}
        disks = hv.get_disks()
        if len(disks):
            result['disk'] = ' '.join(disks.keys())
        for name, size in disks.iteritems():
            if size:
                result['disk%s_size' % name] = str(size)
        if not result:
            result = None
        return result
    
    def _helper_hv_sto(self, hv, tag_name=None, resolve=True):
        '''
        '''
        result = {}
        if hasattr(hv, 'storage'):
            pools = hv.storage().pool_list()
            if len(pools):
                result['sto'] = ' '.join(pools)
            for pool_name in pools:
                pool = hv.storage().pool_get(pool_name)
                capa = pool.get_space_capacity()
                if capa:
                    result['sto%s_size' % pool_name] = str(capa)
                free = pool.get_space_free()
                if free:
                    result['sto%s_free' % pool_name] = str(free)
                used = pool.get_space_used()
                if used:
                    result['sto%s_used' % pool_name] = str(used)
                vol = pool.volume_list()
                if vol:
                    result['sto%s_vol' % pool_name] = ' '.join(vol)
        if not result:
            result = None
        return result
    
    def _helper_vm_disk(self, vm, tag_name=None, resolve=True):
        '''
        '''
        result = {}
        volumes = vm.get_volumes()
        if len(volumes):
            result['disk'] = ' '.join([str(i) for i in range(0, len(volumes))])
        for vol_id, vol in enumerate(volumes):
            path = vol.get_path()
            if path:
                result['disk%i_path' % vol_id] = str(path)
            capa = vol.get_space_capacity()
            if capa:
                result['disk%i_size' % vol_id] = str(capa)
        if not result:
            result = None
        return result
    
    def scheduler_run(self):
        '''
        '''
        # call handler scheduler
        if hasattr(self._host_handle, 'scheduler_run'):
            self._host_handle.scheduler_run()
        # (un)register sub nodes if this host has the capability
        if hasattr(self._host_handle, 'vm_list'):
            try:
                vm_current = self._host_handle.vm_list()
                for vm in vm_current:
                    if vm not in self._register_vm:
                        try:
                            info('registering vm `%s`' % vm)
                            self._connection.get_server().register(vm, 'vm')
                        except RpcError as e:
                            if e.exception == '#FIXME':
                                self._register_vm.append(vm)
                            else:
                                raise e
                        else:
                            self._register_vm.append(vm)
                for vm in self._register_vm:
                    if vm not in vm_current:
                        try:
                            info('unregistering vm `%s`' % vm)
                            self._connection.get_server().unregister(vm)
                        except RpcError as e:
                            if e.exception == '#FIXME':
                                self._register_vm.remove(vm)
                            else:
                                raise e
                        else:
                            self._register_vm.remove(vm)
            except Exception as e:
                debug("REGISTER except `%s`:`%s`" % (repr(e), e))
                pass
    
    @pure
    def get_tags(self, tags=None, noresolve_tags=None):
        '''
        '''
        result = {}
        debug('get_tags: server requested tags=`%s` noresolve_tags=`%s`'
                                                % (tags, noresolve_tags))
        # build a single list of tags
        mytags = {}
        if tags:
            for t in tags:
                mytags[t] = True
        if noresolve_tags:
            for t in noresolve_tags:
                if t not in mytags:
                    mytags[t] = False
        # return all tags if server does not request a subset
        if not mytags:
            # add simple tags
            for t in self.HV_TAG_MAP.keys():
                mytags[t] = True
            # add globbing tags
            for pattern, handler in self.HV_TAG_GLOB.iteritems():
                # helper is available on the current host
                if handler[0](self._host_handle):
                    debug('get_tags: host implements `%s`' % pattern)
                    # get tags from helper
                    htags = handler[0](self._host_handle, resolve=False)
                    # append all tags
                    for t in htags:
                        mytags[t] = True
            debug('get_tags: no tag specified, expanded list to `%s`'
                                                            % mytags.keys())
        # add mandatory tags if missing in the list, or set noresolve
        else:
            for t in self.HV_TAG_MANDATORY:
                if t not in mytags or not mytags[t]:
                    debug('get_tags: add/correct mandatory tag `%s`' % t)
                    mytags[t] = True
        # query host
        debug('get_tags: query host with tag list `%s`' % mytags.keys())
        for tag, resolve in mytags.iteritems():
            # first, test tag name againts list of plain name
            if tag in self.HV_TAG_MAP:
                debug('get_tags: plain mapping found for tag `%s`' % tag)
                if self.HV_TAG_MAP[tag][0](self._host_handle):
                    debug('get_tags: tag `%s` is available on host' % tag)
                    result[tag] = {}
                    result[tag]['ttl'] = self.HV_TAG_MAP[tag][2]
                    if resolve:
                        debug('get_tags: resolving now tag `%s`' % tag)
                        # fetch tag data
                        q = self.HV_TAG_MAP[tag][1](self._host_handle, tag)
                        debug('get_tags: host returned `%s`' % q)
                        if q is not None:
                            # append tag data
                            result[tag]['value'] = str(q)
                        else:
                            debug('get_tags: I wont return `%s`=`None`'
                                                                    % tag)
                else:
                    debug('get_tags: tag `%s` is NOT implemented' % tag)
            # if no direct tag mapping exists, test name against globbing
            # a list
            else:
                debug('get_tags: searching for `%s` in globbing tags' % tag)
                # iter on globbing patterns, and get helper references
                # process the first globbing tag that match then exit bcause 
                # there should not exist two globbing pattern matching
                # one tag, ideally
                for pattern, handler in self.HV_TAG_GLOB.iteritems():
                    # helper is available on the current host
                    if handler[0](self._host_handle, tag):
                        debug('get_tags: testing pattern `%s`' % pattern)
                        if fnmatchcase(tag, pattern):
                            debug('get_tags: processing tag `%s` with '
                                            'pattern `%s`' % (tag, pattern))
                            # get tags from helper
                            htags = handler[1](self._host_handle, tag)
                            # FIXME intead of extracting one tag, try not
                            # to build the whole list. Maybe it's too
                            # difficult and not worth to implement
                            if tag in htags:
                                debug('get_tags: found tag in helper result'
                                            ' with value `%s`' % htags[tag])
                                result[tag] = {}
                                result[tag]['ttl'] = handler[2]
                                if resolve:
                                    result[tag]['value'] = str(htags[tag])
                                
                            break
        debug("get_tags: released lock")
        return result
    
    def _sub_tag_list(self, sub_obj):
        '''
        '''
        result = []
        # add simple tags
        result.extend(self.VM_TAG_MAP.keys())
        # add globbing tags
        for pattern, handler in self.VM_TAG_GLOB.iteritems():
            # helper is available on the current host
            if handler[0](sub_obj):
                debug('sub_tags: sub node implements `%s`' % pattern)
                # get tags from helper
                htags = handler[0](sub_obj, resolve=False)
                debug('sub_tags: handler provides `%s`' % htags)
                # append all tags
                for t in htags.keys():
                    result.append(t)
        return result
    
    @pure
    def sub_tags(self, sub_id, tags=None, noresolve_tags=None):
        '''
        '''
        debug('sub_tags: server requested tags for `%s`' % sub_id)
        if sub_id not in self._host_handle.vm_list():
            debug('sub_tags: sub node `%s` is unknown !' % sub_id)
            raise CCException('sub node `%s` is unknown' % sub_id)
        else:
            # open a wrapper to the VM
            debug('sub_tags: fetching vm data for `%s`' % sub_id)
            sub = self._host_handle.vm_get(sub_id)
        # build a single list of tags
        debug('sub_tags: server requested tags `%s` + `%s`'
                                                    % (tags, noresolve_tags))
        available_tags = self._sub_tag_list(sub)
        mytags = {}
        # return all resolved tags case
        if tags is None and noresolve_tags is None:
            for t in available_tags:
                mytags[t] = True
        elif tags is None or noresolve_tags is None:
            if tags is None:
                for t in available_tags:
                    mytags[t] = True
                for t in noresolve_tags:
                    mytags[t] = False
            else:
                for t in available_tags:
                    mytags[t] = False
                for t in tags:
                    mytags[t] = True
        else:
            for t in noresolve_tags:
                mytags[t] = False
            for t in tags:
                mytags[t] = True
        debug('sub_tags: expanded list to `%s`' % mytags.keys())
        # add mandatory tags if missing in the list, or set noresolve
        for t in self.VM_TAG_MANDATORY:
            if t not in mytags or not mytags[t]:
                debug('sub_tags: add/correct mandatory tag `%s`' % t)
                mytags[t] = True
        # query the subnode
        result = {}
        try:
            '''
            # expand tag list with globbing tags
            if get_all:
                for pattern, handler in self.VM_TAG_GLOB.iteritems():
                    # helper is available on this VM
                    if handler[0](vm):
                        debug('sub_tags: vm implements `%s`' % pattern)
                        # get tags from helper
                        htags = handler[0](vm, resolve=False)
                        # append all tags
                        mytags.extend(htags)
            '''
            # query the VM with each tag
            for tag, resolve in mytags.iteritems():
                # first, search tag in plain mappings
                if tag in self.VM_TAG_MAP:
                    debug('sub_tags: plain mapping found for tag `%s`'
                                                                    % tag)
                    # proceed if tag can be resolved on this VM
                    if self.VM_TAG_MAP[tag][0](sub):
                        result[tag] = {}
                        # FIXME
                        result[tag]['ttl'] = self.VM_TAG_MAP[tag][2]
                        if resolve:
                            debug('sub_tags: resolving tag %s`' % tag)
                            # call the wrapper mapping lambda
                            q = self.VM_TAG_MAP[tag][1](sub, tag)
                            debug('sub_tags: tag query returned `%s`' % q)
                            if q is not None:
                                if resolve:
                                    result[tag]['value'] = str(q)
                            else:
                                debug('sub_tags: I wont return `%s`=`None`'
                                                                    % tag)
                # no tag mapping exist, test name against the globbing list
                else:
                    debug('sub_tags: searching for `%s` in globbing tags'
                                                                    % tag)
                    # iter on globbing patterns, and get helper references
                    # process the first globbing tag that match then exit
                    # because there should not exist two globbing pattern
                    # matching one tag, ideally
                    for pattern, handler in self.VM_TAG_GLOB.iteritems():
                        # helper is available on the current VM
                        if handler[0](sub, tag):
                            if fnmatchcase(tag, pattern):
                                debug('sub_tags: processing tag `%s` with '
                                            'pattern `%s`' % (tag, pattern))
                                # get tags from helper
                                htags = handler[1](sub, tag)
                                # FIXME intead of extracting one tag, try
                                # not to build the whole list. Maybe it's
                                # too difficult and not worth implementing
                                if tag in htags:
                                    debug('sub_tags: found tag in helper '
                                      'result with value `%s`' % htags[tag])
                                    result[tag] = {}
                                    result[tag]['ttl'] = handler[2]
                                    if resolve:
                                        result[tag]['value'] = str(
                                                                htags[tag])
                                break
        except Exception as e:
            debug('sub_tags: `%s` --> `%s`' % (repr(e), e))
        return result
    
    @pure
    def node_shutdown(self, reboot=True, gracefull=True):
        '''
        '''
        info('shutdown: server requested shutdown of local host')
        debug('shutdown: reboot=%s gracefull=%s' % (reboot, gracefull))
        if reboot:
            method = 'power_reboot' if gracefull else 'power_force_reboot'
        else:
            method = 'power_shutdown' if gracefull else 'power_off'
        if hasattr(self._host_handle, method):
            result = getattr(self._host_handle, method)()
            debug('shutdown: in progress ... action returned `%s`' % result)
            return result
        else:
            debug('shutdown: unable to proceed, this feature is not available')
            raise FeatureNotImplemented('host handler has no method `%s`'
                                                                    % method)
    
    @pure
    def vm_stop(self, vm_names=None, force=False):
        '''
        '''
        info('vm_stop: server requested stop of `%s`' % vm_names)
        debug('vm_stop: force stop is `%s`' % force)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_running()
            debug('vm_stop: no vm specified, expanded list to `%s`' % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_stop: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                if force:
                    debug('vm_stop: powering off `%s`' % vm_name)
                    vm.power_off()
                else:
                    info('vm_stop: shutdown requested for `%s`' % vm_name)
                    vm.power_shutdown()
            except:
                pass
    
    @pure
    def vm_start(self, vm_names=None):
        '''
        '''
        info('vm_start: server requested start of `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_stopped()
            debug('vm_start: no vm specified, expanded list to `%s`' % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_start: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                info('vm_start: powering on `%s`' % vm_name)
                vm.power_on()
            except:
                pass
    
    @pure
    def vm_suspend(self, vm_names=None):
        '''
        '''
        info('vm_suspend: server requested suspend of `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_running()
            debug('vm_suspend: no vm specified, expanded list to `%s`'
                                                                    % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_suspend: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                info('vm_suspend: pause execution of `%s`' % vm_name)
                vm.power_suspend()
            except:
                pass
    
    @pure
    def vm_resume(self, vm_names=None):
        '''
        '''
        info('vm_resume: server requested resume of `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_running()
            debug('vm_resume: no vm specified, expanded list to `%s`'% vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_resume: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                info('vm_resume: resume execution of `%s`' % vm_name)
                vm.power_resume()
            except:
                pass
    
    @pure
    def execute_command(self, command):
        '''
        '''
        warning('execute_command: starting execution of `%s`' % command)
        output = self._host_handle.execute(command)
        warning('execute_command: finished execution of `%s`' % command)
        return output
