Commit 36dffbb1 authored by Antoine Millet's avatar Antoine Millet
Browse files

Added capabilities negociation with fallback to old sjRpc mode.

On connection, the new sjRpc work in compatibility mode with the old one. The
label is not defined on frames, and a capabilities special message is sent to the
peer. Old sjRpc ignore this kind on message (but log a warning "malformed
message"), unlike new version which disable this compatibility mode on
receipt.
parent 3e99529b
Loading
Loading
Loading
Loading
+25 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ class RpcProtocol(object):

    REQUEST_MESSAGE = {'id': None, 'method': None, 'args': [], 'kwargs': {}}
    RESPONSE_MESSAGE = {'id': None, 'return': None, 'error': None}
    SPECIAL_MESSAGE = {'special': None}

    '''
    :param connection: the connection serving this :class:`RpcProtocol`
@@ -115,6 +116,8 @@ class RpcProtocol(object):
            self._handle_request(message)
        elif set(RpcProtocol.RESPONSE_MESSAGE) <= set(message):
            self._handle_response(message)
        elif set(RpcProtocol.SPECIAL_MESSAGE) <= set(message):
            self._handle_special(message)
        else:
            self.logger.debug('Malformed message received: %s', message)

@@ -163,6 +166,17 @@ class RpcProtocol(object):
            # Finally, delete the call from the current running call list:
            del self._calls[message['id']]

    def _handle_special(self, message):
        '''
        Handle special message.
        '''
        if message['special'] == 'capabilities':
            if self._label == 0:
                self._connection.set_capabilities(message.get('capabilities'))
            else:
                self.logger.warning('Capabilities message received by non-zero'
                                    ' rpc.')

    def _send(self, message):
        '''
        Low level method to encode a message in json, calculate it size, and
@@ -218,6 +232,17 @@ class RpcProtocol(object):

        self._send(msg)

    def send_special(self, special, **kwargs):
        '''
        Send a "special" message to the peer.

        :param special: type of the special message
        :param \*\*kwargs: fields of the special message
        '''
        msg = {'special': special}
        msg.update(kwargs)
        self._send(msg)

    def response(self, msg_id, returned):
        '''
        Send an "return" response to the peer.
+66 −11
Original line number Diff line number Diff line
@@ -25,13 +25,13 @@ class RpcConnection(object):
    '''

    MESSAGE_HEADER = '!HL'
    MESSAGE_HEADER_FALLBACK = '!L'
    SHORTCUTS_MAINRPC = ('call', 'async_call')

    def __init__(self, sock, *args, **kwargs):
        #super(RpcConnection, self).__init__()
        # Sock of this connection:
        self._sock = sock
        self._sock.settimeout(None)
        sock.setblocking(True)

        # Activate TCP keepalive on the connection:
        #self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
@@ -50,6 +50,13 @@ class RpcConnection(object):
        for name in RpcConnection.SHORTCUTS_MAINRPC:
            setattr(self, name, getattr(self.get_protocol(0), name))

        # By default, enter in fallback mode, no label, all frames are
        # redirected on Rpc0:
        self.fallback = True
        # Send our capabilities to the peer
        self._remote_capabilities = None
        self._send_capabilities()

    @classmethod
    def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs):
        '''
@@ -66,6 +73,7 @@ class RpcConnection(object):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(conn_timeout)
        sock.connect((addr, port))
        sock.setblocking(True)
        return cls(sock, *args, **kwargs)

    @classmethod
@@ -77,12 +85,14 @@ class RpcConnection(object):

        :param cert: ssl certificate or None for ssl without certificat
        '''
        connection = cls.from_addr(addr, port, conn_timeout, *args, **kwargs)
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(conn_timeout)
        sock.connect((addr, port))
        sock.setblocking(True)
        req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED
        connection._sock = ssl.wrap_socket(connection._sock, certfile=cert,
                                           cert_reqs=req,
        sock = ssl.wrap_socket(sock, certfile=cert, cert_reqs=req,
                               ssl_version=ssl.PROTOCOL_TLSv1)
        return connection
        return cls(sock, *args, **kwargs)

    def __repr__(self):
        return '<RpcConnection object>'
@@ -103,6 +113,26 @@ class RpcConnection(object):
                if self._connected:
                    self.shutdown()

    def _enable_fallback(self):
        pass

    def _disable_fallback(self):
        pass

    def _send_capabilities(self):
        '''
        Send capabilities to the peer, only work in fallback mode for
        compatibility with old sjRpc.

        Send a special message through the Rpc0 with these fields:
        - special: 'capabilities'
        - capabilities: {'version': REMOTE_VERSION, 'capabilities': []}
        '''
        from sjrpc import __version__
        cap = {'version': __version__, 'capabilities':['rpc', 'tunnel']}
        rpc0 = self.get_protocol(0)
        rpc0.send_special('capabilities', capabilities=cap)

    def _dispatch(self):
        '''
        Read next message from socket and dispatch it to accoding protocol
@@ -110,6 +140,11 @@ class RpcConnection(object):
        '''

        # Read the header:
        if self.fallback:
            buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK))
            pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, buf)[0]
            label = 0
        else:
            buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER))
            label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf)
        # Get the registered protocol for the specified label
@@ -125,8 +160,17 @@ class RpcConnection(object):
        if not self._connected:
            raise RpcError('RpcError', 'Not connected to the peer')
        size = len(payload)
        if self.fallback:
            header = struct.pack(RpcConnection.MESSAGE_HEADER_FALLBACK, size)
        else:
            header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size)
        try:
            if self.fallback:
                data = header + payload
                while data:
                    self._sock.sendall(data[:4096])
                    data = data[4096:]
            else:
                self._sock.sendall(header + payload)
        except socket.error as err:
            errmsg = 'Fatal error while sending through socket: %s' % err
@@ -137,6 +181,15 @@ class RpcConnection(object):
# Public API
#

    def set_capabilities(self, capabilities):
        '''
        Set capabilities of remote host (and disable fallback mode).

        Should be called by Rpc0 when the peer send its capabilities message.
        '''
        self._remote_capabilities = capabilities
        self.fallback = False

    def register_protocol(self, label, protocol_class, *args, **kwargs):
        '''
        Register a new protocol for the specified label.
@@ -234,6 +287,8 @@ class RpcConnection(object):
            except socket.error as err:
                if not self._connected:
                    raise SocketRpcError('Not connected to the peer')
                elif err.errno == 11:
                    continue
                errmsg = 'Fatal error while receiving from socket: %s' % err
                self.logger.error(errmsg)
                raise SocketRpcError(errmsg)