# This file is part of CloudControl.
#
# CloudControl is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# CloudControl is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with CloudControl.  If not, see <http://www.gnu.org/licenses/>.

""" This module contains the hypervisor allocation algorithm.
"""


from collections import defaultdict

from cloudcontrol.server.utils import itercounter


class AllocationError(Exception):
    """ Exception raised when an error occurs while allocating hypervisor
        to a virtual machine.
    """


# Target filters:

class Filter(object):

    def __init__(self, logger, vmspec, server, client):
        self.logger = logger
        self.vmspec = vmspec
        self.server = server
        self.client = client

    def tql_filter(self, query):
        return query

    def filter(self, candidates):
        for candidate in candidates:
            yield candidate

    def sorter(self, candidate):
        return None


class IsVmUnique(Filter):

    """ Raise an allocation error if a VM with the same title already exists.
    """

    def filter(self, candidates):
        title = self.vmspec.get('title')
        if title:
            vms = self.server.list('r=vm&t="%s"' % title)
            if vms:
                raise AllocationError('A virtual machine with the same title already exists')
        for candidate in candidates:
            yield candidate


class TargetFilter(Filter):

    """ Filter on targeted hypervisors.
    """

    def tql_filter(self, query):
        if 'target' in self.vmspec:
            query = '(%s)&%s' % (query, self.vmspec['target'])
        return query


class IsAllocatable(Filter):

    """ Filter the administratively non-allocatable hypervisors.
    """

    def tql_filter(self, query):
        if 'ignore_alloc' not in self.vmspec.get('flags', []):
            return '(%s)&alloc=yes' % query
        else:
            return query


class IsConnected(Filter):

    """ Filter the unconnected hypervisors.
    """

    def tql_filter(self, query):
        return '(%s)&con' % query


class HaveEnoughMemory(Filter):

    """ Filter hypervisors with not enough allocatable memory for the VM.
    """

    def tql_filter(self, query):
        if 'do_not_check_memory' not in self.vmspec.get('flags', []):
            return '(%s)&memremaining>%s' % (query, self.vmspec['memory'])
        else:
            return '(%s)&memremaining' % query

    def sorter(self, candidate):
        try:
            return -float(candidate.get('memremaining'))
        except ValueError:
            return float('inf')


class HaveEnoughStorage(Filter):

    """ Filter hypervisors with not enough storage for the VM.
    """

    DEFAULT_VG = 'local'

    def tql_filter(self, query):
        # Compute the total size per VG:
        size_by_vg = defaultdict(lambda: 0)
        for volume in self.vmspec.get('volumes', []):
            size_by_vg[volume.get('pool', self.DEFAULT_VG)] += volume['size']
        # Generate the TQL query:
        tql = ''
        for vg, size in size_by_vg.iteritems():
            tql += '&(sto%s_shared=yes|sto%s_free>=%s)' % (vg, vg, size)
        return '(%s)%s' % (query, tql)


class HaveEnoughCPU(Filter):

    """ Filter hypervisors with not enough CPU for the VM, according to the
        overcommit policy.
    """

    DEFAULT_ALLOWED_RATIO = 1
    DEFAULT_RESERVED_CPU = 0

    def tql_filter(self, query):
        return '(%s)$cpualloc$cpu$cpuallowedratio$cpuremaining$cpuallocratio$cpureserved' % query

    def filter(self, candidates):
        cpu = int(self.vmspec['cpu'])
        for candidate in candidates:
            try:
                reserved = int(candidate.get('cpureserved', self.DEFAULT_RESERVED_CPU))
            except ValueError:
                reserved = self.DEFAULT_RESERVED_CPU
            ratio = (float(candidate.get('cpualloc')) + cpu) / (float(candidate.get('cpu')) - reserved)
            if ratio <= float(candidate.get('cpuallowedratio', self.DEFAULT_ALLOWED_RATIO)):
                yield candidate

    def sorter(self, candidate):
        try:
            return -float(candidate.get('cpuallocratio'))
        except ValueError:
            return float('inf')


class SatisfyRiskGroups(Filter):

    """ Complies with risk groups.
    """

    def tql_filter(self, query):
        self.riskgroup_name = self.vmspec.get('tags', {}).get('riskgroup')
        self.riskgroup = self.server.riskgroups.get(self.riskgroup_name)

        if self.riskgroup_name and self.riskgroup is None:
            self.logger.warn('Unknown riskgroup used, ignoring')

        if self.riskgroup:
            tags = ''
            for tag in self.riskgroup:
                tags += '$%s' % tag
            return '(%s)%s' % (query, tags)
        return query

    def filter(self, candidates):
        if self.riskgroup is not None:
            # Get the list of VMs within the riskgroup:
            vms = self.server.list('r=vm&riskgroup="%s"$p' % self.riskgroup_name)

            count_per_hv = defaultdict(lambda: 0)
            # Store count per riskgroup as instance attribute as we will also use
            # it in sorting step. Note that filter will ALWAYS be called before
            # a sorting operation.
            self.count_per_riskgroup = dict((x, defaultdict(lambda: 0)) for x in self.riskgroup)

            if vms:
                # Produce the mapping between the hypervisor and the number of VMs
                # in the riskgroup:
                for vm in vms:
                    count_per_hv[vm['p']] += 1

                # Generate the set of hv hosting these VMs:
                hv = set('id=%s' % vm['p'] for vm in vms)

                # Generate the TQL matching this list of hv:
                hv_tql = '|'.join(hv)

                # Generate the list of tags to show on this list of hv:
                hv_tql_show = ''
                for k in self.riskgroup:
                    hv_tql_show += '$%s' % k

                # Assemble the TQL query:
                tql = '(%s)%s' % (hv_tql, hv_tql_show)

                # Execute the query and retrieve the full list of hv hosting
                # the VMs within the riskgroup:
                hvs = self.server.list(tql)

                # Count the number of vm per riskgroup tag:
                for hv in hvs:
                    for tag in self.riskgroup:
                        self.count_per_riskgroup[tag][hv[tag]] += count_per_hv[hv['id']]

            # Yield only hv which have not reached riskgroup limits:
            for hv in candidates:
                for tag, limit in self.riskgroup.iteritems():
                    if self.count_per_riskgroup[tag].get(hv[tag], 0) >= limit:
                        break
                else:
                    yield hv
        else:
            for hv in candidates:
                yield hv

    def sorter(self, candidate):
        orders = []
        if self.riskgroup:
            for tag in self.riskgroup:
                orders.append(self.count_per_riskgroup[tag].get(candidate[tag], 0))
        return orders


class Allocator(object):

    BASE_TARGET_TQL = 'r=hv'
    DEFAULT_FILTERS = [IsAllocatable, TargetFilter, IsConnected,
                       SatisfyRiskGroups, HaveEnoughCPU, HaveEnoughMemory,
                       HaveEnoughStorage]

    def __init__(self, logger, server, client, filters=DEFAULT_FILTERS):
        self.logger = logger
        self.server = server
        self.client = client
        self.filters = filters

    def allocate(self, vmspec, tql_target):
        # Instanciate filters:
        filters = [f(self.logger.getChild(f.__name__), vmspec, self.server, self.client) for f in self.filters]

        self.logger.info('Looking for candidates for vmspec: %r', vmspec)

        # Generate the TQL query to select target hypervisors:
        tql = self.BASE_TARGET_TQL
        if tql_target:
            tql = '(%s)&%s' % (tql, tql_target)
        for filter in filters:
            tql = filter.tql_filter(tql)

        # Get the list of candidates according to the TQL query:
        self.logger.debug('Querying candidates: %s', tql)
        candidates = self.client.list(tql, method='allocate')
        self.logger.debug('Got %d candidates to filter', len(candidates))

        def _cb_logger_debug(count, message):
            self.logger.debug(message, count)

        # Filter the list of candidates:
        for filter_ in filters:
            candidates = itercounter(candidates, _cb_logger_debug,
                                     '%s: %%d candidates In' % filter_.__class__.__name__)
            candidates = filter_.filter(candidates)
            candidates = itercounter(candidates, _cb_logger_debug,
                                     '%s: %%d candidates Out' % filter_.__class__.__name__)

        # Sort the candidates:
        def sorter(candidate):
            return [f.sorter(candidate) for f in filters]
        candidates = sorted(candidates, key=sorter)

        for i, candidate in enumerate(candidates):
            sorters = ['%s: %s' % (f.__class__.__name__, f.sorter(candidate)) for f in filters]
            self.logger.debug('Candidate %d %s -> %s', i, candidate['id'], ' '.join(sorters))

        # Select the first matching candidate:
        if candidates:
            return [x['id'] for x in candidates]
        else:
            raise AllocationError('No candidate found for %s' % vmspec.get('title', repr(vmspec)))
