Choose which method to use for intercepting traffic.

This commit is contained in:
Brian May 2011-07-07 09:16:14 +10:00
parent 55f86a8b3f
commit c6200eecdc
3 changed files with 52 additions and 29 deletions

View File

@ -104,9 +104,9 @@ class MultiListener:
if self.v4: if self.v4:
self.v4.setsockopt(level, optname, value) self.v4.setsockopt(level, optname, value)
def add_handler(self, handlers, callback, mux): def add_handler(self, handlers, callback, method, mux):
if self.v4: if self.v4:
handlers.append(Handler([self.v4], lambda: callback(self.v4, mux, handlers))) handlers.append(Handler([self.v4], lambda: callback(self.v4, method, mux, handlers)))
def listen(self, backlog): def listen(self, backlog):
if self.v4: if self.v4:
@ -125,7 +125,7 @@ class MultiListener:
class FirewallClient: class FirewallClient:
def __init__(self, port, subnets_include, subnets_exclude, dnsport): def __init__(self, port, subnets_include, subnets_exclude, dnsport, method):
self.port = port self.port = port
self.auto_nets = [] self.auto_nets = []
self.subnets_include = subnets_include self.subnets_include = subnets_include
@ -133,7 +133,9 @@ class FirewallClient:
self.dnsport = dnsport self.dnsport = dnsport
argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] + argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] +
['-v'] * (helpers.verbose or 0) + ['-v'] * (helpers.verbose or 0) +
['--firewall', str(port), str(dnsport)]) ['--firewall', str(port),
str(dnsport),
method])
if ssyslog._p: if ssyslog._p:
argvbase += ['--syslog'] argvbase += ['--syslog']
argv_tries = [ argv_tries = [
@ -170,8 +172,9 @@ class FirewallClient:
raise Fatal(e) raise Fatal(e)
line = self.pfile.readline() line = self.pfile.readline()
self.check() self.check()
if line != 'READY\n': if line[0:5] != 'READY':
raise Fatal('%r expected READY, got %r' % (self.argv, line)) raise Fatal('%r expected READY, got %r' % (self.argv, line))
self.method = line[6:-1]
def check(self): def check(self):
rv = self.p.poll() rv = self.p.poll()
@ -214,7 +217,7 @@ def expire_connections(now, mux):
debug3('Remaining DNS requests: %d\n' % len(dnsreqs)) debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
def onaccept_tcp(listener, mux, handlers): def onaccept_tcp(listener, method, mux, handlers):
global _extra_fd global _extra_fd
try: try:
sock,srcip = listener.accept() sock,srcip = listener.accept()
@ -249,7 +252,7 @@ def onaccept_tcp(listener, mux, handlers):
expire_connections(time.time(), mux) expire_connections(time.time(), mux)
def dns_done(chan, mux, data): def dns_done(chan, mux, data, method):
peer,sock,timeout = dnsreqs.get(chan) or (None,None,None) peer,sock,timeout = dnsreqs.get(chan) or (None,None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer)) debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
if peer: if peer:
@ -259,7 +262,7 @@ def dns_done(chan, mux, data):
sock.sendto(data, peer) sock.sendto(data, peer)
def ondns(listener, mux, handlers): def ondns(listener, method, mux, handlers):
pkt,peer = listener.recvfrom(4096) pkt,peer = listener.recvfrom(4096)
now = time.time() now = time.time()
if pkt: if pkt:
@ -267,12 +270,12 @@ def ondns(listener, mux, handlers):
chan = mux.next_channel() chan = mux.next_channel()
dnsreqs[chan] = peer,listener,now+30 dnsreqs[chan] = peer,listener,now+30
mux.send(chan, ssnet.CMD_DNS_REQ, pkt) mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan, mux, data) mux.channels[chan] = lambda cmd,data: dns_done(chan, mux, data, method)
expire_connections(now, mux) expire_connections(now, mux)
def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control, def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
dns_listener, seed_hosts, auto_nets, dns_listener, method, seed_hosts, auto_nets,
syslog, daemon): syslog, daemon):
handlers = [] handlers = []
if helpers.verbose >= 1: if helpers.verbose >= 1:
@ -351,10 +354,10 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
fw.sethostip(name, ip) fw.sethostip(name, ip)
mux.got_host_list = onhostlist mux.got_host_list = onhostlist
tcp_listener.add_handler(handlers, onaccept_tcp, mux) tcp_listener.add_handler(handlers, onaccept_tcp, method, mux)
if dns_listener: if dns_listener:
dns_listener.add_handler(handlers, ondns, mux) dns_listener.add_handler(handlers, ondns, method, mux)
if seed_hosts != None: if seed_hosts != None:
debug1('seed_hosts: %r\n' % seed_hosts) debug1('seed_hosts: %r\n' % seed_hosts)
@ -371,8 +374,9 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
mux.callback() mux.callback()
def main(listenip_v4, ssh_cmd, remotename, python, latency_control, dns, def main(listenip_v4,
seed_hosts, auto_nets, ssh_cmd, remotename, python, latency_control, dns,
method, seed_hosts, auto_nets,
subnets_include, subnets_exclude, syslog, daemon, pidfile): subnets_include, subnets_exclude, syslog, daemon, pidfile):
if syslog: if syslog:
ssyslog.start_syslog() ssyslog.start_syslog()
@ -462,12 +466,13 @@ def main(listenip_v4, ssh_cmd, remotename, python, latency_control, dns,
dnsport_v4 = 0 dnsport_v4 = 0
dns_listener = None dns_listener = None
fw = FirewallClient(redirectport_v4, subnets_include, subnets_exclude, dnsport_v4) fw = FirewallClient(redirectport_v4, subnets_include, subnets_exclude, dnsport_v4, method)
try: try:
return _main(tcp_listener, fw, ssh_cmd, remotename, return _main(tcp_listener, fw, ssh_cmd, remotename,
python, latency_control, dns_listener, python, latency_control, dns_listener,
seed_hosts, auto_nets, syslog, daemon) fw.method, seed_hosts, auto_nets, syslog,
daemon)
finally: finally:
try: try:
if daemon: if daemon:

View File

@ -65,7 +65,7 @@ def _ipt_ttl(family, *args):
# multiple copies shouldn't have overlapping subnets, or only the most- # multiple copies shouldn't have overlapping subnets, or only the most-
# recently-started one will win (because we use "-I OUTPUT 1" instead of # recently-started one will win (because we use "-I OUTPUT 1" instead of
# "-A OUTPUT"). # "-A OUTPUT").
def do_iptables(port, dnsport, family, subnets): def do_iptables_nat(port, dnsport, family, subnets):
# only ipv4 supported with NAT # only ipv4 supported with NAT
if family != socket.AF_INET: if family != socket.AF_INET:
raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family)) raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family))
@ -389,7 +389,7 @@ def restore_etc_hosts(port):
# exit. In case that fails, it's not the end of the world; future runs will # exit. In case that fails, it's not the end of the world; future runs will
# supercede it in the transproxy list, at least, so the leftover rules # supercede it in the transproxy list, at least, so the leftover rules
# are hopefully harmless. # are hopefully harmless.
def main(port_v4, dnsport_v4, syslog): def main(port_v4, dnsport_v4, method, syslog):
assert(port_v4 > 0) assert(port_v4 > 0)
assert(port_v4 <= 65535) assert(port_v4 <= 65535)
assert(dnsport_v4 >= 0) assert(dnsport_v4 >= 0)
@ -398,12 +398,20 @@ def main(port_v4, dnsport_v4, syslog):
if os.getuid() != 0: if os.getuid() != 0:
raise Fatal('you must be root (or enable su/sudo) to set the firewall') raise Fatal('you must be root (or enable su/sudo) to set the firewall')
if program_exists('ipfw'): if method == "auto":
if program_exists('ipfw'):
method = "ipfw"
elif program_exists('iptables'):
method = "nat"
else:
raise Fatal("can't find either ipfw or iptables; check your PATH")
if method == "nat":
do_it = do_iptables_nat
elif method == "ipfw":
do_it = do_ipfw do_it = do_ipfw
elif program_exists('iptables'):
do_it = do_iptables
else: else:
raise Fatal("can't find either ipfw or iptables; check your PATH") raise Exception('Unknown method "%s"'%method)
# because of limitations of the 'su' command, the *real* stdin/stdout # because of limitations of the 'su' command, the *real* stdin/stdout
# are both attached to stdout initially. Clone stdout into stdin so we # are both attached to stdout initially. Clone stdout into stdin so we
@ -414,8 +422,8 @@ def main(port_v4, dnsport_v4, syslog):
ssyslog.start_syslog() ssyslog.start_syslog()
ssyslog.stderr_to_syslog() ssyslog.stderr_to_syslog()
debug1('firewall manager ready.\n') debug1('firewall manager ready method %s.\n'%method)
sys.stdout.write('READY\n') sys.stdout.write('READY %s\n'%method)
sys.stdout.flush() sys.stdout.flush()
# ctrl-c shouldn't be passed along to me. When the main sshuttle dies, # ctrl-c shouldn't be passed along to me. When the main sshuttle dies,

18
main.py
View File

@ -54,6 +54,7 @@ l,listen= transproxy to this ip address and port number [127.0.0.1:0]
H,auto-hosts scan for remote hostnames and update local /etc/hosts H,auto-hosts scan for remote hostnames and update local /etc/hosts
N,auto-nets automatically determine subnets to route N,auto-nets automatically determine subnets to route
dns capture local DNS requests and forward to the remote DNS server dns capture local DNS requests and forward to the remote DNS server
method= auto, nat, or ipfw
python= path to python interpreter on the remote server python= path to python interpreter on the remote server
r,remote= ssh hostname (and optional username) of remote sshuttle server r,remote= ssh hostname (and optional username) of remote sshuttle server
x,exclude= exclude this subnet (can be used more than once) x,exclude= exclude this subnet (can be used more than once)
@ -86,9 +87,10 @@ try:
server.latency_control = opt.latency_control server.latency_control = opt.latency_control
sys.exit(server.main()) sys.exit(server.main())
elif opt.firewall: elif opt.firewall:
if len(extra) != 2: if len(extra) != 3:
o.fatal('exactly two arguments expected') o.fatal('exactly three arguments expected')
sys.exit(firewall.main(int(extra[0]), int(extra[1]), opt.syslog)) sys.exit(firewall.main(int(extra[0]), int(extra[1]),
extra[2], opt.syslog))
elif opt.hostwatch: elif opt.hostwatch:
sys.exit(hostwatch.hw_main(extra)) sys.exit(hostwatch.hw_main(extra))
else: else:
@ -110,12 +112,20 @@ try:
sh = [] sh = []
else: else:
sh = None sh = None
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), if not opt.method:
method = "auto"
elif opt.method in [ "auto", "nat", "ipfw" ]:
method = opt.method
else:
o.fatal("method %s not supported"%opt.method)
ipport_v4 = parse_ipport(opt.listen or '0.0.0.0:0')
sys.exit(client.main(ipport_v4,
opt.ssh_cmd, opt.ssh_cmd,
remotename, remotename,
opt.python, opt.python,
opt.latency_control, opt.latency_control,
opt.dns, opt.dns,
method,
sh, sh,
opt.auto_nets, opt.auto_nets,
parse_subnets(includes), parse_subnets(includes),