# This file is part of CloudControl.
#
# CloudControl is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# CloudControl is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with CloudControl.  If not, see <http://www.gnu.org/licenses/>.


""" Main class of cc-server.
"""

import os
import re
import logging
from fnmatch import fnmatch as glob

from sjrpc.server import SSLRpcServer
from sjrpc.utils import RpcHandler, pass_connection

from cloudcontrol.server.conf import CCConf
from cloudcontrol.server.exceptions import (AlreadyRegistered,
                                            NotConnectedAccountError,
                                            AuthenticationError,
                                            BadRoleError)
#from cloudcontrol.server.jobs import JobsManager
from cloudcontrol.server.clients import Client
from cloudcontrol.server.rights import RightManager
from cloudcontrol.server.riskgroups import RiskgroupsManager
from cloudcontrol.server.jobs import KillOldCliJob
from cloudcontrol.server.repository import Repository

from cloudcontrol.server.db import SObject, SRequestor
from cloudcontrol.common.tql.db.tag import StaticTag, CallbackTag
from cloudcontrol.common.tql.db.db import TqlDatabase, TqlResponse
from cloudcontrol.common.jobs import JobsManager, JobsStore
from cloudcontrol.server.jobsinterface import ServerJobsManagerInterface

# Import all enabled roles:
import cloudcontrol.server.clients.cli
import cloudcontrol.server.clients.host
import cloudcontrol.server.clients.hv
import cloudcontrol.server.clients.bootstrap
import cloudcontrol.server.clients.spv


RE_FAST_RULE_MATCHING = re.compile('^a=([a-zA-Z0-9_-]+)$')


class WelcomeHandler(RpcHandler):
    """ Default handler used on client connections of the server.
    """

    def __init__(self, server):
        self._server = server

    @pass_connection
    def authentify(self, conn, login, password):
        return self._server.authenticate(conn, login, password)


class CCServer(object):
    """ CloudControl server main class.

    :param conf_dir: the directory that store the client configuration
    :param certfile: the path to the ssl certificate
    :param keyfile: the path to the ssl key
    :param address: the interface to bind
    :param port: the port to bind
    """

    # These tags are reserved and cannot be setted by an user:
    RESERVED_TAGS = ('id', 'a', 'r', 'close', 'con', 'ip', 'p')

    def __init__(self, logger, conf_dir, maxcon, maxidle, certfile=None,
                 keyfile=None, address='0.0.0.0', port=1984):

        self.logger = logger

        self._clients = {}  # Clients connected to the server

        # The interface object to the configuration directory:
        self.conf = CCConf(self.logger.getChild('conf'), conf_dir)
        # Some settings:
        self._maxcon = maxcon
        self._maxidle = maxidle
        # SSL configuration stuff:
        if certfile:
            self.logger.info('SSL Certificate: %s', certfile)
        if keyfile:
            self.logger.info('SSL Key: %s', certfile)

        self.db = TqlDatabase(default_requestor=SRequestor())

        # Create the rpc server:
        self.logger.info('Listening on %s:%s', address, port)
        self.rpc = SSLRpcServer.from_addr(address, port, certfile=certfile,
                                          keyfile=keyfile,
                                          conn_kw=dict(handler=WelcomeHandler(self),
                                                       on_disconnect='on_disconnect',
                                                       timeout=120))

        self.motd_filename = os.path.join(conf_dir, 'motd')

        # The jobs manager:
        job_logger = self.logger.getChild('jobs')
        job_logger.setLevel(logging.DEBUG)
        self.jobs = JobsManager(job_logger,
                                ServerJobsManagerInterface(self),
                                JobsStore(os.path.join(conf_dir, 'jobs')))

        # The rights manager:
        self.rights = RightManager(self.logger.getChild('rights'),
                                   os.path.join(conf_dir, 'ruleset'))

        # The riskgroup manager:
        self.riskgroups = RiskgroupsManager(self.logger.getChild('riskgroups'),
                                            os.path.join(conf_dir, 'riskgroups'))

        self.logger.info('Server started to running')

        # Script repository:
        scripts_directory = os.path.join(conf_dir, 'scripts')
        if not os.path.isdir(scripts_directory):
            os.mkdir(scripts_directory)
        self.scripts = Repository(self.logger.getChild('scripts'), self,
                                  scripts_directory, role='script')

        # Plugin repository:
        plugins_directory = os.path.join(conf_dir, 'plugins')
        if not os.path.isdir(plugins_directory):
            os.mkdir(plugins_directory)
        self.plugins = Repository(self.logger.getChild('plugins'), self,
                                  plugins_directory, role='plugin')

    def _update_accounts(self):
        """ Update the database with accounts.
        """

        db_accounts = set((obj['a'].value for obj in list(self.db.objects) if 'a' in obj))
        accounts = set(self.conf.list_accounts())

        to_register = accounts - db_accounts
        to_unregister = db_accounts - accounts

        for login in to_register:
            conf = self.conf.show(login)
            obj = SObject(login)
            obj.register(StaticTag('r', conf['role']), override=True)
            obj.register(StaticTag('a', login), override=True)
            obj.register(CallbackTag('lastcon', lambda l: self.conf.get_last_connection(l),
                                     ttl=0, extra={'l': login}), override=True)
            obj.register(CallbackTag('lastdcon', lambda l: self.conf.get_last_disconnection(l),
                                     ttl=0, extra={'l': login}), override=True)
            # Register static tags:
            for tag, value in self.conf.show(login)['tags'].iteritems():
                obj.register(StaticTag(tag, value), override=True)
            self.db.register(obj)

        for login in to_unregister:
            self.db.unregister(login)

    def iterclients(self, role=None):
        """ Iterate over connected clients with an optionnal role filter.

        :param role: role to filter
        """

        for client in self._clients.itervalues():
            if role is None or client.ROLE == role:
                yield client

    def authenticate(self, conn, login, password):
        """ Authenticate a client against provided login and password.

        If the authentication is a success, register the client on the server
        and return the client role, else, raise an exception.
        """

        logmsg = 'Authentication error from %s: '
        with self.conf:
            try:
                role = self.conf.authentify(login, password)
            except CCConf.UnknownAccount:
                raise AuthenticationError('Unknown login')
            else:
                if 'close' in self.conf.show(login)['tags']:
                    self.logger.warning(logmsg + 'account closed (%s)', conn.getpeername(), login)
                    raise AuthenticationError('Account is closed')

        if role is None:
            self.logger.warning(logmsg + 'bad login/password (%s)', conn.getpeername(), login)
            raise AuthenticationError('Bad login/password')
        else:
            if role not in Client.roles:
                self.logger.warning(logmsg + 'bad role in account config (%s)', conn.getpeername(), login)
                raise BadRoleError('%r is not a legal role' % role)

            # If authentication is a success, try to register the client:
            client = self.register(login, role, conn)
            return client.role

    def wall(self, sender, message):
        """ Send a wall to all connected cli.
        """

        self.logger.info('Wall from %s: %s', sender, message)

        for client in self.iterclients('cli'):
            client.wall(sender, message)

    def register(self, login, role, connection):
        """ Register a new connected account on the server.

        :param login: login of the account
        :param connection: connection to register
        :param tags: tags to add for the client
        """
        client = Client.from_role(role, None, login, self, connection)
        client.logger = self.logger.getChild('clients.%s' % client.login)
        if client.login in self._clients:
            if client.KILL_ALREADY_CONNECTED:
                self.kill(client.login)
            else:
                raise AlreadyRegistered('A client is already connected with this account.')
        client.attach()
        self._clients[client.login] = client
        return client

    def unregister(self, client):
        """ Unregister a client.
        """
        del self._clients[client.login]

    def run(self):
        """ Run the server mainloop.
        """

        # Register accounts on the database:
        self._update_accounts()

        # Running server internal jobs:
        self.jobs.spawn(KillOldCliJob, None, system=True,
                        settings={'server': self,
                                  'maxcon': self._maxcon,
                                  'maxidle': self._maxidle})

        self.logger.debug('Running rpc mainloop')
        self.rpc.run()

    def get_client(self, login):
        """ Get a connected client by its login.

        :param login: login of the connection to get
        :return: the client instance
        """
        return self._clients[login]

    def kill(self, login):
        """ Disconnect from the server the client identified by provided login.

        :param login: the login of the user to disconnect
        :throws NotConnectedAccount: when provided account is not connected (or
            if account doesn't exists).
        """

        client = self._clients.get(login)
        if client is None:
            raise NotConnectedAccountError('The account %s is not '
                                           'connected' % login)
        client.shutdown()

    def save_motd(self, motd):
        """ Save a new message of the day.
        """
        with open(self.motd_filename, 'w') as fmotd:
            fmotd.write(motd)

    def load_motd(self):
        """ Get the current message of the day.
        """
        try:
            with open(self.motd_filename) as fmotd:
                return fmotd.read()
        except IOError as err:
            if err.errno == 2:
                return ''  # Return empty MOTD
            raise

    def filter(self, tql_response, requester, method):
        """ Filter the provided TqlResponse object using rules matching the
            provided requester.
        """

        requestor = tql_response._requestor

        allowed_result = TqlResponse(requestor)
        deny_rules = []
        for rule in self.rights.iter_rules_method(method):
            # Is the rule matching the requester object:
            match = RE_FAST_RULE_MATCHING.match(rule.match)
            if match and match.group(1) != requester:
                continue
            elif requester not in self.db.raw_query(rule.match):
                continue
            if rule.action == rule.ACCEPT:
                allowed_result |= tql_response & self.db.raw_query(rule.tql)
            elif rule.action == rule.DENY:
                # Defer deny action to the end of process:
                deny_rules.append(rule)
        for rule in deny_rules:
            allowed_result = allowed_result - self.db.raw_query(rule.tql)

        return allowed_result

    def check(self, query, requester, method):
        """ Check if the requester have right to call method on objects
            matched by the provided query.
        """

        matched = self.db.raw_query(query)
        filtered = self.filter(matched, requester, method)
        # Compare len of result is sufficient since the result of
        # a filter operation is warranty to be a subset of its input:
        return len(filtered) == len(matched)

    def check_method(self, requester, method):
        """ Check if the requester have right to call method.
        """
        ok = False  # Default policy is to reject
        for rule in self.rights.iter_rules_method(method):
            match = RE_FAST_RULE_MATCHING.match(rule.match)
            if match and match.group(1) != requester:
                continue
            elif requester not in self.db.raw_query(rule.match):
                continue
            if rule.action == rule.ACCEPT:
                ok = True
            if rule.action == rule.DENY:
                return False
        return ok

    def list(self, query, show=None, requester=None, method=None):
        """ List object authorized for requester object, using method.
        """

        if method is not None and requester is None:
            raise ValueError('Method defined but not requester')

        self._update_accounts()

        result = self.db.raw_query(query)
        if method:
            result = self.filter(result, requester, method)

        # Render the object to dict showing specified tags:
        objects = []
        if show is None:
            show = ()
        for obj in result:
            tags_to_show = set(obj.itermatchingtags(show)) | set(obj.show_tags)
            result._requestor.fetch((obj,), [t.name for t in tags_to_show])
            objects.append(obj.to_dict([t.name for t in tags_to_show]))
        return objects
