PEP8 fixes.

This commit is contained in:
Brian May 2014-09-16 10:24:16 +10:00
parent 5529a04cc9
commit f1c79c7e92
16 changed files with 569 additions and 377 deletions

View File

@ -1,4 +1,5 @@
import sys, zlib import sys
import zlib
z = zlib.decompressobj() z = zlib.decompressobj()
mainmod = sys.modules[__name__] mainmod = sys.modules[__name__]

View File

@ -1,21 +1,30 @@
import struct, select, errno, re, signal, time import struct
import errno
import re
import signal
import time
import compat.ssubprocess as ssubprocess import compat.ssubprocess as ssubprocess
import helpers, ssnet, ssh, ssyslog import helpers
import os
import ssnet
import ssh
import ssyslog
import sys
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import * from helpers import log, debug1, debug2, debug3, Fatal, islocal
recvmsg = None recvmsg = None
try: try:
# try getting recvmsg from python # try getting recvmsg from python
import socket as pythonsocket import socket as pythonsocket
getattr(pythonsocket.socket,"recvmsg") getattr(pythonsocket.socket, "recvmsg")
socket = pythonsocket socket = pythonsocket
recvmsg = "python" recvmsg = "python"
except AttributeError: except AttributeError:
# try getting recvmsg from socket_ext library # try getting recvmsg from socket_ext library
try: try:
import socket_ext import socket_ext
getattr(socket_ext.socket,"recvmsg") getattr(socket_ext.socket, "recvmsg")
socket = socket_ext socket = socket_ext
recvmsg = "socket_ext" recvmsg = "socket_ext"
except ImportError: except ImportError:
@ -23,6 +32,7 @@ except AttributeError:
_extra_fd = os.open('/dev/null', os.O_RDONLY) _extra_fd = os.open('/dev/null', os.O_RDONLY)
def got_signal(signum, frame): def got_signal(signum, frame):
log('exiting on signal %d\n' % signum) log('exiting on signal %d\n' % signum)
sys.exit(1) sys.exit(1)
@ -40,60 +50,64 @@ IPV6_RECVORIGDSTADDR = IPV6_ORIGDSTADDR
if recvmsg == "python": if recvmsg == "python":
def recv_udp(listener, bufsize): def recv_udp(listener, bufsize):
debug3('Accept UDP python using recvmsg.\n') debug3('Accept UDP python using recvmsg.\n')
data, ancdata, msg_flags, srcip = listener.recvmsg(4096,socket.CMSG_SPACE(24)) data, ancdata, msg_flags, srcip = listener.recvmsg(
4096, socket.CMSG_SPACE(24))
dstip = None dstip = None
family = None family = None
for cmsg_level, cmsg_type, cmsg_data in ancdata: for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_IP and cmsg_type == IP_ORIGDSTADDR: if cmsg_level == socket.SOL_IP and cmsg_type == IP_ORIGDSTADDR:
family,port = struct.unpack('=HH', cmsg_data[0:4]) family, port = struct.unpack('=HH', cmsg_data[0:4])
port = socket.htons(port) port = socket.htons(port)
if family == socket.AF_INET: if family == socket.AF_INET:
start = 4 start = 4
length = 4 length = 4
else: else:
raise Fatal("Unsupported socket type '%s'"%family) raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(family, cmsg_data[start:start+length]) ip = socket.inet_ntop(family, cmsg_data[start:start + length])
dstip = (ip, port) dstip = (ip, port)
break break
elif cmsg_level == SOL_IPV6 and cmsg_type == IPV6_ORIGDSTADDR: elif cmsg_level == SOL_IPV6 and cmsg_type == IPV6_ORIGDSTADDR:
family,port = struct.unpack('=HH', cmsg_data[0:4]) family, port = struct.unpack('=HH', cmsg_data[0:4])
port = socket.htons(port) port = socket.htons(port)
if family == socket.AF_INET6: if family == socket.AF_INET6:
start = 8 start = 8
length = 16 length = 16
else: else:
raise Fatal("Unsupported socket type '%s'"%family) raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(family, cmsg_data[start:start+length]) ip = socket.inet_ntop(family, cmsg_data[start:start + length])
dstip = (ip, port) dstip = (ip, port)
break break
return (srcip, dstip, data) return (srcip, dstip, data)
elif recvmsg == "socket_ext": elif recvmsg == "socket_ext":
def recv_udp(listener, bufsize): def recv_udp(listener, bufsize):
debug3('Accept UDP using socket_ext recvmsg.\n') debug3('Accept UDP using socket_ext recvmsg.\n')
srcip, data, adata, flags = listener.recvmsg((bufsize,),socket.CMSG_SPACE(24)) srcip, data, adata, flags = listener.recvmsg(
(bufsize,), socket.CMSG_SPACE(24))
dstip = None dstip = None
family = None family = None
for a in adata: for a in adata:
if a.cmsg_level == socket.SOL_IP and a.cmsg_type == IP_ORIGDSTADDR: if a.cmsg_level == socket.SOL_IP and a.cmsg_type == IP_ORIGDSTADDR:
family,port = struct.unpack('=HH', a.cmsg_data[0:4]) family, port = struct.unpack('=HH', a.cmsg_data[0:4])
port = socket.htons(port) port = socket.htons(port)
if family == socket.AF_INET: if family == socket.AF_INET:
start = 4 start = 4
length = 4 length = 4
else: else:
raise Fatal("Unsupported socket type '%s'"%family) raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(family, a.cmsg_data[start:start+length]) ip = socket.inet_ntop(
family, a.cmsg_data[start:start + length])
dstip = (ip, port) dstip = (ip, port)
break break
elif a.cmsg_level == SOL_IPV6 and a.cmsg_type == IPV6_ORIGDSTADDR: elif a.cmsg_level == SOL_IPV6 and a.cmsg_type == IPV6_ORIGDSTADDR:
family,port = struct.unpack('=HH', a.cmsg_data[0:4]) family, port = struct.unpack('=HH', a.cmsg_data[0:4])
port = socket.htons(port) port = socket.htons(port)
if family == socket.AF_INET6: if family == socket.AF_INET6:
start = 8 start = 8
length = 16 length = 16
else: else:
raise Fatal("Unsupported socket type '%s'"%family) raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(family, a.cmsg_data[start:start+length]) ip = socket.inet_ntop(
family, a.cmsg_data[start:start + length])
dstip = (ip, port) dstip = (ip, port)
break break
return (srcip, dstip, data[0]) return (srcip, dstip, data[0])
@ -142,7 +156,7 @@ def daemonize():
if os.fork(): if os.fork():
os._exit(0) os._exit(0)
outfd = os.open(_pidname, os.O_WRONLY|os.O_CREAT|os.O_EXCL, 0666) outfd = os.open(_pidname, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0666)
try: try:
os.write(outfd, '%d\n' % os.getpid()) os.write(outfd, '%d\n' % os.getpid())
finally: finally:
@ -152,7 +166,7 @@ def daemonize():
# Normal exit when killed, or try/finally won't work and the pidfile won't # Normal exit when killed, or try/finally won't work and the pidfile won't
# be deleted. # be deleted.
signal.signal(signal.SIGTERM, got_signal) signal.signal(signal.SIGTERM, got_signal)
si = open('/dev/null', 'r+') si = open('/dev/null', 'r+')
os.dup2(si.fileno(), 0) os.dup2(si.fileno(), 0)
os.dup2(si.fileno(), 1) os.dup2(si.fileno(), 1)
@ -177,10 +191,10 @@ def original_dst(sock):
SOCKADDR_MIN = 16 SOCKADDR_MIN = 16
sockaddr_in = sock.getsockopt(socket.SOL_IP, sockaddr_in = sock.getsockopt(socket.SOL_IP,
SO_ORIGINAL_DST, SOCKADDR_MIN) SO_ORIGINAL_DST, SOCKADDR_MIN)
(proto, port, a,b,c,d) = struct.unpack('!HHBBBB', sockaddr_in[:8]) (proto, port, a, b, c, d) = struct.unpack('!HHBBBB', sockaddr_in[:8])
assert(socket.htons(proto) == socket.AF_INET) assert(socket.htons(proto) == socket.AF_INET)
ip = '%d.%d.%d.%d' % (a,b,c,d) ip = '%d.%d.%d.%d' % (a, b, c, d)
return (ip,port) return (ip, port)
except socket.error, e: except socket.error, e:
if e.args[0] == errno.ENOPROTOOPT: if e.args[0] == errno.ENOPROTOOPT:
return sock.getsockname() return sock.getsockname()
@ -201,9 +215,15 @@ class MultiListener:
def add_handler(self, handlers, callback, method, mux): def add_handler(self, handlers, callback, method, mux):
if self.v6: if self.v6:
handlers.append(Handler([self.v6], lambda: callback(self.v6, method, mux, handlers))) handlers.append(
Handler(
[self.v6],
lambda: callback(self.v6, method, mux, handlers)))
if self.v4: if self.v4:
handlers.append(Handler([self.v4], lambda: callback(self.v4, method, mux, handlers))) handlers.append(
Handler(
[self.v4],
lambda: callback(self.v4, method, mux, handlers)))
def listen(self, backlog): def listen(self, backlog):
if self.v6: if self.v6:
@ -239,15 +259,17 @@ class MultiListener:
class FirewallClient: class FirewallClient:
def __init__(self, port_v6, port_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method, udp):
def __init__(self, port_v6, port_v4, subnets_include, subnets_exclude,
dnsport_v6, dnsport_v4, method, udp):
self.auto_nets = [] self.auto_nets = []
self.subnets_include = subnets_include self.subnets_include = subnets_include
self.subnets_exclude = subnets_exclude self.subnets_exclude = subnets_exclude
argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] + argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] +
['-v'] * (helpers.verbose or 0) + ['-v'] * (helpers.verbose or 0) +
['--firewall', str(port_v6), str(port_v4), ['--firewall', str(port_v6), str(port_v4),
str(dnsport_v6), str(dnsport_v4), str(dnsport_v6), str(dnsport_v4),
method, str(int(udp))]) method, str(int(udp))])
if ssyslog._p: if ssyslog._p:
argvbase += ['--syslog'] argvbase += ['--syslog']
argv_tries = [ argv_tries = [
@ -260,7 +282,8 @@ class FirewallClient:
# because stupid Linux 'su' requires that stdin be attached to a tty. # because stupid Linux 'su' requires that stdin be attached to a tty.
# Instead, attach a *bidirectional* socket to its stdout, and use # Instead, attach a *bidirectional* socket to its stdout, and use
# that for talking in both directions. # that for talking in both directions.
(s1,s2) = socket.socketpair() (s1, s2) = socket.socketpair()
def setup(): def setup():
# run in the child process # run in the child process
s2.close() s2.close()
@ -295,9 +318,9 @@ class FirewallClient:
def start(self): def start(self):
self.pfile.write('ROUTES\n') self.pfile.write('ROUTES\n')
for (family,ip,width) in self.subnets_include+self.auto_nets: for (family, ip, width) in self.subnets_include + self.auto_nets:
self.pfile.write('%d,%d,0,%s\n' % (family, width, ip)) self.pfile.write('%d,%d,0,%s\n' % (family, width, ip))
for (family,ip,width) in self.subnets_exclude: for (family, ip, width) in self.subnets_exclude:
self.pfile.write('%d,%d,1,%s\n' % (family, width, ip)) self.pfile.write('%d,%d,1,%s\n' % (family, width, ip))
self.pfile.write('GO\n') self.pfile.write('GO\n')
self.pfile.flush() self.pfile.flush()
@ -321,14 +344,16 @@ class FirewallClient:
dnsreqs = {} dnsreqs = {}
udp_by_src = {} udp_by_src = {}
def expire_connections(now, mux): def expire_connections(now, mux):
for chan,timeout in dnsreqs.items(): for chan, timeout in dnsreqs.items():
if timeout < now: if timeout < now:
debug3('expiring dnsreqs channel=%d\n' % chan) debug3('expiring dnsreqs channel=%d\n' % chan)
del mux.channels[chan] del mux.channels[chan]
del dnsreqs[chan] del dnsreqs[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs)) debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
for peer,(chan,timeout) in udp_by_src.items(): for peer, (chan, timeout) in udp_by_src.items():
if timeout < now: if timeout < now:
debug3('expiring UDP channel channel=%d peer=%r\n' % (chan, peer)) debug3('expiring UDP channel channel=%d peer=%r\n' % (chan, peer))
mux.send(chan, ssnet.CMD_UDP_CLOSE, '') mux.send(chan, ssnet.CMD_UDP_CLOSE, '')
@ -340,14 +365,14 @@ def expire_connections(now, mux):
def onaccept_tcp(listener, method, mux, handlers): def onaccept_tcp(listener, method, mux, handlers):
global _extra_fd global _extra_fd
try: try:
sock,srcip = listener.accept() sock, srcip = listener.accept()
except socket.error, e: except socket.error, e:
if e.args[0] in [errno.EMFILE, errno.ENFILE]: if e.args[0] in [errno.EMFILE, errno.ENFILE]:
debug1('Rejected incoming connection: too many open files!\n') debug1('Rejected incoming connection: too many open files!\n')
# free up an fd so we can eat the connection # free up an fd so we can eat the connection
os.close(_extra_fd) os.close(_extra_fd)
try: try:
sock,srcip = listener.accept() sock, srcip = listener.accept()
sock.close() sock.close()
finally: finally:
_extra_fd = os.open('/dev/null', os.O_RDONLY) _extra_fd = os.open('/dev/null', os.O_RDONLY)
@ -355,11 +380,11 @@ def onaccept_tcp(listener, method, mux, handlers):
else: else:
raise raise
if method == "tproxy": if method == "tproxy":
dstip = sock.getsockname(); dstip = sock.getsockname()
else: else:
dstip = original_dst(sock) dstip = original_dst(sock)
debug1('Accept TCP: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1], debug1('Accept TCP: %s:%r -> %s:%r.\n' % (srcip[0], srcip[1],
dstip[0],dstip[1])) dstip[0], dstip[1]))
if dstip[1] == sock.getsockname()[1] and islocal(dstip[0], sock.family): if dstip[1] == sock.getsockname()[1] and islocal(dstip[0], sock.family):
debug1("-- ignored: that's my address!\n") debug1("-- ignored: that's my address!\n")
sock.close() sock.close()
@ -369,16 +394,17 @@ def onaccept_tcp(listener, method, mux, handlers):
log('warning: too many open channels. Discarded connection.\n') log('warning: too many open channels. Discarded connection.\n')
sock.close() sock.close()
return return
mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' % (sock.family, dstip[0], dstip[1])) mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' %
(sock.family, dstip[0], dstip[1]))
outwrap = MuxWrapper(mux, chan) outwrap = MuxWrapper(mux, chan)
handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
expire_connections(time.time(), mux) expire_connections(time.time(), mux)
def udp_done(chan, data, method, family, dstip): def udp_done(chan, data, method, family, dstip):
(src,srcport,data) = data.split(",",2) (src, srcport, data) = data.split(",", 2)
srcip = (src,int(srcport)) srcip = (src, int(srcport))
debug3('doing send from %r to %r\n' % (srcip,dstip,)) debug3('doing send from %r to %r\n' % (srcip, dstip,))
try: try:
sender = socket.socket(family, socket.SOCK_DGRAM) sender = socket.socket(family, socket.SOCK_DGRAM)
@ -388,36 +414,39 @@ def udp_done(chan, data, method, family, dstip):
sender.sendto(data, dstip) sender.sendto(data, dstip)
sender.close() sender.close()
except socket.error, e: except socket.error, e:
debug1('-- ignored socket error sending UDP data: %r\n'%e) debug1('-- ignored socket error sending UDP data: %r\n' % e)
def onaccept_udp(listener, method, mux, handlers): def onaccept_udp(listener, method, mux, handlers):
now = time.time() now = time.time()
srcip, dstip, data = recv_udp(listener, 4096) srcip, dstip, data = recv_udp(listener, 4096)
if not dstip: if not dstip:
debug1("-- ignored UDP from %r: couldn't determine destination IP address\n" % (srcip,)) debug1(
"-- ignored UDP from %r: "
"couldn't determine destination IP address\n" % (srcip,))
return return
debug1('Accept UDP: %r -> %r.\n' % (srcip,dstip,)) debug1('Accept UDP: %r -> %r.\n' % (srcip, dstip,))
if srcip in udp_by_src: if srcip in udp_by_src:
chan,timeout = udp_by_src[srcip] chan, timeout = udp_by_src[srcip]
else: else:
chan = mux.next_channel() chan = mux.next_channel()
mux.channels[chan] = lambda cmd,data: udp_done(chan, data, method, listener.family, dstip=srcip) mux.channels[chan] = lambda cmd, data: udp_done(
chan, data, method, listener.family, dstip=srcip)
mux.send(chan, ssnet.CMD_UDP_OPEN, listener.family) mux.send(chan, ssnet.CMD_UDP_OPEN, listener.family)
udp_by_src[srcip] = chan,now+30 udp_by_src[srcip] = chan, now + 30
hdr = "%s,%r,"%(dstip[0], dstip[1]) hdr = "%s,%r," % (dstip[0], dstip[1])
mux.send(chan, ssnet.CMD_UDP_DATA, hdr+data) mux.send(chan, ssnet.CMD_UDP_DATA, hdr + data)
expire_connections(now, mux) expire_connections(now, mux)
def dns_done(chan, data, method, sock, srcip, dstip, mux): def dns_done(chan, data, method, sock, srcip, dstip, mux):
debug3('dns_done: channel=%d src=%r dst=%r\n' % (chan,srcip,dstip)) debug3('dns_done: channel=%d src=%r dst=%r\n' % (chan, srcip, dstip))
del mux.channels[chan] del mux.channels[chan]
del dnsreqs[chan] del dnsreqs[chan]
if method == "tproxy": if method == "tproxy":
debug3('doing send from %r to %r\n' % (srcip,dstip,)) debug3('doing send from %r to %r\n' % (srcip, dstip,))
sender = socket.socket(sock.family, socket.SOCK_DGRAM) sender = socket.socket(sock.family, socket.SOCK_DGRAM)
sender.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sender.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sender.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1) sender.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
@ -433,17 +462,21 @@ def ondns(listener, method, mux, handlers):
now = time.time() now = time.time()
srcip, dstip, data = recv_udp(listener, 4096) srcip, dstip, data = recv_udp(listener, 4096)
if method == "tproxy" and not dstip: if method == "tproxy" and not dstip:
debug1("-- ignored UDP from %r: couldn't determine destination IP address\n" % (srcip,)) debug1(
"-- ignored UDP from %r: "
"couldn't determine destination IP address\n" % (srcip,))
return return
debug1('DNS request from %r to %r: %d bytes\n' % (srcip,dstip,len(data))) debug1('DNS request from %r to %r: %d bytes\n' % (srcip, dstip, len(data)))
chan = mux.next_channel() chan = mux.next_channel()
dnsreqs[chan] = now+30 dnsreqs[chan] = now + 30
mux.send(chan, ssnet.CMD_DNS_REQ, data) mux.send(chan, ssnet.CMD_DNS_REQ, data)
mux.channels[chan] = lambda cmd,data: dns_done(chan, data, method, listener, srcip=dstip, dstip=srcip, mux=mux) mux.channels[chan] = lambda cmd, data: dns_done(
chan, data, method, listener, srcip=dstip, dstip=srcip, mux=mux)
expire_connections(now, mux) expire_connections(now, mux)
def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_control, def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
python, latency_control,
dns_listener, method, seed_hosts, auto_nets, dns_listener, method, seed_hosts, auto_nets,
syslog, daemon): syslog, daemon):
handlers = [] handlers = []
@ -454,9 +487,10 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
debug1('connecting to server...\n') debug1('connecting to server...\n')
try: try:
(serverproc, serversock) = ssh.connect(ssh_cmd, remotename, python, (serverproc, serversock) = ssh.connect(
stderr=ssyslog._p and ssyslog._p.stdin, ssh_cmd, remotename, python,
options=dict(latency_control=latency_control, method=method)) stderr=ssyslog._p and ssyslog._p.stdin,
options=dict(latency_control=latency_control, method=method))
except socket.error, e: except socket.error, e:
if e.args[0] == errno.EPIPE: if e.args[0] == errno.EPIPE:
raise Fatal("failed to establish ssh session (1)") raise Fatal("failed to establish ssh session (1)")
@ -466,7 +500,7 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
handlers.append(mux) handlers.append(mux)
expected = 'SSHUTTLE0001' expected = 'SSHUTTLE0001'
try: try:
v = 'x' v = 'x'
while v and v != '\0': while v and v != '\0':
@ -480,14 +514,14 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
raise Fatal("failed to establish ssh session (2)") raise Fatal("failed to establish ssh session (2)")
else: else:
raise raise
rv = serverproc.poll() rv = serverproc.poll()
if rv: if rv:
raise Fatal('server died with error code %d' % rv) raise Fatal('server died with error code %d' % rv)
if initstring != expected: if initstring != expected:
raise Fatal('expected server init string %r; got %r' raise Fatal('expected server init string %r; got %r'
% (expected, initstring)) % (expected, initstring))
debug1('connected.\n') debug1('connected.\n')
print 'Connected.' print 'Connected.'
sys.stdout.flush() sys.stdout.flush()
@ -501,8 +535,8 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
def onroutes(routestr): def onroutes(routestr):
if auto_nets: if auto_nets:
for line in routestr.strip().split('\n'): for line in routestr.strip().split('\n'):
(family,ip,width) = line.split(',', 2) (family, ip, width) = line.split(',', 2)
fw.auto_nets.append((family,ip,int(width))) fw.auto_nets.append((family, ip, int(width)))
# we definitely want to do this *after* starting ssh, or we might end # we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection! # up intercepting the ssh connection!
@ -519,7 +553,7 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
debug2('got host list: %r\n' % hostlist) debug2('got host list: %r\n' % hostlist)
for line in hostlist.strip().split(): for line in hostlist.strip().split():
if line: if line:
name,ip = line.split(',', 1) name, ip = line.split(',', 1)
fw.sethostip(name, ip) fw.sethostip(name, ip)
mux.got_host_list = onhostlist mux.got_host_list = onhostlist
@ -531,15 +565,15 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
if dns_listener: if dns_listener:
dns_listener.add_handler(handlers, ondns, method, mux) dns_listener.add_handler(handlers, ondns, method, mux)
if seed_hosts != None: if seed_hosts is not None:
debug1('seed_hosts: %r\n' % seed_hosts) debug1('seed_hosts: %r\n' % seed_hosts)
mux.send(0, ssnet.CMD_HOST_REQ, '\n'.join(seed_hosts)) mux.send(0, ssnet.CMD_HOST_REQ, '\n'.join(seed_hosts))
while 1: while 1:
rv = serverproc.poll() rv = serverproc.poll()
if rv: if rv:
raise Fatal('server died with error code %d' % rv) raise Fatal('server died with error code %d' % rv)
ssnet.runonce(handlers, mux) ssnet.runonce(handlers, mux)
if latency_control: if latency_control:
mux.check_fullness() mux.check_fullness()
@ -562,7 +596,7 @@ def main(listenip_v6, listenip_v4,
debug1('Starting sshuttle proxy.\n') debug1('Starting sshuttle proxy.\n')
if recvmsg is not None: if recvmsg is not None:
debug1("recvmsg %s support enabled.\n"%recvmsg) debug1("recvmsg %s support enabled.\n" % recvmsg)
if method == "tproxy": if method == "tproxy":
if recvmsg is not None: if recvmsg is not None:
@ -580,10 +614,10 @@ def main(listenip_v6, listenip_v4,
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]: if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
# if both ports given, no need to search for a spare port # if both ports given, no need to search for a spare port
ports = [ 0, ] ports = [0, ]
else: else:
# if at least one port missing, we have to search # if at least one port missing, we have to search
ports = xrange(12300,9000,-1) ports = xrange(12300, 9000, -1)
# search for free ports and try to bind # search for free ports and try to bind
last_e = None last_e = None
@ -606,7 +640,7 @@ def main(listenip_v6, listenip_v4,
lv6 = listenip_v6 lv6 = listenip_v6
redirectport_v6 = lv6[1] redirectport_v6 = lv6[1]
elif listenip_v6: elif listenip_v6:
lv6 = (listenip_v6[0],port) lv6 = (listenip_v6[0], port)
redirectport_v6 = port redirectport_v6 = port
else: else:
lv6 = None lv6 = None
@ -616,7 +650,7 @@ def main(listenip_v6, listenip_v4,
lv4 = listenip_v4 lv4 = listenip_v4
redirectport_v4 = lv4[1] redirectport_v4 = lv4[1]
elif listenip_v4: elif listenip_v4:
lv4 = (listenip_v4[0],port) lv4 = (listenip_v4[0], port)
redirectport_v4 = port redirectport_v4 = port
else: else:
lv4 = None lv4 = None
@ -646,20 +680,20 @@ def main(listenip_v6, listenip_v4,
if dns: if dns:
# search for spare port for DNS # search for spare port for DNS
debug2('Binding DNS:') debug2('Binding DNS:')
ports = xrange(12300,9000,-1) ports = xrange(12300, 9000, -1)
for port in ports: for port in ports:
debug2(' %d' % port) debug2(' %d' % port)
dns_listener = MultiListener(socket.SOCK_DGRAM) dns_listener = MultiListener(socket.SOCK_DGRAM)
if listenip_v6: if listenip_v6:
lv6 = (listenip_v6[0],port) lv6 = (listenip_v6[0], port)
dnsport_v6 = port dnsport_v6 = port
else: else:
lv6 = None lv6 = None
dnsport_v6 = 0 dnsport_v6 = 0
if listenip_v4: if listenip_v4:
lv4 = (listenip_v4[0],port) lv4 = (listenip_v4[0], port)
dnsport_v4 = port dnsport_v4 = port
else: else:
lv4 = None lv4 = None
@ -684,20 +718,23 @@ def main(listenip_v6, listenip_v4,
dnsport_v4 = 0 dnsport_v4 = 0
dns_listener = None dns_listener = None
fw = FirewallClient(redirectport_v6, redirectport_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method, udp) fw = FirewallClient(redirectport_v6, redirectport_v4, subnets_include,
subnets_exclude, dnsport_v6, dnsport_v4, method, udp)
if fw.method == "tproxy": if fw.method == "tproxy":
tcp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1) tcp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if udp_listener: if udp_listener:
udp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1) udp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if udp_listener.v4 is not None: if udp_listener.v4 is not None:
udp_listener.v4.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1) udp_listener.v4.setsockopt(
socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
if udp_listener.v6 is not None: if udp_listener.v6 is not None:
udp_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1) udp_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1)
if dns_listener: if dns_listener:
dns_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1) dns_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if dns_listener.v4 is not None: if dns_listener.v4 is not None:
dns_listener.v4.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1) dns_listener.v4.setsockopt(
socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
if dns_listener.v6 is not None: if dns_listener.v6 is not None:
dns_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1) dns_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1)

View File

@ -1,7 +1,13 @@
import re, errno, socket, select, struct import errno
import socket
import select
import struct
import compat.ssubprocess as ssubprocess import compat.ssubprocess as ssubprocess
import helpers, ssyslog import ssyslog
from helpers import * import sys
import os
from helpers import log, debug1, debug3, islocal, Fatal, family_to_string, \
resolvconf_nameservers
# python doesn't have a definition for this # python doesn't have a definition for this
IPPROTO_DIVERT = 254 IPPROTO_DIVERT = 254
@ -20,9 +26,9 @@ def ipt_chain_exists(family, table, name):
elif family == socket.AF_INET: elif family == socket.AF_INET:
cmd = 'iptables' cmd = 'iptables'
else: else:
raise Exception('Unsupported family "%s"'%family_to_string(family)) raise Exception('Unsupported family "%s"' % family_to_string(family))
argv = [cmd, '-t', table, '-nL'] argv = [cmd, '-t', table, '-nL']
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE) p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
for line in p.stdout: for line in p.stdout:
if line.startswith('Chain %s ' % name): if line.startswith('Chain %s ' % name):
return True return True
@ -37,7 +43,7 @@ def _ipt(family, table, *args):
elif family == socket.AF_INET: elif family == socket.AF_INET:
argv = ['iptables', '-t', table] + list(args) argv = ['iptables', '-t', table] + list(args)
else: else:
raise Exception('Unsupported family "%s"'%family_to_string(family)) raise Exception('Unsupported family "%s"' % family_to_string(family))
debug1('>> %s\n' % ' '.join(argv)) debug1('>> %s\n' % ' '.join(argv))
rv = ssubprocess.call(argv) rv = ssubprocess.call(argv)
if rv: if rv:
@ -45,6 +51,8 @@ def _ipt(family, table, *args):
_no_ttl_module = False _no_ttl_module = False
def _ipt_ttl(family, *args): def _ipt_ttl(family, *args):
global _no_ttl_module global _no_ttl_module
if not _no_ttl_module: if not _no_ttl_module:
@ -72,13 +80,17 @@ def _ipt_ttl(family, *args):
def do_iptables_nat(port, dnsport, family, subnets, udp): def do_iptables_nat(port, dnsport, family, subnets, udp):
# only ipv4 supported with NAT # only ipv4 supported with NAT
if family != socket.AF_INET: if family != socket.AF_INET:
raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family)) raise Exception(
'Address family "%s" unsupported by nat method'
% family_to_string(family))
if udp: if udp:
raise Exception("UDP not supported by nat method") raise Exception("UDP not supported by nat method")
table = "nat" table = "nat"
def ipt(*args): def ipt(*args):
return _ipt(family, table, *args) return _ipt(family, table, *args)
def ipt_ttl(*args): def ipt_ttl(*args):
return _ipt_ttl(family, table, *args) return _ipt_ttl(family, table, *args)
@ -103,20 +115,21 @@ def do_iptables_nat(port, dnsport, family, subnets, udp):
# to least-specific, and at any given level of specificity, we want # to least-specific, and at any given level of specificity, we want
# excludes to come first. That's why the columns are in such a non- # excludes to come first. That's why the columns are in such a non-
# intuitive order. # intuitive order.
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True): for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude: if sexclude:
ipt('-A', chain, '-j', 'RETURN', ipt('-A', chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp') '-p', 'tcp')
else: else:
ipt_ttl('-A', chain, '-j', 'REDIRECT', ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp', '-p', 'tcp',
'--to-ports', str(port)) '--to-ports', str(port))
if dnsport: if dnsport:
nslist = resolvconf_nameservers() nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist): for f, ip in filter(lambda i: i[0] == family, nslist):
ipt_ttl('-A', chain, '-j', 'REDIRECT', ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/32' % ip, '--dest', '%s/32' % ip,
'-p', 'udp', '-p', 'udp',
@ -126,15 +139,19 @@ def do_iptables_nat(port, dnsport, family, subnets, udp):
def do_iptables_tproxy(port, dnsport, family, subnets, udp): def do_iptables_tproxy(port, dnsport, family, subnets, udp):
if family not in [socket.AF_INET, socket.AF_INET6]: if family not in [socket.AF_INET, socket.AF_INET6]:
raise Exception('Address family "%s" unsupported by tproxy method'%family_to_string(family)) raise Exception(
'Address family "%s" unsupported by tproxy method'
% family_to_string(family))
table = "mangle" table = "mangle"
def ipt(*args): def ipt(*args):
return _ipt(family, table, *args) return _ipt(family, table, *args)
def ipt_ttl(*args): def ipt_ttl(*args):
return _ipt_ttl(family, table, *args) return _ipt_ttl(family, table, *args)
mark_chain = 'sshuttle-m-%s' % port mark_chain = 'sshuttle-m-%s' % port
tproxy_chain = 'sshuttle-t-%s' % port tproxy_chain = 'sshuttle-t-%s' % port
divert_chain = 'sshuttle-d-%s' % port divert_chain = 'sshuttle-d-%s' % port
@ -165,65 +182,70 @@ def do_iptables_tproxy(port, dnsport, family, subnets, udp):
ipt('-A', divert_chain, '-j', 'MARK', '--set-mark', '1') ipt('-A', divert_chain, '-j', 'MARK', '--set-mark', '1')
ipt('-A', divert_chain, '-j', 'ACCEPT') ipt('-A', divert_chain, '-j', 'ACCEPT')
ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain, ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'tcp', '-p', 'tcp') '-m', 'tcp', '-p', 'tcp')
if subnets and udp: if subnets and udp:
ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain, ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'udp', '-p', 'udp') '-m', 'udp', '-p', 'udp')
if dnsport: if dnsport:
nslist = resolvconf_nameservers() nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist): for f, ip in filter(lambda i: i[0] == family, nslist):
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1', ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/32' % ip, '--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53') '-m', 'udp', '-p', 'udp', '--dport', '53')
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1', ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1',
'--dest', '%s/32' % ip, '--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53', '-m', 'udp', '-p', 'udp', '--dport', '53',
'--on-port', str(dnsport)) '--on-port', str(dnsport))
if subnets: if subnets:
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True): for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude: if sexclude:
ipt('-A', mark_chain, '-j', 'RETURN', ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp') '-m', 'tcp', '-p', 'tcp')
ipt('-A', tproxy_chain, '-j', 'RETURN', ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp') '-m', 'tcp', '-p', 'tcp')
else: else:
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1', ipt('-A', mark_chain, '-j', 'MARK',
'--dest', '%s/%s' % (snet,swidth), '--set-mark', '1',
'-m', 'tcp', '-p', 'tcp') '--dest', '%s/%s' % (snet, swidth),
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1', '-m', 'tcp', '-p', 'tcp')
'--dest', '%s/%s' % (snet,swidth), ipt('-A', tproxy_chain, '-j', 'TPROXY',
'-m', 'tcp', '-p', 'tcp', '--tproxy-mark', '0x1/0x1',
'--on-port', str(port)) '--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp',
'--on-port', str(port))
if sexclude and udp: if sexclude and udp:
ipt('-A', mark_chain, '-j', 'RETURN', ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp') '-m', 'udp', '-p', 'udp')
ipt('-A', tproxy_chain, '-j', 'RETURN', ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth), '--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp') '-m', 'udp', '-p', 'udp')
elif udp: elif udp:
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1', ipt('-A', mark_chain, '-j', 'MARK',
'--dest', '%s/%s' % (snet,swidth), '--set-mark', '1',
'-m', 'udp', '-p', 'udp') '--dest', '%s/%s' % (snet, swidth),
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1', '-m', 'udp', '-p', 'udp')
'--dest', '%s/%s' % (snet,swidth), ipt('-A', tproxy_chain, '-j', 'TPROXY',
'-m', 'udp', '-p', 'udp', '--tproxy-mark', '0x1/0x1',
'--on-port', str(port)) '--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp',
'--on-port', str(port))
def ipfw_rule_exists(n): def ipfw_rule_exists(n):
argv = ['ipfw', 'list'] argv = ['ipfw', 'list']
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE) p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
found = False found = False
for line in p.stdout: for line in p.stdout:
if line.startswith('%05d ' % n): if line.startswith('%05d ' % n):
if not ('ipttl 42' in line if not ('ipttl 42' in line
or ('skipto %d' % (n+1)) in line or ('skipto %d' % (n + 1)) in line
or 'check-state' in line): or 'check-state' in line):
log('non-sshuttle ipfw rule: %r\n' % line.strip()) log('non-sshuttle ipfw rule: %r\n' % line.strip())
raise Fatal('non-sshuttle ipfw rule #%d already exists!' % n) raise Fatal('non-sshuttle ipfw rule #%d already exists!' % n)
@ -235,12 +257,14 @@ def ipfw_rule_exists(n):
_oldctls = {} _oldctls = {}
def _fill_oldctls(prefix): def _fill_oldctls(prefix):
argv = ['sysctl', prefix] argv = ['sysctl', prefix]
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE) p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
for line in p.stdout: for line in p.stdout:
assert(line[-1] == '\n') assert(line[-1] == '\n')
(k,v) = line[:-1].split(': ', 1) (k, v) = line[:-1].split(': ', 1)
_oldctls[k] = v _oldctls[k] = v
rv = p.wait() rv = p.wait()
if rv: if rv:
@ -252,10 +276,12 @@ def _fill_oldctls(prefix):
def _sysctl_set(name, val): def _sysctl_set(name, val):
argv = ['sysctl', '-w', '%s=%s' % (name, val)] argv = ['sysctl', '-w', '%s=%s' % (name, val)]
debug1('>> %s\n' % ' '.join(argv)) debug1('>> %s\n' % ' '.join(argv))
return ssubprocess.call(argv, stdout = open('/dev/null', 'w')) return ssubprocess.call(argv, stdout=open('/dev/null', 'w'))
_changedctls = [] _changedctls = []
def sysctl_set(name, val, permanent=False): def sysctl_set(name, val, permanent=False):
PREFIX = 'net.inet.ip' PREFIX = 'net.inet.ip'
assert(name.startswith(PREFIX + '.')) assert(name.startswith(PREFIX + '.'))
@ -268,7 +294,7 @@ def sysctl_set(name, val, permanent=False):
oldval = _oldctls[name] oldval = _oldctls[name]
if val != oldval: if val != oldval:
rv = _sysctl_set(name, val) rv = _sysctl_set(name, val)
if rv==0 and permanent: if rv == 0 and permanent:
debug1('>> ...saving permanently in /etc/sysctl.conf\n') debug1('>> ...saving permanently in /etc/sysctl.conf\n')
f = open('/etc/sysctl.conf', 'a') f = open('/etc/sysctl.conf', 'a')
f.write('\n' f.write('\n'
@ -293,9 +319,11 @@ def _udp_repack(p, src, dst):
_real_dns_server = [None] _real_dns_server = [None]
def _handle_diversion(divertsock, dnsport): def _handle_diversion(divertsock, dnsport):
p,tag = divertsock.recvfrom(4096) p, tag = divertsock.recvfrom(4096)
src,dst = _udp_unpack(p) src, dst = _udp_unpack(p)
debug3('got diverted packet from %r to %r\n' % (src, dst)) debug3('got diverted packet from %r to %r\n' % (src, dst))
if dst[1] == 53: if dst[1] == 53:
# outgoing DNS # outgoing DNS
@ -311,7 +339,7 @@ def _handle_diversion(divertsock, dnsport):
assert(0) assert(0)
newp = _udp_repack(p, src, dst) newp = _udp_repack(p, src, dst)
divertsock.sendto(newp, tag) divertsock.sendto(newp, tag)
def ipfw(*args): def ipfw(*args):
argv = ['ipfw', '-q'] + list(args) argv = ['ipfw', '-q'] + list(args)
@ -324,12 +352,14 @@ def ipfw(*args):
def do_ipfw(port, dnsport, family, subnets, udp): def do_ipfw(port, dnsport, family, subnets, udp):
# IPv6 not supported # IPv6 not supported
if family not in [socket.AF_INET, ]: if family not in [socket.AF_INET, ]:
raise Exception('Address family "%s" unsupported by ipfw method'%family_to_string(family)) raise Exception(
'Address family "%s" unsupported by ipfw method'
% family_to_string(family))
if udp: if udp:
raise Exception("UDP not supported by ipfw method") raise Exception("UDP not supported by ipfw method")
sport = str(port) sport = str(port)
xsport = str(port+1) xsport = str(port + 1)
# cleanup any existing rules # cleanup any existing rules
if ipfw_rule_exists(port): if ipfw_rule_exists(port):
@ -360,15 +390,16 @@ def do_ipfw(port, dnsport, family, subnets, udp):
if subnets: if subnets:
# create new subnet entries # create new subnet entries
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True): for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude: if sexclude:
ipfw('add', sport, 'skipto', xsport, ipfw('add', sport, 'skipto', xsport,
'log', 'tcp', 'log', 'tcp',
'from', 'any', 'to', '%s/%s' % (snet,swidth)) 'from', 'any', 'to', '%s/%s' % (snet, swidth))
else: else:
ipfw('add', sport, 'fwd', '127.0.0.1,%d' % port, ipfw('add', sport, 'fwd', '127.0.0.1,%d' % port,
'log', 'tcp', 'log', 'tcp',
'from', 'any', 'to', '%s/%s' % (snet,swidth), 'from', 'any', 'to', '%s/%s' % (snet, swidth),
'not', 'ipttl', '42', 'keep-state', 'setup') 'not', 'ipttl', '42', 'keep-state', 'setup')
# This part is much crazier than it is on Linux, because MacOS (at least # This part is much crazier than it is on Linux, because MacOS (at least
@ -403,10 +434,10 @@ def do_ipfw(port, dnsport, family, subnets, udp):
if dnsport: if dnsport:
divertsock = socket.socket(socket.AF_INET, socket.SOCK_RAW, divertsock = socket.socket(socket.AF_INET, socket.SOCK_RAW,
IPPROTO_DIVERT) IPPROTO_DIVERT)
divertsock.bind(('0.0.0.0', port)) # IP field is ignored divertsock.bind(('0.0.0.0', port)) # IP field is ignored
nslist = resolvconf_nameservers() nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist): for f, ip in filter(lambda i: i[0] == family, nslist):
# relabel and then catch outgoing DNS requests # relabel and then catch outgoing DNS requests
ipfw('add', sport, 'divert', sport, ipfw('add', sport, 'divert', sport,
'log', 'udp', 'log', 'udp',
@ -420,14 +451,14 @@ def do_ipfw(port, dnsport, family, subnets, udp):
def do_wait(): def do_wait():
while 1: while 1:
r,w,x = select.select([sys.stdin, divertsock], [], []) r, w, x = select.select([sys.stdin, divertsock], [], [])
if divertsock in r: if divertsock in r:
_handle_diversion(divertsock, dnsport) _handle_diversion(divertsock, dnsport)
if sys.stdin in r: if sys.stdin in r:
return return
else: else:
do_wait = None do_wait = None
return do_wait return do_wait
@ -440,10 +471,12 @@ def program_exists(name):
hostmap = {} hostmap = {}
def rewrite_etc_hosts(port): def rewrite_etc_hosts(port):
HOSTSFILE='/etc/hosts' HOSTSFILE = '/etc/hosts'
BAKFILE='%s.sbak' % HOSTSFILE BAKFILE = '%s.sbak' % HOSTSFILE
APPEND='# sshuttle-firewall-%d AUTOCREATED' % port APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port
old_content = '' old_content = ''
st = None st = None
try: try:
@ -462,8 +495,8 @@ def rewrite_etc_hosts(port):
if line.find(APPEND) >= 0: if line.find(APPEND) >= 0:
continue continue
f.write('%s\n' % line) f.write('%s\n' % line)
for (name,ip) in sorted(hostmap.items()): for (name, ip) in sorted(hostmap.items()):
f.write('%-30s %s\n' % ('%s %s' % (ip,name), APPEND)) f.write('%-30s %s\n' % ('%s %s' % (ip, name), APPEND))
f.close() f.close()
if st: if st:
@ -517,7 +550,7 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
elif method == "ipfw": elif method == "ipfw":
do_it = do_ipfw do_it = do_ipfw
else: else:
raise Exception('Unknown method "%s"'%method) raise Exception('Unknown method "%s"' % method)
# because of limitations of the 'su' command, the *real* stdin/stdout # because of limitations of the 'su' command, the *real* stdin/stdout
# are both attached to stdout initially. Clone stdout into stdin so we # are both attached to stdout initially. Clone stdout into stdin so we
@ -528,8 +561,8 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
ssyslog.start_syslog() ssyslog.start_syslog()
ssyslog.stderr_to_syslog() ssyslog.stderr_to_syslog()
debug1('firewall manager ready method %s.\n'%method) debug1('firewall manager ready method %s.\n' % method)
sys.stdout.write('READY %s\n'%method) sys.stdout.write('READY %s\n' % method)
sys.stdout.flush() sys.stdout.flush()
# ctrl-c shouldn't be passed along to me. When the main sshuttle dies, # ctrl-c shouldn't be passed along to me. When the main sshuttle dies,
@ -553,29 +586,31 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
elif line == 'GO\n': elif line == 'GO\n':
break break
try: try:
(family,width,exclude,ip) = line.strip().split(',', 3) (family, width, exclude, ip) = line.strip().split(',', 3)
except: except:
raise Fatal('firewall: expected route or GO but got %r' % line) raise Fatal('firewall: expected route or GO but got %r' % line)
subnets.append((int(family), int(width), bool(int(exclude)), ip)) subnets.append((int(family), int(width), bool(int(exclude)), ip))
try: try:
if line: if line:
debug1('firewall manager: starting transproxy.\n') debug1('firewall manager: starting transproxy.\n')
subnets_v6 = filter(lambda i: i[0]==socket.AF_INET6, subnets) subnets_v6 = filter(lambda i: i[0] == socket.AF_INET6, subnets)
if port_v6: if port_v6:
do_wait = do_it(port_v6, dnsport_v6, socket.AF_INET6, subnets_v6, udp) do_wait = do_it(
port_v6, dnsport_v6, socket.AF_INET6, subnets_v6, udp)
elif len(subnets_v6) > 0: elif len(subnets_v6) > 0:
debug1("IPv6 subnets defined but IPv6 disabled\n") debug1("IPv6 subnets defined but IPv6 disabled\n")
subnets_v4 = filter(lambda i: i[0]==socket.AF_INET, subnets) subnets_v4 = filter(lambda i: i[0] == socket.AF_INET, subnets)
if port_v4: if port_v4:
do_wait = do_it(port_v4, dnsport_v4, socket.AF_INET, subnets_v4, udp) do_wait = do_it(
port_v4, dnsport_v4, socket.AF_INET, subnets_v4, udp)
elif len(subnets_v4) > 0: elif len(subnets_v4) > 0:
debug1('IPv4 subnets defined but IPv4 disabled\n') debug1('IPv4 subnets defined but IPv4 disabled\n')
sys.stdout.write('STARTED\n') sys.stdout.write('STARTED\n')
try: try:
sys.stdout.flush() sys.stdout.flush()
except IOError: except IOError:
@ -587,10 +622,11 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
# to stay running so that we don't need a *second* password # to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important! # authentication at shutdown time - that cleanup is important!
while 1: while 1:
if do_wait: do_wait() if do_wait:
do_wait()
line = sys.stdin.readline(128) line = sys.stdin.readline(128)
if line.startswith('HOST '): if line.startswith('HOST '):
(name,ip) = line[5:].strip().split(',', 1) (name, ip) = line[5:].strip().split(',', 1)
hostmap[name] = ip hostmap[name] = ip
rewrite_etc_hosts(port_v6 or port_v4) rewrite_etc_hosts(port_v6 or port_v4)
elif line: elif line:

View File

@ -1,8 +1,11 @@
import sys, os, socket, errno import sys
import socket
import errno
logprefix = '' logprefix = ''
verbose = 0 verbose = 0
def log(s): def log(s):
try: try:
sys.stdout.flush() sys.stdout.flush()
@ -13,14 +16,17 @@ def log(s):
# our tty closes. That sucks, but it's no reason to abort the program. # our tty closes. That sucks, but it's no reason to abort the program.
pass pass
def debug1(s): def debug1(s):
if verbose >= 1: if verbose >= 1:
log(s) log(s)
def debug2(s): def debug2(s):
if verbose >= 2: if verbose >= 2:
log(s) log(s)
def debug3(s): def debug3(s):
if verbose >= 3: if verbose >= 3:
log(s) log(s)
@ -43,9 +49,9 @@ def resolvconf_nameservers():
words = line.lower().split() words = line.lower().split()
if len(words) >= 2 and words[0] == 'nameserver': if len(words) >= 2 and words[0] == 'nameserver':
if ':' in words[1]: if ':' in words[1]:
l.append((socket.AF_INET6,words[1])) l.append((socket.AF_INET6, words[1]))
else: else:
l.append((socket.AF_INET,words[1])) l.append((socket.AF_INET, words[1]))
return l return l
@ -58,10 +64,10 @@ def resolvconf_random_nameserver():
random.shuffle(l) random.shuffle(l)
return l[0] return l[0]
else: else:
return (socket.AF_INET,'127.0.0.1') return (socket.AF_INET, '127.0.0.1')
def islocal(ip,family):
def islocal(ip, family):
sock = socket.socket(family) sock = socket.socket(family)
try: try:
try: try:
@ -83,4 +89,3 @@ def family_to_string(family):
return "AF_INET" return "AF_INET"
else: else:
return str(family) return str(family)

View File

@ -1,12 +1,18 @@
import time, socket, re, select, errno import time
import socket
import re
import select
import errno
import os
import sys
if not globals().get('skip_imports'): if not globals().get('skip_imports'):
import compat.ssubprocess as ssubprocess import compat.ssubprocess as ssubprocess
import helpers import helpers
from helpers import * from helpers import log, debug1, debug2, debug3
POLL_TIME = 60*15 POLL_TIME = 60 * 15
NETSTAT_POLL_TIME = 30 NETSTAT_POLL_TIME = 30
CACHEFILE=os.path.expanduser('~/.sshuttle.hosts') CACHEFILE = os.path.expanduser('~/.sshuttle.hosts')
_nmb_ok = True _nmb_ok = True
@ -28,7 +34,7 @@ def write_host_cache():
tmpname = '%s.%d.tmp' % (CACHEFILE, os.getpid()) tmpname = '%s.%d.tmp' % (CACHEFILE, os.getpid())
try: try:
f = open(tmpname, 'wb') f = open(tmpname, 'wb')
for name,ip in sorted(hostnames.items()): for name, ip in sorted(hostnames.items()):
f.write('%s,%s\n' % (name, ip)) f.write('%s,%s\n' % (name, ip))
f.close() f.close()
os.rename(tmpname, CACHEFILE) os.rename(tmpname, CACHEFILE)
@ -50,18 +56,18 @@ def read_host_cache():
for line in f: for line in f:
words = line.strip().split(',') words = line.strip().split(',')
if len(words) == 2: if len(words) == 2:
(name,ip) = words (name, ip) = words
name = re.sub(r'[^-\w]', '-', name).strip() name = re.sub(r'[^-\w]', '-', name).strip()
ip = re.sub(r'[^0-9.]', '', ip).strip() ip = re.sub(r'[^0-9.]', '', ip).strip()
if name and ip: if name and ip:
found_host(name, ip) found_host(name, ip)
def found_host(hostname, ip): def found_host(hostname, ip):
hostname = re.sub(r'\..*', '', hostname) hostname = re.sub(r'\..*', '', hostname)
hostname = re.sub(r'[^-\w]', '_', hostname) hostname = re.sub(r'[^-\w]', '_', hostname)
if (ip.startswith('127.') or ip.startswith('255.') if (ip.startswith('127.') or ip.startswith('255.')
or hostname == 'localhost'): or hostname == 'localhost'):
return return
oldip = hostnames.get(hostname) oldip = hostnames.get(hostname)
if oldip != ip: if oldip != ip:
@ -94,7 +100,7 @@ def _check_revdns(ip):
debug3('< %s\n' % r[0]) debug3('< %s\n' % r[0])
check_host(r[0]) check_host(r[0])
found_host(r[0], ip) found_host(r[0], ip)
except socket.herror, e: except socket.herror:
pass pass
@ -105,7 +111,7 @@ def _check_dns(hostname):
debug3('< %s\n' % ip) debug3('< %s\n' % ip)
check_host(ip) check_host(ip)
found_host(hostname, ip) found_host(hostname, ip)
except socket.gaierror, e: except socket.gaierror:
pass pass
@ -123,7 +129,7 @@ def _check_netstat():
for ip in re.findall(r'\d+\.\d+\.\d+\.\d+', content): for ip in re.findall(r'\d+\.\d+\.\d+\.\d+', content):
debug3('< %s\n' % ip) debug3('< %s\n' % ip)
check_host(ip) check_host(ip)
def _check_smb(hostname): def _check_smb(hostname):
return return
@ -187,7 +193,7 @@ def _check_nmb(hostname, is_workgroup, is_master):
global _nmb_ok global _nmb_ok
if not _nmb_ok: if not _nmb_ok:
return return
argv = ['nmblookup'] + ['-M']*is_master + ['--', hostname] argv = ['nmblookup'] + ['-M'] * is_master + ['--', hostname]
debug2(' > n%d%d: %s\n' % (is_workgroup, is_master, hostname)) debug2(' > n%d%d: %s\n' % (is_workgroup, is_master, hostname))
try: try:
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE, stderr=null) p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE, stderr=null)
@ -228,13 +234,13 @@ def check_workgroup(hostname):
def _enqueue(op, *args): def _enqueue(op, *args):
t = (op,args) t = (op, args)
if queue.get(t) == None: if queue.get(t) is None:
queue[t] = 0 queue[t] = 0
def _stdin_still_ok(timeout): def _stdin_still_ok(timeout):
r,w,x = select.select([sys.stdin.fileno()], [], [], timeout) r, w, x = select.select([sys.stdin.fileno()], [], [], timeout)
if r: if r:
b = os.read(sys.stdin.fileno(), 4096) b = os.read(sys.stdin.fileno(), 4096)
if not b: if not b:
@ -249,7 +255,7 @@ def hw_main(seed_hosts):
helpers.logprefix = 'hostwatch: ' helpers.logprefix = 'hostwatch: '
read_host_cache() read_host_cache()
_enqueue(_check_etc_hosts) _enqueue(_check_etc_hosts)
_enqueue(_check_netstat) _enqueue(_check_netstat)
check_host('localhost') check_host('localhost')
@ -261,8 +267,8 @@ def hw_main(seed_hosts):
while 1: while 1:
now = time.time() now = time.time()
for t,last_polled in queue.items(): for t, last_polled in queue.items():
(op,args) = t (op, args) = t
if not _stdin_still_ok(0): if not _stdin_still_ok(0):
break break
maxtime = POLL_TIME maxtime = POLL_TIME
@ -275,7 +281,7 @@ def hw_main(seed_hosts):
sys.stdout.flush() sys.stdout.flush()
except IOError: except IOError:
break break
# FIXME: use a smarter timeout based on oldest last_polled # FIXME: use a smarter timeout based on oldest last_polled
if not _stdin_still_ok(1): if not _stdin_still_ok(1):
break break

View File

@ -1,7 +1,13 @@
import sys, os, re, socket import sys
import helpers, options, client, server, firewall, hostwatch import re
import compat.ssubprocess as ssubprocess import socket
from helpers import * import helpers
import options
import client
import server
import firewall
import hostwatch
from helpers import log, Fatal
# 1.2.3.4/5 or just 1.2.3.4 # 1.2.3.4/5 or just 1.2.3.4
@ -9,17 +15,17 @@ def parse_subnet4(s):
m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s) m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s)
if not m: if not m:
raise Fatal('%r is not a valid IP subnet format' % s) raise Fatal('%r is not a valid IP subnet format' % s)
(a,b,c,d,width) = m.groups() (a, b, c, d, width) = m.groups()
(a,b,c,d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0)) (a, b, c, d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0))
if width == None: if width is None:
width = 32 width = 32
else: else:
width = int(width) width = int(width)
if a > 255 or b > 255 or c > 255 or d > 255: if a > 255 or b > 255 or c > 255 or d > 255:
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) raise Fatal('%d.%d.%d.%d has numbers > 255' % (a, b, c, d))
if width > 32: if width > 32:
raise Fatal('*/%d is greater than the maximum of 32' % width) raise Fatal('*/%d is greater than the maximum of 32' % width)
return(socket.AF_INET, '%d.%d.%d.%d' % (a,b,c,d), width) return(socket.AF_INET, '%d.%d.%d.%d' % (a, b, c, d), width)
# 1:2::3/64 or just 1:2::3 # 1:2::3/64 or just 1:2::3
@ -27,8 +33,8 @@ def parse_subnet6(s):
m = re.match(r'(?:([a-fA-F\d:]+))?(?:/(\d+))?$', s) m = re.match(r'(?:([a-fA-F\d:]+))?(?:/(\d+))?$', s)
if not m: if not m:
raise Fatal('%r is not a valid IP subnet format' % s) raise Fatal('%r is not a valid IP subnet format' % s)
(net,width) = m.groups() (net, width) = m.groups()
if width == None: if width is None:
width = 128 width = 128
else: else:
width = int(width) width = int(width)
@ -41,7 +47,7 @@ def parse_subnet6(s):
def parse_subnet_file(s): def parse_subnet_file(s):
try: try:
handle = open(s, 'r') handle = open(s, 'r')
except OSError, e: except OSError:
raise Fatal('Unable to open subnet file: %s' % s) raise Fatal('Unable to open subnet file: %s' % s)
raw_config_lines = handle.readlines() raw_config_lines = handle.readlines()
@ -77,16 +83,16 @@ def parse_ipport4(s):
m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s) m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s)
if not m: if not m:
raise Fatal('%r is not a valid IP:port format' % s) raise Fatal('%r is not a valid IP:port format' % s)
(a,b,c,d,port) = m.groups() (a, b, c, d, port) = m.groups()
(a,b,c,d,port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0), (a, b, c, d, port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0),
int(port or 0)) int(port or 0))
if a > 255 or b > 255 or c > 255 or d > 255: if a > 255 or b > 255 or c > 255 or d > 255:
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) raise Fatal('%d.%d.%d.%d has numbers > 255' % (a, b, c, d))
if port > 65535: if port > 65535:
raise Fatal('*:%d is greater than the maximum of 65535' % port) raise Fatal('*:%d is greater than the maximum of 65535' % port)
if a == None: if a is None:
a = b = c = d = 0 a = b = c = d = 0
return ('%d.%d.%d.%d' % (a,b,c,d), port) return ('%d.%d.%d.%d' % (a, b, c, d), port)
# [1:2::3]:456 or [1:2::3] or 456 # [1:2::3]:456 or [1:2::3] or 456
@ -95,8 +101,8 @@ def parse_ipport6(s):
m = re.match(r'(?:\[([^]]*)])?(?::)?(?:(\d+))?$', s) m = re.match(r'(?:\[([^]]*)])?(?::)?(?:(\d+))?$', s)
if not m: if not m:
raise Fatal('%s is not a valid IP:port format' % s) raise Fatal('%s is not a valid IP:port format' % s)
(ip,port) = m.groups() (ip, port) = m.groups()
(ip,port) = (ip or '::', int(port or 0)) (ip, port) = (ip or '::', int(port or 0))
return (ip, port) return (ip, port)
@ -156,8 +162,8 @@ try:
o.fatal('at least one subnet, subnet file, or -N expected') o.fatal('at least one subnet, subnet file, or -N expected')
includes = extra includes = extra
excludes = ['127.0.0.0/8'] excludes = ['127.0.0.0/8']
for k,v in flags: for k, v in flags:
if k in ('-x','--exclude'): if k in ('-x', '--exclude'):
excludes.append(v) excludes.append(v)
remotename = opt.remote remotename = opt.remote
if remotename == '' or remotename == '-': if remotename == '' or remotename == '-':
@ -174,10 +180,10 @@ try:
includes = parse_subnet_file(opt.subnets) includes = parse_subnet_file(opt.subnets)
if not opt.method: if not opt.method:
method = "auto" method = "auto"
elif opt.method in [ "auto", "nat", "tproxy", "ipfw" ]: elif opt.method in ["auto", "nat", "tproxy", "ipfw"]:
method = opt.method method = opt.method
else: else:
o.fatal("method %s not supported"%opt.method) o.fatal("method %s not supported" % opt.method)
if not opt.listen: if not opt.listen:
if opt.method == "tproxy": if opt.method == "tproxy":
ipport_v6 = parse_ipport6('[::1]:0') ipport_v6 = parse_ipport6('[::1]:0')
@ -194,23 +200,23 @@ try:
else: else:
ipport_v4 = parse_ipport4(ip) ipport_v4 = parse_ipport4(ip)
return_code = client.main(ipport_v6, ipport_v4, return_code = client.main(ipport_v6, ipport_v4,
opt.ssh_cmd, opt.ssh_cmd,
remotename, remotename,
opt.python, opt.python,
opt.latency_control, opt.latency_control,
opt.dns, opt.dns,
method, method,
sh, sh,
opt.auto_nets, opt.auto_nets,
parse_subnets(includes), parse_subnets(includes),
parse_subnets(excludes), parse_subnets(excludes),
opt.syslog, opt.daemon, opt.pidfile) opt.syslog, opt.daemon, opt.pidfile)
if return_code == 0: if return_code == 0:
log('Normal exit code, exiting...') log('Normal exit code, exiting...')
else: else:
log('Abnormal exit code detected, failing...' % return_code) log('Abnormal exit code detected, failing...' % return_code)
sys.exit(return_code) sys.exit(return_code)
except Fatal, e: except Fatal, e:
log('fatal: %s\n' % e) log('fatal: %s\n' % e)

View File

@ -1,9 +1,16 @@
"""Command-line options parser. """Command-line options parser.
With the help of an options spec string, easily parse command-line options. With the help of an options spec string, easily parse command-line options.
""" """
import sys, os, textwrap, getopt, re, struct import sys
import os
import textwrap
import getopt
import re
import struct
class OptDict: class OptDict:
def __init__(self): def __init__(self):
self._opts = {} self._opts = {}
@ -46,7 +53,8 @@ def _atoi(v):
def _remove_negative_kv(k, v): def _remove_negative_kv(k, v):
if k.startswith('no-') or k.startswith('no_'): if k.startswith('no-') or k.startswith('no_'):
return k[3:], not v return k[3:], not v
return k,v return k, v
def _remove_negative_k(k): def _remove_negative_k(k):
return _remove_negative_kv(k, None)[0] return _remove_negative_kv(k, None)[0]
@ -55,15 +63,17 @@ def _remove_negative_k(k):
def _tty_width(): def _tty_width():
s = struct.pack("HHHH", 0, 0, 0, 0) s = struct.pack("HHHH", 0, 0, 0, 0)
try: try:
import fcntl, termios import fcntl
import termios
s = fcntl.ioctl(sys.stderr.fileno(), termios.TIOCGWINSZ, s) s = fcntl.ioctl(sys.stderr.fileno(), termios.TIOCGWINSZ, s)
except (IOError, ImportError): except (IOError, ImportError):
return _atoi(os.environ.get('WIDTH')) or 70 return _atoi(os.environ.get('WIDTH')) or 70
(ysize,xsize,ypix,xpix) = struct.unpack('HHHH', s) (ysize, xsize, ypix, xpix) = struct.unpack('HHHH', s)
return xsize or 70 return xsize or 70
class Options: class Options:
"""Option parser. """Option parser.
When constructed, two strings are mandatory. The first one is the command When constructed, two strings are mandatory. The first one is the command
name showed before error messages. The second one is a string called an name showed before error messages. The second one is a string called an
@ -76,6 +86,7 @@ class Options:
By default, the parser function is getopt.gnu_getopt, and the abort By default, the parser function is getopt.gnu_getopt, and the abort
behaviour is to exit the program. behaviour is to exit the program.
""" """
def __init__(self, optspec, optfunc=getopt.gnu_getopt, def __init__(self, optspec, optfunc=getopt.gnu_getopt,
onabort=_default_onabort): onabort=_default_onabort):
self.optspec = optspec self.optspec = optspec
@ -95,7 +106,8 @@ class Options:
first_syn = True first_syn = True
while lines: while lines:
l = lines.pop() l = lines.pop()
if l == '--': break if l == '--':
break
out.append('%s: %s\n' % (first_syn and 'usage' or ' or', l)) out.append('%s: %s\n' % (first_syn and 'usage' or ' or', l))
first_syn = False first_syn = False
out.append('\n') out.append('\n')
@ -122,7 +134,7 @@ class Options:
flagl = flags.split(',') flagl = flags.split(',')
flagl_nice = [] flagl_nice = []
for _f in flagl: for _f in flagl:
f,dvi = _remove_negative_kv(_f, _intify(defval)) f, dvi = _remove_negative_kv(_f, _intify(defval))
self._aliases[f] = _remove_negative_k(flagl[0]) self._aliases[f] = _remove_negative_k(flagl[0])
self._hasparms[f] = has_parm self._hasparms[f] = has_parm
self._defaults[f] = dvi self._defaults[f] = dvi
@ -140,8 +152,8 @@ class Options:
flags_nice += ' ...' flags_nice += ' ...'
prefix = ' %-20s ' % flags_nice prefix = ' %-20s ' % flags_nice
argtext = '\n'.join(textwrap.wrap(extra, width=_tty_width(), argtext = '\n'.join(textwrap.wrap(extra, width=_tty_width(),
initial_indent=prefix, initial_indent=prefix,
subsequent_indent=' '*28)) subsequent_indent=' ' * 28))
out.append(argtext + '\n') out.append(argtext + '\n')
last_was_option = True last_was_option = True
else: else:
@ -170,17 +182,18 @@ class Options:
and "extra" is a list of positional arguments. and "extra" is a list of positional arguments.
""" """
try: try:
(flags,extra) = self.optfunc(args, self._shortopts, self._longopts) (flags, extra) = self.optfunc(
args, self._shortopts, self._longopts)
except getopt.GetoptError, e: except getopt.GetoptError, e:
self.fatal(e) self.fatal(e)
opt = OptDict() opt = OptDict()
for k,v in self._defaults.iteritems(): for k, v in self._defaults.iteritems():
k = self._aliases[k] k = self._aliases[k]
opt[k] = v opt[k] = v
for (k,v) in flags: for (k, v) in flags:
k = k.lstrip('-') k = k.lstrip('-')
if k in ('h', '?', 'help'): if k in ('h', '?', 'help'):
self.usage() self.usage()
@ -195,6 +208,6 @@ class Options:
else: else:
v = _intify(v) v = _intify(v)
opt[k] = v opt[k] = v
for (f1,f2) in self._aliases.iteritems(): for (f1, f2) in self._aliases.iteritems():
opt[f1] = opt._opts.get(f2) opt[f1] = opt._opts.get(f2)
return (opt,flags,extra) return (opt, flags, extra)

View File

@ -1,9 +1,22 @@
import re, struct, socket, select, traceback, time import re
import struct
import socket
import traceback
import time
import sys
import os
if not globals().get('skip_imports'): if not globals().get('skip_imports'):
import ssnet, helpers, hostwatch import ssnet
import helpers
import hostwatch
import compat.ssubprocess as ssubprocess import compat.ssubprocess as ssubprocess
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from ssnet import Handler, Proxy, Mux, MuxWrapper
from helpers import * from helpers import log, debug1, debug2, debug3, Fatal, \
resolvconf_random_nameserver
if not globals().get('latency_control'):
latency_control = None
def _ipmatch(ipstr): def _ipmatch(ipstr):
@ -14,13 +27,13 @@ def _ipmatch(ipstr):
g = m.groups() g = m.groups()
ips = g[0] ips = g[0]
width = int(g[4] or 32) width = int(g[4] or 32)
if g[1] == None: if g[1] is None:
ips += '.0.0.0' ips += '.0.0.0'
width = min(width, 8) width = min(width, 8)
elif g[2] == None: elif g[2] is None:
ips += '.0.0' ips += '.0.0'
width = min(width, 16) width = min(width, 16)
elif g[3] == None: elif g[3] is None:
ips += '.0' ips += '.0'
width = min(width, 24) width = min(width, 24)
return (struct.unpack('!I', socket.inet_aton(ips))[0], width) return (struct.unpack('!I', socket.inet_aton(ips))[0], width)
@ -38,12 +51,12 @@ def _maskbits(netmask):
return 32 return 32
for i in range(32): for i in range(32):
if netmask[0] & _shl(1, i): if netmask[0] & _shl(1, i):
return 32-i return 32 - i
return 0 return 0
def _shl(n, bits): def _shl(n, bits):
return n * int(2**bits) return n * int(2 ** bits)
def _list_routes(): def _list_routes():
@ -58,8 +71,9 @@ def _list_routes():
maskw = _ipmatch(cols[2]) # linux only maskw = _ipmatch(cols[2]) # linux only
mask = _maskbits(maskw) # returns 32 if maskw is null mask = _maskbits(maskw) # returns 32 if maskw is null
width = min(ipw[1], mask) width = min(ipw[1], mask)
ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width) ip = ipw[0] & _shl(_shl(1, width) - 1, 32 - width)
routes.append((socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width)) routes.append(
(socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width))
rv = p.wait() rv = p.wait()
if rv != 0: if rv != 0:
log('WARNING: %r returned %d\n' % (argv, rv)) log('WARNING: %r returned %d\n' % (argv, rv))
@ -68,9 +82,9 @@ def _list_routes():
def list_routes(): def list_routes():
for (family, ip,width) in _list_routes(): for (family, ip, width) in _list_routes():
if not ip.startswith('0.') and not ip.startswith('127.'): if not ip.startswith('0.') and not ip.startswith('127.'):
yield (family, ip,width) yield (family, ip, width)
def _exc_dump(): def _exc_dump():
@ -79,7 +93,7 @@ def _exc_dump():
def start_hostwatch(seed_hosts): def start_hostwatch(seed_hosts):
s1,s2 = socket.socketpair() s1, s2 = socket.socketpair()
pid = os.fork() pid = os.fork()
if not pid: if not pid:
# child # child
@ -91,27 +105,29 @@ def start_hostwatch(seed_hosts):
os.dup2(s1.fileno(), 0) os.dup2(s1.fileno(), 0)
s1.close() s1.close()
rv = hostwatch.hw_main(seed_hosts) or 0 rv = hostwatch.hw_main(seed_hosts) or 0
except Exception, e: except Exception:
log('%s\n' % _exc_dump()) log('%s\n' % _exc_dump())
rv = 98 rv = 98
finally: finally:
os._exit(rv) os._exit(rv)
s1.close() s1.close()
return pid,s2 return pid, s2
class Hostwatch: class Hostwatch:
def __init__(self): def __init__(self):
self.pid = 0 self.pid = 0
self.sock = None self.sock = None
class DnsProxy(Handler): class DnsProxy(Handler):
def __init__(self, mux, chan, request): def __init__(self, mux, chan, request):
# FIXME! IPv4 specific # FIXME! IPv4 specific
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
Handler.__init__(self, [sock]) Handler.__init__(self, [sock])
self.timeout = time.time()+30 self.timeout = time.time() + 30
self.mux = mux self.mux = mux
self.chan = chan self.chan = chan
self.tries = 0 self.tries = 0
@ -164,10 +180,11 @@ class DnsProxy(Handler):
class UdpProxy(Handler): class UdpProxy(Handler):
def __init__(self, mux, chan, family): def __init__(self, mux, chan, family):
sock = socket.socket(family, socket.SOCK_DGRAM) sock = socket.socket(family, socket.SOCK_DGRAM)
Handler.__init__(self, [sock]) Handler.__init__(self, [sock])
self.timeout = time.time()+30 self.timeout = time.time() + 30
self.mux = mux self.mux = mux
self.chan = chan self.chan = chan
self.sock = sock self.sock = sock
@ -177,33 +194,35 @@ class UdpProxy(Handler):
def send(self, dstip, data): def send(self, dstip, data):
debug2('UDP: sending to %r port %d\n' % dstip) debug2('UDP: sending to %r port %d\n' % dstip)
try: try:
self.sock.sendto(data,dstip) self.sock.sendto(data, dstip)
except socket.error, e: except socket.error, e:
log('UDP send to %r port %d: %s\n' % (dstip[0], dstip[1], e)) log('UDP send to %r port %d: %s\n' % (dstip[0], dstip[1], e))
return return
def callback(self): def callback(self):
try: try:
data,peer = self.sock.recvfrom(4096) data, peer = self.sock.recvfrom(4096)
except socket.error, e: except socket.error, e:
log('UDP recv from %r port %d: %s\n' % (peer[0], peer[1], e)) log('UDP recv from %r port %d: %s\n' % (peer[0], peer[1], e))
return return
debug2('UDP response: %d bytes\n' % len(data)) debug2('UDP response: %d bytes\n' % len(data))
hdr = "%s,%r,"%(peer[0], peer[1]) hdr = "%s,%r," % (peer[0], peer[1])
self.mux.send(self.chan, ssnet.CMD_UDP_DATA, hdr+data) self.mux.send(self.chan, ssnet.CMD_UDP_DATA, hdr + data)
def main(): def main():
if helpers.verbose >= 1: if helpers.verbose >= 1:
helpers.logprefix = ' s: ' helpers.logprefix = ' s: '
else: else:
helpers.logprefix = 'server: ' helpers.logprefix = 'server: '
assert latency_control is not None
debug1('latency control setting = %r\n' % latency_control) debug1('latency control setting = %r\n' % latency_control)
routes = list(list_routes()) routes = list(list_routes())
debug1('available routes:\n') debug1('available routes:\n')
for r in routes: for r in routes:
debug1(' %d/%s/%d\n' % r) debug1(' %d/%s/%d\n' % r)
# synchronization header # synchronization header
sys.stdout.write('\0\0SSHUTTLE0001') sys.stdout.write('\0\0SSHUTTLE0001')
sys.stdout.flush() sys.stdout.flush()
@ -221,7 +240,7 @@ def main():
hw = Hostwatch() hw = Hostwatch()
hw.leftover = '' hw.leftover = ''
def hostwatch_ready(): def hostwatch_ready():
assert(hw.pid) assert(hw.pid)
content = hw.sock.recv(4096) content = hw.sock.recv(4096)
@ -239,13 +258,13 @@ def main():
def got_host_req(data): def got_host_req(data):
if not hw.pid: if not hw.pid:
(hw.pid,hw.sock) = start_hostwatch(data.strip().split()) (hw.pid, hw.sock) = start_hostwatch(data.strip().split())
handlers.append(Handler(socks = [hw.sock], handlers.append(Handler(socks=[hw.sock],
callback = hostwatch_ready)) callback=hostwatch_ready))
mux.got_host_req = got_host_req mux.got_host_req = got_host_req
def new_channel(channel, data): def new_channel(channel, data):
(family,dstip,dstport) = data.split(',', 2) (family, dstip, dstport) = data.split(',', 2)
family = int(family) family = int(family)
dstport = int(dstport) dstport = int(dstport)
outwrap = ssnet.connect_dst(family, dstip, dstport) outwrap = ssnet.connect_dst(family, dstip, dstport)
@ -253,6 +272,7 @@ def main():
mux.new_channel = new_channel mux.new_channel = new_channel
dnshandlers = {} dnshandlers = {}
def dns_req(channel, data): def dns_req(channel, data):
debug2('Incoming DNS request channel=%d.\n' % channel) debug2('Incoming DNS request channel=%d.\n' % channel)
h = DnsProxy(mux, channel, data) h = DnsProxy(mux, channel, data)
@ -261,14 +281,15 @@ def main():
mux.got_dns_req = dns_req mux.got_dns_req = dns_req
udphandlers = {} udphandlers = {}
def udp_req(channel, cmd, data): def udp_req(channel, cmd, data):
debug2('Incoming UDP request channel=%d, cmd=%d\n' % (channel,cmd)) debug2('Incoming UDP request channel=%d, cmd=%d\n' % (channel, cmd))
if cmd == ssnet.CMD_UDP_DATA: if cmd == ssnet.CMD_UDP_DATA:
(dstip,dstport,data) = data.split(",",2) (dstip, dstport, data) = data.split(",", 2)
dstport = int(dstport) dstport = int(dstport)
debug2('is incoming UDP data. %r %d.\n' % (dstip,dstport)) debug2('is incoming UDP data. %r %d.\n' % (dstip, dstport))
h = udphandlers[channel] h = udphandlers[channel]
h.send((dstip,dstport),data) h.send((dstip, dstport), data)
elif cmd == ssnet.CMD_UDP_CLOSE: elif cmd == ssnet.CMD_UDP_CLOSE:
debug2('is incoming UDP close\n') debug2('is incoming UDP close\n')
h = udphandlers[channel] h = udphandlers[channel]
@ -280,21 +301,21 @@ def main():
family = int(data) family = int(data)
mux.channels[channel] = lambda cmd, data: udp_req(channel, cmd, data) mux.channels[channel] = lambda cmd, data: udp_req(channel, cmd, data)
if channel in udphandlers: if channel in udphandlers:
raise Fatal('UDP connection channel %d already open'%channel) raise Fatal('UDP connection channel %d already open' % channel)
else: else:
h = UdpProxy(mux, channel, family) h = UdpProxy(mux, channel, family)
handlers.append(h) handlers.append(h)
udphandlers[channel] = h udphandlers[channel] = h
mux.got_udp_open = udp_open mux.got_udp_open = udp_open
while mux.ok: while mux.ok:
if hw.pid: if hw.pid:
assert(hw.pid > 0) assert(hw.pid > 0)
(rpid, rv) = os.waitpid(hw.pid, os.WNOHANG) (rpid, rv) = os.waitpid(hw.pid, os.WNOHANG)
if rpid: if rpid:
raise Fatal('hostwatch exited unexpectedly: code 0x%04x\n' % rv) raise Fatal(
'hostwatch exited unexpectedly: code 0x%04x\n' % rv)
ssnet.runonce(handlers, mux) ssnet.runonce(handlers, mux)
if latency_control: if latency_control:
mux.check_fullness() mux.check_fullness()
@ -302,12 +323,12 @@ def main():
if dnshandlers: if dnshandlers:
now = time.time() now = time.time()
for channel,h in dnshandlers.items(): for channel, h in dnshandlers.items():
if h.timeout < now or not h.ok: if h.timeout < now or not h.ok:
debug3('expiring dnsreqs channel=%d\n' % channel) debug3('expiring dnsreqs channel=%d\n' % channel)
del dnshandlers[channel] del dnshandlers[channel]
h.ok = False h.ok = False
for channel,h in udphandlers.items(): for channel, h in udphandlers.items():
if not h.ok: if not h.ok:
debug3('expiring UDP channel=%d\n' % channel) debug3('expiring UDP channel=%d\n' % channel)
del udphandlers[channel] del udphandlers[channel]

View File

@ -1,7 +1,11 @@
import sys, os, re, socket, zlib import sys
import os
import re
import socket
import zlib
import compat.ssubprocess as ssubprocess import compat.ssubprocess as ssubprocess
import helpers import helpers
from helpers import * from helpers import debug2
def readfile(name): def readfile(name):
@ -15,7 +19,7 @@ def readfile(name):
def empackage(z, filename, data=None): def empackage(z, filename, data=None):
(path,basename) = os.path.split(filename) (path, basename) = os.path.split(filename)
if not data: if not data:
data = readfile(filename) data = readfile(filename)
content = z.compress(data) content = z.compress(data)
@ -24,7 +28,6 @@ def empackage(z, filename, data=None):
def connect(ssh_cmd, rhostport, python, stderr, options): def connect(ssh_cmd, rhostport, python, stderr, options):
main_exe = sys.argv[0]
portl = [] portl = []
if (rhostport or '').count(':') > 1: if (rhostport or '').count(':') > 1:
@ -35,9 +38,11 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
result[1] = result[1].strip(':') result[1] = result[1].strip(':')
if result[1] is not '': if result[1] is not '':
portl = ['-p', str(int(result[1]))] portl = ['-p', str(int(result[1]))]
else: # can't disambiguate IPv6 colons and a port number. pass the hostname through. # can't disambiguate IPv6 colons and a port number. pass the hostname
# through.
else:
rhost = rhostport rhost = rhostport
else: # IPv4 else: # IPv4
l = (rhostport or '').split(':', 1) l = (rhostport or '').split(':', 1)
rhost = l[0] rhost = l[0]
if len(l) > 1: if len(l) > 1:
@ -48,7 +53,7 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
z = zlib.compressobj(1) z = zlib.compressobj(1)
content = readfile('assembler.py') content = readfile('assembler.py')
optdata = ''.join("%s=%r\n" % (k,v) for (k,v) in options.items()) optdata = ''.join("%s=%r\n" % (k, v) for (k, v) in options.items())
content2 = (empackage(z, 'cmdline_options.py', optdata) + content2 = (empackage(z, 'cmdline_options.py', optdata) +
empackage(z, 'helpers.py') + empackage(z, 'helpers.py') +
empackage(z, 'compat/ssubprocess.py') + empackage(z, 'compat/ssubprocess.py') +
@ -56,7 +61,7 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
empackage(z, 'hostwatch.py') + empackage(z, 'hostwatch.py') +
empackage(z, 'server.py') + empackage(z, 'server.py') +
"\n") "\n")
pyscript = r""" pyscript = r"""
import sys; import sys;
skip_imports=1; skip_imports=1;
@ -65,7 +70,6 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
""" % (helpers.verbose or 0, len(content)) """ % (helpers.verbose or 0, len(content))
pyscript = re.sub(r'\s+', ' ', pyscript.strip()) pyscript = re.sub(r'\s+', ' ', pyscript.strip())
if not rhost: if not rhost:
# ignore the --python argument when running locally; we already know # ignore the --python argument when running locally; we already know
# which python version works. # which python version works.
@ -80,14 +84,15 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
else: else:
pycmd = ("P=python2; $P -V 2>/dev/null || P=python; " pycmd = ("P=python2; $P -V 2>/dev/null || P=python; "
"exec \"$P\" -c '%s'") % pyscript "exec \"$P\" -c '%s'") % pyscript
argv = (sshl + argv = (sshl +
portl + portl +
[rhost, '--', pycmd]) [rhost, '--', pycmd])
(s1,s2) = socket.socketpair() (s1, s2) = socket.socketpair()
def setup(): def setup():
# runs in the child process # runs in the child process
s2.close() s2.close()
s1a,s1b = os.dup(s1.fileno()), os.dup(s1.fileno()) s1a, s1b = os.dup(s1.fileno()), os.dup(s1.fileno())
s1.close() s1.close()
debug2('executing: %r\n' % argv) debug2('executing: %r\n' % argv)
p = ssubprocess.Popen(argv, stdin=s1a, stdout=s1b, preexec_fn=setup, p = ssubprocess.Popen(argv, stdin=s1a, stdout=s1b, preexec_fn=setup,

View File

@ -1,9 +1,13 @@
import struct, socket, errno, select import struct
import socket
import errno
import select
import os
if not globals().get('skip_imports'): if not globals().get('skip_imports'):
from helpers import * from helpers import log, debug1, debug2, debug3, Fatal
MAX_CHANNEL = 65535 MAX_CHANNEL = 65535
# these don't exist in the socket module in python 2.3! # these don't exist in the socket module in python 2.3!
SHUT_RD = 0 SHUT_RD = 0
SHUT_WR = 1 SHUT_WR = 1
@ -92,7 +96,10 @@ def _try_peername(sock):
_swcount = 0 _swcount = 0
class SockWrapper: class SockWrapper:
def __init__(self, rsock, wsock, connect_to=None, peername=None): def __init__(self, rsock, wsock, connect_to=None, peername=None):
global _swcount global _swcount
_swcount += 1 _swcount += 1
@ -177,8 +184,8 @@ class SockWrapper:
if not self.shut_read: if not self.shut_read:
debug2('%r: done reading\n' % self) debug2('%r: done reading\n' % self)
self.shut_read = True self.shut_read = True
#self.rsock.shutdown(SHUT_RD) # doesn't do anything anyway # self.rsock.shutdown(SHUT_RD) # doesn't do anything anyway
def nowrite(self): def nowrite(self):
if not self.shut_write: if not self.shut_write:
debug2('%r: done writing\n' % self) debug2('%r: done writing\n' % self)
@ -206,7 +213,7 @@ class SockWrapper:
# unexpected error... stream is dead # unexpected error... stream is dead
self.seterr('uwrite: %s' % e) self.seterr('uwrite: %s' % e)
return 0 return 0
def write(self, buf): def write(self, buf):
assert(buf) assert(buf)
return self.uwrite(buf) return self.uwrite(buf)
@ -221,7 +228,7 @@ class SockWrapper:
return _nb_clean(os.read, self.rsock.fileno(), 65536) return _nb_clean(os.read, self.rsock.fileno(), 65536)
except OSError, e: except OSError, e:
self.seterr('uread: %s' % e) self.seterr('uread: %s' % e)
return '' # unexpected error... we'll call it EOF return '' # unexpected error... we'll call it EOF
def fill(self): def fill(self):
if self.buf: if self.buf:
@ -243,7 +250,8 @@ class SockWrapper:
class Handler: class Handler:
def __init__(self, socks = None, callback = None):
def __init__(self, socks=None, callback=None):
self.ok = True self.ok = True
self.socks = socks or [] self.socks = socks or []
if callback: if callback:
@ -255,7 +263,7 @@ class Handler:
def callback(self): def callback(self):
log('--no callback defined-- %r\n' % self) log('--no callback defined-- %r\n' % self)
(r,w,x) = select.select(self.socks, [], [], 0) (r, w, x) = select.select(self.socks, [], [], 0)
for s in r: for s in r:
v = s.recv(4096) v = s.recv(4096)
if not v: if not v:
@ -265,6 +273,7 @@ class Handler:
class Proxy(Handler): class Proxy(Handler):
def __init__(self, wrap1, wrap2): def __init__(self, wrap1, wrap2):
Handler.__init__(self, [wrap1.rsock, wrap1.wsock, Handler.__init__(self, [wrap1.rsock, wrap1.wsock,
wrap2.rsock, wrap2.wsock]) wrap2.rsock, wrap2.wsock])
@ -272,9 +281,11 @@ class Proxy(Handler):
self.wrap2 = wrap2 self.wrap2 = wrap2
def pre_select(self, r, w, x): def pre_select(self, r, w, x):
if self.wrap1.shut_write: self.wrap2.noread() if self.wrap1.shut_write:
if self.wrap2.shut_write: self.wrap1.noread() self.wrap2.noread()
if self.wrap2.shut_write:
self.wrap1.noread()
if self.wrap1.connect_to: if self.wrap1.connect_to:
_add(w, self.wrap1.rsock) _add(w, self.wrap1.rsock)
elif self.wrap1.buf: elif self.wrap1.buf:
@ -305,13 +316,14 @@ class Proxy(Handler):
self.wrap2.buf = [] self.wrap2.buf = []
self.wrap2.noread() self.wrap2.noread()
if (self.wrap1.shut_read and self.wrap2.shut_read and if (self.wrap1.shut_read and self.wrap2.shut_read and
not self.wrap1.buf and not self.wrap2.buf): not self.wrap1.buf and not self.wrap2.buf):
self.ok = False self.ok = False
self.wrap1.nowrite() self.wrap1.nowrite()
self.wrap2.nowrite() self.wrap2.nowrite()
class Mux(Handler): class Mux(Handler):
def __init__(self, rsock, wsock): def __init__(self, rsock, wsock):
Handler.__init__(self, [rsock, wsock]) Handler.__init__(self, [rsock, wsock])
self.rsock = rsock self.rsock = rsock
@ -342,31 +354,31 @@ class Mux(Handler):
for b in self.outbuf: for b in self.outbuf:
total += len(b) total += len(b)
return total return total
def check_fullness(self): def check_fullness(self):
if self.fullness > 32768: if self.fullness > 32768:
if not self.too_full: if not self.too_full:
self.send(0, CMD_PING, 'rttest') self.send(0, CMD_PING, 'rttest')
self.too_full = True self.too_full = True
#ob = [] #ob = []
#for b in self.outbuf: # for b in self.outbuf:
# (s1,s2,c) = struct.unpack('!ccH', b[:4]) # (s1,s2,c) = struct.unpack('!ccH', b[:4])
# ob.append(c) # ob.append(c)
#log('outbuf: %d %r\n' % (self.amount_queued(), ob)) #log('outbuf: %d %r\n' % (self.amount_queued(), ob))
def send(self, channel, cmd, data): def send(self, channel, cmd, data):
data = str(data) data = str(data)
assert(len(data) <= 65535) assert(len(data) <= 65535)
p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data
self.outbuf.append(p) self.outbuf.append(p)
debug2(' > channel=%d cmd=%s len=%d (fullness=%d)\n' debug2(' > channel=%d cmd=%s len=%d (fullness=%d)\n'
% (channel, cmd_to_name.get(cmd,hex(cmd)), % (channel, cmd_to_name.get(cmd, hex(cmd)),
len(data), self.fullness)) len(data), self.fullness))
self.fullness += len(data) self.fullness += len(data)
def got_packet(self, channel, cmd, data): def got_packet(self, channel, cmd, data):
debug2('< channel=%d cmd=%s len=%d\n' debug2('< channel=%d cmd=%s len=%d\n'
% (channel, cmd_to_name.get(cmd,hex(cmd)), len(data))) % (channel, cmd_to_name.get(cmd, hex(cmd)), len(data)))
if cmd == CMD_PING: if cmd == CMD_PING:
self.send(0, CMD_PONG, data) self.send(0, CMD_PONG, data)
elif cmd == CMD_PONG: elif cmd == CMD_PONG:
@ -405,8 +417,8 @@ class Mux(Handler):
else: else:
callback = self.channels.get(channel) callback = self.channels.get(channel)
if not callback: if not callback:
log('warning: closed channel %d got cmd=%s len=%d\n' log('warning: closed channel %d got cmd=%s len=%d\n'
% (channel, cmd_to_name.get(cmd,hex(cmd)), len(data))) % (channel, cmd_to_name.get(cmd, hex(cmd)), len(data)))
else: else:
callback(cmd, data) callback(cmd, data)
@ -427,18 +439,18 @@ class Mux(Handler):
except OSError, e: except OSError, e:
raise Fatal('other end: %r' % e) raise Fatal('other end: %r' % e)
#log('<<< %r\n' % b) #log('<<< %r\n' % b)
if b == '': # EOF if b == '': # EOF
self.ok = False self.ok = False
if b: if b:
self.inbuf += b self.inbuf += b
def handle(self): def handle(self):
self.fill() self.fill()
#log('inbuf is: (%d,%d) %r\n' # log('inbuf is: (%d,%d) %r\n'
# % (self.want, len(self.inbuf), self.inbuf)) # % (self.want, len(self.inbuf), self.inbuf))
while 1: while 1:
if len(self.inbuf) >= (self.want or HDR_LEN): if len(self.inbuf) >= (self.want or HDR_LEN):
(s1,s2,channel,cmd,datalen) = \ (s1, s2, channel, cmd, datalen) = \
struct.unpack('!ccHHH', self.inbuf[:HDR_LEN]) struct.unpack('!ccHHH', self.inbuf[:HDR_LEN])
assert(s1 == 'S') assert(s1 == 'S')
assert(s2 == 'S') assert(s2 == 'S')
@ -457,7 +469,7 @@ class Mux(Handler):
_add(w, self.wsock) _add(w, self.wsock)
def callback(self): def callback(self):
(r,w,x) = select.select([self.rsock], [self.wsock], [], 0) (r, w, x) = select.select([self.rsock], [self.wsock], [], 0)
if self.rsock in r: if self.rsock in r:
self.handle() self.handle()
if self.outbuf and self.wsock in w: if self.outbuf and self.wsock in w:
@ -465,6 +477,7 @@ class Mux(Handler):
class MuxWrapper(SockWrapper): class MuxWrapper(SockWrapper):
def __init__(self, mux, channel): def __init__(self, mux, channel):
SockWrapper.__init__(self, mux.rsock, mux.wsock) SockWrapper.__init__(self, mux.rsock, mux.wsock)
self.mux = mux self.mux = mux
@ -478,7 +491,7 @@ class MuxWrapper(SockWrapper):
SockWrapper.__del__(self) SockWrapper.__del__(self)
def __repr__(self): def __repr__(self):
return 'SW%r:Mux#%d' % (self.peername,self.channel) return 'SW%r:Mux#%d' % (self.peername, self.channel)
def noread(self): def noread(self):
if not self.shut_read: if not self.shut_read:
@ -511,7 +524,7 @@ class MuxWrapper(SockWrapper):
def uread(self): def uread(self):
if self.shut_read: if self.shut_read:
return '' # EOF return '' # EOF
else: else:
return None # no data available right now return None # no data available right now
@ -523,7 +536,7 @@ class MuxWrapper(SockWrapper):
elif cmd == CMD_TCP_DATA: elif cmd == CMD_TCP_DATA:
self.buf.append(data) self.buf.append(data)
else: else:
raise Exception('unknown command %d (%d bytes)' raise Exception('unknown command %d (%d bytes)'
% (cmd, len(data))) % (cmd, len(data)))
@ -532,8 +545,8 @@ def connect_dst(family, ip, port):
outsock = socket.socket(family) outsock = socket.socket(family)
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
return SockWrapper(outsock, outsock, return SockWrapper(outsock, outsock,
connect_to = (ip,port), connect_to=(ip, port),
peername = '%s:%d' % (ip,port)) peername = '%s:%d' % (ip, port))
def runonce(handlers, mux): def runonce(handlers, mux):
@ -545,14 +558,14 @@ def runonce(handlers, mux):
handlers.remove(h) handlers.remove(h)
for s in handlers: for s in handlers:
s.pre_select(r,w,x) s.pre_select(r, w, x)
debug2('Waiting: %d r=%r w=%r x=%r (fullness=%d/%d)\n' debug2('Waiting: %d r=%r w=%r x=%r (fullness=%d/%d)\n'
% (len(handlers), _fds(r), _fds(w), _fds(x), % (len(handlers), _fds(r), _fds(w), _fds(x),
mux.fullness, mux.too_full)) mux.fullness, mux.too_full))
(r,w,x) = select.select(r,w,x) (r, w, x) = select.select(r, w, x)
debug2(' Ready: %d r=%r w=%r x=%r\n' debug2(' Ready: %d r=%r w=%r x=%r\n'
% (len(handlers), _fds(r), _fds(w), _fds(x))) % (len(handlers), _fds(r), _fds(w), _fds(x)))
ready = r+w+x ready = r + w + x
did = {} did = {}
for h in handlers: for h in handlers:
for s in h.socks: for s in h.socks:

View File

@ -1,8 +1,11 @@
import sys, os import sys
import os
from compat import ssubprocess from compat import ssubprocess
_p = None _p = None
def start_syslog(): def start_syslog():
global _p global _p
_p = ssubprocess.Popen(['logger', _p = ssubprocess.Popen(['logger',

View File

@ -1,5 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys, os, socket, select, struct, time import socket
import select
import struct
import time
listener = socket.socket() listener = socket.socket()
listener.bind(('127.0.0.1', 0)) listener.bind(('127.0.0.1', 0))
@ -23,7 +26,7 @@ while 1:
if count >= 16384: if count >= 16384:
count = 1 count = 1
print 'cli CREATING %d' % count print 'cli CREATING %d' % count
b = struct.pack('I', count) + 'x'*count b = struct.pack('I', count) + 'x' * count
remain[c] = count remain[c] = count
print 'cli >> %r' % len(b) print 'cli >> %r' % len(b)
c.send(b) c.send(b)
@ -32,13 +35,13 @@ while 1:
r = [listener] r = [listener]
time.sleep(0.1) time.sleep(0.1)
else: else:
r = [listener]+servers+clients r = [listener] + servers + clients
print 'select(%d)' % len(r) print 'select(%d)' % len(r)
r,w,x = select.select(r, [], [], 5) r, w, x = select.select(r, [], [], 5)
assert(r) assert(r)
for i in r: for i in r:
if i == listener: if i == listener:
s,addr = listener.accept() s, addr = listener.accept()
servers.append(s) servers.append(s)
elif i in servers: elif i in servers:
b = i.recv(4096) b = i.recv(4096)
@ -47,7 +50,7 @@ while 1:
assert(len(b) >= 4) assert(len(b) >= 4)
want = struct.unpack('I', b[:4])[0] want = struct.unpack('I', b[:4])[0]
b = b[4:] b = b[4:]
#i.send('y'*want) # i.send('y'*want)
else: else:
want = remain[i] want = remain[i]
if want < len(b): if want < len(b):
@ -64,7 +67,7 @@ while 1:
del remain[i] del remain[i]
else: else:
print 'srv >> %r' % len(b) print 'srv >> %r' % len(b)
i.send('y'*len(b)) i.send('y' * len(b))
if not want: if not want:
i.shutdown(socket.SHUT_WR) i.shutdown(socket.SHUT_WR)
elif i in clients: elif i in clients:

View File

@ -1,4 +1,6 @@
import sys, os, re, subprocess import re
import subprocess
def askpass(prompt): def askpass(prompt):
prompt = prompt.replace('"', "'") prompt = prompt.replace('"', "'")
@ -6,7 +8,7 @@ def askpass(prompt):
if 'yes/no' in prompt: if 'yes/no' in prompt:
return "yes" return "yes"
script=""" script = """
tell application "Finder" tell application "Finder"
activate activate
display dialog "%s" \ display dialog "%s" \

View File

@ -1,6 +1,11 @@
import sys, os, pty import sys
import os
import pty
from AppKit import * from AppKit import *
import my, models, askpass import my
import models
import askpass
def sshuttle_args(host, auto_nets, auto_hosts, dns, nets, debug, def sshuttle_args(host, auto_nets, auto_hosts, dns, nets, debug,
no_latency_control): no_latency_control):
@ -21,21 +26,25 @@ def sshuttle_args(host, auto_nets, auto_hosts, dns, nets, debug,
class _Callback(NSObject): class _Callback(NSObject):
def initWithFunc_(self, func): def initWithFunc_(self, func):
self = super(_Callback, self).init() self = super(_Callback, self).init()
self.func = func self.func = func
return self return self
def func_(self, obj): def func_(self, obj):
return self.func(obj) return self.func(obj)
class Callback: class Callback:
def __init__(self, func): def __init__(self, func):
self.obj = _Callback.alloc().initWithFunc_(func) self.obj = _Callback.alloc().initWithFunc_(func)
self.sel = self.obj.func_ self.sel = self.obj.func_
class Runner: class Runner:
def __init__(self, argv, logfunc, promptfunc, serverobj): def __init__(self, argv, logfunc, promptfunc, serverobj):
print 'in __init__' print 'in __init__'
self.id = argv self.id = argv
@ -49,7 +58,7 @@ class Runner:
self.logfunc('\nConnecting to %s.\n' % self.serverobj.host()) self.logfunc('\nConnecting to %s.\n' % self.serverobj.host())
print 'will run: %r' % argv print 'will run: %r' % argv
self.serverobj.setConnected_(False) self.serverobj.setConnected_(False)
pid,fd = pty.fork() pid, fd = pty.fork()
if pid == 0: if pid == 0:
# child # child
try: try:
@ -62,19 +71,20 @@ class Runner:
# parent # parent
self.pid = pid self.pid = pid
self.file = NSFileHandle.alloc()\ self.file = NSFileHandle.alloc()\
.initWithFileDescriptor_closeOnDealloc_(fd, True) .initWithFileDescriptor_closeOnDealloc_(fd, True)
self.cb = Callback(self.gotdata) self.cb = Callback(self.gotdata)
NSNotificationCenter.defaultCenter()\ NSNotificationCenter.defaultCenter()\
.addObserver_selector_name_object_(self.cb.obj, self.cb.sel, .addObserver_selector_name_object_(
NSFileHandleDataAvailableNotification, self.file) self.cb.obj, self.cb.sel,
NSFileHandleDataAvailableNotification, self.file)
self.file.waitForDataInBackgroundAndNotify() self.file.waitForDataInBackgroundAndNotify()
def __del__(self): def __del__(self):
self.wait() self.wait()
def _try_wait(self, options): def _try_wait(self, options):
if self.rv == None and self.pid > 0: if self.rv is None and self.pid > 0:
pid,code = os.waitpid(self.pid, options) pid, code = os.waitpid(self.pid, options)
if pid == self.pid: if pid == self.pid:
if os.WIFEXITED(code): if os.WIFEXITED(code):
self.rv = os.WEXITSTATUS(code) self.rv = os.WEXITSTATUS(code)
@ -88,14 +98,14 @@ class Runner:
def wait(self): def wait(self):
return self._try_wait(0) return self._try_wait(0)
def poll(self): def poll(self):
return self._try_wait(os.WNOHANG) return self._try_wait(os.WNOHANG)
def kill(self): def kill(self):
assert(self.pid > 0) assert(self.pid > 0)
print 'killing: pid=%r rv=%r' % (self.pid, self.rv) print 'killing: pid=%r rv=%r' % (self.pid, self.rv)
if self.rv == None: if self.rv is None:
self.logfunc('Disconnecting from %s.\n' % self.serverobj.host()) self.logfunc('Disconnecting from %s.\n' % self.serverobj.host())
os.kill(self.pid, 15) os.kill(self.pid, 15)
self.wait() self.wait()
@ -118,12 +128,13 @@ class Runner:
self.file.writeData_(my.Data(resp + '\n')) self.file.writeData_(my.Data(resp + '\n'))
self.file.waitForDataInBackgroundAndNotify() self.file.waitForDataInBackgroundAndNotify()
self.poll() self.poll()
#print 'gotdata done!' # print 'gotdata done!'
class SshuttleApp(NSObject): class SshuttleApp(NSObject):
def initialize(self): def initialize(self):
d = my.PList('UserDefaults') d = my.PList('UserDefaults')
my.Defaults().registerDefaults_(d) my.Defaults().registerDefaults_(d)
@ -137,7 +148,7 @@ class SshuttleController(NSObject):
serversController = objc.IBOutlet() serversController = objc.IBOutlet()
logField = objc.IBOutlet() logField = objc.IBOutlet()
latencyControlField = objc.IBOutlet() latencyControlField = objc.IBOutlet()
servers = [] servers = []
conns = {} conns = {}
@ -145,12 +156,14 @@ class SshuttleController(NSObject):
host = server.host() host = server.host()
print 'connecting %r' % host print 'connecting %r' % host
self.fill_menu() self.fill_menu()
def logfunc(msg): def logfunc(msg):
print 'log! (%d bytes)' % len(msg) print 'log! (%d bytes)' % len(msg)
self.logField.textStorage()\ self.logField.textStorage()\
.appendAttributedString_(NSAttributedString.alloc()\ .appendAttributedString_(NSAttributedString.alloc()
.initWithString_(msg)) .initWithString_(msg))
self.logField.didChangeText() self.logField.didChangeText()
def promptfunc(prompt): def promptfunc(prompt):
print 'prompt! %r' % prompt print 'prompt! %r' % prompt
return askpass.askpass(prompt) return askpass.askpass(prompt)
@ -164,12 +177,12 @@ class SshuttleController(NSObject):
manual_nets = [] manual_nets = []
noLatencyControl = (server.latencyControl() != models.LAT_INTERACTIVE) noLatencyControl = (server.latencyControl() != models.LAT_INTERACTIVE)
conn = Runner(sshuttle_args(host, conn = Runner(sshuttle_args(host,
auto_nets = nets_mode == models.NET_AUTO, auto_nets=nets_mode == models.NET_AUTO,
auto_hosts = server.autoHosts(), auto_hosts=server.autoHosts(),
dns = server.useDns(), dns=server.useDns(),
nets = manual_nets, nets=manual_nets,
debug = self.debugField.state(), debug=self.debugField.state(),
no_latency_control = noLatencyControl), no_latency_control=noLatencyControl),
logfunc=logfunc, promptfunc=promptfunc, logfunc=logfunc, promptfunc=promptfunc,
serverobj=server) serverobj=server)
self.conns[host] = conn self.conns[host] = conn
@ -182,8 +195,8 @@ class SshuttleController(NSObject):
conn.kill() conn.kill()
self.fill_menu() self.fill_menu()
self.logField.textStorage().setAttributedString_( self.logField.textStorage().setAttributedString_(
NSAttributedString.alloc().initWithString_('')) NSAttributedString.alloc().initWithString_(''))
@objc.IBAction @objc.IBAction
def cmd_connect(self, sender): def cmd_connect(self, sender):
server = sender.representedObject() server = sender.representedObject()
@ -213,6 +226,7 @@ class SshuttleController(NSObject):
it.setRepresentedObject_(obj) it.setRepresentedObject_(obj)
it.setTarget_(self) it.setTarget_(self)
it.setAction_(func) it.setAction_(func)
def addnote(name): def addnote(name):
additem(name, None, None) additem(name, None, None)
@ -271,8 +285,9 @@ class SshuttleController(NSObject):
sl = [] sl = []
for s in l: for s in l:
host = s.get('host', None) host = s.get('host', None)
if not host: continue if not host:
continue
nets = s.get('nets', []) nets = s.get('nets', [])
nl = [] nl = []
for n in nets: for n in nets:
@ -282,7 +297,7 @@ class SshuttleController(NSObject):
net.setSubnet_(subnet) net.setSubnet_(subnet)
net.setWidth_(width) net.setWidth_(width)
nl.append(net) nl.append(net)
autoNets = s.get('autoNets', models.NET_AUTO) autoNets = s.get('autoNets', models.NET_AUTO)
autoHosts = s.get('autoHosts', True) autoHosts = s.get('autoHosts', True)
useDns = s.get('useDns', autoNets == models.NET_ALL) useDns = s.get('useDns', autoNets == models.NET_ALL)
@ -302,11 +317,13 @@ class SshuttleController(NSObject):
l = [] l = []
for s in self.servers: for s in self.servers:
host = s.host() host = s.host()
if not host: continue if not host:
continue
nets = [] nets = []
for n in s.nets(): for n in s.nets():
subnet = n.subnet() subnet = n.subnet()
if not subnet: continue if not subnet:
continue
nets.append((subnet, n.width())) nets.append((subnet, n.width()))
d = dict(host=s.host(), d = dict(host=s.host(),
nets=nets, nets=nets,
@ -352,9 +369,9 @@ class SshuttleController(NSObject):
statusitem.setHighlightMode_(True) statusitem.setHighlightMode_(True)
statusitem.setMenu_(self.menu) statusitem.setMenu_(self.menu)
self.fill_menu() self.fill_menu()
models.configchange_callback = my.DelayedCallback(self.save_servers) models.configchange_callback = my.DelayedCallback(self.save_servers)
def sc(server): def sc(server):
if server.wantConnect(): if server.wantConnect():
self._connect(server) self._connect(server)

View File

@ -35,24 +35,29 @@ def _validate_width(v):
class SshuttleNet(NSObject): class SshuttleNet(NSObject):
def subnet(self): def subnet(self):
return getattr(self, '_k_subnet', None) return getattr(self, '_k_subnet', None)
def setSubnet_(self, v): def setSubnet_(self, v):
self._k_subnet = v self._k_subnet = v
config_changed() config_changed()
@objc_validator @objc_validator
def validateSubnet_error_(self, value, error): def validateSubnet_error_(self, value, error):
#print 'validateSubnet!' # print 'validateSubnet!'
return True, _validate_ip(value), error return True, _validate_ip(value), error
def width(self): def width(self):
return getattr(self, '_k_width', 24) return getattr(self, '_k_width', 24)
def setWidth_(self, v): def setWidth_(self, v):
self._k_width = v self._k_width = v
config_changed() config_changed()
@objc_validator @objc_validator
def validateWidth_error_(self, value, error): def validateWidth_error_(self, value, error):
#print 'validateWidth!' # print 'validateWidth!'
return True, _validate_width(value), error return True, _validate_width(value), error
NET_ALL = 0 NET_ALL = 0
@ -62,30 +67,37 @@ NET_MANUAL = 2
LAT_BANDWIDTH = 0 LAT_BANDWIDTH = 0
LAT_INTERACTIVE = 1 LAT_INTERACTIVE = 1
class SshuttleServer(NSObject): class SshuttleServer(NSObject):
def init(self): def init(self):
self = super(SshuttleServer, self).init() self = super(SshuttleServer, self).init()
config_changed() config_changed()
return self return self
def wantConnect(self): def wantConnect(self):
return getattr(self, '_k_wantconnect', False) return getattr(self, '_k_wantconnect', False)
def setWantConnect_(self, v): def setWantConnect_(self, v):
self._k_wantconnect = v self._k_wantconnect = v
self.setError_(None) self.setError_(None)
config_changed() config_changed()
if setconnect_callback: setconnect_callback(self) if setconnect_callback:
setconnect_callback(self)
def connected(self): def connected(self):
return getattr(self, '_k_connected', False) return getattr(self, '_k_connected', False)
def setConnected_(self, v): def setConnected_(self, v):
print 'setConnected of %r to %r' % (self, v) print 'setConnected of %r to %r' % (self, v)
self._k_connected = v self._k_connected = v
if v: self.setError_(None) # connected ok, so no error if v:
self.setError_(None) # connected ok, so no error
config_changed() config_changed()
def error(self): def error(self):
return getattr(self, '_k_error', None) return getattr(self, '_k_error', None)
def setError_(self, v): def setError_(self, v):
self._k_error = v self._k_error = v
config_changed() config_changed()
@ -107,40 +119,47 @@ class SshuttleServer(NSObject):
suffix = " (all traffic)" suffix = " (all traffic)"
elif an == NET_MANUAL: elif an == NET_MANUAL:
n = self.nets() n = self.nets()
suffix = ' (%d subnet%s)' % (len(n), len(n)!=1 and 's' or '') suffix = ' (%d subnet%s)' % (len(n), len(n) != 1 and 's' or '')
return self.host() + suffix return self.host() + suffix
def setTitle_(self, v): def setTitle_(self, v):
# title is always auto-generated # title is always auto-generated
config_changed() config_changed()
def host(self): def host(self):
return getattr(self, '_k_host', None) return getattr(self, '_k_host', None)
def setHost_(self, v): def setHost_(self, v):
self._k_host = v self._k_host = v
self.setTitle_(None) self.setTitle_(None)
config_changed() config_changed()
@objc_validator @objc_validator
def validateHost_error_(self, value, error): def validateHost_error_(self, value, error):
#print 'validatehost! %r %r %r' % (self, value, error) # print 'validatehost! %r %r %r' % (self, value, error)
while value.startswith('-'): while value.startswith('-'):
value = value[1:] value = value[1:]
return True, value, error return True, value, error
def nets(self): def nets(self):
return getattr(self, '_k_nets', []) return getattr(self, '_k_nets', [])
def setNets_(self, v): def setNets_(self, v):
self._k_nets = v self._k_nets = v
self.setTitle_(None) self.setTitle_(None)
config_changed() config_changed()
def netsHidden(self): def netsHidden(self):
#print 'checking netsHidden' # print 'checking netsHidden'
return self.autoNets() != NET_MANUAL return self.autoNets() != NET_MANUAL
def setNetsHidden_(self, v): def setNetsHidden_(self, v):
config_changed() config_changed()
#print 'setting netsHidden to %r' % v # print 'setting netsHidden to %r' % v
def autoNets(self): def autoNets(self):
return getattr(self, '_k_autoNets', NET_AUTO) return getattr(self, '_k_autoNets', NET_AUTO)
def setAutoNets_(self, v): def setAutoNets_(self, v):
self._k_autoNets = v self._k_autoNets = v
self.setNetsHidden_(-1) self.setNetsHidden_(-1)
@ -150,18 +169,21 @@ class SshuttleServer(NSObject):
def autoHosts(self): def autoHosts(self):
return getattr(self, '_k_autoHosts', True) return getattr(self, '_k_autoHosts', True)
def setAutoHosts_(self, v): def setAutoHosts_(self, v):
self._k_autoHosts = v self._k_autoHosts = v
config_changed() config_changed()
def useDns(self): def useDns(self):
return getattr(self, '_k_useDns', False) return getattr(self, '_k_useDns', False)
def setUseDns_(self, v): def setUseDns_(self, v):
self._k_useDns = v self._k_useDns = v
config_changed() config_changed()
def latencyControl(self): def latencyControl(self):
return getattr(self, '_k_latencyControl', LAT_INTERACTIVE) return getattr(self, '_k_latencyControl', LAT_INTERACTIVE)
def setLatencyControl_(self, v): def setLatencyControl_(self, v):
self._k_latencyControl = v self._k_latencyControl = v
config_changed() config_changed()

View File

@ -1,4 +1,4 @@
import sys, os import os
from AppKit import * from AppKit import *
import PyObjCTools.AppHelper import PyObjCTools.AppHelper
@ -44,11 +44,13 @@ def Defaults():
# #
def DelayedCallback(func, *args, **kwargs): def DelayedCallback(func, *args, **kwargs):
flag = [0] flag = [0]
def _go(): def _go():
if flag[0]: if flag[0]:
print 'running %r (flag=%r)' % (func, flag) print 'running %r (flag=%r)' % (func, flag)
flag[0] = 0 flag[0] = 0
func(*args, **kwargs) func(*args, **kwargs)
def call(): def call():
flag[0] += 1 flag[0] += 1
PyObjCTools.AppHelper.callAfter(_go) PyObjCTools.AppHelper.callAfter(_go)