Added new --auto-hosts and --seed-hosts options to the client.

Now if you use --auto-hosts (-H), the client will ask the server to spawn a
hostwatcher to add names.  That, in turn, will send names back to the
server, which sends them back to the client, which sends them to the
firewall subprocess, which will write them to /etc/hosts.  Whew!

Only the firewall process can write to /etc/hosts, of course, because only
he's running as root.

Since the name discovery process is kind of slow, we cache the names in
~/.sshuttle.hosts on the remote server.

Right now, most of the names are discovered using nmblookup and smbclient,
as well as by reading the existing entries in /etc/hosts.  What would really
be nice would be to query active directory or mdns somehow... but I don't
really know how those work, so this is what you get for now :)  It's pretty
neat, at least.
This commit is contained in:
Avery Pennarun 2010-05-08 03:03:12 -04:00
parent a2ea5ab455
commit 33efa5ac62
7 changed files with 193 additions and 17 deletions

View File

@ -1,4 +1,4 @@
import struct, socket, select, subprocess, errno import struct, socket, select, subprocess, errno, re
import helpers, ssnet, ssh import helpers, ssnet, ssh
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import * from helpers import *
@ -76,6 +76,12 @@ class FirewallClient:
if line != 'STARTED\n': if line != 'STARTED\n':
raise Fatal('%r expected STARTED, got %r' % (self.argv, line)) raise Fatal('%r expected STARTED, got %r' % (self.argv, line))
def sethostip(self, hostname, ip):
assert(not re.search(r'[^-\w]', hostname))
assert(not re.search(r'[^0-9.]', ip))
self.pfile.write('HOST %s,%s\n' % (hostname, ip))
self.pfile.flush()
def done(self): def done(self):
self.pfile.close() self.pfile.close()
rv = self.p.wait() rv = self.p.wait()
@ -83,7 +89,7 @@ class FirewallClient:
raise Fatal('cleanup: %r returned %d' % (self.argv, rv)) raise Fatal('cleanup: %r returned %d' % (self.argv, rv))
def _main(listener, fw, use_server, remotename, auto_nets): def _main(listener, fw, use_server, remotename, seed_hosts, auto_nets):
handlers = [] handlers = []
if use_server: if use_server:
if helpers.verbose >= 1: if helpers.verbose >= 1:
@ -122,6 +128,14 @@ def _main(listener, fw, use_server, remotename, auto_nets):
fw.start() fw.start()
mux.got_routes = onroutes mux.got_routes = onroutes
def onhostlist(hostlist):
debug2('got host list: %r\n' % hostlist)
for line in hostlist.strip().split():
if line:
name,ip = line.split(',', 1)
fw.sethostip(name, ip)
mux.got_host_list = onhostlist
def onaccept(): def onaccept():
sock,srcip = listener.accept() sock,srcip = listener.accept()
dstip = original_dst(sock) dstip = original_dst(sock)
@ -140,6 +154,10 @@ def _main(listener, fw, use_server, remotename, auto_nets):
handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
handlers.append(Handler([listener], onaccept)) handlers.append(Handler([listener], onaccept))
if seed_hosts != None:
debug1('seed_hosts: %r\n' % seed_hosts)
mux.send(0, ssnet.CMD_HOST_REQ, '\n'.join(seed_hosts))
while 1: while 1:
if use_server: if use_server:
rv = serverproc.poll() rv = serverproc.poll()
@ -165,7 +183,7 @@ def _main(listener, fw, use_server, remotename, auto_nets):
mux.check_fullness() mux.check_fullness()
def main(listenip, use_server, remotename, auto_nets, subnets): def main(listenip, use_server, remotename, seed_hosts, auto_nets, subnets):
debug1('Starting sshuttle proxy.\n') debug1('Starting sshuttle proxy.\n')
listener = socket.socket() listener = socket.socket()
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -195,6 +213,7 @@ def main(listenip, use_server, remotename, auto_nets, subnets):
fw = FirewallClient(listenip[1], subnets) fw = FirewallClient(listenip[1], subnets)
try: try:
return _main(listener, fw, use_server, remotename, auto_nets) return _main(listener, fw, use_server, remotename,
seed_hosts, auto_nets)
finally: finally:
fw.done() fw.done()

View File

@ -1,4 +1,4 @@
import subprocess, re import subprocess, re, errno
import helpers import helpers
from helpers import * from helpers import *
@ -134,6 +134,38 @@ def program_exists(name):
if os.path.exists(fn): if os.path.exists(fn):
return not os.path.isdir(fn) and os.access(fn, os.X_OK) return not os.path.isdir(fn) and os.access(fn, os.X_OK)
hostmap = {}
def rewrite_etc_hosts(port):
HOSTSFILE='/etc/hosts'
BAKFILE='%s.sbak' % HOSTSFILE
APPEND='# sshuttle-firewall-%d AUTOCREATED' % port
old_content = ''
try:
old_content = open(HOSTSFILE).read()
except IOError, e:
if e.errno == errno.ENOENT:
pass
else:
raise
if old_content.strip() and not os.path.exists(BAKFILE):
open(BAKFILE, 'w').write(old_content)
tmpname = "%s.%d.tmp" % (HOSTSFILE, port)
f = open(tmpname, 'w')
for line in old_content.rstrip().split('\n'):
if line.find(APPEND) >= 0:
continue
f.write('%s\n' % line)
for (name,ip) in sorted(hostmap.items()):
f.write('%-30s %s\n' % ('%s %s' % (ip,name), APPEND))
f.close()
os.rename(tmpname, HOSTSFILE)
def restore_etc_hosts(port):
global hostmap
hostmap = {}
rewrite_etc_hosts(port)
# This is some voodoo for setting up the kernel's transparent # This is some voodoo for setting up the kernel's transparent
# proxying stuff. If subnets is empty, we just delete our sshuttle rules; # proxying stuff. If subnets is empty, we just delete our sshuttle rules;
@ -199,20 +231,29 @@ def main(port):
try: try:
sys.stdout.flush() sys.stdout.flush()
# Now we wait until EOF or any other kind of exception. We need
# to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important!
while sys.stdin.readline(128):
pass
except IOError: except IOError:
# the parent process died for some reason; he's surely been loud # the parent process died for some reason; he's surely been loud
# enough, so no reason to report another error # enough, so no reason to report another error
return return
# Now we wait until EOF or any other kind of exception. We need
# to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important!
while 1:
line = sys.stdin.readline(128)
if line.startswith('HOST '):
(name,ip) = line[5:].strip().split(',', 1)
hostmap[name] = ip
rewrite_etc_hosts(port)
elif line:
raise Fatal('expected EOF, got %r' % line)
else:
break
finally: finally:
try: try:
debug1('firewall manager: undoing changes.\n') debug1('firewall manager: undoing changes.\n')
except: except:
pass pass
do_it(port, []) do_it(port, [])
restore_etc_hosts(port)

View File

@ -4,6 +4,7 @@ if not globals().get('skip_imports'):
from helpers import * from helpers import *
POLL_TIME = 60*15 POLL_TIME = 60*15
CACHEFILE=os.path.expanduser('~/.sshuttle.hosts')
_nmb_ok = True _nmb_ok = True
@ -17,6 +18,39 @@ def _is_ip(s):
return re.match(r'\d+\.\d+\.\d+\.\d+$', s) return re.match(r'\d+\.\d+\.\d+\.\d+$', s)
def write_host_cache():
tmpname = '%s.%d.tmp' % (CACHEFILE, os.getpid())
try:
f = open(tmpname, 'wb')
for name,ip in sorted(hostnames.items()):
f.write('%s,%s\n' % (name, ip))
f.close()
os.rename(tmpname, CACHEFILE)
finally:
try:
os.unlink(tmpname)
except:
pass
def read_host_cache():
try:
f = open(CACHEFILE)
except IOError, e:
if e.errno == errno.ENOENT:
return
else:
raise
for line in f:
words = line.strip().split(',')
if len(words) == 2:
(name,ip) = words
name = re.sub(r'[^-\w]', '-', name).strip()
ip = re.sub(r'[^0-9.]', '', ip).strip()
if name and ip:
found_host(name, ip)
def found_host(hostname, ip): def found_host(hostname, ip):
hostname = re.sub(r'\..*', '', hostname) hostname = re.sub(r'\..*', '', hostname)
hostname = re.sub(r'[^-\w]', '_', hostname) hostname = re.sub(r'[^-\w]', '_', hostname)
@ -27,6 +61,7 @@ def found_host(hostname, ip):
hostnames[hostname] = ip hostnames[hostname] = ip
debug1('Found: %s: %s\n' % (hostname, ip)) debug1('Found: %s: %s\n' % (hostname, ip))
sys.stdout.write('%s,%s\n' % (hostname, ip)) sys.stdout.write('%s,%s\n' % (hostname, ip))
write_host_cache()
def _check_etc_hosts(): def _check_etc_hosts():
@ -188,6 +223,8 @@ def hw_main(seed_hosts):
else: else:
helpers.logprefix = 'hostwatch: ' helpers.logprefix = 'hostwatch: '
read_host_cache()
_enqueue(_check_etc_hosts) _enqueue(_check_etc_hosts)
check_host('localhost') check_host('localhost')
check_host(socket.gethostname()) check_host(socket.gethostname())

13
main.py
View File

@ -50,9 +50,11 @@ sshuttle --firewall <port> <subnets...>
sshuttle --server sshuttle --server
-- --
l,listen= transproxy to this ip address and port number [default=0] l,listen= transproxy to this ip address and port number [default=0]
N,auto-nets automatically determine subnets to route H,auto-hosts scan for remote hostnames and update local /etc/hosts
N,auto-nets automatically determine subnets to route
r,remote= ssh hostname (and optional username) of remote sshuttle server r,remote= ssh hostname (and optional username) of remote sshuttle server
v,verbose increase debug message verbosity v,verbose increase debug message verbosity
seed-hosts= with -H, use these hostnames for initial scan (comma-separated)
noserver don't use a separate server process (mostly for debugging) noserver don't use a separate server process (mostly for debugging)
server [internal use only] server [internal use only]
firewall [internal use only] firewall [internal use only]
@ -80,9 +82,18 @@ try:
remotename = opt.remote remotename = opt.remote
if remotename == '' or remotename == '-': if remotename == '' or remotename == '-':
remotename = None remotename = None
if opt.seed_hosts and not opt.auto_hosts:
o.fatal('--seed-hosts only works if you also use -H')
if opt.seed_hosts:
sh = re.split(r'[\s,]+', (opt.seed_hosts or "").strip())
elif opt.auto_hosts:
sh = []
else:
sh = None
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
not opt.noserver, not opt.noserver,
remotename, remotename,
sh,
opt.auto_nets, opt.auto_nets,
parse_subnets(extra))) parse_subnets(extra)))
except Fatal, e: except Fatal, e:

View File

@ -1,6 +1,6 @@
import re, struct, socket, select, subprocess import re, struct, socket, select, subprocess, traceback
if not globals().get('skip_imports'): if not globals().get('skip_imports'):
import ssnet, helpers import ssnet, helpers, hostwatch
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import * from helpers import *
@ -67,6 +67,37 @@ def list_routes():
yield (ip,width) yield (ip,width)
def _exc_dump():
exc_info = sys.exc_info()
return ''.join(traceback.format_exception(*exc_info))
def start_hostwatch(seed_hosts):
s1,s2 = socket.socketpair()
pid = os.fork()
if not pid:
# child
rv = 99
try:
s2.close()
os.dup2(s1.fileno(), 1)
os.dup2(s1.fileno(), 0)
s1.close()
rv = hostwatch.hw_main(seed_hosts) or 0
except Exception, e:
log('%s\n' % _exc_dump())
rv = 98
finally:
os._exit(rv)
s1.close()
return pid,s2
class Hostwatch:
def __init__(self):
self.pid = 0
self.sock = None
def main(): def main():
if helpers.verbose >= 1: if helpers.verbose >= 1:
@ -93,15 +124,36 @@ def main():
for r in routes) for r in routes)
mux.send(0, ssnet.CMD_ROUTES, routepkt) mux.send(0, ssnet.CMD_ROUTES, routepkt)
hw = Hostwatch()
def hostwatch_ready():
assert(hw.pid)
content = hw.sock.recv(4096)
if content:
mux.send(0, ssnet.CMD_HOST_LIST, content)
else:
raise Fatal('hostwatch process died')
def got_host_req(data):
if not hw.pid:
(hw.pid,hw.sock) = start_hostwatch(data.strip().split())
handlers.append(Handler(socks = [hw.sock],
callback = hostwatch_ready))
mux.got_host_req = got_host_req
def new_channel(channel, data): def new_channel(channel, data):
(dstip,dstport) = data.split(',', 1) (dstip,dstport) = data.split(',', 1)
dstport = int(dstport) dstport = int(dstport)
outwrap = ssnet.connect_dst(dstip,dstport) outwrap = ssnet.connect_dst(dstip,dstport)
handlers.append(Proxy(MuxWrapper(mux, channel), outwrap)) handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
mux.new_channel = new_channel mux.new_channel = new_channel
while mux.ok: while mux.ok:
if hw.pid:
(rpid, rv) = os.waitpid(hw.pid, os.WNOHANG)
if rpid:
raise Fatal('hostwatch exited unexpectedly: code 0x%04x\n' % rv)
r = set() r = set()
w = set() w = set()
x = set() x = set()

1
ssh.py
View File

@ -30,6 +30,7 @@ def connect(rhostport):
content = readfile('assembler.py') content = readfile('assembler.py')
content2 = (empackage(z, 'helpers.py') + content2 = (empackage(z, 'helpers.py') +
empackage(z, 'ssnet.py') + empackage(z, 'ssnet.py') +
empackage(z, 'hostwatch.py') +
empackage(z, 'server.py') + empackage(z, 'server.py') +
"\n") "\n")

View File

@ -13,6 +13,8 @@ CMD_CLOSE = 0x4204
CMD_EOF = 0x4205 CMD_EOF = 0x4205
CMD_DATA = 0x4206 CMD_DATA = 0x4206
CMD_ROUTES = 0x4207 CMD_ROUTES = 0x4207
CMD_HOST_REQ = 0x4208
CMD_HOST_LIST = 0x4209
cmd_to_name = { cmd_to_name = {
CMD_EXIT: 'EXIT', CMD_EXIT: 'EXIT',
@ -23,6 +25,8 @@ cmd_to_name = {
CMD_EOF: 'EOF', CMD_EOF: 'EOF',
CMD_DATA: 'DATA', CMD_DATA: 'DATA',
CMD_ROUTES: 'ROUTES', CMD_ROUTES: 'ROUTES',
CMD_HOST_REQ: 'HOST_REQ',
CMD_HOST_LIST: 'HOST_LIST',
} }
@ -223,6 +227,7 @@ class Mux(Handler):
self.rsock = rsock self.rsock = rsock
self.wsock = wsock self.wsock = wsock
self.new_channel = self.got_routes = None self.new_channel = self.got_routes = None
self.got_host_req = self.got_host_list = None
self.channels = {} self.channels = {}
self.chani = 0 self.chani = 0
self.want = 0 self.want = 0
@ -284,7 +289,17 @@ class Mux(Handler):
if self.got_routes: if self.got_routes:
self.got_routes(data) self.got_routes(data)
else: else:
raise Exception('weird: got CMD_ROUTES without got_routes?') raise Exception('got CMD_ROUTES without got_routes?')
elif cmd == CMD_HOST_REQ:
if self.got_host_req:
self.got_host_req(data)
else:
raise Exception('got CMD_HOST_REQ without got_host_req?')
elif cmd == CMD_HOST_LIST:
if self.got_host_list:
self.got_host_list(data)
else:
raise Exception('got CMD_HOST_LIST without got_host_list?')
else: else:
callback = self.channels[channel] callback = self.channels[channel]
callback(cmd, data) callback(cmd, data)