#!/usr/bin/env python #coding:utf8 import select import threading class ConnectionManager(object): ''' Base class for all connection manager classes. ''' # The timeout to wait before the poll call release the hand with no events: POLL_TIMEOUT = 1 # Masks for fd registration on poll object: MASK_NORMAL = (select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP) MASK_WRITABLE = MASK_NORMAL | select.EPOLLOUT def __init__(self): self._poll = select.epoll() self._running = True self._received_msg = {} self._wait_groups = {} self._poll_callbacks = {} def register(self, fd, callback, *args, **kwargs): ''' Register an fd on the poll object with the specified callback. The callback will be called each time poller drop an event for the specified fd. Extra args will be passed to the callback after fd and events. :param fd: the fd to register :param callback: the callable to use on event :param *args, **kwargs: extra arguments passed to the callback ''' if hasattr(fd, 'fileno'): fd = fd.fileno() self._poll_callbacks[fd] = {'func': callback, 'extra': args, 'kwextra': kwargs} self._poll.register(fd, ConnectionManager.MASK_NORMAL) def unregister(self, fd): ''' Unregister the specified fd from the manager. :param fd: the fd to unregister. ''' self._poll.unregister(fd) del self._poll_callbacks[fd] def is_running(self): return self._running def run(self): ''' Run the main loop of the :class:`ConnectionManager`. It will catch events on registered :class:`RpcConnection` and process them. ''' while self._running: try: events = self._poll.poll(ConnectionManager.POLL_TIMEOUT) except IOError: pass else: for fd, event in events: if fd in self._poll_callbacks: cb = self._poll_callbacks[fd] cb['func'](fd, event, *cb['extra'], **cb['kwextra']) def start(self, daemonize=False): ''' Run the main loop in a separated thread. :param daemonize: set the thread daemon state ''' t = threading.Thread(target=self.run) t.daemon = daemonize t.start() def wait(self, msg_id_set, timeout=None, wait_all=True): ''' Wait for the asynchronous messages in ``msg_id_set``. When the timeout argument is present and not ``None``, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). You can also set ``wait_all`` to False if you want to unlock the call when the first response is received. :param msg_id_set: set of message to wait :type msg_id_set: :class:`frozenset` :param timeout: timeout value or None to disable timeout (default: None) :type timeout: :class:`int` or :class:`None` :param wait_all: wait for all messages (default: True) :type wait_all: :class:`bool` .. warning: This is important that ``msg_id_set`` is a :class:`frozenset` and not a :class:`set`. ''' waiter = {'event': threading.Event(), 'wait_all': wait_all} self._wait_groups.setdefault(msg_id_set, waiter) already_completed = self._check_waiter(msg_id_set) if not already_completed: waiter['event'].wait(timeout=timeout) # Clean the call list on each attached RpcConnection for connection in self.all_connections(): connection.clean_all_calls(msg_id_set) # Get the messages: messages = [] for msg_id, msg in self._received_msg.items(): if msg_id in msg_id_set: messages.append(msg) del self._received_msg[msg_id] waiter['responses'] = tuple(messages) messages = waiter['responses'] del self._wait_groups[msg_id_set] return messages def signal_arrival(self, message): ''' Signal the arrival of a new message to the :class:`ConnectionManager`. This method is ordinary called by the :class:`RpcConnections` objects, when a response to an asynchronous call is received. :param message: the message received ''' self._received_msg[message['id']] = message for waitset in self._wait_groups.keys(): self._check_waiter(waitset) def _check_waiter(self, waitset): ''' Check if a waitset is completed and process it. :param waitset: the waitset to check :return: True if waitset is completed else None ''' # Make a set of received messages ids: recv_msg = set(self._received_msg) try: waiter = self._wait_groups[waitset] except KeyError: return False is_ok = (waiter['wait_all'] and waitset <= recv_msg or not waiter['wait_all'] and not recv_msg.isdisjoint(waitset)) if is_ok: # Unlock the event: waiter['event'].set() return True else: return False def all_connections(self): ''' Return all connection attached to this :class:`ConnectionManager`. :return: a set of :class:`RpcConnection` attached to this :class:`ConnectionManager` ''' raise NotImplementedError def shutdown(self): ''' Shutdown the manager properly. ''' self._running = False def data_to_write(self, fd): ''' Method called by a connection to inform the manager that it have data to send. :param connection: the fd which have data to write ''' if fd is not None: self._poll.modify(fd, ConnectionManager.MASK_WRITABLE) def nothing_to_write(self, fd): ''' Method called by a connection to inform the manager that it have no more data to send. :param fd: the fd which have no more data to write ''' if fd is not None: self._poll.modify(fd, ConnectionManager.MASK_NORMAL) def handle_event(self, fd, event): ''' Handle an event and make associated action. This is an abstract method to overload on derived classes. :param fd: the fd that have generated the event :param event: the event as returned by the poller object ''' pass