Basic implementation of a multiplex protocol - client side only.

Currently the 'server' is just a pipe to run 'hd' (hexdump) for looking at
the client-side results.  Lame, but true.
This commit is contained in:
Avery Pennarun 2010-05-01 23:14:42 -04:00
parent 9f514d7a15
commit 5f0bfb5d9e
4 changed files with 196 additions and 36 deletions

View File

@ -1,13 +1,13 @@
import struct, socket, select, subprocess, errno
from ssnet import SockWrapper, Handler, Proxy
import ssnet, ssh
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
def original_dst(sock):
SO_ORIGINAL_DST = 80
SOCKADDR_MIN = 16
sockaddr_in = sock.getsockopt(socket.SOL_IP, SO_ORIGINAL_DST, SOCKADDR_MIN)
(proto, port, a,b,c,d) = struct.unpack('!hhBBBB', sockaddr_in[:8])
(proto, port, a,b,c,d) = struct.unpack('!HHBBBB', sockaddr_in[:8])
assert(socket.htons(proto) == socket.AF_INET)
ip = '%d.%d.%d.%d' % (a,b,c,d)
return (ip,port)
@ -21,8 +21,17 @@ def iptables_setup(port, subnets):
raise Exception('%r returned %d' % (argv, rv))
def _main(listener, remotename, subnets):
def _main(listener, listenport, use_server, remotename, subnets):
handlers = []
if use_server:
(serverproc, serversock) = ssh.connect(remotename)
mux = Mux(serversock)
handlers.append(mux)
# we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection!
iptables_setup(listenport, subnets)
def onaccept():
sock,srcip = listener.accept()
dstip = original_dst(sock)
@ -31,10 +40,16 @@ def _main(listener, remotename, subnets):
log("-- ignored: that's my address!\n")
sock.close()
return
if use_server:
chan = mux.next_channel()
mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip)
outwrap = MuxWrapper(mux, chan)
else:
outsock = socket.socket()
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
outsock.connect(dstip)
handlers.append(Proxy(SockWrapper(sock), SockWrapper(outsock)))
outwrap = SockWrapper(outsock)
handlers.append(Proxy(SockWrapper(sock), outwrap))
handlers.append(Handler([listener], onaccept))
while 1:
@ -54,7 +69,7 @@ def _main(listener, remotename, subnets):
s.callback()
def main(listenip, remotename, subnets):
def main(listenip, use_server, remotename, subnets):
log('Starting sshuttle proxy.\n')
listener = socket.socket()
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@ -81,9 +96,7 @@ def main(listenip, remotename, subnets):
listenip = listener.getsockname()
log('Listening on %r.\n' % (listenip,))
iptables_setup(listenip[1], subnets)
try:
return _main(listener, remotename, subnets)
return _main(listener, listenip[1], use_server, remotename, subnets)
finally:
iptables_setup(listenip[1], [])

10
main.py
View File

@ -1,5 +1,5 @@
#!/usr/bin/env python
import sys, re
import sys, os, re
import options, client, iptables
@ -50,6 +50,7 @@ sshuttle --server
--
l,listen= transproxy to this ip address and port number [default=0]
r,remote= ssh hostname (and optional username) of remote sshuttle server
noserver don't use a separate server process (mostly for debugging)
server [internal use only]
iptables [internal use only]
"""
@ -57,7 +58,9 @@ o = options.Options('sshuttle', optspec)
(opt, flags, extra) = o.parse(sys.argv[1:])
if opt.server:
o.fatal('server mode not implemented yet')
#o.fatal('server mode not implemented yet')
os.dup2(2,1)
os.execvp('hd', ['hd'])
sys.exit(1)
elif opt.iptables:
if len(extra) < 1:
@ -67,9 +70,10 @@ elif opt.iptables:
else:
if len(extra) < 1:
o.fatal('at least one subnet expected')
remotename = extra[0]
remotename = opt.remote
if remotename == '' or remotename == '-':
remotename = None
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
not opt.noserver,
remotename,
parse_subnets(extra)))

24
ssh.py
View File

@ -1,14 +1,13 @@
import os, re, subprocess
import sys, os, re, subprocess, socket
def connect(rhost, subcmd):
assert(not re.search(r'[^\w-]', subcmd))
def connect(rhost):
main_exe = sys.argv[0]
nicedir = os.path.split(os.path.abspath(main_exe))[0]
nicedir = re.sub(r':', "_", nicedir)
if rhost == '-':
rhost = None
if not rhost:
argv = ['sshuttle', subcmd]
argv = ['sshuttle', '--server']
else:
# WARNING: shell quoting security holes are possible here, so we
# have to be super careful. We have to use 'sh -c' because
@ -19,14 +18,21 @@ def connect(rhost, subcmd):
# stuff here.
escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir)
cmd = r"""
sh -c PATH=%s:'$PATH sshuttle %s'
""" % (escapedir, subcmd)
argv = ['ssh', rhost, '--', cmd.strip()]
sh -c PATH=%s:'$PATH sshuttle --server'
""" % (escapedir,)
argv = ['ssh', '-v', rhost, '--', cmd.strip()]
print repr(argv)
(s1,s2) = socket.socketpair()
def setup():
# runs in the child process
s2.close()
if not rhost:
os.environ['PATH'] = ':'.join([nicedir,
os.environ.get('PATH', '')])
os.setsid()
return subprocess.Popen(argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
preexec_fn=setup)
s1a,s1b = os.dup(s1.fileno()), os.dup(s1.fileno())
s1.close()
p = subprocess.Popen(argv, stdin=s1a, stdout=s1b, preexec_fn=setup)
os.close(s1a)
os.close(s1b)
return p, s2

161
ssnet.py
View File

@ -1,6 +1,17 @@
import socket, errno, select
import struct, socket, errno, select
from helpers import *
HDR_LEN = 8
CMD_EXIT = 0x4200
CMD_PING = 0x4201
CMD_PONG = 0x4202
CMD_CONNECT = 0x4203
CMD_CLOSE = 0x4204
CMD_EOF = 0x4205
CMD_DATA = 0x4206
def _nb_clean(func, *args):
try:
return func(*args)
@ -35,28 +46,32 @@ class SockWrapper:
self.shut_write = True
self.sock.shutdown(socket.SHUT_WR)
def write(self, buf):
assert(buf)
def uwrite(self, buf):
self.sock.setblocking(False)
return _nb_clean(self.sock.send, buf)
def fill(self):
def write(self, buf):
assert(buf)
return self.uwrite(buf)
def uread(self):
if self.shut_read:
return
self.sock.setblocking(False)
rb = _nb_clean(self.sock.recv, 65536)
return _nb_clean(self.sock.recv, 65536)
def fill(self):
if self.buf:
return
rb = self.uread()
if rb:
self.buf.append(rb)
if rb == '': # empty string means EOF; None means temporarily empty
self.noread()
def maybe_fill(self):
if not self.buf:
self.fill()
def copy_to(self, outwrap):
if self.buf and self.buf[0]:
wrote = outwrap.sock.send(self.buf[0])
wrote = outwrap.write(self.buf[0])
self.buf[0] = self.buf[0][wrote:]
while self.buf and not self.buf[0]:
self.buf.pop(0)
@ -102,8 +117,8 @@ class Proxy(Handler):
r.add(self.wrap2.sock)
def callback(self):
self.wrap1.maybe_fill()
self.wrap2.maybe_fill()
self.wrap1.fill()
self.wrap2.fill()
self.wrap1.copy_to(self.wrap2)
self.wrap2.copy_to(self.wrap1)
if (self.wrap1.shut_read and self.wrap2.shut_read and
@ -111,3 +126,125 @@ class Proxy(Handler):
self.ok = False
class Mux(Handler):
def __init__(self, sock):
Handler.__init__(self, [sock])
self.sock = sock
self.channels = {}
self.chani = 0
self.want = 0
self.inbuf = ''
self.outbuf = []
self.send(0, CMD_PING, 'chicken')
def next_channel(self):
# channel 0 is special, so we never allocate it
for timeout in xrange(1024):
self.chani += 1
if self.chani > 65535:
self.chani = 1
if not self.channels.get(self.chani):
return self.chani
def send(self, channel, cmd, data):
data = str(data)
assert(len(data) <= 65535)
p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data
self.outbuf.append(p)
log('Mux: send queue is %d/%d\n'
% (len(self.outbuf), sum(len(b) for b in self.outbuf)))
def got_packet(self, channel, cmd, data):
log('--got-packet--\n')
if cmd == CMD_PING:
self.mux.send(0, CMD_PONG, data)
elif cmd == CMD_EXIT:
self.ok = False
else:
c = self.channels[channel]
c.got_packet(cmd, data)
def flush(self):
self.sock.setblocking(False)
if self.outbuf and self.outbuf[0]:
wrote = _nb_clean(self.sock.send, self.outbuf[0])
if wrote:
self.outbuf[0] = self.outbuf[0][wrote:]
while self.outbuf and not self.outbuf[0]:
self.outbuf.pop()
def fill(self):
self.sock.setblocking(False)
b = _nb_clean(self.sock.recv, 32768)
if b == '': # EOF
ok = False
if b:
self.inbuf += b
def handle(self):
log('inbuf is: %r\n' % self.inbuf)
if len(self.inbuf) >= (self.want or HDR_LEN):
(s1,s2,channel,cmd,datalen) = struct.unpack('!ccHHH',
self.inbuf[:HDR_LEN])
assert(s1 == 'S')
assert(s2 == 'S')
self.want = datalen + HDR_LEN
if self.want and len(self.inbuf) >= self.want:
data = self.inbuf[HDR_LEN:self.want]
self.inbuf = self.inbuf[self.want:]
self.got_packet(channel, cmd, data)
else:
self.fill()
def pre_select(self, r, w, x):
if self.inbuf < (self.want or HDR_LEN):
r.add(self.sock)
if self.outbuf:
w.add(self.sock)
def callback(self):
(r,w,x) = select.select([self.sock], [self.sock], [], 0)
if self.sock in r:
self.handle()
if self.outbuf and self.sock in w:
self.flush()
class MuxWrapper(SockWrapper):
def __init__(self, mux, channel):
SockWrapper.__init__(self, mux.sock)
self.mux = mux
self.channel = channel
self.mux.channels[channel] = self
log('Created MuxWrapper on channel %d\n' % channel)
def noread(self):
if not self.shut_read:
self.shut_read = True
def nowrite(self):
if not self.shut_write:
self.shut_write = True
self.mux.send(self.channel, CMD_EOF, '')
def uwrite(self, buf):
self.mux.send(self.channel, CMD_DATA, buf)
return len(buf)
def uread(self):
if self.shut_read:
return '' # EOF
else:
return None # no data available right now
def got_packet(self, cmd, data):
if cmd == CMD_CLOSE:
self.noread()
self.nowrite()
elif cmd == CMD_EOF:
self.noread()
elif cmd == CMD_DATA:
self.buf.append(data)
else:
raise Exception('unknown command %d (%d bytes)'
% (cmd, len(data)))