Keep track of address family address belongs too.

This commit is contained in:
Brian May 2011-06-06 11:57:42 +10:00
parent 50849b86b0
commit 061e6a0933
6 changed files with 74 additions and 39 deletions

View File

@ -180,10 +180,10 @@ class FirewallClient:
def start(self): def start(self):
self.pfile.write('ROUTES\n') self.pfile.write('ROUTES\n')
for (ip,width) in self.subnets_include+self.auto_nets: for (family,ip,width) in self.subnets_include+self.auto_nets:
self.pfile.write('%d,0,%s\n' % (width, ip)) self.pfile.write('%d,%d,0,%s\n' % (family, width, ip))
for (ip,width) in self.subnets_exclude: for (family,ip,width) in self.subnets_exclude:
self.pfile.write('%d,1,%s\n' % (width, ip)) self.pfile.write('%d,%d,1,%s\n' % (family, width, ip))
self.pfile.write('GO\n') self.pfile.write('GO\n')
self.pfile.flush() self.pfile.flush()
line = self.pfile.readline() line = self.pfile.readline()
@ -234,7 +234,7 @@ def onaccept_tcp(listener, mux, handlers):
dstip = original_dst(sock) dstip = original_dst(sock)
debug1('Accept: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1], debug1('Accept: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1],
dstip[0],dstip[1])) dstip[0],dstip[1]))
if dstip[1] == listener.getsockname()[1] and islocal(dstip[0]): if dstip[1] == listener.getsockname()[1] and islocal(dstip[0], sock.family):
debug1("-- ignored: that's my address!\n") debug1("-- ignored: that's my address!\n")
sock.close() sock.close()
return return
@ -243,7 +243,7 @@ def onaccept_tcp(listener, mux, handlers):
log('warning: too many open channels. Discarded connection.\n') log('warning: too many open channels. Discarded connection.\n')
sock.close() sock.close()
return return
mux.send(chan, ssnet.CMD_TCP_CONNECT, '%s,%s' % dstip) mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' % (sock.family, dstip[0], dstip[1]))
outwrap = MuxWrapper(mux, chan) outwrap = MuxWrapper(mux, chan)
handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
expire_connections(time.time(), mux) expire_connections(time.time(), mux)
@ -329,8 +329,8 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
def onroutes(routestr): def onroutes(routestr):
if auto_nets: if auto_nets:
for line in routestr.strip().split('\n'): for line in routestr.strip().split('\n'):
(ip,width) = line.split(',', 1) (family,ip,width) = line.split(',', 2)
fw.auto_nets.append((ip,int(width))) fw.auto_nets.append((family,ip,int(width)))
# we definitely want to do this *after* starting ssh, or we might end # we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection! # up intercepting the ssh connection!

View File

@ -14,8 +14,12 @@ def nonfatal(func, *args):
log('error: %s\n' % e) log('error: %s\n' % e)
def ipt_chain_exists(name): def ipt_chain_exists(family, name):
argv = ['iptables', '-t', 'nat', '-nL'] if family == socket.AF_INET:
cmd = 'iptables'
else:
raise Exception('Unsupported family "%s"'%family_to_string(family))
argv = [cmd, '-t', 'nat', '-nL']
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE) p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE)
for line in p.stdout: for line in p.stdout:
if line.startswith('Chain %s ' % name): if line.startswith('Chain %s ' % name):
@ -25,8 +29,11 @@ def ipt_chain_exists(name):
raise Fatal('%r returned %d' % (argv, rv)) raise Fatal('%r returned %d' % (argv, rv))
def ipt(*args): def _ipt(family, *args):
argv = ['iptables', '-t', 'nat'] + list(args) if family == socket.AF_INET:
argv = ['iptables', '-t', 'nat'] + list(args)
else:
raise Exception('Unsupported family "%s"'%family_to_string(family))
debug1('>> %s\n' % ' '.join(argv)) debug1('>> %s\n' % ' '.join(argv))
rv = ssubprocess.call(argv) rv = ssubprocess.call(argv)
if rv: if rv:
@ -34,7 +41,7 @@ def ipt(*args):
_no_ttl_module = False _no_ttl_module = False
def ipt_ttl(*args): def _ipt_ttl(family, *args):
global _no_ttl_module global _no_ttl_module
if not _no_ttl_module: if not _no_ttl_module:
# we avoid infinite loops by generating server-side connections # we avoid infinite loops by generating server-side connections
@ -42,16 +49,15 @@ def ipt_ttl(*args):
# connections, in case client == server. # connections, in case client == server.
try: try:
argsplus = list(args) + ['-m', 'ttl', '!', '--ttl', '42'] argsplus = list(args) + ['-m', 'ttl', '!', '--ttl', '42']
ipt(*argsplus) _ipt(family, *argsplus)
except Fatal: except Fatal:
ipt(*args) _ipt(family, *args)
# we only get here if the non-ttl attempt succeeds # we only get here if the non-ttl attempt succeeds
log('sshuttle: warning: your iptables is missing ' log('sshuttle: warning: your iptables is missing '
'the ttl module.\n') 'the ttl module.\n')
_no_ttl_module = True _no_ttl_module = True
else: else:
ipt(*args) _ipt(family, *args)
# We name the chain based on the transproxy port number so that it's possible # We name the chain based on the transproxy port number so that it's possible
@ -59,11 +65,20 @@ def ipt_ttl(*args):
# multiple copies shouldn't have overlapping subnets, or only the most- # multiple copies shouldn't have overlapping subnets, or only the most-
# recently-started one will win (because we use "-I OUTPUT 1" instead of # recently-started one will win (because we use "-I OUTPUT 1" instead of
# "-A OUTPUT"). # "-A OUTPUT").
def do_iptables(port, dnsport, subnets): def do_iptables(port, dnsport, family, subnets):
# only ipv4 supported with NAT
if family != socket.AF_INET:
raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family))
def ipt(*args):
return _ipt(family, *args)
def ipt_ttl(*args):
return _ipt_ttl(family, *args)
chain = 'sshuttle-%s' % port chain = 'sshuttle-%s' % port
# basic cleanup/setup of chains # basic cleanup/setup of chains
if ipt_chain_exists(chain): if ipt_chain_exists(family, chain):
nonfatal(ipt, '-D', 'OUTPUT', '-j', chain) nonfatal(ipt, '-D', 'OUTPUT', '-j', chain)
nonfatal(ipt, '-D', 'PREROUTING', '-j', chain) nonfatal(ipt, '-D', 'PREROUTING', '-j', chain)
nonfatal(ipt, '-F', chain) nonfatal(ipt, '-F', chain)
@ -81,7 +96,7 @@ def do_iptables(port, dnsport, subnets):
# to least-specific, and at any given level of specificity, we want # to least-specific, and at any given level of specificity, we want
# excludes to come first. That's why the columns are in such a non- # excludes to come first. That's why the columns are in such a non-
# intuitive order. # intuitive order.
for swidth,sexclude,snet in sorted(subnets, reverse=True): for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude: if sexclude:
ipt('-A', chain, '-j', 'RETURN', ipt('-A', chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet,swidth),
@ -207,7 +222,11 @@ def ipfw(*args):
raise Fatal('%r returned %d' % (argv, rv)) raise Fatal('%r returned %d' % (argv, rv))
def do_ipfw(port, dnsport, subnets): def do_ipfw(port, dnsport, family, subnets):
# IPv6 not supported
if family not in [socket.AF_INET, ]:
raise Exception('Address family "%s" unsupported by ipfw method'%family_to_string(family))
sport = str(port) sport = str(port)
xsport = str(port+1) xsport = str(port+1)
@ -240,7 +259,7 @@ def do_ipfw(port, dnsport, subnets):
if subnets: if subnets:
# create new subnet entries # create new subnet entries
for swidth,sexclude,snet in sorted(subnets, reverse=True): for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude: if sexclude:
ipfw('add', sport, 'skipto', xsport, ipfw('add', sport, 'skipto', xsport,
'log', 'tcp', 'log', 'tcp',
@ -419,15 +438,21 @@ def main(port, dnsport, syslog):
elif line == 'GO\n': elif line == 'GO\n':
break break
try: try:
(width,exclude,ip) = line.strip().split(',', 2) (family,width,exclude,ip) = line.strip().split(',', 3)
except: except:
raise Fatal('firewall: expected route or GO but got %r' % line) raise Fatal('firewall: expected route or GO but got %r' % line)
subnets.append((int(width), bool(int(exclude)), ip)) subnets.append((int(family), int(width), bool(int(exclude)), ip))
try: try:
if line: if line:
debug1('firewall manager: starting transproxy.\n') debug1('firewall manager: starting transproxy.\n')
do_wait = do_it(port, dnsport, subnets)
subnets_v4 = filter(lambda i: i[0]==socket.AF_INET, subnets)
if port:
do_wait = do_it(port, dnsport, socket.AF_INET, subnets_v4)
elif len(subnets_v4) > 0:
debug1('IPv4 subnets defined but IPv4 disabled\n')
sys.stdout.write('STARTED\n') sys.stdout.write('STARTED\n')
try: try:
@ -456,5 +481,6 @@ def main(port, dnsport, syslog):
debug1('firewall manager: undoing changes.\n') debug1('firewall manager: undoing changes.\n')
except: except:
pass pass
do_it(port, 0, []) if port:
do_it(port, 0, socket.AF_INET, [])
restore_etc_hosts(port) restore_etc_hosts(port)

View File

@ -58,8 +58,8 @@ def resolvconf_random_nameserver():
return '127.0.0.1' return '127.0.0.1'
def islocal(ip): def islocal(ip,family):
sock = socket.socket() sock = socket.socket(family)
try: try:
try: try:
sock.bind((ip, 0)) sock.bind((ip, 0))
@ -73,3 +73,11 @@ def islocal(ip):
return True # it's a local IP, or there would have been an error return True # it's a local IP, or there would have been an error
def family_to_string(family):
if family == socket.AF_INET6:
return "AF_INET6"
elif family == socket.AF_INET:
return "AF_INET"
else:
return str(family)

View File

@ -1,4 +1,4 @@
import sys, os, re import sys, os, re, socket
import helpers, options, client, server, firewall, hostwatch import helpers, options, client, server, firewall, hostwatch
import compat.ssubprocess as ssubprocess import compat.ssubprocess as ssubprocess
from helpers import * from helpers import *
@ -22,7 +22,7 @@ def parse_subnets(subnets_str):
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d))
if width > 32: if width > 32:
raise Fatal('*/%d is greater than the maximum of 32' % width) raise Fatal('*/%d is greater than the maximum of 32' % width)
subnets.append(('%d.%d.%d.%d' % (a,b,c,d), width)) subnets.append((socket.AF_INET, '%d.%d.%d.%d' % (a,b,c,d), width))
return subnets return subnets

View File

@ -59,7 +59,7 @@ def _list_routes():
mask = _maskbits(maskw) # returns 32 if maskw is null mask = _maskbits(maskw) # returns 32 if maskw is null
width = min(ipw[1], mask) width = min(ipw[1], mask)
ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width) ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width)
routes.append((socket.inet_ntoa(struct.pack('!I', ip)), width)) routes.append((socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width))
rv = p.wait() rv = p.wait()
if rv != 0: if rv != 0:
log('WARNING: %r returned %d\n' % (argv, rv)) log('WARNING: %r returned %d\n' % (argv, rv))
@ -68,9 +68,9 @@ def _list_routes():
def list_routes(): def list_routes():
for (ip,width) in _list_routes(): for (family, ip,width) in _list_routes():
if not ip.startswith('0.') and not ip.startswith('127.'): if not ip.startswith('0.') and not ip.startswith('127.'):
yield (ip,width) yield (family, ip,width)
def _exc_dump(): def _exc_dump():
@ -170,7 +170,7 @@ def main():
routes = list(list_routes()) routes = list(list_routes())
debug1('available routes:\n') debug1('available routes:\n')
for r in routes: for r in routes:
debug1(' %s/%d\n' % r) debug1(' %d/%s/%d\n' % r)
# synchronization header # synchronization header
sys.stdout.write('\0\0SSHUTTLE0001') sys.stdout.write('\0\0SSHUTTLE0001')
@ -184,7 +184,7 @@ def main():
handlers.append(mux) handlers.append(mux)
routepkt = '' routepkt = ''
for r in routes: for r in routes:
routepkt += '%s,%d\n' % r routepkt += '%d,%s,%d\n' % r
mux.send(0, ssnet.CMD_ROUTES, routepkt) mux.send(0, ssnet.CMD_ROUTES, routepkt)
hw = Hostwatch() hw = Hostwatch()
@ -213,9 +213,10 @@ def main():
mux.got_host_req = got_host_req mux.got_host_req = got_host_req
def new_channel(channel, data): def new_channel(channel, data):
(dstip,dstport) = data.split(',', 1) (family,dstip,dstport) = data.split(',', 2)
family = int(family)
dstport = int(dstport) dstport = int(dstport)
outwrap = ssnet.connect_dst(dstip,dstport) outwrap = ssnet.connect_dst(family, dstip, dstport)
handlers.append(Proxy(MuxWrapper(mux, channel), outwrap)) handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
mux.new_channel = new_channel mux.new_channel = new_channel

View File

@ -523,9 +523,9 @@ class MuxWrapper(SockWrapper):
% (cmd, len(data))) % (cmd, len(data)))
def connect_dst(ip, port): def connect_dst(family, ip, port):
debug2('Connecting to %s:%d\n' % (ip, port)) debug2('Connecting to %s:%d\n' % (ip, port))
outsock = socket.socket() outsock = socket.socket(family)
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
return SockWrapper(outsock, outsock, return SockWrapper(outsock, outsock,
connect_to = (ip,port), connect_to = (ip,port),