sshuttle/client.py
2011-05-12 14:37:19 +10:00

486 lines
15 KiB
Python

import struct, socket, select, errno, re, signal, time
import compat.ssubprocess as ssubprocess
import helpers, ssnet, ssh, ssyslog
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
_extra_fd = os.open('/dev/null', os.O_RDONLY)
def got_signal(signum, frame):
log('exiting on signal %d\n' % signum)
sys.exit(1)
_pidname = None
def check_daemon(pidfile):
global _pidname
_pidname = os.path.abspath(pidfile)
try:
oldpid = open(_pidname).read(1024)
except IOError, e:
if e.errno == errno.ENOENT:
return # no pidfile, ok
else:
raise Fatal("can't read %s: %s" % (_pidname, e))
if not oldpid:
os.unlink(_pidname)
return # invalid pidfile, ok
oldpid = int(oldpid.strip() or 0)
if oldpid <= 0:
os.unlink(_pidname)
return # invalid pidfile, ok
try:
os.kill(oldpid, 0)
except OSError, e:
if e.errno == errno.ESRCH:
os.unlink(_pidname)
return # outdated pidfile, ok
elif e.errno == errno.EPERM:
pass
else:
raise
raise Fatal("%s: sshuttle is already running (pid=%d)"
% (_pidname, oldpid))
def daemonize():
if os.fork():
os._exit(0)
os.setsid()
if os.fork():
os._exit(0)
outfd = os.open(_pidname, os.O_WRONLY|os.O_CREAT|os.O_EXCL, 0666)
try:
os.write(outfd, '%d\n' % os.getpid())
finally:
os.close(outfd)
os.chdir("/")
# Normal exit when killed, or try/finally won't work and the pidfile won't
# be deleted.
signal.signal(signal.SIGTERM, got_signal)
si = open('/dev/null', 'r+')
os.dup2(si.fileno(), 0)
os.dup2(si.fileno(), 1)
si.close()
ssyslog.stderr_to_syslog()
def daemon_cleanup():
try:
os.unlink(_pidname)
except OSError, e:
if e.errno == errno.ENOENT:
pass
else:
raise
def original_dst(sock):
try:
SO_ORIGINAL_DST = 80
SOCKADDR_MIN = 16
sockaddr_in = sock.getsockopt(socket.SOL_IP,
SO_ORIGINAL_DST, SOCKADDR_MIN)
(proto, port, a,b,c,d) = struct.unpack('!HHBBBB', sockaddr_in[:8])
assert(socket.htons(proto) == socket.AF_INET)
ip = '%d.%d.%d.%d' % (a,b,c,d)
return (ip,port)
except socket.error, e:
if e.args[0] == errno.ENOPROTOOPT:
return sock.getsockname()
raise
class FirewallClient:
def __init__(self, port, subnets_include, subnets_exclude, dnsport):
self.port = port
self.auto_nets = []
self.subnets_include = subnets_include
self.subnets_exclude = subnets_exclude
self.dnsport = dnsport
argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] +
['-v'] * (helpers.verbose or 0) +
['--firewall', str(port), str(dnsport)])
if ssyslog._p:
argvbase += ['--syslog']
argv_tries = [
['sudo', '-p', '[local sudo] Password: '] + argvbase,
['su', '-c', ' '.join(argvbase)],
argvbase
]
# we can't use stdin/stdout=subprocess.PIPE here, as we normally would,
# because stupid Linux 'su' requires that stdin be attached to a tty.
# Instead, attach a *bidirectional* socket to its stdout, and use
# that for talking in both directions.
(s1,s2) = socket.socketpair()
def setup():
# run in the child process
s2.close()
e = None
if os.getuid() == 0:
argv_tries = argv_tries[-1:] # last entry only
for argv in argv_tries:
try:
if argv[0] == 'su':
sys.stderr.write('[local su] ')
self.p = ssubprocess.Popen(argv, stdout=s1, preexec_fn=setup)
e = None
break
except OSError, e:
pass
self.argv = argv
s1.close()
self.pfile = s2.makefile('wb+')
if e:
log('Spawning firewall manager: %r\n' % self.argv)
raise Fatal(e)
line = self.pfile.readline()
self.check()
if line != 'READY\n':
raise Fatal('%r expected READY, got %r' % (self.argv, line))
def check(self):
rv = self.p.poll()
if rv:
raise Fatal('%r returned %d' % (self.argv, rv))
def start(self):
self.pfile.write('ROUTES\n')
for (ip,width) in self.subnets_include+self.auto_nets:
self.pfile.write('%d,0,%s\n' % (width, ip))
for (ip,width) in self.subnets_exclude:
self.pfile.write('%d,1,%s\n' % (width, ip))
self.pfile.write('GO\n')
self.pfile.flush()
line = self.pfile.readline()
self.check()
if line != 'STARTED\n':
raise Fatal('%r expected STARTED, got %r' % (self.argv, line))
def sethostip(self, hostname, ip):
assert(not re.search(r'[^-\w]', hostname))
assert(not re.search(r'[^0-9.]', ip))
self.pfile.write('HOST %s,%s\n' % (hostname, ip))
self.pfile.flush()
def done(self):
self.pfile.close()
rv = self.p.wait()
if rv:
raise Fatal('cleanup: %r returned %d' % (self.argv, rv))
def unpack_dns_name(buf, off):
name = ''
while True:
# get the next octet from buffer
n = ord(buf[off])
# zero octet terminates name
if n == 0:
off += 1
break
# top two bits on
# => a 2 octect pointer to another part of the buffer
elif (n & 0xc0) == 0xc0:
ptr = struct.unpack('>H', buf[off:off+2])[0] & 0x3fff
off = ptr
# an octet representing the number of bytes to process.
else:
off += 1
name = name + buf[off:off+n] + '.'
off += n
return name.strip('.'), off
class dnspkt:
def unpack(self, buf, off):
l = len(buf)
(self.id, self.op, self.qdcount, self.ancount, self.nscount, self.arcount) = struct.unpack("!HHHHHH",buf[off:off+12])
off += 12
self.q = []
for i in range(self.qdcount):
qname, off = unpack_dns_name(buf, off)
qtype, qclass = struct.unpack('!HH', buf[off:off+4])
off += 4
self.q.append( (qname,qtype,qclass) )
return off
def match_q_domain(self, domain):
l = len(domain)
for qname,qtype,qclass in self.q:
if qname[-l:] == domain:
if l==len(qname):
return True
elif qname[-l-1] == '.':
return True
return False
def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets,
syslog, daemon):
handlers = []
if helpers.verbose >= 1:
helpers.logprefix = 'c : '
else:
helpers.logprefix = 'client: '
debug1('connecting to server...\n')
try:
(serverproc, serversock) = ssh.connect(ssh_cmd, remotename, python,
stderr=ssyslog._p and ssyslog._p.stdin,
options=dict(latency_control=latency_control))
except socket.error, e:
if e.args[0] == errno.EPIPE:
raise Fatal("failed to establish ssh session (1)")
else:
raise
mux = Mux(serversock, serversock)
handlers.append(mux)
expected = 'SSHUTTLE0001'
try:
initstring = serversock.recv(len(expected))
except socket.error, e:
if e.args[0] == errno.ECONNRESET:
raise Fatal("failed to establish ssh session (2)")
else:
raise
rv = serverproc.poll()
if rv:
raise Fatal('server died with error code %d' % rv)
if initstring != expected:
raise Fatal('expected server init string %r; got %r'
% (expected, initstring))
debug1('connected.\n')
print 'Connected.'
sys.stdout.flush()
if daemon:
daemonize()
log('daemonizing (%s).\n' % _pidname)
elif syslog:
debug1('switching to syslog.\n')
ssyslog.stderr_to_syslog()
def onroutes(routestr):
if auto_nets:
for line in routestr.strip().split('\n'):
(ip,width) = line.split(',', 1)
fw.auto_nets.append((ip,int(width)))
# we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection!
#
# Moreover, now that we have the --auto-nets option, we have to wait
# for the server to send us that message anyway. Even if we haven't
# set --auto-nets, we might as well wait for the message first, then
# ignore its contents.
mux.got_routes = None
fw.start()
mux.got_routes = onroutes
def onhostlist(hostlist):
debug2('got host list: %r\n' % hostlist)
for line in hostlist.strip().split():
if line:
name,ip = line.split(',', 1)
fw.sethostip(name, ip)
mux.got_host_list = onhostlist
def onaccept():
global _extra_fd
try:
sock,srcip = listener.accept()
except socket.error, e:
if e.args[0] in [errno.EMFILE, errno.ENFILE]:
debug1('Rejected incoming connection: too many open files!\n')
# free up an fd so we can eat the connection
os.close(_extra_fd)
try:
sock,srcip = listener.accept()
sock.close()
finally:
_extra_fd = os.open('/dev/null', os.O_RDONLY)
return
else:
raise
dstip = original_dst(sock)
debug1('Accept: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1],
dstip[0],dstip[1]))
if dstip[1] == listener.getsockname()[1] and islocal(dstip[0]):
debug1("-- ignored: that's my address!\n")
sock.close()
return
chan = mux.next_channel()
if not chan:
log('warning: too many open channels. Discarded connection.\n')
sock.close()
return
mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip)
outwrap = MuxWrapper(mux, chan)
handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
handlers.append(Handler([listener], onaccept))
dnsreqs = {}
dnsforwards = {}
def dns_done(chan, data):
peer,timeout = dnsreqs.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
if peer:
del dnsreqs[chan]
debug3('doing sendto %r\n' % (peer,))
dnslistener.sendto(data, peer)
def ondns():
pkt,peer = dnslistener.recvfrom(4096)
now = time.time()
if pkt:
debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt)))
dns = dnspkt()
dns.unpack(pkt, 0)
match=False
if dns_domains is not None:
for domain in dns_domains:
if dns.match_q_domain(domain):
match=True
break
if match:
debug3("We need to redirect this request remotely\n")
chan = mux.next_channel()
dnsreqs[chan] = peer,now+30
mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan,data)
else:
debug3("We need to forward this request locally\n")
dnsforwarder.sendto(pkt, dns_to)
dnsforwards[dns.id] = peer,now+30
for chan,(peer,timeout) in dnsreqs.items():
if timeout < now:
del dnsreqs[chan]
for chan,(peer,timeout) in dnsforwards.items():
if timeout < now:
del dnsforwards[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
debug3('Remaining DNS forwards: %d\n' % len(dnsforwards))
if dnslistener:
handlers.append(Handler([dnslistener], ondns))
def ondnsforward():
debug1("We got a response.\n")
pkt,server = dnsforwarder.recvfrom(4096)
now = time.time()
if server[0] != dns_to[0] or server[1] != dns_to[1]:
debug1("Ooops. The response came from the wrong server. Ignoring\n")
else:
dns = dnspkt()
dns.unpack(pkt, 0)
chan=dns.id
peer,timeout = dnsforwards.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
if peer:
del dnsforwards[chan]
debug3('doing sendto %r\n' % (peer,))
dnslistener.sendto(pkt, peer)
if dnsforwarder:
handlers.append(Handler([dnsforwarder], ondnsforward))
if seed_hosts != None:
debug1('seed_hosts: %r\n' % seed_hosts)
mux.send(0, ssnet.CMD_HOST_REQ, '\n'.join(seed_hosts))
while 1:
rv = serverproc.poll()
if rv:
raise Fatal('server died with error code %d' % rv)
ssnet.runonce(handlers, mux)
if latency_control:
mux.check_fullness()
mux.callback()
def main(listenip, ssh_cmd, remotename, python, latency_control,
dns, dns_domains, dns_to,
seed_hosts, auto_nets,
subnets_include, subnets_exclude, syslog, daemon, pidfile):
if syslog:
ssyslog.start_syslog()
if daemon:
try:
check_daemon(pidfile)
except Fatal, e:
log("%s\n" % e)
return 5
debug1('Starting sshuttle proxy.\n')
if listenip[1]:
ports = [listenip[1]]
else:
ports = xrange(12300,9000,-1)
last_e = None
bound = False
debug2('Binding:')
for port in ports:
debug2(' %d' % port)
listener = socket.socket()
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dnslistener = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
dnslistener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
listener.bind((listenip[0], port))
dnslistener.bind((listenip[0], port))
bound = True
break
except socket.error, e:
last_e = e
debug2('\n')
if not bound:
assert(last_e)
raise last_e
listener.listen(10)
listenip = listener.getsockname()
debug1('Listening on %r.\n' % (listenip,))
if dns:
dnsip = dnslistener.getsockname()
debug1('DNS listening on %r.\n' % (dnsip,))
dnsport = dnsip[1]
dnsforwarder = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
dnsforwarder.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dnsforwarder.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
else:
dnsport = 0
dnslistener = None
dnsforwarder = None
fw = FirewallClient(listenip[1], subnets_include, subnets_exclude, dnsport)
try:
return _main(listener, fw, ssh_cmd, remotename,
python, latency_control,
dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets, syslog, daemon)
finally:
try:
if daemon:
# it's not our child anymore; can't waitpid
fw.p.returncode = 0
fw.done()
finally:
if daemon:
daemon_cleanup()