diff --git a/client.py b/client.py index 8d0e8c4..3074697 100644 --- a/client.py +++ b/client.py @@ -104,9 +104,9 @@ class MultiListener: if self.v4: self.v4.setsockopt(level, optname, value) - def add_handler(self, handlers, callback, mux): + def add_handler(self, handlers, callback, method, mux): if self.v4: - handlers.append(Handler([self.v4], lambda: callback(self.v4, mux, handlers))) + handlers.append(Handler([self.v4], lambda: callback(self.v4, method, mux, handlers))) def listen(self, backlog): if self.v4: @@ -125,7 +125,7 @@ class MultiListener: class FirewallClient: - def __init__(self, port, subnets_include, subnets_exclude, dnsport): + def __init__(self, port, subnets_include, subnets_exclude, dnsport, method): self.port = port self.auto_nets = [] self.subnets_include = subnets_include @@ -133,7 +133,9 @@ class FirewallClient: self.dnsport = dnsport argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] + ['-v'] * (helpers.verbose or 0) + - ['--firewall', str(port), str(dnsport)]) + ['--firewall', str(port), + str(dnsport), + method]) if ssyslog._p: argvbase += ['--syslog'] argv_tries = [ @@ -170,8 +172,9 @@ class FirewallClient: raise Fatal(e) line = self.pfile.readline() self.check() - if line != 'READY\n': + if line[0:5] != 'READY': raise Fatal('%r expected READY, got %r' % (self.argv, line)) + self.method = line[6:-1] def check(self): rv = self.p.poll() @@ -214,7 +217,7 @@ def expire_connections(now, mux): debug3('Remaining DNS requests: %d\n' % len(dnsreqs)) -def onaccept_tcp(listener, mux, handlers): +def onaccept_tcp(listener, method, mux, handlers): global _extra_fd try: sock,srcip = listener.accept() @@ -249,7 +252,7 @@ def onaccept_tcp(listener, mux, handlers): expire_connections(time.time(), mux) -def dns_done(chan, mux, data): +def dns_done(chan, mux, data, method): peer,sock,timeout = dnsreqs.get(chan) or (None,None,None) debug3('dns_done: channel=%r peer=%r\n' % (chan, peer)) if peer: @@ -259,7 +262,7 @@ def dns_done(chan, mux, data): sock.sendto(data, peer) -def ondns(listener, mux, handlers): +def ondns(listener, method, mux, handlers): pkt,peer = listener.recvfrom(4096) now = time.time() if pkt: @@ -267,12 +270,12 @@ def ondns(listener, mux, handlers): chan = mux.next_channel() dnsreqs[chan] = peer,listener,now+30 mux.send(chan, ssnet.CMD_DNS_REQ, pkt) - mux.channels[chan] = lambda cmd,data: dns_done(chan, mux, data) + mux.channels[chan] = lambda cmd,data: dns_done(chan, mux, data, method) expire_connections(now, mux) def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, - dns_listener, seed_hosts, auto_nets, + dns_listener, method, seed_hosts, auto_nets, syslog, daemon): handlers = [] if helpers.verbose >= 1: @@ -351,10 +354,10 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, fw.sethostip(name, ip) mux.got_host_list = onhostlist - tcp_listener.add_handler(handlers, onaccept_tcp, mux) + tcp_listener.add_handler(handlers, onaccept_tcp, method, mux) if dns_listener: - dns_listener.add_handler(handlers, ondns, mux) + dns_listener.add_handler(handlers, ondns, method, mux) if seed_hosts != None: debug1('seed_hosts: %r\n' % seed_hosts) @@ -371,8 +374,9 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, mux.callback() -def main(listenip_v4, ssh_cmd, remotename, python, latency_control, dns, - seed_hosts, auto_nets, +def main(listenip_v4, + ssh_cmd, remotename, python, latency_control, dns, + method, seed_hosts, auto_nets, subnets_include, subnets_exclude, syslog, daemon, pidfile): if syslog: ssyslog.start_syslog() @@ -462,12 +466,13 @@ def main(listenip_v4, ssh_cmd, remotename, python, latency_control, dns, dnsport_v4 = 0 dns_listener = None - fw = FirewallClient(redirectport_v4, subnets_include, subnets_exclude, dnsport_v4) - + fw = FirewallClient(redirectport_v4, subnets_include, subnets_exclude, dnsport_v4, method) + try: return _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, dns_listener, - seed_hosts, auto_nets, syslog, daemon) + fw.method, seed_hosts, auto_nets, syslog, + daemon) finally: try: if daemon: diff --git a/firewall.py b/firewall.py index b68d7a8..2136f21 100644 --- a/firewall.py +++ b/firewall.py @@ -65,7 +65,7 @@ def _ipt_ttl(family, *args): # multiple copies shouldn't have overlapping subnets, or only the most- # recently-started one will win (because we use "-I OUTPUT 1" instead of # "-A OUTPUT"). -def do_iptables(port, dnsport, family, subnets): +def do_iptables_nat(port, dnsport, family, subnets): # only ipv4 supported with NAT if family != socket.AF_INET: raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family)) @@ -389,7 +389,7 @@ def restore_etc_hosts(port): # exit. In case that fails, it's not the end of the world; future runs will # supercede it in the transproxy list, at least, so the leftover rules # are hopefully harmless. -def main(port_v4, dnsport_v4, syslog): +def main(port_v4, dnsport_v4, method, syslog): assert(port_v4 > 0) assert(port_v4 <= 65535) assert(dnsport_v4 >= 0) @@ -398,12 +398,20 @@ def main(port_v4, dnsport_v4, syslog): if os.getuid() != 0: raise Fatal('you must be root (or enable su/sudo) to set the firewall') - if program_exists('ipfw'): + if method == "auto": + if program_exists('ipfw'): + method = "ipfw" + elif program_exists('iptables'): + method = "nat" + else: + raise Fatal("can't find either ipfw or iptables; check your PATH") + + if method == "nat": + do_it = do_iptables_nat + elif method == "ipfw": do_it = do_ipfw - elif program_exists('iptables'): - do_it = do_iptables else: - raise Fatal("can't find either ipfw or iptables; check your PATH") + raise Exception('Unknown method "%s"'%method) # because of limitations of the 'su' command, the *real* stdin/stdout # are both attached to stdout initially. Clone stdout into stdin so we @@ -414,8 +422,8 @@ def main(port_v4, dnsport_v4, syslog): ssyslog.start_syslog() ssyslog.stderr_to_syslog() - debug1('firewall manager ready.\n') - sys.stdout.write('READY\n') + debug1('firewall manager ready method %s.\n'%method) + sys.stdout.write('READY %s\n'%method) sys.stdout.flush() # ctrl-c shouldn't be passed along to me. When the main sshuttle dies, diff --git a/main.py b/main.py index 2dec9df..a08bf80 100644 --- a/main.py +++ b/main.py @@ -54,6 +54,7 @@ l,listen= transproxy to this ip address and port number [127.0.0.1:0] H,auto-hosts scan for remote hostnames and update local /etc/hosts N,auto-nets automatically determine subnets to route dns capture local DNS requests and forward to the remote DNS server +method= auto, nat, or ipfw python= path to python interpreter on the remote server r,remote= ssh hostname (and optional username) of remote sshuttle server x,exclude= exclude this subnet (can be used more than once) @@ -86,9 +87,10 @@ try: server.latency_control = opt.latency_control sys.exit(server.main()) elif opt.firewall: - if len(extra) != 2: - o.fatal('exactly two arguments expected') - sys.exit(firewall.main(int(extra[0]), int(extra[1]), opt.syslog)) + if len(extra) != 3: + o.fatal('exactly three arguments expected') + sys.exit(firewall.main(int(extra[0]), int(extra[1]), + extra[2], opt.syslog)) elif opt.hostwatch: sys.exit(hostwatch.hw_main(extra)) else: @@ -110,12 +112,20 @@ try: sh = [] else: sh = None - sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), + if not opt.method: + method = "auto" + elif opt.method in [ "auto", "nat", "ipfw" ]: + method = opt.method + else: + o.fatal("method %s not supported"%opt.method) + ipport_v4 = parse_ipport(opt.listen or '0.0.0.0:0') + sys.exit(client.main(ipport_v4, opt.ssh_cmd, remotename, opt.python, opt.latency_control, opt.dns, + method, sh, opt.auto_nets, parse_subnets(includes),