diff --git a/sshuttle/tests/test_firewall.py b/sshuttle/tests/test_firewall.py index 540a5bf..10734d9 100644 --- a/sshuttle/tests/test_firewall.py +++ b/sshuttle/tests/test_firewall.py @@ -1,9 +1,5 @@ from mock import Mock, patch, call import io -import os -import os.path -import shutil -import filecmp import sshuttle.firewall @@ -19,27 +15,27 @@ NSLIST 10,2404:6800:4004:80c::33 PORTS 1024,1025,1026,1027 GO 1 +HOST 1.2.3.3,existing """) stdout = Mock() return stdin, stdout -@patch('sshuttle.firewall.HOSTSFILE', new='tmp/hosts') -@patch('sshuttle.firewall.hostmap', new={ - 'myhost': '1.2.3.4', - 'myotherhost': '1.2.3.5', -}) -def test_rewrite_etc_hosts(): - if not os.path.isdir("tmp"): - os.mkdir("tmp") +def test_rewrite_etc_hosts(tmpdir): + orig_hosts = tmpdir.join("hosts.orig") + orig_hosts.write("1.2.3.3 existing\n") - with open("tmp/hosts.orig", "w") as f: - f.write("1.2.3.3 existing\n") + new_hosts = tmpdir.join("hosts") + orig_hosts.copy(new_hosts) - shutil.copyfile("tmp/hosts.orig", "tmp/hosts") + hostmap = { + 'myhost': '1.2.3.4', + 'myotherhost': '1.2.3.5', + } + with patch('sshuttle.firewall.HOSTSFILE', new=str(new_hosts)): + sshuttle.firewall.rewrite_etc_hosts(hostmap, 10) - sshuttle.firewall.rewrite_etc_hosts(10) - with open("tmp/hosts") as f: + with new_hosts.open() as f: line = f.readline() s = line.split() assert s == ['1.2.3.3', 'existing'] @@ -57,39 +53,37 @@ def test_rewrite_etc_hosts(): line = f.readline() assert line == "" - sshuttle.firewall.restore_etc_hosts(10) - assert filecmp.cmp("tmp/hosts.orig", "tmp/hosts", shallow=False) is True + with patch('sshuttle.firewall.HOSTSFILE', new=str(new_hosts)): + sshuttle.firewall.restore_etc_hosts(10) + assert orig_hosts.computehash() == new_hosts.computehash() -@patch('sshuttle.firewall.HOSTSFILE', new='tmp/hosts') +@patch('sshuttle.firewall.rewrite_etc_hosts') @patch('sshuttle.firewall.setup_daemon') @patch('sshuttle.firewall.get_method') -def test_main(mock_get_method, mock_setup_daemon): +def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts): stdin, stdout = setup_daemon() mock_setup_daemon.return_value = stdin, stdout - if not os.path.isdir("tmp"): - os.mkdir("tmp") + mock_get_method("not_auto").name = "test" + mock_get_method.reset_mock() - sshuttle.firewall.main("test", False) + sshuttle.firewall.main("not_auto", False) - with open("tmp/hosts") as f: - line = f.readline() - s = line.split() - assert s == ['1.2.3.3', 'existing'] + assert mock_rewrite_etc_hosts.mock_calls == [ + call({'1.2.3.3': 'existing'}, 1024), + call({}, 1024), + ] - line = f.readline() - assert line == "" - - stdout.mock_calls == [ + assert stdout.mock_calls == [ call.write('READY test\n'), call.flush(), call.write('STARTED\n'), call.flush() ] - mock_setup_daemon.mock_calls == [call()] - mock_get_method.mock_calls == [ - call('test'), + assert mock_setup_daemon.mock_calls == [call()] + assert mock_get_method.mock_calls == [ + call('not_auto'), call().setup_firewall( 1024, 1026, [(10, u'2404:6800:4004:80c::33')], @@ -104,6 +98,7 @@ def test_main(mock_get_method, mock_setup_daemon): [(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')], True), call().setup_firewall()(), + call().setup_firewall()(), call().setup_firewall(1024, 0, [], 10, [], True), call().setup_firewall(1025, 0, [], 2, [], True), ]