!improved windrivert throughput

This commit is contained in:
nom3ad 2024-01-02 17:53:20 +05:30 committed by Brian May
parent 371258991f
commit e19fc01324
8 changed files with 81 additions and 75 deletions

View File

@ -68,6 +68,13 @@ if [[ $ssh_copy_id == true ]]; then
with_set_x ssh-copy-id -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p "$port" "$user@$host" with_set_x ssh-copy-id -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p "$port" "$user@$host"
fi fi
sshuttle_bin=${sshuttle_bin:-"$(dirname "$0")/../run"} if [[ -z $sshuttle_bin || "$sshuttle_bin" == dev ]]; then
cd "$(dirname "$0")/.."
export PYTHONPATH="."
sshuttle_bin="./run"
fi
set -x set -x
exec "${sshuttle_bin}" -r "$user@$host:$port" --ssh-cmd "$ssh_cmd" "${args[@]}" exec "${sshuttle_bin}" -r "$user@$host:$port" --ssh-cmd "$ssh_cmd" "${args[@]}"

View File

@ -18,9 +18,6 @@ benchmark() {
local sshuttle_bin="${1?:}" local sshuttle_bin="${1?:}"
local node="${2:-'node-1'}" local node="${2:-'node-1'}"
echo -e "\n======== Benchmarking sshuttle: $sshuttle_bin ========" echo -e "\n======== Benchmarking sshuttle: $sshuttle_bin ========"
if [[ "$sshuttle_bin" == dev ]]; then
sshuttle_bin="../run"
fi
./exec-sshuttle "$node" --sshuttle-bin="$sshuttle_bin" --listen 55771 & ./exec-sshuttle "$node" --sshuttle-bin="$sshuttle_bin" --listen 55771 &
sshuttle_pid=$! sshuttle_pid=$!
trap 'kill -0 $sshuttle_pid &>/dev/null && kill -15 $sshuttle_pid' EXIT trap 'kill -0 $sshuttle_pid &>/dev/null && kill -15 $sshuttle_pid' EXIT
@ -34,6 +31,6 @@ benchmark() {
if [[ "$1" ]]; then if [[ "$1" ]]; then
benchmark "$1" benchmark "$1"
else else
benchmark "${SSHUTTLE_BIN:-/bin/sshuttle}" node-1 benchmark "${SSHUTTLE_BIN:-sshuttle}" node-1
benchmark dev node-1 benchmark dev node-1
fi fi

View File

@ -3,9 +3,8 @@ import sys
import os import os
from sshuttle.cmdline import main from sshuttle.cmdline import main
from sshuttle.helpers import debug3 from sshuttle.helpers import debug3
from sshuttle import __version__
debug3("Starting cmd %r (pid:%s) | sshuttle: %s | Python: %s" % (sys.argv, os.getpid(), __version__, sys.version)) debug3("Start: (pid=%s, ppid=%s) %r" % (os.getpid(), os.getppid(), sys.argv))
exit_code = main() exit_code = main()
debug3("Exiting cmd %r (pid:%s) with code %s" % (sys.argv, os.getpid(), exit_code,)) debug3("Exit: (pid=%s, ppid=%s, code=%s) cmd %r" % (os.getpid(), os.getppid(), exit_code, sys.argv))
sys.exit(exit_code) sys.exit(exit_code)

View File

@ -5,6 +5,7 @@ import time
import subprocess as ssubprocess import subprocess as ssubprocess
import os import os
import sys import sys
import base64
import platform import platform
import sshuttle.helpers as helpers import sshuttle.helpers as helpers
@ -14,7 +15,7 @@ import sshuttle.ssyslog as ssyslog
import sshuttle.sdnotify as sdnotify import sshuttle.sdnotify as sdnotify
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \ from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
resolvconf_nameservers, which, is_admin_user resolvconf_nameservers, which, is_admin_user, RWPair
from sshuttle.methods import get_method, Features from sshuttle.methods import get_method, Features
from sshuttle import __version__ from sshuttle import __version__
try: try:
@ -294,48 +295,32 @@ class FirewallClient:
return s2.makefile('rwb') return s2.makefile('rwb')
else: else:
# In windows, if client/firewall processes is running as admin user, stdio can be used for communication. # In Windows CPython, BSD sockets are not supported as subprocess stdio.
# But if firewall process is run with elevated mode, access to stdio is lost. # if client (and firewall) processes is running as admin user, pipe based stdio can be used for communication.
# So we have to use a socketpair (as in unix). # But if firewall process is spwaned in elevated mode by non-admin client process, access to stdio is lost.
# But socket need to be "shared" to child process as it can't be directly set as stdio in Windows # To work around this, we can use a socketpair.
# But socket need to be "shared" to child process as it can't be directly set as stdio.
can_use_stdio = is_admin_user() can_use_stdio = is_admin_user()
pstdout = ssubprocess.PIPE if can_use_stdio else None
pstdin = ssubprocess.PIPE
preexec_fn = None preexec_fn = None
penv = os.environ.copy() penv = os.environ.copy()
penv['PYTHONPATH'] = os.path.dirname(os.path.dirname(__file__)) if can_use_stdio:
pstdout = ssubprocess.PIPE
pstdin = ssubprocess.PIPE
def get_pfile(): def get_pfile():
if can_use_stdio:
self.p.stdin.write(b'COM_STDIO:\n')
self.p.stdin.flush()
class RWPair:
def __init__(self, r, w):
self.r = r
self.w = w
self.read = r.read
self.readline = r.readline
self.write = w.write
self.flush = w.flush
def close(self):
for f in self.r, self.w:
try:
f.close()
except Exception:
pass
return RWPair(self.p.stdout, self.p.stdin) return RWPair(self.p.stdout, self.p.stdin)
# import io penv['SSHUTTLE_FW_COM_CHANNEL'] = 'stdio'
# return io.BufferedRWPair(self.p.stdout, self.p.stdin, 1)
else: else:
import base64 pstdout = None
pstdin = None
(s1, s2) = socket.socketpair() (s1, s2) = socket.socketpair()
socket_share_data = s1.share(self.p.pid) socket_share_data = s1.share(self.p.pid)
s1.close()
socket_share_data_b64 = base64.b64encode(socket_share_data) socket_share_data_b64 = base64.b64encode(socket_share_data)
self.p.stdin.write(b'COM_SOCKETSHARE:' + socket_share_data_b64 + b'\n') penv['SSHUTTLE_FW_COM_CHANNEL'] = socket_share_data_b64
self.p.stdin.flush()
def get_pfile():
s1.close()
return s2.makefile('rwb') return s2.makefile('rwb')
try: try:
debug1("Starting firewall manager with command: %r" % argv) debug1("Starting firewall manager with command: %r" % argv)

View File

@ -125,21 +125,16 @@ def _setup_daemon_for_windows():
signal.signal(signal.SIGTERM, firewall_exit) signal.signal(signal.SIGTERM, firewall_exit)
signal.signal(signal.SIGINT, firewall_exit) signal.signal(signal.SIGINT, firewall_exit)
socket_share_data_prefix = b'COM_SOCKETSHARE:' com_chan = os.environ.get('SSHUTTLE_FW_COM_CHANNEL')
line = sys.stdin.buffer.readline().strip() if com_chan == 'stdio':
if line.startswith(socket_share_data_prefix): debug3('Using inherited stdio for communicating with sshuttle client process')
else:
debug3('Using shared socket for communicating with sshuttle client process') debug3('Using shared socket for communicating with sshuttle client process')
socket_share_data_b64 = line[len(socket_share_data_prefix):] socket_share_data = base64.b64decode(com_chan)
socket_share_data = base64.b64decode(socket_share_data_b64)
sock = socket.fromshare(socket_share_data) # type: socket.socket sock = socket.fromshare(socket_share_data) # type: socket.socket
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0)) sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True) sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
sock.close() sock.close()
elif line.startswith(b"COM_STDIO:"):
debug3('Using inherited stdio for communicating with sshuttle client process')
else:
raise Fatal("Unexpected stdin: " + line)
return sys.stdin.buffer, sys.stdout.buffer return sys.stdin.buffer, sys.stdout.buffer

View File

@ -15,6 +15,10 @@ def b(s):
return s.encode("ASCII") return s.encode("ASCII")
def get_verbose_level():
return verbose
def log(s): def log(s):
global logprefix global logprefix
try: try:
@ -254,3 +258,20 @@ def set_non_blocking_io(fd):
else: else:
_sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) _sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
_sock.setblocking(False) _sock.setblocking(False)
class RWPair:
def __init__(self, r, w):
self.r = r
self.w = w
self.read = r.read
self.readline = r.readline
self.write = w.write
self.flush = w.flush
def close(self):
for f in self.r, self.w:
try:
f.close()
except Exception:
pass

View File

@ -1,6 +1,6 @@
import os import os
import sys import sys
import ipaddress from ipaddress import ip_address, ip_network
import threading import threading
from collections import namedtuple from collections import namedtuple
import socket import socket
@ -15,7 +15,7 @@ import traceback
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal from sshuttle.helpers import debug3, log, debug1, debug2, get_verbose_level, Fatal
try: try:
# https://reqrypt.org/windivert-doc.html#divert_iphdr # https://reqrypt.org/windivert-doc.html#divert_iphdr
@ -30,7 +30,7 @@ ConnectionTuple = namedtuple(
) )
WINDIVERT_MAX_CONNECTIONS = 10_000 WINDIVERT_MAX_CONNECTIONS = int(os.environ.get('WINDIVERT_MAX_CONNECTIONS', 1024))
class IPProtocol(IntEnum): class IPProtocol(IntEnum):
@ -150,8 +150,8 @@ class ConnTrack:
else: else:
raise RuntimeError("No slot available in ConnTrack") # should not be here raise RuntimeError("No slot available in ConnTrack") # should not be here
src_addr = ipaddress.ip_address(src_addr) src_addr = ip_address(src_addr)
dst_addr = ipaddress.ip_address(dst_addr) dst_addr = ip_address(dst_addr)
assert src_addr.version == dst_addr.version assert src_addr.version == dst_addr.version
ip_version = src_addr.version ip_version = src_addr.version
state_epoch = int(time.time()) state_epoch = int(time.time())
@ -169,7 +169,7 @@ class ConnTrack:
def update(self, proto, src_addr, src_port, state): def update(self, proto, src_addr, src_port, state):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
src_addr = ipaddress.ip_address(src_addr) src_addr = ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for i in self.used_slots: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
@ -190,7 +190,7 @@ class ConnTrack:
def remove(self, proto, src_addr, src_port): def remove(self, proto, src_addr, src_port):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
src_addr = ipaddress.ip_address(src_addr) src_addr = ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for i in self.used_slots: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
@ -209,7 +209,7 @@ class ConnTrack:
) )
def get(self, proto, src_addr, src_port): def get(self, proto, src_addr, src_port):
src_addr = ipaddress.ip_address(src_addr) src_addr = ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for entry in self.shm_list: for entry in self.shm_list:
if entry and entry.startswith(packed): if entry and entry.startswith(packed):
@ -246,8 +246,8 @@ class ConnTrack:
state_epoch, state_epoch,
state, state,
) = self.struct_full_tuple.unpack(packed) ) = self.struct_full_tuple.unpack(packed)
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) dst_addr = str(ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) src_addr = str(ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
return ConnectionTuple( return ConnectionTuple(
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
) )
@ -281,7 +281,7 @@ class Method(BaseMethod):
continue continue
port_suffix = ":" + str(port) port_suffix = ":" + str(port)
if state == "LISTENING" and local_addr.endswith(port_suffix): if state == "LISTENING" and local_addr.endswith(port_suffix):
return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]")) return ip_address(local_addr[:-len(port_suffix)].strip("[]"))
raise Fatal("Could not find listening address for {}/{}".format(port, proto)) raise Fatal("Could not find listening address for {}/{}".format(port, proto))
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark): def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark):
@ -298,7 +298,7 @@ class Method(BaseMethod):
# As a workaround we use another interface ip instead. # As a workaround we use another interface ip instead.
local_addr = self._get_local_proxy_listen_addr(port, family) local_addr = self._get_local_proxy_listen_addr(port, family)
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): for addr in (ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
if addr.is_loopback or addr.version != family.version: if addr.is_loopback or addr.version != family.version:
continue continue
if local_addr.is_unspecified or local_addr == addr: if local_addr.is_unspecified or local_addr == addr:
@ -380,9 +380,9 @@ class Method(BaseMethod):
for af, c in self.network_config.items(): for af, c in self.network_config.items():
subnet_filters = [] subnet_filters = []
for cidr in c["subnets"]: for cidr in c["subnets"]:
ip_network = ipaddress.ip_network(cidr) ip_net = ip_network(cidr)
first_ip = ip_network.network_address first_ip = ip_net.network_address
last_ip = ip_network.broadcast_address last_ip = ip_net.broadcast_address
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})") subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ") family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
@ -394,8 +394,9 @@ class Method(BaseMethod):
proxy_port = self.proxy_port proxy_port = self.proxy_port
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4] proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6] proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
verbose = get_verbose_level()
for pkt in w: for pkt in w:
debug3(">>> " + repr_pkt(pkt)) verbose >= 3 and debug3(">>> " + repr_pkt(pkt))
if pkt.tcp.syn and not pkt.tcp.ack: if pkt.tcp.syn and not pkt.tcp.ack:
# SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK) # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
self.conntrack.add( self.conntrack.add(
@ -434,7 +435,7 @@ class Method(BaseMethod):
proto = IPProtocol.TCP proto = IPProtocol.TCP
direction = "inbound" # only when proxy address is not loopback address (Useful for testing) direction = "inbound" # only when proxy address is not loopback address (Useful for testing)
ip_filters = [] ip_filters = []
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a): for addr in (ip_address(a) for a in self.proxy_addr.values() if a):
if addr.is_loopback: # Windivert treats all loopback traffic as outbound if addr.is_loopback: # Windivert treats all loopback traffic as outbound
direction = "outbound" direction = "outbound"
if addr.version == 4: if addr.version == 4:
@ -448,8 +449,9 @@ class Method(BaseMethod):
debug1(f"[INGRESS] {filter=}") debug1(f"[INGRESS] {filter=}")
with pydivert.WinDivert(filter) as w: with pydivert.WinDivert(filter) as w:
ready_cb() ready_cb()
verbose = get_verbose_level()
for pkt in w: for pkt in w:
debug3("<<< " + repr_pkt(pkt)) verbose >= 3 and debug3("<<< " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: if pkt.tcp.syn and pkt.tcp.ack:
# SYN+ACK received (connection established) # SYN+ACK received (connection established)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
@ -466,7 +468,7 @@ class Method(BaseMethod):
else: else:
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
if not conn: if not conn:
debug2("Unexpected packet: " + repr_pkt(pkt)) verbose >= 2 and debug2("Unexpected packet: " + repr_pkt(pkt))
continue continue
pkt.src_addr = conn.dst_addr pkt.src_addr = conn.dst_addr
pkt.tcp.src_port = conn.dst_port pkt.tcp.src_port = conn.dst_port

View File

@ -218,7 +218,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
os.close(pstdout) os.close(pstdout)
return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0) return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0)
else: else:
# In Windows CPython, we can't use BSD sockets as subprocess stdio # In Windows CPython, BSD sockets are not supported as subprocess stdio
# and select.select() used in ssnet.py won't work on Windows pipes. # and select.select() used in ssnet.py won't work on Windows pipes.
# So we have to use both socketpair (for select.select) and pipes (for subprocess.Popen) together # So we have to use both socketpair (for select.select) and pipes (for subprocess.Popen) together
# along with reader/writer threads to stream data between them # along with reader/writer threads to stream data between them