From 7a92183f5950262102b68aadf23c0cd35d15afc0 Mon Sep 17 00:00:00 2001 From: nom3ad <19239479+nom3ad@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:04:04 +0530 Subject: [PATCH] windows: better connection tracker --- sshuttle/firewall.py | 2 +- sshuttle/methods/__init__.py | 2 +- sshuttle/methods/windivert.py | 29 +++++++++++++++++++++-------- tests/client/test_firewall.py | 2 +- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/sshuttle/firewall.py b/sshuttle/firewall.py index bbeaaa4..0b34f68 100644 --- a/sshuttle/firewall.py +++ b/sshuttle/firewall.py @@ -348,7 +348,7 @@ def main(method_name, syslog): try: # For some methods (eg: windivert) firewall setup will be differed / will run asynchronously. # Such method implements wait_for_firewall_ready() to wait until firewall is up and running. - method.wait_for_firewall_ready() + method.wait_for_firewall_ready(sshuttle_pid) except NotImplementedError: pass diff --git a/sshuttle/methods/__init__.py b/sshuttle/methods/__init__.py index 49da095..0f56e59 100644 --- a/sshuttle/methods/__init__.py +++ b/sshuttle/methods/__init__.py @@ -98,7 +98,7 @@ class BaseMethod(object): def restore_firewall(self, port, family, udp, user, group): raise NotImplementedError() - def wait_for_firewall_ready(self): + def wait_for_firewall_ready(self, sshuttle_pid): raise NotImplementedError() @staticmethod diff --git a/sshuttle/methods/windivert.py b/sshuttle/methods/windivert.py index d9d1524..2bf4674 100644 --- a/sshuttle/methods/windivert.py +++ b/sshuttle/methods/windivert.py @@ -15,7 +15,7 @@ import traceback from sshuttle.methods import BaseMethod -from sshuttle.helpers import debug3, debug1, debug2, get_verbose_level, Fatal +from sshuttle.helpers import log, debug3, debug1, debug2, get_verbose_level, Fatal try: # https://reqrypt.org/windivert-doc.html#divert_iphdr @@ -228,8 +228,17 @@ class ConnTrack: if entry and entry.startswith(packed): return self._unpack(entry) + def dump(self): + for entry in self.shm_list: + if not entry: + continue + conn = self._unpack(entry) + proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state = conn + log(f"{proto.name}/{ip_version} {src_addr}:{src_port} -> {dst_addr}:{dst_port} {state.name}@{state_epoch}") + @synchronized_method("rlock") def gc(self, connection_timeout_sec=15): + # self.dump() now = int(time.time()) n = 0 for i in tuple(self.used_slots): @@ -261,9 +270,9 @@ class ConnTrack: ) = self.struct_full_tuple.unpack(packed) dst_addr = ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]).exploded src_addr = ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]).exploded - return ConnectionTuple( - IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) - ) + proto = IPProtocol(proto) + state = ConnState(state) + return ConnectionTuple(proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state) def __iter__(self): def conn_iter(): @@ -338,12 +347,14 @@ class Method(BaseMethod): "proxy_addr": (proxy_ip, proxy_port) } - def wait_for_firewall_ready(self): + def wait_for_firewall_ready(self, sshuttle_pid): debug2(f"network_config={self.network_config}") - self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS) - methods = (self._egress_divert, self._ingress_divert, self._connection_gc) + self.conntrack = ConnTrack(f"sshuttle-windivert-{sshuttle_pid}", WINDIVERT_MAX_CONNECTIONS) + if not self.conntrack.is_owner: + raise Fatal("ConnTrack should be owner in wait_for_firewall_ready()") + thread_target_funcs = (self._egress_divert, self._ingress_divert, self._connection_gc) ready_events = [] - for fn in methods: + for fn in thread_target_funcs: ev = threading.Event() ready_events.append(ev) @@ -376,6 +387,8 @@ class Method(BaseMethod): def get_tcp_dstip(self, sock): if not hasattr(self, "conntrack"): self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}") + if self.conntrack.is_owner: + raise Fatal("ConnTrack should not be owner in get_tcp_dstip()") src_addr, src_port = sock.getpeername() c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port) diff --git a/tests/client/test_firewall.py b/tests/client/test_firewall.py index f9d4fdb..a953527 100644 --- a/tests/client/test_firewall.py +++ b/tests/client/test_firewall.py @@ -157,7 +157,7 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts): None, None, '0x01'), - call().wait_for_firewall_ready(), + call().wait_for_firewall_ready(os.getpid()), call().restore_firewall(1024, AF_INET6, True, None, None), call().restore_firewall(1025, AF_INET, True, None, None), ]