mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-21 23:43:18 +01:00
Refactor automatic method selection.
Add an "is_supported()" function to the different methods so that each method can include whatever logic they wish to indicate if they are supported on a particular machine. Previously, methods/__init__.py contained all of the logic for selecting individual methods. Now, it iterates through a list of possible options and stops on the first method that it finds that is_supported(). Currently, the decision is made based on the presence of programs in the PATH. In the future, things such as the platform sshuttle is running on could be considered.
This commit is contained in:
parent
7c338866bf
commit
6d4261e3f9
@ -114,6 +114,11 @@ def main(method_name, syslog):
|
||||
ssyslog.start_syslog()
|
||||
ssyslog.stderr_to_syslog()
|
||||
|
||||
if not method.is_supported():
|
||||
raise Fatal("The %s method is not supported on this machine. "
|
||||
"Check that the appropriate programs are in your "
|
||||
"PATH." % method_name)
|
||||
|
||||
debug1('ready method name %s.\n' % method.name)
|
||||
stdout.write('READY %s\n' % method.name)
|
||||
stdout.flush()
|
||||
|
@ -3,7 +3,7 @@ import socket
|
||||
import struct
|
||||
import errno
|
||||
import ipaddress
|
||||
from sshuttle.helpers import Fatal, debug3, which
|
||||
from sshuttle.helpers import Fatal, debug3
|
||||
|
||||
|
||||
def original_dst(sock):
|
||||
@ -54,6 +54,12 @@ class BaseMethod(object):
|
||||
result.user = False
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def is_supported():
|
||||
"""Returns true if it appears that this method will work on this
|
||||
machine."""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_tcp_dstip(sock):
|
||||
return original_dst(sock)
|
||||
@ -102,16 +108,15 @@ def get_method(method_name):
|
||||
|
||||
|
||||
def get_auto_method():
|
||||
if which('iptables'):
|
||||
method_name = "nat"
|
||||
elif which('nft'):
|
||||
method_name = "nft"
|
||||
elif which('pfctl'):
|
||||
method_name = "pf"
|
||||
elif which('ipfw'):
|
||||
method_name = "ipfw"
|
||||
else:
|
||||
raise Fatal(
|
||||
"can't find either iptables, nft or pfctl; check your PATH")
|
||||
debug3("Selecting a method automatically...\n")
|
||||
# Try these methods, in order:
|
||||
methods_to_try = ["nat", "nft", "pf", "ipfw"]
|
||||
for m in methods_to_try:
|
||||
method = get_method(m)
|
||||
if method.is_supported():
|
||||
debug3("Method '%s' was automatically selected.\n" % m)
|
||||
return method
|
||||
|
||||
return get_method(method_name)
|
||||
raise Fatal("Unable to automatically find a supported method. Check that "
|
||||
"the appropriate programs are in your PATH. We tried "
|
||||
"methods: %s" % str(methods_to_try))
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import subprocess as ssubprocess
|
||||
from sshuttle.methods import BaseMethod
|
||||
from sshuttle.helpers import log, debug1, debug3, \
|
||||
Fatal, family_to_string, get_env
|
||||
from sshuttle.helpers import log, debug1, debug2, debug3, \
|
||||
Fatal, family_to_string, get_env, which
|
||||
|
||||
recvmsg = None
|
||||
try:
|
||||
@ -253,3 +253,10 @@ class Method(BaseMethod):
|
||||
ipfw_noexit('table', '124', 'flush')
|
||||
ipfw_noexit('table', '125', 'flush')
|
||||
ipfw_noexit('table', '126', 'flush')
|
||||
|
||||
def is_supported(self):
|
||||
if which("ipfw"):
|
||||
return True
|
||||
debug2("ipfw method not supported because 'ipfw' command is "
|
||||
"missing.\n")
|
||||
return False
|
||||
|
@ -1,6 +1,6 @@
|
||||
import socket
|
||||
from sshuttle.firewall import subnet_weight
|
||||
from sshuttle.helpers import family_to_string
|
||||
from sshuttle.helpers import family_to_string, which, debug2
|
||||
from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists, nonfatal
|
||||
from sshuttle.methods import BaseMethod
|
||||
|
||||
@ -124,3 +124,10 @@ class Method(BaseMethod):
|
||||
result = super(Method, self).get_supported_features()
|
||||
result.user = True
|
||||
return result
|
||||
|
||||
def is_supported(self):
|
||||
if which("iptables"):
|
||||
return True
|
||||
debug2("nat method not supported because 'iptables' command "
|
||||
"is missing.\n")
|
||||
return False
|
||||
|
@ -2,6 +2,7 @@ import socket
|
||||
from sshuttle.firewall import subnet_weight
|
||||
from sshuttle.linux import nft, nonfatal
|
||||
from sshuttle.methods import BaseMethod
|
||||
from sshuttle.helpers import debug2, which
|
||||
|
||||
|
||||
class Method(BaseMethod):
|
||||
@ -113,3 +114,9 @@ class Method(BaseMethod):
|
||||
result = super(Method, self).get_supported_features()
|
||||
result.ipv6 = True
|
||||
return result
|
||||
|
||||
def is_supported(self):
|
||||
if which("nft"):
|
||||
return True
|
||||
debug2("nft method not supported because 'nft' command is missing.\n")
|
||||
return False
|
||||
|
@ -12,7 +12,7 @@ from ctypes import c_char, c_uint8, c_uint16, c_uint32, Union, Structure, \
|
||||
sizeof, addressof, memmove
|
||||
from sshuttle.firewall import subnet_weight
|
||||
from sshuttle.helpers import debug1, debug2, debug3, Fatal, family_to_string, \
|
||||
get_env
|
||||
get_env, which
|
||||
from sshuttle.methods import BaseMethod
|
||||
|
||||
|
||||
@ -491,3 +491,9 @@ class Method(BaseMethod):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_supported(self):
|
||||
if which("pfctl"):
|
||||
return True
|
||||
debug2("pf method not supported because 'pfctl' command is missing.\n")
|
||||
return False
|
||||
|
@ -3,7 +3,7 @@ from sshuttle.firewall import subnet_weight
|
||||
from sshuttle.helpers import family_to_string
|
||||
from sshuttle.linux import ipt, ipt_ttl, ipt_chain_exists
|
||||
from sshuttle.methods import BaseMethod
|
||||
from sshuttle.helpers import debug1, debug3, Fatal
|
||||
from sshuttle.helpers import debug1, debug2, debug3, Fatal, which
|
||||
|
||||
recvmsg = None
|
||||
try:
|
||||
@ -294,3 +294,10 @@ class Method(BaseMethod):
|
||||
if ipt_chain_exists(family, table, divert_chain):
|
||||
_ipt('-F', divert_chain)
|
||||
_ipt('-X', divert_chain)
|
||||
|
||||
def is_supported(self):
|
||||
if which("iptables") and which("ip6tables"):
|
||||
return True
|
||||
debug2("tproxy method not supported because 'iptables' "
|
||||
"or 'ip6tables' commands are missing.\n")
|
||||
return False
|
||||
|
@ -116,6 +116,8 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
|
||||
assert mock_setup_daemon.mock_calls == [call()]
|
||||
assert mock_get_method.mock_calls == [
|
||||
call('not_auto'),
|
||||
call().is_supported(),
|
||||
call().is_supported().__bool__(),
|
||||
call().setup_firewall(
|
||||
1024, 1026,
|
||||
[(AF_INET6, u'2404:6800:4004:80c::33')],
|
||||
|
Loading…
Reference in New Issue
Block a user