Newer
Older
'''
This module contains the RpcConnection class, more informations about this
class are located in it docstring.
'''
from __future__ import absolute_import
from sjrpc.core.protocols import Protocol, RpcProtocol, TunnelProtocol
from sjrpc.core.exceptions import (RpcConnectionError, NoFreeLabelError,
FallbackModeEnabledError, SocketError)
import pyev
class RpcConnection(object):
'''
This class manage a single peer connection.
You can wrap an existing socket with the default constructor::
>>> conn = RpcConnection(mysocket)
Or create a new socket automatically with from_addr constructor::
>>> conn = RpcConnection.from_addr(host, port)
If you prefer SSL connection, you can use the from_addr_ssl constructor::
>>> conn = RpcConnection.from_addr_ssl(host, port)
By default, an :class:`RpcProtocol` is created on label 0, you can access
to this rpc through the `conn.rpc` shortcut::
>>> conn.rpc.call('ping')
Also, the connection object expose :meth:`call` and :meth:`async_call`
method from default rpc, so you can use it directly on connection::
>>> conn.call('ping') # Equivalent to the exemple before
.. seealso::
You can read the :ref:`Default rpc, aka Rpc0` section to know more about
the default rpc
:param sock: the socket object of this newly created :class:`RpcConnection`
:param fallback_timeout: set the maximum time to wait the "capabilities"
message before to send anything to the peer. 0 to wait indefinitly, -1
to disable the fallback mode.
:param \*args,\*\*kwargs: arguments to pass to the default rpc protocol
automatically registered on label 0.
NONBLOCKING_ERRORS = (errno.EAGAIN, errno.EWOULDBLOCK)
NONBLOCKING_SSL_ERRORS = (ssl.SSL_ERROR_WANT_READ, )
MESSAGE_HEADER = '!HL'
MESSAGE_HEADER_FALLBACK = '!L'
MAX_LABEL = 2 ** 16
Antoine Millet
committed
DEFAULT_RECV_SIZE = 1024 * 64 # 64kB
SHORTCUTS_MAINRPC = ('call', 'async_call')
def __init__(self, sock, loop=None, enable_tcp_keepalive=False,
fallback_timeout=1.0, *args, **kwargs):
sock.setblocking(False)
# Initialization requires fallback mode disabled:
self.fallback = False
# Get the pyev loop:
if loop is None:
Antoine Millet
committed
self.loop = pyev.Loop()
else:
self.loop = loop
# Activate TCP keepalive on the connection:
if enable_tcp_keepalive:
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# Watcher list:
self._watchers = set()
# Socket inbound/outbound buffers:
self._inbound_buffer = ''
self._outbound_buffer = ''
if fallback_timeout == -1:
self._remains = struct.calcsize(RpcConnection.MESSAGE_HEADER)
else:
self._remains = struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)
self._proto_receiving = None
# Initialize main read/write watchers:
self._sock_reader = self.create_watcher(pyev.Io,
fd=self._sock,
events=pyev.EV_READ,
Antoine Millet
committed
callback=self._reader)
self._sock_reader.start()
self._sock_writer = self.create_watcher(pyev.Io,
fd=self._sock,
events=pyev.EV_WRITE,
callback=self._writer)
Antoine Millet
committed
# Is the RpcConnection connected to its peer:
self._connected = True
# "Need to send" loop signal:
self._need_to_send = self.create_watcher(pyev.Async,
callback=self._cb_need_to_send)
self._need_to_send.start()
self.logger = logging.getLogger('sjrpc.%s' % self.getpeername())
# Protocols registered on this connection:
self._protocols = {}
self.register_protocol(0, RpcProtocol, *args, **kwargs)
# Create shortcuts to the main rpc (protocol 0) for convenience:
for name in RpcConnection.SHORTCUTS_MAINRPC:
setattr(self, name, getattr(self.get_protocol(0), name))
self._event_fallback = threading.Event()
# By default, enter in fallback mode, no label, all frames are
# redirected on Rpc0:
if fallback_timeout != -1:
self.fallback = True
self.create_watcher(pyev.Timer, after=fallback_timeout, repeat=0,
callback=self._cb_set_event_fallback).start()
# Set the event fallback just to send the capability message:
self._event_fallback.set()
self._remote_capabilities = None
self._send_capabilities()
# And clear it after:
self._event_fallback.clear()
def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs):
'''
Construct the instance of :class:`RpcConnection` without providing
the :class:`socket` object. Socket is automatically created and passed
to the standard constructor before to return the new instance.
:param addr: the target ip address
:param port: the target port
:param conn_timeout: the connection operation timeout
:param \*args,\*\*kwargs: extra argument to pass to the constructor (see
constructor doctring)
'''
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(conn_timeout)
sock.setblocking(True)
return cls(sock, *args, **kwargs)
@classmethod
def from_addr_ssl(cls, addr, port, cert=None,
conn_timeout=30, *args, **kwargs):
'''
Construct :class:`RpcConnection` instance like :meth:`from_addr`, but
enable ssl on socket.
:param cert: ssl client certificate or None for ssl without certificat
'''
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(conn_timeout)
sock.connect((addr, port))
sock.setblocking(True)
req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED
sock = ssl.wrap_socket(sock, certfile=cert, cert_reqs=req,
ssl_version=ssl.PROTOCOL_TLSv1)
return cls(sock, *args, **kwargs)
Antoine Millet
committed
return '<RpcConnection object>'
def __hash__(self):
return self._sock.__hash__()
def __nonzero__(self):
return self._connected
def _send_capabilities(self):
'''
Send capabilities to the peer, only work in fallback mode for
compatibility with old sjRpc.
Send a special message through the Rpc0 with these fields:
- special: 'capabilities'
- capabilities: {'version': REMOTE_VERSION, 'capabilities': []}
'''
from sjrpc import __version__
rpc0 = self.get_protocol(0)
rpc0.send_special('capabilities', capabilities=cap)
Antoine Millet
committed
def _reader(self, watcher, revents):
Antoine Millet
committed
Read socket and feed inbound buffer. Launch the dispatch when all
data are buffered.
Antoine Millet
committed
# Read all possible data from the socket:
Antoine Millet
committed
buf = self._sock.recv(RpcConnection.DEFAULT_RECV_SIZE)
except socket.error as err:
if (isinstance(err, socket.error) and err.errno
in RpcConnection.NONBLOCKING_ERRORS):
return
elif (isinstance(err, ssl.SSLError) and err.errno
in RpcConnection.NONBLOCKING_SSL_ERRORS):
return
else:
# If any fatal error is triggered, the connection is shutdown:
self.shutdown()
self.logger.debug('Unexpected socket error: %s', err)
return
Antoine Millet
committed
# For ssl socket, we need to fetch buffered ssl-side data:
if isinstance(self._sock, ssl.SSLSocket):
pending = self._sock.pending()
if pending:
buf += self._sock.recv(pending)
# Empty data on non-blocking socket means that the connection
# has been closed:
self.shutdown()
Antoine Millet
committed
self._inbound_buffer += buf
# Process and dispatch all inbound data:
Antoine Millet
committed
self._dispatch()
def _dispatch(self):
'''
Read the inbound_buffer, parse and dispatch messages.
'''
if self._proto_receiving is None:
Antoine Millet
committed
if self.fallback:
size = struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)
buf = self._inbound_buffer[:size]
self._inbound_buffer = self._inbound_buffer[size:]
pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER_FALLBACK, buf)[0]
label = 0
else:
size = struct.calcsize(RpcConnection.MESSAGE_HEADER)
buf = self._inbound_buffer[:size]
self._inbound_buffer = self._inbound_buffer[size:]
label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf)
Antoine Millet
committed
# Get the registered protocol for the specified label:
self._proto_receiving = self._protocols.get(label)
Antoine Millet
committed
# If frame's label is not binded to a protocol, we create a
# dummy protocol to consume the payload:
if self._proto_receiving is None:
self._proto_receiving = Protocol(self, -1)
Antoine Millet
committed
self._proto_receiving.start_message(pl_size)
self._remains += pl_size
Antoine Millet
committed
size = len(self._inbound_buffer) + self._remains
buf = self._inbound_buffer[:size]
self._inbound_buffer = self._inbound_buffer[size:]
self._proto_receiving.feed(buf)
Antoine Millet
committed
if self._remains <= 0:
self._proto_receiving.end_of_message()
if self.fallback:
Antoine Millet
committed
self._remains += struct.calcsize(RpcConnection.MESSAGE_HEADER_FALLBACK)
Antoine Millet
committed
self._remains += struct.calcsize(RpcConnection.MESSAGE_HEADER)
self._proto_receiving = None
def _writer(self, watcher, revent):
'''
Write data on the socket.
'''
if self._outbound_buffer:
try:
if self.fallback:
sent = self._sock.send(self._outbound_buffer[:4096])
else:
sent = self._sock.send(self._outbound_buffer)
except socket.error as err:
if (isinstance(err, socket.error) and err.errno
in RpcConnection.NONBLOCKING_ERRORS):
return
elif (isinstance(err, ssl.SSLError) and err.errno
in RpcConnection.NONBLOCKING_SSL_ERRORS):
return
errmsg = 'Fatal error while sending through socket: %s' % err
self.logger.error(errmsg)
self._outbound_buffer = self._outbound_buffer[sent:]
if not self._outbound_buffer:
watcher.stop()
def _cb_need_to_send(self, watcher, revents):
self._sock_writer.start()
def _cb_set_event_fallback(self, watcher, revents):
self._event_fallback.set()
#
# Public API
#
@property
def rpc(self):
return self.get_protocol(0)
def run(self):
'''
Main loop execution.
'''
self.loop.start()
def create_watcher(self, watcher_class, **kwargs):
'''
Create a new pyev watcher and return it.
'''
kwargs['loop'] = self.loop
watcher = watcher_class(**kwargs)
self._watchers.add(watcher)
return watcher
def send(self, label, payload):
Low level method to send a message through the socket, generally
used by protocols.
self._event_fallback.wait()
Antoine Millet
committed
if not self._connected:
raise RpcConnectionError('Not connected to the peer')
size = len(payload)
if self.fallback:
header = struct.pack(RpcConnection.MESSAGE_HEADER_FALLBACK, size)
else:
header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size)
self._outbound_buffer += header + payload
self._need_to_send.send()
def set_capabilities(self, capabilities):
'''
Set capabilities of remote host (and disable fallback mode).
Should be called by Rpc0 when the peer send its capabilities message.
'''
self._remote_capabilities = capabilities
self.fallback = False
self._event_fallback.set()
def register_protocol(self, label, protocol_class, *args, **kwargs):
Register a new protocol for the specified label.
if self.fallback:
raise FallbackModeEnabledError('Fallback mode is not compatible '
'with protocols')
if label is None:
for label in xrange(0, RpcConnection.MAX_LABEL):
if label not in self._protocols:
break
else:
raise NoFreeLabelError('No more label number are availables')
if label in self._protocols:
raise KeyError('A protocol is already registered for this label')
elif not isinstance(label, int):
raise ValueError('Label must be an integer')
self._protocols[label] = protocol_class(self, label, *args, **kwargs)
return self._protocols[label]
def unregister_protocol(self, label):
Unregister the specified protocol label for this connection.
if self.fallback:
raise FallbackModeEnabledError('Fallback mode is not compatible '
'with protocols')
if label in self._protocols and label != 0:
del self._protocols[label]
raise KeyError('No protocol registered for this label')
def create_rpc(self, label=None, *args, **kwargs):
'''
Shortcut which can be used to create rpc protocols.
'''
return self.register_protocol(label, RpcProtocol, *args, **kwargs)
def create_tunnel(self, label=None, *args, **kwargs):
'''
Shortcut which can be used to create tunnels protocols.
'''
return self.register_protocol(label, TunnelProtocol, *args, **kwargs)
def get_protocol(self, label):
Get the protocol registered for specified label.
proto = self._protocols.get(label)
if proto is None:
raise KeyError('No protocol registered for this label')
return proto
def shutdown(self):
# Ignore repeated calls to shutdown:
if not self._connected:
return
# Unset connected state:
self._connected = False
self.logger.info('Connection shutdown.')
# Shutdown each registered watcher:
for watcher in self._watchers:
watcher.stop()
# Shutdown each registered protocols:
for proto in self._protocols.itervalues():
proto.shutdown()
Antoine Millet
committed
# Close the connection socket:
try:
self._sock.shutdown(socket.SHUT_RDWR)
self._sock.close()
except socket.error as err:
#self.logger.warn('Error while socket close: %s', err)
pass
def get_fd(self):
'''
Get the file descriptor of the socket managed by this connection.
:return: the file descriptor number of the socket
'''
try:
return self._sock.fileno()
except socket.error:
return None
def getpeername(self):
'''
Get the peer name.
:return: string representing the peer name
'''
return '%s:%s' % self._sock.getpeername()