windows: better connection tracker

This commit is contained in:
nom3ad 2024-01-09 23:04:04 +05:30 committed by Brian May
parent 81a598a4cc
commit 7a92183f59
4 changed files with 24 additions and 11 deletions

View File

@ -348,7 +348,7 @@ def main(method_name, syslog):
try:
# For some methods (eg: windivert) firewall setup will be differed / will run asynchronously.
# Such method implements wait_for_firewall_ready() to wait until firewall is up and running.
method.wait_for_firewall_ready()
method.wait_for_firewall_ready(sshuttle_pid)
except NotImplementedError:
pass

View File

@ -98,7 +98,7 @@ class BaseMethod(object):
def restore_firewall(self, port, family, udp, user, group):
raise NotImplementedError()
def wait_for_firewall_ready(self):
def wait_for_firewall_ready(self, sshuttle_pid):
raise NotImplementedError()
@staticmethod

View File

@ -15,7 +15,7 @@ import traceback
from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug3, debug1, debug2, get_verbose_level, Fatal
from sshuttle.helpers import log, debug3, debug1, debug2, get_verbose_level, Fatal
try:
# https://reqrypt.org/windivert-doc.html#divert_iphdr
@ -228,8 +228,17 @@ class ConnTrack:
if entry and entry.startswith(packed):
return self._unpack(entry)
def dump(self):
for entry in self.shm_list:
if not entry:
continue
conn = self._unpack(entry)
proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state = conn
log(f"{proto.name}/{ip_version} {src_addr}:{src_port} -> {dst_addr}:{dst_port} {state.name}@{state_epoch}")
@synchronized_method("rlock")
def gc(self, connection_timeout_sec=15):
# self.dump()
now = int(time.time())
n = 0
for i in tuple(self.used_slots):
@ -261,9 +270,9 @@ class ConnTrack:
) = self.struct_full_tuple.unpack(packed)
dst_addr = ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]).exploded
src_addr = ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]).exploded
return ConnectionTuple(
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
)
proto = IPProtocol(proto)
state = ConnState(state)
return ConnectionTuple(proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state)
def __iter__(self):
def conn_iter():
@ -338,12 +347,14 @@ class Method(BaseMethod):
"proxy_addr": (proxy_ip, proxy_port)
}
def wait_for_firewall_ready(self):
def wait_for_firewall_ready(self, sshuttle_pid):
debug2(f"network_config={self.network_config}")
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
self.conntrack = ConnTrack(f"sshuttle-windivert-{sshuttle_pid}", WINDIVERT_MAX_CONNECTIONS)
if not self.conntrack.is_owner:
raise Fatal("ConnTrack should be owner in wait_for_firewall_ready()")
thread_target_funcs = (self._egress_divert, self._ingress_divert, self._connection_gc)
ready_events = []
for fn in methods:
for fn in thread_target_funcs:
ev = threading.Event()
ready_events.append(ev)
@ -376,6 +387,8 @@ class Method(BaseMethod):
def get_tcp_dstip(self, sock):
if not hasattr(self, "conntrack"):
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
if self.conntrack.is_owner:
raise Fatal("ConnTrack should not be owner in get_tcp_dstip()")
src_addr, src_port = sock.getpeername()
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)

View File

@ -157,7 +157,7 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
None,
None,
'0x01'),
call().wait_for_firewall_ready(),
call().wait_for_firewall_ready(os.getpid()),
call().restore_firewall(1024, AF_INET6, True, None, None),
call().restore_firewall(1025, AF_INET, True, None, None),
]