Commit 9d3610c5 authored by Antoine Millet's avatar Antoine Millet
Browse files

Rewrited inbound data processing in order to be compatible with ssl sockets

parent 736fd828
Loading
Loading
Loading
Loading
+50 −28
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ class RpcConnection(object):
    MESSAGE_HEADER = '!HL'
    MESSAGE_HEADER_FALLBACK = '!L'
    MAX_LABEL = 2 ** 16
    DEFAULT_RECV_SIZE = 1024 * 64 # 64kB
    SHORTCUTS_MAINRPC = ('call', 'async_call')

    def __init__(self, sock, loop=None, enable_tcp_keepalive=False,
@@ -93,7 +94,7 @@ class RpcConnection(object):
        self._sock_reader = self.create_watcher(pyev.Io,
                                                fd=self._sock,
                                                events=pyev.EV_READ,
                                                callback=self._dispatch)
                                                callback=self._reader)
        self._sock_reader.start()
        self._sock_writer = self.create_watcher(pyev.Io,
                                                fd=self._sock,
@@ -187,14 +188,14 @@ class RpcConnection(object):
        rpc0 = self.get_protocol(0)
        rpc0.send_special('capabilities', capabilities=cap)

    def _dispatch(self, watcher, revents):
    def _reader(self, watcher, revents):
        '''
        Read next message from socket and dispatch it to accoding protocol
        handler.
        Read socket and feed inbound buffer. Launch the dispatch when all
        data are buffered.
        '''
        # Try to received remaining data from the socket:
        # Read all possible data from the socket:
        try:
            buf = self._sock.recv(self._remains)
            buf = self._sock.recv(RpcConnection.DEFAULT_RECV_SIZE)
        except socket.error as err:
            if (isinstance(err, socket.error) and err.errno
                in RpcConnection.NONBLOCKING_ERRORS):
@@ -204,20 +205,40 @@ class RpcConnection(object):
                return
            else:
                raise
        if not buf:
        # For ssl socket, we need to fetch buffered ssl-side data:
        if isinstance(self._sock, ssl.SSLSocket):
            pending = self._sock.pending()
            if pending:
                buf += self._sock.recv(pending)
        # Empty data on non-blocking socket means that the connection
        # has been closed:
        if not buf:
            self.shutdown()

        self._remains -= len(buf)
        if self._proto_receiving is None:
        self._inbound_buffer += buf
            if self._remains == 0:

        # Process and dispatch all inbound data:
        while self._remains <= len(self._inbound_buffer):
            self._dispatch()

    def _dispatch(self):
        '''
        Read the inbound_buffer, parse and dispatch messages.
        '''

        if self._proto_receiving is None:
            if self.fallback:
                    pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, self._inbound_buffer)[0]
                size = struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)
                buf = self._inbound_buffer[:size]
                self._inbound_buffer = self._inbound_buffer[size:]
                pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, buf)[0]
                label = 0
            else:
                    label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, self._inbound_buffer)
                size = struct.calcsize(RpcConnection.MESSAGE_HEADER)
                buf = self._inbound_buffer[:size]
                self._inbound_buffer = self._inbound_buffer[size:]
                label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf)

            # Get the registered protocol for the specified label:
            self._proto_receiving = self._protocols.get(label)
@@ -228,17 +249,18 @@ class RpcConnection(object):
                self._proto_receiving = Protocol(self, -1)

            self._proto_receiving.start_message(pl_size)
                self._inbound_buffer = ''
                self._remains = pl_size
            self._remains += pl_size
        else:
            size = len(self._inbound_buffer) + self._remains
            buf = self._inbound_buffer[:size]
            self._inbound_buffer = self._inbound_buffer[size:]
            self._proto_receiving.feed(buf)
            if self._remains == 0:
            if self._remains <= 0:
                self._proto_receiving.end_of_message()
                if self.fallback:
                    self._remains = struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)
                    self._remains += struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)
                else:
                    self._remains = struct.calcsize(RpcConnection.MESSAGE_HEADER)
                self._inbound_buffer = ''
                    self._remains += struct.calcsize(RpcConnection.MESSAGE_HEADER)
                self._proto_receiving = None

    def _writer(self, watcher, revent):