windows: better connection tracker

This commit is contained in:
nom3ad 2024-01-09 23:04:04 +05:30 committed by Brian May
parent 81a598a4cc
commit 7a92183f59
4 changed files with 24 additions and 11 deletions

View File

@ -348,7 +348,7 @@ def main(method_name, syslog):
try: try:
# For some methods (eg: windivert) firewall setup will be differed / will run asynchronously. # For some methods (eg: windivert) firewall setup will be differed / will run asynchronously.
# Such method implements wait_for_firewall_ready() to wait until firewall is up and running. # Such method implements wait_for_firewall_ready() to wait until firewall is up and running.
method.wait_for_firewall_ready() method.wait_for_firewall_ready(sshuttle_pid)
except NotImplementedError: except NotImplementedError:
pass pass

View File

@ -98,7 +98,7 @@ class BaseMethod(object):
def restore_firewall(self, port, family, udp, user, group): def restore_firewall(self, port, family, udp, user, group):
raise NotImplementedError() raise NotImplementedError()
def wait_for_firewall_ready(self): def wait_for_firewall_ready(self, sshuttle_pid):
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod

View File

@ -15,7 +15,7 @@ import traceback
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug3, debug1, debug2, get_verbose_level, Fatal from sshuttle.helpers import log, debug3, debug1, debug2, get_verbose_level, Fatal
try: try:
# https://reqrypt.org/windivert-doc.html#divert_iphdr # https://reqrypt.org/windivert-doc.html#divert_iphdr
@ -228,8 +228,17 @@ class ConnTrack:
if entry and entry.startswith(packed): if entry and entry.startswith(packed):
return self._unpack(entry) return self._unpack(entry)
def dump(self):
for entry in self.shm_list:
if not entry:
continue
conn = self._unpack(entry)
proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state = conn
log(f"{proto.name}/{ip_version} {src_addr}:{src_port} -> {dst_addr}:{dst_port} {state.name}@{state_epoch}")
@synchronized_method("rlock") @synchronized_method("rlock")
def gc(self, connection_timeout_sec=15): def gc(self, connection_timeout_sec=15):
# self.dump()
now = int(time.time()) now = int(time.time())
n = 0 n = 0
for i in tuple(self.used_slots): for i in tuple(self.used_slots):
@ -261,9 +270,9 @@ class ConnTrack:
) = self.struct_full_tuple.unpack(packed) ) = self.struct_full_tuple.unpack(packed)
dst_addr = ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]).exploded dst_addr = ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]).exploded
src_addr = ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]).exploded src_addr = ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]).exploded
return ConnectionTuple( proto = IPProtocol(proto)
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) state = ConnState(state)
) return ConnectionTuple(proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state)
def __iter__(self): def __iter__(self):
def conn_iter(): def conn_iter():
@ -338,12 +347,14 @@ class Method(BaseMethod):
"proxy_addr": (proxy_ip, proxy_port) "proxy_addr": (proxy_ip, proxy_port)
} }
def wait_for_firewall_ready(self): def wait_for_firewall_ready(self, sshuttle_pid):
debug2(f"network_config={self.network_config}") debug2(f"network_config={self.network_config}")
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS) self.conntrack = ConnTrack(f"sshuttle-windivert-{sshuttle_pid}", WINDIVERT_MAX_CONNECTIONS)
methods = (self._egress_divert, self._ingress_divert, self._connection_gc) if not self.conntrack.is_owner:
raise Fatal("ConnTrack should be owner in wait_for_firewall_ready()")
thread_target_funcs = (self._egress_divert, self._ingress_divert, self._connection_gc)
ready_events = [] ready_events = []
for fn in methods: for fn in thread_target_funcs:
ev = threading.Event() ev = threading.Event()
ready_events.append(ev) ready_events.append(ev)
@ -376,6 +387,8 @@ class Method(BaseMethod):
def get_tcp_dstip(self, sock): def get_tcp_dstip(self, sock):
if not hasattr(self, "conntrack"): if not hasattr(self, "conntrack"):
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}") self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
if self.conntrack.is_owner:
raise Fatal("ConnTrack should not be owner in get_tcp_dstip()")
src_addr, src_port = sock.getpeername() src_addr, src_port = sock.getpeername()
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port) c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)

View File

@ -157,7 +157,7 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
None, None,
None, None,
'0x01'), '0x01'),
call().wait_for_firewall_ready(), call().wait_for_firewall_ready(os.getpid()),
call().restore_firewall(1024, AF_INET6, True, None, None), call().restore_firewall(1024, AF_INET6, True, None, None),
call().restore_firewall(1025, AF_INET, True, None, None), call().restore_firewall(1025, AF_INET, True, None, None),
] ]