mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-02-22 21:31:19 +01:00
Simplify selection of features
This commit is contained in:
parent
6b4e36c528
commit
90654b4fb9
@ -14,7 +14,7 @@ import platform
|
||||
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
|
||||
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
|
||||
resolvconf_nameservers
|
||||
from sshuttle.methods import get_method
|
||||
from sshuttle.methods import get_method, Features
|
||||
|
||||
_extra_fd = os.open('/dev/null', os.O_RDONLY)
|
||||
|
||||
@ -505,19 +505,44 @@ def main(listenip_v6, listenip_v4,
|
||||
|
||||
fw = FirewallClient(method_name)
|
||||
|
||||
features = fw.method.get_supported_features()
|
||||
# Get family specific subnet lists
|
||||
if dns:
|
||||
nslist += resolvconf_nameservers()
|
||||
|
||||
subnets = subnets_include + subnets_exclude # we don't care here
|
||||
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
||||
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||
|
||||
# Check features available
|
||||
avail = fw.method.get_supported_features()
|
||||
required = Features()
|
||||
|
||||
if listenip_v6 == "auto":
|
||||
if features.ipv6:
|
||||
if avail.ipv6:
|
||||
listenip_v6 = ('::1', 0)
|
||||
else:
|
||||
listenip_v6 = None
|
||||
|
||||
required.ipv6 = len(subnets_v6) > 0 or len(nslist_v6) > 0
|
||||
required.udp = avail.udp
|
||||
required.dns = len(nslist) > 0
|
||||
|
||||
fw.method.assert_features(required)
|
||||
|
||||
if required.ipv6 and listenip_v6 is None:
|
||||
raise Fatal("IPv6 required but not listening.")
|
||||
|
||||
# display features enabled
|
||||
debug1("IPv6 enabled: %r\n" % required.ipv6)
|
||||
debug1("UDP enabled: %r\n" % required.udp)
|
||||
debug1("DNS enabled: %r\n" % required.dns)
|
||||
|
||||
# bind to required ports
|
||||
if listenip_v4 == "auto":
|
||||
listenip_v4 = ('127.0.0.1', 0)
|
||||
|
||||
udp = features.udp
|
||||
debug1("UDP enabled: %r\n" % udp)
|
||||
|
||||
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
|
||||
# if both ports given, no need to search for a spare port
|
||||
ports = [0, ]
|
||||
@ -536,7 +561,7 @@ def main(listenip_v6, listenip_v4,
|
||||
tcp_listener = MultiListener()
|
||||
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
if udp:
|
||||
if required.udp:
|
||||
udp_listener = MultiListener(socket.SOCK_DGRAM)
|
||||
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
else:
|
||||
@ -584,10 +609,7 @@ def main(listenip_v6, listenip_v4,
|
||||
udp_listener.print_listening("UDP redirector")
|
||||
|
||||
bound = False
|
||||
if dns or nslist:
|
||||
if dns:
|
||||
nslist += resolvconf_nameservers()
|
||||
dns = True
|
||||
if required.dns:
|
||||
# search for spare port for DNS
|
||||
debug2('Binding DNS:')
|
||||
ports = range(12300, 9000, -1)
|
||||
@ -628,17 +650,41 @@ def main(listenip_v6, listenip_v4,
|
||||
dnsport_v4 = 0
|
||||
dns_listener = None
|
||||
|
||||
fw.method.check_settings(udp, dns)
|
||||
# Last minute sanity checks.
|
||||
# These should never fail.
|
||||
# If these do fail, something is broken above.
|
||||
if len(subnets_v6) > 0:
|
||||
assert required.ipv6
|
||||
if redirectport_v6 == 0:
|
||||
raise Fatal("IPv6 subnets defined but not listening")
|
||||
|
||||
if len(nslist_v6) > 0:
|
||||
assert required.dns
|
||||
assert required.ipv6
|
||||
if dnsport_v6 == 0:
|
||||
raise Fatal("IPv6 ns servers defined but not listening")
|
||||
|
||||
if len(subnets_v4) > 0:
|
||||
if redirectport_v4 == 0:
|
||||
raise Fatal("IPv4 subnets defined but not listening")
|
||||
|
||||
if len(nslist_v4) > 0:
|
||||
if dnsport_v4 == 0:
|
||||
raise Fatal("IPv4 ns servers defined but not listening")
|
||||
|
||||
# setup method specific stuff on listeners
|
||||
fw.method.setup_tcp_listener(tcp_listener)
|
||||
if udp_listener:
|
||||
fw.method.setup_udp_listener(udp_listener)
|
||||
if dns_listener:
|
||||
fw.method.setup_udp_listener(dns_listener)
|
||||
|
||||
# start the firewall
|
||||
fw.setup(subnets_include, subnets_exclude, nslist,
|
||||
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4,
|
||||
udp)
|
||||
required.udp)
|
||||
|
||||
# start the client process
|
||||
try:
|
||||
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
|
||||
python, latency_control, dns_listener,
|
||||
|
@ -178,26 +178,23 @@ def main(method_name, syslog):
|
||||
try:
|
||||
debug1('firewall manager: setting up.\n')
|
||||
|
||||
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||
if port_v6 > 0:
|
||||
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||
|
||||
if len(subnets_v6) > 0 or len(nslist_v6) > 0:
|
||||
debug2('firewall manager: setting up IPv6.\n')
|
||||
method.setup_firewall(
|
||||
port_v6, dnsport_v6, nslist_v6,
|
||||
socket.AF_INET6, subnets_v6, udp)
|
||||
elif len(subnets_v6) > 0:
|
||||
debug1("IPv6 subnets defined but IPv6 disabled\n")
|
||||
|
||||
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
||||
if port_v4 > 0:
|
||||
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||
|
||||
if len(subnets_v4) > 0 or len(nslist_v4) > 0:
|
||||
debug2('firewall manager: setting up IPv4.\n')
|
||||
method.setup_firewall(
|
||||
port_v4, dnsport_v4, nslist_v4,
|
||||
socket.AF_INET, subnets_v4, udp)
|
||||
elif len(subnets_v4) > 0:
|
||||
debug1('firewall manager: '
|
||||
'IPv4 subnets defined but IPv4 disabled\n')
|
||||
|
||||
stdout.write('STARTED\n')
|
||||
|
||||
|
@ -62,9 +62,13 @@ class BaseMethod(object):
|
||||
def setup_udp_listener(self, udp_listener):
|
||||
pass
|
||||
|
||||
def check_settings(self, udp, dns):
|
||||
if udp:
|
||||
Fatal("UDP support not supported with method %s.\n" % self.name)
|
||||
def assert_features(self, features):
|
||||
avail = self.get_supported_features()
|
||||
for key in ["udp", "dns", "ipv6"]:
|
||||
if getattr(features, key) and not getattr(avail, key):
|
||||
raise Fatal(
|
||||
"Feature %s not supported with method %s.\n" %
|
||||
(key, self.name))
|
||||
|
||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
|
||||
raise NotImplementedError()
|
||||
|
@ -55,13 +55,12 @@ class Method(BaseMethod):
|
||||
'-p', 'tcp',
|
||||
'--to-ports', str(port))
|
||||
|
||||
if dnsport:
|
||||
for f, ip in [i for i in nslist if i[0] == family]:
|
||||
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
|
||||
'--dest', '%s/32' % ip,
|
||||
'-p', 'udp',
|
||||
'--dport', '53',
|
||||
'--to-ports', str(dnsport))
|
||||
for f, ip in [i for i in nslist if i[0] == family]:
|
||||
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
|
||||
'--dest', '%s/32' % ip,
|
||||
'-p', 'udp',
|
||||
'--dport', '53',
|
||||
'--to-ports', str(dnsport))
|
||||
|
||||
def restore_firewall(self, port, family, udp):
|
||||
# only ipv4 supported with NAT
|
||||
|
@ -181,27 +181,28 @@ class Method(BaseMethod):
|
||||
if udp:
|
||||
raise Exception("UDP not supported by pf method_name")
|
||||
|
||||
includes = []
|
||||
# If a given subnet is both included and excluded, list the
|
||||
# exclusion first; the table will ignore the second, opposite
|
||||
# definition
|
||||
for f, swidth, sexclude, snet in sorted(
|
||||
subnets, key=lambda s: (s[1], s[2]), reverse=True):
|
||||
includes.append(b"%s%s/%d" %
|
||||
(b"!" if sexclude else b"",
|
||||
snet.encode("ASCII"),
|
||||
swidth))
|
||||
if len(subnets) > 0:
|
||||
includes = []
|
||||
# If a given subnet is both included and excluded, list the
|
||||
# exclusion first; the table will ignore the second, opposite
|
||||
# definition
|
||||
for f, swidth, sexclude, snet in sorted(
|
||||
subnets, key=lambda s: (s[1], s[2]), reverse=True):
|
||||
includes.append(b"%s%s/%d" %
|
||||
(b"!" if sexclude else b"",
|
||||
snet.encode("ASCII"),
|
||||
swidth))
|
||||
|
||||
tables.append(
|
||||
b'table <forward_subnets> {%s}' % b','.join(includes))
|
||||
translating_rules.append(
|
||||
b'rdr pass on lo0 proto tcp '
|
||||
b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
|
||||
filtering_rules.append(
|
||||
b'pass out route-to lo0 inet proto tcp '
|
||||
b'to <forward_subnets> keep state')
|
||||
tables.append(
|
||||
b'table <forward_subnets> {%s}' % b','.join(includes))
|
||||
translating_rules.append(
|
||||
b'rdr pass on lo0 proto tcp '
|
||||
b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
|
||||
filtering_rules.append(
|
||||
b'pass out route-to lo0 inet proto tcp '
|
||||
b'to <forward_subnets> keep state')
|
||||
|
||||
if dnsport:
|
||||
if len(nslist) > 0:
|
||||
tables.append(
|
||||
b'table <dns_servers> {%s}' %
|
||||
b','.join([ns[1].encode("ASCII") for ns in nslist]))
|
||||
|
@ -59,6 +59,7 @@ if recvmsg == "python":
|
||||
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
|
||||
dstip = (ip, port)
|
||||
break
|
||||
print("xxxxx", srcip, dstip)
|
||||
return (srcip, dstip, data)
|
||||
elif recvmsg == "socket_ext":
|
||||
def recv_udp(listener, bufsize):
|
||||
@ -187,16 +188,15 @@ class Method(BaseMethod):
|
||||
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
|
||||
'-m', 'udp', '-p', 'udp')
|
||||
|
||||
if dnsport:
|
||||
for f, ip in [i for i in nslist if i[0] == family]:
|
||||
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
|
||||
'--dest', '%s/32' % ip,
|
||||
'-m', 'udp', '-p', 'udp', '--dport', '53')
|
||||
_ipt('-A', tproxy_chain, '-j', 'TPROXY',
|
||||
'--tproxy-mark', '0x1/0x1',
|
||||
'--dest', '%s/32' % ip,
|
||||
'-m', 'udp', '-p', 'udp', '--dport', '53',
|
||||
'--on-port', str(dnsport))
|
||||
for f, ip in [i for i in nslist if i[0] == family]:
|
||||
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
|
||||
'--dest', '%s/32' % ip,
|
||||
'-m', 'udp', '-p', 'udp', '--dport', '53')
|
||||
_ipt('-A', tproxy_chain, '-j', 'TPROXY',
|
||||
'--tproxy-mark', '0x1/0x1',
|
||||
'--dest', '%s/32' % ip,
|
||||
'-m', 'udp', '-p', 'udp', '--dport', '53',
|
||||
'--on-port', str(dnsport))
|
||||
|
||||
for f, swidth, sexclude, snet \
|
||||
in sorted(subnets, key=lambda s: s[1], reverse=True):
|
||||
@ -267,16 +267,3 @@ class Method(BaseMethod):
|
||||
if ipt_chain_exists(family, table, divert_chain):
|
||||
_ipt('-F', divert_chain)
|
||||
_ipt('-X', divert_chain)
|
||||
|
||||
def check_settings(self, udp, dns):
|
||||
if udp and recvmsg is None:
|
||||
raise Fatal("tproxy UDP support requires recvmsg function.\n")
|
||||
|
||||
if dns and recvmsg is None:
|
||||
raise Fatal("tproxy DNS support requires recvmsg function.\n")
|
||||
|
||||
if udp:
|
||||
debug1("tproxy UDP support enabled.\n")
|
||||
|
||||
if dns:
|
||||
debug1("tproxy DNS support enabled.\n")
|
||||
|
@ -3,6 +3,7 @@ from mock import Mock, patch, call
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from sshuttle.helpers import Fatal
|
||||
from sshuttle.methods import get_method
|
||||
|
||||
|
||||
@ -11,6 +12,7 @@ def test_get_supported_features():
|
||||
features = method.get_supported_features()
|
||||
assert not features.ipv6
|
||||
assert not features.udp
|
||||
assert features.dns
|
||||
|
||||
|
||||
def test_get_tcp_dstip():
|
||||
@ -52,10 +54,18 @@ def test_setup_udp_listener():
|
||||
assert listener.mock_calls == []
|
||||
|
||||
|
||||
def test_check_settings():
|
||||
def test_assert_features():
|
||||
method = get_method('nat')
|
||||
method.check_settings(True, True)
|
||||
method.check_settings(False, True)
|
||||
features = method.get_supported_features()
|
||||
method.assert_features(features)
|
||||
|
||||
features.udp = True
|
||||
with pytest.raises(Fatal):
|
||||
method.assert_features(features)
|
||||
|
||||
features.ipv6 = True
|
||||
with pytest.raises(Fatal):
|
||||
method.assert_features(features)
|
||||
|
||||
|
||||
def test_firewall_command():
|
||||
|
@ -3,6 +3,7 @@ from mock import Mock, patch, call, ANY
|
||||
import socket
|
||||
|
||||
from sshuttle.methods import get_method
|
||||
from sshuttle.helpers import Fatal
|
||||
|
||||
|
||||
def test_get_supported_features():
|
||||
@ -10,6 +11,7 @@ def test_get_supported_features():
|
||||
features = method.get_supported_features()
|
||||
assert not features.ipv6
|
||||
assert not features.udp
|
||||
assert features.dns
|
||||
|
||||
|
||||
@patch('sshuttle.helpers.verbose', new=3)
|
||||
@ -68,10 +70,18 @@ def test_setup_udp_listener():
|
||||
assert listener.mock_calls == []
|
||||
|
||||
|
||||
def test_check_settings():
|
||||
def test_assert_features():
|
||||
method = get_method('pf')
|
||||
method.check_settings(True, True)
|
||||
method.check_settings(False, True)
|
||||
features = method.get_supported_features()
|
||||
method.assert_features(features)
|
||||
|
||||
features.udp = True
|
||||
with pytest.raises(Fatal):
|
||||
method.assert_features(features)
|
||||
|
||||
features.ipv6 = True
|
||||
with pytest.raises(Fatal):
|
||||
method.assert_features(features)
|
||||
|
||||
|
||||
@patch('sshuttle.methods.pf.sys.stdout')
|
||||
|
@ -3,11 +3,22 @@ from mock import Mock, patch, call
|
||||
from sshuttle.methods import get_method
|
||||
|
||||
|
||||
def test_get_supported_features():
|
||||
@patch("sshuttle.methods.tproxy.recvmsg")
|
||||
def test_get_supported_features_recvmsg(mock_recvmsg):
|
||||
method = get_method('tproxy')
|
||||
features = method.get_supported_features()
|
||||
assert features.ipv6
|
||||
assert features.udp
|
||||
assert features.dns
|
||||
|
||||
|
||||
@patch("sshuttle.methods.tproxy.recvmsg", None)
|
||||
def test_get_supported_features_norecvmsg():
|
||||
method = get_method('tproxy')
|
||||
features = method.get_supported_features()
|
||||
assert features.ipv6
|
||||
assert not features.udp
|
||||
assert not features.dns
|
||||
|
||||
|
||||
def test_get_tcp_dstip():
|
||||
@ -66,10 +77,10 @@ def test_setup_udp_listener():
|
||||
]
|
||||
|
||||
|
||||
def test_check_settings():
|
||||
def test_assert_features():
|
||||
method = get_method('tproxy')
|
||||
method.check_settings(True, True)
|
||||
method.check_settings(False, True)
|
||||
features = method.get_supported_features()
|
||||
method.assert_features(features)
|
||||
|
||||
|
||||
def test_firewall_command():
|
||||
|
Loading…
Reference in New Issue
Block a user