mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-21 23:43:18 +01:00
windivert: garbage collect timed put connections from tracker
This commit is contained in:
parent
338486930f
commit
c01794f232
@ -55,10 +55,15 @@ class IPFamily(IntEnum):
|
||||
|
||||
|
||||
class ConnState(IntEnum):
|
||||
TCP_SYN_SEND = 10
|
||||
TCP_SYN_ACK_RECV = 11
|
||||
TCP_FIN_SEND = 20
|
||||
TCP_FIN_RECV = 21
|
||||
TCP_SYN_SENT = 11 # SYN sent
|
||||
TCP_ESTABLISHED = 12 # SYN+ACK received
|
||||
TCP_FIN_WAIT_1 = 91 # FIN sent
|
||||
TCP_CLOSE_WAIT = 92 # FIN received
|
||||
|
||||
@staticmethod
|
||||
def can_timeout(state):
|
||||
return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT)
|
||||
|
||||
|
||||
def repr_pkt(p):
|
||||
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
|
||||
@ -138,7 +143,7 @@ class ConnTrack:
|
||||
self.shm_list[self.next_slot] = packed
|
||||
self.used_slots.add(self.next_slot)
|
||||
proto = IPProtocol(proto)
|
||||
debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
|
||||
|
||||
@synchronized_method('rlock')
|
||||
def update(self, proto, src_addr, src_port, state):
|
||||
@ -150,10 +155,10 @@ class ConnTrack:
|
||||
if self.shm_list[i].startswith(packed):
|
||||
state_epoch = int(time.time())
|
||||
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
||||
debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
return self._unpack(self.shm_list[i])
|
||||
else:
|
||||
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
|
||||
|
||||
@synchronized_method('rlock')
|
||||
def remove(self, proto, src_addr, src_port):
|
||||
@ -166,10 +171,10 @@ class ConnTrack:
|
||||
conn = self._unpack(self.shm_list[i])
|
||||
self.shm_list[i] = b''
|
||||
self.used_slots.remove(i)
|
||||
debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
return conn
|
||||
else:
|
||||
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
|
||||
|
||||
|
||||
def get(self, proto, src_addr, src_port):
|
||||
@ -179,12 +184,35 @@ class ConnTrack:
|
||||
if entry and entry.startswith(packed):
|
||||
return self._unpack(entry)
|
||||
|
||||
@synchronized_method('rlock')
|
||||
def gc(self, connection_timeout_sec=15):
|
||||
now = int(time.time())
|
||||
n = 0
|
||||
for i in tuple(self.used_slots):
|
||||
state_packed = self.shm_list[i][-5:]
|
||||
(state_epoch, state) = self.struct_state_tuple.unpack(state_packed)
|
||||
if (now - state_epoch) < connection_timeout_sec:
|
||||
continue
|
||||
if ConnState.can_timeout(state):
|
||||
conn = self._unpack(self.shm_list[i])
|
||||
self.shm_list[i] = b''
|
||||
self.used_slots.remove(i)
|
||||
n += 1
|
||||
debug3(f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
|
||||
|
||||
def _unpack(self, packed):
|
||||
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed)
|
||||
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
||||
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
||||
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
def conn_iter():
|
||||
for i in self.used_slots:
|
||||
yield self._unpack(self.shm_list[i])
|
||||
return conn_iter()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
|
||||
|
||||
@ -238,7 +266,7 @@ class Method(BaseMethod):
|
||||
def wait_for_firewall_ready(self):
|
||||
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
||||
methods = (self._egress_divert, self._ingress_divert)
|
||||
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
||||
ready_events = []
|
||||
for fn in methods:
|
||||
ev = threading.Event()
|
||||
@ -305,11 +333,11 @@ class Method(BaseMethod):
|
||||
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
||||
for pkt in w:
|
||||
debug3(">>> " + repr_pkt(pkt))
|
||||
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
|
||||
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SEND)
|
||||
if pkt.tcp.fin: # FIN (start of graceful close)
|
||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
|
||||
if pkt.tcp.rst : # RST
|
||||
if pkt.tcp.syn and not pkt.tcp.ack: # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
|
||||
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SENT)
|
||||
if pkt.tcp.fin: # FIN sent (start of graceful close our side, and we wait for ACK)
|
||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
|
||||
if pkt.tcp.rst : # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
|
||||
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||
|
||||
# DNAT
|
||||
@ -346,10 +374,15 @@ class Method(BaseMethod):
|
||||
ready_cb()
|
||||
for pkt in w:
|
||||
debug3("<<< " + repr_pkt(pkt))
|
||||
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
|
||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV)
|
||||
elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown
|
||||
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK received (connection established)
|
||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
|
||||
elif pkt.tcp.rst: # RST received - Abrupt connection teardown initiated by other side. We don't expect anymore packets
|
||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||
# https://wiki.wireshark.org/TCP-4-times-close.md
|
||||
elif pkt.tcp.fin and pkt.tcp.ack: # FIN+ACK received (Passive close by other side. We don't expect any more packets. Other side expects an ACK)
|
||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||
elif pkt.tcp.fin: # FIN received (Other side initiated graceful close. We expects a final ACK for a FIN packet)
|
||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
|
||||
else:
|
||||
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
||||
if not conn:
|
||||
@ -359,4 +392,8 @@ class Method(BaseMethod):
|
||||
pkt.tcp.src_port = conn.dst_port
|
||||
w.send(pkt, recalculate_checksum=True)
|
||||
|
||||
|
||||
def _connection_gc(self, ready_cb):
|
||||
ready_cb()
|
||||
while True:
|
||||
time.sleep(5)
|
||||
self.conntrack.gc()
|
||||
|
Loading…
Reference in New Issue
Block a user