diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py index 1e8d325..5a3b0e7 100644 --- a/sshuttle/firewall.py +++ b/sshuttle/firewall.py @@ -114,6 +114,11 @@ def main(method_name, syslog): ssyslog.start_syslog() ssyslog.stderr_to_syslog() + if not method.is_supported(): + raise Fatal("The %s method is not supported on this machine. " + "Check that the appropriate programs are in your " + "PATH." % method_name) + debug1('ready method name %s.\n' % method.name) stdout.write('READY %s\n' % method.name) stdout.flush() diff --git a/sshuttle/methods/__init__.py b/sshuttle/methods/__init__.py index 2c59904..a8fe938 100644 --- a/sshuttle/methods/__init__.py +++ b/sshuttle/methods/__init__.py @@ -3,7 +3,7 @@ import socket import struct import errno import ipaddress -from sshuttle.helpers import Fatal, debug3, which +from sshuttle.helpers import Fatal, debug3 def original_dst(sock): @@ -54,6 +54,12 @@ class BaseMethod(object): result.user = False return result + @staticmethod + def is_supported(): + """Returns true if it appears that this method will work on this + machine.""" + return False + @staticmethod def get_tcp_dstip(sock): return original_dst(sock) @@ -102,16 +108,15 @@ def get_method(method_name): def get_auto_method(): - if which('iptables'): - method_name = "nat" - elif which('nft'): - method_name = "nft" - elif which('pfctl'): - method_name = "pf" - elif which('ipfw'): - method_name = "ipfw" - else: - raise Fatal( - "can't find either iptables, nft or pfctl; check your PATH") + debug3("Selecting a method automatically...\n") + # Try these methods, in order: + methods_to_try = ["nat", "nft", "pf", "ipfw"] + for m in methods_to_try: + method = get_method(m) + if method.is_supported(): + debug3("Method '%s' was automatically selected.\n" % m) + return method - return get_method(method_name) + raise Fatal("Unable to automatically find a supported method. Check that " + "the appropriate programs are in your PATH. We tried " + "methods: %s" % str(methods_to_try)) diff --git a/sshuttle/methods/ipfw.py b/sshuttle/methods/ipfw.py index 8b4aebc..bedaf3c 100644 --- a/sshuttle/methods/ipfw.py +++ b/sshuttle/methods/ipfw.py @@ -1,8 +1,8 @@ import os import subprocess as ssubprocess from sshuttle.methods import BaseMethod -from sshuttle.helpers import log, debug1, debug3, \ - Fatal, family_to_string, get_env +from sshuttle.helpers import log, debug1, debug2, debug3, \ + Fatal, family_to_string, get_env, which recvmsg = None try: @@ -253,3 +253,10 @@ class Method(BaseMethod): ipfw_noexit('table', '124', 'flush') ipfw_noexit('table', '125', 'flush') ipfw_noexit('table', '126', 'flush') + + def is_supported(self): + if which("ipfw"): + return True + debug2("ipfw method not supported because 'ipfw' command is " + "missing.\n") + return False diff --git a/sshuttle/methods/nat.py b/sshuttle/methods/nat.py index 18ec1fd..f8c9149 100644 --- a/sshuttle/methods/nat.py +++ b/sshuttle/methods/nat.py @@ -1,6 +1,6 @@ import socket from sshuttle.firewall import subnet_weight -from sshuttle.helpers import family_to_string +from sshuttle.helpers import family_to_string, which, debug2 from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists, nonfatal from sshuttle.methods import BaseMethod @@ -124,3 +124,10 @@ class Method(BaseMethod): result = super(Method, self).get_supported_features() result.user = True return result + + def is_supported(self): + if which("iptables"): + return True + debug2("nat method not supported because 'iptables' command " + "is missing.\n") + return False diff --git a/sshuttle/methods/nft.py b/sshuttle/methods/nft.py index 1b127ef..eb287f0 100644 --- a/sshuttle/methods/nft.py +++ b/sshuttle/methods/nft.py @@ -2,6 +2,7 @@ import socket from sshuttle.firewall import subnet_weight from sshuttle.linux import nft, nonfatal from sshuttle.methods import BaseMethod +from sshuttle.helpers import debug2, which class Method(BaseMethod): @@ -113,3 +114,9 @@ class Method(BaseMethod): result = super(Method, self).get_supported_features() result.ipv6 = True return result + + def is_supported(self): + if which("nft"): + return True + debug2("nft method not supported because 'nft' command is missing.\n") + return False diff --git a/sshuttle/methods/pf.py b/sshuttle/methods/pf.py index 0ef0f46..d686904 100644 --- a/sshuttle/methods/pf.py +++ b/sshuttle/methods/pf.py @@ -12,7 +12,7 @@ from ctypes import c_char, c_uint8, c_uint16, c_uint32, Union, Structure, \ sizeof, addressof, memmove from sshuttle.firewall import subnet_weight from sshuttle.helpers import debug1, debug2, debug3, Fatal, family_to_string, \ - get_env + get_env, which from sshuttle.methods import BaseMethod @@ -491,3 +491,9 @@ class Method(BaseMethod): return True else: return False + + def is_supported(self): + if which("pfctl"): + return True + debug2("pf method not supported because 'pfctl' command is missing.\n") + return False diff --git a/sshuttle/methods/tproxy.py b/sshuttle/methods/tproxy.py index e7dba5e..6c3cf61 100644 --- a/sshuttle/methods/tproxy.py +++ b/sshuttle/methods/tproxy.py @@ -3,7 +3,7 @@ from sshuttle.firewall import subnet_weight from sshuttle.helpers import family_to_string from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists from sshuttle.methods import BaseMethod -from sshuttle.helpers import debug1, debug3, Fatal +from sshuttle.helpers import debug1, debug2, debug3, Fatal, which recvmsg = None try: @@ -294,3 +294,10 @@ class Method(BaseMethod): if ipt_chain_exists(family, table, divert_chain): _ipt('-F', divert_chain) _ipt('-X', divert_chain) + + def is_supported(self): + if which("iptables") and which("ip6tables"): + return True + debug2("tproxy method not supported because 'iptables' " + "or 'ip6tables' commands are missing.\n") + return False diff --git a/tests/client/test_firewall.py b/tests/client/test_firewall.py index d19839b..71f1940 100644 --- a/tests/client/test_firewall.py +++ b/tests/client/test_firewall.py @@ -116,6 +116,8 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts): assert mock_setup_daemon.mock_calls == [call()] assert mock_get_method.mock_calls == [ call('not_auto'), + call().is_supported(), + call().is_supported().__bool__(), call().setup_firewall( 1024, 1026, [(AF_INET6, u'2404:6800:4004:80c::33')],