We now have a server that works... some of the time.

There still seem to be some weird timing and/or closing-related bugs, since
I can't load the eqldata project correctly unless I use --noserver.
This commit is contained in:
Avery Pennarun 2010-05-02 00:52:06 -04:00
parent d435c41bdb
commit 915a96b0ec
5 changed files with 146 additions and 38 deletions

View File

@ -1,5 +1,5 @@
import struct, socket, select, subprocess, errno import struct, socket, select, subprocess, errno
import ssnet, ssh import ssnet, ssh, helpers
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import * from helpers import *
@ -24,10 +24,21 @@ def iptables_setup(port, subnets):
def _main(listener, listenport, use_server, remotename, subnets): def _main(listener, listenport, use_server, remotename, subnets):
handlers = [] handlers = []
if use_server: if use_server:
helpers.logprefix = 'c : '
(serverproc, serversock) = ssh.connect(remotename) (serverproc, serversock) = ssh.connect(remotename)
mux = Mux(serversock, serversock) mux = Mux(serversock, serversock)
handlers.append(mux) handlers.append(mux)
expected = 'SSHUTTLE0001'
initstring = serversock.recv(len(expected))
if initstring != expected:
raise Exception('expected server init string %r; got %r'
% (expected, initstring))
rv = serverproc.poll()
if rv:
raise Exception('server died with error code %d' % rv)
# we definitely want to do this *after* starting ssh, or we might end # we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection! # up intercepting the ssh connection!
iptables_setup(listenport, subnets) iptables_setup(listenport, subnets)
@ -45,21 +56,24 @@ def _main(listener, listenport, use_server, remotename, subnets):
mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip) mux.send(chan, ssnet.CMD_CONNECT, '%s,%s' % dstip)
outwrap = MuxWrapper(mux, chan) outwrap = MuxWrapper(mux, chan)
else: else:
outsock = socket.socket() outwrap = ssnet.connect_dst(dstip[0], dstip[1])
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
outsock.connect(dstip)
outwrap = SockWrapper(outsock, outsock)
handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
handlers.append(Handler([listener], onaccept)) handlers.append(Handler([listener], onaccept))
while 1: while 1:
if use_server:
rv = serverproc.poll()
if rv:
raise Exception('server died with error code %d' % rv)
r = set() r = set()
w = set() w = set()
x = set() x = set()
handlers = filter(lambda s: s.ok, handlers) handlers = filter(lambda s: s.ok, handlers)
for s in handlers: for s in handlers:
s.pre_select(r,w,x) s.pre_select(r,w,x)
log('\nWaiting: %d[%d,%d,%d]...\n' log('\n')
log('Waiting: %d[%d,%d,%d]...\n'
% (len(handlers), len(r), len(w), len(x))) % (len(handlers), len(r), len(w), len(x)))
(r,w,x) = select.select(r,w,x) (r,w,x) = select.select(r,w,x)
log('r=%r w=%r x=%r\n' % (r,w,x)) log('r=%r w=%r x=%r\n' % (r,w,x))

View File

@ -1,6 +1,8 @@
import sys, os import sys, os
logprefix = ''
def log(s): def log(s):
sys.stdout.flush() sys.stdout.flush()
sys.stderr.write(s) sys.stderr.write(logprefix + s)
sys.stderr.flush() sys.stderr.flush()

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys, os, re import sys, os, re
import options, client, iptables import options, client, iptables, server
# list of: # list of:
@ -58,10 +58,7 @@ o = options.Options('sshuttle', optspec)
(opt, flags, extra) = o.parse(sys.argv[1:]) (opt, flags, extra) = o.parse(sys.argv[1:])
if opt.server: if opt.server:
#o.fatal('server mode not implemented yet') sys.exit(server.main())
os.dup2(2,1)
os.execvp('hd', ['hd'])
sys.exit(1)
elif opt.iptables: elif opt.iptables:
if len(extra) < 1: if len(extra) < 1:
o.fatal('at least one argument expected') o.fatal('at least one argument expected')

43
server.py Normal file
View File

@ -0,0 +1,43 @@
import struct, socket, select
import ssnet, helpers
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
def main():
# synchronization header
sys.stdout.write('SSHUTTLE0001')
sys.stdout.flush()
helpers.logprefix = ' s: '
handlers = []
mux = Mux(socket.fromfd(sys.stdin.fileno(),
socket.AF_INET, socket.SOCK_STREAM),
socket.fromfd(sys.stdout.fileno(),
socket.AF_INET, socket.SOCK_STREAM))
handlers.append(mux)
def new_channel(channel, data):
(dstip,dstport) = data.split(',', 1)
dstport = int(dstport)
outwrap = ssnet.connect_dst(dstip,dstport)
handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
mux.new_channel = new_channel
while mux.ok:
r = set()
w = set()
x = set()
handlers = filter(lambda s: s.ok, handlers)
for s in handlers:
s.pre_select(r,w,x)
log('\n')
log('Waiting: %d[%d,%d,%d]...\n'
% (len(handlers), len(r), len(w), len(x)))
(r,w,x) = select.select(r,w,x)
log('r=%r w=%r x=%r\n' % (r,w,x))
ready = set(r) | set(w) | set(x)
for s in handlers:
if s.socks & ready:
s.callback()

104
ssnet.py
View File

@ -16,16 +16,27 @@ def _nb_clean(func, *args):
try: try:
return func(*args) return func(*args)
except socket.error, e: except socket.error, e:
if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
else:
return None return None
raise
def _try_peername(sock):
try:
return sock.getpeername()
except socket.error, e:
if e.args[0] not in (errno.ENOTCONN,):
raise
else:
return ('0.0.0.0',0)
class SockWrapper: class SockWrapper:
def __init__(self, rsock, wsock): def __init__(self, rsock, wsock):
self.rsock = rsock self.rsock = rsock
self.wsock = wsock self.wsock = wsock
self.peername = self.rsock.getpeername() self.peername = _try_peername(self.rsock)
self.shut_read = self.shut_write = False self.shut_read = self.shut_write = False
self.buf = [] self.buf = []
@ -45,11 +56,20 @@ class SockWrapper:
if not self.shut_write: if not self.shut_write:
log('%r: done writing\n' % self) log('%r: done writing\n' % self)
self.shut_write = True self.shut_write = True
self.wsock.shutdown(socket.SHUT_WR) try:
self.wsock.shutdown(socket.SHUT_WR)
except socket.error:
pass
def uwrite(self, buf): def uwrite(self, buf):
self.wsock.setblocking(False) self.wsock.setblocking(False)
return _nb_clean(self.wsock.send, buf) try:
return _nb_clean(self.wsock.send, buf)
except socket.error:
# unexpected error... stream is dead
self.nowrite()
self.noread()
return 0
def write(self, buf): def write(self, buf):
assert(buf) assert(buf)
@ -59,7 +79,10 @@ class SockWrapper:
if self.shut_read: if self.shut_read:
return return
self.rsock.setblocking(False) self.rsock.setblocking(False)
return _nb_clean(self.rsock.recv, 65536) try:
return _nb_clean(self.rsock.recv, 65536)
except socket.error:
return '' # unexpected error... we'll call it EOF
def fill(self): def fill(self):
if self.buf: if self.buf:
@ -133,6 +156,7 @@ class Mux(Handler):
Handler.__init__(self, [rsock, wsock]) Handler.__init__(self, [rsock, wsock])
self.rsock = rsock self.rsock = rsock
self.wsock = wsock self.wsock = wsock
self.new_channel = None
self.channels = {} self.channels = {}
self.chani = 0 self.chani = 0
self.want = 0 self.want = 0
@ -160,12 +184,18 @@ class Mux(Handler):
def got_packet(self, channel, cmd, data): def got_packet(self, channel, cmd, data):
log('--got-packet--\n') log('--got-packet--\n')
if cmd == CMD_PING: if cmd == CMD_PING:
self.mux.send(0, CMD_PONG, data) self.send(0, CMD_PONG, data)
elif cmd == CMD_PONG:
log('received PING response\n')
elif cmd == CMD_EXIT: elif cmd == CMD_EXIT:
self.ok = False self.ok = False
elif cmd == CMD_CONNECT:
assert(not self.channels.get(channel))
if self.new_channel:
self.new_channel(channel, data)
else: else:
c = self.channels[channel] callback = self.channels[channel]
c.got_packet(cmd, data) callback(cmd, data)
def flush(self): def flush(self):
self.wsock.setblocking(False) self.wsock.setblocking(False)
@ -180,28 +210,30 @@ class Mux(Handler):
self.rsock.setblocking(False) self.rsock.setblocking(False)
b = _nb_clean(self.rsock.recv, 32768) b = _nb_clean(self.rsock.recv, 32768)
if b == '': # EOF if b == '': # EOF
ok = False self.ok = False
if b: if b:
self.inbuf += b self.inbuf += b
def handle(self): def handle(self):
log('inbuf is: %r\n' % self.inbuf) self.fill()
if len(self.inbuf) >= (self.want or HDR_LEN): log('inbuf is: (%d,%d) %r\n' % (self.want, len(self.inbuf), self.inbuf))
(s1,s2,channel,cmd,datalen) = struct.unpack('!ccHHH', while 1:
self.inbuf[:HDR_LEN]) if len(self.inbuf) >= (self.want or HDR_LEN):
assert(s1 == 'S') (s1,s2,channel,cmd,datalen) = \
assert(s2 == 'S') struct.unpack('!ccHHH', self.inbuf[:HDR_LEN])
self.want = datalen + HDR_LEN assert(s1 == 'S')
if self.want and len(self.inbuf) >= self.want: assert(s2 == 'S')
data = self.inbuf[HDR_LEN:self.want] self.want = datalen + HDR_LEN
self.inbuf = self.inbuf[self.want:] if self.want and len(self.inbuf) >= self.want:
self.got_packet(channel, cmd, data) data = self.inbuf[HDR_LEN:self.want]
else: self.inbuf = self.inbuf[self.want:]
self.fill() self.want = 0
self.got_packet(channel, cmd, data)
else:
break
def pre_select(self, r, w, x): def pre_select(self, r, w, x):
if self.inbuf < (self.want or HDR_LEN): r.add(self.rsock)
r.add(self.rsock)
if self.outbuf: if self.outbuf:
w.add(self.wsock) w.add(self.wsock)
@ -218,9 +250,16 @@ class MuxWrapper(SockWrapper):
SockWrapper.__init__(self, mux.rsock, mux.wsock) SockWrapper.__init__(self, mux.rsock, mux.wsock)
self.mux = mux self.mux = mux
self.channel = channel self.channel = channel
self.mux.channels[channel] = self self.mux.channels[channel] = self.got_packet
log('Created MuxWrapper on channel %d\n' % channel) log('Created MuxWrapper on channel %d\n' % channel)
def __del__(self):
self.nowrite()
SockWrapper.__del__(self)
def __repr__(self):
return 'SW%r:Mux#%d' % (self.peername,self.channel)
def noread(self): def noread(self):
if not self.shut_read: if not self.shut_read:
self.shut_read = True self.shut_read = True
@ -231,6 +270,8 @@ class MuxWrapper(SockWrapper):
self.mux.send(self.channel, CMD_EOF, '') self.mux.send(self.channel, CMD_EOF, '')
def uwrite(self, buf): def uwrite(self, buf):
if len(buf) > 65535:
buf = buf[:32768]
self.mux.send(self.channel, CMD_DATA, buf) self.mux.send(self.channel, CMD_DATA, buf)
return len(buf) return len(buf)
@ -251,3 +292,14 @@ class MuxWrapper(SockWrapper):
else: else:
raise Exception('unknown command %d (%d bytes)' raise Exception('unknown command %d (%d bytes)'
% (cmd, len(data))) % (cmd, len(data)))
def connect_dst(ip, port):
outsock = socket.socket()
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
try:
outsock.connect((ip,port))
except socket.error, e:
if e.args[0] not in [errno.ECONNREFUSED]:
raise
return SockWrapper(outsock,outsock)