diff --git a/sshuttle/methods/windivert.py b/sshuttle/methods/windivert.py index a962feb..d9d1524 100644 --- a/sshuttle/methods/windivert.py +++ b/sshuttle/methods/windivert.py @@ -7,7 +7,7 @@ import socket import subprocess import re from multiprocessing import shared_memory -import struct +from struct import Struct from functools import wraps from enum import IntEnum import time @@ -61,7 +61,7 @@ class IPFamily(IntEnum): @property def loopback_addr(self): - return "127.0.0.1" if self == socket.AF_INET else "::1" + return ip_address("127.0.0.1" if self == socket.AF_INET else "::1") class ConnState(IntEnum): @@ -123,9 +123,9 @@ class ConnTrack: raise RuntimeError("ConnTrack can not be instantiated multiple times") def __init__(self, name, max_connections=0) -> None: - self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B"))) - self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H"))) - self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B"))) + self.struct_full_tuple = Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B"))) + self.struct_src_tuple = Struct(">" + "".join(("B", "B", "16s", "H"))) + self.struct_state_tuple = Struct(">" + "".join(("L", "B"))) try: self.max_connections = max_connections @@ -142,8 +142,8 @@ class ConnTrack: self.max_connections = len(self.shm_list) debug2( - f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} " - f"shm_size={self.shm_list.shm.size}B" + f"ConnTrack: is_owner={self.is_owner} cap={len(self.shm_list)} item_sz={self.struct_full_tuple.size}B" + f"shm_name={self.shm_list.shm.name} shm_sz={self.shm_list.shm.size}B" ) @synchronized_method("rlock") @@ -279,7 +279,6 @@ class ConnTrack: class Method(BaseMethod): network_config = {} - proxy_port = None def __init__(self, name): super().__init__(name) @@ -297,10 +296,10 @@ class Method(BaseMethod): raise Fatal("Could not find listening address for {}/{}".format(port, proto)) def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark): - debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") + debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {group=} {tmark=}") - if nslist or user or udp: - raise NotImplementedError() + if nslist or user or udp or group: + raise NotImplementedError("user, group, nslist, udp are not supported") family = IPFamily(family) @@ -312,24 +311,26 @@ class Method(BaseMethod): if proxy_bind_addr.is_loopback: raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.") if not proxy_bind_addr.is_unspecified: - proxy_ip = proxy_bind_addr.exploded + proxy_ip = proxy_bind_addr else: local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)] for addr in local_addresses: if not addr.is_loopback and not addr.is_link_local: - proxy_ip = addr.exploded + proxy_ip = addr break else: raise Fatal("Windivert method requires proxy to be reachable by a non loopback address." f"No address found for {family.name} in {local_addresses}") - debug2("Found non loopback address to connect to proxy: " + proxy_ip) + debug2(f"Found non loopback address to connect to proxy: {proxy_ip}") subnet_addresses = [] for (_, mask, exclude, network_addr, fport, lport) in subnets: - if exclude: - continue - assert fport == 0, "custom port range not supported" - assert lport == 0, "custom port range not supported" - subnet_addresses.append("%s/%s" % (network_addr, mask)) + if fport and lport: + if lport > fport: + raise Fatal("lport must be less than or equal to fport") + ports = (fport, lport) + else: + ports = None + subnet_addresses.append((ip_network(f"{network_addr}/{mask}"), ports, exclude)) self.network_config[family] = { "subnets": subnet_addresses, @@ -391,30 +392,48 @@ class Method(BaseMethod): """divert outgoing packets to proxy""" proto = IPProtocol.TCP filter = f"outbound and {proto.filter}" - - # with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w: - family_filters = [] + af_filters = [] for af, c in self.network_config.items(): - subnet_filters = [] - for cidr in c["subnets"]: - ip_net = ip_network(cidr) + subnet_include_filters = [] + subnet_exclude_filters = [] + for ip_net, ports, exclude in c["subnets"]: first_ip = ip_net.network_address.exploded last_ip = ip_net.broadcast_address.exploded - subnet_filters.append(f"({af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip})") - if not subnet_filters: - continue + if first_ip == last_ip: + _subney_filter = f"{af.filter}.DstAddr=={first_ip}" + else: + _subney_filter = f"{af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip}" + if ports: + if ports[0] == ports[1]: + _subney_filter += f" and {proto.filter}.DstPort=={ports[0]}" + else: + _subney_filter += f" and tcp.DstPort>={ports[0]} and tcp.DstPort<={ports[1]}" + (subnet_exclude_filters if exclude else subnet_include_filters).append(f'({_subney_filter})') + _af_filter = f"{af.filter}" + if subnet_include_filters: + _af_filter += f" and ({' or '.join(subnet_include_filters)})" + if subnet_exclude_filters: + # TODO(noma3ad) use not() operator with Windivert2 after upgrade + _af_filter += f" and (({' or '.join(subnet_exclude_filters)})? false : true)" proxy_ip, proxy_port = c["proxy_addr"] - proxy_guard_filter = f'({af.filter}.DstAddr!={proxy_ip} or tcp.DstPort!={proxy_port})' - family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) and {proxy_guard_filter}") - if not family_filters: + # Avoids proxy outbound traffic getting directed to itself + proxy_guard_filter = f'(({af.filter}.DstAddr=={proxy_ip.exploded} and tcp.DstPort=={proxy_port})? false : true)' + _af_filter += f" and {proxy_guard_filter}" + af_filters.append(_af_filter) + if not af_filters: raise Fatal("At least one ipv4 or ipv6 subnet is expected") - filter = f"{filter} and ({' or '.join(family_filters)})" + filter = f"{filter} and ({' or '.join(af_filters)})" debug1(f"[EGRESS] {filter=}") with pydivert.WinDivert(filter, layer=pydivert.Layer.NETWORK, flags=pydivert.Flag.DEFAULT) as w: + proxy_ipv4, proxy_ipv6 = None, None + if IPFamily.IPv4 in self.network_config: + proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"] + proxy_ipv4 = proxy_ipv4[0].exploded, proxy_ipv4[1] + if IPFamily.IPv6 in self.network_config: + proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"] + proxy_ipv6 = proxy_ipv6[0].exploded, proxy_ipv6[1] ready_cb() - proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"] if IPFamily.IPv4 in self.network_config else None - proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"] if IPFamily.IPv6 in self.network_config else None verbose = get_verbose_level() for pkt in w: verbose >= 3 and debug3("[EGRESS] " + repr_pkt(pkt)) @@ -461,7 +480,7 @@ class Method(BaseMethod): continue proxy_ip, proxy_port = c["proxy_addr"] # "ip.SrcAddr=={hex(int(proxy_ip))}" # only Windivert >=2 supports this - proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip} and tcp.SrcPort=={proxy_port}") + proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip.exploded} and tcp.SrcPort=={proxy_port}") if not proxy_addr_filters: raise Fatal("At least one ipv4 or ipv6 address is expected") filter = f"{direction} and {proto.filter} and ({' or '.join(proxy_addr_filters)})"