windivert: garbage collect timed put connections from tracker

This commit is contained in:
nom3ad 2022-09-07 12:26:21 +05:30 committed by Brian May
parent 338486930f
commit c01794f232

View File

@ -55,10 +55,15 @@ class IPFamily(IntEnum):
class ConnState(IntEnum):
TCP_SYN_SEND = 10
TCP_SYN_ACK_RECV = 11
TCP_FIN_SEND = 20
TCP_FIN_RECV = 21
TCP_SYN_SENT = 11 # SYN sent
TCP_ESTABLISHED = 12 # SYN+ACK received
TCP_FIN_WAIT_1 = 91 # FIN sent
TCP_CLOSE_WAIT = 92 # FIN received
@staticmethod
def can_timeout(state):
return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT)
def repr_pkt(p):
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
@ -138,7 +143,7 @@ class ConnTrack:
self.shm_list[self.next_slot] = packed
self.used_slots.add(self.next_slot)
proto = IPProtocol(proto)
debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
debug3(f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
@synchronized_method('rlock')
def update(self, proto, src_addr, src_port, state):
@ -150,10 +155,10 @@ class ConnTrack:
if self.shm_list[i].startswith(packed):
state_epoch = int(time.time())
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
debug3(f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
return self._unpack(self.shm_list[i])
else:
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
@synchronized_method('rlock')
def remove(self, proto, src_addr, src_port):
@ -166,10 +171,10 @@ class ConnTrack:
conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b''
self.used_slots.remove(i)
debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slots)}")
debug3(f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
return conn
else:
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
def get(self, proto, src_addr, src_port):
@ -179,12 +184,35 @@ class ConnTrack:
if entry and entry.startswith(packed):
return self._unpack(entry)
@synchronized_method('rlock')
def gc(self, connection_timeout_sec=15):
now = int(time.time())
n = 0
for i in tuple(self.used_slots):
state_packed = self.shm_list[i][-5:]
(state_epoch, state) = self.struct_state_tuple.unpack(state_packed)
if (now - state_epoch) < connection_timeout_sec:
continue
if ConnState.can_timeout(state):
conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b''
self.used_slots.remove(i)
n += 1
debug3(f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
def _unpack(self, packed):
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed)
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
def __iter__(self):
def conn_iter():
for i in self.used_slots:
yield self._unpack(self.shm_list[i])
return conn_iter()
def __repr__(self):
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
@ -238,7 +266,7 @@ class Method(BaseMethod):
def wait_for_firewall_ready(self):
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
methods = (self._egress_divert, self._ingress_divert)
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
ready_events = []
for fn in methods:
ev = threading.Event()
@ -305,11 +333,11 @@ class Method(BaseMethod):
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
for pkt in w:
debug3(">>> " + repr_pkt(pkt))
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SEND)
if pkt.tcp.fin: # FIN (start of graceful close)
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
if pkt.tcp.rst : # RST
if pkt.tcp.syn and not pkt.tcp.ack: # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SENT)
if pkt.tcp.fin: # FIN sent (start of graceful close our side, and we wait for ACK)
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
if pkt.tcp.rst : # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
# DNAT
@ -346,10 +374,15 @@ class Method(BaseMethod):
ready_cb()
for pkt in w:
debug3("<<< " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV)
elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK received (connection established)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
elif pkt.tcp.rst: # RST received - Abrupt connection teardown initiated by other side. We don't expect anymore packets
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
# https://wiki.wireshark.org/TCP-4-times-close.md
elif pkt.tcp.fin and pkt.tcp.ack: # FIN+ACK received (Passive close by other side. We don't expect any more packets. Other side expects an ACK)
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
elif pkt.tcp.fin: # FIN received (Other side initiated graceful close. We expects a final ACK for a FIN packet)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
else:
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
if not conn:
@ -359,4 +392,8 @@ class Method(BaseMethod):
pkt.tcp.src_port = conn.dst_port
w.send(pkt, recalculate_checksum=True)
def _connection_gc(self, ready_cb):
ready_cb()
while True:
time.sleep(5)
self.conntrack.gc()