Add firewall tests.

This commit is contained in:
Brian May 2015-11-17 09:19:20 +11:00
parent ac723694bf
commit 54de23aae3
4 changed files with 160 additions and 39 deletions

12
.travis.yml Normal file
View File

@ -0,0 +1,12 @@
language: python
python:
- 2.6
- 2.7
- 3.5
- pypy
install:
- travis_retry pip install -q pytest mock
script:
- py.test

View File

@ -8,10 +8,10 @@ from sshuttle.helpers import debug1, debug2, Fatal
from sshuttle.methods import get_auto_method, get_method from sshuttle.methods import get_auto_method, get_method
hostmap = {} hostmap = {}
HOSTSFILE = '/etc/hosts'
def rewrite_etc_hosts(port): def rewrite_etc_hosts(port):
HOSTSFILE = '/etc/hosts'
BAKFILE = '%s.sbak' % HOSTSFILE BAKFILE = '%s.sbak' % HOSTSFILE
APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port
old_content = '' old_content = ''
@ -51,36 +51,11 @@ def restore_etc_hosts(port):
rewrite_etc_hosts(port) rewrite_etc_hosts(port)
# This is some voodoo for setting up the kernel's transparent # Isolate function that needs to be replaced for tests
# proxying stuff. If subnets is empty, we just delete our sshuttle rules; def setup_daemon():
# otherwise we delete it, then make them from scratch.
#
# This code is supposed to clean up after itself by deleting its rules on
# exit. In case that fails, it's not the end of the world; future runs will
# supercede it in the transproxy list, at least, so the leftover rules
# are hopefully harmless.
def main(method_name, syslog):
if os.getuid() != 0: if os.getuid() != 0:
raise Fatal('you must be root (or enable su/sudo) to set the firewall') raise Fatal('you must be root (or enable su/sudo) to set the firewall')
if method_name == "auto":
method = get_auto_method()
else:
method = get_method(method_name)
# because of limitations of the 'su' command, the *real* stdin/stdout
# are both attached to stdout initially. Clone stdout into stdin so we
# can read from it.
os.dup2(1, 0)
if syslog:
ssyslog.start_syslog()
ssyslog.stderr_to_syslog()
debug1('firewall manager ready method name %s.\n' % method.name)
sys.stdout.write('READY %s\n' % method.name)
sys.stdout.flush()
# don't disappear if our controlling terminal or stdout/stderr # don't disappear if our controlling terminal or stdout/stderr
# disappears; we still have to clean up. # disappears; we still have to clean up.
signal.signal(signal.SIGHUP, signal.SIG_IGN) signal.signal(signal.SIGHUP, signal.SIG_IGN)
@ -92,10 +67,42 @@ def main(method_name, syslog):
# I'll die automatically. # I'll die automatically.
os.setsid() os.setsid()
# because of limitations of the 'su' command, the *real* stdin/stdout
# are both attached to stdout initially. Clone stdout into stdin so we
# can read from it.
os.dup2(1, 0)
return sys.stdin, sys.stdout
# This is some voodoo for setting up the kernel's transparent
# proxying stuff. If subnets is empty, we just delete our sshuttle rules;
# otherwise we delete it, then make them from scratch.
#
# This code is supposed to clean up after itself by deleting its rules on
# exit. In case that fails, it's not the end of the world; future runs will
# supercede it in the transproxy list, at least, so the leftover rules
# are hopefully harmless.
def main(method_name, syslog):
stdin, stdout = setup_daemon()
if method_name == "auto":
method = get_auto_method()
else:
method = get_method(method_name)
if syslog:
ssyslog.start_syslog()
ssyslog.stderr_to_syslog()
debug1('firewall manager ready method name %s.\n' % method_name)
stdout.write('READY %s\n' % method_name)
stdout.flush()
# we wait until we get some input before creating the rules. That way, # we wait until we get some input before creating the rules. That way,
# sshuttle can launch us as early as possible (and get sudo password # sshuttle can launch us as early as possible (and get sudo password
# authentication as early in the startup process as possible). # authentication as early in the startup process as possible).
line = sys.stdin.readline(128) line = stdin.readline(128)
if not line: if not line:
return # parent died; nothing to do return # parent died; nothing to do
@ -103,7 +110,7 @@ def main(method_name, syslog):
if line != 'ROUTES\n': if line != 'ROUTES\n':
raise Fatal('firewall: expected ROUTES but got %r' % line) raise Fatal('firewall: expected ROUTES but got %r' % line)
while 1: while 1:
line = sys.stdin.readline(128) line = stdin.readline(128)
if not line: if not line:
raise Fatal('firewall: expected route but got %r' % line) raise Fatal('firewall: expected route but got %r' % line)
elif line.startswith("NSLIST\n"): elif line.startswith("NSLIST\n"):
@ -119,7 +126,7 @@ def main(method_name, syslog):
if line != 'NSLIST\n': if line != 'NSLIST\n':
raise Fatal('firewall: expected NSLIST but got %r' % line) raise Fatal('firewall: expected NSLIST but got %r' % line)
while 1: while 1:
line = sys.stdin.readline(128) line = stdin.readline(128)
if not line: if not line:
raise Fatal('firewall: expected nslist but got %r' % line) raise Fatal('firewall: expected nslist but got %r' % line)
elif line.startswith("PORTS "): elif line.startswith("PORTS "):
@ -155,7 +162,7 @@ def main(method_name, syslog):
debug2('Got ports: %d,%d,%d,%d\n' debug2('Got ports: %d,%d,%d,%d\n'
% (port_v6, port_v4, dnsport_v6, dnsport_v4)) % (port_v6, port_v4, dnsport_v6, dnsport_v4))
line = sys.stdin.readline(128) line = stdin.readline(128)
if not line: if not line:
raise Fatal('firewall: expected GO but got %r' % line) raise Fatal('firewall: expected GO but got %r' % line)
elif not line.startswith("GO "): elif not line.startswith("GO "):
@ -169,26 +176,28 @@ def main(method_name, syslog):
do_wait = None do_wait = None
debug1('firewall manager: starting transproxy.\n') debug1('firewall manager: starting transproxy.\n')
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6] subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
if port_v6 > 0: if port_v6 > 0:
do_wait = method.setup_firewall( do_wait = method.setup_firewall(
port_v6, dnsport_v6, nslist, port_v6, dnsport_v6, nslist_v6,
socket.AF_INET6, subnets_v6, udp) socket.AF_INET6, subnets_v6, udp)
elif len(subnets_v6) > 0: elif len(subnets_v6) > 0:
debug1("IPv6 subnets defined but IPv6 disabled\n") debug1("IPv6 subnets defined but IPv6 disabled\n")
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET] subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
if port_v4 > 0: if port_v4 > 0:
do_wait = method.setup_firewall( do_wait = method.setup_firewall(
port_v4, dnsport_v4, nslist, port_v4, dnsport_v4, nslist_v4,
socket.AF_INET, subnets_v4, udp) socket.AF_INET, subnets_v4, udp)
elif len(subnets_v4) > 0: elif len(subnets_v4) > 0:
debug1('IPv4 subnets defined but IPv4 disabled\n') debug1('IPv4 subnets defined but IPv4 disabled\n')
sys.stdout.write('STARTED\n') stdout.write('STARTED\n')
try: try:
sys.stdout.flush() stdout.flush()
except IOError: except IOError:
# the parent process died for some reason; he's surely been loud # the parent process died for some reason; he's surely been loud
# enough, so no reason to report another error # enough, so no reason to report another error
@ -198,9 +207,9 @@ def main(method_name, syslog):
# to stay running so that we don't need a *second* password # to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important! # authentication at shutdown time - that cleanup is important!
while 1: while 1:
if do_wait: if do_wait is not None:
do_wait() do_wait()
line = sys.stdin.readline(128) line = stdin.readline(128)
if line.startswith('HOST '): if line.startswith('HOST '):
(name, ip) = line[5:].strip().split(',', 1) (name, ip) = line[5:].strip().split(',', 1)
hostmap[name] = ip hostmap[name] = ip

View File

@ -0,0 +1,101 @@
from mock import Mock, patch, call
from contextlib import nested
import io
import os
import os.path
import shutil
import filecmp
import pytest
import sshuttle.firewall
def setup_daemon():
stdin = io.StringIO(u"""ROUTES
2,24,0,1.2.3.0
2,32,1,1.2.3.66
10,64,0,2404:6800:4004:80c::
10,128,1,2404:6800:4004:80c::101f
NSLIST
2,1.2.3.33
10,2404:6800:4004:80c::33
PORTS 1024,1025,1026,1027
GO 1
""")
stdout = Mock()
return stdin, stdout
def test_rewrite_etc_hosts():
if not os.path.isdir("tmp"):
os.mkdir("tmp")
with open("tmp/hosts.orig", "w") as f:
f.write("1.2.3.3 existing\n")
shutil.copyfile("tmp/hosts.orig", "tmp/hosts")
sshuttle.firewall.HOSTSFILE = "tmp/hosts"
sshuttle.firewall.hostmap = {
'myhost': '1.2.3.4',
'myotherhost': '1.2.3.5',
}
sshuttle.firewall.rewrite_etc_hosts(10)
with open("tmp/hosts") as f:
line = f.next()
s = line.split()
assert s == ['1.2.3.3', 'existing']
line = f.next()
s = line.split()
assert s == ['1.2.3.4', 'myhost',
'#', 'sshuttle-firewall-10', 'AUTOCREATED']
line = f.next()
s = line.split()
assert s == ['1.2.3.5', 'myotherhost',
'#', 'sshuttle-firewall-10', 'AUTOCREATED']
with pytest.raises(StopIteration):
line = f.next()
sshuttle.firewall.restore_etc_hosts(10)
assert filecmp.cmp("tmp/hosts.orig", "tmp/hosts", shallow=False) is True
def test_main():
with nested(
patch('sshuttle.firewall.setup_daemon'),
patch('sshuttle.firewall.get_method')
) as (mock_setup_daemon, mock_get_method):
stdin, stdout = setup_daemon()
mock_setup_daemon.return_value = stdin, stdout
sshuttle.firewall.main("test", False)
stdout.mock_calls == [
call.write('READY test\n'),
call.flush(),
call.write('STARTED\n'),
call.flush()
]
mock_setup_daemon.mock_calls == [call()]
mock_get_method.mock_calls == [
call('test'),
call().setup_firewall(
1024, 1026,
[(10, u'2404:6800:4004:80c::33')],
10,
[(10, 64, False, u'2404:6800:4004:80c::'),
(10, 128, True, u'2404:6800:4004:80c::101f')],
True),
call().setup_firewall(
1025, 1027,
[(2, u'1.2.3.33')],
2,
[(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')],
True),
call().setup_firewall()(),
call().setup_firewall(1024, 0, [], 10, [], True),
call().setup_firewall(1025, 0, [], 2, [], True),
]

View File

@ -1 +0,0 @@
..