mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-25 01:13:37 +01:00
windows: better connection tracker
This commit is contained in:
parent
81a598a4cc
commit
7a92183f59
@ -348,7 +348,7 @@ def main(method_name, syslog):
|
|||||||
try:
|
try:
|
||||||
# For some methods (eg: windivert) firewall setup will be differed / will run asynchronously.
|
# For some methods (eg: windivert) firewall setup will be differed / will run asynchronously.
|
||||||
# Such method implements wait_for_firewall_ready() to wait until firewall is up and running.
|
# Such method implements wait_for_firewall_ready() to wait until firewall is up and running.
|
||||||
method.wait_for_firewall_ready()
|
method.wait_for_firewall_ready(sshuttle_pid)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ class BaseMethod(object):
|
|||||||
def restore_firewall(self, port, family, udp, user, group):
|
def restore_firewall(self, port, family, udp, user, group):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def wait_for_firewall_ready(self):
|
def wait_for_firewall_ready(self, sshuttle_pid):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -15,7 +15,7 @@ import traceback
|
|||||||
|
|
||||||
|
|
||||||
from sshuttle.methods import BaseMethod
|
from sshuttle.methods import BaseMethod
|
||||||
from sshuttle.helpers import debug3, debug1, debug2, get_verbose_level, Fatal
|
from sshuttle.helpers import log, debug3, debug1, debug2, get_verbose_level, Fatal
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
||||||
@ -228,8 +228,17 @@ class ConnTrack:
|
|||||||
if entry and entry.startswith(packed):
|
if entry and entry.startswith(packed):
|
||||||
return self._unpack(entry)
|
return self._unpack(entry)
|
||||||
|
|
||||||
|
def dump(self):
|
||||||
|
for entry in self.shm_list:
|
||||||
|
if not entry:
|
||||||
|
continue
|
||||||
|
conn = self._unpack(entry)
|
||||||
|
proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state = conn
|
||||||
|
log(f"{proto.name}/{ip_version} {src_addr}:{src_port} -> {dst_addr}:{dst_port} {state.name}@{state_epoch}")
|
||||||
|
|
||||||
@synchronized_method("rlock")
|
@synchronized_method("rlock")
|
||||||
def gc(self, connection_timeout_sec=15):
|
def gc(self, connection_timeout_sec=15):
|
||||||
|
# self.dump()
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
n = 0
|
n = 0
|
||||||
for i in tuple(self.used_slots):
|
for i in tuple(self.used_slots):
|
||||||
@ -261,9 +270,9 @@ class ConnTrack:
|
|||||||
) = self.struct_full_tuple.unpack(packed)
|
) = self.struct_full_tuple.unpack(packed)
|
||||||
dst_addr = ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]).exploded
|
dst_addr = ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]).exploded
|
||||||
src_addr = ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]).exploded
|
src_addr = ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]).exploded
|
||||||
return ConnectionTuple(
|
proto = IPProtocol(proto)
|
||||||
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
|
state = ConnState(state)
|
||||||
)
|
return ConnectionTuple(proto, ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, state)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
def conn_iter():
|
def conn_iter():
|
||||||
@ -338,12 +347,14 @@ class Method(BaseMethod):
|
|||||||
"proxy_addr": (proxy_ip, proxy_port)
|
"proxy_addr": (proxy_ip, proxy_port)
|
||||||
}
|
}
|
||||||
|
|
||||||
def wait_for_firewall_ready(self):
|
def wait_for_firewall_ready(self, sshuttle_pid):
|
||||||
debug2(f"network_config={self.network_config}")
|
debug2(f"network_config={self.network_config}")
|
||||||
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
|
self.conntrack = ConnTrack(f"sshuttle-windivert-{sshuttle_pid}", WINDIVERT_MAX_CONNECTIONS)
|
||||||
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
if not self.conntrack.is_owner:
|
||||||
|
raise Fatal("ConnTrack should be owner in wait_for_firewall_ready()")
|
||||||
|
thread_target_funcs = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
||||||
ready_events = []
|
ready_events = []
|
||||||
for fn in methods:
|
for fn in thread_target_funcs:
|
||||||
ev = threading.Event()
|
ev = threading.Event()
|
||||||
ready_events.append(ev)
|
ready_events.append(ev)
|
||||||
|
|
||||||
@ -376,6 +387,8 @@ class Method(BaseMethod):
|
|||||||
def get_tcp_dstip(self, sock):
|
def get_tcp_dstip(self, sock):
|
||||||
if not hasattr(self, "conntrack"):
|
if not hasattr(self, "conntrack"):
|
||||||
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
|
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
|
||||||
|
if self.conntrack.is_owner:
|
||||||
|
raise Fatal("ConnTrack should not be owner in get_tcp_dstip()")
|
||||||
|
|
||||||
src_addr, src_port = sock.getpeername()
|
src_addr, src_port = sock.getpeername()
|
||||||
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
|
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
|
||||||
|
@ -157,7 +157,7 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
'0x01'),
|
'0x01'),
|
||||||
call().wait_for_firewall_ready(),
|
call().wait_for_firewall_ready(os.getpid()),
|
||||||
call().restore_firewall(1024, AF_INET6, True, None, None),
|
call().restore_firewall(1024, AF_INET6, True, None, None),
|
||||||
call().restore_firewall(1025, AF_INET, True, None, None),
|
call().restore_firewall(1025, AF_INET, True, None, None),
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user