Skip to content
connectionmanagers.py 6.78 KiB
Newer Older
Antoine Millet's avatar
Antoine Millet committed
#!/usr/bin/env python
#coding:utf8

import select
import threading

class ConnectionManager(object):
    '''
    Base class for all connection manager classes.
    '''
    
    # The timeout to wait before the poll call release the hand with no events:
    POLL_TIMEOUT = 1
    
    # Masks for fd registration on poll object:
    MASK_NORMAL = (select.EPOLLIN | select.EPOLLPRI | 
                   select.EPOLLERR | select.EPOLLHUP)
    MASK_WRITABLE = MASK_NORMAL | select.EPOLLOUT
    
    def __init__(self):
        self._poll = select.epoll()
        self._running = True
        self._received_msg = {}
        self._wait_groups = {}
        self._poll_callbacks = {}

    def register(self, fd, callback, *args, **kwargs):
Antoine Millet's avatar
Antoine Millet committed
        '''
        Register an fd on the poll object with the specified callback. The
        callback will be called each time poller drop an event for the specified
        fd. Extra args will be passed to the callback after fd and events.

        :param fd: the fd to register
        :param callback: the callable to use on event
        :param *args, **kwargs: extra arguments passed to the callback
Antoine Millet's avatar
Antoine Millet committed
        '''

        if hasattr(fd, 'fileno'):
            fd = fd.fileno()
        self._poll_callbacks[fd] = {'func': callback,
                                    'extra': args,
                                    'kwextra': kwargs}
        self._poll.register(fd, ConnectionManager.MASK_NORMAL)

    def unregister(self, fd):
        '''
        Unregister the specified fd from the manager.

        :param fd: the fd to unregister.
        '''

        self._poll.unregister(fd)
        del self._poll_callbacks[fd]
Antoine Millet's avatar
Antoine Millet committed

    def is_running(self):
        return self._running
    
    def run(self):
        '''
        Run the main loop of the :class:`ConnectionManager`. It will catch 
        events on registered :class:`RpcConnection` and process them.
        '''
        
        while self._running:
            try:
                events = self._poll.poll(ConnectionManager.POLL_TIMEOUT)
            except IOError:
                pass
            else:
                for fd, event in events:
                    if fd in self._poll_callbacks:
                        cb = self._poll_callbacks[fd]
                        cb['func'](fd, event, *cb['extra'], **cb['kwextra'])
Antoine Millet's avatar
Antoine Millet committed

    def start(self, daemonize=False):
        '''
        Run the main loop in a separated thread.

        :param daemonize: set the thread daemon state
        '''
        
        t = threading.Thread(target=self.run)
        t.daemon = daemonize
        t.start()

    def wait(self, msg_id_set, timeout=None, wait_all=True):
        '''
        Wait for the asynchronous messages in ``msg_id_set``.

        When the timeout argument is present and not ``None``, it should be a
        floating point number specifying a timeout for the operation in
        seconds (or fractions thereof).

        You can also set ``wait_all`` to False if you want to unlock the call
        when the first response is received.
        
        :param msg_id_set: set of message to wait
        :type msg_id_set: :class:`frozenset`
        :param timeout: timeout value or None to disable timeout (default: None)
        :type timeout: :class:`int` or :class:`None`
        :param wait_all: wait for all messages (default: True)
        :type wait_all: :class:`bool`
        
        .. warning:
           This is important that ``msg_id_set`` is a :class:`frozenset`
           and not a :class:`set`.
        '''

        waiter = {'event': threading.Event(), 'wait_all': wait_all}
        self._wait_groups.setdefault(msg_id_set, waiter)
        already_completed = self._check_waiter(msg_id_set)

        if not already_completed:
            waiter['event'].wait(timeout=timeout)

        # Clean the call list on each attached RpcConnection
        for connection in self.all_connections():
            connection.clean_all_calls(msg_id_set)
        
        # Get the messages:
        messages = []

        for msg_id, msg in self._received_msg.items():
            if msg_id in msg_id_set:
                messages.append(msg)
                del self._received_msg[msg_id]
        waiter['responses'] = tuple(messages)

Antoine Millet's avatar
Antoine Millet committed
        messages = waiter['responses']
        del self._wait_groups[msg_id_set]

        return messages

    def signal_arrival(self, message):
        '''
        Signal the arrival of a new message to the :class:`ConnectionManager`.
        This method is ordinary called by the :class:`RpcConnections` objects,
        when a response to an asynchronous call is received.

        :param message: the message received
        '''

        self._received_msg[message['id']] = message
        for waitset in self._wait_groups.keys():
            self._check_waiter(waitset)

    def _check_waiter(self, waitset):
        '''
        Check if a waitset is completed and process it.

        :param waitset: the waitset to check
        :return: True if waitset is completed else None
        '''

        # Make a set of received messages ids:
        recv_msg = set(self._received_msg)
        
        try:
            waiter = self._wait_groups[waitset]
        except KeyError:
            return False
            
        is_ok =  (waiter['wait_all'] and waitset <= recv_msg
                  or not waiter['wait_all'] and not recv_msg.isdisjoint(waitset))

        if is_ok:
            # Unlock the event:
            waiter['event'].set()

            return True
        else:
            return False

    def all_connections(self):
        '''
        Return all connection attached to this :class:`ConnectionManager`.

        :return: a set of :class:`RpcConnection` attached
            to this :class:`ConnectionManager`
        '''

        raise NotImplementedError

    def shutdown(self):
        '''
        Shutdown the manager properly.
        '''
        
        self._running = False
    
    def data_to_write(self, fd):
Antoine Millet's avatar
Antoine Millet committed
        '''
        Method called by a connection to inform the manager that it have data
        to send.
        
        :param connection: the fd which have data to write
Antoine Millet's avatar
Antoine Millet committed
        '''

        if fd is not None:
            self._poll.modify(fd, ConnectionManager.MASK_WRITABLE)
Antoine Millet's avatar
Antoine Millet committed
        
    def nothing_to_write(self, fd):
Antoine Millet's avatar
Antoine Millet committed
        '''
        Method called by a connection to inform the manager that it have no
        more data to send.

        :param fd: the fd which have no more data to write
Antoine Millet's avatar
Antoine Millet committed
        '''

        if fd is not None:
            self._poll.modify(fd, ConnectionManager.MASK_NORMAL)
Antoine Millet's avatar
Antoine Millet committed
    
    def handle_event(self, fd, event):
        '''
        Handle an event and make associated action. This is an abstract method to
        overload on derived classes.
        
        :param fd: the fd that have generated the event
        :param event: the event as returned by the poller object
        '''
        pass