diff --git a/docs/requirements.rst b/docs/requirements.rst index 798cdcb..f070361 100644 --- a/docs/requirements.rst +++ b/docs/requirements.rst @@ -49,6 +49,8 @@ Supports: * IPv4 TCP * IPv4 DNS +* IPv6 TCP +* IPv6 DNS Requires: @@ -72,3 +74,19 @@ Additional Suggested Software - You may want to use autossh, available in various package management systems +- If you are using systemd, sshuttle can notify it when the connection to + the remote end is established and the firewall rules are installed + +.. code-block:: ini + :emphasize-lines: 6 + + [Unit] + Description=sshuttle + After=network.target + + [Service] + Type=notify + ExecStart=/usr/bin/sshuttle --dns --remote @ + + [Install] + WantedBy=multi-user.target diff --git a/sshuttle/sdnotify.py b/sshuttle/sdnotify.py index 665d953..ac478f4 100644 --- a/sshuttle/sdnotify.py +++ b/sshuttle/sdnotify.py @@ -1,6 +1,6 @@ import socket import os -from sshuttle.helpers import debug1, debug2, debug3 +from sshuttle.helpers import debug1 def _notify(message): addr = os.environ.get("NOTIFY_SOCKET", None) @@ -12,17 +12,18 @@ def _notify(message): try: sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) - except socket.error as e: - debug1("Error creating socket to notify to systemd: %s\n" % e) + except (OSError, IOError) as e: + debug1("Error creating socket to notify systemd: %s\n" % e) + return False - if not (sock and message): + if not message: return False assert isinstance(message, bytes) try: return (sock.sendto(message, addr) > 0) - except socket.error as e: + except (OSError, IOError) as e: debug1("Error notifying systemd: %s\n" % e) return False diff --git a/sshuttle/tests/client/test_sdnotify.py b/sshuttle/tests/client/test_sdnotify.py new file mode 100644 index 0000000..2a062cf --- /dev/null +++ b/sshuttle/tests/client/test_sdnotify.py @@ -0,0 +1,66 @@ +from mock import Mock, patch, call +import sys +import io +import socket + +import sshuttle.sdnotify + + +@patch('sshuttle.sdnotify.os.environ.get') +def test_notify_invalid_socket_path(mock_get): + mock_get.return_value = 'invalid_path' + assert not sshuttle.sdnotify.send(sshuttle.sdnotify.ready()) + + +@patch('sshuttle.sdnotify.os.environ.get') +def test_notify_socket_not_there(mock_get): + mock_get.return_value = '/run/valid_nonexistent_path' + assert not sshuttle.sdnotify.send(sshuttle.sdnotify.ready()) + + +@patch('sshuttle.sdnotify.os.environ.get') +def test_notify_no_message(mock_get): + mock_get.return_value = '/run/valid_path' + assert not sshuttle.sdnotify.send() + + +@patch('sshuttle.sdnotify.socket.socket') +@patch('sshuttle.sdnotify.os.environ.get') +def test_notify_socket_error(mock_get, mock_socket): + mock_get.return_value = '/run/valid_path' + mock_socket.side_effect = socket.error('test error') + assert not sshuttle.sdnotify.send(sshuttle.sdnotify.ready()) + + +@patch('sshuttle.sdnotify.socket.socket') +@patch('sshuttle.sdnotify.os.environ.get') +def test_notify_sendto_error(mock_get, mock_socket): + message = sshuttle.sdnotify.ready() + socket_path = '/run/valid_path' + + sock = Mock() + sock.sendto.side_effect = socket.error('test error') + mock_get.return_value = '/run/valid_path' + mock_socket.return_value = sock + + assert not sshuttle.sdnotify.send(message) + assert sock.sendto.mock_calls == [ + call(message, socket_path), + ] + + +@patch('sshuttle.sdnotify.socket.socket') +@patch('sshuttle.sdnotify.os.environ.get') +def test_notify(mock_get, mock_socket): + messages = [sshuttle.sdnotify.ready(), sshuttle.sdnotify.status('Running')] + socket_path = '/run/valid_path' + + sock = Mock() + sock.sendto.return_value = 1 + mock_get.return_value = '/run/valid_path' + mock_socket.return_value = sock + + assert sshuttle.sdnotify.send(*messages) + assert sock.sendto.mock_calls == [ + call(b'\n'.join(messages), socket_path), + ]