mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-22 07:53:43 +01:00
6c2dc54b9e
It seems ssh is kind of stupid and uses a really big SO_SNDBUF (hundreds of kbytes). Thus, we can't depend on the socket's output buffer to limit our latency down to something reasonable. Instead, limit the amount of data we can send in a single round trip.
382 lines
11 KiB
Python
382 lines
11 KiB
Python
import struct, socket, errno, select
|
|
from helpers import *
|
|
|
|
HDR_LEN = 8
|
|
|
|
|
|
CMD_EXIT = 0x4200
|
|
CMD_PING = 0x4201
|
|
CMD_PONG = 0x4202
|
|
CMD_CONNECT = 0x4203
|
|
CMD_CLOSE = 0x4204
|
|
CMD_EOF = 0x4205
|
|
CMD_DATA = 0x4206
|
|
|
|
cmd_to_name = {
|
|
CMD_EXIT: 'EXIT',
|
|
CMD_PING: 'PING',
|
|
CMD_PONG: 'PONG',
|
|
CMD_CONNECT: 'CONNECT',
|
|
CMD_CLOSE: 'CLOSE',
|
|
CMD_EOF: 'EOF',
|
|
CMD_DATA: 'DATA',
|
|
}
|
|
|
|
|
|
|
|
def _nb_clean(func, *args):
|
|
try:
|
|
return func(*args)
|
|
except OSError, e:
|
|
if e.errno not in (errno.EWOULDBLOCK, errno.EAGAIN):
|
|
raise
|
|
else:
|
|
return None
|
|
|
|
|
|
def _try_peername(sock):
|
|
try:
|
|
pn = sock.getpeername()
|
|
if pn:
|
|
return '%s:%s' % (pn[0], pn[1])
|
|
except socket.error, e:
|
|
if e.args[0] not in (errno.ENOTCONN, errno.ENOTSOCK):
|
|
raise
|
|
return 'unknown'
|
|
|
|
|
|
class SockWrapper:
|
|
def __init__(self, rsock, wsock, connect_to=None, peername=None):
|
|
self.exc = None
|
|
self.rsock = rsock
|
|
self.wsock = wsock
|
|
self.shut_read = self.shut_write = False
|
|
self.buf = []
|
|
self.connect_to = connect_to
|
|
self.peername = peername or _try_peername(self.rsock)
|
|
self.try_connect()
|
|
|
|
def __del__(self):
|
|
debug1('%r: deleting\n' % self)
|
|
if self.exc:
|
|
debug1('%r: error was: %r\n' % (self, self.exc))
|
|
|
|
def __repr__(self):
|
|
return 'SW:%s' % (self.peername,)
|
|
|
|
def seterr(self, e):
|
|
if not self.exc:
|
|
self.exc = e
|
|
|
|
def try_connect(self):
|
|
if not self.connect_to:
|
|
return # already connected
|
|
self.rsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
|
|
self.rsock.setblocking(False)
|
|
try:
|
|
self.rsock.connect(self.connect_to)
|
|
self.connect_to = None
|
|
except socket.error, e:
|
|
if e.args[0] in [errno.EINPROGRESS, errno.EALREADY]:
|
|
pass # not connected yet
|
|
elif e.args[0] in [errno.ECONNREFUSED, errno.ETIMEDOUT]:
|
|
# a "normal" kind of error
|
|
self.connect_to = None
|
|
self.seterr(e)
|
|
else:
|
|
raise # error we've never heard of?! barf completely.
|
|
|
|
def noread(self):
|
|
if not self.shut_read:
|
|
debug2('%r: done reading\n' % self)
|
|
self.shut_read = True
|
|
#self.rsock.shutdown(socket.SHUT_RD) # doesn't do anything anyway
|
|
|
|
def nowrite(self):
|
|
if not self.shut_write:
|
|
debug2('%r: done writing\n' % self)
|
|
self.shut_write = True
|
|
try:
|
|
self.wsock.shutdown(socket.SHUT_WR)
|
|
except socket.error, e:
|
|
self.seterr(e)
|
|
|
|
def uwrite(self, buf):
|
|
if self.connect_to:
|
|
return 0 # still connecting
|
|
self.wsock.setblocking(False)
|
|
try:
|
|
return _nb_clean(os.write, self.wsock.fileno(), buf)
|
|
except OSError, e:
|
|
# unexpected error... stream is dead
|
|
self.seterr(e)
|
|
self.nowrite()
|
|
self.noread()
|
|
return 0
|
|
|
|
def write(self, buf):
|
|
assert(buf)
|
|
return self.uwrite(buf)
|
|
|
|
def uread(self):
|
|
if self.connect_to:
|
|
return None # still connecting
|
|
if self.shut_read:
|
|
return
|
|
self.rsock.setblocking(False)
|
|
try:
|
|
return _nb_clean(os.read, self.rsock.fileno(), 65536)
|
|
except OSError, e:
|
|
self.seterr(e)
|
|
return '' # unexpected error... we'll call it EOF
|
|
|
|
def fill(self):
|
|
if self.buf:
|
|
return
|
|
rb = self.uread()
|
|
if rb:
|
|
self.buf.append(rb)
|
|
if rb == '': # empty string means EOF; None means temporarily empty
|
|
self.noread()
|
|
|
|
def copy_to(self, outwrap):
|
|
if self.buf and self.buf[0]:
|
|
wrote = outwrap.write(self.buf[0])
|
|
self.buf[0] = self.buf[0][wrote:]
|
|
while self.buf and not self.buf[0]:
|
|
self.buf[0:1] = []
|
|
if not self.buf and self.shut_read:
|
|
outwrap.nowrite()
|
|
|
|
|
|
class Handler:
|
|
def __init__(self, socks = None, callback = None):
|
|
self.ok = True
|
|
self.socks = set(socks or [])
|
|
if callback:
|
|
self.callback = callback
|
|
|
|
def pre_select(self, r, w, x):
|
|
r |= self.socks
|
|
|
|
def callback(self):
|
|
log('--no callback defined-- %r\n' % self)
|
|
(r,w,x) = select.select(self.socks, [], [], 0)
|
|
for s in r:
|
|
v = s.recv(4096)
|
|
if not v:
|
|
log('--closed-- %r\n' % self)
|
|
self.socks = set()
|
|
self.ok = False
|
|
|
|
|
|
class Proxy(Handler):
|
|
def __init__(self, wrap1, wrap2):
|
|
Handler.__init__(self, [wrap1.rsock, wrap1.wsock,
|
|
wrap2.rsock, wrap2.wsock])
|
|
self.wrap1 = wrap1
|
|
self.wrap2 = wrap2
|
|
|
|
def pre_select(self, r, w, x):
|
|
if self.wrap1.connect_to:
|
|
w.add(self.wrap1.rsock)
|
|
elif self.wrap1.buf:
|
|
w.add(self.wrap2.wsock)
|
|
elif not self.wrap1.shut_read:
|
|
r.add(self.wrap1.rsock)
|
|
|
|
if self.wrap2.connect_to:
|
|
w.add(self.wrap2.rsock)
|
|
elif self.wrap2.buf:
|
|
w.add(self.wrap1.wsock)
|
|
elif not self.wrap2.shut_read:
|
|
r.add(self.wrap2.rsock)
|
|
|
|
def callback(self):
|
|
self.wrap1.try_connect()
|
|
self.wrap2.try_connect()
|
|
self.wrap1.fill()
|
|
self.wrap2.fill()
|
|
self.wrap1.copy_to(self.wrap2)
|
|
self.wrap2.copy_to(self.wrap1)
|
|
if (self.wrap1.shut_read and self.wrap2.shut_read and
|
|
not self.wrap1.buf and not self.wrap2.buf):
|
|
self.ok = False
|
|
|
|
|
|
class Mux(Handler):
|
|
def __init__(self, rsock, wsock):
|
|
Handler.__init__(self, [rsock, wsock])
|
|
self.rsock = rsock
|
|
self.wsock = wsock
|
|
self.new_channel = None
|
|
self.channels = {}
|
|
self.chani = 0
|
|
self.want = 0
|
|
self.inbuf = ''
|
|
self.outbuf = []
|
|
self.fullness = 0
|
|
self.too_full = False
|
|
self.send(0, CMD_PING, 'chicken')
|
|
|
|
def next_channel(self):
|
|
# channel 0 is special, so we never allocate it
|
|
for timeout in xrange(1024):
|
|
self.chani += 1
|
|
if self.chani > 65535:
|
|
self.chani = 1
|
|
if not self.channels.get(self.chani):
|
|
return self.chani
|
|
|
|
def amount_queued(self):
|
|
return sum(len(b) for b in self.outbuf)
|
|
|
|
def check_fullness(self):
|
|
if self.fullness > 65536:
|
|
self.too_full = True
|
|
#ob = []
|
|
#for b in self.outbuf:
|
|
# (s1,s2,c) = struct.unpack('!ccH', b[:4])
|
|
# ob.append(c)
|
|
#log('outbuf: %d %r\n' % (self.amount_queued(), ob))
|
|
|
|
def send(self, channel, cmd, data):
|
|
data = str(data)
|
|
assert(len(data) <= 65535)
|
|
p = struct.pack('!ccHHH', 'S', 'S', channel, cmd, len(data)) + data
|
|
self.outbuf.append(p)
|
|
debug2(' > channel=%d cmd=%s len=%d (fullness=%d)\n'
|
|
% (channel, cmd_to_name[cmd], len(data), self.fullness))
|
|
if self.fullness < 32768 and self.fullness+len(data) > 32768:
|
|
self.send(0, CMD_PING, 'rttest')
|
|
self.fullness += len(data)
|
|
|
|
def got_packet(self, channel, cmd, data):
|
|
debug2('< channel=%d cmd=%s len=%d\n'
|
|
% (channel, cmd_to_name[cmd], len(data)))
|
|
if cmd == CMD_PING:
|
|
self.send(0, CMD_PONG, data)
|
|
elif cmd == CMD_PONG:
|
|
debug2('received PING response\n')
|
|
self.too_full = False
|
|
self.fullness = 0
|
|
elif cmd == CMD_EXIT:
|
|
self.ok = False
|
|
elif cmd == CMD_CONNECT:
|
|
assert(not self.channels.get(channel))
|
|
if self.new_channel:
|
|
self.new_channel(channel, data)
|
|
else:
|
|
callback = self.channels[channel]
|
|
callback(cmd, data)
|
|
|
|
def flush(self):
|
|
self.wsock.setblocking(False)
|
|
if self.outbuf and self.outbuf[0]:
|
|
wrote = _nb_clean(os.write, self.wsock.fileno(), self.outbuf[0])
|
|
debug2('wrote: %d/%d\n' % (wrote, len(self.outbuf[0])))
|
|
if wrote:
|
|
self.outbuf[0] = self.outbuf[0][wrote:]
|
|
while self.outbuf and not self.outbuf[0]:
|
|
self.outbuf[0:1] = []
|
|
|
|
def fill(self):
|
|
self.rsock.setblocking(False)
|
|
b = _nb_clean(os.read, self.rsock.fileno(), 32768)
|
|
#log('<<< %r\n' % b)
|
|
if b == '': # EOF
|
|
self.ok = False
|
|
if b:
|
|
self.inbuf += b
|
|
|
|
def handle(self):
|
|
self.fill()
|
|
#log('inbuf is: (%d,%d) %r\n'
|
|
# % (self.want, len(self.inbuf), self.inbuf))
|
|
while 1:
|
|
if len(self.inbuf) >= (self.want or HDR_LEN):
|
|
(s1,s2,channel,cmd,datalen) = \
|
|
struct.unpack('!ccHHH', self.inbuf[:HDR_LEN])
|
|
assert(s1 == 'S')
|
|
assert(s2 == 'S')
|
|
self.want = datalen + HDR_LEN
|
|
if self.want and len(self.inbuf) >= self.want:
|
|
data = self.inbuf[HDR_LEN:self.want]
|
|
self.inbuf = self.inbuf[self.want:]
|
|
self.want = 0
|
|
self.got_packet(channel, cmd, data)
|
|
else:
|
|
break
|
|
|
|
def pre_select(self, r, w, x):
|
|
r.add(self.rsock)
|
|
if self.outbuf:
|
|
w.add(self.wsock)
|
|
|
|
def callback(self):
|
|
(r,w,x) = select.select([self.rsock], [self.wsock], [], 0)
|
|
if self.rsock in r:
|
|
self.handle()
|
|
if self.outbuf and self.wsock in w:
|
|
self.flush()
|
|
|
|
|
|
class MuxWrapper(SockWrapper):
|
|
def __init__(self, mux, channel):
|
|
SockWrapper.__init__(self, mux.rsock, mux.wsock)
|
|
self.mux = mux
|
|
self.channel = channel
|
|
self.mux.channels[channel] = self.got_packet
|
|
self.socks = []
|
|
debug2('new channel: %d\n' % channel)
|
|
|
|
def __del__(self):
|
|
self.nowrite()
|
|
SockWrapper.__del__(self)
|
|
|
|
def __repr__(self):
|
|
return 'SW%r:Mux#%d' % (self.peername,self.channel)
|
|
|
|
def noread(self):
|
|
if not self.shut_read:
|
|
self.shut_read = True
|
|
|
|
def nowrite(self):
|
|
if not self.shut_write:
|
|
self.shut_write = True
|
|
self.mux.send(self.channel, CMD_EOF, '')
|
|
|
|
def uwrite(self, buf):
|
|
if self.mux.too_full:
|
|
return 0 # too much already enqueued
|
|
if len(buf) > 2048:
|
|
buf = buf[:2048]
|
|
self.mux.send(self.channel, CMD_DATA, buf)
|
|
return len(buf)
|
|
|
|
def uread(self):
|
|
if self.shut_read:
|
|
return '' # EOF
|
|
else:
|
|
return None # no data available right now
|
|
|
|
def got_packet(self, cmd, data):
|
|
if cmd == CMD_CLOSE:
|
|
self.noread()
|
|
self.nowrite()
|
|
elif cmd == CMD_EOF:
|
|
self.noread()
|
|
elif cmd == CMD_DATA:
|
|
self.buf.append(data)
|
|
else:
|
|
raise Exception('unknown command %d (%d bytes)'
|
|
% (cmd, len(data)))
|
|
|
|
|
|
def connect_dst(ip, port):
|
|
debug2('Connecting to %s:%d\n' % (ip, port))
|
|
outsock = socket.socket()
|
|
return SockWrapper(outsock, outsock,
|
|
connect_to = (ip,port),
|
|
peername = '%s:%d' % (ip,port))
|