Basic implementation of a multiplex protocol - client side only.

Currently the 'server' is just a pipe to run 'hd' (hexdump) for looking at
the client-side results.  Lame, but true.
This commit is contained in:
Avery Pennarun 2010-05-01 23:14:42 -04:00
parent 9f514d7a15
commit 5f0bfb5d9e
4 changed files with 196 additions and 36 deletions

View File

@ -1,13 +1,13 @@
import struct, socket, select, subprocess, errno import struct, socket, select, subprocess, errno
from ssnet import SockWrapper, Handler, Proxy import ssnet, ssh
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import * from helpers import *
def original_dst(sock): def original_dst(sock):
SO_ORIGINAL_DST = 80 SO_ORIGINAL_DST = 80
SOCKADDR_MIN = 16 SOCKADDR_MIN = 16
sockaddr_in = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, SOCKADDR_MIN) sockaddr_in = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, SOCKADDR_MIN)
(proto, port, a,b,c,d) = struct.unpack('!hhBBBB', sockaddr_in[:8]) (proto, port, a,b,c,d) = struct.unpack('!HHBBBB', sockaddr_in[:8])
assert(socket.htons(proto) == socket.AF_INET) assert(socket.htons(proto) == socket.AF_INET)
ip = '%d.%d.%d.%d' % (a,b,c,d) ip = '%d.%d.%d.%d' % (a,b,c,d)
return (ip,port) return (ip,port)
@ -21,8 +21,17 @@ def iptables_setup(port, subnets):
raise Exception('%r returned %d' % (argv, rv)) raise Exception('%r returned %d' % (argv, rv))
def _main(listener, remotename, subnets): def _main(listener, listenport, use_server, remotename, subnets):
handlers = [] handlers = []
if use_server:
(serverproc, serversock) = ssh.connect(remotename)
mux = Mux(serversock)
handlers.append(mux)
# we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection!
iptables_setup(listenport, subnets)
def onaccept(): def onaccept():
sock,srcip = listener.accept() sock,srcip = listener.accept()
dstip = original_dst(sock) dstip = original_dst(sock)
@ -31,10 +40,16 @@ def _main(listener, remotename, subnets):
log("-- ignored: that's my address!\n") log("-- ignored: that's my address!\n")
sock.close() sock.close()
return return
outsock = socket.socket() if use_server:
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) chan = mux.next_channel()
outsock.connect(dstip) mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip)
handlers.append(Proxy(SockWrapper(sock), SockWrapper(outsock))) outwrap = MuxWrapper(mux, chan)
else:
outsock = socket.socket()
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
outsock.connect(dstip)
outwrap = SockWrapper(outsock)
handlers.append(Proxy(SockWrapper(sock), outwrap))
handlers.append(Handler([listener], onaccept)) handlers.append(Handler([listener], onaccept))
while 1: while 1:
@ -54,7 +69,7 @@ def _main(listener, remotename, subnets):
s.callback() s.callback()
def main(listenip, remotename, subnets): def main(listenip, use_server, remotename, subnets):
log('Starting sshuttle proxy.\n') log('Starting sshuttle proxy.\n')
listener = socket.socket() listener = socket.socket()
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -81,9 +96,7 @@ def main(listenip, remotename, subnets):
listenip = listener.getsockname() listenip = listener.getsockname()
log('Listening on %r.\n' % (listenip,)) log('Listening on %r.\n' % (listenip,))
iptables_setup(listenip[1], subnets)
try: try:
return _main(listener, remotename, subnets) return _main(listener, listenip[1], use_server, remotename, subnets)
finally: finally:
iptables_setup(listenip[1], []) iptables_setup(listenip[1], [])

10
main.py
View File

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys, re import sys, os, re
import options, client, iptables import options, client, iptables
@ -50,6 +50,7 @@ sshuttle --server
-- --
l,listen= transproxy to this ip address and port number [default=0] l,listen= transproxy to this ip address and port number [default=0]
r,remote= ssh hostname (and optional username) of remote sshuttle server r,remote= ssh hostname (and optional username) of remote sshuttle server
noserver don't use a separate server process (mostly for debugging)
server [internal use only] server [internal use only]
iptables [internal use only] iptables [internal use only]
""" """
@ -57,7 +58,9 @@ o = options.Options('sshuttle', optspec)
(opt, flags, extra) = o.parse(sys.argv[1:]) (opt, flags, extra) = o.parse(sys.argv[1:])
if opt.server: if opt.server:
o.fatal('server mode not implemented yet') #o.fatal('server mode not implemented yet')
os.dup2(2,1)
os.execvp('hd', ['hd'])
sys.exit(1) sys.exit(1)
elif opt.iptables: elif opt.iptables:
if len(extra) < 1: if len(extra) < 1:
@ -67,9 +70,10 @@ elif opt.iptables:
else: else:
if len(extra) < 1: if len(extra) < 1:
o.fatal('at least one subnet expected') o.fatal('at least one subnet expected')
remotename = extra[0] remotename = opt.remote
if remotename == '' or remotename == '-': if remotename == '' or remotename == '-':
remotename = None remotename = None
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
not opt.noserver,
remotename, remotename,
parse_subnets(extra))) parse_subnets(extra)))

24
ssh.py
View File

@ -1,14 +1,13 @@
import os, re, subprocess import sys, os, re, subprocess, socket
def connect(rhost, subcmd): def connect(rhost):
assert(not re.search(r'[^\w-]', subcmd))
main_exe = sys.argv[0] main_exe = sys.argv[0]
nicedir = os.path.split(os.path.abspath(main_exe))[0] nicedir = os.path.split(os.path.abspath(main_exe))[0]
nicedir = re.sub(r':', "_", nicedir) nicedir = re.sub(r':', "_", nicedir)
if rhost == '-': if rhost == '-':
rhost = None rhost = None
if not rhost: if not rhost:
argv = ['sshuttle', subcmd] argv = ['sshuttle', '--server']
else: else:
# WARNING: shell quoting security holes are possible here, so we # WARNING: shell quoting security holes are possible here, so we
# have to be super careful. We have to use 'sh -c' because # have to be super careful. We have to use 'sh -c' because
@ -19,14 +18,21 @@ def connect(rhost, subcmd):
# stuff here. # stuff here.
escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir) escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir)
cmd = r""" cmd = r"""
sh -c PATH=%s:'$PATH sshuttle %s' sh -c PATH=%s:'$PATH sshuttle --server'
""" % (escapedir, subcmd) """ % (escapedir,)
argv = ['ssh', rhost, '--', cmd.strip()] argv = ['ssh', '-v', rhost, '--', cmd.strip()]
print repr(argv)
(s1,s2) = socket.socketpair()
def setup(): def setup():
# runs in the child process # runs in the child process
s2.close()
if not rhost: if not rhost:
os.environ['PATH'] = ':'.join([nicedir, os.environ['PATH'] = ':'.join([nicedir,
os.environ.get('PATH', '')]) os.environ.get('PATH', '')])
os.setsid() os.setsid()
return subprocess.Popen(argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE, s1a,s1b = os.dup(s1.fileno()), os.dup(s1.fileno())
preexec_fn=setup) s1.close()
p = subprocess.Popen(argv, stdin=s1a, stdout=s1b, preexec_fn=setup)
os.close(s1a)
os.close(s1b)
return p, s2

161
ssnet.py
View File

@ -1,6 +1,17 @@
import socket, errno, select import struct, socket, errno, select
from helpers import * from helpers import *
HDR_LEN = 8
CMD_EXIT = 0x4200
CMD_PING = 0x4201
CMD_PONG = 0x4202
CMD_CONNECT = 0x4203
CMD_CLOSE = 0x4204
CMD_EOF = 0x4205
CMD_DATA = 0x4206
def _nb_clean(func, *args): def _nb_clean(func, *args):
try: try:
return func(*args) return func(*args)
@ -35,28 +46,32 @@ class SockWrapper:
self.shut_write = True self.shut_write = True
self.sock.shutdown(socket.SHUT_WR) self.sock.shutdown(socket.SHUT_WR)
def write(self, buf): def uwrite(self, buf):
assert(buf)
self.sock.setblocking(False) self.sock.setblocking(False)
return _nb_clean(self.sock.send, buf) return _nb_clean(self.sock.send, buf)
def fill(self): def write(self, buf):
assert(buf)
return self.uwrite(buf)
def uread(self):
if self.shut_read: if self.shut_read:
return return
self.sock.setblocking(False) self.sock.setblocking(False)
rb = _nb_clean(self.sock.recv, 65536) return _nb_clean(self.sock.recv, 65536)
def fill(self):
if self.buf:
return
rb = self.uread()
if rb: if rb:
self.buf.append(rb) self.buf.append(rb)
if rb == '': # empty string means EOF; None means temporarily empty if rb == '': # empty string means EOF; None means temporarily empty
self.noread() self.noread()
def maybe_fill(self):
if not self.buf:
self.fill()
def copy_to(self, outwrap): def copy_to(self, outwrap):
if self.buf and self.buf[0]: if self.buf and self.buf[0]:
wrote = outwrap.sock.send(self.buf[0]) wrote = outwrap.write(self.buf[0])
self.buf[0] = self.buf[0][wrote:] self.buf[0] = self.buf[0][wrote:]
while self.buf and not self.buf[0]: while self.buf and not self.buf[0]:
self.buf.pop(0) self.buf.pop(0)
@ -102,8 +117,8 @@ class Proxy(Handler):
r.add(self.wrap2.sock) r.add(self.wrap2.sock)
def callback(self): def callback(self):
self.wrap1.maybe_fill() self.wrap1.fill()
self.wrap2.maybe_fill() self.wrap2.fill()
self.wrap1.copy_to(self.wrap2) self.wrap1.copy_to(self.wrap2)
self.wrap2.copy_to(self.wrap1) self.wrap2.copy_to(self.wrap1)
if (self.wrap1.shut_read and self.wrap2.shut_read and if (self.wrap1.shut_read and self.wrap2.shut_read and
@ -111,3 +126,125 @@ class Proxy(Handler):
self.ok = False self.ok = False
class Mux(Handler):
def __init__(self, sock):
Handler.__init__(self, [sock])
self.sock = sock
self.channels = {}
self.chani = 0
self.want = 0
self.inbuf = ''
self.outbuf = []
self.send(0, CMD_PING, 'chicken')
def next_channel(self):
# channel 0 is special, so we never allocate it
for timeout in xrange(1024):
self.chani += 1
if self.chani > 65535:
self.chani = 1
if not self.channels.get(self.chani):
return self.chani
def send(self, channel, cmd, data):
data = str(data)
assert(len(data) <= 65535)
p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data
self.outbuf.append(p)
log('Mux: send queue is %d/%d\n'
% (len(self.outbuf), sum(len(b) for b in self.outbuf)))
def got_packet(self, channel, cmd, data):
log('--got-packet--\n')
if cmd == CMD_PING:
self.mux.send(0, CMD_PONG, data)
elif cmd == CMD_EXIT:
self.ok = False
else:
c = self.channels[channel]
c.got_packet(cmd, data)
def flush(self):
self.sock.setblocking(False)
if self.outbuf and self.outbuf[0]:
wrote = _nb_clean(self.sock.send, self.outbuf[0])
if wrote:
self.outbuf[0] = self.outbuf[0][wrote:]
while self.outbuf and not self.outbuf[0]:
self.outbuf.pop()
def fill(self):
self.sock.setblocking(False)
b = _nb_clean(self.sock.recv, 32768)
if b == '': # EOF
ok = False
if b:
self.inbuf += b
def handle(self):
log('inbuf is: %r\n' % self.inbuf)
if len(self.inbuf) >= (self.want or HDR_LEN):
(s1,s2,channel,cmd,datalen) = struct.unpack('!ccHHH',
self.inbuf[:HDR_LEN])
assert(s1 == 'S')
assert(s2 == 'S')
self.want = datalen + HDR_LEN
if self.want and len(self.inbuf) >= self.want:
data = self.inbuf[HDR_LEN:self.want]
self.inbuf = self.inbuf[self.want:]
self.got_packet(channel, cmd, data)
else:
self.fill()
def pre_select(self, r, w, x):
if self.inbuf < (self.want or HDR_LEN):
r.add(self.sock)
if self.outbuf:
w.add(self.sock)
def callback(self):
(r,w,x) = select.select([self.sock], [self.sock], [], 0)
if self.sock in r:
self.handle()
if self.outbuf and self.sock in w:
self.flush()
class MuxWrapper(SockWrapper):
def __init__(self, mux, channel):
SockWrapper.__init__(self, mux.sock)
self.mux = mux
self.channel = channel
self.mux.channels[channel] = self
log('Created MuxWrapper on channel %d\n' % channel)
def noread(self):
if not self.shut_read:
self.shut_read = True
def nowrite(self):
if not self.shut_write:
self.shut_write = True
self.mux.send(self.channel, CMD_EOF, '')
def uwrite(self, buf):
self.mux.send(self.channel, CMD_DATA, buf)
return len(buf)
def uread(self):
if self.shut_read:
return '' # EOF
else:
return None # no data available right now
def got_packet(self, cmd, data):
if cmd == CMD_CLOSE:
self.noread()
self.nowrite()
elif cmd == CMD_EOF:
self.noread()
elif cmd == CMD_DATA:
self.buf.append(data)
else:
raise Exception('unknown command %d (%d bytes)'
% (cmd, len(data)))