''' This module contains the RpcConnection class, more informations about this class are located in it docstring. ''' from __future__ import absolute_import import ssl import struct import socket import logging from sjrpc.core.protocols.rpc import RpcProtocol from sjrpc.core.exceptions import RpcError, SocketRpcError class RpcConnection(object): ''' This class manage a single peer connection. :param sock: the socket object of this newly created :class:`RpcConnection` :param *args, **kwargs: arguments to pass to the default rpc protocol automatically registered on label 0. ''' MESSAGE_HEADER = '!HL' SHORTCUTS_MAINRPC = ('call', 'async_call') def __init__(self, sock, *args, **kwargs): #super(RpcConnection, self).__init__() # Sock of this connection: self._sock = sock self._sock.settimeout(None) # Activate TCP keepalive on the connection: #self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # Is the RpcConnection connected to its peer: self._connected = True # Setup self.logger.facility: self.logger = logging.getLogger('sjrpc.%s' % self.getpeername()) # Protocols registered on this connection: self._protocols = {} self.register_protocol(0, RpcProtocol, *args, **kwargs) # Create shortcuts to the main rpc (protocol 0) for convenience: for name in RpcConnection.SHORTCUTS_MAINRPC: setattr(self, name, getattr(self.get_protocol(0), name)) @classmethod def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs): ''' Construct the instance of :class:`RpcConnection` without providing the :class:`socket` object. Socket is automatically created and passed to the standard constructor before to return the new instance. :param addr: the target ip address :param port: the target port :param conn_timeout: the connection operation timeout :param *args, **kwargs: extra argument to pass to the constructor (see constructor doctring) ''' sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(conn_timeout) sock.connect((addr, port)) return cls(sock, *args, **kwargs) @classmethod def from_addr_ssl(cls, addr, port, cert=None, conn_timeout=30, *args, **kwargs): ''' Construct :class:`RpcConnection` instance like :meth:`from_addr`, but enable ssl on socket. :param cert: ssl certificate or None for ssl without certificat ''' connection = cls.from_addr(addr, port, conn_timeout, *args, **kwargs) req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED connection._sock = ssl.wrap_socket(connection._sock, certfile=cert, cert_reqs=req, ssl_version=ssl.PROTOCOL_TLSv1) return connection def __repr__(self): return '' def __hash__(self): return self._sock.__hash__() def run(self): ''' Inbound message processing loop. ''' while self._connected: try: self._dispatch() except SocketRpcError: # If SocketRpcError occurs while dispatching, shutdown the # connection if it not already shutdown: if self._connected: self.shutdown() def _dispatch(self): ''' Read next message from socket and dispatch it to accoding protocol handler. ''' # Read the header: buf = self.recv_until(struct.calcsize(RpcConnection.MESSAGE_HEADER)) label, pl_size = struct.unpack(RpcConnection.MESSAGE_HEADER, buf) # Get the registered protocol for the specified label proto = self._protocols.get(label) if proto is not None: proto.handle(label, pl_size) def send(self, label, payload): ''' Low level method to send a message through the socket, generally used by protocols. ''' if not self._connected: raise RpcError('RpcError', 'Not connected to the peer') size = len(payload) header = struct.pack(RpcConnection.MESSAGE_HEADER, label, size) try: self._sock.sendall(header + payload) except socket.error as err: errmsg = 'Fatal error while sending through socket: %s' % err self.logger.error(errmsg) raise RpcError('SocketError', errmsg) # # Public API # def register_protocol(self, label, protocol_class, *args, **kwargs): ''' Register a new protocol for the specified label. ''' if label in self._protocols: raise KeyError('A protocol is already registered for this label') elif not isinstance(label, int): raise ValueError('Label must be an integer') self._protocols[label] = protocol_class(self, label, *args, **kwargs) return self._protocols[label] def unregister_protocol(self, label): ''' Unregister the specified protocol label for this connection. ''' if label in self._protocols and label != 0: del self._protocols[label] else: raise KeyError('No protocol registered for this label') def get_protocol(self, label): ''' Get the protocol registered for specified label. ''' proto = self._protocols.get(label) if proto is None: raise KeyError('No protocol registered for this label') return proto def shutdown(self): ''' Shutdown this connection. ''' # Shutdown each registered protocols: for proto in self._protocols.itervalues(): proto.shutdown() # Close the connection socket: self._connected = False try: self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() except socket.error as err: #self.logger.warn('Error while socket close: %s', err) pass def get_handler(self): ''' Return the handler binded to the :class:`RpcConnection`. :return: binded handler ''' return self._handler def set_handler(self, handler): ''' Define a new handler for this connection. :param handler: the new handler to define. ''' self._handler = handler def get_fd(self): ''' Get the file descriptor of the socket managed by this connection. :return: the file descriptor number of the socket ''' try: return self._sock.fileno() except socket.error: return None def getpeername(self): ''' Get the peer name. :return: string representing the peer name ''' return '%s:%s' % self._sock.getpeername() def recv_until(self, bufsize, flags=None): ''' Read socket until bufsize is received. ''' buf = '' while len(buf) < bufsize: remains = bufsize - len(buf) try: received = self._sock.recv(remains) except socket.error as err: if not self._connected: raise SocketRpcError('Not connected to the peer') errmsg = 'Fatal error while receiving from socket: %s' % err self.logger.error(errmsg) raise SocketRpcError(errmsg) # Handle peer disconnection: if not received: self.logger.info('Connection reset by peer') self.shutdown() buf += received return buf class GreenRpcConnection(RpcConnection): ''' Cooperative RpcConnection to use with Gevent. ''' def __init__(self, *args, **kwargs): super(GreenRpcConnection, self).__init__(*args, **kwargs) self._greenlet = None @classmethod def from_addr(cls, addr, port, conn_timeout=30.0, *args, **kwargs): ''' Construct the instance of :class:`RpcConnection` without providing the :class:`socket` object. Socket is automatically created and passed to the standard constructor before to return the new instance. :param addr: the target ip address :param port: the target port :param conn_timeout: the connection operation timeout :param *args, **kwargs: extra argument to pass to the constructor (see constructor doctring) ''' import gevent.socket sock = gevent.socket.create_connection((addr, port), conn_timeout) return cls(sock, *args, **kwargs) @classmethod def from_addr_ssl(cls, addr, port, cert, conn_timeout=30, *args, **kwargs): ''' Construct :class:`RpcConnection` instance like :meth:`from_addr`, but enable ssl on socket. :param cert: ssl certificate or None for ssl without certificat ''' import gevent.socket sock = gevent.socket.create_connection((addr, port), conn_timeout) req = ssl.CERT_NONE if cert is None else ssl.CERT_REQUIRED sock = gevent.ssl.SSLSocket(sock, certfile=None, cert_reqs=req, ssl_version=ssl.PROTOCOL_TLSv1) return cls(sock, *args, **kwargs) def run(self): import gevent self._greenlet = gevent.spawn(self.run) self._greenlet.join() def shutdown(self): super(GreenRpcConnection, self).shutdown() if self._greenlet is not None: self._greenlet.kill()