mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-04-23 18:58:59 +02:00
windivert - basic working connection tracker
This commit is contained in:
parent
5a64c81b5b
commit
2c74476124
@ -1,8 +1,14 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
import socket
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
import struct
|
||||||
|
from functools import wraps
|
||||||
|
from enum import IntEnum
|
||||||
|
import time
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pydivert
|
import pydivert
|
||||||
@ -10,33 +16,147 @@ except ImportError:
|
|||||||
raise Fatal('Could not import pydivert module. windivert requires https://pypi.org/project/pydivert')
|
raise Fatal('Could not import pydivert module. windivert requires https://pypi.org/project/pydivert')
|
||||||
|
|
||||||
from sshuttle.methods import BaseMethod
|
from sshuttle.methods import BaseMethod
|
||||||
from sshuttle.helpers import log, debug1, debug2, Fatal
|
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal
|
||||||
|
|
||||||
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
||||||
|
|
||||||
|
|
||||||
ConnectionTuple = namedtuple(
|
ConnectionTuple = namedtuple(
|
||||||
"ConnectionTuple", ["protocol", "src_addr", "src_port", "dst_addr", "dst_port"]
|
"ConnectionTuple", ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", 'state']
|
||||||
)
|
)
|
||||||
|
|
||||||
class ConnectionTracker:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.d = {}
|
|
||||||
|
|
||||||
def add_tcp(self, src_addr, src_port, dst_addr, dst_port):
|
MAX_CONNECTIONS = 2 #_000
|
||||||
k = ("TCP", src_addr, src_port)
|
|
||||||
v = (dst_addr, dst_port)
|
class IPProtocol(IntEnum):
|
||||||
if self.d.get(k) != v:
|
TCP = socket.IPPROTO_TCP
|
||||||
debug1("Adding tcp connection to tracker:" + repr((src_addr, src_port, dst_addr, dst_port)))
|
UDP = socket.IPPROTO_UDP
|
||||||
self.d[k] = v
|
|
||||||
|
class ConnState(IntEnum):
|
||||||
|
TCP_SYN_SEND = 10
|
||||||
|
TCP_SYN_ACK_RECV = 11
|
||||||
|
TCP_FIN_SEND = 20
|
||||||
|
TCP_FIN_RECV = 21
|
||||||
|
|
||||||
|
def repr_pkt(p):
|
||||||
|
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
|
||||||
|
if p.tcp:
|
||||||
|
t = p.tcp
|
||||||
|
r += f" {len(t.payload)}B ("
|
||||||
|
r += '+'.join(f.upper() for f in ('fin','syn', "rst", "psh", 'ack', 'urg', 'ece', 'cwr', 'ns') if getattr(t, f))
|
||||||
|
r += f') SEQ#{t.seq_num}'
|
||||||
|
if t.ack:
|
||||||
|
r += f' ACK#{t.ack_num}'
|
||||||
|
r += f' WZ={t.window_size}'
|
||||||
|
else:
|
||||||
|
r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}"
|
||||||
|
return f"<Pkt {r}>"
|
||||||
|
|
||||||
|
def synchronized_method(lock):
|
||||||
|
def decorator(method):
|
||||||
|
@wraps(method)
|
||||||
|
def wrapped(self, *args, **kwargs):
|
||||||
|
with getattr(self, lock):
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
return wrapped
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
class ConnTrack:
|
||||||
|
def __init__(self, name, max_connections=0) -> None:
|
||||||
|
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B')))
|
||||||
|
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H')))
|
||||||
|
self.struct_state_tuple = struct.Struct('>' + ''.join(('L', 'B')))
|
||||||
|
|
||||||
def get_tcp(self, src_addr, src_port):
|
|
||||||
try:
|
try:
|
||||||
return ConnectionTuple(
|
self.max_connections = max_connections
|
||||||
"TCP", src_addr, src_port, *self.d[("TCP", src_addr, src_port)]
|
self.shm_list = shared_memory.ShareableList([bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name)
|
||||||
)
|
self.is_owner = True
|
||||||
except KeyError:
|
self.next_slot = 0
|
||||||
return None
|
self.used_slotes = set()
|
||||||
|
self.rlock = threading.RLock()
|
||||||
|
except FileExistsError:
|
||||||
|
self.is_owner = False
|
||||||
|
self.shm_list = shared_memory.ShareableList(name=name)
|
||||||
|
self.max_connections = len(self.shm_list)
|
||||||
|
|
||||||
|
debug2(f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} shm_size={self.shm_list.shm.size}B")
|
||||||
|
|
||||||
|
@synchronized_method('rlock')
|
||||||
|
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
|
||||||
|
if not self.is_owner:
|
||||||
|
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||||
|
if len(self.used_slotes) >= self.max_connections:
|
||||||
|
raise RuntimeError(f"No slot avaialble in ConnTrack {len(self.used_slotes)}/{self.max_connections}")
|
||||||
|
|
||||||
|
if self.get(proto, src_addr, src_port):
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(self.max_connections):
|
||||||
|
if self.next_slot not in self.used_slotes:
|
||||||
|
break
|
||||||
|
self.next_slot = (self.next_slot +1) % self.max_connections
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No slot avaialble in ConnTrack") # should not be here
|
||||||
|
|
||||||
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
|
dst_addr = ipaddress.ip_address(dst_addr)
|
||||||
|
assert src_addr.version == dst_addr.version
|
||||||
|
ip_version = src_addr.version
|
||||||
|
state_epoch = int(time.time())
|
||||||
|
entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state)
|
||||||
|
packed = self.struct_full_tuple.pack(*entry)
|
||||||
|
self.shm_list[self.next_slot] = packed
|
||||||
|
self.used_slotes.add(self.next_slot)
|
||||||
|
proto = IPProtocol(proto)
|
||||||
|
debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slotes)}")
|
||||||
|
|
||||||
|
@synchronized_method('rlock')
|
||||||
|
def update(self, proto, src_addr, src_port, state):
|
||||||
|
if not self.is_owner:
|
||||||
|
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||||
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
|
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
||||||
|
for i in self.used_slotes:
|
||||||
|
if self.shm_list[i].startswith(packed):
|
||||||
|
state_epoch = int(time.time())
|
||||||
|
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
||||||
|
debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slotes)}")
|
||||||
|
return self._unpack(self.shm_list[i])
|
||||||
|
else:
|
||||||
|
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slotes)}")
|
||||||
|
|
||||||
|
@synchronized_method('rlock')
|
||||||
|
def remove(self, proto, src_addr, src_port):
|
||||||
|
if not self.is_owner:
|
||||||
|
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||||
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
|
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
||||||
|
for i in self.used_slotes:
|
||||||
|
if self.shm_list[i].startswith(packed):
|
||||||
|
conn = self._unpack(self.shm_list[i])
|
||||||
|
self.shm_list[i] = b''
|
||||||
|
self.used_slotes.remove(i)
|
||||||
|
debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slotes)}")
|
||||||
|
return conn
|
||||||
|
else:
|
||||||
|
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slotes)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get(self, proto, src_addr, src_port):
|
||||||
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
|
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
||||||
|
for entry in self.shm_list:
|
||||||
|
if entry and entry.startswith(packed):
|
||||||
|
return self._unpack(entry)
|
||||||
|
|
||||||
|
def _unpack(self, packed):
|
||||||
|
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed)
|
||||||
|
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
||||||
|
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
||||||
|
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ConnTrack(n={len(self.used_slotes) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
|
||||||
|
|
||||||
|
|
||||||
class Method(BaseMethod):
|
class Method(BaseMethod):
|
||||||
@ -44,10 +164,7 @@ class Method(BaseMethod):
|
|||||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
|
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
|
||||||
user, tmark):
|
user, tmark):
|
||||||
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||||
# port=12300, dnsport=0, nslist=[], family=<AddressFamily.AF_INET: 2>,
|
self.conntrack = ConnTrack(f'sshttle-windiver-{os.getppid()}', MAX_CONNECTIONS)
|
||||||
# subnets=[(2, 24, False, '10.111.10.0', 0, 0), (2, 16, False, '169.254.0.0', 0, 0), (2, 24, False, '172.31.0.0', 0, 0), (2, 16, False, '192.168.0.0', 0, 0), (2, 32, True, '0.0.0.0', 0, 0)],
|
|
||||||
# udp=False, user=None, tmark='0x01'
|
|
||||||
self.conntrack = ConnectionTracker()
|
|
||||||
proxy_addr = "10.0.2.15"
|
proxy_addr = "10.0.2.15"
|
||||||
|
|
||||||
subnet_addreses = []
|
subnet_addreses = []
|
||||||
@ -58,7 +175,7 @@ class Method(BaseMethod):
|
|||||||
assert lport == 0, 'custom port range not supported'
|
assert lport == 0, 'custom port range not supported'
|
||||||
subnet_addreses.append("%s/%s" % (network_addr, mask))
|
subnet_addreses.append("%s/%s" % (network_addr, mask))
|
||||||
|
|
||||||
debug2("subnet_addreses=%s proxy_addr=%s:%s" % (subnet_addreses,proxy_addr,port))
|
debug2("setup_firewall() subnet_addreses=%s proxy_addr=%s:%s" % (subnet_addreses,proxy_addr,port))
|
||||||
|
|
||||||
# check permission
|
# check permission
|
||||||
with pydivert.WinDivert('false'):
|
with pydivert.WinDivert('false'):
|
||||||
@ -78,7 +195,14 @@ class Method(BaseMethod):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_tcp_dstip(self, sock):
|
def get_tcp_dstip(self, sock):
|
||||||
return ('172.31.0.141', 80)
|
if not hasattr(self, 'conntrack'):
|
||||||
|
self.conntrack = ConnTrack(f'sshttle-windiver-{os.getpid()}')
|
||||||
|
|
||||||
|
src_addr , src_port = sock.getpeername()
|
||||||
|
c = self.conntrack.get(IPProtocol.TCP , src_addr, src_port)
|
||||||
|
if not c:
|
||||||
|
return (src_addr , src_port)
|
||||||
|
return (c.dst_addr, c.dst_port)
|
||||||
|
|
||||||
def is_supported(self):
|
def is_supported(self):
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
@ -87,6 +211,8 @@ class Method(BaseMethod):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _outbound_divert(self, subnets, proxy_addr, proxy_port):
|
def _outbound_divert(self, subnets, proxy_addr, proxy_port):
|
||||||
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
||||||
filter = "outbound and ip and tcp"
|
filter = "outbound and ip and tcp"
|
||||||
@ -101,8 +227,13 @@ class Method(BaseMethod):
|
|||||||
debug1(f"[OUTBOUND] {filter=}")
|
debug1(f"[OUTBOUND] {filter=}")
|
||||||
with pydivert.WinDivert(filter) as w:
|
with pydivert.WinDivert(filter) as w:
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
# debug3(repr(pkt))
|
debug3(">>> " + repr_pkt(pkt))
|
||||||
self.conntrack.add_tcp(pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port)
|
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
|
||||||
|
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SEND)
|
||||||
|
if pkt.tcp.fin: # FIN (start of graceful close)
|
||||||
|
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
|
||||||
|
if pkt.tcp.rst : # RST
|
||||||
|
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||||
pkt.ipv4.dst_addr = proxy_addr
|
pkt.ipv4.dst_addr = proxy_addr
|
||||||
pkt.tcp.dst_port = proxy_port
|
pkt.tcp.dst_port = proxy_port
|
||||||
w.send(pkt, recalculate_checksum=True)
|
w.send(pkt, recalculate_checksum=True)
|
||||||
@ -113,11 +244,15 @@ class Method(BaseMethod):
|
|||||||
debug2(f"[INBOUND] {filter=}")
|
debug2(f"[INBOUND] {filter=}")
|
||||||
with pydivert.WinDivert(filter) as w:
|
with pydivert.WinDivert(filter) as w:
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
# debug2(repr(conntrack.d))
|
debug3("<<< " + repr_pkt(pkt))
|
||||||
# debug2(repr((pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port)))
|
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK Conenction established
|
||||||
conn = self.conntrack.get_tcp(pkt.dst_addr, pkt.dst_port)
|
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV)
|
||||||
|
elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown
|
||||||
|
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
|
else:
|
||||||
|
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
if not conn:
|
if not conn:
|
||||||
debug2("Unexpcted packet:" + repr((pkt.protocol,pkt.src_addr,pkt.src_port,pkt.dst_addr,pkt.dst_port)))
|
debug2("Unexpcted packet: " + repr_pkt(pkt))
|
||||||
continue
|
continue
|
||||||
pkt.ipv4.src_addr = conn.dst_addr
|
pkt.ipv4.src_addr = conn.dst_addr
|
||||||
pkt.tcp.src_port = conn.dst_port
|
pkt.tcp.src_port = conn.dst_port
|
||||||
|
Loading…
Reference in New Issue
Block a user