from __init__ import __version__
from common import LocalHost
from sjrpc.utils import RpcHandler
from sjrpc.utils import pure
from logging import debug, info
from exceptions import FeatureNotImplemented
from fnmatch import fnmatchcase

try:
    import kvm
except ImportError:
    _MOD_KVM = False
else:
    _MOD_KVM = True


class NodeHandler(RpcHandler):
    '''
    Main node handler that exports the host capabilities to the server.
    '''   
    
    def __init__(self, connection, detect_hv, allow_exec):
        '''
        '''
        super(RpcHandler, self).__init__()
        self._connection = connection
        self._allow_cmd_exec = allow_exec
        
        if not detect_hv:
            debug('Hypervisor detection disabled, running as regular'
              ' node')
            self._host_handle = LocalHost()
        else:
            if _MOD_KVM:
                debug('Hypervisor detection...')
                debug('Initializing connection to the KVM hypervisor')
                self._host_handle = kvm.KvmHypervisor()
        
        self.EXEC_METHODS = ['execute_command', 'shutdown']
        
        self.HV_TAG_MANDATORY = ['h']
        self.HV_TAG_MAP = {
            'version'   : ( lambda o: True,
                            lambda o,t: str(__version__)),
            'h'         : self._tag_map_direct('get_name'),
            'htype'     : self._tag_map_direct('get_hv_type'),
            'status'    : self._tag_map_direct('get_status'),
            'hserial'  : self._tag_map_direct('get_hw_serial'),
            'hvendor'   : self._tag_map_direct('get_hw_vendor'),
            'hmachine'  : self._tag_map_direct('get_hw_product'),
            'hbios'     : self._tag_map_direct('get_hw_bios'),
            'arch'      : self._tag_map_direct('get_arch'),
            'platform'  : self._tag_map_direct('get_platform'),
            'uname'     : self._tag_map_direct('get_uname'),
            'uptime'    : self._tag_map_direct('get_uptime'),
            'hvm'       : self._tag_map_direct('get_hvm_available'),
            'libvirtver': self._tag_map_direct('get_hv_version'),
            'load'      : self._tag_map_direct('get_loadavg'),
            'cpu'       : self._tag_map_direct('get_cpu'),
            'cpucore'   : self._tag_map_direct('get_cpu_core'),
            'cputhread' : self._tag_map_direct('get_cpu_threads'),
            'cpufreq'   : self._tag_map_direct('get_cpu_frequency'),
            'cpuuse'    : self._tag_map_direct('get_cpu_usage'),
            'mem'       : self._tag_map_direct('get_mem'),
            'memfree'   : self._tag_map_direct('get_mem_free'),
            'memused'   : self._tag_map_direct('get_mem_used'),
            'disk'      : self._tag_map_keys('get_disks'),
            'sto'       : ( lambda o: hasattr(o, 'storage'),
                            lambda o,t: ' '.join(
                                        getattr(o, 'storage')().pool_list())),
            'nvm'       : self._tag_map_counter('vm_list'),
            'vmstarted' : self._tag_map_counter('vm_list_running'),
            'vmstopped' : self._tag_map_counter('vm_list_stopped'),
            'vmpaused'  : self._tag_map_counter('vm_list_paused'),
        }
        self.HV_TAG_GLOB = {
            'disk*'     : self._tag_map_helper(self._helper_hv_disk),
            'sto*'      : self._tag_map_helper(self._helper_hv_sto),
        }
        
        self.VM_TAG_MANDATORY = ['hv', 'h']
        self.VM_TAG_MAP = {
            'version'   : ( lambda o: True,
                            lambda o,t: str(__version__)),
            'hv'        : ( lambda o: hasattr(o, 'hypervisor'),
                            lambda o,t: o.hypervisor().get_name()),
            # FIXME crappy tag implementation
            'status'    : ( lambda o: True,
                            lambda o,t: 'running' if o.is_active()
                                        else 'paused' if o.is_paused()
                                        else 'stopped'),
            'h'         : self._tag_map_direct('get_name'),
            'arch'      : self._tag_map_direct('get_arch'),
            'cpu'       : self._tag_map_direct('get_cpu'),
            'mem'       : self._tag_map_direct('get_mem'),
            'memused'   : self._tag_map_direct('get_mem_used'),
            'memfree'   : self._tag_map_direct('get_mem_free'),
        }
        self.VM_TAG_GLOB = {
            'disk*'     : self._tag_map_helper(self._helper_vm_disk),
        }
    
    def __getitem__(self, name):
        '''
        '''
        # filter the private members access
        if name.startswith('_'):
            raise KeyError('Remote name `%s` is private' % repr(name))
        # filter command execution methods
        elif not self._allow_cmd_exec and name in self.EXEC_METHODS:
            raise KeyError('Remote name `%s` is disabled by configuration'
                                                                % repr(name))
        else:
            debug('Called %s.%s' % (self.__class__.__name__, name))
            return super(NodeHandler, self).__getitem__(name)
    
    def _tag_map_direct(self, method):
        '''
        '''
        return ( lambda o: hasattr(o, method),
                 lambda o,t: getattr(o, method)())
    
    def _tag_map_counter(self, method):
        '''
        '''
        return ( lambda o: hasattr(o, method),
                 lambda o,t: len(getattr(o, method)()))
    
    def _tag_map_keys(self, method):
        '''
        '''
        return ( lambda o: hasattr(o, method),
                 lambda o,t: ' '.join(getattr(o, method)().keys()))
    
    def _tag_map_helper(self, helper):
        '''
        '''
        return ( lambda o, resolve=False: helper(o, resolve=resolve),
                 lambda o, tag_name=None, resolve=False:
                                helper(o, tag_name=tag_name, resolve=resolve))
    
    def _helper_hv_disk(self, hv, tag_name=None, resolve=True):
        '''
        '''
        result = {}
        disks = hv.get_disks()
        if len(disks):
            result['disk'] = ' '.join(disks.keys())
        for name, size in disks.iteritems():
            if size:
                result['disk%s_size' % name] = str(size)
        if not result:
            result = None
        return result
    
    def _helper_hv_sto(self, hv, tag_name=None, resolve=True):
        '''
        '''
        result = {}
        if hasattr(hv, 'storage'):
            pools = hv.storage().pool_list()
            if len(pools):
                result['sto'] = ' '.join(pools)
            for pool_name in pools:
                pool = hv.storage().pool_get(pool_name)
                capa = pool.get_space_capacity()
                if capa:
                    result['sto%s_size' % pool_name] = str(capa)
                free = pool.get_space_free()
                if free:
                    result['sto%s_free' % pool_name] = str(free)
                used = pool.get_space_used()
                if used:
                    result['sto%s_used' % pool_name] = str(used)
                vol = pool.volume_list()
                if vol:
                    result['sto%s_vol' % pool_name] = ' '.join(vol)
        if not result:
            result = None
        return result
    
    def _helper_vm_disk(self, vm, tag_name=None, resolve=True):
        '''
        '''
        result = {}
        volumes = vm.get_volumes()
        if len(volumes):
            result['disk'] = ' '.join([str(i) for i in range(0, len(volumes))])
        for vol_id, vol in enumerate(volumes):
            path = vol.get_path()
            if path:
                result['disk%i_path' % vol_id] = str(path)
            capa = vol.get_space_capacity()
            if capa:
                result['disk%i_size' % vol_id] = str(capa)
        if not result:
            result = None
        return result
    
    @pure
    def node_tags(self, tags=None):
        '''
        '''
        result = {}
        debug('get_tags: server requested tags `%s`' % tags)
        # return all tags if server does not request a subset
        if tags is None:
            # add simple tags
            tags = self.HV_TAG_MAP.keys()
            # add globbing tags
            for pattern, handler in self.HV_TAG_GLOB.iteritems():
                # helper is available on the current host
                if handler[0](self._host_handle):
                    debug('get_tags: host implements `%s`' % pattern)
                    # get tags from helper
                    htags = handler[0](self._host_handle, resolve=False)
                    # append all tags
                    tags.extend(htags)
            debug('get_tags: no tag specified, expanded list to `%s`' % tags)
        # add mandatory tags if missing in the list
        else:
            for mtag in self.HV_TAG_MANDATORY:
                if mtag not in tags:
                    debug('get_tags: add missing mandatory tag `%s`' % mtag)
                    tags.append(mtag)
        # query host
        debug('get_tags: query host with tag list `%s`' % tags)
        for tag in tags:
            # first, test tag name againts list of plain name
            if tag in self.HV_TAG_MAP:
                debug('get_tags: plain mapping found for tag `%s`' % tag)
                if self.HV_TAG_MAP[tag][0](self._host_handle):
                    # fetch tag data
                    q = self.HV_TAG_MAP[tag][1](self._host_handle, tag)
                    debug('get_tags: host returned `%s`' % q)
                    # append tag information
                    if q is not None:
                        # when a dict is returned, it may contain >1 tags
                        # in this case the real tag name is given by the
                        # wrapper and it may differ from the mapping name
                        if isinstance(q, dict):
                            for key, val in q.iteritems():
                                result[key] = {}
                                result[key]['value'] = str(val)
                                # FIXME really fast
                                result[key]['ttl'] = -1
                        # or there's only one value
                        else:
                            result[tag] = {}
                            result[tag]['value'] = str(q)
                            # FIXME really fast
                            result[tag]['ttl'] = -1
                    else:
                        debug('get_tags: I wont return `%s`=`None`' % tag)
                else:
                    debug('get_tags: tag `%s` is NOT implemented' % tag)
            # if no direct tag mapping exists, test name against globbing a list
            else:
                debug('get_tags: searching for `%s` in globbing tags' % tag)
                # iter on globbing patterns, and get helper references
                # process the first globbing tag that match then exit because 
                # there should not exist two globbing pattern matching
                # one tag, ideally
                for pattern, handler in self.HV_TAG_GLOB.iteritems():
                    # helper is available on the current host
                    if handler[0](self._host_handle, tag):
                        if fnmatchcase(tag, pattern):
                            debug('get_tags: processing tag `%s` with '
                                                'pattern `%s`' % (tag, pattern))
                            # get tags from helper
                            htags = handler[0](self._host_handle, tag)
                            # FIXME intead of extracting one tag, try not
                            # to build the whole list. Maybe it's too
                            # difficult and not worth to implement
                            if tag in htags:
                                debug('get_tags: found tag in helper result'
                                                'with value `%s`' % htags[tag])
                                result[tag] = {}
                                result[tag]['value'] = str(htags[tag])
                                # FIXME
                                result[tag]['ttl'] = -1
                            break
        return result
    
    @pure
    def node_shutdown(self, reboot=True, gracefull=True):
        '''
        '''
        info('shutdown: server requested shutdown of local host')
        debug('shutdown: reboot=%s gracefull=%s' % (reboot, gracefull))
        if reboot:
            method = 'power_reboot' if gracefull else 'power_force_reboot'
        else:
            method = 'power_shutdown' if gracefull else 'power_off'
        if hasattr(self._host_handle, method):
            result = getattr(self._host_handle, method)()
            debug('shutdown: in progress ... action returned `%s`' % result)
            return result
        else:
            debug('shutdown: unable to proceed, this feature is not available')
            raise FeatureNotImplemented('host handler has no method `%s`'
                                                                    % method)
    
    @pure
    def vm_tags(self, vm_names=None, tags=None, resolve=True):
        '''
        '''
        # list all VMs if the server did not provide names
        debug('vm_list: server requested list of vm `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list()
            debug('vm_list: no vm specified, expanded list to `%s`' % vm_names)
        # return all tags if server does not request a subset
        get_all = tags is None
        debug('vm_list: server requested tags `%s`' % tags)
        if get_all:
            # add simple tags
            tags = self.VM_TAG_MAP.keys()
            debug('vm_list: no tag specified, expanded list to `%s`' % tags)
        # add mandatory tags if missing in the list
        else:
            for mtag in self.VM_TAG_MANDATORY:
                if mtag not in tags:
                    debug('vm_list: add missing mandatory tag `%s`' % mtag)
                    tags.append(mtag)
        # query each vm
        result = {}
        for vm_name in vm_names:
            vm_tag = {}
            try:
                # copy tag list for local modifications (globbing)
                mytags = tags
                # open a wrapper to the VM
                debug('vm_list: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                # expand tag list with globbing tags
                if get_all:
                    for pattern, handler in self.VM_TAG_GLOB.iteritems():
                        # helper is available on this VM
                        if handler[0](vm):
                            debug('vm_list: vm implements `%s`' % pattern)
                            # get tags from helper
                            htags = handler[0](vm, resolve=False)
                            # append all tags
                            mytags.extend(htags)
                # query the VM with each tag
                for tag in mytags:
                    # first, search tag in plain mappings
                    if tag in self.VM_TAG_MAP:
                        debug('vm_list: plain mapping found for tag `%s`' % tag)
                        # proceed if tag can be resolved on this VM
                        if self.VM_TAG_MAP[tag][0](vm):
                            vm_tag[tag] = {}
                            # fetch data only if we only built the tag list
                            debug('vm_list: resolving tag %s`' % tag)
                            # call the wrapper mapping lambda
                            q = self.VM_TAG_MAP[tag][1](vm, tag)
                            debug('vm_list: query returned `%s`' % q)
                            # when a dict is returned, it may contain >1 tag
                            # in this case the real tag name is given by the
                            # wrapper and it may differ from the mapping nam
                            if isinstance(q, dict):
                                for key, val in q.iteritems():
                                    if val is not None:
                                        if resolve:
                                            vm_tag[key]['value'] = str(q)
                                        # FIXME really fast
                                        vm_tag[key]['ttl'] = -1
                                    else:
                                        debug('vm_list: I wont return '
                                                        '`%s`=`None`' % key)
                            # or there's only one value
                            elif q is not None:
                                if resolve:
                                    vm_tag[tag]['value'] = str(q)
                                # FIXME really fast
                                vm_tag[tag]['ttl'] = -1
                            else:
                                debug('vm_list: I wont return `%s`=`None`'
                                                                % tag)
                    # no tag mapping exist, test name against the globbing list
                    else:
                        debug('vm_list: searching for `%s` in globbing tags'
                                                                        % tag)
                        # iter on globbing patterns, and get helper references
                        # process the first globbing tag that match then exit
                        # because there should not exist two globbing pattern
                        # matching one tag, ideally
                        for pattern, handler in self.VM_TAG_GLOB.iteritems():
                            # helper is available on the current VM
                            if handler[0](vm, tag):
                                if fnmatchcase(tag, pattern):
                                    debug('get_tags: processing tag `%s` with '
                                                'pattern `%s`' % (tag, pattern))
                                    # get tags from helper
                                    htags = handler[0](vm, tag)
                                    # FIXME intead of extracting one tag, try
                                    # not to build the whole list. Maybe it's
                                    # too difficult and not worth implementing
                                    if tag in htags:
                                        debug('get_tags: found tag in helper '
                                          'result with value `%s`' % htags[tag])
                                        vm_tag[tag] = {}
                                        if resolve:
                                            vm_tag[tag]['value'] = str(
                                                                    htags[tag])
                                        # FIXME
                                        vm_tag[tag]['ttl'] = -1
                                    break
                # save the tag list
                # FIXME: in case of exception, we won't return a single VM tag
                result[vm_name] = vm_tag
            except Exception as e:
                debug('(%s) : %s' % (repr(e), e))
        return result
    
    @pure
    def vm_stop(self, vm_names=None, force=False):
        '''
        '''
        info('vm_stop: server requested stop of `%s`' % vm_names)
        debug('vm_stop: force stop is `%s`' % force)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_running()
            debug('vm_stop: no vm specified, expanded list to `%s`' % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_stop: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                if force:
                    debug('vm_stop: powering off `%s`' % vm_name)
                    vm.power_off()
                else:
                    info('vm_stop: shutdown requested for `%s`' % vm_name)
                    vm.power_shutdown()
            except:
                pass
    
    @pure
    def vm_start(self, vm_names=None):
        '''
        '''
        info('vm_start: server requested start of `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_stopped()
            debug('vm_start: no vm specified, expanded list to `%s`' % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_start: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                info('vm_start: powering on `%s`' % vm_name)
                vm.power_on()
            except:
                pass
    
    @pure
    def vm_suspend(self, vm_names=None):
        '''
        '''
        info('vm_suspend: server requested suspend of `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_running()
            debug('vm_suspend: no vm specified, expanded list to `%s`'
                                                                    % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_suspend: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                info('vm_suspend: pause execution of `%s`' % vm_name)
                vm.power_suspend()
            except:
                pass
    
    @pure
    def vm_resume(self, vm_names=None):
        '''
        '''
        info('vm_resume: server requested resume of `%s`' % vm_names)
        if vm_names is None:
            vm_names = self._host_handle.vm_list_running()
            debug('vm_resume: no vm specified, expanded list to `%s`'
                                                                    % vm_names)
        for vm_name in vm_names:
            try:
                debug('vm_resume: fetching vm data for `%s`' % vm_name)
                vm = self._host_handle.vm_get(vm_name)
                info('vm_resume: resume execution of `%s`' % vm_name)
                vm.power_resume()
            except:
                pass
    
    @pure
    def execute_command(self, command):
        '''
        '''
        info('execute_command: starting execution of `%s`' % command)
        output = self._host_handle.execute(command)
        info('execute_command: finished execution of `%s`' % command)
        return output