Refactor automatic method selection.

Add an "is_supported()" function to the different methods so that each
method can include whatever logic they wish to indicate if they are
supported on a particular machine. Previously, methods/__init__.py
contained all of the logic for selecting individual methods. Now, it
iterates through a list of possible options and stops on the first
method that it finds that is_supported().

Currently, the decision is made based on the presence of programs in
the PATH. In the future, things such as the platform sshuttle is
running on could be considered.
This commit is contained in:
Scott Kuhl 2020-12-26 15:25:33 -05:00 committed by Brian May
parent 7c338866bf
commit 6d4261e3f9
8 changed files with 64 additions and 18 deletions

View File

@ -114,6 +114,11 @@ def main(method_name, syslog):
ssyslog.start_syslog() ssyslog.start_syslog()
ssyslog.stderr_to_syslog() ssyslog.stderr_to_syslog()
if not method.is_supported():
raise Fatal("The %s method is not supported on this machine. "
"Check that the appropriate programs are in your "
"PATH." % method_name)
debug1('ready method name %s.\n' % method.name) debug1('ready method name %s.\n' % method.name)
stdout.write('READY %s\n' % method.name) stdout.write('READY %s\n' % method.name)
stdout.flush() stdout.flush()

View File

@ -3,7 +3,7 @@ import socket
import struct import struct
import errno import errno
import ipaddress import ipaddress
from sshuttle.helpers import Fatal, debug3, which from sshuttle.helpers import Fatal, debug3
def original_dst(sock): def original_dst(sock):
@ -54,6 +54,12 @@ class BaseMethod(object):
result.user = False result.user = False
return result return result
@staticmethod
def is_supported():
"""Returns true if it appears that this method will work on this
machine."""
return False
@staticmethod @staticmethod
def get_tcp_dstip(sock): def get_tcp_dstip(sock):
return original_dst(sock) return original_dst(sock)
@ -102,16 +108,15 @@ def get_method(method_name):
def get_auto_method(): def get_auto_method():
if which('iptables'): debug3("Selecting a method automatically...\n")
method_name = "nat" # Try these methods, in order:
elif which('nft'): methods_to_try = ["nat", "nft", "pf", "ipfw"]
method_name = "nft" for m in methods_to_try:
elif which('pfctl'): method = get_method(m)
method_name = "pf" if method.is_supported():
elif which('ipfw'): debug3("Method '%s' was automatically selected.\n" % m)
method_name = "ipfw" return method
else:
raise Fatal(
"can't find either iptables, nft or pfctl; check your PATH")
return get_method(method_name) raise Fatal("Unable to automatically find a supported method. Check that "
"the appropriate programs are in your PATH. We tried "
"methods: %s" % str(methods_to_try))

View File

@ -1,8 +1,8 @@
import os import os
import subprocess as ssubprocess import subprocess as ssubprocess
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import log, debug1, debug3, \ from sshuttle.helpers import log, debug1, debug2, debug3, \
Fatal, family_to_string, get_env Fatal, family_to_string, get_env, which
recvmsg = None recvmsg = None
try: try:
@ -253,3 +253,10 @@ class Method(BaseMethod):
ipfw_noexit('table', '124', 'flush') ipfw_noexit('table', '124', 'flush')
ipfw_noexit('table', '125', 'flush') ipfw_noexit('table', '125', 'flush')
ipfw_noexit('table', '126', 'flush') ipfw_noexit('table', '126', 'flush')
def is_supported(self):
if which("ipfw"):
return True
debug2("ipfw method not supported because 'ipfw' command is "
"missing.\n")
return False

View File

@ -1,6 +1,6 @@
import socket import socket
from sshuttle.firewall import subnet_weight from sshuttle.firewall import subnet_weight
from sshuttle.helpers import family_to_string from sshuttle.helpers import family_to_string, which, debug2
from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists, nonfatal from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists, nonfatal
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
@ -124,3 +124,10 @@ class Method(BaseMethod):
result = super(Method, self).get_supported_features() result = super(Method, self).get_supported_features()
result.user = True result.user = True
return result return result
def is_supported(self):
if which("iptables"):
return True
debug2("nat method not supported because 'iptables' command "
"is missing.\n")
return False

View File

@ -2,6 +2,7 @@ import socket
from sshuttle.firewall import subnet_weight from sshuttle.firewall import subnet_weight
from sshuttle.linux import nft, nonfatal from sshuttle.linux import nft, nonfatal
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug2, which
class Method(BaseMethod): class Method(BaseMethod):
@ -113,3 +114,9 @@ class Method(BaseMethod):
result = super(Method, self).get_supported_features() result = super(Method, self).get_supported_features()
result.ipv6 = True result.ipv6 = True
return result return result
def is_supported(self):
if which("nft"):
return True
debug2("nft method not supported because 'nft' command is missing.\n")
return False

View File

@ -12,7 +12,7 @@ from ctypes import c_char, c_uint8, c_uint16, c_uint32, Union, Structure, \
sizeof, addressof, memmove sizeof, addressof, memmove
from sshuttle.firewall import subnet_weight from sshuttle.firewall import subnet_weight
from sshuttle.helpers import debug1, debug2, debug3, Fatal, family_to_string, \ from sshuttle.helpers import debug1, debug2, debug3, Fatal, family_to_string, \
get_env get_env, which
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
@ -491,3 +491,9 @@ class Method(BaseMethod):
return True return True
else: else:
return False return False
def is_supported(self):
if which("pfctl"):
return True
debug2("pf method not supported because 'pfctl' command is missing.\n")
return False

View File

@ -3,7 +3,7 @@ from sshuttle.firewall import subnet_weight
from sshuttle.helpers import family_to_string from sshuttle.helpers import family_to_string
from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug1, debug3, Fatal from sshuttle.helpers import debug1, debug2, debug3, Fatal, which
recvmsg = None recvmsg = None
try: try:
@ -294,3 +294,10 @@ class Method(BaseMethod):
if ipt_chain_exists(family, table, divert_chain): if ipt_chain_exists(family, table, divert_chain):
_ipt('-F', divert_chain) _ipt('-F', divert_chain)
_ipt('-X', divert_chain) _ipt('-X', divert_chain)
def is_supported(self):
if which("iptables") and which("ip6tables"):
return True
debug2("tproxy method not supported because 'iptables' "
"or 'ip6tables' commands are missing.\n")
return False

View File

@ -116,6 +116,8 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
assert mock_setup_daemon.mock_calls == [call()] assert mock_setup_daemon.mock_calls == [call()]
assert mock_get_method.mock_calls == [ assert mock_get_method.mock_calls == [
call('not_auto'), call('not_auto'),
call().is_supported(),
call().is_supported().__bool__(),
call().setup_firewall( call().setup_firewall(
1024, 1026, 1024, 1026,
[(AF_INET6, u'2404:6800:4004:80c::33')], [(AF_INET6, u'2404:6800:4004:80c::33')],