mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-01-22 13:58:50 +01:00
Add firewall tests.
This commit is contained in:
parent
ac723694bf
commit
54de23aae3
12
.travis.yml
Normal file
12
.travis.yml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
language: python
|
||||||
|
python:
|
||||||
|
- 2.6
|
||||||
|
- 2.7
|
||||||
|
- 3.5
|
||||||
|
- pypy
|
||||||
|
|
||||||
|
install:
|
||||||
|
- travis_retry pip install -q pytest mock
|
||||||
|
|
||||||
|
script:
|
||||||
|
- py.test
|
@ -8,10 +8,10 @@ from sshuttle.helpers import debug1, debug2, Fatal
|
|||||||
from sshuttle.methods import get_auto_method, get_method
|
from sshuttle.methods import get_auto_method, get_method
|
||||||
|
|
||||||
hostmap = {}
|
hostmap = {}
|
||||||
|
HOSTSFILE = '/etc/hosts'
|
||||||
|
|
||||||
|
|
||||||
def rewrite_etc_hosts(port):
|
def rewrite_etc_hosts(port):
|
||||||
HOSTSFILE = '/etc/hosts'
|
|
||||||
BAKFILE = '%s.sbak' % HOSTSFILE
|
BAKFILE = '%s.sbak' % HOSTSFILE
|
||||||
APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port
|
APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port
|
||||||
old_content = ''
|
old_content = ''
|
||||||
@ -51,36 +51,11 @@ def restore_etc_hosts(port):
|
|||||||
rewrite_etc_hosts(port)
|
rewrite_etc_hosts(port)
|
||||||
|
|
||||||
|
|
||||||
# This is some voodoo for setting up the kernel's transparent
|
# Isolate function that needs to be replaced for tests
|
||||||
# proxying stuff. If subnets is empty, we just delete our sshuttle rules;
|
def setup_daemon():
|
||||||
# otherwise we delete it, then make them from scratch.
|
|
||||||
#
|
|
||||||
# This code is supposed to clean up after itself by deleting its rules on
|
|
||||||
# exit. In case that fails, it's not the end of the world; future runs will
|
|
||||||
# supercede it in the transproxy list, at least, so the leftover rules
|
|
||||||
# are hopefully harmless.
|
|
||||||
def main(method_name, syslog):
|
|
||||||
if os.getuid() != 0:
|
if os.getuid() != 0:
|
||||||
raise Fatal('you must be root (or enable su/sudo) to set the firewall')
|
raise Fatal('you must be root (or enable su/sudo) to set the firewall')
|
||||||
|
|
||||||
if method_name == "auto":
|
|
||||||
method = get_auto_method()
|
|
||||||
else:
|
|
||||||
method = get_method(method_name)
|
|
||||||
|
|
||||||
# because of limitations of the 'su' command, the *real* stdin/stdout
|
|
||||||
# are both attached to stdout initially. Clone stdout into stdin so we
|
|
||||||
# can read from it.
|
|
||||||
os.dup2(1, 0)
|
|
||||||
|
|
||||||
if syslog:
|
|
||||||
ssyslog.start_syslog()
|
|
||||||
ssyslog.stderr_to_syslog()
|
|
||||||
|
|
||||||
debug1('firewall manager ready method name %s.\n' % method.name)
|
|
||||||
sys.stdout.write('READY %s\n' % method.name)
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
# don't disappear if our controlling terminal or stdout/stderr
|
# don't disappear if our controlling terminal or stdout/stderr
|
||||||
# disappears; we still have to clean up.
|
# disappears; we still have to clean up.
|
||||||
signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
||||||
@ -92,10 +67,42 @@ def main(method_name, syslog):
|
|||||||
# I'll die automatically.
|
# I'll die automatically.
|
||||||
os.setsid()
|
os.setsid()
|
||||||
|
|
||||||
|
# because of limitations of the 'su' command, the *real* stdin/stdout
|
||||||
|
# are both attached to stdout initially. Clone stdout into stdin so we
|
||||||
|
# can read from it.
|
||||||
|
os.dup2(1, 0)
|
||||||
|
|
||||||
|
return sys.stdin, sys.stdout
|
||||||
|
|
||||||
|
|
||||||
|
# This is some voodoo for setting up the kernel's transparent
|
||||||
|
# proxying stuff. If subnets is empty, we just delete our sshuttle rules;
|
||||||
|
# otherwise we delete it, then make them from scratch.
|
||||||
|
#
|
||||||
|
# This code is supposed to clean up after itself by deleting its rules on
|
||||||
|
# exit. In case that fails, it's not the end of the world; future runs will
|
||||||
|
# supercede it in the transproxy list, at least, so the leftover rules
|
||||||
|
# are hopefully harmless.
|
||||||
|
def main(method_name, syslog):
|
||||||
|
stdin, stdout = setup_daemon()
|
||||||
|
|
||||||
|
if method_name == "auto":
|
||||||
|
method = get_auto_method()
|
||||||
|
else:
|
||||||
|
method = get_method(method_name)
|
||||||
|
|
||||||
|
if syslog:
|
||||||
|
ssyslog.start_syslog()
|
||||||
|
ssyslog.stderr_to_syslog()
|
||||||
|
|
||||||
|
debug1('firewall manager ready method name %s.\n' % method_name)
|
||||||
|
stdout.write('READY %s\n' % method_name)
|
||||||
|
stdout.flush()
|
||||||
|
|
||||||
# we wait until we get some input before creating the rules. That way,
|
# we wait until we get some input before creating the rules. That way,
|
||||||
# sshuttle can launch us as early as possible (and get sudo password
|
# sshuttle can launch us as early as possible (and get sudo password
|
||||||
# authentication as early in the startup process as possible).
|
# authentication as early in the startup process as possible).
|
||||||
line = sys.stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
if not line:
|
if not line:
|
||||||
return # parent died; nothing to do
|
return # parent died; nothing to do
|
||||||
|
|
||||||
@ -103,7 +110,7 @@ def main(method_name, syslog):
|
|||||||
if line != 'ROUTES\n':
|
if line != 'ROUTES\n':
|
||||||
raise Fatal('firewall: expected ROUTES but got %r' % line)
|
raise Fatal('firewall: expected ROUTES but got %r' % line)
|
||||||
while 1:
|
while 1:
|
||||||
line = sys.stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('firewall: expected route but got %r' % line)
|
raise Fatal('firewall: expected route but got %r' % line)
|
||||||
elif line.startswith("NSLIST\n"):
|
elif line.startswith("NSLIST\n"):
|
||||||
@ -119,7 +126,7 @@ def main(method_name, syslog):
|
|||||||
if line != 'NSLIST\n':
|
if line != 'NSLIST\n':
|
||||||
raise Fatal('firewall: expected NSLIST but got %r' % line)
|
raise Fatal('firewall: expected NSLIST but got %r' % line)
|
||||||
while 1:
|
while 1:
|
||||||
line = sys.stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('firewall: expected nslist but got %r' % line)
|
raise Fatal('firewall: expected nslist but got %r' % line)
|
||||||
elif line.startswith("PORTS "):
|
elif line.startswith("PORTS "):
|
||||||
@ -155,7 +162,7 @@ def main(method_name, syslog):
|
|||||||
debug2('Got ports: %d,%d,%d,%d\n'
|
debug2('Got ports: %d,%d,%d,%d\n'
|
||||||
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
||||||
|
|
||||||
line = sys.stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('firewall: expected GO but got %r' % line)
|
raise Fatal('firewall: expected GO but got %r' % line)
|
||||||
elif not line.startswith("GO "):
|
elif not line.startswith("GO "):
|
||||||
@ -169,26 +176,28 @@ def main(method_name, syslog):
|
|||||||
do_wait = None
|
do_wait = None
|
||||||
debug1('firewall manager: starting transproxy.\n')
|
debug1('firewall manager: starting transproxy.\n')
|
||||||
|
|
||||||
|
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||||
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||||
if port_v6 > 0:
|
if port_v6 > 0:
|
||||||
do_wait = method.setup_firewall(
|
do_wait = method.setup_firewall(
|
||||||
port_v6, dnsport_v6, nslist,
|
port_v6, dnsport_v6, nslist_v6,
|
||||||
socket.AF_INET6, subnets_v6, udp)
|
socket.AF_INET6, subnets_v6, udp)
|
||||||
elif len(subnets_v6) > 0:
|
elif len(subnets_v6) > 0:
|
||||||
debug1("IPv6 subnets defined but IPv6 disabled\n")
|
debug1("IPv6 subnets defined but IPv6 disabled\n")
|
||||||
|
|
||||||
|
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
|
||||||
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
|
||||||
if port_v4 > 0:
|
if port_v4 > 0:
|
||||||
do_wait = method.setup_firewall(
|
do_wait = method.setup_firewall(
|
||||||
port_v4, dnsport_v4, nslist,
|
port_v4, dnsport_v4, nslist_v4,
|
||||||
socket.AF_INET, subnets_v4, udp)
|
socket.AF_INET, subnets_v4, udp)
|
||||||
elif len(subnets_v4) > 0:
|
elif len(subnets_v4) > 0:
|
||||||
debug1('IPv4 subnets defined but IPv4 disabled\n')
|
debug1('IPv4 subnets defined but IPv4 disabled\n')
|
||||||
|
|
||||||
sys.stdout.write('STARTED\n')
|
stdout.write('STARTED\n')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sys.stdout.flush()
|
stdout.flush()
|
||||||
except IOError:
|
except IOError:
|
||||||
# the parent process died for some reason; he's surely been loud
|
# the parent process died for some reason; he's surely been loud
|
||||||
# enough, so no reason to report another error
|
# enough, so no reason to report another error
|
||||||
@ -198,9 +207,9 @@ def main(method_name, syslog):
|
|||||||
# to stay running so that we don't need a *second* password
|
# to stay running so that we don't need a *second* password
|
||||||
# authentication at shutdown time - that cleanup is important!
|
# authentication at shutdown time - that cleanup is important!
|
||||||
while 1:
|
while 1:
|
||||||
if do_wait:
|
if do_wait is not None:
|
||||||
do_wait()
|
do_wait()
|
||||||
line = sys.stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
if line.startswith('HOST '):
|
if line.startswith('HOST '):
|
||||||
(name, ip) = line[5:].strip().split(',', 1)
|
(name, ip) = line[5:].strip().split(',', 1)
|
||||||
hostmap[name] = ip
|
hostmap[name] = ip
|
||||||
|
101
sshuttle/tests/test_firewall.py
Normal file
101
sshuttle/tests/test_firewall.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
from mock import Mock, patch, call
|
||||||
|
from contextlib import nested
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
import filecmp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import sshuttle.firewall
|
||||||
|
|
||||||
|
|
||||||
|
def setup_daemon():
|
||||||
|
stdin = io.StringIO(u"""ROUTES
|
||||||
|
2,24,0,1.2.3.0
|
||||||
|
2,32,1,1.2.3.66
|
||||||
|
10,64,0,2404:6800:4004:80c::
|
||||||
|
10,128,1,2404:6800:4004:80c::101f
|
||||||
|
NSLIST
|
||||||
|
2,1.2.3.33
|
||||||
|
10,2404:6800:4004:80c::33
|
||||||
|
PORTS 1024,1025,1026,1027
|
||||||
|
GO 1
|
||||||
|
""")
|
||||||
|
stdout = Mock()
|
||||||
|
return stdin, stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_rewrite_etc_hosts():
|
||||||
|
if not os.path.isdir("tmp"):
|
||||||
|
os.mkdir("tmp")
|
||||||
|
|
||||||
|
with open("tmp/hosts.orig", "w") as f:
|
||||||
|
f.write("1.2.3.3 existing\n")
|
||||||
|
|
||||||
|
shutil.copyfile("tmp/hosts.orig", "tmp/hosts")
|
||||||
|
|
||||||
|
sshuttle.firewall.HOSTSFILE = "tmp/hosts"
|
||||||
|
sshuttle.firewall.hostmap = {
|
||||||
|
'myhost': '1.2.3.4',
|
||||||
|
'myotherhost': '1.2.3.5',
|
||||||
|
}
|
||||||
|
sshuttle.firewall.rewrite_etc_hosts(10)
|
||||||
|
with open("tmp/hosts") as f:
|
||||||
|
line = f.next()
|
||||||
|
s = line.split()
|
||||||
|
assert s == ['1.2.3.3', 'existing']
|
||||||
|
|
||||||
|
line = f.next()
|
||||||
|
s = line.split()
|
||||||
|
assert s == ['1.2.3.4', 'myhost',
|
||||||
|
'#', 'sshuttle-firewall-10', 'AUTOCREATED']
|
||||||
|
|
||||||
|
line = f.next()
|
||||||
|
s = line.split()
|
||||||
|
assert s == ['1.2.3.5', 'myotherhost',
|
||||||
|
'#', 'sshuttle-firewall-10', 'AUTOCREATED']
|
||||||
|
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
line = f.next()
|
||||||
|
|
||||||
|
sshuttle.firewall.restore_etc_hosts(10)
|
||||||
|
assert filecmp.cmp("tmp/hosts.orig", "tmp/hosts", shallow=False) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_main():
|
||||||
|
with nested(
|
||||||
|
patch('sshuttle.firewall.setup_daemon'),
|
||||||
|
patch('sshuttle.firewall.get_method')
|
||||||
|
) as (mock_setup_daemon, mock_get_method):
|
||||||
|
stdin, stdout = setup_daemon()
|
||||||
|
mock_setup_daemon.return_value = stdin, stdout
|
||||||
|
|
||||||
|
sshuttle.firewall.main("test", False)
|
||||||
|
|
||||||
|
stdout.mock_calls == [
|
||||||
|
call.write('READY test\n'),
|
||||||
|
call.flush(),
|
||||||
|
call.write('STARTED\n'),
|
||||||
|
call.flush()
|
||||||
|
]
|
||||||
|
mock_setup_daemon.mock_calls == [call()]
|
||||||
|
mock_get_method.mock_calls == [
|
||||||
|
call('test'),
|
||||||
|
call().setup_firewall(
|
||||||
|
1024, 1026,
|
||||||
|
[(10, u'2404:6800:4004:80c::33')],
|
||||||
|
10,
|
||||||
|
[(10, 64, False, u'2404:6800:4004:80c::'),
|
||||||
|
(10, 128, True, u'2404:6800:4004:80c::101f')],
|
||||||
|
True),
|
||||||
|
call().setup_firewall(
|
||||||
|
1025, 1027,
|
||||||
|
[(2, u'1.2.3.33')],
|
||||||
|
2,
|
||||||
|
[(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')],
|
||||||
|
True),
|
||||||
|
call().setup_firewall()(),
|
||||||
|
call().setup_firewall(1024, 0, [], 10, [], True),
|
||||||
|
call().setup_firewall(1025, 0, [], 2, [], True),
|
||||||
|
]
|
@ -1 +0,0 @@
|
|||||||
..
|
|
Loading…
Reference in New Issue
Block a user