mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-01-22 05:49:09 +01:00
Add firewall tests.
This commit is contained in:
parent
ac723694bf
commit
54de23aae3
12
.travis.yml
Normal file
12
.travis.yml
Normal file
@ -0,0 +1,12 @@
|
||||
language: python
|
||||
python:
|
||||
- 2.6
|
||||
- 2.7
|
||||
- 3.5
|
||||
- pypy
|
||||
|
||||
install:
|
||||
- travis_retry pip install -q pytest mock
|
||||
|
||||
script:
|
||||
- py.test
|
@ -8,10 +8,10 @@ from sshuttle.helpers import debug1, debug2, Fatal
|
||||
from sshuttle.methods import get_auto_method, get_method
|
||||
|
||||
hostmap = {}
|
||||
HOSTSFILE = '/etc/hosts'
|
||||
|
||||
|
||||
def rewrite_etc_hosts(port):
|
||||
HOSTSFILE = '/etc/hosts'
|
||||
BAKFILE = '%s.sbak' % HOSTSFILE
|
||||
APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port
|
||||
old_content = ''
|
||||
@ -51,36 +51,11 @@ def restore_etc_hosts(port):
|
||||
rewrite_etc_hosts(port)
|
||||
|
||||
|
||||
# This is some voodoo for setting up the kernel's transparent
|
||||
# proxying stuff. If subnets is empty, we just delete our sshuttle rules;
|
||||
# otherwise we delete it, then make them from scratch.
|
||||
#
|
||||
# This code is supposed to clean up after itself by deleting its rules on
|
||||
# exit. In case that fails, it's not the end of the world; future runs will
|
||||
# supercede it in the transproxy list, at least, so the leftover rules
|
||||
# are hopefully harmless.
|
||||
def main(method_name, syslog):
|
||||
# Isolate function that needs to be replaced for tests
|
||||
def setup_daemon():
|
||||
if os.getuid() != 0:
|
||||
raise Fatal('you must be root (or enable su/sudo) to set the firewall')
|
||||
|
||||
if method_name == "auto":
|
||||
method = get_auto_method()
|
||||
else:
|
||||
method = get_method(method_name)
|
||||
|
||||
# because of limitations of the 'su' command, the *real* stdin/stdout
|
||||
# are both attached to stdout initially. Clone stdout into stdin so we
|
||||
# can read from it.
|
||||
os.dup2(1, 0)
|
||||
|
||||
if syslog:
|
||||
ssyslog.start_syslog()
|
||||
ssyslog.stderr_to_syslog()
|
||||
|
||||
debug1('firewall manager ready method name %s.\n' % method.name)
|
||||
sys.stdout.write('READY %s\n' % method.name)
|
||||
sys.stdout.flush()
|
||||
|
||||
# don't disappear if our controlling terminal or stdout/stderr
|
||||
# disappears; we still have to clean up.
|
||||
signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
||||
@ -92,10 +67,42 @@ def main(method_name, syslog):
|
||||
# I'll die automatically.
|
||||
os.setsid()
|
||||
|
||||
# because of limitations of the 'su' command, the *real* stdin/stdout
|
||||
# are both attached to stdout initially. Clone stdout into stdin so we
|
||||
# can read from it.
|
||||
os.dup2(1, 0)
|
||||
|
||||
return sys.stdin, sys.stdout
|
||||
|
||||
|
||||
# This is some voodoo for setting up the kernel's transparent
|
||||
# proxying stuff. If subnets is empty, we just delete our sshuttle rules;
|
||||
# otherwise we delete it, then make them from scratch.
|
||||
#
|
||||
# This code is supposed to clean up after itself by deleting its rules on
|
||||
# exit. In case that fails, it's not the end of the world; future runs will
|
||||
# supercede it in the transproxy list, at least, so the leftover rules
|
||||
# are hopefully harmless.
|
||||
def main(method_name, syslog):
|
||||
stdin, stdout = setup_daemon()
|
||||
|
||||
if method_name == "auto":
|
||||
method = get_auto_method()
|
||||
else:
|
||||
method = get_method(method_name)
|
||||
|
||||
if syslog:
|
||||
ssyslog.start_syslog()
|
||||
ssyslog.stderr_to_syslog()
|
||||
|
||||
debug1('firewall manager ready method name %s.\n' % method_name)
|
||||
stdout.write('READY %s\n' % method_name)
|
||||
stdout.flush()
|
||||
|
||||
# we wait until we get some input before creating the rules. That way,
|
||||
# sshuttle can launch us as early as possible (and get sudo password
|
||||
# authentication as early in the startup process as possible).
|
||||
line = sys.stdin.readline(128)
|
||||
line = stdin.readline(128)
|
||||
if not line:
|
||||
return # parent died; nothing to do
|
||||
|
||||
@ -103,7 +110,7 @@ def main(method_name, syslog):
|
||||
if line != 'ROUTES\n':
|
||||
raise Fatal('firewall: expected ROUTES but got %r' % line)
|
||||
while 1:
|
||||
line = sys.stdin.readline(128)
|
||||
line = stdin.readline(128)
|
||||
if not line:
|
||||
raise Fatal('firewall: expected route but got %r' % line)
|
||||
elif line.startswith("NSLIST\n"):
|
||||
@ -119,7 +126,7 @@ def main(method_name, syslog):
|
||||
if line != 'NSLIST\n':
|
||||
raise Fatal('firewall: expected NSLIST but got %r' % line)
|
||||
while 1:
|
||||
line = sys.stdin.readline(128)
|
||||
line = stdin.readline(128)
|
||||
if not line:
|
||||
raise Fatal('firewall: expected nslist but got %r' % line)
|
||||
elif line.startswith("PORTS "):
|
||||
@ -155,7 +162,7 @@ def main(method_name, syslog):
|
||||
debug2('Got ports: %d,%d,%d,%d\n'
|
||||
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
||||
|
||||
line = sys.stdin.readline(128)
|
||||
line = stdin.readline(128)
|
||||
if not line:
|
||||
raise Fatal('firewall: expected GO but got %r' % line)
|
||||
elif not line.startswith("GO "):
|
||||
@ -169,26 +176,28 @@ def main(method_name, syslog):
|
||||
do_wait = None
|
||||
debug1('firewall manager: starting transproxy.\n')
|
||||
|
||||
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||
if port_v6 > 0:
|
||||
do_wait = method.setup_firewall(
|
||||
port_v6, dnsport_v6, nslist,
|
||||
port_v6, dnsport_v6, nslist_v6,
|
||||
socket.AF_INET6, subnets_v6, udp)
|
||||
elif len(subnets_v6) > 0:
|
||||
debug1("IPv6 subnets defined but IPv6 disabled\n")
|
||||
|
||||
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
||||
if port_v4 > 0:
|
||||
do_wait = method.setup_firewall(
|
||||
port_v4, dnsport_v4, nslist,
|
||||
port_v4, dnsport_v4, nslist_v4,
|
||||
socket.AF_INET, subnets_v4, udp)
|
||||
elif len(subnets_v4) > 0:
|
||||
debug1('IPv4 subnets defined but IPv4 disabled\n')
|
||||
|
||||
sys.stdout.write('STARTED\n')
|
||||
stdout.write('STARTED\n')
|
||||
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
stdout.flush()
|
||||
except IOError:
|
||||
# the parent process died for some reason; he's surely been loud
|
||||
# enough, so no reason to report another error
|
||||
@ -198,9 +207,9 @@ def main(method_name, syslog):
|
||||
# to stay running so that we don't need a *second* password
|
||||
# authentication at shutdown time - that cleanup is important!
|
||||
while 1:
|
||||
if do_wait:
|
||||
if do_wait is not None:
|
||||
do_wait()
|
||||
line = sys.stdin.readline(128)
|
||||
line = stdin.readline(128)
|
||||
if line.startswith('HOST '):
|
||||
(name, ip) = line[5:].strip().split(',', 1)
|
||||
hostmap[name] = ip
|
||||
|
101
sshuttle/tests/test_firewall.py
Normal file
101
sshuttle/tests/test_firewall.py
Normal file
@ -0,0 +1,101 @@
|
||||
from mock import Mock, patch, call
|
||||
from contextlib import nested
|
||||
import io
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import filecmp
|
||||
import pytest
|
||||
|
||||
import sshuttle.firewall
|
||||
|
||||
|
||||
def setup_daemon():
|
||||
stdin = io.StringIO(u"""ROUTES
|
||||
2,24,0,1.2.3.0
|
||||
2,32,1,1.2.3.66
|
||||
10,64,0,2404:6800:4004:80c::
|
||||
10,128,1,2404:6800:4004:80c::101f
|
||||
NSLIST
|
||||
2,1.2.3.33
|
||||
10,2404:6800:4004:80c::33
|
||||
PORTS 1024,1025,1026,1027
|
||||
GO 1
|
||||
""")
|
||||
stdout = Mock()
|
||||
return stdin, stdout
|
||||
|
||||
|
||||
def test_rewrite_etc_hosts():
|
||||
if not os.path.isdir("tmp"):
|
||||
os.mkdir("tmp")
|
||||
|
||||
with open("tmp/hosts.orig", "w") as f:
|
||||
f.write("1.2.3.3 existing\n")
|
||||
|
||||
shutil.copyfile("tmp/hosts.orig", "tmp/hosts")
|
||||
|
||||
sshuttle.firewall.HOSTSFILE = "tmp/hosts"
|
||||
sshuttle.firewall.hostmap = {
|
||||
'myhost': '1.2.3.4',
|
||||
'myotherhost': '1.2.3.5',
|
||||
}
|
||||
sshuttle.firewall.rewrite_etc_hosts(10)
|
||||
with open("tmp/hosts") as f:
|
||||
line = f.next()
|
||||
s = line.split()
|
||||
assert s == ['1.2.3.3', 'existing']
|
||||
|
||||
line = f.next()
|
||||
s = line.split()
|
||||
assert s == ['1.2.3.4', 'myhost',
|
||||
'#', 'sshuttle-firewall-10', 'AUTOCREATED']
|
||||
|
||||
line = f.next()
|
||||
s = line.split()
|
||||
assert s == ['1.2.3.5', 'myotherhost',
|
||||
'#', 'sshuttle-firewall-10', 'AUTOCREATED']
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
line = f.next()
|
||||
|
||||
sshuttle.firewall.restore_etc_hosts(10)
|
||||
assert filecmp.cmp("tmp/hosts.orig", "tmp/hosts", shallow=False) is True
|
||||
|
||||
|
||||
def test_main():
|
||||
with nested(
|
||||
patch('sshuttle.firewall.setup_daemon'),
|
||||
patch('sshuttle.firewall.get_method')
|
||||
) as (mock_setup_daemon, mock_get_method):
|
||||
stdin, stdout = setup_daemon()
|
||||
mock_setup_daemon.return_value = stdin, stdout
|
||||
|
||||
sshuttle.firewall.main("test", False)
|
||||
|
||||
stdout.mock_calls == [
|
||||
call.write('READY test\n'),
|
||||
call.flush(),
|
||||
call.write('STARTED\n'),
|
||||
call.flush()
|
||||
]
|
||||
mock_setup_daemon.mock_calls == [call()]
|
||||
mock_get_method.mock_calls == [
|
||||
call('test'),
|
||||
call().setup_firewall(
|
||||
1024, 1026,
|
||||
[(10, u'2404:6800:4004:80c::33')],
|
||||
10,
|
||||
[(10, 64, False, u'2404:6800:4004:80c::'),
|
||||
(10, 128, True, u'2404:6800:4004:80c::101f')],
|
||||
True),
|
||||
call().setup_firewall(
|
||||
1025, 1027,
|
||||
[(2, u'1.2.3.33')],
|
||||
2,
|
||||
[(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')],
|
||||
True),
|
||||
call().setup_firewall()(),
|
||||
call().setup_firewall(1024, 0, [], 10, [], True),
|
||||
call().setup_firewall(1025, 0, [], 2, [], True),
|
||||
]
|
@ -1 +0,0 @@
|
||||
..
|
Loading…
Reference in New Issue
Block a user