diff --git a/client.py b/client.py index db4334d..1298277 100644 --- a/client.py +++ b/client.py @@ -95,6 +95,35 @@ def original_dst(sock): raise +class MultiListener: + + def __init__(self, type=socket.SOCK_STREAM, proto=0): + self.v4 = socket.socket(socket.AF_INET, type, proto) + + def setsockopt(self, level, optname, value): + if self.v4: + self.v4.setsockopt(level, optname, value) + + def add_handler(self, handlers, callback, mux): + if self.v4: + handlers.append(Handler([self.v4], lambda: callback(self.v4, mux, handlers))) + + def listen(self, backlog): + if self.v4: + self.v4.listen(backlog) + + def bind(self, address_v4): + if address_v4 and self.v4: + self.v4.bind(address_v4) + else: + self.v4 = None + + def print_listening(self, what): + if self.v4: + listenip = self.v4.getsockname() + debug1('%s listening on %r.\n' % (what, listenip)) + + class FirewallClient: def __init__(self, port, subnets_include, subnets_exclude, dnsport): self.port = port @@ -322,10 +351,10 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, fw.sethostip(name, ip) mux.got_host_list = onhostlist - handlers.append(Handler([tcp_listener], lambda: onaccept_tcp(tcp_listener, mux, handlers))) + tcp_listener.add_handler(handlers, onaccept_tcp, mux) if dns_listener: - handlers.append(Handler([dns_listener], lambda: ondns(dns_listener, mux, handlers))) + dns_listener.add_handler(handlers, ondns, mux) if seed_hosts != None: debug1('seed_hosts: %r\n' % seed_hosts) @@ -364,9 +393,9 @@ def main(listenip, ssh_cmd, remotename, python, latency_control, dns, debug2('Binding:') for port in ports: debug2(' %d' % port) - tcp_listener = socket.socket() + tcp_listener = MultiListener() tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - dns_listener = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + dns_listener = MultiListener(socket.SOCK_DGRAM) dns_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: tcp_listener.bind((listenip[0], port)) @@ -380,11 +409,11 @@ def main(listenip, ssh_cmd, remotename, python, latency_control, dns, assert(last_e) raise last_e tcp_listener.listen(10) - listenip = tcp_listener.getsockname() + listenip = tcp_listener.v4.getsockname() debug1('Listening on %r.\n' % (listenip,)) if dns: - dnsip = dns_listener.getsockname() + dnsip = dns_listener.v4.getsockname() debug1('DNS listening on %r.\n' % (dnsip,)) dnsport = dnsip[1] else: