Keep track of address family address belongs too.

This commit is contained in:
Brian May 2011-06-06 11:57:42 +10:00
parent 50849b86b0
commit 061e6a0933
6 changed files with 74 additions and 39 deletions

View File

@ -180,10 +180,10 @@ class FirewallClient:
def start(self):
self.pfile.write('ROUTES\n')
for (ip,width) in self.subnets_include+self.auto_nets:
self.pfile.write('%d,0,%s\n' % (width, ip))
for (ip,width) in self.subnets_exclude:
self.pfile.write('%d,1,%s\n' % (width, ip))
for (family,ip,width) in self.subnets_include+self.auto_nets:
self.pfile.write('%d,%d,0,%s\n' % (family, width, ip))
for (family,ip,width) in self.subnets_exclude:
self.pfile.write('%d,%d,1,%s\n' % (family, width, ip))
self.pfile.write('GO\n')
self.pfile.flush()
line = self.pfile.readline()
@ -234,7 +234,7 @@ def onaccept_tcp(listener, mux, handlers):
dstip = original_dst(sock)
debug1('Accept: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1],
dstip[0],dstip[1]))
if dstip[1] == listener.getsockname()[1] and islocal(dstip[0]):
if dstip[1] == listener.getsockname()[1] and islocal(dstip[0], sock.family):
debug1("-- ignored: that's my address!\n")
sock.close()
return
@ -243,7 +243,7 @@ def onaccept_tcp(listener, mux, handlers):
log('warning: too many open channels. Discarded connection.\n')
sock.close()
return
mux.send(chan, ssnet.CMD_TCP_CONNECT, '%s,%s' % dstip)
mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' % (sock.family, dstip[0], dstip[1]))
outwrap = MuxWrapper(mux, chan)
handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
expire_connections(time.time(), mux)
@ -329,8 +329,8 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
def onroutes(routestr):
if auto_nets:
for line in routestr.strip().split('\n'):
(ip,width) = line.split(',', 1)
fw.auto_nets.append((ip,int(width)))
(family,ip,width) = line.split(',', 2)
fw.auto_nets.append((family,ip,int(width)))
# we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection!

View File

@ -14,8 +14,12 @@ def nonfatal(func, *args):
log('error: %s\n' % e)
def ipt_chain_exists(name):
argv = ['iptables', '-t', 'nat', '-nL']
def ipt_chain_exists(family, name):
if family == socket.AF_INET:
cmd = 'iptables'
else:
raise Exception('Unsupported family "%s"'%family_to_string(family))
argv = [cmd, '-t', 'nat', '-nL']
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE)
for line in p.stdout:
if line.startswith('Chain %s ' % name):
@ -25,8 +29,11 @@ def ipt_chain_exists(name):
raise Fatal('%r returned %d' % (argv, rv))
def ipt(*args):
argv = ['iptables', '-t', 'nat'] + list(args)
def _ipt(family, *args):
if family == socket.AF_INET:
argv = ['iptables', '-t', 'nat'] + list(args)
else:
raise Exception('Unsupported family "%s"'%family_to_string(family))
debug1('>> %s\n' % ' '.join(argv))
rv = ssubprocess.call(argv)
if rv:
@ -34,7 +41,7 @@ def ipt(*args):
_no_ttl_module = False
def ipt_ttl(*args):
def _ipt_ttl(family, *args):
global _no_ttl_module
if not _no_ttl_module:
# we avoid infinite loops by generating server-side connections
@ -42,16 +49,15 @@ def ipt_ttl(*args):
# connections, in case client == server.
try:
argsplus = list(args) + ['-m', 'ttl', '!', '--ttl', '42']
ipt(*argsplus)
_ipt(family, *argsplus)
except Fatal:
ipt(*args)
_ipt(family, *args)
# we only get here if the non-ttl attempt succeeds
log('sshuttle: warning: your iptables is missing '
'the ttl module.\n')
_no_ttl_module = True
else:
ipt(*args)
_ipt(family, *args)
# We name the chain based on the transproxy port number so that it's possible
@ -59,11 +65,20 @@ def ipt_ttl(*args):
# multiple copies shouldn't have overlapping subnets, or only the most-
# recently-started one will win (because we use "-I OUTPUT 1" instead of
# "-A OUTPUT").
def do_iptables(port, dnsport, subnets):
def do_iptables(port, dnsport, family, subnets):
# only ipv4 supported with NAT
if family != socket.AF_INET:
raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family))
def ipt(*args):
return _ipt(family, *args)
def ipt_ttl(*args):
return _ipt_ttl(family, *args)
chain = 'sshuttle-%s' % port
# basic cleanup/setup of chains
if ipt_chain_exists(chain):
if ipt_chain_exists(family, chain):
nonfatal(ipt, '-D', 'OUTPUT', '-j', chain)
nonfatal(ipt, '-D', 'PREROUTING', '-j', chain)
nonfatal(ipt, '-F', chain)
@ -81,7 +96,7 @@ def do_iptables(port, dnsport, subnets):
# to least-specific, and at any given level of specificity, we want
# excludes to come first. That's why the columns are in such a non-
# intuitive order.
for swidth,sexclude,snet in sorted(subnets, reverse=True):
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
ipt('-A', chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
@ -207,7 +222,11 @@ def ipfw(*args):
raise Fatal('%r returned %d' % (argv, rv))
def do_ipfw(port, dnsport, subnets):
def do_ipfw(port, dnsport, family, subnets):
# IPv6 not supported
if family not in [socket.AF_INET, ]:
raise Exception('Address family "%s" unsupported by ipfw method'%family_to_string(family))
sport = str(port)
xsport = str(port+1)
@ -240,7 +259,7 @@ def do_ipfw(port, dnsport, subnets):
if subnets:
# create new subnet entries
for swidth,sexclude,snet in sorted(subnets, reverse=True):
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
ipfw('add', sport, 'skipto', xsport,
'log', 'tcp',
@ -419,15 +438,21 @@ def main(port, dnsport, syslog):
elif line == 'GO\n':
break
try:
(width,exclude,ip) = line.strip().split(',', 2)
(family,width,exclude,ip) = line.strip().split(',', 3)
except:
raise Fatal('firewall: expected route or GO but got %r' % line)
subnets.append((int(width), bool(int(exclude)), ip))
subnets.append((int(family), int(width), bool(int(exclude)), ip))
try:
if line:
debug1('firewall manager: starting transproxy.\n')
do_wait = do_it(port, dnsport, subnets)
subnets_v4 = filter(lambda i: i[0]==socket.AF_INET, subnets)
if port:
do_wait = do_it(port, dnsport, socket.AF_INET, subnets_v4)
elif len(subnets_v4) > 0:
debug1('IPv4 subnets defined but IPv4 disabled\n')
sys.stdout.write('STARTED\n')
try:
@ -456,5 +481,6 @@ def main(port, dnsport, syslog):
debug1('firewall manager: undoing changes.\n')
except:
pass
do_it(port, 0, [])
if port:
do_it(port, 0, socket.AF_INET, [])
restore_etc_hosts(port)

View File

@ -58,8 +58,8 @@ def resolvconf_random_nameserver():
return '127.0.0.1'
def islocal(ip):
sock = socket.socket()
def islocal(ip,family):
sock = socket.socket(family)
try:
try:
sock.bind((ip, 0))
@ -73,3 +73,11 @@ def islocal(ip):
return True # it's a local IP, or there would have been an error
def family_to_string(family):
if family == socket.AF_INET6:
return "AF_INET6"
elif family == socket.AF_INET:
return "AF_INET"
else:
return str(family)

View File

@ -1,4 +1,4 @@
import sys, os, re
import sys, os, re, socket
import helpers, options, client, server, firewall, hostwatch
import compat.ssubprocess as ssubprocess
from helpers import *
@ -22,7 +22,7 @@ def parse_subnets(subnets_str):
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d))
if width > 32:
raise Fatal('*/%d is greater than the maximum of 32' % width)
subnets.append(('%d.%d.%d.%d' % (a,b,c,d), width))
subnets.append((socket.AF_INET, '%d.%d.%d.%d' % (a,b,c,d), width))
return subnets

View File

@ -59,7 +59,7 @@ def _list_routes():
mask = _maskbits(maskw) # returns 32 if maskw is null
width = min(ipw[1], mask)
ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width)
routes.append((socket.inet_ntoa(struct.pack('!I', ip)), width))
routes.append((socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width))
rv = p.wait()
if rv != 0:
log('WARNING: %r returned %d\n' % (argv, rv))
@ -68,9 +68,9 @@ def _list_routes():
def list_routes():
for (ip,width) in _list_routes():
for (family, ip,width) in _list_routes():
if not ip.startswith('0.') and not ip.startswith('127.'):
yield (ip,width)
yield (family, ip,width)
def _exc_dump():
@ -170,7 +170,7 @@ def main():
routes = list(list_routes())
debug1('available routes:\n')
for r in routes:
debug1(' %s/%d\n' % r)
debug1(' %d/%s/%d\n' % r)
# synchronization header
sys.stdout.write('\0\0SSHUTTLE0001')
@ -184,7 +184,7 @@ def main():
handlers.append(mux)
routepkt = ''
for r in routes:
routepkt += '%s,%d\n' % r
routepkt += '%d,%s,%d\n' % r
mux.send(0, ssnet.CMD_ROUTES, routepkt)
hw = Hostwatch()
@ -213,9 +213,10 @@ def main():
mux.got_host_req = got_host_req
def new_channel(channel, data):
(dstip,dstport) = data.split(',', 1)
(family,dstip,dstport) = data.split(',', 2)
family = int(family)
dstport = int(dstport)
outwrap = ssnet.connect_dst(dstip,dstport)
outwrap = ssnet.connect_dst(family, dstip, dstport)
handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
mux.new_channel = new_channel

View File

@ -523,9 +523,9 @@ class MuxWrapper(SockWrapper):
% (cmd, len(data)))
def connect_dst(ip, port):
def connect_dst(family, ip, port):
debug2('Connecting to %s:%d\n' % (ip, port))
outsock = socket.socket()
outsock = socket.socket(family)
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
return SockWrapper(outsock, outsock,
connect_to = (ip,port),