Skip to content
Snippets Groups Projects
Commit 36dffbb1 authored by Antoine Millet's avatar Antoine Millet
Browse files

Added capabilities negociation with fallback to old sjRpc mode.

On connection, the new sjRpc work in compatibility mode with the old one. The
label is not defined on frames, and a capabilities special message is sent to the
peer. Old sjRpc ignore this kind on message (but log a warning "malformed
message"), unlike new version which disable this compatibility mode on
receipt.
parent 3e99529b
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ class RpcProtocol(object): ...@@ -12,6 +12,7 @@ class RpcProtocol(object):
REQUEST_MESSAGE = {'id': None, 'method': None, 'args': [], 'kwargs': {}} REQUEST_MESSAGE = {'id': None, 'method': None, 'args': [], 'kwargs': {}}
RESPONSE_MESSAGE = {'id': None, 'return': None, 'error': None} RESPONSE_MESSAGE = {'id': None, 'return': None, 'error': None}
SPECIAL_MESSAGE = {'special': None}
''' '''
:param connection: the connection serving this :class:`RpcProtocol` :param connection: the connection serving this :class:`RpcProtocol`
...@@ -115,6 +116,8 @@ class RpcProtocol(object): ...@@ -115,6 +116,8 @@ class RpcProtocol(object):
self._handle_request(message) self._handle_request(message)
elif set(RpcProtocol.RESPONSE_MESSAGE) <= set(message): elif set(RpcProtocol.RESPONSE_MESSAGE) <= set(message):
self._handle_response(message) self._handle_response(message)
elif set(RpcProtocol.SPECIAL_MESSAGE) <= set(message):
self._handle_special(message)
else: else:
self.logger.debug('Malformed message received: %s', message) self.logger.debug('Malformed message received: %s', message)
...@@ -163,6 +166,17 @@ class RpcProtocol(object): ...@@ -163,6 +166,17 @@ class RpcProtocol(object):
# Finally, delete the call from the current running call list: # Finally, delete the call from the current running call list:
del self._calls[message['id']] del self._calls[message['id']]
def _handle_special(self, message):
'''
Handle special message.
'''
if message['special'] == 'capabilities':
if self._label == 0:
self._connection.set_capabilities(message.get('capabilities'))
else:
self.logger.warning('Capabilities message received by non-zero'
' rpc.')
def _send(self, message): def _send(self, message):
''' '''
Low level method to encode a message in json, calculate it size, and Low level method to encode a message in json, calculate it size, and
...@@ -218,6 +232,17 @@ class RpcProtocol(object): ...@@ -218,6 +232,17 @@ class RpcProtocol(object):
self._send(msg) self._send(msg)
def send_special(self, special, **kwargs):
'''
Send a "special" message to the peer.
:param special: type of the special message
:param \*\*kwargs: fields of the special message
'''
msg = {'special': special}
msg.update(kwargs)
self._send(msg)
def response(self, msg_id, returned): def response(self, msg_id, returned):
''' '''
Send an "return" response to the peer. Send an "return" response to the peer.
......
...@@ -25,13 +25,13 @@ class RpcConnection(object): ...@@ -25,13 +25,13 @@ class RpcConnection(object):
''' '''
MESSAGE_HEADER = '!HL' MESSAGE_HEADER = '!HL'
MESSAGE_HEADER_FALLBACK = '!L'
SHORTCUTS_MAINRPC = ('call', 'async_call') SHORTCUTS_MAINRPC = ('call', 'async_call')
def __init__(self, sock, *args, **kwargs): def __init__(self, sock, *args, **kwargs):
#super(RpcConnection, self).__init__()
# Sock of this connection: # Sock of this connection:
self._sock = sock self._sock = sock
self._sock.settimeout(None) sock.setblocking(True)
# Activate TCP keepalive on the connection: # Activate TCP keepalive on the connection:
#self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) #self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
...@@ -50,6 +50,13 @@ class RpcConnection(object): ...@@ -50,6 +50,13 @@ class RpcConnection(object):
for name in RpcConnection.SHORTCUTS_MAINRPC: for name in RpcConnection.SHORTCUTS_MAINRPC:
setattr(self, name, getattr(self.get_protocol(0), name)) setattr(self, name, getattr(self.get_protocol(0), name))
# By default, enter in fallback mode, no label, all frames are
# redirected on Rpc0:
self.fallback = True
# Send our capabilities to the peer
self._remote_capabilities = None
self._send_capabilities()
@classmethod @classmethod
def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs): def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs):
''' '''
...@@ -66,6 +73,7 @@ class RpcConnection(object): ...@@ -66,6 +73,7 @@ class RpcConnection(object):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(conn_timeout) sock.settimeout(conn_timeout)
sock.connect((addr, port)) sock.connect((addr, port))
sock.setblocking(True)
return cls(sock, *args, **kwargs) return cls(sock, *args, **kwargs)
@classmethod @classmethod
...@@ -77,12 +85,14 @@ class RpcConnection(object): ...@@ -77,12 +85,14 @@ class RpcConnection(object):
:param cert: ssl certificate or None for ssl without certificat :param cert: ssl certificate or None for ssl without certificat
''' '''
connection = cls.from_addr(addr, port, conn_timeout, *args, **kwargs) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(conn_timeout)
sock.connect((addr, port))
sock.setblocking(True)
req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED
connection._sock = ssl.wrap_socket(connection._sock, certfile=cert, sock = ssl.wrap_socket(sock, certfile=cert, cert_reqs=req,
cert_reqs=req, ssl_version=ssl.PROTOCOL_TLSv1)
ssl_version=ssl.PROTOCOL_TLSv1) return cls(sock, *args, **kwargs)
return connection
def __repr__(self): def __repr__(self):
return '<RpcConnection object>' return '<RpcConnection object>'
...@@ -103,6 +113,26 @@ class RpcConnection(object): ...@@ -103,6 +113,26 @@ class RpcConnection(object):
if self._connected: if self._connected:
self.shutdown() self.shutdown()
def _enable_fallback(self):
pass
def _disable_fallback(self):
pass
def _send_capabilities(self):
'''
Send capabilities to the peer, only work in fallback mode for
compatibility with old sjRpc.
Send a special message through the Rpc0 with these fields:
- special: 'capabilities'
- capabilities: {'version': REMOTE_VERSION, 'capabilities': []}
'''
from sjrpc import __version__
cap = {'version': __version__, 'capabilities':['rpc', 'tunnel']}
rpc0 = self.get_protocol(0)
rpc0.send_special('capabilities', capabilities=cap)
def _dispatch(self): def _dispatch(self):
''' '''
Read next message from socket and dispatch it to accoding protocol Read next message from socket and dispatch it to accoding protocol
...@@ -110,8 +140,13 @@ class RpcConnection(object): ...@@ -110,8 +140,13 @@ class RpcConnection(object):
''' '''
# Read the header: # Read the header:
buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER)) if self.fallback:
label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf) buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK))
pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, buf)[0]
label = 0
else:
buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER))
label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf)
# Get the registered protocol for the specified label # Get the registered protocol for the specified label
proto = self._protocols.get(label) proto = self._protocols.get(label)
if proto is not None: if proto is not None:
...@@ -125,9 +160,18 @@ class RpcConnection(object): ...@@ -125,9 +160,18 @@ class RpcConnection(object):
if not self._connected: if not self._connected:
raise RpcError('RpcError', 'Not connected to the peer') raise RpcError('RpcError', 'Not connected to the peer')
size = len(payload) size = len(payload)
header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size) if self.fallback:
header = struct.pack(RpcConnection.MESSAGE_HEADER_FALLBACK, size)
else:
header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size)
try: try:
self._sock.sendall(header + payload) if self.fallback:
data = header + payload
while data:
self._sock.sendall(data[:4096])
data = data[4096:]
else:
self._sock.sendall(header + payload)
except socket.error as err: except socket.error as err:
errmsg = 'Fatal error while sending through socket: %s' % err errmsg = 'Fatal error while sending through socket: %s' % err
self.logger.error(errmsg) self.logger.error(errmsg)
...@@ -137,6 +181,15 @@ class RpcConnection(object): ...@@ -137,6 +181,15 @@ class RpcConnection(object):
# Public API # Public API
# #
def set_capabilities(self, capabilities):
'''
Set capabilities of remote host (and disable fallback mode).
Should be called by Rpc0 when the peer send its capabilities message.
'''
self._remote_capabilities = capabilities
self.fallback = False
def register_protocol(self, label, protocol_class, *args, **kwargs): def register_protocol(self, label, protocol_class, *args, **kwargs):
''' '''
Register a new protocol for the specified label. Register a new protocol for the specified label.
...@@ -234,6 +287,8 @@ class RpcConnection(object): ...@@ -234,6 +287,8 @@ class RpcConnection(object):
except socket.error as err: except socket.error as err:
if not self._connected: if not self._connected:
raise SocketRpcError('Not connected to the peer') raise SocketRpcError('Not connected to the peer')
elif err.errno == 11:
continue
errmsg = 'Fatal error while receiving from socket: %s' % err errmsg = 'Fatal error while receiving from socket: %s' % err
self.logger.error(errmsg) self.logger.error(errmsg)
raise SocketRpcError(errmsg) raise SocketRpcError(errmsg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment