mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-05-13 04:24:59 +02:00
windivert: add ipv6 support and better thread handling
This commit is contained in:
parent
bd2f960743
commit
338486930f
@ -125,8 +125,9 @@ def _setup_daemon_windows():
|
|||||||
# debug3(f'FROM_SHARE ${socket_share_data_b64=}')
|
# debug3(f'FROM_SHARE ${socket_share_data_b64=}')
|
||||||
socket_share_data = base64.b64decode(socket_share_data_b64)
|
socket_share_data = base64.b64decode(socket_share_data_b64)
|
||||||
sock = socket.fromshare(socket_share_data)
|
sock = socket.fromshare(socket_share_data)
|
||||||
sys.stdin = io.TextIOWrapper(sock.makefile('rb'))
|
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
|
||||||
sys.stdout = io.TextIOWrapper(sock.makefile('wb'))
|
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
|
||||||
|
sock.close()
|
||||||
return sys.stdin, sys.stdout
|
return sys.stdin, sys.stdout
|
||||||
|
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
@ -324,10 +325,14 @@ def main(method_name, syslog):
|
|||||||
socket.AF_INET, subnets_v4, udp,
|
socket.AF_INET, subnets_v4, udp,
|
||||||
user, group, tmark)
|
user, group, tmark)
|
||||||
|
|
||||||
|
try:
|
||||||
|
method.wait_for_firewall_ready()
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
if sys.platform != 'win32':
|
if sys.platform != 'win32':
|
||||||
flush_systemd_dns_cache()
|
flush_systemd_dns_cache()
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stdout.write('STARTED\n')
|
stdout.write('STARTED\n')
|
||||||
stdout.flush()
|
stdout.flush()
|
||||||
@ -340,7 +345,9 @@ def main(method_name, syslog):
|
|||||||
# authentication at shutdown time - that cleanup is important!
|
# authentication at shutdown time - that cleanup is important!
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
try:
|
||||||
|
debug3("===================================================")
|
||||||
line = stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
|
debug3("===================================================" + str(line))
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
debug3('read from stdin failed: %s' % (e,))
|
debug3('read from stdin failed: %s' % (e,))
|
||||||
return
|
return
|
||||||
|
@ -13,8 +13,11 @@ def b(s):
|
|||||||
|
|
||||||
def log(s):
|
def log(s):
|
||||||
global logprefix
|
global logprefix
|
||||||
|
try:
|
||||||
try:
|
try:
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
except (IOError,ValueError):
|
||||||
|
pass
|
||||||
# Put newline at end of string if line doesn't have one.
|
# Put newline at end of string if line doesn't have one.
|
||||||
if not s.endswith("\n"):
|
if not s.endswith("\n"):
|
||||||
s = s+"\n"
|
s = s+"\n"
|
||||||
|
@ -97,6 +97,9 @@ class BaseMethod(object):
|
|||||||
def restore_firewall(self, port, family, udp, user, group):
|
def restore_firewall(self, port, family, udp, user, group):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def wait_for_firewall_ready(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def firewall_command(line):
|
def firewall_command(line):
|
||||||
return False
|
return False
|
||||||
|
@ -9,6 +9,7 @@ import struct
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pydivert
|
import pydivert
|
||||||
@ -32,6 +33,27 @@ class IPProtocol(IntEnum):
|
|||||||
TCP = socket.IPPROTO_TCP
|
TCP = socket.IPPROTO_TCP
|
||||||
UDP = socket.IPPROTO_UDP
|
UDP = socket.IPPROTO_UDP
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filter(self):
|
||||||
|
return 'tcp' if self == IPProtocol.TCP else 'udp'
|
||||||
|
|
||||||
|
class IPFamily(IntEnum):
|
||||||
|
IPv4 = socket.AF_INET
|
||||||
|
IPv6 = socket.AF_INET6
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filter(self):
|
||||||
|
return 'ip' if self == socket.AF_INET else 'ipv6'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self):
|
||||||
|
return 4 if self == socket.AF_INET else 6
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loopback_addr(self):
|
||||||
|
return '127.0.0.1' if self == socket.AF_INET else '::1'
|
||||||
|
|
||||||
|
|
||||||
class ConnState(IntEnum):
|
class ConnState(IntEnum):
|
||||||
TCP_SYN_SEND = 10
|
TCP_SYN_SEND = 10
|
||||||
TCP_SYN_ACK_RECV = 11
|
TCP_SYN_ACK_RECV = 11
|
||||||
@ -62,6 +84,14 @@ def synchronized_method(lock):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
class ConnTrack:
|
class ConnTrack:
|
||||||
|
|
||||||
|
_instance =None
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if not cls._instance:
|
||||||
|
cls._instance = object.__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
||||||
|
|
||||||
def __init__(self, name, max_connections=0) -> None:
|
def __init__(self, name, max_connections=0) -> None:
|
||||||
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B')))
|
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B')))
|
||||||
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H')))
|
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H')))
|
||||||
@ -161,11 +191,34 @@ class ConnTrack:
|
|||||||
|
|
||||||
class Method(BaseMethod):
|
class Method(BaseMethod):
|
||||||
|
|
||||||
|
network_config = {}
|
||||||
|
proxy_port = None
|
||||||
|
proxy_addr = { IPFamily.IPv4: None, IPFamily.IPv6: None }
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
|
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
|
||||||
user, tmark):
|
user, tmark):
|
||||||
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
|
||||||
proxy_addr = "10.0.2.15"
|
if nslist or user or udp:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
family = IPFamily(family)
|
||||||
|
|
||||||
|
# using loopback proxy address never worked. See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
|
||||||
|
# As a workaround we use another interface ip instead.
|
||||||
|
# self.proxy_addr[family] = family.loopback_addr
|
||||||
|
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
|
||||||
|
if addr.is_loopback or addr.version != family.version:
|
||||||
|
continue
|
||||||
|
self.proxy_addr[family] = str(addr)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Fatal(f"Could not find a non loopback proxy address for {family.name}")
|
||||||
|
|
||||||
|
self.proxy_port = port
|
||||||
|
|
||||||
subnet_addresses = []
|
subnet_addresses = []
|
||||||
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
||||||
@ -175,14 +228,32 @@ class Method(BaseMethod):
|
|||||||
assert lport == 0, 'custom port range not supported'
|
assert lport == 0, 'custom port range not supported'
|
||||||
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
||||||
|
|
||||||
debug2("setup_firewall() subnet_addresses=%s proxy_addr=%s:%s" % (subnet_addresses,proxy_addr,port))
|
self.network_config[family] = {
|
||||||
|
'subnets': subnet_addresses,
|
||||||
|
"nslist": nslist,
|
||||||
|
}
|
||||||
|
|
||||||
# check permission
|
|
||||||
with pydivert.WinDivert('false'):
|
|
||||||
pass
|
|
||||||
|
|
||||||
threading.Thread(name='outbound_divert', target=self._outbound_divert, args=(subnet_addresses, proxy_addr, port), daemon=True).start()
|
|
||||||
threading.Thread(name='inbound_divert', target=self._inbound_divert, args=(proxy_addr, port), daemon=True).start()
|
def wait_for_firewall_ready(self):
|
||||||
|
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
||||||
|
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
||||||
|
methods = (self._egress_divert, self._ingress_divert)
|
||||||
|
ready_events = []
|
||||||
|
for fn in methods:
|
||||||
|
ev = threading.Event()
|
||||||
|
ready_events.append(ev)
|
||||||
|
def _target():
|
||||||
|
try:
|
||||||
|
fn(ev.set)
|
||||||
|
except:
|
||||||
|
debug2(f'thread {fn.__name__} exiting due to: ' + traceback.format_exc())
|
||||||
|
sys.stdin.close() # this will exist main thread
|
||||||
|
sys.stdout.close()
|
||||||
|
threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
|
||||||
|
for ev in ready_events:
|
||||||
|
if not ev.wait(5): # at most 5 sec
|
||||||
|
raise Fatal(f"timeout in wait_for_firewall_ready()")
|
||||||
|
|
||||||
def restore_firewall(self, port, family, udp, user):
|
def restore_firewall(self, port, family, udp, user):
|
||||||
pass
|
pass
|
||||||
@ -209,23 +280,29 @@ class Method(BaseMethod):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _egress_divert(self, ready_cb):
|
||||||
|
proto = IPProtocol.TCP
|
||||||
|
filter = f"outbound and {proto.filter}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _outbound_divert(self, subnets, proxy_addr, proxy_port):
|
|
||||||
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
||||||
filter = "outbound and ip and tcp"
|
family_filters = []
|
||||||
subnet_selectors = []
|
for af, c in self.network_config.items():
|
||||||
for cidr in subnets:
|
subnet_filters = []
|
||||||
|
for cidr in c['subnets']:
|
||||||
ip_network = ipaddress.ip_network(cidr)
|
ip_network = ipaddress.ip_network(cidr)
|
||||||
first_ip = ip_network.network_address
|
first_ip = ip_network.network_address
|
||||||
last_ip = ip_network.broadcast_address
|
last_ip = ip_network.broadcast_address
|
||||||
subnet_selectors.append(f"(ip.DstAddr >= {first_ip} and ip.DstAddr <= {last_ip})")
|
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
|
||||||
filter = f"{filter} and ({'or'.join(subnet_selectors)}) "
|
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
|
||||||
|
|
||||||
|
filter = f"{filter} and ({' or '.join(family_filters)})"
|
||||||
|
|
||||||
debug1(f"[OUTBOUND] {filter=}")
|
debug1(f"[OUTBOUND] {filter=}")
|
||||||
with pydivert.WinDivert(filter) as w:
|
with pydivert.WinDivert(filter) as w:
|
||||||
|
ready_cb()
|
||||||
|
proxy_port = self.proxy_port
|
||||||
|
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
|
||||||
|
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
debug3(">>> " + repr_pkt(pkt))
|
debug3(">>> " + repr_pkt(pkt))
|
||||||
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
|
if pkt.tcp.syn and not pkt.tcp.ack: # SYN (start of 3-way handshake connection establishment)
|
||||||
@ -234,15 +311,39 @@ class Method(BaseMethod):
|
|||||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
|
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_SEND)
|
||||||
if pkt.tcp.rst : # RST
|
if pkt.tcp.rst : # RST
|
||||||
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||||
pkt.ipv4.dst_addr = proxy_addr
|
|
||||||
|
# DNAT
|
||||||
|
if pkt.ipv4 and proxy_addr_ipv4:
|
||||||
|
pkt.dst_addr = proxy_addr_ipv4
|
||||||
|
if pkt.ipv6 and proxy_addr_ipv6:
|
||||||
|
pkt.dst_addr = proxy_addr_ipv6
|
||||||
pkt.tcp.dst_port = proxy_port
|
pkt.tcp.dst_port = proxy_port
|
||||||
|
|
||||||
|
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well by setting src_addr to loopback address.
|
||||||
|
# Otherwise injecting packet will be ignored by Windows network stack as teh packet has to cross public to private address space.
|
||||||
|
# See: https://github.com/basil00/Divert/issues/82
|
||||||
|
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
|
||||||
|
# >>> pkt.dst_addr = proxy_addr_ipv4
|
||||||
|
|
||||||
w.send(pkt, recalculate_checksum=True)
|
w.send(pkt, recalculate_checksum=True)
|
||||||
|
|
||||||
|
|
||||||
def _inbound_divert(self, proxy_addr, proxy_port):
|
def _ingress_divert(self, ready_cb):
|
||||||
filter = f"inbound and ip and tcp and ip.SrcAddr == {proxy_addr} and tcp.SrcPort == {proxy_port}"
|
proto = IPProtocol.TCP
|
||||||
debug2(f"[INBOUND] {filter=}")
|
direction = 'inbound' # only when proxy address is not loopback address (Useful for testing)
|
||||||
|
ip_filters = []
|
||||||
|
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
|
||||||
|
if addr.is_loopback: # Windivert treats all loopback traffic as outbound
|
||||||
|
direction = "outbound"
|
||||||
|
if addr.version == 4:
|
||||||
|
ip_filters.append(f"ip.SrcAddr=={addr}")
|
||||||
|
else:
|
||||||
|
# ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this
|
||||||
|
ip_filters.append(f"ipv6.SrcAddr=={addr}")
|
||||||
|
filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}"
|
||||||
|
debug2(f"[INGRESS] {filter=}")
|
||||||
with pydivert.WinDivert(filter) as w:
|
with pydivert.WinDivert(filter) as w:
|
||||||
|
ready_cb()
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
debug3("<<< " + repr_pkt(pkt))
|
debug3("<<< " + repr_pkt(pkt))
|
||||||
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
|
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
|
||||||
@ -254,7 +355,7 @@ class Method(BaseMethod):
|
|||||||
if not conn:
|
if not conn:
|
||||||
debug2("Unexpected packet: " + repr_pkt(pkt))
|
debug2("Unexpected packet: " + repr_pkt(pkt))
|
||||||
continue
|
continue
|
||||||
pkt.ipv4.src_addr = conn.dst_addr
|
pkt.src_addr = conn.dst_addr
|
||||||
pkt.tcp.src_port = conn.dst_port
|
pkt.tcp.src_port = conn.dst_port
|
||||||
w.send(pkt, recalculate_checksum=True)
|
w.send(pkt, recalculate_checksum=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user