diff --git a/client.py b/client.py index 1298277..d1bd6f7 100644 --- a/client.py +++ b/client.py @@ -180,10 +180,10 @@ class FirewallClient: def start(self): self.pfile.write('ROUTES\n') - for (ip,width) in self.subnets_include+self.auto_nets: - self.pfile.write('%d,0,%s\n' % (width, ip)) - for (ip,width) in self.subnets_exclude: - self.pfile.write('%d,1,%s\n' % (width, ip)) + for (family,ip,width) in self.subnets_include+self.auto_nets: + self.pfile.write('%d,%d,0,%s\n' % (family, width, ip)) + for (family,ip,width) in self.subnets_exclude: + self.pfile.write('%d,%d,1,%s\n' % (family, width, ip)) self.pfile.write('GO\n') self.pfile.flush() line = self.pfile.readline() @@ -234,7 +234,7 @@ def onaccept_tcp(listener, mux, handlers): dstip = original_dst(sock) debug1('Accept: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1], dstip[0],dstip[1])) - if dstip[1] == listener.getsockname()[1] and islocal(dstip[0]): + if dstip[1] == listener.getsockname()[1] and islocal(dstip[0], sock.family): debug1("-- ignored: that's my address!\n") sock.close() return @@ -243,7 +243,7 @@ def onaccept_tcp(listener, mux, handlers): log('warning: too many open channels. Discarded connection.\n') sock.close() return - mux.send(chan, ssnet.CMD_TCP_CONNECT, '%s,%s' % dstip) + mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' % (sock.family, dstip[0], dstip[1])) outwrap = MuxWrapper(mux, chan) handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) expire_connections(time.time(), mux) @@ -329,8 +329,8 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, def onroutes(routestr): if auto_nets: for line in routestr.strip().split('\n'): - (ip,width) = line.split(',', 1) - fw.auto_nets.append((ip,int(width))) + (family,ip,width) = line.split(',', 2) + fw.auto_nets.append((family,ip,int(width))) # we definitely want to do this *after* starting ssh, or we might end # up intercepting the ssh connection! diff --git a/firewall.py b/firewall.py index 4fd8c79..08bff15 100644 --- a/firewall.py +++ b/firewall.py @@ -14,8 +14,12 @@ def nonfatal(func, *args): log('error: %s\n' % e) -def ipt_chain_exists(name): - argv = ['iptables', '-t', 'nat', '-nL'] +def ipt_chain_exists(family, name): + if family == socket.AF_INET: + cmd = 'iptables' + else: + raise Exception('Unsupported family "%s"'%family_to_string(family)) + argv = [cmd, '-t', 'nat', '-nL'] p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE) for line in p.stdout: if line.startswith('Chain %s ' % name): @@ -25,8 +29,11 @@ def ipt_chain_exists(name): raise Fatal('%r returned %d' % (argv, rv)) -def ipt(*args): - argv = ['iptables', '-t', 'nat'] + list(args) +def _ipt(family, *args): + if family == socket.AF_INET: + argv = ['iptables', '-t', 'nat'] + list(args) + else: + raise Exception('Unsupported family "%s"'%family_to_string(family)) debug1('>> %s\n' % ' '.join(argv)) rv = ssubprocess.call(argv) if rv: @@ -34,7 +41,7 @@ def ipt(*args): _no_ttl_module = False -def ipt_ttl(*args): +def _ipt_ttl(family, *args): global _no_ttl_module if not _no_ttl_module: # we avoid infinite loops by generating server-side connections @@ -42,16 +49,15 @@ def ipt_ttl(*args): # connections, in case client == server. try: argsplus = list(args) + ['-m', 'ttl', '!', '--ttl', '42'] - ipt(*argsplus) + _ipt(family, *argsplus) except Fatal: - ipt(*args) + _ipt(family, *args) # we only get here if the non-ttl attempt succeeds log('sshuttle: warning: your iptables is missing ' 'the ttl module.\n') _no_ttl_module = True else: - ipt(*args) - + _ipt(family, *args) # We name the chain based on the transproxy port number so that it's possible @@ -59,11 +65,20 @@ def ipt_ttl(*args): # multiple copies shouldn't have overlapping subnets, or only the most- # recently-started one will win (because we use "-I OUTPUT 1" instead of # "-A OUTPUT"). -def do_iptables(port, dnsport, subnets): +def do_iptables(port, dnsport, family, subnets): + # only ipv4 supported with NAT + if family != socket.AF_INET: + raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family)) + + def ipt(*args): + return _ipt(family, *args) + def ipt_ttl(*args): + return _ipt_ttl(family, *args) + chain = 'sshuttle-%s' % port # basic cleanup/setup of chains - if ipt_chain_exists(chain): + if ipt_chain_exists(family, chain): nonfatal(ipt, '-D', 'OUTPUT', '-j', chain) nonfatal(ipt, '-D', 'PREROUTING', '-j', chain) nonfatal(ipt, '-F', chain) @@ -81,7 +96,7 @@ def do_iptables(port, dnsport, subnets): # to least-specific, and at any given level of specificity, we want # excludes to come first. That's why the columns are in such a non- # intuitive order. - for swidth,sexclude,snet in sorted(subnets, reverse=True): + for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True): if sexclude: ipt('-A', chain, '-j', 'RETURN', '--dest', '%s/%s' % (snet,swidth), @@ -207,7 +222,11 @@ def ipfw(*args): raise Fatal('%r returned %d' % (argv, rv)) -def do_ipfw(port, dnsport, subnets): +def do_ipfw(port, dnsport, family, subnets): + # IPv6 not supported + if family not in [socket.AF_INET, ]: + raise Exception('Address family "%s" unsupported by ipfw method'%family_to_string(family)) + sport = str(port) xsport = str(port+1) @@ -240,7 +259,7 @@ def do_ipfw(port, dnsport, subnets): if subnets: # create new subnet entries - for swidth,sexclude,snet in sorted(subnets, reverse=True): + for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True): if sexclude: ipfw('add', sport, 'skipto', xsport, 'log', 'tcp', @@ -419,15 +438,21 @@ def main(port, dnsport, syslog): elif line == 'GO\n': break try: - (width,exclude,ip) = line.strip().split(',', 2) + (family,width,exclude,ip) = line.strip().split(',', 3) except: raise Fatal('firewall: expected route or GO but got %r' % line) - subnets.append((int(width), bool(int(exclude)), ip)) + subnets.append((int(family), int(width), bool(int(exclude)), ip)) try: if line: debug1('firewall manager: starting transproxy.\n') - do_wait = do_it(port, dnsport, subnets) + + subnets_v4 = filter(lambda i: i[0]==socket.AF_INET, subnets) + if port: + do_wait = do_it(port, dnsport, socket.AF_INET, subnets_v4) + elif len(subnets_v4) > 0: + debug1('IPv4 subnets defined but IPv4 disabled\n') + sys.stdout.write('STARTED\n') try: @@ -456,5 +481,6 @@ def main(port, dnsport, syslog): debug1('firewall manager: undoing changes.\n') except: pass - do_it(port, 0, []) + if port: + do_it(port, 0, socket.AF_INET, []) restore_etc_hosts(port) diff --git a/helpers.py b/helpers.py index af49788..2025348 100644 --- a/helpers.py +++ b/helpers.py @@ -58,8 +58,8 @@ def resolvconf_random_nameserver(): return '127.0.0.1' -def islocal(ip): - sock = socket.socket() +def islocal(ip,family): + sock = socket.socket(family) try: try: sock.bind((ip, 0)) @@ -73,3 +73,11 @@ def islocal(ip): return True # it's a local IP, or there would have been an error +def family_to_string(family): + if family == socket.AF_INET6: + return "AF_INET6" + elif family == socket.AF_INET: + return "AF_INET" + else: + return str(family) + diff --git a/main.py b/main.py index 1cf00af..2dec9df 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -import sys, os, re +import sys, os, re, socket import helpers, options, client, server, firewall, hostwatch import compat.ssubprocess as ssubprocess from helpers import * @@ -22,7 +22,7 @@ def parse_subnets(subnets_str): raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) if width > 32: raise Fatal('*/%d is greater than the maximum of 32' % width) - subnets.append(('%d.%d.%d.%d' % (a,b,c,d), width)) + subnets.append((socket.AF_INET, '%d.%d.%d.%d' % (a,b,c,d), width)) return subnets diff --git a/server.py b/server.py index bdf7ae2..84d2e89 100644 --- a/server.py +++ b/server.py @@ -59,7 +59,7 @@ def _list_routes(): mask = _maskbits(maskw) # returns 32 if maskw is null width = min(ipw[1], mask) ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width) - routes.append((socket.inet_ntoa(struct.pack('!I', ip)), width)) + routes.append((socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width)) rv = p.wait() if rv != 0: log('WARNING: %r returned %d\n' % (argv, rv)) @@ -68,9 +68,9 @@ def _list_routes(): def list_routes(): - for (ip,width) in _list_routes(): + for (family, ip,width) in _list_routes(): if not ip.startswith('0.') and not ip.startswith('127.'): - yield (ip,width) + yield (family, ip,width) def _exc_dump(): @@ -170,7 +170,7 @@ def main(): routes = list(list_routes()) debug1('available routes:\n') for r in routes: - debug1(' %s/%d\n' % r) + debug1(' %d/%s/%d\n' % r) # synchronization header sys.stdout.write('\0\0SSHUTTLE0001') @@ -184,7 +184,7 @@ def main(): handlers.append(mux) routepkt = '' for r in routes: - routepkt += '%s,%d\n' % r + routepkt += '%d,%s,%d\n' % r mux.send(0, ssnet.CMD_ROUTES, routepkt) hw = Hostwatch() @@ -213,9 +213,10 @@ def main(): mux.got_host_req = got_host_req def new_channel(channel, data): - (dstip,dstport) = data.split(',', 1) + (family,dstip,dstport) = data.split(',', 2) + family = int(family) dstport = int(dstport) - outwrap = ssnet.connect_dst(dstip,dstport) + outwrap = ssnet.connect_dst(family, dstip, dstport) handlers.append(Proxy(MuxWrapper(mux, channel), outwrap)) mux.new_channel = new_channel diff --git a/ssnet.py b/ssnet.py index ed84348..904ea3e 100644 --- a/ssnet.py +++ b/ssnet.py @@ -523,9 +523,9 @@ class MuxWrapper(SockWrapper): % (cmd, len(data))) -def connect_dst(ip, port): +def connect_dst(family, ip, port): debug2('Connecting to %s:%d\n' % (ip, port)) - outsock = socket.socket() + outsock = socket.socket(family) outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) return SockWrapper(outsock, outsock, connect_to = (ip,port),