mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-04-21 17:58:39 +02:00
Simplify selection of features
This commit is contained in:
parent
6b4e36c528
commit
90654b4fb9
@ -14,7 +14,7 @@ import platform
|
|||||||
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
|
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
|
||||||
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
|
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
|
||||||
resolvconf_nameservers
|
resolvconf_nameservers
|
||||||
from sshuttle.methods import get_method
|
from sshuttle.methods import get_method, Features
|
||||||
|
|
||||||
_extra_fd = os.open('/dev/null', os.O_RDONLY)
|
_extra_fd = os.open('/dev/null', os.O_RDONLY)
|
||||||
|
|
||||||
@ -505,19 +505,44 @@ def main(listenip_v6, listenip_v4,
|
|||||||
|
|
||||||
fw = FirewallClient(method_name)
|
fw = FirewallClient(method_name)
|
||||||
|
|
||||||
features = fw.method.get_supported_features()
|
# Get family specific subnet lists
|
||||||
|
if dns:
|
||||||
|
nslist += resolvconf_nameservers()
|
||||||
|
|
||||||
|
subnets = subnets_include + subnets_exclude # we don't care here
|
||||||
|
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||||
|
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||||
|
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
||||||
|
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||||
|
|
||||||
|
# Check features available
|
||||||
|
avail = fw.method.get_supported_features()
|
||||||
|
required = Features()
|
||||||
|
|
||||||
if listenip_v6 == "auto":
|
if listenip_v6 == "auto":
|
||||||
if features.ipv6:
|
if avail.ipv6:
|
||||||
listenip_v6 = ('::1', 0)
|
listenip_v6 = ('::1', 0)
|
||||||
else:
|
else:
|
||||||
listenip_v6 = None
|
listenip_v6 = None
|
||||||
|
|
||||||
|
required.ipv6 = len(subnets_v6) > 0 or len(nslist_v6) > 0
|
||||||
|
required.udp = avail.udp
|
||||||
|
required.dns = len(nslist) > 0
|
||||||
|
|
||||||
|
fw.method.assert_features(required)
|
||||||
|
|
||||||
|
if required.ipv6 and listenip_v6 is None:
|
||||||
|
raise Fatal("IPv6 required but not listening.")
|
||||||
|
|
||||||
|
# display features enabled
|
||||||
|
debug1("IPv6 enabled: %r\n" % required.ipv6)
|
||||||
|
debug1("UDP enabled: %r\n" % required.udp)
|
||||||
|
debug1("DNS enabled: %r\n" % required.dns)
|
||||||
|
|
||||||
|
# bind to required ports
|
||||||
if listenip_v4 == "auto":
|
if listenip_v4 == "auto":
|
||||||
listenip_v4 = ('127.0.0.1', 0)
|
listenip_v4 = ('127.0.0.1', 0)
|
||||||
|
|
||||||
udp = features.udp
|
|
||||||
debug1("UDP enabled: %r\n" % udp)
|
|
||||||
|
|
||||||
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
|
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
|
||||||
# if both ports given, no need to search for a spare port
|
# if both ports given, no need to search for a spare port
|
||||||
ports = [0, ]
|
ports = [0, ]
|
||||||
@ -536,7 +561,7 @@ def main(listenip_v6, listenip_v4,
|
|||||||
tcp_listener = MultiListener()
|
tcp_listener = MultiListener()
|
||||||
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
|
||||||
if udp:
|
if required.udp:
|
||||||
udp_listener = MultiListener(socket.SOCK_DGRAM)
|
udp_listener = MultiListener(socket.SOCK_DGRAM)
|
||||||
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
else:
|
else:
|
||||||
@ -584,10 +609,7 @@ def main(listenip_v6, listenip_v4,
|
|||||||
udp_listener.print_listening("UDP redirector")
|
udp_listener.print_listening("UDP redirector")
|
||||||
|
|
||||||
bound = False
|
bound = False
|
||||||
if dns or nslist:
|
if required.dns:
|
||||||
if dns:
|
|
||||||
nslist += resolvconf_nameservers()
|
|
||||||
dns = True
|
|
||||||
# search for spare port for DNS
|
# search for spare port for DNS
|
||||||
debug2('Binding DNS:')
|
debug2('Binding DNS:')
|
||||||
ports = range(12300, 9000, -1)
|
ports = range(12300, 9000, -1)
|
||||||
@ -628,17 +650,41 @@ def main(listenip_v6, listenip_v4,
|
|||||||
dnsport_v4 = 0
|
dnsport_v4 = 0
|
||||||
dns_listener = None
|
dns_listener = None
|
||||||
|
|
||||||
fw.method.check_settings(udp, dns)
|
# Last minute sanity checks.
|
||||||
|
# These should never fail.
|
||||||
|
# If these do fail, something is broken above.
|
||||||
|
if len(subnets_v6) > 0:
|
||||||
|
assert required.ipv6
|
||||||
|
if redirectport_v6 == 0:
|
||||||
|
raise Fatal("IPv6 subnets defined but not listening")
|
||||||
|
|
||||||
|
if len(nslist_v6) > 0:
|
||||||
|
assert required.dns
|
||||||
|
assert required.ipv6
|
||||||
|
if dnsport_v6 == 0:
|
||||||
|
raise Fatal("IPv6 ns servers defined but not listening")
|
||||||
|
|
||||||
|
if len(subnets_v4) > 0:
|
||||||
|
if redirectport_v4 == 0:
|
||||||
|
raise Fatal("IPv4 subnets defined but not listening")
|
||||||
|
|
||||||
|
if len(nslist_v4) > 0:
|
||||||
|
if dnsport_v4 == 0:
|
||||||
|
raise Fatal("IPv4 ns servers defined but not listening")
|
||||||
|
|
||||||
|
# setup method specific stuff on listeners
|
||||||
fw.method.setup_tcp_listener(tcp_listener)
|
fw.method.setup_tcp_listener(tcp_listener)
|
||||||
if udp_listener:
|
if udp_listener:
|
||||||
fw.method.setup_udp_listener(udp_listener)
|
fw.method.setup_udp_listener(udp_listener)
|
||||||
if dns_listener:
|
if dns_listener:
|
||||||
fw.method.setup_udp_listener(dns_listener)
|
fw.method.setup_udp_listener(dns_listener)
|
||||||
|
|
||||||
|
# start the firewall
|
||||||
fw.setup(subnets_include, subnets_exclude, nslist,
|
fw.setup(subnets_include, subnets_exclude, nslist,
|
||||||
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4,
|
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4,
|
||||||
udp)
|
required.udp)
|
||||||
|
|
||||||
|
# start the client process
|
||||||
try:
|
try:
|
||||||
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
|
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
|
||||||
python, latency_control, dns_listener,
|
python, latency_control, dns_listener,
|
||||||
|
@ -178,26 +178,23 @@ def main(method_name, syslog):
|
|||||||
try:
|
try:
|
||||||
debug1('firewall manager: setting up.\n')
|
debug1('firewall manager: setting up.\n')
|
||||||
|
|
||||||
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
|
||||||
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||||
if port_v6 > 0:
|
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||||
|
|
||||||
|
if len(subnets_v6) > 0 or len(nslist_v6) > 0:
|
||||||
debug2('firewall manager: setting up IPv6.\n')
|
debug2('firewall manager: setting up IPv6.\n')
|
||||||
method.setup_firewall(
|
method.setup_firewall(
|
||||||
port_v6, dnsport_v6, nslist_v6,
|
port_v6, dnsport_v6, nslist_v6,
|
||||||
socket.AF_INET6, subnets_v6, udp)
|
socket.AF_INET6, subnets_v6, udp)
|
||||||
elif len(subnets_v6) > 0:
|
|
||||||
debug1("IPv6 subnets defined but IPv6 disabled\n")
|
|
||||||
|
|
||||||
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
|
||||||
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
||||||
if port_v4 > 0:
|
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||||
|
|
||||||
|
if len(subnets_v4) > 0 or len(nslist_v4) > 0:
|
||||||
debug2('firewall manager: setting up IPv4.\n')
|
debug2('firewall manager: setting up IPv4.\n')
|
||||||
method.setup_firewall(
|
method.setup_firewall(
|
||||||
port_v4, dnsport_v4, nslist_v4,
|
port_v4, dnsport_v4, nslist_v4,
|
||||||
socket.AF_INET, subnets_v4, udp)
|
socket.AF_INET, subnets_v4, udp)
|
||||||
elif len(subnets_v4) > 0:
|
|
||||||
debug1('firewall manager: '
|
|
||||||
'IPv4 subnets defined but IPv4 disabled\n')
|
|
||||||
|
|
||||||
stdout.write('STARTED\n')
|
stdout.write('STARTED\n')
|
||||||
|
|
||||||
|
@ -62,9 +62,13 @@ class BaseMethod(object):
|
|||||||
def setup_udp_listener(self, udp_listener):
|
def setup_udp_listener(self, udp_listener):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def check_settings(self, udp, dns):
|
def assert_features(self, features):
|
||||||
if udp:
|
avail = self.get_supported_features()
|
||||||
Fatal("UDP support not supported with method %s.\n" % self.name)
|
for key in ["udp", "dns", "ipv6"]:
|
||||||
|
if getattr(features, key) and not getattr(avail, key):
|
||||||
|
raise Fatal(
|
||||||
|
"Feature %s not supported with method %s.\n" %
|
||||||
|
(key, self.name))
|
||||||
|
|
||||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
|
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -55,7 +55,6 @@ class Method(BaseMethod):
|
|||||||
'-p', 'tcp',
|
'-p', 'tcp',
|
||||||
'--to-ports', str(port))
|
'--to-ports', str(port))
|
||||||
|
|
||||||
if dnsport:
|
|
||||||
for f, ip in [i for i in nslist if i[0] == family]:
|
for f, ip in [i for i in nslist if i[0] == family]:
|
||||||
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
|
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
|
||||||
'--dest', '%s/32' % ip,
|
'--dest', '%s/32' % ip,
|
||||||
|
@ -181,6 +181,7 @@ class Method(BaseMethod):
|
|||||||
if udp:
|
if udp:
|
||||||
raise Exception("UDP not supported by pf method_name")
|
raise Exception("UDP not supported by pf method_name")
|
||||||
|
|
||||||
|
if len(subnets) > 0:
|
||||||
includes = []
|
includes = []
|
||||||
# If a given subnet is both included and excluded, list the
|
# If a given subnet is both included and excluded, list the
|
||||||
# exclusion first; the table will ignore the second, opposite
|
# exclusion first; the table will ignore the second, opposite
|
||||||
@ -201,7 +202,7 @@ class Method(BaseMethod):
|
|||||||
b'pass out route-to lo0 inet proto tcp '
|
b'pass out route-to lo0 inet proto tcp '
|
||||||
b'to <forward_subnets> keep state')
|
b'to <forward_subnets> keep state')
|
||||||
|
|
||||||
if dnsport:
|
if len(nslist) > 0:
|
||||||
tables.append(
|
tables.append(
|
||||||
b'table <dns_servers> {%s}' %
|
b'table <dns_servers> {%s}' %
|
||||||
b','.join([ns[1].encode("ASCII") for ns in nslist]))
|
b','.join([ns[1].encode("ASCII") for ns in nslist]))
|
||||||
|
@ -59,6 +59,7 @@ if recvmsg == "python":
|
|||||||
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
|
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
|
||||||
dstip = (ip, port)
|
dstip = (ip, port)
|
||||||
break
|
break
|
||||||
|
print("xxxxx", srcip, dstip)
|
||||||
return (srcip, dstip, data)
|
return (srcip, dstip, data)
|
||||||
elif recvmsg == "socket_ext":
|
elif recvmsg == "socket_ext":
|
||||||
def recv_udp(listener, bufsize):
|
def recv_udp(listener, bufsize):
|
||||||
@ -187,7 +188,6 @@ class Method(BaseMethod):
|
|||||||
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
|
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
|
||||||
'-m', 'udp', '-p', 'udp')
|
'-m', 'udp', '-p', 'udp')
|
||||||
|
|
||||||
if dnsport:
|
|
||||||
for f, ip in [i for i in nslist if i[0] == family]:
|
for f, ip in [i for i in nslist if i[0] == family]:
|
||||||
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
|
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
|
||||||
'--dest', '%s/32' % ip,
|
'--dest', '%s/32' % ip,
|
||||||
@ -267,16 +267,3 @@ class Method(BaseMethod):
|
|||||||
if ipt_chain_exists(family, table, divert_chain):
|
if ipt_chain_exists(family, table, divert_chain):
|
||||||
_ipt('-F', divert_chain)
|
_ipt('-F', divert_chain)
|
||||||
_ipt('-X', divert_chain)
|
_ipt('-X', divert_chain)
|
||||||
|
|
||||||
def check_settings(self, udp, dns):
|
|
||||||
if udp and recvmsg is None:
|
|
||||||
raise Fatal("tproxy UDP support requires recvmsg function.\n")
|
|
||||||
|
|
||||||
if dns and recvmsg is None:
|
|
||||||
raise Fatal("tproxy DNS support requires recvmsg function.\n")
|
|
||||||
|
|
||||||
if udp:
|
|
||||||
debug1("tproxy UDP support enabled.\n")
|
|
||||||
|
|
||||||
if dns:
|
|
||||||
debug1("tproxy DNS support enabled.\n")
|
|
||||||
|
@ -3,6 +3,7 @@ from mock import Mock, patch, call
|
|||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
from sshuttle.helpers import Fatal
|
||||||
from sshuttle.methods import get_method
|
from sshuttle.methods import get_method
|
||||||
|
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ def test_get_supported_features():
|
|||||||
features = method.get_supported_features()
|
features = method.get_supported_features()
|
||||||
assert not features.ipv6
|
assert not features.ipv6
|
||||||
assert not features.udp
|
assert not features.udp
|
||||||
|
assert features.dns
|
||||||
|
|
||||||
|
|
||||||
def test_get_tcp_dstip():
|
def test_get_tcp_dstip():
|
||||||
@ -52,10 +54,18 @@ def test_setup_udp_listener():
|
|||||||
assert listener.mock_calls == []
|
assert listener.mock_calls == []
|
||||||
|
|
||||||
|
|
||||||
def test_check_settings():
|
def test_assert_features():
|
||||||
method = get_method('nat')
|
method = get_method('nat')
|
||||||
method.check_settings(True, True)
|
features = method.get_supported_features()
|
||||||
method.check_settings(False, True)
|
method.assert_features(features)
|
||||||
|
|
||||||
|
features.udp = True
|
||||||
|
with pytest.raises(Fatal):
|
||||||
|
method.assert_features(features)
|
||||||
|
|
||||||
|
features.ipv6 = True
|
||||||
|
with pytest.raises(Fatal):
|
||||||
|
method.assert_features(features)
|
||||||
|
|
||||||
|
|
||||||
def test_firewall_command():
|
def test_firewall_command():
|
||||||
|
@ -3,6 +3,7 @@ from mock import Mock, patch, call, ANY
|
|||||||
import socket
|
import socket
|
||||||
|
|
||||||
from sshuttle.methods import get_method
|
from sshuttle.methods import get_method
|
||||||
|
from sshuttle.helpers import Fatal
|
||||||
|
|
||||||
|
|
||||||
def test_get_supported_features():
|
def test_get_supported_features():
|
||||||
@ -10,6 +11,7 @@ def test_get_supported_features():
|
|||||||
features = method.get_supported_features()
|
features = method.get_supported_features()
|
||||||
assert not features.ipv6
|
assert not features.ipv6
|
||||||
assert not features.udp
|
assert not features.udp
|
||||||
|
assert features.dns
|
||||||
|
|
||||||
|
|
||||||
@patch('sshuttle.helpers.verbose', new=3)
|
@patch('sshuttle.helpers.verbose', new=3)
|
||||||
@ -68,10 +70,18 @@ def test_setup_udp_listener():
|
|||||||
assert listener.mock_calls == []
|
assert listener.mock_calls == []
|
||||||
|
|
||||||
|
|
||||||
def test_check_settings():
|
def test_assert_features():
|
||||||
method = get_method('pf')
|
method = get_method('pf')
|
||||||
method.check_settings(True, True)
|
features = method.get_supported_features()
|
||||||
method.check_settings(False, True)
|
method.assert_features(features)
|
||||||
|
|
||||||
|
features.udp = True
|
||||||
|
with pytest.raises(Fatal):
|
||||||
|
method.assert_features(features)
|
||||||
|
|
||||||
|
features.ipv6 = True
|
||||||
|
with pytest.raises(Fatal):
|
||||||
|
method.assert_features(features)
|
||||||
|
|
||||||
|
|
||||||
@patch('sshuttle.methods.pf.sys.stdout')
|
@patch('sshuttle.methods.pf.sys.stdout')
|
||||||
|
@ -3,11 +3,22 @@ from mock import Mock, patch, call
|
|||||||
from sshuttle.methods import get_method
|
from sshuttle.methods import get_method
|
||||||
|
|
||||||
|
|
||||||
def test_get_supported_features():
|
@patch("sshuttle.methods.tproxy.recvmsg")
|
||||||
|
def test_get_supported_features_recvmsg(mock_recvmsg):
|
||||||
method = get_method('tproxy')
|
method = get_method('tproxy')
|
||||||
features = method.get_supported_features()
|
features = method.get_supported_features()
|
||||||
assert features.ipv6
|
assert features.ipv6
|
||||||
assert features.udp
|
assert features.udp
|
||||||
|
assert features.dns
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sshuttle.methods.tproxy.recvmsg", None)
|
||||||
|
def test_get_supported_features_norecvmsg():
|
||||||
|
method = get_method('tproxy')
|
||||||
|
features = method.get_supported_features()
|
||||||
|
assert features.ipv6
|
||||||
|
assert not features.udp
|
||||||
|
assert not features.dns
|
||||||
|
|
||||||
|
|
||||||
def test_get_tcp_dstip():
|
def test_get_tcp_dstip():
|
||||||
@ -66,10 +77,10 @@ def test_setup_udp_listener():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_check_settings():
|
def test_assert_features():
|
||||||
method = get_method('tproxy')
|
method = get_method('tproxy')
|
||||||
method.check_settings(True, True)
|
features = method.get_supported_features()
|
||||||
method.check_settings(False, True)
|
method.assert_features(features)
|
||||||
|
|
||||||
|
|
||||||
def test_firewall_command():
|
def test_firewall_command():
|
||||||
|
Loading…
Reference in New Issue
Block a user