Fixup firewall tests.

This commit is contained in:
Brian May 2015-12-06 11:02:31 +11:00
parent 53c07f7d90
commit bd97506f7d

View File

@ -1,9 +1,5 @@
from mock import Mock, patch, call from mock import Mock, patch, call
import io import io
import os
import os.path
import shutil
import filecmp
import sshuttle.firewall import sshuttle.firewall
@ -19,27 +15,27 @@ NSLIST
10,2404:6800:4004:80c::33 10,2404:6800:4004:80c::33
PORTS 1024,1025,1026,1027 PORTS 1024,1025,1026,1027
GO 1 GO 1
HOST 1.2.3.3,existing
""") """)
stdout = Mock() stdout = Mock()
return stdin, stdout return stdin, stdout
@patch('sshuttle.firewall.HOSTSFILE', new='tmp/hosts') def test_rewrite_etc_hosts(tmpdir):
@patch('sshuttle.firewall.hostmap', new={ orig_hosts = tmpdir.join("hosts.orig")
orig_hosts.write("1.2.3.3 existing\n")
new_hosts = tmpdir.join("hosts")
orig_hosts.copy(new_hosts)
hostmap = {
'myhost': '1.2.3.4', 'myhost': '1.2.3.4',
'myotherhost': '1.2.3.5', 'myotherhost': '1.2.3.5',
}) }
def test_rewrite_etc_hosts(): with patch('sshuttle.firewall.HOSTSFILE', new=str(new_hosts)):
if not os.path.isdir("tmp"): sshuttle.firewall.rewrite_etc_hosts(hostmap, 10)
os.mkdir("tmp")
with open("tmp/hosts.orig", "w") as f: with new_hosts.open() as f:
f.write("1.2.3.3 existing\n")
shutil.copyfile("tmp/hosts.orig", "tmp/hosts")
sshuttle.firewall.rewrite_etc_hosts(10)
with open("tmp/hosts") as f:
line = f.readline() line = f.readline()
s = line.split() s = line.split()
assert s == ['1.2.3.3', 'existing'] assert s == ['1.2.3.3', 'existing']
@ -57,39 +53,37 @@ def test_rewrite_etc_hosts():
line = f.readline() line = f.readline()
assert line == "" assert line == ""
with patch('sshuttle.firewall.HOSTSFILE', new=str(new_hosts)):
sshuttle.firewall.restore_etc_hosts(10) sshuttle.firewall.restore_etc_hosts(10)
assert filecmp.cmp("tmp/hosts.orig", "tmp/hosts", shallow=False) is True assert orig_hosts.computehash() == new_hosts.computehash()
@patch('sshuttle.firewall.HOSTSFILE', new='tmp/hosts') @patch('sshuttle.firewall.rewrite_etc_hosts')
@patch('sshuttle.firewall.setup_daemon') @patch('sshuttle.firewall.setup_daemon')
@patch('sshuttle.firewall.get_method') @patch('sshuttle.firewall.get_method')
def test_main(mock_get_method, mock_setup_daemon): def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
stdin, stdout = setup_daemon() stdin, stdout = setup_daemon()
mock_setup_daemon.return_value = stdin, stdout mock_setup_daemon.return_value = stdin, stdout
if not os.path.isdir("tmp"): mock_get_method("not_auto").name = "test"
os.mkdir("tmp") mock_get_method.reset_mock()
sshuttle.firewall.main("test", False) sshuttle.firewall.main("not_auto", False)
with open("tmp/hosts") as f: assert mock_rewrite_etc_hosts.mock_calls == [
line = f.readline() call({'1.2.3.3': 'existing'}, 1024),
s = line.split() call({}, 1024),
assert s == ['1.2.3.3', 'existing'] ]
line = f.readline() assert stdout.mock_calls == [
assert line == ""
stdout.mock_calls == [
call.write('READY test\n'), call.write('READY test\n'),
call.flush(), call.flush(),
call.write('STARTED\n'), call.write('STARTED\n'),
call.flush() call.flush()
] ]
mock_setup_daemon.mock_calls == [call()] assert mock_setup_daemon.mock_calls == [call()]
mock_get_method.mock_calls == [ assert mock_get_method.mock_calls == [
call('test'), call('not_auto'),
call().setup_firewall( call().setup_firewall(
1024, 1026, 1024, 1026,
[(10, u'2404:6800:4004:80c::33')], [(10, u'2404:6800:4004:80c::33')],
@ -104,6 +98,7 @@ def test_main(mock_get_method, mock_setup_daemon):
[(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')], [(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')],
True), True),
call().setup_firewall()(), call().setup_firewall()(),
call().setup_firewall()(),
call().setup_firewall(1024, 0, [], 10, [], True), call().setup_firewall(1024, 0, [], 10, [], True),
call().setup_firewall(1025, 0, [], 2, [], True), call().setup_firewall(1025, 0, [], 2, [], True),
] ]