mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-05-07 01:24:27 +02:00
support port ranges and exclude subnets
This commit is contained in:
parent
72060abbef
commit
89a94ff150
@ -7,7 +7,7 @@ import socket
|
|||||||
import subprocess
|
import subprocess
|
||||||
import re
|
import re
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
import struct
|
from struct import Struct
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import time
|
import time
|
||||||
@ -61,7 +61,7 @@ class IPFamily(IntEnum):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def loopback_addr(self):
|
def loopback_addr(self):
|
||||||
return "127.0.0.1" if self == socket.AF_INET else "::1"
|
return ip_address("127.0.0.1" if self == socket.AF_INET else "::1")
|
||||||
|
|
||||||
|
|
||||||
class ConnState(IntEnum):
|
class ConnState(IntEnum):
|
||||||
@ -123,9 +123,9 @@ class ConnTrack:
|
|||||||
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
||||||
|
|
||||||
def __init__(self, name, max_connections=0) -> None:
|
def __init__(self, name, max_connections=0) -> None:
|
||||||
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
|
self.struct_full_tuple = Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
|
||||||
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
|
self.struct_src_tuple = Struct(">" + "".join(("B", "B", "16s", "H")))
|
||||||
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
|
self.struct_state_tuple = Struct(">" + "".join(("L", "B")))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.max_connections = max_connections
|
self.max_connections = max_connections
|
||||||
@ -142,8 +142,8 @@ class ConnTrack:
|
|||||||
self.max_connections = len(self.shm_list)
|
self.max_connections = len(self.shm_list)
|
||||||
|
|
||||||
debug2(
|
debug2(
|
||||||
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
|
f"ConnTrack: is_owner={self.is_owner} cap={len(self.shm_list)} item_sz={self.struct_full_tuple.size}B"
|
||||||
f"shm_size={self.shm_list.shm.size}B"
|
f"shm_name={self.shm_list.shm.name} shm_sz={self.shm_list.shm.size}B"
|
||||||
)
|
)
|
||||||
|
|
||||||
@synchronized_method("rlock")
|
@synchronized_method("rlock")
|
||||||
@ -279,7 +279,6 @@ class ConnTrack:
|
|||||||
class Method(BaseMethod):
|
class Method(BaseMethod):
|
||||||
|
|
||||||
network_config = {}
|
network_config = {}
|
||||||
proxy_port = None
|
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
@ -297,10 +296,10 @@ class Method(BaseMethod):
|
|||||||
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
|
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
|
||||||
|
|
||||||
def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark):
|
def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark):
|
||||||
debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {group=} {tmark=}")
|
||||||
|
|
||||||
if nslist or user or udp:
|
if nslist or user or udp or group:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("user, group, nslist, udp are not supported")
|
||||||
|
|
||||||
family = IPFamily(family)
|
family = IPFamily(family)
|
||||||
|
|
||||||
@ -312,24 +311,26 @@ class Method(BaseMethod):
|
|||||||
if proxy_bind_addr.is_loopback:
|
if proxy_bind_addr.is_loopback:
|
||||||
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.")
|
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.")
|
||||||
if not proxy_bind_addr.is_unspecified:
|
if not proxy_bind_addr.is_unspecified:
|
||||||
proxy_ip = proxy_bind_addr.exploded
|
proxy_ip = proxy_bind_addr
|
||||||
else:
|
else:
|
||||||
local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)]
|
local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)]
|
||||||
for addr in local_addresses:
|
for addr in local_addresses:
|
||||||
if not addr.is_loopback and not addr.is_link_local:
|
if not addr.is_loopback and not addr.is_link_local:
|
||||||
proxy_ip = addr.exploded
|
proxy_ip = addr
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address."
|
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address."
|
||||||
f"No address found for {family.name} in {local_addresses}")
|
f"No address found for {family.name} in {local_addresses}")
|
||||||
debug2("Found non loopback address to connect to proxy: " + proxy_ip)
|
debug2(f"Found non loopback address to connect to proxy: {proxy_ip}")
|
||||||
subnet_addresses = []
|
subnet_addresses = []
|
||||||
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
||||||
if exclude:
|
if fport and lport:
|
||||||
continue
|
if lport > fport:
|
||||||
assert fport == 0, "custom port range not supported"
|
raise Fatal("lport must be less than or equal to fport")
|
||||||
assert lport == 0, "custom port range not supported"
|
ports = (fport, lport)
|
||||||
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
else:
|
||||||
|
ports = None
|
||||||
|
subnet_addresses.append((ip_network(f"{network_addr}/{mask}"), ports, exclude))
|
||||||
|
|
||||||
self.network_config[family] = {
|
self.network_config[family] = {
|
||||||
"subnets": subnet_addresses,
|
"subnets": subnet_addresses,
|
||||||
@ -391,30 +392,48 @@ class Method(BaseMethod):
|
|||||||
"""divert outgoing packets to proxy"""
|
"""divert outgoing packets to proxy"""
|
||||||
proto = IPProtocol.TCP
|
proto = IPProtocol.TCP
|
||||||
filter = f"outbound and {proto.filter}"
|
filter = f"outbound and {proto.filter}"
|
||||||
|
af_filters = []
|
||||||
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
|
||||||
family_filters = []
|
|
||||||
for af, c in self.network_config.items():
|
for af, c in self.network_config.items():
|
||||||
subnet_filters = []
|
subnet_include_filters = []
|
||||||
for cidr in c["subnets"]:
|
subnet_exclude_filters = []
|
||||||
ip_net = ip_network(cidr)
|
for ip_net, ports, exclude in c["subnets"]:
|
||||||
first_ip = ip_net.network_address.exploded
|
first_ip = ip_net.network_address.exploded
|
||||||
last_ip = ip_net.broadcast_address.exploded
|
last_ip = ip_net.broadcast_address.exploded
|
||||||
subnet_filters.append(f"({af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip})")
|
if first_ip == last_ip:
|
||||||
if not subnet_filters:
|
_subney_filter = f"{af.filter}.DstAddr=={first_ip}"
|
||||||
continue
|
else:
|
||||||
|
_subney_filter = f"{af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip}"
|
||||||
|
if ports:
|
||||||
|
if ports[0] == ports[1]:
|
||||||
|
_subney_filter += f" and {proto.filter}.DstPort=={ports[0]}"
|
||||||
|
else:
|
||||||
|
_subney_filter += f" and tcp.DstPort>={ports[0]} and tcp.DstPort<={ports[1]}"
|
||||||
|
(subnet_exclude_filters if exclude else subnet_include_filters).append(f'({_subney_filter})')
|
||||||
|
_af_filter = f"{af.filter}"
|
||||||
|
if subnet_include_filters:
|
||||||
|
_af_filter += f" and ({' or '.join(subnet_include_filters)})"
|
||||||
|
if subnet_exclude_filters:
|
||||||
|
# TODO(noma3ad) use not() operator with Windivert2 after upgrade
|
||||||
|
_af_filter += f" and (({' or '.join(subnet_exclude_filters)})? false : true)"
|
||||||
proxy_ip, proxy_port = c["proxy_addr"]
|
proxy_ip, proxy_port = c["proxy_addr"]
|
||||||
proxy_guard_filter = f'({af.filter}.DstAddr!={proxy_ip} or tcp.DstPort!={proxy_port})'
|
# Avoids proxy outbound traffic getting directed to itself
|
||||||
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) and {proxy_guard_filter}")
|
proxy_guard_filter = f'(({af.filter}.DstAddr=={proxy_ip.exploded} and tcp.DstPort=={proxy_port})? false : true)'
|
||||||
if not family_filters:
|
_af_filter += f" and {proxy_guard_filter}"
|
||||||
|
af_filters.append(_af_filter)
|
||||||
|
if not af_filters:
|
||||||
raise Fatal("At least one ipv4 or ipv6 subnet is expected")
|
raise Fatal("At least one ipv4 or ipv6 subnet is expected")
|
||||||
|
|
||||||
filter = f"{filter} and ({' or '.join(family_filters)})"
|
filter = f"{filter} and ({' or '.join(af_filters)})"
|
||||||
debug1(f"[EGRESS] {filter=}")
|
debug1(f"[EGRESS] {filter=}")
|
||||||
with pydivert.WinDivert(filter, layer=pydivert.Layer.NETWORK, flags=pydivert.Flag.DEFAULT) as w:
|
with pydivert.WinDivert(filter, layer=pydivert.Layer.NETWORK, flags=pydivert.Flag.DEFAULT) as w:
|
||||||
|
proxy_ipv4, proxy_ipv6 = None, None
|
||||||
|
if IPFamily.IPv4 in self.network_config:
|
||||||
|
proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"]
|
||||||
|
proxy_ipv4 = proxy_ipv4[0].exploded, proxy_ipv4[1]
|
||||||
|
if IPFamily.IPv6 in self.network_config:
|
||||||
|
proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"]
|
||||||
|
proxy_ipv6 = proxy_ipv6[0].exploded, proxy_ipv6[1]
|
||||||
ready_cb()
|
ready_cb()
|
||||||
proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"] if IPFamily.IPv4 in self.network_config else None
|
|
||||||
proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"] if IPFamily.IPv6 in self.network_config else None
|
|
||||||
verbose = get_verbose_level()
|
verbose = get_verbose_level()
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
verbose >= 3 and debug3("[EGRESS] " + repr_pkt(pkt))
|
verbose >= 3 and debug3("[EGRESS] " + repr_pkt(pkt))
|
||||||
@ -461,7 +480,7 @@ class Method(BaseMethod):
|
|||||||
continue
|
continue
|
||||||
proxy_ip, proxy_port = c["proxy_addr"]
|
proxy_ip, proxy_port = c["proxy_addr"]
|
||||||
# "ip.SrcAddr=={hex(int(proxy_ip))}" # only Windivert >=2 supports this
|
# "ip.SrcAddr=={hex(int(proxy_ip))}" # only Windivert >=2 supports this
|
||||||
proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip} and tcp.SrcPort=={proxy_port}")
|
proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip.exploded} and tcp.SrcPort=={proxy_port}")
|
||||||
if not proxy_addr_filters:
|
if not proxy_addr_filters:
|
||||||
raise Fatal("At least one ipv4 or ipv6 address is expected")
|
raise Fatal("At least one ipv4 or ipv6 address is expected")
|
||||||
filter = f"{direction} and {proto.filter} and ({' or '.join(proxy_addr_filters)})"
|
filter = f"{direction} and {proto.filter} and ({' or '.join(proxy_addr_filters)})"
|
||||||
|
Loading…
Reference in New Issue
Block a user