stdin and stdout have different fds, so make SockWrapper take *two* socks.

We'll need this when we have a SockWrapper pointing at a Mux on a subprocess
pipe.
This commit is contained in:
Avery Pennarun 2010-05-01 23:32:30 -04:00
parent 5f0bfb5d9e
commit d435c41bdb
2 changed files with 33 additions and 30 deletions

View File

@ -25,7 +25,7 @@ def _main(listener, listenport, use_server, remotename, subnets):
handlers = [] handlers = []
if use_server: if use_server:
(serverproc, serversock) = ssh.connect(remotename) (serverproc, serversock) = ssh.connect(remotename)
mux = Mux(serversock) mux = Mux(serversock, serversock)
handlers.append(mux) handlers.append(mux)
# we definitely want to do this *after* starting ssh, or we might end # we definitely want to do this *after* starting ssh, or we might end
@ -48,8 +48,8 @@ def _main(listener, listenport, use_server, remotename, subnets):
outsock = socket.socket() outsock = socket.socket()
outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
outsock.connect(dstip) outsock.connect(dstip)
outwrap = SockWrapper(outsock) outwrap = SockWrapper(outsock, outsock)
handlers.append(Proxy(SockWrapper(sock), outwrap)) handlers.append(Proxy(SockWrapper(sock, sock), outwrap))
handlers.append(Handler([listener], onaccept)) handlers.append(Handler([listener], onaccept))
while 1: while 1:

View File

@ -22,9 +22,10 @@ def _nb_clean(func, *args):
class SockWrapper: class SockWrapper:
def __init__(self, sock): def __init__(self, rsock, wsock):
self.sock = sock self.rsock = rsock
self.peername = self.sock.getpeername() self.wsock = wsock
self.peername = self.rsock.getpeername()
self.shut_read = self.shut_write = False self.shut_read = self.shut_write = False
self.buf = [] self.buf = []
@ -38,17 +39,17 @@ class SockWrapper:
if not self.shut_read: if not self.shut_read:
log('%r: done reading\n' % self) log('%r: done reading\n' % self)
self.shut_read = True self.shut_read = True
#self.sock.shutdown(socket.SHUT_RD) # doesn't do anything anyway #self.rsock.shutdown(socket.SHUT_RD) # doesn't do anything anyway
def nowrite(self): def nowrite(self):
if not self.shut_write: if not self.shut_write:
log('%r: done writing\n' % self) log('%r: done writing\n' % self)
self.shut_write = True self.shut_write = True
self.sock.shutdown(socket.SHUT_WR) self.wsock.shutdown(socket.SHUT_WR)
def uwrite(self, buf): def uwrite(self, buf):
self.sock.setblocking(False) self.wsock.setblocking(False)
return _nb_clean(self.sock.send, buf) return _nb_clean(self.wsock.send, buf)
def write(self, buf): def write(self, buf):
assert(buf) assert(buf)
@ -57,8 +58,8 @@ class SockWrapper:
def uread(self): def uread(self):
if self.shut_read: if self.shut_read:
return return
self.sock.setblocking(False) self.rsock.setblocking(False)
return _nb_clean(self.sock.recv, 65536) return _nb_clean(self.rsock.recv, 65536)
def fill(self): def fill(self):
if self.buf: if self.buf:
@ -102,19 +103,20 @@ class Handler:
class Proxy(Handler): class Proxy(Handler):
def __init__(self, wrap1, wrap2): def __init__(self, wrap1, wrap2):
Handler.__init__(self, [wrap1.sock, wrap2.sock]) Handler.__init__(self, [wrap1.rsock, wrap1.wsock,
wrap2.rsock, wrap2.wsock])
self.wrap1 = wrap1 self.wrap1 = wrap1
self.wrap2 = wrap2 self.wrap2 = wrap2
def pre_select(self, r, w, x): def pre_select(self, r, w, x):
if self.wrap1.buf: if self.wrap1.buf:
w.add(self.wrap2.sock) w.add(self.wrap2.wsock)
elif not self.wrap1.shut_read: elif not self.wrap1.shut_read:
r.add(self.wrap1.sock) r.add(self.wrap1.rsock)
if self.wrap2.buf: if self.wrap2.buf:
w.add(self.wrap1.sock) w.add(self.wrap1.wsock)
elif not self.wrap2.shut_read: elif not self.wrap2.shut_read:
r.add(self.wrap2.sock) r.add(self.wrap2.rsock)
def callback(self): def callback(self):
self.wrap1.fill() self.wrap1.fill()
@ -127,9 +129,10 @@ class Proxy(Handler):
class Mux(Handler): class Mux(Handler):
def __init__(self, sock): def __init__(self, rsock, wsock):
Handler.__init__(self, [sock]) Handler.__init__(self, [rsock, wsock])
self.sock = sock self.rsock = rsock
self.wsock = wsock
self.channels = {} self.channels = {}
self.chani = 0 self.chani = 0
self.want = 0 self.want = 0
@ -165,17 +168,17 @@ class Mux(Handler):
c.got_packet(cmd, data) c.got_packet(cmd, data)
def flush(self): def flush(self):
self.sock.setblocking(False) self.wsock.setblocking(False)
if self.outbuf and self.outbuf[0]: if self.outbuf and self.outbuf[0]:
wrote = _nb_clean(self.sock.send, self.outbuf[0]) wrote = _nb_clean(self.wsock.send, self.outbuf[0])
if wrote: if wrote:
self.outbuf[0] = self.outbuf[0][wrote:] self.outbuf[0] = self.outbuf[0][wrote:]
while self.outbuf and not self.outbuf[0]: while self.outbuf and not self.outbuf[0]:
self.outbuf.pop() self.outbuf.pop()
def fill(self): def fill(self):
self.sock.setblocking(False) self.rsock.setblocking(False)
b = _nb_clean(self.sock.recv, 32768) b = _nb_clean(self.rsock.recv, 32768)
if b == '': # EOF if b == '': # EOF
ok = False ok = False
if b: if b:
@ -198,21 +201,21 @@ class Mux(Handler):
def pre_select(self, r, w, x): def pre_select(self, r, w, x):
if self.inbuf < (self.want or HDR_LEN): if self.inbuf < (self.want or HDR_LEN):
r.add(self.sock) r.add(self.rsock)
if self.outbuf: if self.outbuf:
w.add(self.sock) w.add(self.wsock)
def callback(self): def callback(self):
(r,w,x) = select.select([self.sock], [self.sock], [], 0) (r,w,x) = select.select([self.rsock], [self.wsock], [], 0)
if self.sock in r: if self.rsock in r:
self.handle() self.handle()
if self.outbuf and self.sock in w: if self.outbuf and self.wsock in w:
self.flush() self.flush()
class MuxWrapper(SockWrapper): class MuxWrapper(SockWrapper):
def __init__(self, mux, channel): def __init__(self, mux, channel):
SockWrapper.__init__(self, mux.sock) SockWrapper.__init__(self, mux.rsock, mux.wsock)
self.mux = mux self.mux = mux
self.channel = channel self.channel = channel
self.mux.channels[channel] = self self.mux.channels[channel] = self