Split setup_firewall method.

* setup_firewall sets the firewall up.
* restore_firewall restores the firewall to initial state.
This commit is contained in:
Brian May 2015-12-13 11:56:18 +11:00
parent 2eeea9536a
commit 2b235331d0
9 changed files with 182 additions and 131 deletions

View File

@ -232,7 +232,7 @@ def main(method_name, syslog):
try:
if port_v6:
debug2('firewall manager: undoing IPv6 changes.\n')
method.setup_firewall(port_v6, 0, [], socket.AF_INET6, [], udp)
method.restore_firewall(port_v6, socket.AF_INET6, udp)
except:
try:
debug1("firewall manager: "
@ -245,7 +245,7 @@ def main(method_name, syslog):
try:
if port_v4:
debug2('firewall manager: undoing IPv4 changes.\n')
method.setup_firewall(port_v4, 0, [], socket.AF_INET, [], udp)
method.restore_firewall(port_v4, socket.AF_INET, udp)
except:
try:
debug1("firewall manager: "

View File

@ -68,6 +68,9 @@ class BaseMethod(object):
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
raise NotImplementedError()
def restore_firewall(self, port, family, udp):
raise NotImplementedError()
def firewall_command(self, line):
return False

View File

@ -31,35 +31,29 @@ class Method(BaseMethod):
chain = 'sshuttle-%s' % port
# basic cleanup/setup of chains
if ipt_chain_exists(family, table, chain):
nonfatal(_ipt, '-D', 'OUTPUT', '-j', chain)
nonfatal(_ipt, '-D', 'PREROUTING', '-j', chain)
nonfatal(_ipt, '-F', chain)
_ipt('-X', chain)
self.restore_firewall(port, family, udp)
if subnets or dnsport:
_ipt('-N', chain)
_ipt('-F', chain)
_ipt('-I', 'OUTPUT', '1', '-j', chain)
_ipt('-I', 'PREROUTING', '1', '-j', chain)
_ipt('-N', chain)
_ipt('-F', chain)
_ipt('-I', 'OUTPUT', '1', '-j', chain)
_ipt('-I', 'PREROUTING', '1', '-j', chain)
if subnets:
# create new subnet entries. Note that we're sorting in a very
# particular order: we need to go from most-specific (largest
# swidth) to least-specific, and at any given level of specificity,
# we want excludes to come first. That's why the columns are in
# such a non- intuitive order.
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
_ipt('-A', chain, '-j', 'RETURN',
# create new subnet entries. Note that we're sorting in a very
# particular order: we need to go from most-specific (largest
# swidth) to least-specific, and at any given level of specificity,
# we want excludes to come first. That's why the columns are in
# such a non- intuitive order.
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
_ipt('-A', chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp')
else:
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp')
else:
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp',
'--to-ports', str(port))
'-p', 'tcp',
'--to-ports', str(port))
if dnsport:
for f, ip in [i for i in nslist if i[0] == family]:
@ -68,3 +62,29 @@ class Method(BaseMethod):
'-p', 'udp',
'--dport', '53',
'--to-ports', str(dnsport))
def restore_firewall(self, port, family, udp):
# only ipv4 supported with NAT
if family != socket.AF_INET:
raise Exception(
'Address family "%s" unsupported by nat method_name'
% family_to_string(family))
if udp:
raise Exception("UDP not supported by nat method_name")
table = "nat"
def _ipt(*args):
return ipt(family, table, *args)
def _ipt_ttl(*args):
return ipt_ttl(family, table, *args)
chain = 'sshuttle-%s' % port
# basic cleanup/setup of chains
if ipt_chain_exists(family, table, chain):
nonfatal(_ipt, '-D', 'OUTPUT', '-j', chain)
nonfatal(_ipt, '-D', 'PREROUTING', '-j', chain)
nonfatal(_ipt, '-F', chain)
_ipt('-X', chain)

View File

@ -181,63 +181,70 @@ class Method(BaseMethod):
if udp:
raise Exception("UDP not supported by pf method_name")
if subnets:
includes = []
# If a given subnet is both included and excluded, list the
# exclusion first; the table will ignore the second, opposite
# definition
for f, swidth, sexclude, snet in sorted(
subnets, key=lambda s: (s[1], s[2]), reverse=True):
includes.append(b"%s%s/%d" %
(b"!" if sexclude else b"",
snet.encode("ASCII"),
swidth))
includes = []
# If a given subnet is both included and excluded, list the
# exclusion first; the table will ignore the second, opposite
# definition
for f, swidth, sexclude, snet in sorted(
subnets, key=lambda s: (s[1], s[2]), reverse=True):
includes.append(b"%s%s/%d" %
(b"!" if sexclude else b"",
snet.encode("ASCII"),
swidth))
tables.append(
b'table <forward_subnets> {%s}' % b','.join(includes))
translating_rules.append(
b'rdr pass on lo0 proto tcp '
b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
filtering_rules.append(
b'pass out route-to lo0 inet proto tcp '
b'to <forward_subnets> keep state')
if dnsport:
tables.append(
b'table <forward_subnets> {%s}' % b','.join(includes))
b'table <dns_servers> {%s}' %
b','.join([ns[1].encode("ASCII") for ns in nslist]))
translating_rules.append(
b'rdr pass on lo0 proto tcp '
b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
b'rdr pass on lo0 proto udp to '
b'<dns_servers> port 53 -> 127.0.0.1 port %r' % dnsport)
filtering_rules.append(
b'pass out route-to lo0 inet proto tcp '
b'to <forward_subnets> keep state')
b'pass out route-to lo0 inet proto udp to '
b'<dns_servers> port 53 keep state')
if dnsport:
tables.append(
b'table <dns_servers> {%s}' %
b','.join([ns[1].encode("ASCII") for ns in nslist]))
translating_rules.append(
b'rdr pass on lo0 proto udp to '
b'<dns_servers> port 53 -> 127.0.0.1 port %r' % dnsport)
filtering_rules.append(
b'pass out route-to lo0 inet proto udp to '
b'<dns_servers> port 53 keep state')
rules = b'\n'.join(tables + translating_rules + filtering_rules) \
+ b'\n'
assert isinstance(rules, bytes)
rules = b'\n'.join(tables + translating_rules + filtering_rules) \
+ b'\n'
assert isinstance(rules, bytes)
pf_status = pfctl('-s all')[0]
if b'\nrdr-anchor "sshuttle" all\n' not in pf_status:
pf_add_anchor_rule(PF_RDR, "sshuttle")
if b'\nanchor "sshuttle" all\n' not in pf_status:
pf_add_anchor_rule(PF_PASS, "sshuttle")
pf_status = pfctl('-s all')[0]
if b'\nrdr-anchor "sshuttle" all\n' not in pf_status:
pf_add_anchor_rule(PF_RDR, "sshuttle")
if b'\nanchor "sshuttle" all\n' not in pf_status:
pf_add_anchor_rule(PF_PASS, "sshuttle")
pfctl('-a sshuttle -f /dev/stdin', rules)
if sys.platform == "darwin":
o = pfctl('-E')
_pf_context['Xtoken'] = \
re.search(b'Token : (.+)', o[1]).group(1)
elif b'INFO:\nStatus: Disabled' in pf_status:
pfctl('-e')
_pf_context['started_by_sshuttle'] = True
pfctl('-a sshuttle -f /dev/stdin', rules)
if sys.platform == "darwin":
o = pfctl('-E')
_pf_context['Xtoken'] = \
re.search(b'Token : (.+)', o[1]).group(1)
elif b'INFO:\nStatus: Disabled' in pf_status:
pfctl('-e')
_pf_context['started_by_sshuttle'] = True
else:
pfctl('-a sshuttle -F all')
if sys.platform == "darwin":
if _pf_context['Xtoken'] is not None:
pfctl('-X %s' % _pf_context['Xtoken'].decode("ASCII"))
elif _pf_context['started_by_sshuttle']:
pfctl('-d')
def restore_firewall(self, port, family, udp):
if family != socket.AF_INET:
raise Exception(
'Address family "%s" unsupported by pf method_name'
% family_to_string(family))
if udp:
raise Exception("UDP not supported by pf method_name")
pfctl('-a sshuttle -F all')
if sys.platform == "darwin":
if _pf_context['Xtoken'] is not None:
pfctl('-X %s' % _pf_context['Xtoken'].decode("ASCII"))
elif _pf_context['started_by_sshuttle']:
pfctl('-d')
def firewall_command(self, line):
if line.startswith('QUERY_PF_NAT '):

View File

@ -163,34 +163,22 @@ class Method(BaseMethod):
divert_chain = 'sshuttle-d-%s' % port
# basic cleanup/setup of chains
if ipt_chain_exists(family, table, mark_chain):
_ipt('-D', 'OUTPUT', '-j', mark_chain)
_ipt('-F', mark_chain)
_ipt('-X', mark_chain)
self.restore_firewall(port, family, udp)
if ipt_chain_exists(family, table, tproxy_chain):
_ipt('-D', 'PREROUTING', '-j', tproxy_chain)
_ipt('-F', tproxy_chain)
_ipt('-X', tproxy_chain)
_ipt('-N', mark_chain)
_ipt('-F', mark_chain)
_ipt('-N', divert_chain)
_ipt('-F', divert_chain)
_ipt('-N', tproxy_chain)
_ipt('-F', tproxy_chain)
_ipt('-I', 'OUTPUT', '1', '-j', mark_chain)
_ipt('-I', 'PREROUTING', '1', '-j', tproxy_chain)
_ipt('-A', divert_chain, '-j', 'MARK', '--set-mark', '1')
_ipt('-A', divert_chain, '-j', 'ACCEPT')
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'tcp', '-p', 'tcp')
if ipt_chain_exists(family, table, divert_chain):
_ipt('-F', divert_chain)
_ipt('-X', divert_chain)
if subnets or dnsport:
_ipt('-N', mark_chain)
_ipt('-F', mark_chain)
_ipt('-N', divert_chain)
_ipt('-F', divert_chain)
_ipt('-N', tproxy_chain)
_ipt('-F', tproxy_chain)
_ipt('-I', 'OUTPUT', '1', '-j', mark_chain)
_ipt('-I', 'PREROUTING', '1', '-j', tproxy_chain)
_ipt('-A', divert_chain, '-j', 'MARK', '--set-mark', '1')
_ipt('-A', divert_chain, '-j', 'ACCEPT')
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'tcp', '-p', 'tcp')
if subnets and udp:
if udp:
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'udp', '-p', 'udp')
@ -205,34 +193,34 @@ class Method(BaseMethod):
'-m', 'udp', '-p', 'udp', '--dport', '53',
'--on-port', str(dnsport))
if subnets:
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
_ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
_ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
else:
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
_ipt('-A', tproxy_chain, '-j', 'TPROXY',
'--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp',
'--on-port', str(port))
if udp:
if sexclude:
_ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
'-m', 'udp', '-p', 'udp')
_ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
'-m', 'udp', '-p', 'udp')
else:
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
_ipt('-A', tproxy_chain, '-j', 'TPROXY',
'--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp',
'--on-port', str(port))
if sexclude and udp:
_ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp')
_ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp')
elif udp:
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp')
@ -242,6 +230,39 @@ class Method(BaseMethod):
'-m', 'udp', '-p', 'udp',
'--on-port', str(port))
def restore_firewall(self, port, family, udp):
if family not in [socket.AF_INET, socket.AF_INET6]:
raise Exception(
'Address family "%s" unsupported by tproxy method'
% family_to_string(family))
table = "mangle"
def _ipt(*args):
return ipt(family, table, *args)
def _ipt_ttl(*args):
return ipt_ttl(family, table, *args)
mark_chain = 'sshuttle-m-%s' % port
tproxy_chain = 'sshuttle-t-%s' % port
divert_chain = 'sshuttle-d-%s' % port
# basic cleanup/setup of chains
if ipt_chain_exists(family, table, mark_chain):
_ipt('-D', 'OUTPUT', '-j', mark_chain)
_ipt('-F', mark_chain)
_ipt('-X', mark_chain)
if ipt_chain_exists(family, table, tproxy_chain):
_ipt('-D', 'PREROUTING', '-j', tproxy_chain)
_ipt('-F', tproxy_chain)
_ipt('-X', tproxy_chain)
if ipt_chain_exists(family, table, divert_chain):
_ipt('-F', divert_chain)
_ipt('-X', divert_chain)
def check_settings(self, udp, dns):
if udp and recvmsg is None:
Fatal("tproxy UDP support requires recvmsg function.\n")

View File

@ -97,6 +97,6 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
2,
[(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')],
True),
call().setup_firewall(1024, 0, [], 10, [], True),
call().setup_firewall(1025, 0, [], 2, [], True),
call().restore_firewall(1024, 10, True),
call().restore_firewall(1025, 2, True),
]

View File

@ -129,7 +129,7 @@ def test_setup_firewall(mock_ipt_chain_exists, mock_ipt_ttl, mock_ipt):
mock_ipt_ttl.reset_mock()
mock_ipt.reset_mock()
method.setup_firewall(1025, 0, [], 2, [], False)
method.restore_firewall(1025, 2, False)
assert mock_ipt_chain_exists.mock_calls == [
call(2, 'nat', 'sshuttle-1025')
]

View File

@ -173,7 +173,7 @@ def test_setup_firewall(mock_pf_get_dev, mock_ioctl, mock_pfctl):
mock_ioctl.reset_mock()
mock_pfctl.reset_mock()
method.setup_firewall(1025, 0, [], 2, [], False)
method.restore_firewall(1025, 2, False)
assert mock_ioctl.mock_calls == []
assert mock_pfctl.mock_calls == [
call('-a sshuttle -F all'),

View File

@ -160,7 +160,7 @@ def test_setup_firewall(mock_ipt_chain_exists, mock_ipt_ttl, mock_ipt):
mock_ipt_ttl.reset_mock()
mock_ipt.reset_mock()
method.setup_firewall(1025, 0, [], 10, [], True)
method.restore_firewall(1025, 10, True)
assert mock_ipt_chain_exists.mock_calls == [
call(10, 'mangle', 'sshuttle-m-1025'),
call(10, 'mangle', 'sshuttle-t-1025'),
@ -250,7 +250,7 @@ def test_setup_firewall(mock_ipt_chain_exists, mock_ipt_ttl, mock_ipt):
mock_ipt_ttl.reset_mock()
mock_ipt.reset_mock()
method.setup_firewall(1025, 0, [], 2, [], True)
method.restore_firewall(1025, 2, True)
assert mock_ipt_chain_exists.mock_calls == [
call(2, 'mangle', 'sshuttle-m-1025'),
call(2, 'mangle', 'sshuttle-t-1025'),