From e19fc0132410b1f3e7805e4ea257b52bf2688435 Mon Sep 17 00:00:00 2001 From: nom3ad <19239479+nom3ad@users.noreply.github.com> Date: Tue, 2 Jan 2024 17:53:20 +0530 Subject: [PATCH] !improved windrivert throughput --- hack/exec-sshuttle | 9 +++++- hack/run-benchmark | 5 +-- sshuttle/__main__.py | 5 ++- sshuttle/client.py | 59 +++++++++++++---------------------- sshuttle/firewall.py | 15 +++------ sshuttle/helpers.py | 21 +++++++++++++ sshuttle/methods/windivert.py | 40 +++++++++++++----------- sshuttle/ssh.py | 2 +- 8 files changed, 81 insertions(+), 75 deletions(-) diff --git a/hack/exec-sshuttle b/hack/exec-sshuttle index 4061908..0a82cfb 100755 --- a/hack/exec-sshuttle +++ b/hack/exec-sshuttle @@ -68,6 +68,13 @@ if [[ $ssh_copy_id == true ]]; then with_set_x ssh-copy-id -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p "$port" "$user@$host" fi -sshuttle_bin=${sshuttle_bin:-"$(dirname "$0")/../run"} +if [[ -z $sshuttle_bin || "$sshuttle_bin" == dev ]]; then + cd "$(dirname "$0")/.." + export PYTHONPATH="." + sshuttle_bin="./run" +fi + set -x + exec "${sshuttle_bin}" -r "$user@$host:$port" --ssh-cmd "$ssh_cmd" "${args[@]}" + diff --git a/hack/run-benchmark b/hack/run-benchmark index a17d634..72b2ca0 100755 --- a/hack/run-benchmark +++ b/hack/run-benchmark @@ -18,9 +18,6 @@ benchmark() { local sshuttle_bin="${1?:}" local node="${2:-'node-1'}" echo -e "\n======== Benchmarking sshuttle: $sshuttle_bin ========" - if [[ "$sshuttle_bin" == dev ]]; then - sshuttle_bin="../run" - fi ./exec-sshuttle "$node" --sshuttle-bin="$sshuttle_bin" --listen 55771 & sshuttle_pid=$! trap 'kill -0 $sshuttle_pid &>/dev/null && kill -15 $sshuttle_pid' EXIT @@ -34,6 +31,6 @@ benchmark() { if [[ "$1" ]]; then benchmark "$1" else - benchmark "${SSHUTTLE_BIN:-/bin/sshuttle}" node-1 + benchmark "${SSHUTTLE_BIN:-sshuttle}" node-1 benchmark dev node-1 fi diff --git a/sshuttle/__main__.py b/sshuttle/__main__.py index c756679..3b42093 100644 --- a/sshuttle/__main__.py +++ b/sshuttle/__main__.py @@ -3,9 +3,8 @@ import sys import os from sshuttle.cmdline import main from sshuttle.helpers import debug3 -from sshuttle import __version__ -debug3("Starting cmd %r (pid:%s) | sshuttle: %s | Python: %s" % (sys.argv, os.getpid(), __version__, sys.version)) +debug3("Start: (pid=%s, ppid=%s) %r" % (os.getpid(), os.getppid(), sys.argv)) exit_code = main() -debug3("Exiting cmd %r (pid:%s) with code %s" % (sys.argv, os.getpid(), exit_code,)) +debug3("Exit: (pid=%s, ppid=%s, code=%s) cmd %r" % (os.getpid(), os.getppid(), exit_code, sys.argv)) sys.exit(exit_code) diff --git a/sshuttle/client.py b/sshuttle/client.py index 232620a..c837b24 100644 --- a/sshuttle/client.py +++ b/sshuttle/client.py @@ -5,6 +5,7 @@ import time import subprocess as ssubprocess import os import sys +import base64 import platform import sshuttle.helpers as helpers @@ -14,7 +15,7 @@ import sshuttle.ssyslog as ssyslog import sshuttle.sdnotify as sdnotify from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \ - resolvconf_nameservers, which, is_admin_user + resolvconf_nameservers, which, is_admin_user, RWPair from sshuttle.methods import get_method, Features from sshuttle import __version__ try: @@ -294,48 +295,32 @@ class FirewallClient: return s2.makefile('rwb') else: - # In windows, if client/firewall processes is running as admin user, stdio can be used for communication. - # But if firewall process is run with elevated mode, access to stdio is lost. - # So we have to use a socketpair (as in unix). - # But socket need to be "shared" to child process as it can't be directly set as stdio in Windows + # In Windows CPython, BSD sockets are not supported as subprocess stdio. + # if client (and firewall) processes is running as admin user, pipe based stdio can be used for communication. + # But if firewall process is spwaned in elevated mode by non-admin client process, access to stdio is lost. + # To work around this, we can use a socketpair. + # But socket need to be "shared" to child process as it can't be directly set as stdio. can_use_stdio = is_admin_user() - pstdout = ssubprocess.PIPE if can_use_stdio else None - pstdin = ssubprocess.PIPE + preexec_fn = None penv = os.environ.copy() - penv['PYTHONPATH'] = os.path.dirname(os.path.dirname(__file__)) + if can_use_stdio: + pstdout = ssubprocess.PIPE + pstdin = ssubprocess.PIPE - def get_pfile(): - if can_use_stdio: - self.p.stdin.write(b'COM_STDIO:\n') - self.p.stdin.flush() - - class RWPair: - def __init__(self, r, w): - self.r = r - self.w = w - self.read = r.read - self.readline = r.readline - self.write = w.write - self.flush = w.flush - - def close(self): - for f in self.r, self.w: - try: - f.close() - except Exception: - pass + def get_pfile(): return RWPair(self.p.stdout, self.p.stdin) - # import io - # return io.BufferedRWPair(self.p.stdout, self.p.stdin, 1) - else: - import base64 - (s1, s2) = socket.socketpair() - socket_share_data = s1.share(self.p.pid) + penv['SSHUTTLE_FW_COM_CHANNEL'] = 'stdio' + else: + pstdout = None + pstdin = None + (s1, s2) = socket.socketpair() + socket_share_data = s1.share(self.p.pid) + socket_share_data_b64 = base64.b64encode(socket_share_data) + penv['SSHUTTLE_FW_COM_CHANNEL'] = socket_share_data_b64 + + def get_pfile(): s1.close() - socket_share_data_b64 = base64.b64encode(socket_share_data) - self.p.stdin.write(b'COM_SOCKETSHARE:' + socket_share_data_b64 + b'\n') - self.p.stdin.flush() return s2.makefile('rwb') try: debug1("Starting firewall manager with command: %r" % argv) diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py index 9532ab0..2ec9e25 100644 --- a/sshuttle/firewall.py +++ b/sshuttle/firewall.py @@ -125,21 +125,16 @@ def _setup_daemon_for_windows(): signal.signal(signal.SIGTERM, firewall_exit) signal.signal(signal.SIGINT, firewall_exit) - socket_share_data_prefix = b'COM_SOCKETSHARE:' - line = sys.stdin.buffer.readline().strip() - if line.startswith(socket_share_data_prefix): + com_chan = os.environ.get('SSHUTTLE_FW_COM_CHANNEL') + if com_chan == 'stdio': + debug3('Using inherited stdio for communicating with sshuttle client process') + else: debug3('Using shared socket for communicating with sshuttle client process') - socket_share_data_b64 = line[len(socket_share_data_prefix):] - socket_share_data = base64.b64decode(socket_share_data_b64) + socket_share_data = base64.b64decode(com_chan) sock = socket.fromshare(socket_share_data) # type: socket.socket sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0)) sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True) sock.close() - elif line.startswith(b"COM_STDIO:"): - debug3('Using inherited stdio for communicating with sshuttle client process') - else: - raise Fatal("Unexpected stdin: " + line) - return sys.stdin.buffer, sys.stdout.buffer diff --git a/sshuttle/helpers.py b/sshuttle/helpers.py index 6ad857d..c682150 100644 --- a/sshuttle/helpers.py +++ b/sshuttle/helpers.py @@ -15,6 +15,10 @@ def b(s): return s.encode("ASCII") +def get_verbose_level(): + return verbose + + def log(s): global logprefix try: @@ -254,3 +258,20 @@ def set_non_blocking_io(fd): else: _sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) _sock.setblocking(False) + + +class RWPair: + def __init__(self, r, w): + self.r = r + self.w = w + self.read = r.read + self.readline = r.readline + self.write = w.write + self.flush = w.flush + + def close(self): + for f in self.r, self.w: + try: + f.close() + except Exception: + pass diff --git a/sshuttle/methods/windivert.py b/sshuttle/methods/windivert.py index dc3b169..b17a316 100644 --- a/sshuttle/methods/windivert.py +++ b/sshuttle/methods/windivert.py @@ -1,6 +1,6 @@ import os import sys -import ipaddress +from ipaddress import ip_address, ip_network import threading from collections import namedtuple import socket @@ -15,7 +15,7 @@ import traceback from sshuttle.methods import BaseMethod -from sshuttle.helpers import debug3, log, debug1, debug2, Fatal +from sshuttle.helpers import debug3, log, debug1, debug2, get_verbose_level, Fatal try: # https://reqrypt.org/windivert-doc.html#divert_iphdr @@ -30,7 +30,7 @@ ConnectionTuple = namedtuple( ) -WINDIVERT_MAX_CONNECTIONS = 10_000 +WINDIVERT_MAX_CONNECTIONS = int(os.environ.get('WINDIVERT_MAX_CONNECTIONS', 1024)) class IPProtocol(IntEnum): @@ -150,8 +150,8 @@ class ConnTrack: else: raise RuntimeError("No slot available in ConnTrack") # should not be here - src_addr = ipaddress.ip_address(src_addr) - dst_addr = ipaddress.ip_address(dst_addr) + src_addr = ip_address(src_addr) + dst_addr = ip_address(dst_addr) assert src_addr.version == dst_addr.version ip_version = src_addr.version state_epoch = int(time.time()) @@ -169,7 +169,7 @@ class ConnTrack: def update(self, proto, src_addr, src_port, state): if not self.is_owner: raise RuntimeError("Only owner can mutate ConnTrack") - src_addr = ipaddress.ip_address(src_addr) + src_addr = ip_address(src_addr) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) for i in self.used_slots: if self.shm_list[i].startswith(packed): @@ -190,7 +190,7 @@ class ConnTrack: def remove(self, proto, src_addr, src_port): if not self.is_owner: raise RuntimeError("Only owner can mutate ConnTrack") - src_addr = ipaddress.ip_address(src_addr) + src_addr = ip_address(src_addr) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) for i in self.used_slots: if self.shm_list[i].startswith(packed): @@ -209,7 +209,7 @@ class ConnTrack: ) def get(self, proto, src_addr, src_port): - src_addr = ipaddress.ip_address(src_addr) + src_addr = ip_address(src_addr) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) for entry in self.shm_list: if entry and entry.startswith(packed): @@ -246,8 +246,8 @@ class ConnTrack: state_epoch, state, ) = self.struct_full_tuple.unpack(packed) - dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) - src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) + dst_addr = str(ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) + src_addr = str(ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) return ConnectionTuple( IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) ) @@ -281,7 +281,7 @@ class Method(BaseMethod): continue port_suffix = ":" + str(port) if state == "LISTENING" and local_addr.endswith(port_suffix): - return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]")) + return ip_address(local_addr[:-len(port_suffix)].strip("[]")) raise Fatal("Could not find listening address for {}/{}".format(port, proto)) def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark): @@ -298,7 +298,7 @@ class Method(BaseMethod): # As a workaround we use another interface ip instead. local_addr = self._get_local_proxy_listen_addr(port, family) - for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): + for addr in (ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): if addr.is_loopback or addr.version != family.version: continue if local_addr.is_unspecified or local_addr == addr: @@ -380,9 +380,9 @@ class Method(BaseMethod): for af, c in self.network_config.items(): subnet_filters = [] for cidr in c["subnets"]: - ip_network = ipaddress.ip_network(cidr) - first_ip = ip_network.network_address - last_ip = ip_network.broadcast_address + ip_net = ip_network(cidr) + first_ip = ip_net.network_address + last_ip = ip_net.broadcast_address subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})") family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ") @@ -394,8 +394,9 @@ class Method(BaseMethod): proxy_port = self.proxy_port proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4] proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6] + verbose = get_verbose_level() for pkt in w: - debug3(">>> " + repr_pkt(pkt)) + verbose >= 3 and debug3(">>> " + repr_pkt(pkt)) if pkt.tcp.syn and not pkt.tcp.ack: # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK) self.conntrack.add( @@ -434,7 +435,7 @@ class Method(BaseMethod): proto = IPProtocol.TCP direction = "inbound" # only when proxy address is not loopback address (Useful for testing) ip_filters = [] - for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a): + for addr in (ip_address(a) for a in self.proxy_addr.values() if a): if addr.is_loopback: # Windivert treats all loopback traffic as outbound direction = "outbound" if addr.version == 4: @@ -448,8 +449,9 @@ class Method(BaseMethod): debug1(f"[INGRESS] {filter=}") with pydivert.WinDivert(filter) as w: ready_cb() + verbose = get_verbose_level() for pkt in w: - debug3("<<< " + repr_pkt(pkt)) + verbose >= 3 and debug3("<<< " + repr_pkt(pkt)) if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK received (connection established) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) @@ -466,7 +468,7 @@ class Method(BaseMethod): else: conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) if not conn: - debug2("Unexpected packet: " + repr_pkt(pkt)) + verbose >= 2 and debug2("Unexpected packet: " + repr_pkt(pkt)) continue pkt.src_addr = conn.dst_addr pkt.tcp.src_port = conn.dst_port diff --git a/sshuttle/ssh.py b/sshuttle/ssh.py index 7b494d2..08ea5df 100644 --- a/sshuttle/ssh.py +++ b/sshuttle/ssh.py @@ -218,7 +218,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options): os.close(pstdout) return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0) else: - # In Windows CPython, we can't use BSD sockets as subprocess stdio + # In Windows CPython, BSD sockets are not supported as subprocess stdio # and select.select() used in ssnet.py won't work on Windows pipes. # So we have to use both socketpair (for select.select) and pipes (for subprocess.Popen) together # along with reader/writer threads to stream data between them