windivert: add ipv6 support and better thread handling

This commit is contained in:
nom3ad 2022-09-07 12:26:21 +05:30 committed by Brian May
parent bd2f960743
commit 338486930f
4 changed files with 143 additions and 29 deletions

View File

@ -125,8 +125,9 @@ def _setup_daemon_windows():
# debug3(f'FROM_SHARE ${socket_share_data_b64=}')
socket_share_data = base64.b64decode(socket_share_data_b64)
sock = socket.fromshare(socket_share_data)
sys.stdin = io.TextIOWrapper(sock.makefile('rb'))
sys.stdout = io.TextIOWrapper(sock.makefile('wb'))
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
sock.close()
return sys.stdin, sys.stdout
if sys.platform == 'win32':
@ -324,10 +325,14 @@ def main(method_name, syslog):
socket.AF_INET, subnets_v4, udp,
user, group, tmark)
try:
method.wait_for_firewall_ready()
except NotImplementedError:
pass
if sys.platform != 'win32':
flush_systemd_dns_cache()
try:
stdout.write('STARTED\n')
stdout.flush()
@ -340,7 +345,9 @@ def main(method_name, syslog):
# authentication at shutdown time - that cleanup is important!
while 1:
try:
debug3("===================================================")
line = stdin.readline(128)
debug3("===================================================" + str(line))
except IOError as e:
debug3('read from stdin failed: %s' % (e,))
return

View File

@ -14,7 +14,10 @@ def b(s):
def log(s):
global logprefix
try:
sys.stdout.flush()
try:
sys.stdout.flush()
except (IOError,ValueError):
pass
# Put newline at end of string if line doesn't have one.
if not s.endswith("\n"):
s = s+"\n"

View File

@ -97,6 +97,9 @@ class BaseMethod(object):
def restore_firewall(self, port, family, udp, user, group):
raise NotImplementedError()
def wait_for_firewall_ready(self):
raise NotImplementedError()
@staticmethod
def firewall_command(line):
return False

View File

@ -9,6 +9,7 @@ import struct
from functools import wraps
from enum import IntEnum
import time
import traceback
try:
import pydivert
@ -32,6 +33,27 @@ class IPProtocol(IntEnum):
TCP = socket.IPPROTO_TCP
UDP = socket.IPPROTO_UDP
@property
def filter(self):
return 'tcp' if self == IPProtocol.TCP else 'udp'
class IPFamily(IntEnum):
IPv4 = socket.AF_INET
IPv6 = socket.AF_INET6
@property
def filter(self):
return 'ip' if self == socket.AF_INET else 'ipv6'
@property
def version(self):
return 4 if self == socket.AF_INET else 6
@property
def loopback_addr(self):
return '127.0.0.1' if self == socket.AF_INET else '::1'
class ConnState(IntEnum):
TCP_SYN_SEND = 10
TCP_SYN_ACK_RECV = 11
@ -62,6 +84,14 @@ def synchronized_method(lock):
return decorator
class ConnTrack:
_instance =None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = object.__new__(cls)
return cls._instance
raise RuntimeError("ConnTrack can not be instantiated multiple times")
def __init__(self, name, max_connections=0) -> None:
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B')))
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H')))
@ -161,11 +191,34 @@ class ConnTrack:
class Method(BaseMethod):
network_config = {}
proxy_port = None
proxy_addr = { IPFamily.IPv4: None, IPFamily.IPv6: None }
def __init__(self, name):
super().__init__(name)
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
user, tmark):
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
proxy_addr = "10.0.2.15"
if nslist or user or udp:
raise NotImplementedError()
family = IPFamily(family)
# using loopback proxy address never worked. See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
# As a workaround we use another interface ip instead.
# self.proxy_addr[family] = family.loopback_addr
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
if addr.is_loopback or addr.version != family.version:
continue
self.proxy_addr[family] = str(addr)
break
else:
raise Fatal(f"Could not find a non loopback proxy address for {family.name}")
self.proxy_port = port
subnet_addresses = []
for (_, mask, exclude, network_addr, fport, lport) in subnets:
@ -175,15 +228,33 @@ class Method(BaseMethod):
assert lport == 0, 'custom port range not supported'
subnet_addresses.append("%s/%s" % (network_addr, mask))
debug2("setup_firewall() subnet_addresses=%s proxy_addr=%s:%s" % (subnet_addresses,proxy_addr,port))
self.network_config[family] = {
'subnets': subnet_addresses,
"nslist": nslist,
}
# check permission
with pydivert.WinDivert('false'):
pass
threading.Thread(name='outbound_divert', target=self._outbound_divert, args=(subnet_addresses, proxy_addr, port), daemon=True).start()
threading.Thread(name='inbound_divert', target=self._inbound_divert, args=(proxy_addr, port), daemon=True).start()
def wait_for_firewall_ready(self):
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
methods = (self._egress_divert, self._ingress_divert)
ready_events = []
for fn in methods:
ev = threading.Event()
ready_events.append(ev)
def _target():
try:
fn(ev.set)
except:
debug2(f'thread {fn.__name__} exiting due to: ' + traceback.format_exc())
sys.stdin.close() # this will exist main thread
sys.stdout.close()
threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
for ev in ready_events:
if not ev.wait(5): # at most 5 sec
raise Fatal(f"timeout in wait_for_firewall_ready()")
def restore_firewall(self, port, family, udp, user):
pass
@ -209,23 +280,29 @@ class Method(BaseMethod):
return True
return False
def _egress_divert(self, ready_cb):
proto = IPProtocol.TCP
filter = f"outbound and {proto.filter}"
def _outbound_divert(self, subnets, proxy_addr, proxy_port):
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
filter = "outbound and ip and tcp"
subnet_selectors = []
for cidr in subnets:
ip_network = ipaddress.ip_network(cidr)
first_ip = ip_network.network_address
last_ip = ip_network.broadcast_address
subnet_selectors.append(f"(ip.DstAddr >= {first_ip} and ip.DstAddr <= {last_ip})")
filter = f"{filter} and ({'or'.join(subnet_selectors)}) "
family_filters = []
for af, c in self.network_config.items():
subnet_filters = []
for cidr in c['subnets']:
ip_network = ipaddress.ip_network(cidr)
first_ip = ip_network.network_address
last_ip = ip_network.broadcast_address
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
filter = f"{filter} and ({' or '.join(family_filters)})"
debug1(f"[OUTBOUND] {filter=}")
with pydivert.WinDivert(filter) as w:
ready_cb()
proxy_port = self.proxy_port
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
for pkt in w:
debug3(">>> " + repr_pkt(pkt))
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
@ -234,15 +311,39 @@ class Method(BaseMethod):
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
if pkt.tcp.rst : # RST
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
pkt.ipv4.dst_addr = proxy_addr
# DNAT
if pkt.ipv4 and proxy_addr_ipv4:
pkt.dst_addr = proxy_addr_ipv4
if pkt.ipv6 and proxy_addr_ipv6:
pkt.dst_addr = proxy_addr_ipv6
pkt.tcp.dst_port = proxy_port
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well by setting src_addr to loopback address.
# Otherwise injecting packet will be ignored by Windows network stack as teh packet has to cross public to private address space.
# See: https://github.com/basil00/Divert/issues/82
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
# >>> pkt.dst_addr = proxy_addr_ipv4
w.send(pkt, recalculate_checksum=True)
def _inbound_divert(self, proxy_addr, proxy_port):
filter = f"inbound and ip and tcp and ip.SrcAddr == {proxy_addr} and tcp.SrcPort == {proxy_port}"
debug2(f"[INBOUND] {filter=}")
def _ingress_divert(self, ready_cb):
proto = IPProtocol.TCP
direction = 'inbound' # only when proxy address is not loopback address (Useful for testing)
ip_filters = []
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
if addr.is_loopback: # Windivert treats all loopback traffic as outbound
direction = "outbound"
if addr.version == 4:
ip_filters.append(f"ip.SrcAddr=={addr}")
else:
# ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this
ip_filters.append(f"ipv6.SrcAddr=={addr}")
filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}"
debug2(f"[INGRESS] {filter=}")
with pydivert.WinDivert(filter) as w:
ready_cb()
for pkt in w:
debug3("<<< " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
@ -254,7 +355,7 @@ class Method(BaseMethod):
if not conn:
debug2("Unexpected packet: " + repr_pkt(pkt))
continue
pkt.ipv4.src_addr = conn.dst_addr
pkt.src_addr = conn.dst_addr
pkt.tcp.src_port = conn.dst_port
w.send(pkt, recalculate_checksum=True)