Skip to content
connectionmanagers.py 6.05 KiB
Newer Older
Antoine Millet's avatar
Antoine Millet committed
#!/usr/bin/env python
#coding:utf8

import select
import threading

class ConnectionManager(object):
    '''
    Base class for all connection manager classes.
    '''
    
    # The timeout to wait before the poll call release the hand with no events:
    POLL_TIMEOUT = 1
    
    # Masks for fd registration on poll object:
    MASK_NORMAL = (select.EPOLLIN | select.EPOLLPRI | 
                   select.EPOLLERR | select.EPOLLHUP)
    MASK_WRITABLE = MASK_NORMAL | select.EPOLLOUT
    
    def __init__(self):
        self._poll = select.epoll()
        self._running = True
        self._received_msg = {}
        self._wait_groups = {}
    
    def register(self, connection):
        '''
        Register a :class:`RpcConnection` object on this manager.
        
        :param connection: the instance of :class:`RpcConnection` to register
        :type param: instance of :class:`RpcConnection`
        '''
        
        self._poll.register(connection.get_fd(), ConnectionManager.MASK_NORMAL)

    def is_running(self):
        return self._running
    
    def run(self):
        '''
        Run the main loop of the :class:`ConnectionManager`. It will catch 
        events on registered :class:`RpcConnection` and process them.
        '''
        
        while self._running:
            events = self._poll.poll(ConnectionManager.POLL_TIMEOUT)
            for fd, event in events:
                self.handle_event(fd, event)

    def start(self, daemonize=False):
        '''
        Run the main loop in a separated thread.

        :param daemonize: set the thread daemon state
        '''
        
        t = threading.Thread(target=self.run)
        t.daemon = daemonize
        t.start()

    def wait(self, msg_id_set, timeout=None, wait_all=True):
        '''
        Wait for the asynchronous messages in ``msg_id_set``.

        When the timeout argument is present and not ``None``, it should be a
        floating point number specifying a timeout for the operation in
        seconds (or fractions thereof).

        You can also set ``wait_all`` to False if you want to unlock the call
        when the first response is received.
        
        :param msg_id_set: set of message to wait
        :type msg_id_set: :class:`frozenset`
        :param timeout: timeout value or None to disable timeout (default: None)
        :type timeout: :class:`int` or :class:`None`
        :param wait_all: wait for all messages (default: True)
        :type wait_all: :class:`bool`
        
        .. warning:
           This is important that ``msg_id_set`` is a :class:`frozenset`
           and not a :class:`set`.
        '''

        waiter = {'event': threading.Event(), 'wait_all': wait_all}
        self._wait_groups.setdefault(msg_id_set, waiter)
        already_completed = self._check_waiter(msg_id_set)

        if not already_completed:
            waiter['event'].wait(timeout=timeout)

        messages = waiter['responses']
        del self._wait_groups[msg_id_set]

        return messages

    def signal_arrival(self, message):
        '''
        Signal the arrival of a new message to the :class:`ConnectionManager`.
        This method is ordinary called by the :class:`RpcConnections` objects,
        when a response to an asynchronous call is received.

        :param message: the message received
        '''

        self._received_msg[message['id']] = message
        for waitset in self._wait_groups.keys():
            self._check_waiter(waitset)

    def _check_waiter(self, waitset):
        '''
        Check if a waitset is completed and process it.

        :param waitset: the waitset to check
        :return: True if waitset is completed else None
        '''

        # Make a set of received messages ids:
        recv_msg = set(self._received_msg)
        
        try:
            waiter = self._wait_groups[waitset]
        except KeyError:
            return False
            
        is_ok =  (waiter['wait_all'] and waitset <= recv_msg
                  or not waiter['wait_all'] and not recv_msg.isdisjoint(waitset))

        if is_ok:
            # Clean the call list on each attached RpcConnection
            for connection in self.all_connections():
                connection.clean_all_calls(waitset)
            
            # Get the messages:
            messages = []

            for msg_id, msg in self._received_msg.items():
                if msg_id in waitset:
                    messages.append(msg)
                    del self._received_msg[msg_id]
            waiter['responses'] = tuple(messages)

            # Unlock the event:
            waiter['event'].set()

            return True
        else:
            return False

    def all_connections(self):
        '''
        Return all connection attached to this :class:`ConnectionManager`.

        :return: a set of :class:`RpcConnection` attached
            to this :class:`ConnectionManager`
        '''

        raise NotImplementedError

    def shutdown(self):
        '''
        Shutdown the manager properly.
        '''
        
        self._running = False
    
    def data_to_write(self, connection):
        '''
        Method called by a connection to inform the manager that it have data
        to send.
        
        :param connection: the :class:`RpcConnection` object which inform the
            manager
        '''

        fd = connection.get_fd()
        self._poll.modify(fd, ConnectionManager.MASK_WRITABLE)
        
    def nothing_to_write(self, connection):
        '''
        Method called by a connection to inform the manager that it have no
        more data to send.

        :param connection: the :class:`RpcConnection` object which inform the 
            manager
        '''

        fd = connection.get_fd()
        self._poll.modify(fd, ConnectionManager.MASK_NORMAL)
    
    def handle_event(self, fd, event):
        '''
        Handle an event and make associated action. This is an abstract method to
        overload on derived classes.
        
        :param fd: the fd that have generated the event
        :param event: the event as returned by the poller object
        '''
        pass