diff --git a/sshuttle/__main__.py b/sshuttle/__main__.py index c885caa..e3663b5 100644 --- a/sshuttle/__main__.py +++ b/sshuttle/__main__.py @@ -5,4 +5,4 @@ from sshuttle.cmdline import main from sshuttle.helpers import debug3 exit_code=main() debug3("Exiting process %r (pid:%s) with code %s" % (sys.argv, os.getpid(), exit_code,)) -sys.exit(exit_code) +sys.exit(exit_code) \ No newline at end of file diff --git a/sshuttle/methods/windivert.py b/sshuttle/methods/windivert.py index 75bca5a..6d10f83 100644 --- a/sshuttle/methods/windivert.py +++ b/sshuttle/methods/windivert.py @@ -1,8 +1,14 @@ +import os import sys import ipaddress import threading from collections import namedtuple - +import socket +from multiprocessing import shared_memory +import struct +from functools import wraps +from enum import IntEnum +import time try: import pydivert @@ -10,33 +16,147 @@ except ImportError: raise Fatal('Could not import pydivert module. windivert requires https://pypi.org/project/pydivert') from sshuttle.methods import BaseMethod -from sshuttle.helpers import log, debug1, debug2, Fatal +from sshuttle.helpers import debug3, log, debug1, debug2, Fatal # https://reqrypt.org/windivert-doc.html#divert_iphdr ConnectionTuple = namedtuple( - "ConnectionTuple", ["protocol", "src_addr", "src_port", "dst_addr", "dst_port"] + "ConnectionTuple", ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", 'state'] ) -class ConnectionTracker: - def __init__(self) -> None: - self.d = {} - def add_tcp(self, src_addr, src_port, dst_addr, dst_port): - k = ("TCP", src_addr, src_port) - v = (dst_addr, dst_port) - if self.d.get(k) != v: - debug1("Adding tcp connection to tracker:" + repr((src_addr, src_port, dst_addr, dst_port))) - self.d[k] = v +MAX_CONNECTIONS = 2 #_000 + +class IPProtocol(IntEnum): + TCP = socket.IPPROTO_TCP + UDP = socket.IPPROTO_UDP + +class ConnState(IntEnum): + TCP_SYN_SEND = 10 + TCP_SYN_ACK_RECV = 11 + TCP_FIN_SEND = 20 + TCP_FIN_RECV = 21 + +def repr_pkt(p): + r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}" + if p.tcp: + t = p.tcp + r += f" {len(t.payload)}B (" + r += '+'.join(f.upper() for f in ('fin','syn', "rst", "psh", 'ack', 'urg', 'ece', 'cwr', 'ns') if getattr(t, f)) + r += f') SEQ#{t.seq_num}' + if t.ack: + r += f' ACK#{t.ack_num}' + r += f' WZ={t.window_size}' + else: + r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}" + return f"" + +def synchronized_method(lock): + def decorator(method): + @wraps(method) + def wrapped(self, *args, **kwargs): + with getattr(self, lock): + return method(self, *args, **kwargs) + return wrapped + return decorator + +class ConnTrack: + def __init__(self, name, max_connections=0) -> None: + self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B'))) + self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H'))) + self.struct_state_tuple = struct.Struct('>' + ''.join(('L', 'B'))) - def get_tcp(self, src_addr, src_port): try: - return ConnectionTuple( - "TCP", src_addr, src_port, *self.d[("TCP", src_addr, src_port)] - ) - except KeyError: - return None + self.max_connections = max_connections + self.shm_list = shared_memory.ShareableList([bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name) + self.is_owner = True + self.next_slot = 0 + self.used_slotes = set() + self.rlock = threading.RLock() + except FileExistsError: + self.is_owner = False + self.shm_list = shared_memory.ShareableList(name=name) + self.max_connections = len(self.shm_list) + + debug2(f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} shm_size={self.shm_list.shm.size}B") + + @synchronized_method('rlock') + def add(self, proto, src_addr, src_port, dst_addr, dst_port, state): + if not self.is_owner: + raise RuntimeError("Only owner can mutate ConnTrack") + if len(self.used_slotes) >= self.max_connections: + raise RuntimeError(f"No slot avaialble in ConnTrack {len(self.used_slotes)}/{self.max_connections}") + + if self.get(proto, src_addr, src_port): + return + + for _ in range(self.max_connections): + if self.next_slot not in self.used_slotes: + break + self.next_slot = (self.next_slot +1) % self.max_connections + else: + raise RuntimeError("No slot avaialble in ConnTrack") # should not be here + + src_addr = ipaddress.ip_address(src_addr) + dst_addr = ipaddress.ip_address(dst_addr) + assert src_addr.version == dst_addr.version + ip_version = src_addr.version + state_epoch = int(time.time()) + entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state) + packed = self.struct_full_tuple.pack(*entry) + self.shm_list[self.next_slot] = packed + self.used_slotes.add(self.next_slot) + proto = IPProtocol(proto) + debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slotes)}") + + @synchronized_method('rlock') + def update(self, proto, src_addr, src_port, state): + if not self.is_owner: + raise RuntimeError("Only owner can mutate ConnTrack") + src_addr = ipaddress.ip_address(src_addr) + packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) + for i in self.used_slotes: + if self.shm_list[i].startswith(packed): + state_epoch = int(time.time()) + self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state) + debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slotes)}") + return self._unpack(self.shm_list[i]) + else: + debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slotes)}") + + @synchronized_method('rlock') + def remove(self, proto, src_addr, src_port): + if not self.is_owner: + raise RuntimeError("Only owner can mutate ConnTrack") + src_addr = ipaddress.ip_address(src_addr) + packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) + for i in self.used_slotes: + if self.shm_list[i].startswith(packed): + conn = self._unpack(self.shm_list[i]) + self.shm_list[i] = b'' + self.used_slotes.remove(i) + debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slotes)}") + return conn + else: + debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slotes)}") + + + def get(self, proto, src_addr, src_port): + src_addr = ipaddress.ip_address(src_addr) + packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) + for entry in self.shm_list: + if entry and entry.startswith(packed): + return self._unpack(entry) + + def _unpack(self, packed): + (proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed) + dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) + src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) + return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)) + + def __repr__(self): + return f"" class Method(BaseMethod): @@ -44,10 +164,7 @@ class Method(BaseMethod): def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark): log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") - # port=12300, dnsport=0, nslist=[], family=, - # subnets=[(2, 24, False, '10.111.10.0', 0, 0), (2, 16, False, '169.254.0.0', 0, 0), (2, 24, False, '172.31.0.0', 0, 0), (2, 16, False, '192.168.0.0', 0, 0), (2, 32, True, '0.0.0.0', 0, 0)], - # udp=False, user=None, tmark='0x01' - self.conntrack = ConnectionTracker() + self.conntrack = ConnTrack(f'sshttle-windiver-{os.getppid()}', MAX_CONNECTIONS) proxy_addr = "10.0.2.15" subnet_addreses = [] @@ -58,7 +175,7 @@ class Method(BaseMethod): assert lport == 0, 'custom port range not supported' subnet_addreses.append("%s/%s" % (network_addr, mask)) - debug2("subnet_addreses=%s proxy_addr=%s:%s" % (subnet_addreses,proxy_addr,port)) + debug2("setup_firewall() subnet_addreses=%s proxy_addr=%s:%s" % (subnet_addreses,proxy_addr,port)) # check permission with pydivert.WinDivert('false'): @@ -78,7 +195,14 @@ class Method(BaseMethod): return result def get_tcp_dstip(self, sock): - return ('172.31.0.141', 80) + if not hasattr(self, 'conntrack'): + self.conntrack = ConnTrack(f'sshttle-windiver-{os.getpid()}') + + src_addr , src_port = sock.getpeername() + c = self.conntrack.get(IPProtocol.TCP , src_addr, src_port) + if not c: + return (src_addr , src_port) + return (c.dst_addr, c.dst_port) def is_supported(self): if sys.platform == 'win32': @@ -87,6 +211,8 @@ class Method(BaseMethod): + + def _outbound_divert(self, subnets, proxy_addr, proxy_port): # with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w: filter = "outbound and ip and tcp" @@ -101,8 +227,13 @@ class Method(BaseMethod): debug1(f"[OUTBOUND] {filter=}") with pydivert.WinDivert(filter) as w: for pkt in w: - # debug3(repr(pkt)) - self.conntrack.add_tcp(pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port) + debug3(">>> " + repr_pkt(pkt)) + if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment) + self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SEND) + if pkt.tcp.fin: # FIN (start of graceful close) + self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND) + if pkt.tcp.rst : # RST + self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port) pkt.ipv4.dst_addr = proxy_addr pkt.tcp.dst_port = proxy_port w.send(pkt, recalculate_checksum=True) @@ -113,11 +244,15 @@ class Method(BaseMethod): debug2(f"[INBOUND] {filter=}") with pydivert.WinDivert(filter) as w: for pkt in w: - # debug2(repr(conntrack.d)) - # debug2(repr((pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port))) - conn = self.conntrack.get_tcp(pkt.dst_addr, pkt.dst_port) + debug3("<<< " + repr_pkt(pkt)) + if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK Conenction established + conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV) + elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown + conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) + else: + conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) if not conn: - debug2("Unexpcted packet:" + repr((pkt.protocol,pkt.src_addr,pkt.src_port,pkt.dst_addr,pkt.dst_port))) + debug2("Unexpcted packet: " + repr_pkt(pkt)) continue pkt.ipv4.src_addr = conn.dst_addr pkt.tcp.src_port = conn.dst_port