diff --git a/sshuttle/cmdline.py b/sshuttle/cmdline.py index 11d6796..a32c6e7 100644 --- a/sshuttle/cmdline.py +++ b/sshuttle/cmdline.py @@ -9,7 +9,7 @@ import sshuttle.firewall as firewall import sshuttle.hostwatch as hostwatch import sshuttle.ssyslog as ssyslog from sshuttle.options import parser, parse_ipport -from sshuttle.helpers import family_ip_tuple, log, Fatal +from sshuttle.helpers import family_ip_tuple, log, Fatal, start_stdout_stderr_flush_thread from sshuttle.sudoers import sudoers diff --git a/sshuttle/hostwatch.py b/sshuttle/hostwatch.py index 35ab2cc..1884165 100644 --- a/sshuttle/hostwatch.py +++ b/sshuttle/hostwatch.py @@ -18,6 +18,8 @@ CACHEFILE = os.path.expanduser('~/.sshuttle.hosts') # Have we already failed to write CACHEFILE? CACHE_WRITE_FAILED = False +SHOULD_WRITE_CACHE = False + hostnames = {} queue = {} try: @@ -81,6 +83,11 @@ def read_host_cache(): ip = re.sub(r'[^0-9.]', '', ip).strip() if name and ip: found_host(name, ip) + f.close() + global SHOULD_WRITE_CACHE + if SHOULD_WRITE_CACHE: + write_host_cache() + SHOULD_WRITE_CACHE = False def found_host(name, ip): @@ -97,12 +104,13 @@ def found_host(name, ip): if hostname != name: found_host(hostname, ip) + global SHOULD_WRITE_CACHE oldip = hostnames.get(name) if oldip != ip: hostnames[name] = ip debug1('Found: %s: %s' % (name, ip)) sys.stdout.write('%s,%s\n' % (name, ip)) - write_host_cache() + SHOULD_WRITE_CACHE = True def _check_etc_hosts(): diff --git a/tests/client/test_helpers.py b/tests/client/test_helpers.py index bfbb145..794c284 100644 --- a/tests/client/test_helpers.py +++ b/tests/client/test_helpers.py @@ -2,6 +2,7 @@ import io import socket from socket import AF_INET, AF_INET6 import errno +import time from unittest.mock import patch, call import sshuttle.helpers