mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-21 23:43:18 +01:00
windivert: add ipv6 support and better thread handling
This commit is contained in:
parent
bd2f960743
commit
338486930f
@ -125,8 +125,9 @@ def _setup_daemon_windows():
|
||||
# debug3(f'FROM_SHARE ${socket_share_data_b64=}')
|
||||
socket_share_data = base64.b64decode(socket_share_data_b64)
|
||||
sock = socket.fromshare(socket_share_data)
|
||||
sys.stdin = io.TextIOWrapper(sock.makefile('rb'))
|
||||
sys.stdout = io.TextIOWrapper(sock.makefile('wb'))
|
||||
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
|
||||
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
|
||||
sock.close()
|
||||
return sys.stdin, sys.stdout
|
||||
|
||||
if sys.platform == 'win32':
|
||||
@ -324,10 +325,14 @@ def main(method_name, syslog):
|
||||
socket.AF_INET, subnets_v4, udp,
|
||||
user, group, tmark)
|
||||
|
||||
try:
|
||||
method.wait_for_firewall_ready()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
if sys.platform != 'win32':
|
||||
flush_systemd_dns_cache()
|
||||
|
||||
|
||||
try:
|
||||
stdout.write('STARTED\n')
|
||||
stdout.flush()
|
||||
@ -340,7 +345,9 @@ def main(method_name, syslog):
|
||||
# authentication at shutdown time - that cleanup is important!
|
||||
while 1:
|
||||
try:
|
||||
debug3("===================================================")
|
||||
line = stdin.readline(128)
|
||||
debug3("===================================================" + str(line))
|
||||
except IOError as e:
|
||||
debug3('read from stdin failed: %s' % (e,))
|
||||
return
|
||||
|
@ -14,7 +14,10 @@ def b(s):
|
||||
def log(s):
|
||||
global logprefix
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
except (IOError,ValueError):
|
||||
pass
|
||||
# Put newline at end of string if line doesn't have one.
|
||||
if not s.endswith("\n"):
|
||||
s = s+"\n"
|
||||
|
@ -97,6 +97,9 @@ class BaseMethod(object):
|
||||
def restore_firewall(self, port, family, udp, user, group):
|
||||
raise NotImplementedError()
|
||||
|
||||
def wait_for_firewall_ready(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def firewall_command(line):
|
||||
return False
|
||||
|
@ -9,6 +9,7 @@ import struct
|
||||
from functools import wraps
|
||||
from enum import IntEnum
|
||||
import time
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import pydivert
|
||||
@ -32,6 +33,27 @@ class IPProtocol(IntEnum):
|
||||
TCP = socket.IPPROTO_TCP
|
||||
UDP = socket.IPPROTO_UDP
|
||||
|
||||
@property
|
||||
def filter(self):
|
||||
return 'tcp' if self == IPProtocol.TCP else 'udp'
|
||||
|
||||
class IPFamily(IntEnum):
|
||||
IPv4 = socket.AF_INET
|
||||
IPv6 = socket.AF_INET6
|
||||
|
||||
@property
|
||||
def filter(self):
|
||||
return 'ip' if self == socket.AF_INET else 'ipv6'
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return 4 if self == socket.AF_INET else 6
|
||||
|
||||
@property
|
||||
def loopback_addr(self):
|
||||
return '127.0.0.1' if self == socket.AF_INET else '::1'
|
||||
|
||||
|
||||
class ConnState(IntEnum):
|
||||
TCP_SYN_SEND = 10
|
||||
TCP_SYN_ACK_RECV = 11
|
||||
@ -62,6 +84,14 @@ def synchronized_method(lock):
|
||||
return decorator
|
||||
|
||||
class ConnTrack:
|
||||
|
||||
_instance =None
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = object.__new__(cls)
|
||||
return cls._instance
|
||||
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
||||
|
||||
def __init__(self, name, max_connections=0) -> None:
|
||||
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B')))
|
||||
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H')))
|
||||
@ -161,11 +191,34 @@ class ConnTrack:
|
||||
|
||||
class Method(BaseMethod):
|
||||
|
||||
network_config = {}
|
||||
proxy_port = None
|
||||
proxy_addr = { IPFamily.IPv4: None, IPFamily.IPv6: None }
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
|
||||
user, tmark):
|
||||
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
||||
proxy_addr = "10.0.2.15"
|
||||
|
||||
if nslist or user or udp:
|
||||
raise NotImplementedError()
|
||||
|
||||
family = IPFamily(family)
|
||||
|
||||
# using loopback proxy address never worked. See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
|
||||
# As a workaround we use another interface ip instead.
|
||||
# self.proxy_addr[family] = family.loopback_addr
|
||||
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
|
||||
if addr.is_loopback or addr.version != family.version:
|
||||
continue
|
||||
self.proxy_addr[family] = str(addr)
|
||||
break
|
||||
else:
|
||||
raise Fatal(f"Could not find a non loopback proxy address for {family.name}")
|
||||
|
||||
self.proxy_port = port
|
||||
|
||||
subnet_addresses = []
|
||||
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
||||
@ -175,15 +228,33 @@ class Method(BaseMethod):
|
||||
assert lport == 0, 'custom port range not supported'
|
||||
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
||||
|
||||
debug2("setup_firewall() subnet_addresses=%s proxy_addr=%s:%s" % (subnet_addresses,proxy_addr,port))
|
||||
self.network_config[family] = {
|
||||
'subnets': subnet_addresses,
|
||||
"nslist": nslist,
|
||||
}
|
||||
|
||||
# check permission
|
||||
with pydivert.WinDivert('false'):
|
||||
pass
|
||||
|
||||
threading.Thread(name='outbound_divert', target=self._outbound_divert, args=(subnet_addresses, proxy_addr, port), daemon=True).start()
|
||||
threading.Thread(name='inbound_divert', target=self._inbound_divert, args=(proxy_addr, port), daemon=True).start()
|
||||
|
||||
def wait_for_firewall_ready(self):
|
||||
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
||||
methods = (self._egress_divert, self._ingress_divert)
|
||||
ready_events = []
|
||||
for fn in methods:
|
||||
ev = threading.Event()
|
||||
ready_events.append(ev)
|
||||
def _target():
|
||||
try:
|
||||
fn(ev.set)
|
||||
except:
|
||||
debug2(f'thread {fn.__name__} exiting due to: ' + traceback.format_exc())
|
||||
sys.stdin.close() # this will exist main thread
|
||||
sys.stdout.close()
|
||||
threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
|
||||
for ev in ready_events:
|
||||
if not ev.wait(5): # at most 5 sec
|
||||
raise Fatal(f"timeout in wait_for_firewall_ready()")
|
||||
|
||||
def restore_firewall(self, port, family, udp, user):
|
||||
pass
|
||||
|
||||
@ -209,23 +280,29 @@ class Method(BaseMethod):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _egress_divert(self, ready_cb):
|
||||
proto = IPProtocol.TCP
|
||||
filter = f"outbound and {proto.filter}"
|
||||
|
||||
|
||||
|
||||
|
||||
def _outbound_divert(self, subnets, proxy_addr, proxy_port):
|
||||
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
||||
filter = "outbound and ip and tcp"
|
||||
subnet_selectors = []
|
||||
for cidr in subnets:
|
||||
ip_network = ipaddress.ip_network(cidr)
|
||||
first_ip = ip_network.network_address
|
||||
last_ip = ip_network.broadcast_address
|
||||
subnet_selectors.append(f"(ip.DstAddr >= {first_ip} and ip.DstAddr <= {last_ip})")
|
||||
filter = f"{filter} and ({'or'.join(subnet_selectors)}) "
|
||||
family_filters = []
|
||||
for af, c in self.network_config.items():
|
||||
subnet_filters = []
|
||||
for cidr in c['subnets']:
|
||||
ip_network = ipaddress.ip_network(cidr)
|
||||
first_ip = ip_network.network_address
|
||||
last_ip = ip_network.broadcast_address
|
||||
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
|
||||
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
|
||||
|
||||
filter = f"{filter} and ({' or '.join(family_filters)})"
|
||||
|
||||
debug1(f"[OUTBOUND] {filter=}")
|
||||
with pydivert.WinDivert(filter) as w:
|
||||
ready_cb()
|
||||
proxy_port = self.proxy_port
|
||||
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
|
||||
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
||||
for pkt in w:
|
||||
debug3(">>> " + repr_pkt(pkt))
|
||||
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
|
||||
@ -234,15 +311,39 @@ class Method(BaseMethod):
|
||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
|
||||
if pkt.tcp.rst : # RST
|
||||
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||
pkt.ipv4.dst_addr = proxy_addr
|
||||
|
||||
# DNAT
|
||||
if pkt.ipv4 and proxy_addr_ipv4:
|
||||
pkt.dst_addr = proxy_addr_ipv4
|
||||
if pkt.ipv6 and proxy_addr_ipv6:
|
||||
pkt.dst_addr = proxy_addr_ipv6
|
||||
pkt.tcp.dst_port = proxy_port
|
||||
|
||||
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well by setting src_addr to loopback address.
|
||||
# Otherwise injecting packet will be ignored by Windows network stack as teh packet has to cross public to private address space.
|
||||
# See: https://github.com/basil00/Divert/issues/82
|
||||
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
|
||||
# >>> pkt.dst_addr = proxy_addr_ipv4
|
||||
|
||||
w.send(pkt, recalculate_checksum=True)
|
||||
|
||||
|
||||
def _inbound_divert(self, proxy_addr, proxy_port):
|
||||
filter = f"inbound and ip and tcp and ip.SrcAddr == {proxy_addr} and tcp.SrcPort == {proxy_port}"
|
||||
debug2(f"[INBOUND] {filter=}")
|
||||
def _ingress_divert(self, ready_cb):
|
||||
proto = IPProtocol.TCP
|
||||
direction = 'inbound' # only when proxy address is not loopback address (Useful for testing)
|
||||
ip_filters = []
|
||||
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
|
||||
if addr.is_loopback: # Windivert treats all loopback traffic as outbound
|
||||
direction = "outbound"
|
||||
if addr.version == 4:
|
||||
ip_filters.append(f"ip.SrcAddr=={addr}")
|
||||
else:
|
||||
# ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this
|
||||
ip_filters.append(f"ipv6.SrcAddr=={addr}")
|
||||
filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}"
|
||||
debug2(f"[INGRESS] {filter=}")
|
||||
with pydivert.WinDivert(filter) as w:
|
||||
ready_cb()
|
||||
for pkt in w:
|
||||
debug3("<<< " + repr_pkt(pkt))
|
||||
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
|
||||
@ -254,7 +355,7 @@ class Method(BaseMethod):
|
||||
if not conn:
|
||||
debug2("Unexpected packet: " + repr_pkt(pkt))
|
||||
continue
|
||||
pkt.ipv4.src_addr = conn.dst_addr
|
||||
pkt.src_addr = conn.dst_addr
|
||||
pkt.tcp.src_port = conn.dst_port
|
||||
w.send(pkt, recalculate_checksum=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user