diff --git a/client.py b/client.py index b3c993e..2ca75ba 100644 --- a/client.py +++ b/client.py @@ -25,7 +25,7 @@ def _main(listener, listenport, use_server, remotename, subnets): handlers = [] if use_server: (serverproc, serversock) = ssh.connect(remotename) - mux = Mux(serversock) + mux = Mux(serversock, serversock) handlers.append(mux) # we definitely want to do this *after* starting ssh, or we might end @@ -48,8 +48,8 @@ def _main(listener, listenport, use_server, remotename, subnets): outsock = socket.socket() outsock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) outsock.connect(dstip) - outwrap = SockWrapper(outsock) - handlers.append(Proxy(SockWrapper(sock), outwrap)) + outwrap = SockWrapper(outsock, outsock) + handlers.append(Proxy(SockWrapper(sock, sock), outwrap)) handlers.append(Handler([listener], onaccept)) while 1: diff --git a/ssnet.py b/ssnet.py index 60cc1e5..0966cec 100644 --- a/ssnet.py +++ b/ssnet.py @@ -22,9 +22,10 @@ def _nb_clean(func, *args): class SockWrapper: - def __init__(self, sock): - self.sock = sock - self.peername = self.sock.getpeername() + def __init__(self, rsock, wsock): + self.rsock = rsock + self.wsock = wsock + self.peername = self.rsock.getpeername() self.shut_read = self.shut_write = False self.buf = [] @@ -38,17 +39,17 @@ class SockWrapper: if not self.shut_read: log('%r: done reading\n' % self) self.shut_read = True - #self.sock.shutdown(socket.SHUT_RD) # doesn't do anything anyway + #self.rsock.shutdown(socket.SHUT_RD) # doesn't do anything anyway def nowrite(self): if not self.shut_write: log('%r: done writing\n' % self) self.shut_write = True - self.sock.shutdown(socket.SHUT_WR) + self.wsock.shutdown(socket.SHUT_WR) def uwrite(self, buf): - self.sock.setblocking(False) - return _nb_clean(self.sock.send, buf) + self.wsock.setblocking(False) + return _nb_clean(self.wsock.send, buf) def write(self, buf): assert(buf) @@ -57,8 +58,8 @@ class SockWrapper: def uread(self): if self.shut_read: return - self.sock.setblocking(False) - return _nb_clean(self.sock.recv, 65536) + self.rsock.setblocking(False) + return _nb_clean(self.rsock.recv, 65536) def fill(self): if self.buf: @@ -102,19 +103,20 @@ class Handler: class Proxy(Handler): def __init__(self, wrap1, wrap2): - Handler.__init__(self, [wrap1.sock, wrap2.sock]) + Handler.__init__(self, [wrap1.rsock, wrap1.wsock, + wrap2.rsock, wrap2.wsock]) self.wrap1 = wrap1 self.wrap2 = wrap2 def pre_select(self, r, w, x): if self.wrap1.buf: - w.add(self.wrap2.sock) + w.add(self.wrap2.wsock) elif not self.wrap1.shut_read: - r.add(self.wrap1.sock) + r.add(self.wrap1.rsock) if self.wrap2.buf: - w.add(self.wrap1.sock) + w.add(self.wrap1.wsock) elif not self.wrap2.shut_read: - r.add(self.wrap2.sock) + r.add(self.wrap2.rsock) def callback(self): self.wrap1.fill() @@ -127,9 +129,10 @@ class Proxy(Handler): class Mux(Handler): - def __init__(self, sock): - Handler.__init__(self, [sock]) - self.sock = sock + def __init__(self, rsock, wsock): + Handler.__init__(self, [rsock, wsock]) + self.rsock = rsock + self.wsock = wsock self.channels = {} self.chani = 0 self.want = 0 @@ -165,17 +168,17 @@ class Mux(Handler): c.got_packet(cmd, data) def flush(self): - self.sock.setblocking(False) + self.wsock.setblocking(False) if self.outbuf and self.outbuf[0]: - wrote = _nb_clean(self.sock.send, self.outbuf[0]) + wrote = _nb_clean(self.wsock.send, self.outbuf[0]) if wrote: self.outbuf[0] = self.outbuf[0][wrote:] while self.outbuf and not self.outbuf[0]: self.outbuf.pop() def fill(self): - self.sock.setblocking(False) - b = _nb_clean(self.sock.recv, 32768) + self.rsock.setblocking(False) + b = _nb_clean(self.rsock.recv, 32768) if b == '': # EOF ok = False if b: @@ -198,21 +201,21 @@ class Mux(Handler): def pre_select(self, r, w, x): if self.inbuf < (self.want or HDR_LEN): - r.add(self.sock) + r.add(self.rsock) if self.outbuf: - w.add(self.sock) + w.add(self.wsock) def callback(self): - (r,w,x) = select.select([self.sock], [self.sock], [], 0) - if self.sock in r: + (r,w,x) = select.select([self.rsock], [self.wsock], [], 0) + if self.rsock in r: self.handle() - if self.outbuf and self.sock in w: + if self.outbuf and self.wsock in w: self.flush() class MuxWrapper(SockWrapper): def __init__(self, mux, channel): - SockWrapper.__init__(self, mux.sock) + SockWrapper.__init__(self, mux.rsock, mux.wsock) self.mux = mux self.channel = channel self.mux.channels[channel] = self