Commit 36dffbb1 authored by Antoine Millet's avatar Antoine Millet

Added capabilities negociation with fallback to old sjRpc mode.

On connection, the new sjRpc work in compatibility mode with the old one. The
label is not defined on frames, and a capabilities special message is sent to the
peer. Old sjRpc ignore this kind on message (but log a warning "malformed
message"), unlike new version which disable this compatibility mode on
receipt.
parent 3e99529b
......@@ -12,6 +12,7 @@ class RpcProtocol(object):
REQUEST_MESSAGE = {'id': None, 'method': None, 'args': [], 'kwargs': {}}
RESPONSE_MESSAGE = {'id': None, 'return': None, 'error': None}
SPECIAL_MESSAGE = {'special': None}
'''
:param connection: the connection serving this :class:`RpcProtocol`
......@@ -115,6 +116,8 @@ class RpcProtocol(object):
self._handle_request(message)
elif set(RpcProtocol.RESPONSE_MESSAGE) <= set(message):
self._handle_response(message)
elif set(RpcProtocol.SPECIAL_MESSAGE) <= set(message):
self._handle_special(message)
else:
self.logger.debug('Malformed message received: %s', message)
......@@ -163,6 +166,17 @@ class RpcProtocol(object):
# Finally, delete the call from the current running call list:
del self._calls[message['id']]
def _handle_special(self, message):
'''
Handle special message.
'''
if message['special'] == 'capabilities':
if self._label == 0:
self._connection.set_capabilities(message.get('capabilities'))
else:
self.logger.warning('Capabilities message received by non-zero'
' rpc.')
def _send(self, message):
'''
Low level method to encode a message in json, calculate it size, and
......@@ -218,6 +232,17 @@ class RpcProtocol(object):
self._send(msg)
def send_special(self, special, **kwargs):
'''
Send a "special" message to the peer.
:param special: type of the special message
:param \*\*kwargs: fields of the special message
'''
msg = {'special': special}
msg.update(kwargs)
self._send(msg)
def response(self, msg_id, returned):
'''
Send an "return" response to the peer.
......
......@@ -25,13 +25,13 @@ class RpcConnection(object):
'''
MESSAGE_HEADER = '!HL'
MESSAGE_HEADER_FALLBACK = '!L'
SHORTCUTS_MAINRPC = ('call', 'async_call')
def __init__(self, sock, *args, **kwargs):
#super(RpcConnection, self).__init__()
# Sock of this connection:
self._sock = sock
self._sock.settimeout(None)
sock.setblocking(True)
# Activate TCP keepalive on the connection:
#self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
......@@ -50,6 +50,13 @@ class RpcConnection(object):
for name in RpcConnection.SHORTCUTS_MAINRPC:
setattr(self, name, getattr(self.get_protocol(0), name))
# By default, enter in fallback mode, no label, all frames are
# redirected on Rpc0:
self.fallback = True
# Send our capabilities to the peer
self._remote_capabilities = None
self._send_capabilities()
@classmethod
def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs):
'''
......@@ -66,6 +73,7 @@ class RpcConnection(object):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(conn_timeout)
sock.connect((addr, port))
sock.setblocking(True)
return cls(sock, *args, **kwargs)
@classmethod
......@@ -77,12 +85,14 @@ class RpcConnection(object):
:param cert: ssl certificate or None for ssl without certificat
'''
connection = cls.from_addr(addr, port, conn_timeout, *args, **kwargs)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(conn_timeout)
sock.connect((addr, port))
sock.setblocking(True)
req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED
connection._sock = ssl.wrap_socket(connection._sock, certfile=cert,
cert_reqs=req,
ssl_version=ssl.PROTOCOL_TLSv1)
return connection
sock = ssl.wrap_socket(sock, certfile=cert, cert_reqs=req,
ssl_version=ssl.PROTOCOL_TLSv1)
return cls(sock, *args, **kwargs)
def __repr__(self):
return '<RpcConnection object>'
......@@ -103,6 +113,26 @@ class RpcConnection(object):
if self._connected:
self.shutdown()
def _enable_fallback(self):
pass
def _disable_fallback(self):
pass
def _send_capabilities(self):
'''
Send capabilities to the peer, only work in fallback mode for
compatibility with old sjRpc.
Send a special message through the Rpc0 with these fields:
- special: 'capabilities'
- capabilities: {'version': REMOTE_VERSION, 'capabilities': []}
'''
from sjrpc import __version__
cap = {'version': __version__, 'capabilities':['rpc', 'tunnel']}
rpc0 = self.get_protocol(0)
rpc0.send_special('capabilities', capabilities=cap)
def _dispatch(self):
'''
Read next message from socket and dispatch it to accoding protocol
......@@ -110,8 +140,13 @@ class RpcConnection(object):
'''
# Read the header:
buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER))
label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf)
if self.fallback:
buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK))
pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, buf)[0]
label = 0
else:
buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER))
label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf)
# Get the registered protocol for the specified label
proto = self._protocols.get(label)
if proto is not None:
......@@ -125,9 +160,18 @@ class RpcConnection(object):
if not self._connected:
raise RpcError('RpcError', 'Not connected to the peer')
size = len(payload)
header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size)
if self.fallback:
header = struct.pack(RpcConnection.MESSAGE_HEADER_FALLBACK, size)
else:
header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size)
try:
self._sock.sendall(header + payload)
if self.fallback:
data = header + payload
while data:
self._sock.sendall(data[:4096])
data = data[4096:]
else:
self._sock.sendall(header + payload)
except socket.error as err:
errmsg = 'Fatal error while sending through socket: %s' % err
self.logger.error(errmsg)
......@@ -137,6 +181,15 @@ class RpcConnection(object):
# Public API
#
def set_capabilities(self, capabilities):
'''
Set capabilities of remote host (and disable fallback mode).
Should be called by Rpc0 when the peer send its capabilities message.
'''
self._remote_capabilities = capabilities
self.fallback = False
def register_protocol(self, label, protocol_class, *args, **kwargs):
'''
Register a new protocol for the specified label.
......@@ -234,6 +287,8 @@ class RpcConnection(object):
except socket.error as err:
if not self._connected:
raise SocketRpcError('Not connected to the peer')
elif err.errno == 11:
continue
errmsg = 'Fatal error while receiving from socket: %s' % err
self.logger.error(errmsg)
raise SocketRpcError(errmsg)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment