Commit a57f92dc authored by Antoine Millet's avatar Antoine Millet
Browse files

Added a closing feature to tunnels

parent ec1d396f
Loading
Loading
Loading
Loading
+35 −8
Original line number Diff line number Diff line
@@ -27,6 +27,8 @@ class TunnelProtocol(Protocol):
            self._endpoint = endpoint
            self._socket = None

        self._cb_on_close = kwargs.pop('on_close', self._cb_default_on_close)
        self._is_closed = False
        self._from_tun_to_endpoint_buf = ''
        self._asked = 0 # Data asked to the peer
        self._ok_to_send = 0 # Data I can send to the peer
@@ -80,6 +82,12 @@ class TunnelProtocol(Protocol):
                                          type='get', payload=dict(size=size))
        self._asked += size

    def _cb_default_on_close(self, tun):
        '''
        Action to do on the endpoint when the connection is closed.
        '''
        tun.endpoint.close()

#
# Public methods:
#
@@ -95,6 +103,21 @@ class TunnelProtocol(Protocol):
        '''
        return self._endpoint

    def close(self):
        '''
        Close the tunnel and unregister it from connection.
        '''
        if not self._is_closed:
            self._is_closed = True
            # Stop watchers:
            self._endpoint_reader.stop()
            self._endpoint_writer.stop()
            # Send the end of stream message to the peer:
            self._connection.rpc.send_special('protoctl', label=self._label, type='eos')
            # Unregister the tunnel:
            self._connection.unregister_protocol(self._label)
            self._cb_on_close(self)

    def end_of_message(self):
        '''
        Handle inbound data from the :class:`RpcConnection` peer and place it
@@ -104,6 +127,7 @@ class TunnelProtocol(Protocol):
        self._endpoint_writer.start()

    def handle_control(self, control_type, payload):
        if not self._is_closed:
            if control_type == 'get':
                size = payload.get('size', TunnelProtocol.DEFAULT_GET_SIZE)
                self._ok_to_send += size
@@ -112,3 +136,6 @@ class TunnelProtocol(Protocol):
                self._send_get(TunnelProtocol.GET_SIZE)
                self._ok_to_send += TunnelProtocol.GET_SIZE
                self._endpoint_reader.start()
            elif control_type == 'eos':
                self.logger.debug('Received EOS event')
                self.close()