From 4a84ad3be60a4e09f0298cf10c8af44f5655d737 Mon Sep 17 00:00:00 2001 From: nom3ad <19239479+nom3ad@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:24:31 +0530 Subject: [PATCH] fix windows CRLF issue on stdin/stdout --- sshuttle/client.py | 4 +- sshuttle/firewall.py | 64 ++- sshuttle/methods/windivert.py | 958 +++++++++++++++++----------------- sshuttle/ssh.py | 2 +- tests/client/test_firewall.py | 6 +- 5 files changed, 520 insertions(+), 514 deletions(-) diff --git a/sshuttle/client.py b/sshuttle/client.py index e2179f4..232620a 100644 --- a/sshuttle/client.py +++ b/sshuttle/client.py @@ -307,7 +307,7 @@ class FirewallClient: def get_pfile(): if can_use_stdio: - self.p.stdin.write(b'STDIO:\n') + self.p.stdin.write(b'COM_STDIO:\n') self.p.stdin.flush() class RWPair: @@ -334,7 +334,7 @@ class FirewallClient: socket_share_data = s1.share(self.p.pid) s1.close() socket_share_data_b64 = base64.b64encode(socket_share_data) - self.p.stdin.write(b'SOCKETSHARE:' + socket_share_data_b64 + b'\n') + self.p.stdin.write(b'COM_SOCKETSHARE:' + socket_share_data_b64 + b'\n') self.p.stdin.flush() return s2.makefile('rwb') try: diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py index 44276b9..3e3bb64 100644 --- a/sshuttle/firewall.py +++ b/sshuttle/firewall.py @@ -84,7 +84,7 @@ def firewall_exit(signum, frame): # the typical exit process as described above. global sshuttle_pid if sshuttle_pid: - debug1("Relaying interupt signal to sshuttle process %d\n" % sshuttle_pid) + debug1("Relaying interupt signal to sshuttle process %d" % sshuttle_pid) if sys.platform == 'win32': sig = signal.CTRL_C_EVENT else: @@ -115,7 +115,7 @@ def _setup_daemon_for_unix_like(): # setsid() fails if sudo is configured with the use_pty option. pass - return sys.stdin, sys.stdout + return sys.stdin.buffer, sys.stdout.buffer def _setup_daemon_for_windows(): @@ -125,9 +125,9 @@ def _setup_daemon_for_windows(): signal.signal(signal.SIGTERM, firewall_exit) signal.signal(signal.SIGINT, firewall_exit) - socket_share_data_prefix = 'SOCKETSHARE:' - line = sys.stdin.readline().strip() - if line.startswith('SOCKETSHARE:'): + socket_share_data_prefix = b'COM_SOCKETSHARE:' + line = sys.stdin.buffer.readline().strip() + if line.startswith(socket_share_data_prefix): debug3('Using shared socket for communicating with sshuttle client process') socket_share_data_b64 = line[len(socket_share_data_prefix):] socket_share_data = base64.b64decode(socket_share_data_b64) @@ -135,12 +135,12 @@ def _setup_daemon_for_windows(): sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0)) sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True) sock.close() - elif line.startswith("STDIO:"): + elif line.startswith(b"COM_STDIO:"): debug3('Using inherited stdio for communicating with sshuttle client process') else: raise Fatal("Unexpected stdin: " + line) - return sys.stdin, sys.stdout + return sys.stdin.buffer, sys.stdout.buffer # Isolate function that needs to be replaced for tests @@ -221,33 +221,43 @@ def main(method_name, syslog): "PATH." % method_name) debug1('ready method name %s.' % method.name) - stdout.write('READY %s\n' % method.name) + stdout.write(('READY %s\n' % method.name).encode('ASCII')) stdout.flush() + + def _read_next_string_line(): + try: + line = stdin.readline(128) + if not line: + return # parent probably exited + return line.decode('ASCII').strip() + except IOError as e: + # On windows, ConnectionResetError is thrown when parent process closes it's socket pair end + debug3('read from stdin failed: %s' % (e,)) + return # we wait until we get some input before creating the rules. That way, # sshuttle can launch us as early as possible (and get sudo password # authentication as early in the startup process as possible). try: - line = stdin.readline(128) + line = _read_next_string_line() if not line: return # parent probably exited - except ConnectionResetError as e: + except IOError as e: # On windows, ConnectionResetError is thrown when parent process closes it's socket pair end debug3('read from stdin failed: %s' % (e,)) return subnets = [] - if line != 'ROUTES\n': + if line != 'ROUTES': raise Fatal('expected ROUTES but got %r' % line) while 1: - line = stdin.readline(128) + line = _read_next_string_line() if not line: raise Fatal('expected route but got %r' % line) - elif line.startswith("NSLIST\n"): + elif line.startswith("NSLIST"): break try: - (family, width, exclude, ip, fport, lport) = \ - line.strip().split(',', 5) + (family, width, exclude, ip, fport, lport) = line.split(',', 5) except Exception: raise Fatal('expected route or NSLIST but got %r' % line) subnets.append(( @@ -260,16 +270,16 @@ def main(method_name, syslog): debug2('Got subnets: %r' % subnets) nslist = [] - if line != 'NSLIST\n': + if line != 'NSLIST': raise Fatal('expected NSLIST but got %r' % line) while 1: - line = stdin.readline(128) + line = _read_next_string_line() if not line: raise Fatal('expected nslist but got %r' % line) elif line.startswith("PORTS "): break try: - (family, ip) = line.strip().split(',', 1) + (family, ip) = line.split(',', 1) except Exception: raise Fatal('expected nslist or PORTS but got %r' % line) nslist.append((int(family), ip)) @@ -299,15 +309,13 @@ def main(method_name, syslog): debug2('Got ports: %d,%d,%d,%d' % (port_v6, port_v4, dnsport_v6, dnsport_v4)) - line = stdin.readline(128) - if not line: - raise Fatal('expected GO but got %r' % line) - elif not line.startswith("GO "): + line = _read_next_string_line() + if not line or not line.startswith("GO "): raise Fatal('expected GO but got %r' % line) _, _, args = line.partition(" ") global sshuttle_pid - udp, user, group, tmark, sshuttle_pid = args.strip().split(" ", 4) + udp, user, group, tmark, sshuttle_pid = args.split(" ", 4) udp = bool(int(udp)) sshuttle_pid = int(sshuttle_pid) if user == '-': @@ -350,7 +358,7 @@ def main(method_name, syslog): flush_systemd_dns_cache() try: - stdout.write('STARTED\n') + stdout.write(b'STARTED\n') stdout.flush() except IOError as e: # the parent process probably died debug3('write to stdout failed: %s' % (e,)) @@ -360,13 +368,11 @@ def main(method_name, syslog): # to stay running so that we don't need a *second* password # authentication at shutdown time - that cleanup is important! while 1: - try: - line = stdin.readline(128) - except IOError as e: - debug3('read from stdin failed: %s' % (e,)) + line = _read_next_string_line() + if not line: return if line.startswith('HOST '): - (name, ip) = line[5:].strip().split(',', 1) + (name, ip) = line[5:].split(',', 1) hostmap[name] = ip debug2('setting up /etc/hosts.') rewrite_etc_hosts(hostmap, port_v6 or port_v4) diff --git a/sshuttle/methods/windivert.py b/sshuttle/methods/windivert.py index beb2e15..dc3b169 100644 --- a/sshuttle/methods/windivert.py +++ b/sshuttle/methods/windivert.py @@ -1,479 +1,479 @@ -import os -import sys -import ipaddress -import threading -from collections import namedtuple -import socket -import subprocess -import re -from multiprocessing import shared_memory -import struct -from functools import wraps -from enum import IntEnum -import time -import traceback - - -from sshuttle.methods import BaseMethod -from sshuttle.helpers import debug3, log, debug1, debug2, Fatal - -try: - # https://reqrypt.org/windivert-doc.html#divert_iphdr - import pydivert -except ImportError: - raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert") - - -ConnectionTuple = namedtuple( - "ConnectionTuple", - ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"], -) - - -WINDIVERT_MAX_CONNECTIONS = 10_000 - - -class IPProtocol(IntEnum): - TCP = socket.IPPROTO_TCP - UDP = socket.IPPROTO_UDP - - @property - def filter(self): - return "tcp" if self == IPProtocol.TCP else "udp" - - -class IPFamily(IntEnum): - IPv4 = socket.AF_INET - IPv6 = socket.AF_INET6 - - @property - def filter(self): - return "ip" if self == socket.AF_INET else "ipv6" - - @property - def version(self): - return 4 if self == socket.AF_INET else 6 - - @property - def loopback_addr(self): - return "127.0.0.1" if self == socket.AF_INET else "::1" - - -class ConnState(IntEnum): - TCP_SYN_SENT = 11 # SYN sent - TCP_ESTABLISHED = 12 # SYN+ACK received - TCP_FIN_WAIT_1 = 91 # FIN sent - TCP_CLOSE_WAIT = 92 # FIN received - - @staticmethod - def can_timeout(state): - return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT) - - -def repr_pkt(p): - r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}" - if p.tcp: - t = p.tcp - r += f" {len(t.payload)}B (" - r += "+".join( - f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f) - ) - r += f") SEQ#{t.seq_num}" - if t.ack: - r += f" ACK#{t.ack_num}" - r += f" WZ={t.window_size}" - else: - r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}" - return f"" - - -def synchronized_method(lock): - def decorator(method): - @wraps(method) - def wrapped(self, *args, **kwargs): - with getattr(self, lock): - return method(self, *args, **kwargs) - - return wrapped - - return decorator - - -class ConnTrack: - - _instance = None - - def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = object.__new__(cls) - return cls._instance - raise RuntimeError("ConnTrack can not be instantiated multiple times") - - def __init__(self, name, max_connections=0) -> None: - self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B"))) - self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H"))) - self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B"))) - - try: - self.max_connections = max_connections - self.shm_list = shared_memory.ShareableList( - [bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name - ) - self.is_owner = True - self.next_slot = 0 - self.used_slots = set() - self.rlock = threading.RLock() - except FileExistsError: - self.is_owner = False - self.shm_list = shared_memory.ShareableList(name=name) - self.max_connections = len(self.shm_list) - - debug2( - f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} " - f"shm_size={self.shm_list.shm.size}B" - ) - - @synchronized_method("rlock") - def add(self, proto, src_addr, src_port, dst_addr, dst_port, state): - if not self.is_owner: - raise RuntimeError("Only owner can mutate ConnTrack") - if len(self.used_slots) >= self.max_connections: - raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}") - - if self.get(proto, src_addr, src_port): - return - - for _ in range(self.max_connections): - if self.next_slot not in self.used_slots: - break - self.next_slot = (self.next_slot + 1) % self.max_connections - else: - raise RuntimeError("No slot available in ConnTrack") # should not be here - - src_addr = ipaddress.ip_address(src_addr) - dst_addr = ipaddress.ip_address(dst_addr) - assert src_addr.version == dst_addr.version - ip_version = src_addr.version - state_epoch = int(time.time()) - entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state) - packed = self.struct_full_tuple.pack(*entry) - self.shm_list[self.next_slot] = packed - self.used_slots.add(self.next_slot) - proto = IPProtocol(proto) - debug3( - f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to " - f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}" - ) - - @synchronized_method("rlock") - def update(self, proto, src_addr, src_port, state): - if not self.is_owner: - raise RuntimeError("Only owner can mutate ConnTrack") - src_addr = ipaddress.ip_address(src_addr) - packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) - for i in self.used_slots: - if self.shm_list[i].startswith(packed): - state_epoch = int(time.time()) - self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state) - debug3( - f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | " - f"#ActiveConn={len(self.used_slots)}" - ) - return self._unpack(self.shm_list[i]) - else: - debug3( - f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | " - f"#ActiveConn={len(self.used_slots)}" - ) - - @synchronized_method("rlock") - def remove(self, proto, src_addr, src_port): - if not self.is_owner: - raise RuntimeError("Only owner can mutate ConnTrack") - src_addr = ipaddress.ip_address(src_addr) - packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) - for i in self.used_slots: - if self.shm_list[i].startswith(packed): - conn = self._unpack(self.shm_list[i]) - self.shm_list[i] = b"" - self.used_slots.remove(i) - debug3( - f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | " - f"#ActiveConn={len(self.used_slots)}" - ) - return conn - else: - debug3( - f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |" - f" #ActiveConn={len(self.used_slots)}" - ) - - def get(self, proto, src_addr, src_port): - src_addr = ipaddress.ip_address(src_addr) - packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) - for entry in self.shm_list: - if entry and entry.startswith(packed): - return self._unpack(entry) - - @synchronized_method("rlock") - def gc(self, connection_timeout_sec=15): - now = int(time.time()) - n = 0 - for i in tuple(self.used_slots): - state_packed = self.shm_list[i][-5:] - (state_epoch, state) = self.struct_state_tuple.unpack(state_packed) - if (now - state_epoch) < connection_timeout_sec: - continue - if ConnState.can_timeout(state): - conn = self._unpack(self.shm_list[i]) - self.shm_list[i] = b"" - self.used_slots.remove(i) - n += 1 - debug3( - f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})" - f" from slot={i} | #ActiveConn={len(self.used_slots)}" - ) - debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}") - - def _unpack(self, packed): - ( - proto, - ip_version, - src_addr_packed, - src_port, - dst_addr_packed, - dst_port, - state_epoch, - state, - ) = self.struct_full_tuple.unpack(packed) - dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) - src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) - return ConnectionTuple( - IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) - ) - - def __iter__(self): - def conn_iter(): - for i in self.used_slots: - yield self._unpack(self.shm_list[i]) - - return conn_iter() - - def __repr__(self): - return f"" - - -class Method(BaseMethod): - - network_config = {} - proxy_port = None - proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None} - - def __init__(self, name): - super().__init__(name) - - def _get_local_proxy_listen_addr(self, port, family): - proto = "TCPv6" if family.version == 6 else "TCP" - for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines(): - try: - _, local_addr, _, state, *_ = re.split(r"\s+", line.strip()) - except ValueError: - continue - port_suffix = ":" + str(port) - if state == "LISTENING" and local_addr.endswith(port_suffix): - return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]")) - raise Fatal("Could not find listening address for {}/{}".format(port, proto)) - - def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark): - log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") - - if nslist or user or udp: - raise NotImplementedError() - - family = IPFamily(family) - - # using loopback proxy address never worked. - # >>> self.proxy_addr[family] = family.loopback_addr - # See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82) - # As a workaround we use another interface ip instead. - - local_addr = self._get_local_proxy_listen_addr(port, family) - for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): - if addr.is_loopback or addr.version != family.version: - continue - if local_addr.is_unspecified or local_addr == addr: - debug2("Found non loopback address to connect to proxy: " + str(addr)) - self.proxy_addr[family] = str(addr) - break - else: - raise Fatal("Windivert method requires proxy to listen on non loopback address") - - self.proxy_port = port - - subnet_addresses = [] - for (_, mask, exclude, network_addr, fport, lport) in subnets: - if exclude: - continue - assert fport == 0, "custom port range not supported" - assert lport == 0, "custom port range not supported" - subnet_addresses.append("%s/%s" % (network_addr, mask)) - - self.network_config[family] = { - "subnets": subnet_addresses, - "nslist": nslist, - } - - def wait_for_firewall_ready(self): - debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}") - self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS) - methods = (self._egress_divert, self._ingress_divert, self._connection_gc) - ready_events = [] - for fn in methods: - ev = threading.Event() - ready_events.append(ev) - - def _target(): - try: - fn(ev.set) - except Exception: - debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc()) - sys.stdin.close() # this will exist main thread - sys.stdout.close() - - threading.Thread(name=fn.__name__, target=_target, daemon=True).start() - for ev in ready_events: - if not ev.wait(5): # at most 5 sec - raise Fatal("timeout in wait_for_firewall_ready()") - - def restore_firewall(self, port, family, udp, user): - pass - - def get_supported_features(self): - result = super(Method, self).get_supported_features() - result.loopback_port = False - result.user = False - result.dns = False - result.ipv6 = False - return result - - def get_tcp_dstip(self, sock): - if not hasattr(self, "conntrack"): - self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}") - - src_addr, src_port = sock.getpeername() - c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port) - if not c: - return (src_addr, src_port) - return (c.dst_addr, c.dst_port) - - def is_supported(self): - if sys.platform == "win32": - return True - return False - - def _egress_divert(self, ready_cb): - proto = IPProtocol.TCP - filter = f"outbound and {proto.filter}" - - # with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w: - family_filters = [] - for af, c in self.network_config.items(): - subnet_filters = [] - for cidr in c["subnets"]: - ip_network = ipaddress.ip_network(cidr) - first_ip = ip_network.network_address - last_ip = ip_network.broadcast_address - subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})") - family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ") - - filter = f"{filter} and ({' or '.join(family_filters)})" - - debug1(f"[OUTBOUND] {filter=}") - with pydivert.WinDivert(filter) as w: - ready_cb() - proxy_port = self.proxy_port - proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4] - proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6] - for pkt in w: - debug3(">>> " + repr_pkt(pkt)) - if pkt.tcp.syn and not pkt.tcp.ack: - # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK) - self.conntrack.add( - socket.IPPROTO_TCP, - pkt.src_addr, - pkt.src_port, - pkt.dst_addr, - pkt.dst_port, - ConnState.TCP_SYN_SENT, - ) - if pkt.tcp.fin: - # FIN sent (start of graceful close our side, and we wait for ACK) - self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1) - if pkt.tcp.rst: - # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply) - self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port) - - # DNAT - if pkt.ipv4 and proxy_addr_ipv4: - pkt.dst_addr = proxy_addr_ipv4 - if pkt.ipv6 and proxy_addr_ipv6: - pkt.dst_addr = proxy_addr_ipv6 - pkt.tcp.dst_port = proxy_port - - # XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well - # by setting src_addr to loopback address. - # Otherwise injecting packet will be ignored by Windows network stack - # as they packet has to cross public to private address space. - # See: https://github.com/basil00/Divert/issues/82 - # Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets. - # >>> pkt.dst_addr = proxy_addr_ipv4 - - w.send(pkt, recalculate_checksum=True) - - def _ingress_divert(self, ready_cb): - proto = IPProtocol.TCP - direction = "inbound" # only when proxy address is not loopback address (Useful for testing) - ip_filters = [] - for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a): - if addr.is_loopback: # Windivert treats all loopback traffic as outbound - direction = "outbound" - if addr.version == 4: - ip_filters.append(f"ip.SrcAddr=={addr}") - else: - # ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this - ip_filters.append(f"ipv6.SrcAddr=={addr}") - if not ip_filters: - raise Fatal("At least ipv4 or ipv6 address is expected") - filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}" - debug1(f"[INGRESS] {filter=}") - with pydivert.WinDivert(filter) as w: - ready_cb() - for pkt in w: - debug3("<<< " + repr_pkt(pkt)) - if pkt.tcp.syn and pkt.tcp.ack: - # SYN+ACK received (connection established) - conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) - elif pkt.tcp.rst: - # RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets - conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) - # https://wiki.wireshark.org/TCP-4-times-close.md - elif pkt.tcp.fin and pkt.tcp.ack: - # FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK) - conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) - elif pkt.tcp.fin: - # FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet) - conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT) - else: - conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) - if not conn: - debug2("Unexpected packet: " + repr_pkt(pkt)) - continue - pkt.src_addr = conn.dst_addr - pkt.tcp.src_port = conn.dst_port - w.send(pkt, recalculate_checksum=True) - - def _connection_gc(self, ready_cb): - ready_cb() - while True: - time.sleep(5) - self.conntrack.gc() +import os +import sys +import ipaddress +import threading +from collections import namedtuple +import socket +import subprocess +import re +from multiprocessing import shared_memory +import struct +from functools import wraps +from enum import IntEnum +import time +import traceback + + +from sshuttle.methods import BaseMethod +from sshuttle.helpers import debug3, log, debug1, debug2, Fatal + +try: + # https://reqrypt.org/windivert-doc.html#divert_iphdr + import pydivert +except ImportError: + raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert") + + +ConnectionTuple = namedtuple( + "ConnectionTuple", + ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"], +) + + +WINDIVERT_MAX_CONNECTIONS = 10_000 + + +class IPProtocol(IntEnum): + TCP = socket.IPPROTO_TCP + UDP = socket.IPPROTO_UDP + + @property + def filter(self): + return "tcp" if self == IPProtocol.TCP else "udp" + + +class IPFamily(IntEnum): + IPv4 = socket.AF_INET + IPv6 = socket.AF_INET6 + + @property + def filter(self): + return "ip" if self == socket.AF_INET else "ipv6" + + @property + def version(self): + return 4 if self == socket.AF_INET else 6 + + @property + def loopback_addr(self): + return "127.0.0.1" if self == socket.AF_INET else "::1" + + +class ConnState(IntEnum): + TCP_SYN_SENT = 11 # SYN sent + TCP_ESTABLISHED = 12 # SYN+ACK received + TCP_FIN_WAIT_1 = 91 # FIN sent + TCP_CLOSE_WAIT = 92 # FIN received + + @staticmethod + def can_timeout(state): + return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT) + + +def repr_pkt(p): + r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}" + if p.tcp: + t = p.tcp + r += f" {len(t.payload)}B (" + r += "+".join( + f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f) + ) + r += f") SEQ#{t.seq_num}" + if t.ack: + r += f" ACK#{t.ack_num}" + r += f" WZ={t.window_size}" + else: + r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}" + return f"" + + +def synchronized_method(lock): + def decorator(method): + @wraps(method) + def wrapped(self, *args, **kwargs): + with getattr(self, lock): + return method(self, *args, **kwargs) + + return wrapped + + return decorator + + +class ConnTrack: + + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = object.__new__(cls) + return cls._instance + raise RuntimeError("ConnTrack can not be instantiated multiple times") + + def __init__(self, name, max_connections=0) -> None: + self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B"))) + self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H"))) + self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B"))) + + try: + self.max_connections = max_connections + self.shm_list = shared_memory.ShareableList( + [bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name + ) + self.is_owner = True + self.next_slot = 0 + self.used_slots = set() + self.rlock = threading.RLock() + except FileExistsError: + self.is_owner = False + self.shm_list = shared_memory.ShareableList(name=name) + self.max_connections = len(self.shm_list) + + debug2( + f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} " + f"shm_size={self.shm_list.shm.size}B" + ) + + @synchronized_method("rlock") + def add(self, proto, src_addr, src_port, dst_addr, dst_port, state): + if not self.is_owner: + raise RuntimeError("Only owner can mutate ConnTrack") + if len(self.used_slots) >= self.max_connections: + raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}") + + if self.get(proto, src_addr, src_port): + return + + for _ in range(self.max_connections): + if self.next_slot not in self.used_slots: + break + self.next_slot = (self.next_slot + 1) % self.max_connections + else: + raise RuntimeError("No slot available in ConnTrack") # should not be here + + src_addr = ipaddress.ip_address(src_addr) + dst_addr = ipaddress.ip_address(dst_addr) + assert src_addr.version == dst_addr.version + ip_version = src_addr.version + state_epoch = int(time.time()) + entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state) + packed = self.struct_full_tuple.pack(*entry) + self.shm_list[self.next_slot] = packed + self.used_slots.add(self.next_slot) + proto = IPProtocol(proto) + debug3( + f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to " + f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}" + ) + + @synchronized_method("rlock") + def update(self, proto, src_addr, src_port, state): + if not self.is_owner: + raise RuntimeError("Only owner can mutate ConnTrack") + src_addr = ipaddress.ip_address(src_addr) + packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) + for i in self.used_slots: + if self.shm_list[i].startswith(packed): + state_epoch = int(time.time()) + self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state) + debug3( + f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | " + f"#ActiveConn={len(self.used_slots)}" + ) + return self._unpack(self.shm_list[i]) + else: + debug3( + f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | " + f"#ActiveConn={len(self.used_slots)}" + ) + + @synchronized_method("rlock") + def remove(self, proto, src_addr, src_port): + if not self.is_owner: + raise RuntimeError("Only owner can mutate ConnTrack") + src_addr = ipaddress.ip_address(src_addr) + packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) + for i in self.used_slots: + if self.shm_list[i].startswith(packed): + conn = self._unpack(self.shm_list[i]) + self.shm_list[i] = b"" + self.used_slots.remove(i) + debug3( + f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | " + f"#ActiveConn={len(self.used_slots)}" + ) + return conn + else: + debug3( + f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |" + f" #ActiveConn={len(self.used_slots)}" + ) + + def get(self, proto, src_addr, src_port): + src_addr = ipaddress.ip_address(src_addr) + packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) + for entry in self.shm_list: + if entry and entry.startswith(packed): + return self._unpack(entry) + + @synchronized_method("rlock") + def gc(self, connection_timeout_sec=15): + now = int(time.time()) + n = 0 + for i in tuple(self.used_slots): + state_packed = self.shm_list[i][-5:] + (state_epoch, state) = self.struct_state_tuple.unpack(state_packed) + if (now - state_epoch) < connection_timeout_sec: + continue + if ConnState.can_timeout(state): + conn = self._unpack(self.shm_list[i]) + self.shm_list[i] = b"" + self.used_slots.remove(i) + n += 1 + debug3( + f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})" + f" from slot={i} | #ActiveConn={len(self.used_slots)}" + ) + debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}") + + def _unpack(self, packed): + ( + proto, + ip_version, + src_addr_packed, + src_port, + dst_addr_packed, + dst_port, + state_epoch, + state, + ) = self.struct_full_tuple.unpack(packed) + dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) + src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) + return ConnectionTuple( + IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) + ) + + def __iter__(self): + def conn_iter(): + for i in self.used_slots: + yield self._unpack(self.shm_list[i]) + + return conn_iter() + + def __repr__(self): + return f"" + + +class Method(BaseMethod): + + network_config = {} + proxy_port = None + proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None} + + def __init__(self, name): + super().__init__(name) + + def _get_local_proxy_listen_addr(self, port, family): + proto = "TCPv6" if family.version == 6 else "TCP" + for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines(): + try: + _, local_addr, _, state, *_ = re.split(r"\s+", line.strip()) + except ValueError: + continue + port_suffix = ":" + str(port) + if state == "LISTENING" and local_addr.endswith(port_suffix): + return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]")) + raise Fatal("Could not find listening address for {}/{}".format(port, proto)) + + def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark): + log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") + + if nslist or user or udp: + raise NotImplementedError() + + family = IPFamily(family) + + # using loopback proxy address never worked. + # >>> self.proxy_addr[family] = family.loopback_addr + # See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82) + # As a workaround we use another interface ip instead. + + local_addr = self._get_local_proxy_listen_addr(port, family) + for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): + if addr.is_loopback or addr.version != family.version: + continue + if local_addr.is_unspecified or local_addr == addr: + debug2("Found non loopback address to connect to proxy: " + str(addr)) + self.proxy_addr[family] = str(addr) + break + else: + raise Fatal("Windivert method requires proxy to listen on non loopback address") + + self.proxy_port = port + + subnet_addresses = [] + for (_, mask, exclude, network_addr, fport, lport) in subnets: + if exclude: + continue + assert fport == 0, "custom port range not supported" + assert lport == 0, "custom port range not supported" + subnet_addresses.append("%s/%s" % (network_addr, mask)) + + self.network_config[family] = { + "subnets": subnet_addresses, + "nslist": nslist, + } + + def wait_for_firewall_ready(self): + debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}") + self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS) + methods = (self._egress_divert, self._ingress_divert, self._connection_gc) + ready_events = [] + for fn in methods: + ev = threading.Event() + ready_events.append(ev) + + def _target(): + try: + fn(ev.set) + except Exception: + debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc()) + sys.stdin.close() # this will exist main thread + sys.stdout.close() + + threading.Thread(name=fn.__name__, target=_target, daemon=True).start() + for ev in ready_events: + if not ev.wait(5): # at most 5 sec + raise Fatal("timeout in wait_for_firewall_ready()") + + def restore_firewall(self, port, family, udp, user, group): + pass + + def get_supported_features(self): + result = super(Method, self).get_supported_features() + result.loopback_port = False + result.user = False + result.dns = False + result.ipv6 = False + return result + + def get_tcp_dstip(self, sock): + if not hasattr(self, "conntrack"): + self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}") + + src_addr, src_port = sock.getpeername() + c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port) + if not c: + return (src_addr, src_port) + return (c.dst_addr, c.dst_port) + + def is_supported(self): + if sys.platform == "win32": + return True + return False + + def _egress_divert(self, ready_cb): + proto = IPProtocol.TCP + filter = f"outbound and {proto.filter}" + + # with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w: + family_filters = [] + for af, c in self.network_config.items(): + subnet_filters = [] + for cidr in c["subnets"]: + ip_network = ipaddress.ip_network(cidr) + first_ip = ip_network.network_address + last_ip = ip_network.broadcast_address + subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})") + family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ") + + filter = f"{filter} and ({' or '.join(family_filters)})" + + debug1(f"[OUTBOUND] {filter=}") + with pydivert.WinDivert(filter) as w: + ready_cb() + proxy_port = self.proxy_port + proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4] + proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6] + for pkt in w: + debug3(">>> " + repr_pkt(pkt)) + if pkt.tcp.syn and not pkt.tcp.ack: + # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK) + self.conntrack.add( + socket.IPPROTO_TCP, + pkt.src_addr, + pkt.src_port, + pkt.dst_addr, + pkt.dst_port, + ConnState.TCP_SYN_SENT, + ) + if pkt.tcp.fin: + # FIN sent (start of graceful close our side, and we wait for ACK) + self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1) + if pkt.tcp.rst: + # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply) + self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port) + + # DNAT + if pkt.ipv4 and proxy_addr_ipv4: + pkt.dst_addr = proxy_addr_ipv4 + if pkt.ipv6 and proxy_addr_ipv6: + pkt.dst_addr = proxy_addr_ipv6 + pkt.tcp.dst_port = proxy_port + + # XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well + # by setting src_addr to loopback address. + # Otherwise injecting packet will be ignored by Windows network stack + # as they packet has to cross public to private address space. + # See: https://github.com/basil00/Divert/issues/82 + # Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets. + # >>> pkt.dst_addr = proxy_addr_ipv4 + + w.send(pkt, recalculate_checksum=True) + + def _ingress_divert(self, ready_cb): + proto = IPProtocol.TCP + direction = "inbound" # only when proxy address is not loopback address (Useful for testing) + ip_filters = [] + for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a): + if addr.is_loopback: # Windivert treats all loopback traffic as outbound + direction = "outbound" + if addr.version == 4: + ip_filters.append(f"ip.SrcAddr=={addr}") + else: + # ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this + ip_filters.append(f"ipv6.SrcAddr=={addr}") + if not ip_filters: + raise Fatal("At least ipv4 or ipv6 address is expected") + filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}" + debug1(f"[INGRESS] {filter=}") + with pydivert.WinDivert(filter) as w: + ready_cb() + for pkt in w: + debug3("<<< " + repr_pkt(pkt)) + if pkt.tcp.syn and pkt.tcp.ack: + # SYN+ACK received (connection established) + conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) + elif pkt.tcp.rst: + # RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets + conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) + # https://wiki.wireshark.org/TCP-4-times-close.md + elif pkt.tcp.fin and pkt.tcp.ack: + # FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK) + conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) + elif pkt.tcp.fin: + # FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet) + conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT) + else: + conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) + if not conn: + debug2("Unexpected packet: " + repr_pkt(pkt)) + continue + pkt.src_addr = conn.dst_addr + pkt.tcp.src_port = conn.dst_port + w.send(pkt, recalculate_checksum=True) + + def _connection_gc(self, ready_cb): + ready_cb() + while True: + time.sleep(5) + self.conntrack.gc() diff --git a/sshuttle/ssh.py b/sshuttle/ssh.py index 3942165..7b494d2 100644 --- a/sshuttle/ssh.py +++ b/sshuttle/ssh.py @@ -260,7 +260,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options): threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start() threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start() - return s2 + return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0) # https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows close_fds = False if sys.platform == 'win32' else True diff --git a/tests/client/test_firewall.py b/tests/client/test_firewall.py index e714c81..02a73f7 100644 --- a/tests/client/test_firewall.py +++ b/tests/client/test_firewall.py @@ -10,7 +10,7 @@ import sshuttle.firewall def setup_daemon(): - stdin = io.StringIO(u"""ROUTES + stdin = io.BytesIO(u"""ROUTES {inet},24,0,1.2.3.0,8000,9000 {inet},32,1,1.2.3.66,8080,8080 {inet6},64,0,2404:6800:4004:80c::,0,0 @@ -127,9 +127,9 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts): ] assert stdout.mock_calls == [ - call.write('READY test\n'), + call.write(b'READY test\n'), call.flush(), - call.write('STARTED\n'), + call.write(b'STARTED\n'), call.flush() ] assert mock_setup_daemon.mock_calls == [call()]