diff --git a/sjrpc/core/protocols/rpc.py b/sjrpc/core/protocols/rpc.py index 79d457c2c82650aed1013cdbbb759324bd832c5a..f76291e40319e9110ef707e98e495a0eb4e52752 100644 --- a/sjrpc/core/protocols/rpc.py +++ b/sjrpc/core/protocols/rpc.py @@ -12,6 +12,7 @@ class RpcProtocol(object): REQUEST_MESSAGE = {'id': None, 'method': None, 'args': [], 'kwargs': {}} RESPONSE_MESSAGE = {'id': None, 'return': None, 'error': None} + SPECIAL_MESSAGE = {'special': None} ''' :param connection: the connection serving this :class:`RpcProtocol` @@ -115,6 +116,8 @@ class RpcProtocol(object): self._handle_request(message) elif set(RpcProtocol.RESPONSE_MESSAGE) <= set(message): self._handle_response(message) + elif set(RpcProtocol.SPECIAL_MESSAGE) <= set(message): + self._handle_special(message) else: self.logger.debug('Malformed message received: %s', message) @@ -163,6 +166,17 @@ class RpcProtocol(object): # Finally, delete the call from the current running call list: del self._calls[message['id']] + def _handle_special(self, message): + ''' + Handle special message. + ''' + if message['special'] == 'capabilities': + if self._label == 0: + self._connection.set_capabilities(message.get('capabilities')) + else: + self.logger.warning('Capabilities message received by non-zero' + ' rpc.') + def _send(self, message): ''' Low level method to encode a message in json, calculate it size, and @@ -218,6 +232,17 @@ class RpcProtocol(object): self._send(msg) + def send_special(self, special, **kwargs): + ''' + Send a "special" message to the peer. + + :param special: type of the special message + :param \*\*kwargs: fields of the special message + ''' + msg = {'special': special} + msg.update(kwargs) + self._send(msg) + def response(self, msg_id, returned): ''' Send an "return" response to the peer. diff --git a/sjrpc/core/rpcconnection.py b/sjrpc/core/rpcconnection.py index 8a9c910e00089335059db042a40a08fe9b3e54c9..470c2e98e12b87ddebf28ee3435ecf33a6ac25de 100644 --- a/sjrpc/core/rpcconnection.py +++ b/sjrpc/core/rpcconnection.py @@ -25,13 +25,13 @@ class RpcConnection(object): ''' MESSAGE_HEADER = '!HL' + MESSAGE_HEADER_FALLBACK = '!L' SHORTCUTS_MAINRPC = ('call', 'async_call') def __init__(self, sock, *args, **kwargs): - #super(RpcConnection, self).__init__() # Sock of this connection: self._sock = sock - self._sock.settimeout(None) + sock.setblocking(True) # Activate TCP keepalive on the connection: #self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) @@ -50,6 +50,13 @@ class RpcConnection(object): for name in RpcConnection.SHORTCUTS_MAINRPC: setattr(self, name, getattr(self.get_protocol(0), name)) + # By default, enter in fallback mode, no label, all frames are + # redirected on Rpc0: + self.fallback = True + # Send our capabilities to the peer + self._remote_capabilities = None + self._send_capabilities() + @classmethod def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs): ''' @@ -66,6 +73,7 @@ class RpcConnection(object): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(conn_timeout) sock.connect((addr, port)) + sock.setblocking(True) return cls(sock, *args, **kwargs) @classmethod @@ -77,12 +85,14 @@ class RpcConnection(object): :param cert: ssl certificate or None for ssl without certificat ''' - connection = cls.from_addr(addr, port, conn_timeout, *args, **kwargs) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(conn_timeout) + sock.connect((addr, port)) + sock.setblocking(True) req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED - connection._sock = ssl.wrap_socket(connection._sock, certfile=cert, - cert_reqs=req, - ssl_version=ssl.PROTOCOL_TLSv1) - return connection + sock = ssl.wrap_socket(sock, certfile=cert, cert_reqs=req, + ssl_version=ssl.PROTOCOL_TLSv1) + return cls(sock, *args, **kwargs) def __repr__(self): return '' @@ -103,6 +113,26 @@ class RpcConnection(object): if self._connected: self.shutdown() + def _enable_fallback(self): + pass + + def _disable_fallback(self): + pass + + def _send_capabilities(self): + ''' + Send capabilities to the peer, only work in fallback mode for + compatibility with old sjRpc. + + Send a special message through the Rpc0 with these fields: + - special: 'capabilities' + - capabilities: {'version': REMOTE_VERSION, 'capabilities': []} + ''' + from sjrpc import __version__ + cap = {'version': __version__, 'capabilities':['rpc', 'tunnel']} + rpc0 = self.get_protocol(0) + rpc0.send_special('capabilities', capabilities=cap) + def _dispatch(self): ''' Read next message from socket and dispatch it to accoding protocol @@ -110,8 +140,13 @@ class RpcConnection(object): ''' # Read the header: - buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER)) - label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf) + if self.fallback: + buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)) + pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, buf)[0] + label = 0 + else: + buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER)) + label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf) # Get the registered protocol for the specified label proto = self._protocols.get(label) if proto is not None: @@ -125,9 +160,18 @@ class RpcConnection(object): if not self._connected: raise RpcError('RpcError', 'Not connected to the peer') size = len(payload) - header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size) + if self.fallback: + header = struct.pack(RpcConnection.MESSAGE_HEADER_FALLBACK, size) + else: + header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size) try: - self._sock.sendall(header + payload) + if self.fallback: + data = header + payload + while data: + self._sock.sendall(data[:4096]) + data = data[4096:] + else: + self._sock.sendall(header + payload) except socket.error as err: errmsg = 'Fatal error while sending through socket: %s' % err self.logger.error(errmsg) @@ -137,6 +181,15 @@ class RpcConnection(object): # Public API # + def set_capabilities(self, capabilities): + ''' + Set capabilities of remote host (and disable fallback mode). + + Should be called by Rpc0 when the peer send its capabilities message. + ''' + self._remote_capabilities = capabilities + self.fallback = False + def register_protocol(self, label, protocol_class, *args, **kwargs): ''' Register a new protocol for the specified label. @@ -234,6 +287,8 @@ class RpcConnection(object): except socket.error as err: if not self._connected: raise SocketRpcError('Not connected to the peer') + elif err.errno == 11: + continue errmsg = 'Fatal error while receiving from socket: %s' % err self.logger.error(errmsg) raise SocketRpcError(errmsg)