Simplify selection of features

This commit is contained in:
Brian May 2015-12-15 11:40:55 +11:00
parent 6b4e36c528
commit 90654b4fb9
9 changed files with 149 additions and 84 deletions

View File

@ -14,7 +14,7 @@ import platform
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \ from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
resolvconf_nameservers resolvconf_nameservers
from sshuttle.methods import get_method from sshuttle.methods import get_method, Features
_extra_fd = os.open('/dev/null', os.O_RDONLY) _extra_fd = os.open('/dev/null', os.O_RDONLY)
@ -505,19 +505,44 @@ def main(listenip_v6, listenip_v4,
fw = FirewallClient(method_name) fw = FirewallClient(method_name)
features = fw.method.get_supported_features() # Get family specific subnet lists
if dns:
nslist += resolvconf_nameservers()
subnets = subnets_include + subnets_exclude # we don't care here
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
# Check features available
avail = fw.method.get_supported_features()
required = Features()
if listenip_v6 == "auto": if listenip_v6 == "auto":
if features.ipv6: if avail.ipv6:
listenip_v6 = ('::1', 0) listenip_v6 = ('::1', 0)
else: else:
listenip_v6 = None listenip_v6 = None
required.ipv6 = len(subnets_v6) > 0 or len(nslist_v6) > 0
required.udp = avail.udp
required.dns = len(nslist) > 0
fw.method.assert_features(required)
if required.ipv6 and listenip_v6 is None:
raise Fatal("IPv6 required but not listening.")
# display features enabled
debug1("IPv6 enabled: %r\n" % required.ipv6)
debug1("UDP enabled: %r\n" % required.udp)
debug1("DNS enabled: %r\n" % required.dns)
# bind to required ports
if listenip_v4 == "auto": if listenip_v4 == "auto":
listenip_v4 = ('127.0.0.1', 0) listenip_v4 = ('127.0.0.1', 0)
udp = features.udp
debug1("UDP enabled: %r\n" % udp)
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]: if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
# if both ports given, no need to search for a spare port # if both ports given, no need to search for a spare port
ports = [0, ] ports = [0, ]
@ -536,7 +561,7 @@ def main(listenip_v6, listenip_v4,
tcp_listener = MultiListener() tcp_listener = MultiListener()
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if udp: if required.udp:
udp_listener = MultiListener(socket.SOCK_DGRAM) udp_listener = MultiListener(socket.SOCK_DGRAM)
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
else: else:
@ -584,10 +609,7 @@ def main(listenip_v6, listenip_v4,
udp_listener.print_listening("UDP redirector") udp_listener.print_listening("UDP redirector")
bound = False bound = False
if dns or nslist: if required.dns:
if dns:
nslist += resolvconf_nameservers()
dns = True
# search for spare port for DNS # search for spare port for DNS
debug2('Binding DNS:') debug2('Binding DNS:')
ports = range(12300, 9000, -1) ports = range(12300, 9000, -1)
@ -628,17 +650,41 @@ def main(listenip_v6, listenip_v4,
dnsport_v4 = 0 dnsport_v4 = 0
dns_listener = None dns_listener = None
fw.method.check_settings(udp, dns) # Last minute sanity checks.
# These should never fail.
# If these do fail, something is broken above.
if len(subnets_v6) > 0:
assert required.ipv6
if redirectport_v6 == 0:
raise Fatal("IPv6 subnets defined but not listening")
if len(nslist_v6) > 0:
assert required.dns
assert required.ipv6
if dnsport_v6 == 0:
raise Fatal("IPv6 ns servers defined but not listening")
if len(subnets_v4) > 0:
if redirectport_v4 == 0:
raise Fatal("IPv4 subnets defined but not listening")
if len(nslist_v4) > 0:
if dnsport_v4 == 0:
raise Fatal("IPv4 ns servers defined but not listening")
# setup method specific stuff on listeners
fw.method.setup_tcp_listener(tcp_listener) fw.method.setup_tcp_listener(tcp_listener)
if udp_listener: if udp_listener:
fw.method.setup_udp_listener(udp_listener) fw.method.setup_udp_listener(udp_listener)
if dns_listener: if dns_listener:
fw.method.setup_udp_listener(dns_listener) fw.method.setup_udp_listener(dns_listener)
# start the firewall
fw.setup(subnets_include, subnets_exclude, nslist, fw.setup(subnets_include, subnets_exclude, nslist,
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4, redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4,
udp) required.udp)
# start the client process
try: try:
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
python, latency_control, dns_listener, python, latency_control, dns_listener,

View File

@ -178,26 +178,23 @@ def main(method_name, syslog):
try: try:
debug1('firewall manager: setting up.\n') debug1('firewall manager: setting up.\n')
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6] subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
if port_v6 > 0: nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
if len(subnets_v6) > 0 or len(nslist_v6) > 0:
debug2('firewall manager: setting up IPv6.\n') debug2('firewall manager: setting up IPv6.\n')
method.setup_firewall( method.setup_firewall(
port_v6, dnsport_v6, nslist_v6, port_v6, dnsport_v6, nslist_v6,
socket.AF_INET6, subnets_v6, udp) socket.AF_INET6, subnets_v6, udp)
elif len(subnets_v6) > 0:
debug1("IPv6 subnets defined but IPv6 disabled\n")
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET] subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
if port_v4 > 0: nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
if len(subnets_v4) > 0 or len(nslist_v4) > 0:
debug2('firewall manager: setting up IPv4.\n') debug2('firewall manager: setting up IPv4.\n')
method.setup_firewall( method.setup_firewall(
port_v4, dnsport_v4, nslist_v4, port_v4, dnsport_v4, nslist_v4,
socket.AF_INET, subnets_v4, udp) socket.AF_INET, subnets_v4, udp)
elif len(subnets_v4) > 0:
debug1('firewall manager: '
'IPv4 subnets defined but IPv4 disabled\n')
stdout.write('STARTED\n') stdout.write('STARTED\n')

View File

@ -62,9 +62,13 @@ class BaseMethod(object):
def setup_udp_listener(self, udp_listener): def setup_udp_listener(self, udp_listener):
pass pass
def check_settings(self, udp, dns): def assert_features(self, features):
if udp: avail = self.get_supported_features()
Fatal("UDP support not supported with method %s.\n" % self.name) for key in ["udp", "dns", "ipv6"]:
if getattr(features, key) and not getattr(avail, key):
raise Fatal(
"Feature %s not supported with method %s.\n" %
(key, self.name))
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp): def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
raise NotImplementedError() raise NotImplementedError()

View File

@ -55,7 +55,6 @@ class Method(BaseMethod):
'-p', 'tcp', '-p', 'tcp',
'--to-ports', str(port)) '--to-ports', str(port))
if dnsport:
for f, ip in [i for i in nslist if i[0] == family]: for f, ip in [i for i in nslist if i[0] == family]:
_ipt_ttl('-A', chain, '-j', 'REDIRECT', _ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/32' % ip, '--dest', '%s/32' % ip,

View File

@ -181,6 +181,7 @@ class Method(BaseMethod):
if udp: if udp:
raise Exception("UDP not supported by pf method_name") raise Exception("UDP not supported by pf method_name")
if len(subnets) > 0:
includes = [] includes = []
# If a given subnet is both included and excluded, list the # If a given subnet is both included and excluded, list the
# exclusion first; the table will ignore the second, opposite # exclusion first; the table will ignore the second, opposite
@ -201,7 +202,7 @@ class Method(BaseMethod):
b'pass out route-to lo0 inet proto tcp ' b'pass out route-to lo0 inet proto tcp '
b'to <forward_subnets> keep state') b'to <forward_subnets> keep state')
if dnsport: if len(nslist) > 0:
tables.append( tables.append(
b'table <dns_servers> {%s}' % b'table <dns_servers> {%s}' %
b','.join([ns[1].encode("ASCII") for ns in nslist])) b','.join([ns[1].encode("ASCII") for ns in nslist]))

View File

@ -59,6 +59,7 @@ if recvmsg == "python":
ip = socket.inet_ntop(family, cmsg_data[start:start + length]) ip = socket.inet_ntop(family, cmsg_data[start:start + length])
dstip = (ip, port) dstip = (ip, port)
break break
print("xxxxx", srcip, dstip)
return (srcip, dstip, data) return (srcip, dstip, data)
elif recvmsg == "socket_ext": elif recvmsg == "socket_ext":
def recv_udp(listener, bufsize): def recv_udp(listener, bufsize):
@ -187,7 +188,6 @@ class Method(BaseMethod):
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain, _ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'udp', '-p', 'udp') '-m', 'udp', '-p', 'udp')
if dnsport:
for f, ip in [i for i in nslist if i[0] == family]: for f, ip in [i for i in nslist if i[0] == family]:
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1', _ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/32' % ip, '--dest', '%s/32' % ip,
@ -267,16 +267,3 @@ class Method(BaseMethod):
if ipt_chain_exists(family, table, divert_chain): if ipt_chain_exists(family, table, divert_chain):
_ipt('-F', divert_chain) _ipt('-F', divert_chain)
_ipt('-X', divert_chain) _ipt('-X', divert_chain)
def check_settings(self, udp, dns):
if udp and recvmsg is None:
raise Fatal("tproxy UDP support requires recvmsg function.\n")
if dns and recvmsg is None:
raise Fatal("tproxy DNS support requires recvmsg function.\n")
if udp:
debug1("tproxy UDP support enabled.\n")
if dns:
debug1("tproxy DNS support enabled.\n")

View File

@ -3,6 +3,7 @@ from mock import Mock, patch, call
import socket import socket
import struct import struct
from sshuttle.helpers import Fatal
from sshuttle.methods import get_method from sshuttle.methods import get_method
@ -11,6 +12,7 @@ def test_get_supported_features():
features = method.get_supported_features() features = method.get_supported_features()
assert not features.ipv6 assert not features.ipv6
assert not features.udp assert not features.udp
assert features.dns
def test_get_tcp_dstip(): def test_get_tcp_dstip():
@ -52,10 +54,18 @@ def test_setup_udp_listener():
assert listener.mock_calls == [] assert listener.mock_calls == []
def test_check_settings(): def test_assert_features():
method = get_method('nat') method = get_method('nat')
method.check_settings(True, True) features = method.get_supported_features()
method.check_settings(False, True) method.assert_features(features)
features.udp = True
with pytest.raises(Fatal):
method.assert_features(features)
features.ipv6 = True
with pytest.raises(Fatal):
method.assert_features(features)
def test_firewall_command(): def test_firewall_command():

View File

@ -3,6 +3,7 @@ from mock import Mock, patch, call, ANY
import socket import socket
from sshuttle.methods import get_method from sshuttle.methods import get_method
from sshuttle.helpers import Fatal
def test_get_supported_features(): def test_get_supported_features():
@ -10,6 +11,7 @@ def test_get_supported_features():
features = method.get_supported_features() features = method.get_supported_features()
assert not features.ipv6 assert not features.ipv6
assert not features.udp assert not features.udp
assert features.dns
@patch('sshuttle.helpers.verbose', new=3) @patch('sshuttle.helpers.verbose', new=3)
@ -68,10 +70,18 @@ def test_setup_udp_listener():
assert listener.mock_calls == [] assert listener.mock_calls == []
def test_check_settings(): def test_assert_features():
method = get_method('pf') method = get_method('pf')
method.check_settings(True, True) features = method.get_supported_features()
method.check_settings(False, True) method.assert_features(features)
features.udp = True
with pytest.raises(Fatal):
method.assert_features(features)
features.ipv6 = True
with pytest.raises(Fatal):
method.assert_features(features)
@patch('sshuttle.methods.pf.sys.stdout') @patch('sshuttle.methods.pf.sys.stdout')

View File

@ -3,11 +3,22 @@ from mock import Mock, patch, call
from sshuttle.methods import get_method from sshuttle.methods import get_method
def test_get_supported_features(): @patch("sshuttle.methods.tproxy.recvmsg")
def test_get_supported_features_recvmsg(mock_recvmsg):
method = get_method('tproxy') method = get_method('tproxy')
features = method.get_supported_features() features = method.get_supported_features()
assert features.ipv6 assert features.ipv6
assert features.udp assert features.udp
assert features.dns
@patch("sshuttle.methods.tproxy.recvmsg", None)
def test_get_supported_features_norecvmsg():
method = get_method('tproxy')
features = method.get_supported_features()
assert features.ipv6
assert not features.udp
assert not features.dns
def test_get_tcp_dstip(): def test_get_tcp_dstip():
@ -66,10 +77,10 @@ def test_setup_udp_listener():
] ]
def test_check_settings(): def test_assert_features():
method = get_method('tproxy') method = get_method('tproxy')
method.check_settings(True, True) features = method.get_supported_features()
method.check_settings(False, True) method.assert_features(features)
def test_firewall_command(): def test_firewall_command():