support port ranges and exclude subnets

This commit is contained in:
nom3ad 2024-01-08 22:22:50 +05:30 committed by Brian May
parent 72060abbef
commit 89a94ff150

View File

@ -7,7 +7,7 @@ import socket
import subprocess
import re
from multiprocessing import shared_memory
import struct
from struct import Struct
from functools import wraps
from enum import IntEnum
import time
@ -61,7 +61,7 @@ class IPFamily(IntEnum):
@property
def loopback_addr(self):
return "127.0.0.1" if self == socket.AF_INET else "::1"
return ip_address("127.0.0.1" if self == socket.AF_INET else "::1")
class ConnState(IntEnum):
@ -123,9 +123,9 @@ class ConnTrack:
raise RuntimeError("ConnTrack can not be instantiated multiple times")
def __init__(self, name, max_connections=0) -> None:
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
self.struct_full_tuple = Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
self.struct_src_tuple = Struct(">" + "".join(("B", "B", "16s", "H")))
self.struct_state_tuple = Struct(">" + "".join(("L", "B")))
try:
self.max_connections = max_connections
@ -142,8 +142,8 @@ class ConnTrack:
self.max_connections = len(self.shm_list)
debug2(
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
f"shm_size={self.shm_list.shm.size}B"
f"ConnTrack: is_owner={self.is_owner} cap={len(self.shm_list)} item_sz={self.struct_full_tuple.size}B"
f"shm_name={self.shm_list.shm.name} shm_sz={self.shm_list.shm.size}B"
)
@synchronized_method("rlock")
@ -279,7 +279,6 @@ class ConnTrack:
class Method(BaseMethod):
network_config = {}
proxy_port = None
def __init__(self, name):
super().__init__(name)
@ -297,10 +296,10 @@ class Method(BaseMethod):
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark):
debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {group=} {tmark=}")
if nslist or user or udp:
raise NotImplementedError()
if nslist or user or udp or group:
raise NotImplementedError("user, group, nslist, udp are not supported")
family = IPFamily(family)
@ -312,24 +311,26 @@ class Method(BaseMethod):
if proxy_bind_addr.is_loopback:
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.")
if not proxy_bind_addr.is_unspecified:
proxy_ip = proxy_bind_addr.exploded
proxy_ip = proxy_bind_addr
else:
local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)]
for addr in local_addresses:
if not addr.is_loopback and not addr.is_link_local:
proxy_ip = addr.exploded
proxy_ip = addr
break
else:
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address."
f"No address found for {family.name} in {local_addresses}")
debug2("Found non loopback address to connect to proxy: " + proxy_ip)
debug2(f"Found non loopback address to connect to proxy: {proxy_ip}")
subnet_addresses = []
for (_, mask, exclude, network_addr, fport, lport) in subnets:
if exclude:
continue
assert fport == 0, "custom port range not supported"
assert lport == 0, "custom port range not supported"
subnet_addresses.append("%s/%s" % (network_addr, mask))
if fport and lport:
if lport > fport:
raise Fatal("lport must be less than or equal to fport")
ports = (fport, lport)
else:
ports = None
subnet_addresses.append((ip_network(f"{network_addr}/{mask}"), ports, exclude))
self.network_config[family] = {
"subnets": subnet_addresses,
@ -391,30 +392,48 @@ class Method(BaseMethod):
"""divert outgoing packets to proxy"""
proto = IPProtocol.TCP
filter = f"outbound and {proto.filter}"
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
family_filters = []
af_filters = []
for af, c in self.network_config.items():
subnet_filters = []
for cidr in c["subnets"]:
ip_net = ip_network(cidr)
subnet_include_filters = []
subnet_exclude_filters = []
for ip_net, ports, exclude in c["subnets"]:
first_ip = ip_net.network_address.exploded
last_ip = ip_net.broadcast_address.exploded
subnet_filters.append(f"({af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip})")
if not subnet_filters:
continue
if first_ip == last_ip:
_subney_filter = f"{af.filter}.DstAddr=={first_ip}"
else:
_subney_filter = f"{af.filter}.DstAddr>={first_ip} and {af.filter}.DstAddr<={last_ip}"
if ports:
if ports[0] == ports[1]:
_subney_filter += f" and {proto.filter}.DstPort=={ports[0]}"
else:
_subney_filter += f" and tcp.DstPort>={ports[0]} and tcp.DstPort<={ports[1]}"
(subnet_exclude_filters if exclude else subnet_include_filters).append(f'({_subney_filter})')
_af_filter = f"{af.filter}"
if subnet_include_filters:
_af_filter += f" and ({' or '.join(subnet_include_filters)})"
if subnet_exclude_filters:
# TODO(noma3ad) use not() operator with Windivert2 after upgrade
_af_filter += f" and (({' or '.join(subnet_exclude_filters)})? false : true)"
proxy_ip, proxy_port = c["proxy_addr"]
proxy_guard_filter = f'({af.filter}.DstAddr!={proxy_ip} or tcp.DstPort!={proxy_port})'
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) and {proxy_guard_filter}")
if not family_filters:
# Avoids proxy outbound traffic getting directed to itself
proxy_guard_filter = f'(({af.filter}.DstAddr=={proxy_ip.exploded} and tcp.DstPort=={proxy_port})? false : true)'
_af_filter += f" and {proxy_guard_filter}"
af_filters.append(_af_filter)
if not af_filters:
raise Fatal("At least one ipv4 or ipv6 subnet is expected")
filter = f"{filter} and ({' or '.join(family_filters)})"
filter = f"{filter} and ({' or '.join(af_filters)})"
debug1(f"[EGRESS] {filter=}")
with pydivert.WinDivert(filter, layer=pydivert.Layer.NETWORK, flags=pydivert.Flag.DEFAULT) as w:
proxy_ipv4, proxy_ipv6 = None, None
if IPFamily.IPv4 in self.network_config:
proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"]
proxy_ipv4 = proxy_ipv4[0].exploded, proxy_ipv4[1]
if IPFamily.IPv6 in self.network_config:
proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"]
proxy_ipv6 = proxy_ipv6[0].exploded, proxy_ipv6[1]
ready_cb()
proxy_ipv4 = self.network_config[IPFamily.IPv4]["proxy_addr"] if IPFamily.IPv4 in self.network_config else None
proxy_ipv6 = self.network_config[IPFamily.IPv6]["proxy_addr"] if IPFamily.IPv6 in self.network_config else None
verbose = get_verbose_level()
for pkt in w:
verbose >= 3 and debug3("[EGRESS] " + repr_pkt(pkt))
@ -461,7 +480,7 @@ class Method(BaseMethod):
continue
proxy_ip, proxy_port = c["proxy_addr"]
# "ip.SrcAddr=={hex(int(proxy_ip))}" # only Windivert >=2 supports this
proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip} and tcp.SrcPort=={proxy_port}")
proxy_addr_filters.append(f"{af.filter}.SrcAddr=={proxy_ip.exploded} and tcp.SrcPort=={proxy_port}")
if not proxy_addr_filters:
raise Fatal("At least one ipv4 or ipv6 address is expected")
filter = f"{direction} and {proto.filter} and ({' or '.join(proxy_addr_filters)})"