mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-04-23 10:49:35 +02:00
windivert: garbage collect timed put connections from tracker
This commit is contained in:
parent
338486930f
commit
c01794f232
@ -55,10 +55,15 @@ class IPFamily(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class ConnState(IntEnum):
|
class ConnState(IntEnum):
|
||||||
TCP_SYN_SEND = 10
|
TCP_SYN_SENT = 11 # SYN sent
|
||||||
TCP_SYN_ACK_RECV = 11
|
TCP_ESTABLISHED = 12 # SYN+ACK received
|
||||||
TCP_FIN_SEND = 20
|
TCP_FIN_WAIT_1 = 91 # FIN sent
|
||||||
TCP_FIN_RECV = 21
|
TCP_CLOSE_WAIT = 92 # FIN received
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def can_timeout(state):
|
||||||
|
return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT)
|
||||||
|
|
||||||
|
|
||||||
def repr_pkt(p):
|
def repr_pkt(p):
|
||||||
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
|
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
|
||||||
@ -138,7 +143,7 @@ class ConnTrack:
|
|||||||
self.shm_list[self.next_slot] = packed
|
self.shm_list[self.next_slot] = packed
|
||||||
self.used_slots.add(self.next_slot)
|
self.used_slots.add(self.next_slot)
|
||||||
proto = IPProtocol(proto)
|
proto = IPProtocol(proto)
|
||||||
debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
|
debug3(f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
|
||||||
|
|
||||||
@synchronized_method('rlock')
|
@synchronized_method('rlock')
|
||||||
def update(self, proto, src_addr, src_port, state):
|
def update(self, proto, src_addr, src_port, state):
|
||||||
@ -150,10 +155,10 @@ class ConnTrack:
|
|||||||
if self.shm_list[i].startswith(packed):
|
if self.shm_list[i].startswith(packed):
|
||||||
state_epoch = int(time.time())
|
state_epoch = int(time.time())
|
||||||
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
||||||
debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
debug3(f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||||
return self._unpack(self.shm_list[i])
|
return self._unpack(self.shm_list[i])
|
||||||
else:
|
else:
|
||||||
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
|
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
|
||||||
|
|
||||||
@synchronized_method('rlock')
|
@synchronized_method('rlock')
|
||||||
def remove(self, proto, src_addr, src_port):
|
def remove(self, proto, src_addr, src_port):
|
||||||
@ -166,10 +171,10 @@ class ConnTrack:
|
|||||||
conn = self._unpack(self.shm_list[i])
|
conn = self._unpack(self.shm_list[i])
|
||||||
self.shm_list[i] = b''
|
self.shm_list[i] = b''
|
||||||
self.used_slots.remove(i)
|
self.used_slots.remove(i)
|
||||||
debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
debug3(f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||||
return conn
|
return conn
|
||||||
else:
|
else:
|
||||||
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
|
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
|
||||||
|
|
||||||
|
|
||||||
def get(self, proto, src_addr, src_port):
|
def get(self, proto, src_addr, src_port):
|
||||||
@ -179,12 +184,35 @@ class ConnTrack:
|
|||||||
if entry and entry.startswith(packed):
|
if entry and entry.startswith(packed):
|
||||||
return self._unpack(entry)
|
return self._unpack(entry)
|
||||||
|
|
||||||
|
@synchronized_method('rlock')
|
||||||
|
def gc(self, connection_timeout_sec=15):
|
||||||
|
now = int(time.time())
|
||||||
|
n = 0
|
||||||
|
for i in tuple(self.used_slots):
|
||||||
|
state_packed = self.shm_list[i][-5:]
|
||||||
|
(state_epoch, state) = self.struct_state_tuple.unpack(state_packed)
|
||||||
|
if (now - state_epoch) < connection_timeout_sec:
|
||||||
|
continue
|
||||||
|
if ConnState.can_timeout(state):
|
||||||
|
conn = self._unpack(self.shm_list[i])
|
||||||
|
self.shm_list[i] = b''
|
||||||
|
self.used_slots.remove(i)
|
||||||
|
n += 1
|
||||||
|
debug3(f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||||
|
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
|
||||||
|
|
||||||
def _unpack(self, packed):
|
def _unpack(self, packed):
|
||||||
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed)
|
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed)
|
||||||
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
||||||
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
||||||
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
|
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
def conn_iter():
|
||||||
|
for i in self.used_slots:
|
||||||
|
yield self._unpack(self.shm_list[i])
|
||||||
|
return conn_iter()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
|
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
|
||||||
|
|
||||||
@ -238,7 +266,7 @@ class Method(BaseMethod):
|
|||||||
def wait_for_firewall_ready(self):
|
def wait_for_firewall_ready(self):
|
||||||
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
||||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
||||||
methods = (self._egress_divert, self._ingress_divert)
|
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
||||||
ready_events = []
|
ready_events = []
|
||||||
for fn in methods:
|
for fn in methods:
|
||||||
ev = threading.Event()
|
ev = threading.Event()
|
||||||
@ -305,11 +333,11 @@ class Method(BaseMethod):
|
|||||||
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
debug3(">>> " + repr_pkt(pkt))
|
debug3(">>> " + repr_pkt(pkt))
|
||||||
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
|
if pkt.tcp.syn and not pkt.tcp.ack: # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
|
||||||
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SEND)
|
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SENT)
|
||||||
if pkt.tcp.fin: # FIN (start of graceful close)
|
if pkt.tcp.fin: # FIN sent (start of graceful close our side, and we wait for ACK)
|
||||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
|
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
|
||||||
if pkt.tcp.rst : # RST
|
if pkt.tcp.rst : # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
|
||||||
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||||
|
|
||||||
# DNAT
|
# DNAT
|
||||||
@ -346,10 +374,15 @@ class Method(BaseMethod):
|
|||||||
ready_cb()
|
ready_cb()
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
debug3("<<< " + repr_pkt(pkt))
|
debug3("<<< " + repr_pkt(pkt))
|
||||||
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
|
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK received (connection established)
|
||||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV)
|
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
|
||||||
elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown
|
elif pkt.tcp.rst: # RST received - Abrupt connection teardown initiated by other side. We don't expect anymore packets
|
||||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
|
# https://wiki.wireshark.org/TCP-4-times-close.md
|
||||||
|
elif pkt.tcp.fin and pkt.tcp.ack: # FIN+ACK received (Passive close by other side. We don't expect any more packets. Other side expects an ACK)
|
||||||
|
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
|
elif pkt.tcp.fin: # FIN received (Other side initiated graceful close. We expects a final ACK for a FIN packet)
|
||||||
|
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
|
||||||
else:
|
else:
|
||||||
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
if not conn:
|
if not conn:
|
||||||
@ -359,4 +392,8 @@ class Method(BaseMethod):
|
|||||||
pkt.tcp.src_port = conn.dst_port
|
pkt.tcp.src_port = conn.dst_port
|
||||||
w.send(pkt, recalculate_checksum=True)
|
w.send(pkt, recalculate_checksum=True)
|
||||||
|
|
||||||
|
def _connection_gc(self, ready_cb):
|
||||||
|
ready_cb()
|
||||||
|
while True:
|
||||||
|
time.sleep(5)
|
||||||
|
self.conntrack.gc()
|
||||||
|
Loading…
Reference in New Issue
Block a user