PEP8 fixes.

This commit is contained in:
Brian May 2014-09-16 10:24:16 +10:00
parent 5529a04cc9
commit f1c79c7e92
16 changed files with 569 additions and 377 deletions

View File

@ -1,4 +1,5 @@
import sys, zlib
import sys
import zlib
z = zlib.decompressobj()
mainmod = sys.modules[__name__]

View File

@ -1,21 +1,30 @@
import struct, select, errno, re, signal, time
import struct
import errno
import re
import signal
import time
import compat.ssubprocess as ssubprocess
import helpers, ssnet, ssh, ssyslog
import helpers
import os
import ssnet
import ssh
import ssyslog
import sys
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
from helpers import log, debug1, debug2, debug3, Fatal, islocal
recvmsg = None
try:
# try getting recvmsg from python
import socket as pythonsocket
getattr(pythonsocket.socket,"recvmsg")
getattr(pythonsocket.socket, "recvmsg")
socket = pythonsocket
recvmsg = "python"
except AttributeError:
# try getting recvmsg from socket_ext library
try:
import socket_ext
getattr(socket_ext.socket,"recvmsg")
getattr(socket_ext.socket, "recvmsg")
socket = socket_ext
recvmsg = "socket_ext"
except ImportError:
@ -23,6 +32,7 @@ except AttributeError:
_extra_fd = os.open('/dev/null', os.O_RDONLY)
def got_signal(signum, frame):
log('exiting on signal %d\n' % signum)
sys.exit(1)
@ -40,60 +50,64 @@ IPV6_RECVORIGDSTADDR = IPV6_ORIGDSTADDR
if recvmsg == "python":
def recv_udp(listener, bufsize):
debug3('Accept UDP python using recvmsg.\n')
data, ancdata, msg_flags, srcip = listener.recvmsg(4096,socket.CMSG_SPACE(24))
data, ancdata, msg_flags, srcip = listener.recvmsg(
4096, socket.CMSG_SPACE(24))
dstip = None
family = None
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_IP and cmsg_type == IP_ORIGDSTADDR:
family,port = struct.unpack('=HH', cmsg_data[0:4])
family, port = struct.unpack('=HH', cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET:
start = 4
length = 4
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, cmsg_data[start:start+length])
raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
dstip = (ip, port)
break
elif cmsg_level == SOL_IPV6 and cmsg_type == IPV6_ORIGDSTADDR:
family,port = struct.unpack('=HH', cmsg_data[0:4])
family, port = struct.unpack('=HH', cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET6:
start = 8
length = 16
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, cmsg_data[start:start+length])
raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
dstip = (ip, port)
break
return (srcip, dstip, data)
elif recvmsg == "socket_ext":
def recv_udp(listener, bufsize):
debug3('Accept UDP using socket_ext recvmsg.\n')
srcip, data, adata, flags = listener.recvmsg((bufsize,),socket.CMSG_SPACE(24))
srcip, data, adata, flags = listener.recvmsg(
(bufsize,), socket.CMSG_SPACE(24))
dstip = None
family = None
for a in adata:
if a.cmsg_level == socket.SOL_IP and a.cmsg_type == IP_ORIGDSTADDR:
family,port = struct.unpack('=HH', a.cmsg_data[0:4])
family, port = struct.unpack('=HH', a.cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET:
start = 4
length = 4
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, a.cmsg_data[start:start+length])
raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(
family, a.cmsg_data[start:start + length])
dstip = (ip, port)
break
elif a.cmsg_level == SOL_IPV6 and a.cmsg_type == IPV6_ORIGDSTADDR:
family,port = struct.unpack('=HH', a.cmsg_data[0:4])
family, port = struct.unpack('=HH', a.cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET6:
start = 8
length = 16
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, a.cmsg_data[start:start+length])
raise Fatal("Unsupported socket type '%s'" % family)
ip = socket.inet_ntop(
family, a.cmsg_data[start:start + length])
dstip = (ip, port)
break
return (srcip, dstip, data[0])
@ -142,7 +156,7 @@ def daemonize():
if os.fork():
os._exit(0)
outfd = os.open(_pidname, os.O_WRONLY|os.O_CREAT|os.O_EXCL, 0666)
outfd = os.open(_pidname, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0666)
try:
os.write(outfd, '%d\n' % os.getpid())
finally:
@ -152,7 +166,7 @@ def daemonize():
# Normal exit when killed, or try/finally won't work and the pidfile won't
# be deleted.
signal.signal(signal.SIGTERM, got_signal)
si = open('/dev/null', 'r+')
os.dup2(si.fileno(), 0)
os.dup2(si.fileno(), 1)
@ -177,10 +191,10 @@ def original_dst(sock):
SOCKADDR_MIN = 16
sockaddr_in = sock.getsockopt(socket.SOL_IP,
SO_ORIGINAL_DST, SOCKADDR_MIN)
(proto, port, a,b,c,d) = struct.unpack('!HHBBBB', sockaddr_in[:8])
(proto, port, a, b, c, d) = struct.unpack('!HHBBBB', sockaddr_in[:8])
assert(socket.htons(proto) == socket.AF_INET)
ip = '%d.%d.%d.%d' % (a,b,c,d)
return (ip,port)
ip = '%d.%d.%d.%d' % (a, b, c, d)
return (ip, port)
except socket.error, e:
if e.args[0] == errno.ENOPROTOOPT:
return sock.getsockname()
@ -201,9 +215,15 @@ class MultiListener:
def add_handler(self, handlers, callback, method, mux):
if self.v6:
handlers.append(Handler([self.v6], lambda: callback(self.v6, method, mux, handlers)))
handlers.append(
Handler(
[self.v6],
lambda: callback(self.v6, method, mux, handlers)))
if self.v4:
handlers.append(Handler([self.v4], lambda: callback(self.v4, method, mux, handlers)))
handlers.append(
Handler(
[self.v4],
lambda: callback(self.v4, method, mux, handlers)))
def listen(self, backlog):
if self.v6:
@ -239,15 +259,17 @@ class MultiListener:
class FirewallClient:
def __init__(self, port_v6, port_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method, udp):
def __init__(self, port_v6, port_v4, subnets_include, subnets_exclude,
dnsport_v6, dnsport_v4, method, udp):
self.auto_nets = []
self.subnets_include = subnets_include
self.subnets_exclude = subnets_exclude
argvbase = ([sys.argv[1], sys.argv[0], sys.argv[1]] +
['-v'] * (helpers.verbose or 0) +
['--firewall', str(port_v6), str(port_v4),
str(dnsport_v6), str(dnsport_v4),
method, str(int(udp))])
str(dnsport_v6), str(dnsport_v4),
method, str(int(udp))])
if ssyslog._p:
argvbase += ['--syslog']
argv_tries = [
@ -260,7 +282,8 @@ class FirewallClient:
# because stupid Linux 'su' requires that stdin be attached to a tty.
# Instead, attach a *bidirectional* socket to its stdout, and use
# that for talking in both directions.
(s1,s2) = socket.socketpair()
(s1, s2) = socket.socketpair()
def setup():
# run in the child process
s2.close()
@ -295,9 +318,9 @@ class FirewallClient:
def start(self):
self.pfile.write('ROUTES\n')
for (family,ip,width) in self.subnets_include+self.auto_nets:
for (family, ip, width) in self.subnets_include + self.auto_nets:
self.pfile.write('%d,%d,0,%s\n' % (family, width, ip))
for (family,ip,width) in self.subnets_exclude:
for (family, ip, width) in self.subnets_exclude:
self.pfile.write('%d,%d,1,%s\n' % (family, width, ip))
self.pfile.write('GO\n')
self.pfile.flush()
@ -321,14 +344,16 @@ class FirewallClient:
dnsreqs = {}
udp_by_src = {}
def expire_connections(now, mux):
for chan,timeout in dnsreqs.items():
for chan, timeout in dnsreqs.items():
if timeout < now:
debug3('expiring dnsreqs channel=%d\n' % chan)
del mux.channels[chan]
del dnsreqs[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
for peer,(chan,timeout) in udp_by_src.items():
for peer, (chan, timeout) in udp_by_src.items():
if timeout < now:
debug3('expiring UDP channel channel=%d peer=%r\n' % (chan, peer))
mux.send(chan, ssnet.CMD_UDP_CLOSE, '')
@ -340,14 +365,14 @@ def expire_connections(now, mux):
def onaccept_tcp(listener, method, mux, handlers):
global _extra_fd
try:
sock,srcip = listener.accept()
sock, srcip = listener.accept()
except socket.error, e:
if e.args[0] in [errno.EMFILE, errno.ENFILE]:
debug1('Rejected incoming connection: too many open files!\n')
# free up an fd so we can eat the connection
os.close(_extra_fd)
try:
sock,srcip = listener.accept()
sock, srcip = listener.accept()
sock.close()
finally:
_extra_fd = os.open('/dev/null', os.O_RDONLY)
@ -355,11 +380,11 @@ def onaccept_tcp(listener, method, mux, handlers):
else:
raise
if method == "tproxy":
dstip = sock.getsockname();
dstip = sock.getsockname()
else:
dstip = original_dst(sock)
debug1('Accept TCP: %s:%r -> %s:%r.\n' % (srcip[0],srcip[1],
dstip[0],dstip[1]))
debug1('Accept TCP: %s:%r -> %s:%r.\n' % (srcip[0], srcip[1],
dstip[0], dstip[1]))
if dstip[1] == sock.getsockname()[1] and islocal(dstip[0], sock.family):
debug1("-- ignored: that's my address!\n")
sock.close()
@ -369,16 +394,17 @@ def onaccept_tcp(listener, method, mux, handlers):
log('warning: too many open channels. Discarded connection.\n')
sock.close()
return
mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' % (sock.family, dstip[0], dstip[1]))
mux.send(chan, ssnet.CMD_TCP_CONNECT, '%d,%s,%s' %
(sock.family, dstip[0], dstip[1]))
outwrap = MuxWrapper(mux, chan)
handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
expire_connections(time.time(), mux)
def udp_done(chan, data, method, family, dstip):
(src,srcport,data) = data.split(",",2)
srcip = (src,int(srcport))
debug3('doing send from %r to %r\n' % (srcip,dstip,))
(src, srcport, data) = data.split(",", 2)
srcip = (src, int(srcport))
debug3('doing send from %r to %r\n' % (srcip, dstip,))
try:
sender = socket.socket(family, socket.SOCK_DGRAM)
@ -388,36 +414,39 @@ def udp_done(chan, data, method, family, dstip):
sender.sendto(data, dstip)
sender.close()
except socket.error, e:
debug1('-- ignored socket error sending UDP data: %r\n'%e)
debug1('-- ignored socket error sending UDP data: %r\n' % e)
def onaccept_udp(listener, method, mux, handlers):
now = time.time()
srcip, dstip, data = recv_udp(listener, 4096)
if not dstip:
debug1("-- ignored UDP from %r: couldn't determine destination IP address\n" % (srcip,))
debug1(
"-- ignored UDP from %r: "
"couldn't determine destination IP address\n" % (srcip,))
return
debug1('Accept UDP: %r -> %r.\n' % (srcip,dstip,))
debug1('Accept UDP: %r -> %r.\n' % (srcip, dstip,))
if srcip in udp_by_src:
chan,timeout = udp_by_src[srcip]
chan, timeout = udp_by_src[srcip]
else:
chan = mux.next_channel()
mux.channels[chan] = lambda cmd,data: udp_done(chan, data, method, listener.family, dstip=srcip)
mux.channels[chan] = lambda cmd, data: udp_done(
chan, data, method, listener.family, dstip=srcip)
mux.send(chan, ssnet.CMD_UDP_OPEN, listener.family)
udp_by_src[srcip] = chan,now+30
udp_by_src[srcip] = chan, now + 30
hdr = "%s,%r,"%(dstip[0], dstip[1])
mux.send(chan, ssnet.CMD_UDP_DATA, hdr+data)
hdr = "%s,%r," % (dstip[0], dstip[1])
mux.send(chan, ssnet.CMD_UDP_DATA, hdr + data)
expire_connections(now, mux)
def dns_done(chan, data, method, sock, srcip, dstip, mux):
debug3('dns_done: channel=%d src=%r dst=%r\n' % (chan,srcip,dstip))
debug3('dns_done: channel=%d src=%r dst=%r\n' % (chan, srcip, dstip))
del mux.channels[chan]
del dnsreqs[chan]
if method == "tproxy":
debug3('doing send from %r to %r\n' % (srcip,dstip,))
debug3('doing send from %r to %r\n' % (srcip, dstip,))
sender = socket.socket(sock.family, socket.SOCK_DGRAM)
sender.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sender.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
@ -433,17 +462,21 @@ def ondns(listener, method, mux, handlers):
now = time.time()
srcip, dstip, data = recv_udp(listener, 4096)
if method == "tproxy" and not dstip:
debug1("-- ignored UDP from %r: couldn't determine destination IP address\n" % (srcip,))
debug1(
"-- ignored UDP from %r: "
"couldn't determine destination IP address\n" % (srcip,))
return
debug1('DNS request from %r to %r: %d bytes\n' % (srcip,dstip,len(data)))
debug1('DNS request from %r to %r: %d bytes\n' % (srcip, dstip, len(data)))
chan = mux.next_channel()
dnsreqs[chan] = now+30
dnsreqs[chan] = now + 30
mux.send(chan, ssnet.CMD_DNS_REQ, data)
mux.channels[chan] = lambda cmd,data: dns_done(chan, data, method, listener, srcip=dstip, dstip=srcip, mux=mux)
mux.channels[chan] = lambda cmd, data: dns_done(
chan, data, method, listener, srcip=dstip, dstip=srcip, mux=mux)
expire_connections(now, mux)
def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_control,
def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
python, latency_control,
dns_listener, method, seed_hosts, auto_nets,
syslog, daemon):
handlers = []
@ -454,9 +487,10 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
debug1('connecting to server...\n')
try:
(serverproc, serversock) = ssh.connect(ssh_cmd, remotename, python,
stderr=ssyslog._p and ssyslog._p.stdin,
options=dict(latency_control=latency_control, method=method))
(serverproc, serversock) = ssh.connect(
ssh_cmd, remotename, python,
stderr=ssyslog._p and ssyslog._p.stdin,
options=dict(latency_control=latency_control, method=method))
except socket.error, e:
if e.args[0] == errno.EPIPE:
raise Fatal("failed to establish ssh session (1)")
@ -466,7 +500,7 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
handlers.append(mux)
expected = 'SSHUTTLE0001'
try:
v = 'x'
while v and v != '\0':
@ -480,14 +514,14 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
raise Fatal("failed to establish ssh session (2)")
else:
raise
rv = serverproc.poll()
if rv:
raise Fatal('server died with error code %d' % rv)
if initstring != expected:
raise Fatal('expected server init string %r; got %r'
% (expected, initstring))
% (expected, initstring))
debug1('connected.\n')
print 'Connected.'
sys.stdout.flush()
@ -501,8 +535,8 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
def onroutes(routestr):
if auto_nets:
for line in routestr.strip().split('\n'):
(family,ip,width) = line.split(',', 2)
fw.auto_nets.append((family,ip,int(width)))
(family, ip, width) = line.split(',', 2)
fw.auto_nets.append((family, ip, int(width)))
# we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection!
@ -519,7 +553,7 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
debug2('got host list: %r\n' % hostlist)
for line in hostlist.strip().split():
if line:
name,ip = line.split(',', 1)
name, ip = line.split(',', 1)
fw.sethostip(name, ip)
mux.got_host_list = onhostlist
@ -531,15 +565,15 @@ def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_c
if dns_listener:
dns_listener.add_handler(handlers, ondns, method, mux)
if seed_hosts != None:
if seed_hosts is not None:
debug1('seed_hosts: %r\n' % seed_hosts)
mux.send(0, ssnet.CMD_HOST_REQ, '\n'.join(seed_hosts))
while 1:
rv = serverproc.poll()
if rv:
raise Fatal('server died with error code %d' % rv)
ssnet.runonce(handlers, mux)
if latency_control:
mux.check_fullness()
@ -562,7 +596,7 @@ def main(listenip_v6, listenip_v4,
debug1('Starting sshuttle proxy.\n')
if recvmsg is not None:
debug1("recvmsg %s support enabled.\n"%recvmsg)
debug1("recvmsg %s support enabled.\n" % recvmsg)
if method == "tproxy":
if recvmsg is not None:
@ -580,10 +614,10 @@ def main(listenip_v6, listenip_v4,
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
# if both ports given, no need to search for a spare port
ports = [ 0, ]
ports = [0, ]
else:
# if at least one port missing, we have to search
ports = xrange(12300,9000,-1)
ports = xrange(12300, 9000, -1)
# search for free ports and try to bind
last_e = None
@ -606,7 +640,7 @@ def main(listenip_v6, listenip_v4,
lv6 = listenip_v6
redirectport_v6 = lv6[1]
elif listenip_v6:
lv6 = (listenip_v6[0],port)
lv6 = (listenip_v6[0], port)
redirectport_v6 = port
else:
lv6 = None
@ -616,7 +650,7 @@ def main(listenip_v6, listenip_v4,
lv4 = listenip_v4
redirectport_v4 = lv4[1]
elif listenip_v4:
lv4 = (listenip_v4[0],port)
lv4 = (listenip_v4[0], port)
redirectport_v4 = port
else:
lv4 = None
@ -646,20 +680,20 @@ def main(listenip_v6, listenip_v4,
if dns:
# search for spare port for DNS
debug2('Binding DNS:')
ports = xrange(12300,9000,-1)
ports = xrange(12300, 9000, -1)
for port in ports:
debug2(' %d' % port)
dns_listener = MultiListener(socket.SOCK_DGRAM)
if listenip_v6:
lv6 = (listenip_v6[0],port)
lv6 = (listenip_v6[0], port)
dnsport_v6 = port
else:
lv6 = None
dnsport_v6 = 0
if listenip_v4:
lv4 = (listenip_v4[0],port)
lv4 = (listenip_v4[0], port)
dnsport_v4 = port
else:
lv4 = None
@ -684,20 +718,23 @@ def main(listenip_v6, listenip_v4,
dnsport_v4 = 0
dns_listener = None
fw = FirewallClient(redirectport_v6, redirectport_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method, udp)
fw = FirewallClient(redirectport_v6, redirectport_v4, subnets_include,
subnets_exclude, dnsport_v6, dnsport_v4, method, udp)
if fw.method == "tproxy":
tcp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if udp_listener:
udp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if udp_listener.v4 is not None:
udp_listener.v4.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
udp_listener.v4.setsockopt(
socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
if udp_listener.v6 is not None:
udp_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1)
if dns_listener:
dns_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if dns_listener.v4 is not None:
dns_listener.v4.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
dns_listener.v4.setsockopt(
socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
if dns_listener.v6 is not None:
dns_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1)

View File

@ -1,7 +1,13 @@
import re, errno, socket, select, struct
import errno
import socket
import select
import struct
import compat.ssubprocess as ssubprocess
import helpers, ssyslog
from helpers import *
import ssyslog
import sys
import os
from helpers import log, debug1, debug3, islocal, Fatal, family_to_string, \
resolvconf_nameservers
# python doesn't have a definition for this
IPPROTO_DIVERT = 254
@ -20,9 +26,9 @@ def ipt_chain_exists(family, table, name):
elif family == socket.AF_INET:
cmd = 'iptables'
else:
raise Exception('Unsupported family "%s"'%family_to_string(family))
raise Exception('Unsupported family "%s"' % family_to_string(family))
argv = [cmd, '-t', table, '-nL']
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE)
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
for line in p.stdout:
if line.startswith('Chain %s ' % name):
return True
@ -37,7 +43,7 @@ def _ipt(family, table, *args):
elif family == socket.AF_INET:
argv = ['iptables', '-t', table] + list(args)
else:
raise Exception('Unsupported family "%s"'%family_to_string(family))
raise Exception('Unsupported family "%s"' % family_to_string(family))
debug1('>> %s\n' % ' '.join(argv))
rv = ssubprocess.call(argv)
if rv:
@ -45,6 +51,8 @@ def _ipt(family, table, *args):
_no_ttl_module = False
def _ipt_ttl(family, *args):
global _no_ttl_module
if not _no_ttl_module:
@ -72,13 +80,17 @@ def _ipt_ttl(family, *args):
def do_iptables_nat(port, dnsport, family, subnets, udp):
# only ipv4 supported with NAT
if family != socket.AF_INET:
raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family))
raise Exception(
'Address family "%s" unsupported by nat method'
% family_to_string(family))
if udp:
raise Exception("UDP not supported by nat method")
table = "nat"
def ipt(*args):
return _ipt(family, table, *args)
def ipt_ttl(*args):
return _ipt_ttl(family, table, *args)
@ -103,20 +115,21 @@ def do_iptables_nat(port, dnsport, family, subnets, udp):
# to least-specific, and at any given level of specificity, we want
# excludes to come first. That's why the columns are in such a non-
# intuitive order.
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
ipt('-A', chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp')
else:
ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/%s' % (snet,swidth),
'--dest', '%s/%s' % (snet, swidth),
'-p', 'tcp',
'--to-ports', str(port))
if dnsport:
nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist):
for f, ip in filter(lambda i: i[0] == family, nslist):
ipt_ttl('-A', chain, '-j', 'REDIRECT',
'--dest', '%s/32' % ip,
'-p', 'udp',
@ -126,15 +139,19 @@ def do_iptables_nat(port, dnsport, family, subnets, udp):
def do_iptables_tproxy(port, dnsport, family, subnets, udp):
if family not in [socket.AF_INET, socket.AF_INET6]:
raise Exception('Address family "%s" unsupported by tproxy method'%family_to_string(family))
raise Exception(
'Address family "%s" unsupported by tproxy method'
% family_to_string(family))
table = "mangle"
def ipt(*args):
return _ipt(family, table, *args)
def ipt_ttl(*args):
return _ipt_ttl(family, table, *args)
mark_chain = 'sshuttle-m-%s' % port
mark_chain = 'sshuttle-m-%s' % port
tproxy_chain = 'sshuttle-t-%s' % port
divert_chain = 'sshuttle-d-%s' % port
@ -165,65 +182,70 @@ def do_iptables_tproxy(port, dnsport, family, subnets, udp):
ipt('-A', divert_chain, '-j', 'MARK', '--set-mark', '1')
ipt('-A', divert_chain, '-j', 'ACCEPT')
ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'tcp', '-p', 'tcp')
'-m', 'tcp', '-p', 'tcp')
if subnets and udp:
ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'udp', '-p', 'udp')
'-m', 'udp', '-p', 'udp')
if dnsport:
nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist):
for f, ip in filter(lambda i: i[0] == family, nslist):
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53')
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1',
'--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53',
'--on-port', str(dnsport))
'--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53',
'--on-port', str(dnsport))
if subnets:
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
else:
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'tcp', '-p', 'tcp')
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'tcp', '-p', 'tcp',
'--on-port', str(port))
ipt('-A', mark_chain, '-j', 'MARK',
'--set-mark', '1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp')
ipt('-A', tproxy_chain, '-j', 'TPROXY',
'--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'tcp', '-p', 'tcp',
'--on-port', str(port))
if sexclude and udp:
ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp')
ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp')
elif udp:
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'udp', '-p', 'udp')
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'udp', '-p', 'udp',
'--on-port', str(port))
ipt('-A', mark_chain, '-j', 'MARK',
'--set-mark', '1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp')
ipt('-A', tproxy_chain, '-j', 'TPROXY',
'--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet, swidth),
'-m', 'udp', '-p', 'udp',
'--on-port', str(port))
def ipfw_rule_exists(n):
argv = ['ipfw', 'list']
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE)
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
found = False
for line in p.stdout:
if line.startswith('%05d ' % n):
if not ('ipttl 42' in line
or ('skipto %d' % (n+1)) in line
or ('skipto %d' % (n + 1)) in line
or 'check-state' in line):
log('non-sshuttle ipfw rule: %r\n' % line.strip())
raise Fatal('non-sshuttle ipfw rule #%d already exists!' % n)
@ -235,12 +257,14 @@ def ipfw_rule_exists(n):
_oldctls = {}
def _fill_oldctls(prefix):
argv = ['sysctl', prefix]
p = ssubprocess.Popen(argv, stdout = ssubprocess.PIPE)
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
for line in p.stdout:
assert(line[-1] == '\n')
(k,v) = line[:-1].split(': ', 1)
(k, v) = line[:-1].split(': ', 1)
_oldctls[k] = v
rv = p.wait()
if rv:
@ -252,10 +276,12 @@ def _fill_oldctls(prefix):
def _sysctl_set(name, val):
argv = ['sysctl', '-w', '%s=%s' % (name, val)]
debug1('>> %s\n' % ' '.join(argv))
return ssubprocess.call(argv, stdout = open('/dev/null', 'w'))
return ssubprocess.call(argv, stdout=open('/dev/null', 'w'))
_changedctls = []
def sysctl_set(name, val, permanent=False):
PREFIX = 'net.inet.ip'
assert(name.startswith(PREFIX + '.'))
@ -268,7 +294,7 @@ def sysctl_set(name, val, permanent=False):
oldval = _oldctls[name]
if val != oldval:
rv = _sysctl_set(name, val)
if rv==0 and permanent:
if rv == 0 and permanent:
debug1('>> ...saving permanently in /etc/sysctl.conf\n')
f = open('/etc/sysctl.conf', 'a')
f.write('\n'
@ -293,9 +319,11 @@ def _udp_repack(p, src, dst):
_real_dns_server = [None]
def _handle_diversion(divertsock, dnsport):
p,tag = divertsock.recvfrom(4096)
src,dst = _udp_unpack(p)
p, tag = divertsock.recvfrom(4096)
src, dst = _udp_unpack(p)
debug3('got diverted packet from %r to %r\n' % (src, dst))
if dst[1] == 53:
# outgoing DNS
@ -311,7 +339,7 @@ def _handle_diversion(divertsock, dnsport):
assert(0)
newp = _udp_repack(p, src, dst)
divertsock.sendto(newp, tag)
def ipfw(*args):
argv = ['ipfw', '-q'] + list(args)
@ -324,12 +352,14 @@ def ipfw(*args):
def do_ipfw(port, dnsport, family, subnets, udp):
# IPv6 not supported
if family not in [socket.AF_INET, ]:
raise Exception('Address family "%s" unsupported by ipfw method'%family_to_string(family))
raise Exception(
'Address family "%s" unsupported by ipfw method'
% family_to_string(family))
if udp:
raise Exception("UDP not supported by ipfw method")
sport = str(port)
xsport = str(port+1)
xsport = str(port + 1)
# cleanup any existing rules
if ipfw_rule_exists(port):
@ -360,15 +390,16 @@ def do_ipfw(port, dnsport, family, subnets, udp):
if subnets:
# create new subnet entries
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
for f, swidth, sexclude, snet \
in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
ipfw('add', sport, 'skipto', xsport,
'log', 'tcp',
'from', 'any', 'to', '%s/%s' % (snet,swidth))
'from', 'any', 'to', '%s/%s' % (snet, swidth))
else:
ipfw('add', sport, 'fwd', '127.0.0.1,%d' % port,
'log', 'tcp',
'from', 'any', 'to', '%s/%s' % (snet,swidth),
'from', 'any', 'to', '%s/%s' % (snet, swidth),
'not', 'ipttl', '42', 'keep-state', 'setup')
# This part is much crazier than it is on Linux, because MacOS (at least
@ -403,10 +434,10 @@ def do_ipfw(port, dnsport, family, subnets, udp):
if dnsport:
divertsock = socket.socket(socket.AF_INET, socket.SOCK_RAW,
IPPROTO_DIVERT)
divertsock.bind(('0.0.0.0', port)) # IP field is ignored
divertsock.bind(('0.0.0.0', port)) # IP field is ignored
nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist):
for f, ip in filter(lambda i: i[0] == family, nslist):
# relabel and then catch outgoing DNS requests
ipfw('add', sport, 'divert', sport,
'log', 'udp',
@ -420,14 +451,14 @@ def do_ipfw(port, dnsport, family, subnets, udp):
def do_wait():
while 1:
r,w,x = select.select([sys.stdin, divertsock], [], [])
r, w, x = select.select([sys.stdin, divertsock], [], [])
if divertsock in r:
_handle_diversion(divertsock, dnsport)
if sys.stdin in r:
return
else:
do_wait = None
return do_wait
@ -440,10 +471,12 @@ def program_exists(name):
hostmap = {}
def rewrite_etc_hosts(port):
HOSTSFILE='/etc/hosts'
BAKFILE='%s.sbak' % HOSTSFILE
APPEND='# sshuttle-firewall-%d AUTOCREATED' % port
HOSTSFILE = '/etc/hosts'
BAKFILE = '%s.sbak' % HOSTSFILE
APPEND = '# sshuttle-firewall-%d AUTOCREATED' % port
old_content = ''
st = None
try:
@ -462,8 +495,8 @@ def rewrite_etc_hosts(port):
if line.find(APPEND) >= 0:
continue
f.write('%s\n' % line)
for (name,ip) in sorted(hostmap.items()):
f.write('%-30s %s\n' % ('%s %s' % (ip,name), APPEND))
for (name, ip) in sorted(hostmap.items()):
f.write('%-30s %s\n' % ('%s %s' % (ip, name), APPEND))
f.close()
if st:
@ -517,7 +550,7 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
elif method == "ipfw":
do_it = do_ipfw
else:
raise Exception('Unknown method "%s"'%method)
raise Exception('Unknown method "%s"' % method)
# because of limitations of the 'su' command, the *real* stdin/stdout
# are both attached to stdout initially. Clone stdout into stdin so we
@ -528,8 +561,8 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
ssyslog.start_syslog()
ssyslog.stderr_to_syslog()
debug1('firewall manager ready method %s.\n'%method)
sys.stdout.write('READY %s\n'%method)
debug1('firewall manager ready method %s.\n' % method)
sys.stdout.write('READY %s\n' % method)
sys.stdout.flush()
# ctrl-c shouldn't be passed along to me. When the main sshuttle dies,
@ -553,29 +586,31 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
elif line == 'GO\n':
break
try:
(family,width,exclude,ip) = line.strip().split(',', 3)
(family, width, exclude, ip) = line.strip().split(',', 3)
except:
raise Fatal('firewall: expected route or GO but got %r' % line)
subnets.append((int(family), int(width), bool(int(exclude)), ip))
try:
if line:
debug1('firewall manager: starting transproxy.\n')
subnets_v6 = filter(lambda i: i[0]==socket.AF_INET6, subnets)
subnets_v6 = filter(lambda i: i[0] == socket.AF_INET6, subnets)
if port_v6:
do_wait = do_it(port_v6, dnsport_v6, socket.AF_INET6, subnets_v6, udp)
do_wait = do_it(
port_v6, dnsport_v6, socket.AF_INET6, subnets_v6, udp)
elif len(subnets_v6) > 0:
debug1("IPv6 subnets defined but IPv6 disabled\n")
subnets_v4 = filter(lambda i: i[0]==socket.AF_INET, subnets)
subnets_v4 = filter(lambda i: i[0] == socket.AF_INET, subnets)
if port_v4:
do_wait = do_it(port_v4, dnsport_v4, socket.AF_INET, subnets_v4, udp)
do_wait = do_it(
port_v4, dnsport_v4, socket.AF_INET, subnets_v4, udp)
elif len(subnets_v4) > 0:
debug1('IPv4 subnets defined but IPv4 disabled\n')
sys.stdout.write('STARTED\n')
try:
sys.stdout.flush()
except IOError:
@ -587,10 +622,11 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
# to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important!
while 1:
if do_wait: do_wait()
if do_wait:
do_wait()
line = sys.stdin.readline(128)
if line.startswith('HOST '):
(name,ip) = line[5:].strip().split(',', 1)
(name, ip) = line[5:].strip().split(',', 1)
hostmap[name] = ip
rewrite_etc_hosts(port_v6 or port_v4)
elif line:

View File

@ -1,8 +1,11 @@
import sys, os, socket, errno
import sys
import socket
import errno
logprefix = ''
verbose = 0
def log(s):
try:
sys.stdout.flush()
@ -13,14 +16,17 @@ def log(s):
# our tty closes. That sucks, but it's no reason to abort the program.
pass
def debug1(s):
if verbose >= 1:
log(s)
def debug2(s):
if verbose >= 2:
log(s)
def debug3(s):
if verbose >= 3:
log(s)
@ -43,9 +49,9 @@ def resolvconf_nameservers():
words = line.lower().split()
if len(words) >= 2 and words[0] == 'nameserver':
if ':' in words[1]:
l.append((socket.AF_INET6,words[1]))
l.append((socket.AF_INET6, words[1]))
else:
l.append((socket.AF_INET,words[1]))
l.append((socket.AF_INET, words[1]))
return l
@ -58,10 +64,10 @@ def resolvconf_random_nameserver():
random.shuffle(l)
return l[0]
else:
return (socket.AF_INET,'127.0.0.1')
return (socket.AF_INET, '127.0.0.1')
def islocal(ip,family):
def islocal(ip, family):
sock = socket.socket(family)
try:
try:
@ -83,4 +89,3 @@ def family_to_string(family):
return "AF_INET"
else:
return str(family)

View File

@ -1,12 +1,18 @@
import time, socket, re, select, errno
import time
import socket
import re
import select
import errno
import os
import sys
if not globals().get('skip_imports'):
import compat.ssubprocess as ssubprocess
import helpers
from helpers import *
from helpers import log, debug1, debug2, debug3
POLL_TIME = 60*15
POLL_TIME = 60 * 15
NETSTAT_POLL_TIME = 30
CACHEFILE=os.path.expanduser('~/.sshuttle.hosts')
CACHEFILE = os.path.expanduser('~/.sshuttle.hosts')
_nmb_ok = True
@ -28,7 +34,7 @@ def write_host_cache():
tmpname = '%s.%d.tmp' % (CACHEFILE, os.getpid())
try:
f = open(tmpname, 'wb')
for name,ip in sorted(hostnames.items()):
for name, ip in sorted(hostnames.items()):
f.write('%s,%s\n' % (name, ip))
f.close()
os.rename(tmpname, CACHEFILE)
@ -50,18 +56,18 @@ def read_host_cache():
for line in f:
words = line.strip().split(',')
if len(words) == 2:
(name,ip) = words
(name, ip) = words
name = re.sub(r'[^-\w]', '-', name).strip()
ip = re.sub(r'[^0-9.]', '', ip).strip()
if name and ip:
found_host(name, ip)
def found_host(hostname, ip):
hostname = re.sub(r'\..*', '', hostname)
hostname = re.sub(r'[^-\w]', '_', hostname)
if (ip.startswith('127.') or ip.startswith('255.')
or hostname == 'localhost'):
if (ip.startswith('127.') or ip.startswith('255.')
or hostname == 'localhost'):
return
oldip = hostnames.get(hostname)
if oldip != ip:
@ -94,7 +100,7 @@ def _check_revdns(ip):
debug3('< %s\n' % r[0])
check_host(r[0])
found_host(r[0], ip)
except socket.herror, e:
except socket.herror:
pass
@ -105,7 +111,7 @@ def _check_dns(hostname):
debug3('< %s\n' % ip)
check_host(ip)
found_host(hostname, ip)
except socket.gaierror, e:
except socket.gaierror:
pass
@ -123,7 +129,7 @@ def _check_netstat():
for ip in re.findall(r'\d+\.\d+\.\d+\.\d+', content):
debug3('< %s\n' % ip)
check_host(ip)
def _check_smb(hostname):
return
@ -187,7 +193,7 @@ def _check_nmb(hostname, is_workgroup, is_master):
global _nmb_ok
if not _nmb_ok:
return
argv = ['nmblookup'] + ['-M']*is_master + ['--', hostname]
argv = ['nmblookup'] + ['-M'] * is_master + ['--', hostname]
debug2(' > n%d%d: %s\n' % (is_workgroup, is_master, hostname))
try:
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE, stderr=null)
@ -228,13 +234,13 @@ def check_workgroup(hostname):
def _enqueue(op, *args):
t = (op,args)
if queue.get(t) == None:
t = (op, args)
if queue.get(t) is None:
queue[t] = 0
def _stdin_still_ok(timeout):
r,w,x = select.select([sys.stdin.fileno()], [], [], timeout)
r, w, x = select.select([sys.stdin.fileno()], [], [], timeout)
if r:
b = os.read(sys.stdin.fileno(), 4096)
if not b:
@ -249,7 +255,7 @@ def hw_main(seed_hosts):
helpers.logprefix = 'hostwatch: '
read_host_cache()
_enqueue(_check_etc_hosts)
_enqueue(_check_netstat)
check_host('localhost')
@ -261,8 +267,8 @@ def hw_main(seed_hosts):
while 1:
now = time.time()
for t,last_polled in queue.items():
(op,args) = t
for t, last_polled in queue.items():
(op, args) = t
if not _stdin_still_ok(0):
break
maxtime = POLL_TIME
@ -275,7 +281,7 @@ def hw_main(seed_hosts):
sys.stdout.flush()
except IOError:
break
# FIXME: use a smarter timeout based on oldest last_polled
if not _stdin_still_ok(1):
break

View File

@ -1,7 +1,13 @@
import sys, os, re, socket
import helpers, options, client, server, firewall, hostwatch
import compat.ssubprocess as ssubprocess
from helpers import *
import sys
import re
import socket
import helpers
import options
import client
import server
import firewall
import hostwatch
from helpers import log, Fatal
# 1.2.3.4/5 or just 1.2.3.4
@ -9,17 +15,17 @@ def parse_subnet4(s):
m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s)
if not m:
raise Fatal('%r is not a valid IP subnet format' % s)
(a,b,c,d,width) = m.groups()
(a,b,c,d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0))
if width == None:
(a, b, c, d, width) = m.groups()
(a, b, c, d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0))
if width is None:
width = 32
else:
width = int(width)
if a > 255 or b > 255 or c > 255 or d > 255:
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d))
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a, b, c, d))
if width > 32:
raise Fatal('*/%d is greater than the maximum of 32' % width)
return(socket.AF_INET, '%d.%d.%d.%d' % (a,b,c,d), width)
return(socket.AF_INET, '%d.%d.%d.%d' % (a, b, c, d), width)
# 1:2::3/64 or just 1:2::3
@ -27,8 +33,8 @@ def parse_subnet6(s):
m = re.match(r'(?:([a-fA-F\d:]+))?(?:/(\d+))?$', s)
if not m:
raise Fatal('%r is not a valid IP subnet format' % s)
(net,width) = m.groups()
if width == None:
(net, width) = m.groups()
if width is None:
width = 128
else:
width = int(width)
@ -41,7 +47,7 @@ def parse_subnet6(s):
def parse_subnet_file(s):
try:
handle = open(s, 'r')
except OSError, e:
except OSError:
raise Fatal('Unable to open subnet file: %s' % s)
raw_config_lines = handle.readlines()
@ -77,16 +83,16 @@ def parse_ipport4(s):
m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s)
if not m:
raise Fatal('%r is not a valid IP:port format' % s)
(a,b,c,d,port) = m.groups()
(a,b,c,d,port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0),
int(port or 0))
(a, b, c, d, port) = m.groups()
(a, b, c, d, port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0),
int(port or 0))
if a > 255 or b > 255 or c > 255 or d > 255:
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d))
raise Fatal('%d.%d.%d.%d has numbers > 255' % (a, b, c, d))
if port > 65535:
raise Fatal('*:%d is greater than the maximum of 65535' % port)
if a == None:
if a is None:
a = b = c = d = 0
return ('%d.%d.%d.%d' % (a,b,c,d), port)
return ('%d.%d.%d.%d' % (a, b, c, d), port)
# [1:2::3]:456 or [1:2::3] or 456
@ -95,8 +101,8 @@ def parse_ipport6(s):
m = re.match(r'(?:\[([^]]*)])?(?::)?(?:(\d+))?$', s)
if not m:
raise Fatal('%s is not a valid IP:port format' % s)
(ip,port) = m.groups()
(ip,port) = (ip or '::', int(port or 0))
(ip, port) = m.groups()
(ip, port) = (ip or '::', int(port or 0))
return (ip, port)
@ -156,8 +162,8 @@ try:
o.fatal('at least one subnet, subnet file, or -N expected')
includes = extra
excludes = ['127.0.0.0/8']
for k,v in flags:
if k in ('-x','--exclude'):
for k, v in flags:
if k in ('-x', '--exclude'):
excludes.append(v)
remotename = opt.remote
if remotename == '' or remotename == '-':
@ -174,10 +180,10 @@ try:
includes = parse_subnet_file(opt.subnets)
if not opt.method:
method = "auto"
elif opt.method in [ "auto", "nat", "tproxy", "ipfw" ]:
elif opt.method in ["auto", "nat", "tproxy", "ipfw"]:
method = opt.method
else:
o.fatal("method %s not supported"%opt.method)
o.fatal("method %s not supported" % opt.method)
if not opt.listen:
if opt.method == "tproxy":
ipport_v6 = parse_ipport6('[::1]:0')
@ -194,23 +200,23 @@ try:
else:
ipport_v4 = parse_ipport4(ip)
return_code = client.main(ipport_v6, ipport_v4,
opt.ssh_cmd,
remotename,
opt.python,
opt.latency_control,
opt.dns,
method,
sh,
opt.auto_nets,
parse_subnets(includes),
parse_subnets(excludes),
opt.syslog, opt.daemon, opt.pidfile)
opt.ssh_cmd,
remotename,
opt.python,
opt.latency_control,
opt.dns,
method,
sh,
opt.auto_nets,
parse_subnets(includes),
parse_subnets(excludes),
opt.syslog, opt.daemon, opt.pidfile)
if return_code == 0:
log('Normal exit code, exiting...')
else:
log('Abnormal exit code detected, failing...' % return_code)
sys.exit(return_code)
log('Normal exit code, exiting...')
else:
log('Abnormal exit code detected, failing...' % return_code)
sys.exit(return_code)
except Fatal, e:
log('fatal: %s\n' % e)

View File

@ -1,9 +1,16 @@
"""Command-line options parser.
With the help of an options spec string, easily parse command-line options.
"""
import sys, os, textwrap, getopt, re, struct
import sys
import os
import textwrap
import getopt
import re
import struct
class OptDict:
def __init__(self):
self._opts = {}
@ -46,7 +53,8 @@ def _atoi(v):
def _remove_negative_kv(k, v):
if k.startswith('no-') or k.startswith('no_'):
return k[3:], not v
return k,v
return k, v
def _remove_negative_k(k):
return _remove_negative_kv(k, None)[0]
@ -55,15 +63,17 @@ def _remove_negative_k(k):
def _tty_width():
s = struct.pack("HHHH", 0, 0, 0, 0)
try:
import fcntl, termios
import fcntl
import termios
s = fcntl.ioctl(sys.stderr.fileno(), termios.TIOCGWINSZ, s)
except (IOError, ImportError):
return _atoi(os.environ.get('WIDTH')) or 70
(ysize,xsize,ypix,xpix) = struct.unpack('HHHH', s)
(ysize, xsize, ypix, xpix) = struct.unpack('HHHH', s)
return xsize or 70
class Options:
"""Option parser.
When constructed, two strings are mandatory. The first one is the command
name showed before error messages. The second one is a string called an
@ -76,6 +86,7 @@ class Options:
By default, the parser function is getopt.gnu_getopt, and the abort
behaviour is to exit the program.
"""
def __init__(self, optspec, optfunc=getopt.gnu_getopt,
onabort=_default_onabort):
self.optspec = optspec
@ -95,7 +106,8 @@ class Options:
first_syn = True
while lines:
l = lines.pop()
if l == '--': break
if l == '--':
break
out.append('%s: %s\n' % (first_syn and 'usage' or ' or', l))
first_syn = False
out.append('\n')
@ -122,7 +134,7 @@ class Options:
flagl = flags.split(',')
flagl_nice = []
for _f in flagl:
f,dvi = _remove_negative_kv(_f, _intify(defval))
f, dvi = _remove_negative_kv(_f, _intify(defval))
self._aliases[f] = _remove_negative_k(flagl[0])
self._hasparms[f] = has_parm
self._defaults[f] = dvi
@ -140,8 +152,8 @@ class Options:
flags_nice += ' ...'
prefix = ' %-20s ' % flags_nice
argtext = '\n'.join(textwrap.wrap(extra, width=_tty_width(),
initial_indent=prefix,
subsequent_indent=' '*28))
initial_indent=prefix,
subsequent_indent=' ' * 28))
out.append(argtext + '\n')
last_was_option = True
else:
@ -170,17 +182,18 @@ class Options:
and "extra" is a list of positional arguments.
"""
try:
(flags,extra) = self.optfunc(args, self._shortopts, self._longopts)
(flags, extra) = self.optfunc(
args, self._shortopts, self._longopts)
except getopt.GetoptError, e:
self.fatal(e)
opt = OptDict()
for k,v in self._defaults.iteritems():
for k, v in self._defaults.iteritems():
k = self._aliases[k]
opt[k] = v
for (k,v) in flags:
for (k, v) in flags:
k = k.lstrip('-')
if k in ('h', '?', 'help'):
self.usage()
@ -195,6 +208,6 @@ class Options:
else:
v = _intify(v)
opt[k] = v
for (f1,f2) in self._aliases.iteritems():
for (f1, f2) in self._aliases.iteritems():
opt[f1] = opt._opts.get(f2)
return (opt,flags,extra)
return (opt, flags, extra)

View File

@ -1,9 +1,22 @@
import re, struct, socket, select, traceback, time
import re
import struct
import socket
import traceback
import time
import sys
import os
if not globals().get('skip_imports'):
import ssnet, helpers, hostwatch
import ssnet
import helpers
import hostwatch
import compat.ssubprocess as ssubprocess
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
from ssnet import Handler, Proxy, Mux, MuxWrapper
from helpers import log, debug1, debug2, debug3, Fatal, \
resolvconf_random_nameserver
if not globals().get('latency_control'):
latency_control = None
def _ipmatch(ipstr):
@ -14,13 +27,13 @@ def _ipmatch(ipstr):
g = m.groups()
ips = g[0]
width = int(g[4] or 32)
if g[1] == None:
if g[1] is None:
ips += '.0.0.0'
width = min(width, 8)
elif g[2] == None:
elif g[2] is None:
ips += '.0.0'
width = min(width, 16)
elif g[3] == None:
elif g[3] is None:
ips += '.0'
width = min(width, 24)
return (struct.unpack('!I', socket.inet_aton(ips))[0], width)
@ -38,12 +51,12 @@ def _maskbits(netmask):
return 32
for i in range(32):
if netmask[0] & _shl(1, i):
return 32-i
return 32 - i
return 0
def _shl(n, bits):
return n * int(2**bits)
return n * int(2 ** bits)
def _list_routes():
@ -58,8 +71,9 @@ def _list_routes():
maskw = _ipmatch(cols[2]) # linux only
mask = _maskbits(maskw) # returns 32 if maskw is null
width = min(ipw[1], mask)
ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width)
routes.append((socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width))
ip = ipw[0] & _shl(_shl(1, width) - 1, 32 - width)
routes.append(
(socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width))
rv = p.wait()
if rv != 0:
log('WARNING: %r returned %d\n' % (argv, rv))
@ -68,9 +82,9 @@ def _list_routes():
def list_routes():
for (family, ip,width) in _list_routes():
for (family, ip, width) in _list_routes():
if not ip.startswith('0.') and not ip.startswith('127.'):
yield (family, ip,width)
yield (family, ip, width)
def _exc_dump():
@ -79,7 +93,7 @@ def _exc_dump():
def start_hostwatch(seed_hosts):
s1,s2 = socket.socketpair()
s1, s2 = socket.socketpair()
pid = os.fork()
if not pid:
# child
@ -91,27 +105,29 @@ def start_hostwatch(seed_hosts):
os.dup2(s1.fileno(), 0)
s1.close()
rv = hostwatch.hw_main(seed_hosts) or 0
except Exception, e:
except Exception:
log('%s\n' % _exc_dump())
rv = 98
finally:
os._exit(rv)
s1.close()
return pid,s2
return pid, s2
class Hostwatch:
def __init__(self):
self.pid = 0
self.sock = None
class DnsProxy(Handler):
def __init__(self, mux, chan, request):
# FIXME! IPv4 specific
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
Handler.__init__(self, [sock])
self.timeout = time.time()+30
self.timeout = time.time() + 30
self.mux = mux
self.chan = chan
self.tries = 0
@ -164,10 +180,11 @@ class DnsProxy(Handler):
class UdpProxy(Handler):
def __init__(self, mux, chan, family):
sock = socket.socket(family, socket.SOCK_DGRAM)
Handler.__init__(self, [sock])
self.timeout = time.time()+30
self.timeout = time.time() + 30
self.mux = mux
self.chan = chan
self.sock = sock
@ -177,33 +194,35 @@ class UdpProxy(Handler):
def send(self, dstip, data):
debug2('UDP: sending to %r port %d\n' % dstip)
try:
self.sock.sendto(data,dstip)
self.sock.sendto(data, dstip)
except socket.error, e:
log('UDP send to %r port %d: %s\n' % (dstip[0], dstip[1], e))
return
def callback(self):
try:
data,peer = self.sock.recvfrom(4096)
data, peer = self.sock.recvfrom(4096)
except socket.error, e:
log('UDP recv from %r port %d: %s\n' % (peer[0], peer[1], e))
return
debug2('UDP response: %d bytes\n' % len(data))
hdr = "%s,%r,"%(peer[0], peer[1])
self.mux.send(self.chan, ssnet.CMD_UDP_DATA, hdr+data)
hdr = "%s,%r," % (peer[0], peer[1])
self.mux.send(self.chan, ssnet.CMD_UDP_DATA, hdr + data)
def main():
if helpers.verbose >= 1:
helpers.logprefix = ' s: '
else:
helpers.logprefix = 'server: '
assert latency_control is not None
debug1('latency control setting = %r\n' % latency_control)
routes = list(list_routes())
debug1('available routes:\n')
for r in routes:
debug1(' %d/%s/%d\n' % r)
# synchronization header
sys.stdout.write('\0\0SSHUTTLE0001')
sys.stdout.flush()
@ -221,7 +240,7 @@ def main():
hw = Hostwatch()
hw.leftover = ''
def hostwatch_ready():
assert(hw.pid)
content = hw.sock.recv(4096)
@ -239,13 +258,13 @@ def main():
def got_host_req(data):
if not hw.pid:
(hw.pid,hw.sock) = start_hostwatch(data.strip().split())
handlers.append(Handler(socks = [hw.sock],
callback = hostwatch_ready))
(hw.pid, hw.sock) = start_hostwatch(data.strip().split())
handlers.append(Handler(socks=[hw.sock],
callback=hostwatch_ready))
mux.got_host_req = got_host_req
def new_channel(channel, data):
(family,dstip,dstport) = data.split(',', 2)
(family, dstip, dstport) = data.split(',', 2)
family = int(family)
dstport = int(dstport)
outwrap = ssnet.connect_dst(family, dstip, dstport)
@ -253,6 +272,7 @@ def main():
mux.new_channel = new_channel
dnshandlers = {}
def dns_req(channel, data):
debug2('Incoming DNS request channel=%d.\n' % channel)
h = DnsProxy(mux, channel, data)
@ -261,14 +281,15 @@ def main():
mux.got_dns_req = dns_req
udphandlers = {}
def udp_req(channel, cmd, data):
debug2('Incoming UDP request channel=%d, cmd=%d\n' % (channel,cmd))
debug2('Incoming UDP request channel=%d, cmd=%d\n' % (channel, cmd))
if cmd == ssnet.CMD_UDP_DATA:
(dstip,dstport,data) = data.split(",",2)
(dstip, dstport, data) = data.split(",", 2)
dstport = int(dstport)
debug2('is incoming UDP data. %r %d.\n' % (dstip,dstport))
debug2('is incoming UDP data. %r %d.\n' % (dstip, dstport))
h = udphandlers[channel]
h.send((dstip,dstport),data)
h.send((dstip, dstport), data)
elif cmd == ssnet.CMD_UDP_CLOSE:
debug2('is incoming UDP close\n')
h = udphandlers[channel]
@ -280,21 +301,21 @@ def main():
family = int(data)
mux.channels[channel] = lambda cmd, data: udp_req(channel, cmd, data)
if channel in udphandlers:
raise Fatal('UDP connection channel %d already open'%channel)
raise Fatal('UDP connection channel %d already open' % channel)
else:
h = UdpProxy(mux, channel, family)
handlers.append(h)
udphandlers[channel] = h
mux.got_udp_open = udp_open
while mux.ok:
if hw.pid:
assert(hw.pid > 0)
(rpid, rv) = os.waitpid(hw.pid, os.WNOHANG)
if rpid:
raise Fatal('hostwatch exited unexpectedly: code 0x%04x\n' % rv)
raise Fatal(
'hostwatch exited unexpectedly: code 0x%04x\n' % rv)
ssnet.runonce(handlers, mux)
if latency_control:
mux.check_fullness()
@ -302,12 +323,12 @@ def main():
if dnshandlers:
now = time.time()
for channel,h in dnshandlers.items():
for channel, h in dnshandlers.items():
if h.timeout < now or not h.ok:
debug3('expiring dnsreqs channel=%d\n' % channel)
del dnshandlers[channel]
h.ok = False
for channel,h in udphandlers.items():
for channel, h in udphandlers.items():
if not h.ok:
debug3('expiring UDP channel=%d\n' % channel)
del udphandlers[channel]

View File

@ -1,7 +1,11 @@
import sys, os, re, socket, zlib
import sys
import os
import re
import socket
import zlib
import compat.ssubprocess as ssubprocess
import helpers
from helpers import *
from helpers import debug2
def readfile(name):
@ -15,7 +19,7 @@ def readfile(name):
def empackage(z, filename, data=None):
(path,basename) = os.path.split(filename)
(path, basename) = os.path.split(filename)
if not data:
data = readfile(filename)
content = z.compress(data)
@ -24,7 +28,6 @@ def empackage(z, filename, data=None):
def connect(ssh_cmd, rhostport, python, stderr, options):
main_exe = sys.argv[0]
portl = []
if (rhostport or '').count(':') > 1:
@ -35,9 +38,11 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
result[1] = result[1].strip(':')
if result[1] is not '':
portl = ['-p', str(int(result[1]))]
else: # can't disambiguate IPv6 colons and a port number. pass the hostname through.
# can't disambiguate IPv6 colons and a port number. pass the hostname
# through.
else:
rhost = rhostport
else: # IPv4
else: # IPv4
l = (rhostport or '').split(':', 1)
rhost = l[0]
if len(l) > 1:
@ -48,7 +53,7 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
z = zlib.compressobj(1)
content = readfile('assembler.py')
optdata = ''.join("%s=%r\n" % (k,v) for (k,v) in options.items())
optdata = ''.join("%s=%r\n" % (k, v) for (k, v) in options.items())
content2 = (empackage(z, 'cmdline_options.py', optdata) +
empackage(z, 'helpers.py') +
empackage(z, 'compat/ssubprocess.py') +
@ -56,7 +61,7 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
empackage(z, 'hostwatch.py') +
empackage(z, 'server.py') +
"\n")
pyscript = r"""
import sys;
skip_imports=1;
@ -65,7 +70,6 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
""" % (helpers.verbose or 0, len(content))
pyscript = re.sub(r'\s+', ' ', pyscript.strip())
if not rhost:
# ignore the --python argument when running locally; we already know
# which python version works.
@ -80,14 +84,15 @@ def connect(ssh_cmd, rhostport, python, stderr, options):
else:
pycmd = ("P=python2; $P -V 2>/dev/null || P=python; "
"exec \"$P\" -c '%s'") % pyscript
argv = (sshl +
portl +
argv = (sshl +
portl +
[rhost, '--', pycmd])
(s1,s2) = socket.socketpair()
(s1, s2) = socket.socketpair()
def setup():
# runs in the child process
s2.close()
s1a,s1b = os.dup(s1.fileno()), os.dup(s1.fileno())
s1a, s1b = os.dup(s1.fileno()), os.dup(s1.fileno())
s1.close()
debug2('executing: %r\n' % argv)
p = ssubprocess.Popen(argv, stdin=s1a, stdout=s1b, preexec_fn=setup,

View File

@ -1,9 +1,13 @@
import struct, socket, errno, select
import struct
import socket
import errno
import select
import os
if not globals().get('skip_imports'):
from helpers import *
from helpers import log, debug1, debug2, debug3, Fatal
MAX_CHANNEL = 65535
# these don't exist in the socket module in python 2.3!
SHUT_RD = 0
SHUT_WR = 1
@ -92,7 +96,10 @@ def _try_peername(sock):
_swcount = 0
class SockWrapper:
def __init__(self, rsock, wsock, connect_to=None, peername=None):
global _swcount
_swcount += 1
@ -177,8 +184,8 @@ class SockWrapper:
if not self.shut_read:
debug2('%r: done reading\n' % self)
self.shut_read = True
#self.rsock.shutdown(SHUT_RD) # doesn't do anything anyway
# self.rsock.shutdown(SHUT_RD) # doesn't do anything anyway
def nowrite(self):
if not self.shut_write:
debug2('%r: done writing\n' % self)
@ -206,7 +213,7 @@ class SockWrapper:
# unexpected error... stream is dead
self.seterr('uwrite: %s' % e)
return 0
def write(self, buf):
assert(buf)
return self.uwrite(buf)
@ -221,7 +228,7 @@ class SockWrapper:
return _nb_clean(os.read, self.rsock.fileno(), 65536)
except OSError, e:
self.seterr('uread: %s' % e)
return '' # unexpected error... we'll call it EOF
return '' # unexpected error... we'll call it EOF
def fill(self):
if self.buf:
@ -243,7 +250,8 @@ class SockWrapper:
class Handler:
def __init__(self, socks = None, callback = None):
def __init__(self, socks=None, callback=None):
self.ok = True
self.socks = socks or []
if callback:
@ -255,7 +263,7 @@ class Handler:
def callback(self):
log('--no callback defined-- %r\n' % self)
(r,w,x) = select.select(self.socks, [], [], 0)
(r, w, x) = select.select(self.socks, [], [], 0)
for s in r:
v = s.recv(4096)
if not v:
@ -265,6 +273,7 @@ class Handler:
class Proxy(Handler):
def __init__(self, wrap1, wrap2):
Handler.__init__(self, [wrap1.rsock, wrap1.wsock,
wrap2.rsock, wrap2.wsock])
@ -272,9 +281,11 @@ class Proxy(Handler):
self.wrap2 = wrap2
def pre_select(self, r, w, x):
if self.wrap1.shut_write: self.wrap2.noread()
if self.wrap2.shut_write: self.wrap1.noread()
if self.wrap1.shut_write:
self.wrap2.noread()
if self.wrap2.shut_write:
self.wrap1.noread()
if self.wrap1.connect_to:
_add(w, self.wrap1.rsock)
elif self.wrap1.buf:
@ -305,13 +316,14 @@ class Proxy(Handler):
self.wrap2.buf = []
self.wrap2.noread()
if (self.wrap1.shut_read and self.wrap2.shut_read and
not self.wrap1.buf and not self.wrap2.buf):
not self.wrap1.buf and not self.wrap2.buf):
self.ok = False
self.wrap1.nowrite()
self.wrap2.nowrite()
class Mux(Handler):
def __init__(self, rsock, wsock):
Handler.__init__(self, [rsock, wsock])
self.rsock = rsock
@ -342,31 +354,31 @@ class Mux(Handler):
for b in self.outbuf:
total += len(b)
return total
def check_fullness(self):
if self.fullness > 32768:
if not self.too_full:
self.send(0, CMD_PING, 'rttest')
self.too_full = True
#ob = []
#for b in self.outbuf:
# for b in self.outbuf:
# (s1,s2,c) = struct.unpack('!ccH', b[:4])
# ob.append(c)
#log('outbuf: %d %r\n' % (self.amount_queued(), ob))
def send(self, channel, cmd, data):
data = str(data)
assert(len(data) <= 65535)
p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data
self.outbuf.append(p)
debug2(' > channel=%d cmd=%s len=%d (fullness=%d)\n'
% (channel, cmd_to_name.get(cmd,hex(cmd)),
% (channel, cmd_to_name.get(cmd, hex(cmd)),
len(data), self.fullness))
self.fullness += len(data)
def got_packet(self, channel, cmd, data):
debug2('< channel=%d cmd=%s len=%d\n'
% (channel, cmd_to_name.get(cmd,hex(cmd)), len(data)))
debug2('< channel=%d cmd=%s len=%d\n'
% (channel, cmd_to_name.get(cmd, hex(cmd)), len(data)))
if cmd == CMD_PING:
self.send(0, CMD_PONG, data)
elif cmd == CMD_PONG:
@ -405,8 +417,8 @@ class Mux(Handler):
else:
callback = self.channels.get(channel)
if not callback:
log('warning: closed channel %d got cmd=%s len=%d\n'
% (channel, cmd_to_name.get(cmd,hex(cmd)), len(data)))
log('warning: closed channel %d got cmd=%s len=%d\n'
% (channel, cmd_to_name.get(cmd, hex(cmd)), len(data)))
else:
callback(cmd, data)
@ -427,18 +439,18 @@ class Mux(Handler):
except OSError, e:
raise Fatal('other end: %r' % e)
#log('<<< %r\n' % b)
if b == '': # EOF
if b == '': # EOF
self.ok = False
if b:
self.inbuf += b
def handle(self):
self.fill()
#log('inbuf is: (%d,%d) %r\n'
# log('inbuf is: (%d,%d) %r\n'
# % (self.want, len(self.inbuf), self.inbuf))
while 1:
if len(self.inbuf) >= (self.want or HDR_LEN):
(s1,s2,channel,cmd,datalen) = \
(s1, s2, channel, cmd, datalen) = \
struct.unpack('!ccHHH', self.inbuf[:HDR_LEN])
assert(s1 == 'S')
assert(s2 == 'S')
@ -457,7 +469,7 @@ class Mux(Handler):
_add(w, self.wsock)
def callback(self):
(r,w,x) = select.select([self.rsock], [self.wsock], [], 0)
(r, w, x) = select.select([self.rsock], [self.wsock], [], 0)
if self.rsock in r:
self.handle()
if self.outbuf and self.wsock in w:
@ -465,6 +477,7 @@ class Mux(Handler):
class MuxWrapper(SockWrapper):
def __init__(self, mux, channel):
SockWrapper.__init__(self, mux.rsock, mux.wsock)
self.mux = mux
@ -478,7 +491,7 @@ class MuxWrapper(SockWrapper):
SockWrapper.__del__(self)
def __repr__(self):
return 'SW%r:Mux#%d' % (self.peername,self.channel)
return 'SW%r:Mux#%d' % (self.peername, self.channel)
def noread(self):
if not self.shut_read:
@ -511,7 +524,7 @@ class MuxWrapper(SockWrapper):
def uread(self):
if self.shut_read:
return '' # EOF
return '' # EOF
else:
return None # no data available right now
@ -523,7 +536,7 @@ class MuxWrapper(SockWrapper):
elif cmd == CMD_TCP_DATA:
self.buf.append(data)
else:
raise Exception('unknown command %d (%d bytes)'
raise Exception('unknown command %d (%d bytes)'
% (cmd, len(data)))
@ -532,8 +545,8 @@ def connect_dst(family, ip, port):
outsock = socket.socket(family)
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
return SockWrapper(outsock, outsock,
connect_to = (ip,port),
peername = '%s:%d' % (ip,port))
connect_to=(ip, port),
peername = '%s:%d' % (ip, port))
def runonce(handlers, mux):
@ -545,14 +558,14 @@ def runonce(handlers, mux):
handlers.remove(h)
for s in handlers:
s.pre_select(r,w,x)
debug2('Waiting: %d r=%r w=%r x=%r (fullness=%d/%d)\n'
% (len(handlers), _fds(r), _fds(w), _fds(x),
s.pre_select(r, w, x)
debug2('Waiting: %d r=%r w=%r x=%r (fullness=%d/%d)\n'
% (len(handlers), _fds(r), _fds(w), _fds(x),
mux.fullness, mux.too_full))
(r,w,x) = select.select(r,w,x)
debug2(' Ready: %d r=%r w=%r x=%r\n'
% (len(handlers), _fds(r), _fds(w), _fds(x)))
ready = r+w+x
(r, w, x) = select.select(r, w, x)
debug2(' Ready: %d r=%r w=%r x=%r\n'
% (len(handlers), _fds(r), _fds(w), _fds(x)))
ready = r + w + x
did = {}
for h in handlers:
for s in h.socks:

View File

@ -1,8 +1,11 @@
import sys, os
import sys
import os
from compat import ssubprocess
_p = None
def start_syslog():
global _p
_p = ssubprocess.Popen(['logger',

View File

@ -1,5 +1,8 @@
#!/usr/bin/env python
import sys, os, socket, select, struct, time
import socket
import select
import struct
import time
listener = socket.socket()
listener.bind(('127.0.0.1', 0))
@ -23,7 +26,7 @@ while 1:
if count >= 16384:
count = 1
print 'cli CREATING %d' % count
b = struct.pack('I', count) + 'x'*count
b = struct.pack('I', count) + 'x' * count
remain[c] = count
print 'cli >> %r' % len(b)
c.send(b)
@ -32,13 +35,13 @@ while 1:
r = [listener]
time.sleep(0.1)
else:
r = [listener]+servers+clients
r = [listener] + servers + clients
print 'select(%d)' % len(r)
r,w,x = select.select(r, [], [], 5)
r, w, x = select.select(r, [], [], 5)
assert(r)
for i in r:
if i == listener:
s,addr = listener.accept()
s, addr = listener.accept()
servers.append(s)
elif i in servers:
b = i.recv(4096)
@ -47,7 +50,7 @@ while 1:
assert(len(b) >= 4)
want = struct.unpack('I', b[:4])[0]
b = b[4:]
#i.send('y'*want)
# i.send('y'*want)
else:
want = remain[i]
if want < len(b):
@ -64,7 +67,7 @@ while 1:
del remain[i]
else:
print 'srv >> %r' % len(b)
i.send('y'*len(b))
i.send('y' * len(b))
if not want:
i.shutdown(socket.SHUT_WR)
elif i in clients:

View File

@ -1,4 +1,6 @@
import sys, os, re, subprocess
import re
import subprocess
def askpass(prompt):
prompt = prompt.replace('"', "'")
@ -6,7 +8,7 @@ def askpass(prompt):
if 'yes/no' in prompt:
return "yes"
script="""
script = """
tell application "Finder"
activate
display dialog "%s" \

View File

@ -1,6 +1,11 @@
import sys, os, pty
import sys
import os
import pty
from AppKit import *
import my, models, askpass
import my
import models
import askpass
def sshuttle_args(host, auto_nets, auto_hosts, dns, nets, debug,
no_latency_control):
@ -21,21 +26,25 @@ def sshuttle_args(host, auto_nets, auto_hosts, dns, nets, debug,
class _Callback(NSObject):
def initWithFunc_(self, func):
self = super(_Callback, self).init()
self.func = func
return self
def func_(self, obj):
return self.func(obj)
class Callback:
def __init__(self, func):
self.obj = _Callback.alloc().initWithFunc_(func)
self.sel = self.obj.func_
class Runner:
def __init__(self, argv, logfunc, promptfunc, serverobj):
print 'in __init__'
self.id = argv
@ -49,7 +58,7 @@ class Runner:
self.logfunc('\nConnecting to %s.\n' % self.serverobj.host())
print 'will run: %r' % argv
self.serverobj.setConnected_(False)
pid,fd = pty.fork()
pid, fd = pty.fork()
if pid == 0:
# child
try:
@ -62,19 +71,20 @@ class Runner:
# parent
self.pid = pid
self.file = NSFileHandle.alloc()\
.initWithFileDescriptor_closeOnDealloc_(fd, True)
.initWithFileDescriptor_closeOnDealloc_(fd, True)
self.cb = Callback(self.gotdata)
NSNotificationCenter.defaultCenter()\
.addObserver_selector_name_object_(self.cb.obj, self.cb.sel,
NSFileHandleDataAvailableNotification, self.file)
.addObserver_selector_name_object_(
self.cb.obj, self.cb.sel,
NSFileHandleDataAvailableNotification, self.file)
self.file.waitForDataInBackgroundAndNotify()
def __del__(self):
self.wait()
def _try_wait(self, options):
if self.rv == None and self.pid > 0:
pid,code = os.waitpid(self.pid, options)
if self.rv is None and self.pid > 0:
pid, code = os.waitpid(self.pid, options)
if pid == self.pid:
if os.WIFEXITED(code):
self.rv = os.WEXITSTATUS(code)
@ -88,14 +98,14 @@ class Runner:
def wait(self):
return self._try_wait(0)
def poll(self):
return self._try_wait(os.WNOHANG)
def kill(self):
assert(self.pid > 0)
print 'killing: pid=%r rv=%r' % (self.pid, self.rv)
if self.rv == None:
if self.rv is None:
self.logfunc('Disconnecting from %s.\n' % self.serverobj.host())
os.kill(self.pid, 15)
self.wait()
@ -118,12 +128,13 @@ class Runner:
self.file.writeData_(my.Data(resp + '\n'))
self.file.waitForDataInBackgroundAndNotify()
self.poll()
#print 'gotdata done!'
# print 'gotdata done!'
class SshuttleApp(NSObject):
def initialize(self):
d = my.PList('UserDefaults')
d = my.PList('UserDefaults')
my.Defaults().registerDefaults_(d)
@ -137,7 +148,7 @@ class SshuttleController(NSObject):
serversController = objc.IBOutlet()
logField = objc.IBOutlet()
latencyControlField = objc.IBOutlet()
servers = []
conns = {}
@ -145,12 +156,14 @@ class SshuttleController(NSObject):
host = server.host()
print 'connecting %r' % host
self.fill_menu()
def logfunc(msg):
print 'log! (%d bytes)' % len(msg)
self.logField.textStorage()\
.appendAttributedString_(NSAttributedString.alloc()\
.appendAttributedString_(NSAttributedString.alloc()
.initWithString_(msg))
self.logField.didChangeText()
def promptfunc(prompt):
print 'prompt! %r' % prompt
return askpass.askpass(prompt)
@ -164,12 +177,12 @@ class SshuttleController(NSObject):
manual_nets = []
noLatencyControl = (server.latencyControl() != models.LAT_INTERACTIVE)
conn = Runner(sshuttle_args(host,
auto_nets = nets_mode == models.NET_AUTO,
auto_hosts = server.autoHosts(),
dns = server.useDns(),
nets = manual_nets,
debug = self.debugField.state(),
no_latency_control = noLatencyControl),
auto_nets=nets_mode == models.NET_AUTO,
auto_hosts=server.autoHosts(),
dns=server.useDns(),
nets=manual_nets,
debug=self.debugField.state(),
no_latency_control=noLatencyControl),
logfunc=logfunc, promptfunc=promptfunc,
serverobj=server)
self.conns[host] = conn
@ -182,8 +195,8 @@ class SshuttleController(NSObject):
conn.kill()
self.fill_menu()
self.logField.textStorage().setAttributedString_(
NSAttributedString.alloc().initWithString_(''))
NSAttributedString.alloc().initWithString_(''))
@objc.IBAction
def cmd_connect(self, sender):
server = sender.representedObject()
@ -213,6 +226,7 @@ class SshuttleController(NSObject):
it.setRepresentedObject_(obj)
it.setTarget_(self)
it.setAction_(func)
def addnote(name):
additem(name, None, None)
@ -271,8 +285,9 @@ class SshuttleController(NSObject):
sl = []
for s in l:
host = s.get('host', None)
if not host: continue
if not host:
continue
nets = s.get('nets', [])
nl = []
for n in nets:
@ -282,7 +297,7 @@ class SshuttleController(NSObject):
net.setSubnet_(subnet)
net.setWidth_(width)
nl.append(net)
autoNets = s.get('autoNets', models.NET_AUTO)
autoHosts = s.get('autoHosts', True)
useDns = s.get('useDns', autoNets == models.NET_ALL)
@ -302,11 +317,13 @@ class SshuttleController(NSObject):
l = []
for s in self.servers:
host = s.host()
if not host: continue
if not host:
continue
nets = []
for n in s.nets():
subnet = n.subnet()
if not subnet: continue
if not subnet:
continue
nets.append((subnet, n.width()))
d = dict(host=s.host(),
nets=nets,
@ -352,9 +369,9 @@ class SshuttleController(NSObject):
statusitem.setHighlightMode_(True)
statusitem.setMenu_(self.menu)
self.fill_menu()
models.configchange_callback = my.DelayedCallback(self.save_servers)
def sc(server):
if server.wantConnect():
self._connect(server)

View File

@ -35,24 +35,29 @@ def _validate_width(v):
class SshuttleNet(NSObject):
def subnet(self):
return getattr(self, '_k_subnet', None)
def setSubnet_(self, v):
self._k_subnet = v
config_changed()
@objc_validator
def validateSubnet_error_(self, value, error):
#print 'validateSubnet!'
# print 'validateSubnet!'
return True, _validate_ip(value), error
def width(self):
return getattr(self, '_k_width', 24)
def setWidth_(self, v):
self._k_width = v
config_changed()
@objc_validator
def validateWidth_error_(self, value, error):
#print 'validateWidth!'
# print 'validateWidth!'
return True, _validate_width(value), error
NET_ALL = 0
@ -62,30 +67,37 @@ NET_MANUAL = 2
LAT_BANDWIDTH = 0
LAT_INTERACTIVE = 1
class SshuttleServer(NSObject):
def init(self):
self = super(SshuttleServer, self).init()
config_changed()
return self
def wantConnect(self):
return getattr(self, '_k_wantconnect', False)
def setWantConnect_(self, v):
self._k_wantconnect = v
self.setError_(None)
config_changed()
if setconnect_callback: setconnect_callback(self)
if setconnect_callback:
setconnect_callback(self)
def connected(self):
return getattr(self, '_k_connected', False)
def setConnected_(self, v):
print 'setConnected of %r to %r' % (self, v)
self._k_connected = v
if v: self.setError_(None) # connected ok, so no error
if v:
self.setError_(None) # connected ok, so no error
config_changed()
def error(self):
return getattr(self, '_k_error', None)
def setError_(self, v):
self._k_error = v
config_changed()
@ -107,40 +119,47 @@ class SshuttleServer(NSObject):
suffix = " (all traffic)"
elif an == NET_MANUAL:
n = self.nets()
suffix = ' (%d subnet%s)' % (len(n), len(n)!=1 and 's' or '')
suffix = ' (%d subnet%s)' % (len(n), len(n) != 1 and 's' or '')
return self.host() + suffix
def setTitle_(self, v):
# title is always auto-generated
config_changed()
def host(self):
return getattr(self, '_k_host', None)
def setHost_(self, v):
self._k_host = v
self.setTitle_(None)
config_changed()
@objc_validator
def validateHost_error_(self, value, error):
#print 'validatehost! %r %r %r' % (self, value, error)
# print 'validatehost! %r %r %r' % (self, value, error)
while value.startswith('-'):
value = value[1:]
return True, value, error
def nets(self):
return getattr(self, '_k_nets', [])
def setNets_(self, v):
self._k_nets = v
self.setTitle_(None)
config_changed()
def netsHidden(self):
#print 'checking netsHidden'
# print 'checking netsHidden'
return self.autoNets() != NET_MANUAL
def setNetsHidden_(self, v):
config_changed()
#print 'setting netsHidden to %r' % v
# print 'setting netsHidden to %r' % v
def autoNets(self):
return getattr(self, '_k_autoNets', NET_AUTO)
def setAutoNets_(self, v):
self._k_autoNets = v
self.setNetsHidden_(-1)
@ -150,18 +169,21 @@ class SshuttleServer(NSObject):
def autoHosts(self):
return getattr(self, '_k_autoHosts', True)
def setAutoHosts_(self, v):
self._k_autoHosts = v
config_changed()
def useDns(self):
return getattr(self, '_k_useDns', False)
def setUseDns_(self, v):
self._k_useDns = v
config_changed()
def latencyControl(self):
return getattr(self, '_k_latencyControl', LAT_INTERACTIVE)
def setLatencyControl_(self, v):
self._k_latencyControl = v
config_changed()

View File

@ -1,4 +1,4 @@
import sys, os
import os
from AppKit import *
import PyObjCTools.AppHelper
@ -44,11 +44,13 @@ def Defaults():
#
def DelayedCallback(func, *args, **kwargs):
flag = [0]
def _go():
if flag[0]:
print 'running %r (flag=%r)' % (func, flag)
flag[0] = 0
func(*args, **kwargs)
def call():
flag[0] += 1
PyObjCTools.AppHelper.callAfter(_go)