1
1
mirror of https://github.com/sshuttle/sshuttle.git synced 2025-05-07 01:24:27 +02:00

support port ranges and exclude subnets

This commit is contained in:
nom3ad 2024-01-08 22:22:50 +05:30 committed by Brian May
parent 72060abbef
commit 89a94ff150

View File

@ -7,7 +7,7 @@ import socket
import subprocess import subprocess
import re import re
from multiprocessing import shared_memory from multiprocessing import shared_memory
import struct from struct import Struct
from functools import wraps from functools import wraps
from enum import IntEnum from enum import IntEnum
import time import time
@ -61,7 +61,7 @@ class IPFamily(IntEnum):
@property @property
def loopback_addr(self): def loopback_addr(self):
return "127.0.0.1" if self == socket.AF_INET else "::1" return ip_address("127.0.0.1" if self == socket.AF_INET else "::1")
class ConnState(IntEnum): class ConnState(IntEnum):
@ -123,9 +123,9 @@ class ConnTrack:
raise RuntimeError("ConnTrack can not be instantiated multiple times") raise RuntimeError("ConnTrack can not be instantiated multiple times")
def __init__(self, name, max_connections=0) -> None: def __init__(self, name, max_connections=0) -> None:
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B"))) self.struct_full_tuple = Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H"))) self.struct_src_tuple = Struct(">" + "".join(("B", "B", "16s", "H")))
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B"))) self.struct_state_tuple = Struct(">" + "".join(("L", "B")))
try: try:
self.max_connections = max_connections self.max_connections = max_connections
@ -142,8 +142,8 @@ class ConnTrack:
self.max_connections = len(self.shm_list) self.max_connections = len(self.shm_list)
debug2( debug2(
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} " f"ConnTrack: is_owner={self.is_owner} cap={len(self.shm_list)} item_sz={self.struct_full_tuple.size}B"
f"shm_size={self.shm_list.shm.size}B" f"shm_name={self.shm_list.shm.name} shm_sz={self.shm_list.shm.size}B"
) )
@synchronized_method("rlock") @synchronized_method("rlock")
@ -279,7 +279,6 @@ class ConnTrack:
class Method(BaseMethod): class Method(BaseMethod):
network_config = {} network_config = {}
proxy_port = None
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
@ -297,10 +296,10 @@ class Method(BaseMethod):
raise Fatal("Could not find listening address for {}/{}".format(port, proto)) raise Fatal("Could not find listening address for {}/{}".format(port, proto))
def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark): def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark):
debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {group=} {tmark=}")
if nslist or user or udp: if nslist or user or udp or group:
raise NotImplementedError() raise NotImplementedError("user, group, nslist, udp are not supported")
family = IPFamily(family) family = IPFamily(family)
@ -312,24 +311,26 @@ class Method(BaseMethod):
if proxy_bind_addr.is_loopback: if proxy_bind_addr.is_loopback:
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.") raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.")
if not proxy_bind_addr.is_unspecified: if not proxy_bind_addr.is_unspecified:
proxy_ip = proxy_bind_addr.exploded proxy_ip = proxy_bind_addr
else: else:
local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)] local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)]
for addr in local_addresses: for addr in local_addresses:
if not addr.is_loopback and not addr.is_link_local: if not addr.is_loopback and not addr.is_link_local:
proxy_ip = addr.exploded proxy_ip = addr
break break
else: else:
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address." raise Fatal("Windivert method requires proxy to be reachable by a non loopback address."
f"No address found for {family.name} in {local_addresses}") f"No address found for {family.name} in {local_addresses}")
debug2("Found non loopback address to connect to proxy: " + proxy_ip) debug2(f"Found non loopback address to connect to proxy: {proxy_ip}")
subnet_addresses = [] subnet_addresses = []
for (_, mask, exclude, network_addr, fport, lport) in subnets: for (_, mask, exclude, network_addr, fport, lport) in subnets:
if exclude: if fport and lport:
continue if lport > fport:
assert fport == 0, "custom port range not supported" raise Fatal("lport must be less than or equal to fport")
assert lport == 0, "custom port range not supported" ports = (fport, lport)
subnet_addresses.append("%s/%s" % (network_addr, mask)) else:
ports = None
subnet_addresses.append((ip_network(f"{network_addr}/{mask}"), ports, exclude))
self.network_config[family] = { self.network_config[family] = {
"subnets": subnet_addresses, "subnets": subnet_addresses,
@ -391,30 +392,48 @@ class Method(BaseMethod):
"""divert outgoing packets to proxy""" """divert outgoing packets to proxy"""
proto = IPProtocol.TCP proto = IPProtocol.TCP
filter = f"outbound and {proto.filter}" filter = f"outbound and {proto.filter}"
af_filters = []
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
family_filters = []
for af, c in self.network_config.items(): for af, c in self.network_config.items():
subnet_filters = [] subnet_include_filters = []
for cidr in c["subnets"]: subnet_exclude_filters = []
ip_net = ip_network(cidr) for ip_net, ports, exclude in c["subnets"]:
first_ip = ip_net.network_address.exploded first_ip = ip_net.network_address.exploded
last_ip = ip_net.broadcast_address.exploded last_ip = ip_net.broadcast_address.exploded
subnet_filters.append(f"({af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip})") if first_ip == last_ip:
if not subnet_filters: _subney_filter = f"{af.filter}.DstAddr=={first_ip}"
continue else:
_subney_filter = f"{af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip}"
if ports:
if ports[0] == ports[1]:
_subney_filter += f" and {proto.filter}.DstPort=={ports[0]}"
else:
_subney_filter += f" and tcp.DstPort>={ports[0]} and tcp.DstPort<={ports[1]}"
(subnet_exclude_filters if exclude else subnet_include_filters).append(f'({_subney_filter})')
_af_filter = f"{af.filter}"
if subnet_include_filters:
_af_filter += f" and ({' or '.join(subnet_include_filters)})"
if subnet_exclude_filters:
# TODO(noma3ad) use not() operator with Windivert2 after upgrade
_af_filter += f" and (({' or '.join(subnet_exclude_filters)})? false : true)"
proxy_ip, proxy_port = c["proxy_addr"] proxy_ip, proxy_port = c["proxy_addr"]
proxy_guard_filter = f'({af.filter}.DstAddr!={proxy_ip} or tcp.DstPort!={proxy_port})' # Avoids proxy outbound traffic getting directed to itself
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) and {proxy_guard_filter}") proxy_guard_filter = f'(({af.filter}.DstAddr=={proxy_ip.exploded} and tcp.DstPort=={proxy_port})? false : true)'
if not family_filters: _af_filter += f" and {proxy_guard_filter}"
af_filters.append(_af_filter)
if not af_filters:
raise Fatal("At least one ipv4 or ipv6 subnet is expected") raise Fatal("At least one ipv4 or ipv6 subnet is expected")
filter = f"{filter} and ({' or '.join(family_filters)})" filter = f"{filter} and ({' or '.join(af_filters)})"
debug1(f"[EGRESS] {filter=}") debug1(f"[EGRESS] {filter=}")
with pydivert.WinDivert(filter, layer=pydivert.Layer.NETWORK, flags=pydivert.Flag.DEFAULT) as w: with pydivert.WinDivert(filter, layer=pydivert.Layer.NETWORK, flags=pydivert.Flag.DEFAULT) as w:
proxy_ipv4, proxy_ipv6 = None, None
if IPFamily.IPv4 in self.network_config:
proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"]
proxy_ipv4 = proxy_ipv4[0].exploded, proxy_ipv4[1]
if IPFamily.IPv6 in self.network_config:
proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"]
proxy_ipv6 = proxy_ipv6[0].exploded, proxy_ipv6[1]
ready_cb() ready_cb()
proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"] if IPFamily.IPv4 in self.network_config else None
proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"] if IPFamily.IPv6 in self.network_config else None
verbose = get_verbose_level() verbose = get_verbose_level()
for pkt in w: for pkt in w:
verbose >= 3 and debug3("[EGRESS] " + repr_pkt(pkt)) verbose >= 3 and debug3("[EGRESS] " + repr_pkt(pkt))
@ -461,7 +480,7 @@ class Method(BaseMethod):
continue continue
proxy_ip, proxy_port = c["proxy_addr"] proxy_ip, proxy_port = c["proxy_addr"]
# "ip.SrcAddr=={hex(int(proxy_ip))}" # only Windivert >=2 supports this # "ip.SrcAddr=={hex(int(proxy_ip))}" # only Windivert >=2 supports this
proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip} and tcp.SrcPort=={proxy_port}") proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip.exploded} and tcp.SrcPort=={proxy_port}")
if not proxy_addr_filters: if not proxy_addr_filters:
raise Fatal("At least one ipv4 or ipv6 address is expected") raise Fatal("At least one ipv4 or ipv6 address is expected")
filter = f"{direction} and {proto.filter} and ({' or '.join(proxy_addr_filters)})" filter = f"{direction} and {proto.filter} and ({' or '.join(proxy_addr_filters)})"