fix windows CRLF issue on stdin/stdout

This commit is contained in:
nom3ad 2024-01-02 00:24:31 +05:30 committed by Brian May
parent 900acc3ac7
commit 4a84ad3be6
5 changed files with 520 additions and 514 deletions

View File

@ -307,7 +307,7 @@ class FirewallClient:
def get_pfile(): def get_pfile():
if can_use_stdio: if can_use_stdio:
self.p.stdin.write(b'STDIO:\n') self.p.stdin.write(b'COM_STDIO:\n')
self.p.stdin.flush() self.p.stdin.flush()
class RWPair: class RWPair:
@ -334,7 +334,7 @@ class FirewallClient:
socket_share_data = s1.share(self.p.pid) socket_share_data = s1.share(self.p.pid)
s1.close() s1.close()
socket_share_data_b64 = base64.b64encode(socket_share_data) socket_share_data_b64 = base64.b64encode(socket_share_data)
self.p.stdin.write(b'SOCKETSHARE:' + socket_share_data_b64 + b'\n') self.p.stdin.write(b'COM_SOCKETSHARE:' + socket_share_data_b64 + b'\n')
self.p.stdin.flush() self.p.stdin.flush()
return s2.makefile('rwb') return s2.makefile('rwb')
try: try:

View File

@ -84,7 +84,7 @@ def firewall_exit(signum, frame):
# the typical exit process as described above. # the typical exit process as described above.
global sshuttle_pid global sshuttle_pid
if sshuttle_pid: if sshuttle_pid:
debug1("Relaying interupt signal to sshuttle process %d\n" % sshuttle_pid) debug1("Relaying interupt signal to sshuttle process %d" % sshuttle_pid)
if sys.platform == 'win32': if sys.platform == 'win32':
sig = signal.CTRL_C_EVENT sig = signal.CTRL_C_EVENT
else: else:
@ -115,7 +115,7 @@ def _setup_daemon_for_unix_like():
# setsid() fails if sudo is configured with the use_pty option. # setsid() fails if sudo is configured with the use_pty option.
pass pass
return sys.stdin, sys.stdout return sys.stdin.buffer, sys.stdout.buffer
def _setup_daemon_for_windows(): def _setup_daemon_for_windows():
@ -125,9 +125,9 @@ def _setup_daemon_for_windows():
signal.signal(signal.SIGTERM, firewall_exit) signal.signal(signal.SIGTERM, firewall_exit)
signal.signal(signal.SIGINT, firewall_exit) signal.signal(signal.SIGINT, firewall_exit)
socket_share_data_prefix = 'SOCKETSHARE:' socket_share_data_prefix = b'COM_SOCKETSHARE:'
line = sys.stdin.readline().strip() line = sys.stdin.buffer.readline().strip()
if line.startswith('SOCKETSHARE:'): if line.startswith(socket_share_data_prefix):
debug3('Using shared socket for communicating with sshuttle client process') debug3('Using shared socket for communicating with sshuttle client process')
socket_share_data_b64 = line[len(socket_share_data_prefix):] socket_share_data_b64 = line[len(socket_share_data_prefix):]
socket_share_data = base64.b64decode(socket_share_data_b64) socket_share_data = base64.b64decode(socket_share_data_b64)
@ -135,12 +135,12 @@ def _setup_daemon_for_windows():
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0)) sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True) sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
sock.close() sock.close()
elif line.startswith("STDIO:"): elif line.startswith(b"COM_STDIO:"):
debug3('Using inherited stdio for communicating with sshuttle client process') debug3('Using inherited stdio for communicating with sshuttle client process')
else: else:
raise Fatal("Unexpected stdin: " + line) raise Fatal("Unexpected stdin: " + line)
return sys.stdin, sys.stdout return sys.stdin.buffer, sys.stdout.buffer
# Isolate function that needs to be replaced for tests # Isolate function that needs to be replaced for tests
@ -221,33 +221,43 @@ def main(method_name, syslog):
"PATH." % method_name) "PATH." % method_name)
debug1('ready method name %s.' % method.name) debug1('ready method name %s.' % method.name)
stdout.write('READY %s\n' % method.name) stdout.write(('READY %s\n' % method.name).encode('ASCII'))
stdout.flush() stdout.flush()
def _read_next_string_line():
try:
line = stdin.readline(128)
if not line:
return # parent probably exited
return line.decode('ASCII').strip()
except IOError as e:
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
debug3('read from stdin failed: %s' % (e,))
return
# we wait until we get some input before creating the rules. That way, # we wait until we get some input before creating the rules. That way,
# sshuttle can launch us as early as possible (and get sudo password # sshuttle can launch us as early as possible (and get sudo password
# authentication as early in the startup process as possible). # authentication as early in the startup process as possible).
try: try:
line = stdin.readline(128) line = _read_next_string_line()
if not line: if not line:
return # parent probably exited return # parent probably exited
except ConnectionResetError as e: except IOError as e:
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end # On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
debug3('read from stdin failed: %s' % (e,)) debug3('read from stdin failed: %s' % (e,))
return return
subnets = [] subnets = []
if line != 'ROUTES\n': if line != 'ROUTES':
raise Fatal('expected ROUTES but got %r' % line) raise Fatal('expected ROUTES but got %r' % line)
while 1: while 1:
line = stdin.readline(128) line = _read_next_string_line()
if not line: if not line:
raise Fatal('expected route but got %r' % line) raise Fatal('expected route but got %r' % line)
elif line.startswith("NSLIST\n"): elif line.startswith("NSLIST"):
break break
try: try:
(family, width, exclude, ip, fport, lport) = \ (family, width, exclude, ip, fport, lport) = line.split(',', 5)
line.strip().split(',', 5)
except Exception: except Exception:
raise Fatal('expected route or NSLIST but got %r' % line) raise Fatal('expected route or NSLIST but got %r' % line)
subnets.append(( subnets.append((
@ -260,16 +270,16 @@ def main(method_name, syslog):
debug2('Got subnets: %r' % subnets) debug2('Got subnets: %r' % subnets)
nslist = [] nslist = []
if line != 'NSLIST\n': if line != 'NSLIST':
raise Fatal('expected NSLIST but got %r' % line) raise Fatal('expected NSLIST but got %r' % line)
while 1: while 1:
line = stdin.readline(128) line = _read_next_string_line()
if not line: if not line:
raise Fatal('expected nslist but got %r' % line) raise Fatal('expected nslist but got %r' % line)
elif line.startswith("PORTS "): elif line.startswith("PORTS "):
break break
try: try:
(family, ip) = line.strip().split(',', 1) (family, ip) = line.split(',', 1)
except Exception: except Exception:
raise Fatal('expected nslist or PORTS but got %r' % line) raise Fatal('expected nslist or PORTS but got %r' % line)
nslist.append((int(family), ip)) nslist.append((int(family), ip))
@ -299,15 +309,13 @@ def main(method_name, syslog):
debug2('Got ports: %d,%d,%d,%d' debug2('Got ports: %d,%d,%d,%d'
% (port_v6, port_v4, dnsport_v6, dnsport_v4)) % (port_v6, port_v4, dnsport_v6, dnsport_v4))
line = stdin.readline(128) line = _read_next_string_line()
if not line: if not line or not line.startswith("GO "):
raise Fatal('expected GO but got %r' % line)
elif not line.startswith("GO "):
raise Fatal('expected GO but got %r' % line) raise Fatal('expected GO but got %r' % line)
_, _, args = line.partition(" ") _, _, args = line.partition(" ")
global sshuttle_pid global sshuttle_pid
udp, user, group, tmark, sshuttle_pid = args.strip().split(" ", 4) udp, user, group, tmark, sshuttle_pid = args.split(" ", 4)
udp = bool(int(udp)) udp = bool(int(udp))
sshuttle_pid = int(sshuttle_pid) sshuttle_pid = int(sshuttle_pid)
if user == '-': if user == '-':
@ -350,7 +358,7 @@ def main(method_name, syslog):
flush_systemd_dns_cache() flush_systemd_dns_cache()
try: try:
stdout.write('STARTED\n') stdout.write(b'STARTED\n')
stdout.flush() stdout.flush()
except IOError as e: # the parent process probably died except IOError as e: # the parent process probably died
debug3('write to stdout failed: %s' % (e,)) debug3('write to stdout failed: %s' % (e,))
@ -360,13 +368,11 @@ def main(method_name, syslog):
# to stay running so that we don't need a *second* password # to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important! # authentication at shutdown time - that cleanup is important!
while 1: while 1:
try: line = _read_next_string_line()
line = stdin.readline(128) if not line:
except IOError as e:
debug3('read from stdin failed: %s' % (e,))
return return
if line.startswith('HOST '): if line.startswith('HOST '):
(name, ip) = line[5:].strip().split(',', 1) (name, ip) = line[5:].split(',', 1)
hostmap[name] = ip hostmap[name] = ip
debug2('setting up /etc/hosts.') debug2('setting up /etc/hosts.')
rewrite_etc_hosts(hostmap, port_v6 or port_v4) rewrite_etc_hosts(hostmap, port_v6 or port_v4)

View File

@ -1,479 +1,479 @@
import os import os
import sys import sys
import ipaddress import ipaddress
import threading import threading
from collections import namedtuple from collections import namedtuple
import socket import socket
import subprocess import subprocess
import re import re
from multiprocessing import shared_memory from multiprocessing import shared_memory
import struct import struct
from functools import wraps from functools import wraps
from enum import IntEnum from enum import IntEnum
import time import time
import traceback import traceback
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal from sshuttle.helpers import debug3, log, debug1, debug2, Fatal
try: try:
# https://reqrypt.org/windivert-doc.html#divert_iphdr # https://reqrypt.org/windivert-doc.html#divert_iphdr
import pydivert import pydivert
except ImportError: except ImportError:
raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert") raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert")
ConnectionTuple = namedtuple( ConnectionTuple = namedtuple(
"ConnectionTuple", "ConnectionTuple",
["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"], ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"],
) )
WINDIVERT_MAX_CONNECTIONS = 10_000 WINDIVERT_MAX_CONNECTIONS = 10_000
class IPProtocol(IntEnum): class IPProtocol(IntEnum):
TCP = socket.IPPROTO_TCP TCP = socket.IPPROTO_TCP
UDP = socket.IPPROTO_UDP UDP = socket.IPPROTO_UDP
@property @property
def filter(self): def filter(self):
return "tcp" if self == IPProtocol.TCP else "udp" return "tcp" if self == IPProtocol.TCP else "udp"
class IPFamily(IntEnum): class IPFamily(IntEnum):
IPv4 = socket.AF_INET IPv4 = socket.AF_INET
IPv6 = socket.AF_INET6 IPv6 = socket.AF_INET6
@property @property
def filter(self): def filter(self):
return "ip" if self == socket.AF_INET else "ipv6" return "ip" if self == socket.AF_INET else "ipv6"
@property @property
def version(self): def version(self):
return 4 if self == socket.AF_INET else 6 return 4 if self == socket.AF_INET else 6
@property @property
def loopback_addr(self): def loopback_addr(self):
return "127.0.0.1" if self == socket.AF_INET else "::1" return "127.0.0.1" if self == socket.AF_INET else "::1"
class ConnState(IntEnum): class ConnState(IntEnum):
TCP_SYN_SENT = 11 # SYN sent TCP_SYN_SENT = 11 # SYN sent
TCP_ESTABLISHED = 12 # SYN+ACK received TCP_ESTABLISHED = 12 # SYN+ACK received
TCP_FIN_WAIT_1 = 91 # FIN sent TCP_FIN_WAIT_1 = 91 # FIN sent
TCP_CLOSE_WAIT = 92 # FIN received TCP_CLOSE_WAIT = 92 # FIN received
@staticmethod @staticmethod
def can_timeout(state): def can_timeout(state):
return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT) return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT)
def repr_pkt(p): def repr_pkt(p):
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}" r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
if p.tcp: if p.tcp:
t = p.tcp t = p.tcp
r += f" {len(t.payload)}B (" r += f" {len(t.payload)}B ("
r += "+".join( r += "+".join(
f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f) f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f)
) )
r += f") SEQ#{t.seq_num}" r += f") SEQ#{t.seq_num}"
if t.ack: if t.ack:
r += f" ACK#{t.ack_num}" r += f" ACK#{t.ack_num}"
r += f" WZ={t.window_size}" r += f" WZ={t.window_size}"
else: else:
r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}" r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}"
return f"<Pkt {r}>" return f"<Pkt {r}>"
def synchronized_method(lock): def synchronized_method(lock):
def decorator(method): def decorator(method):
@wraps(method) @wraps(method)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
with getattr(self, lock): with getattr(self, lock):
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return wrapped return wrapped
return decorator return decorator
class ConnTrack: class ConnTrack:
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if not cls._instance: if not cls._instance:
cls._instance = object.__new__(cls) cls._instance = object.__new__(cls)
return cls._instance return cls._instance
raise RuntimeError("ConnTrack can not be instantiated multiple times") raise RuntimeError("ConnTrack can not be instantiated multiple times")
def __init__(self, name, max_connections=0) -> None: def __init__(self, name, max_connections=0) -> None:
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B"))) self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H"))) self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B"))) self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
try: try:
self.max_connections = max_connections self.max_connections = max_connections
self.shm_list = shared_memory.ShareableList( self.shm_list = shared_memory.ShareableList(
[bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name [bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name
) )
self.is_owner = True self.is_owner = True
self.next_slot = 0 self.next_slot = 0
self.used_slots = set() self.used_slots = set()
self.rlock = threading.RLock() self.rlock = threading.RLock()
except FileExistsError: except FileExistsError:
self.is_owner = False self.is_owner = False
self.shm_list = shared_memory.ShareableList(name=name) self.shm_list = shared_memory.ShareableList(name=name)
self.max_connections = len(self.shm_list) self.max_connections = len(self.shm_list)
debug2( debug2(
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} " f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
f"shm_size={self.shm_list.shm.size}B" f"shm_size={self.shm_list.shm.size}B"
) )
@synchronized_method("rlock") @synchronized_method("rlock")
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state): def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
if len(self.used_slots) >= self.max_connections: if len(self.used_slots) >= self.max_connections:
raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}") raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}")
if self.get(proto, src_addr, src_port): if self.get(proto, src_addr, src_port):
return return
for _ in range(self.max_connections): for _ in range(self.max_connections):
if self.next_slot not in self.used_slots: if self.next_slot not in self.used_slots:
break break
self.next_slot = (self.next_slot + 1) % self.max_connections self.next_slot = (self.next_slot + 1) % self.max_connections
else: else:
raise RuntimeError("No slot available in ConnTrack") # should not be here raise RuntimeError("No slot available in ConnTrack") # should not be here
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
dst_addr = ipaddress.ip_address(dst_addr) dst_addr = ipaddress.ip_address(dst_addr)
assert src_addr.version == dst_addr.version assert src_addr.version == dst_addr.version
ip_version = src_addr.version ip_version = src_addr.version
state_epoch = int(time.time()) state_epoch = int(time.time())
entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state) entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state)
packed = self.struct_full_tuple.pack(*entry) packed = self.struct_full_tuple.pack(*entry)
self.shm_list[self.next_slot] = packed self.shm_list[self.next_slot] = packed
self.used_slots.add(self.next_slot) self.used_slots.add(self.next_slot)
proto = IPProtocol(proto) proto = IPProtocol(proto)
debug3( debug3(
f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to " f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to "
f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}" f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}"
) )
@synchronized_method("rlock") @synchronized_method("rlock")
def update(self, proto, src_addr, src_port, state): def update(self, proto, src_addr, src_port, state):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for i in self.used_slots: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
state_epoch = int(time.time()) state_epoch = int(time.time())
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state) self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
debug3( debug3(
f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | " f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | "
f"#ActiveConn={len(self.used_slots)}" f"#ActiveConn={len(self.used_slots)}"
) )
return self._unpack(self.shm_list[i]) return self._unpack(self.shm_list[i])
else: else:
debug3( debug3(
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | " f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | "
f"#ActiveConn={len(self.used_slots)}" f"#ActiveConn={len(self.used_slots)}"
) )
@synchronized_method("rlock") @synchronized_method("rlock")
def remove(self, proto, src_addr, src_port): def remove(self, proto, src_addr, src_port):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for i in self.used_slots: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
conn = self._unpack(self.shm_list[i]) conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b"" self.shm_list[i] = b""
self.used_slots.remove(i) self.used_slots.remove(i)
debug3( debug3(
f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | " f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | "
f"#ActiveConn={len(self.used_slots)}" f"#ActiveConn={len(self.used_slots)}"
) )
return conn return conn
else: else:
debug3( debug3(
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |" f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |"
f" #ActiveConn={len(self.used_slots)}" f" #ActiveConn={len(self.used_slots)}"
) )
def get(self, proto, src_addr, src_port): def get(self, proto, src_addr, src_port):
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for entry in self.shm_list: for entry in self.shm_list:
if entry and entry.startswith(packed): if entry and entry.startswith(packed):
return self._unpack(entry) return self._unpack(entry)
@synchronized_method("rlock") @synchronized_method("rlock")
def gc(self, connection_timeout_sec=15): def gc(self, connection_timeout_sec=15):
now = int(time.time()) now = int(time.time())
n = 0 n = 0
for i in tuple(self.used_slots): for i in tuple(self.used_slots):
state_packed = self.shm_list[i][-5:] state_packed = self.shm_list[i][-5:]
(state_epoch, state) = self.struct_state_tuple.unpack(state_packed) (state_epoch, state) = self.struct_state_tuple.unpack(state_packed)
if (now - state_epoch) < connection_timeout_sec: if (now - state_epoch) < connection_timeout_sec:
continue continue
if ConnState.can_timeout(state): if ConnState.can_timeout(state):
conn = self._unpack(self.shm_list[i]) conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b"" self.shm_list[i] = b""
self.used_slots.remove(i) self.used_slots.remove(i)
n += 1 n += 1
debug3( debug3(
f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})" f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})"
f" from slot={i} | #ActiveConn={len(self.used_slots)}" f" from slot={i} | #ActiveConn={len(self.used_slots)}"
) )
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}") debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
def _unpack(self, packed): def _unpack(self, packed):
( (
proto, proto,
ip_version, ip_version,
src_addr_packed, src_addr_packed,
src_port, src_port,
dst_addr_packed, dst_addr_packed,
dst_port, dst_port,
state_epoch, state_epoch,
state, state,
) = self.struct_full_tuple.unpack(packed) ) = self.struct_full_tuple.unpack(packed)
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
return ConnectionTuple( return ConnectionTuple(
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state) IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
) )
def __iter__(self): def __iter__(self):
def conn_iter(): def conn_iter():
for i in self.used_slots: for i in self.used_slots:
yield self._unpack(self.shm_list[i]) yield self._unpack(self.shm_list[i])
return conn_iter() return conn_iter()
def __repr__(self): def __repr__(self):
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'},cap={len(self.shm_list)},owner={self.is_owner})>" return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'},cap={len(self.shm_list)},owner={self.is_owner})>"
class Method(BaseMethod): class Method(BaseMethod):
network_config = {} network_config = {}
proxy_port = None proxy_port = None
proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None} proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None}
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def _get_local_proxy_listen_addr(self, port, family): def _get_local_proxy_listen_addr(self, port, family):
proto = "TCPv6" if family.version == 6 else "TCP" proto = "TCPv6" if family.version == 6 else "TCP"
for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines(): for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines():
try: try:
_, local_addr, _, state, *_ = re.split(r"\s+", line.strip()) _, local_addr, _, state, *_ = re.split(r"\s+", line.strip())
except ValueError: except ValueError:
continue continue
port_suffix = ":" + str(port) port_suffix = ":" + str(port)
if state == "LISTENING" and local_addr.endswith(port_suffix): if state == "LISTENING" and local_addr.endswith(port_suffix):
return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]")) return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]"))
raise Fatal("Could not find listening address for {}/{}".format(port, proto)) raise Fatal("Could not find listening address for {}/{}".format(port, proto))
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark): def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark):
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
if nslist or user or udp: if nslist or user or udp:
raise NotImplementedError() raise NotImplementedError()
family = IPFamily(family) family = IPFamily(family)
# using loopback proxy address never worked. # using loopback proxy address never worked.
# >>> self.proxy_addr[family] = family.loopback_addr # >>> self.proxy_addr[family] = family.loopback_addr
# See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82) # See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
# As a workaround we use another interface ip instead. # As a workaround we use another interface ip instead.
local_addr = self._get_local_proxy_listen_addr(port, family) local_addr = self._get_local_proxy_listen_addr(port, family)
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
if addr.is_loopback or addr.version != family.version: if addr.is_loopback or addr.version != family.version:
continue continue
if local_addr.is_unspecified or local_addr == addr: if local_addr.is_unspecified or local_addr == addr:
debug2("Found non loopback address to connect to proxy: " + str(addr)) debug2("Found non loopback address to connect to proxy: " + str(addr))
self.proxy_addr[family] = str(addr) self.proxy_addr[family] = str(addr)
break break
else: else:
raise Fatal("Windivert method requires proxy to listen on non loopback address") raise Fatal("Windivert method requires proxy to listen on non loopback address")
self.proxy_port = port self.proxy_port = port
subnet_addresses = [] subnet_addresses = []
for (_, mask, exclude, network_addr, fport, lport) in subnets: for (_, mask, exclude, network_addr, fport, lport) in subnets:
if exclude: if exclude:
continue continue
assert fport == 0, "custom port range not supported" assert fport == 0, "custom port range not supported"
assert lport == 0, "custom port range not supported" assert lport == 0, "custom port range not supported"
subnet_addresses.append("%s/%s" % (network_addr, mask)) subnet_addresses.append("%s/%s" % (network_addr, mask))
self.network_config[family] = { self.network_config[family] = {
"subnets": subnet_addresses, "subnets": subnet_addresses,
"nslist": nslist, "nslist": nslist,
} }
def wait_for_firewall_ready(self): def wait_for_firewall_ready(self):
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}") debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS) self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
methods = (self._egress_divert, self._ingress_divert, self._connection_gc) methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
ready_events = [] ready_events = []
for fn in methods: for fn in methods:
ev = threading.Event() ev = threading.Event()
ready_events.append(ev) ready_events.append(ev)
def _target(): def _target():
try: try:
fn(ev.set) fn(ev.set)
except Exception: except Exception:
debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc()) debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc())
sys.stdin.close() # this will exist main thread sys.stdin.close() # this will exist main thread
sys.stdout.close() sys.stdout.close()
threading.Thread(name=fn.__name__, target=_target, daemon=True).start() threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
for ev in ready_events: for ev in ready_events:
if not ev.wait(5): # at most 5 sec if not ev.wait(5): # at most 5 sec
raise Fatal("timeout in wait_for_firewall_ready()") raise Fatal("timeout in wait_for_firewall_ready()")
def restore_firewall(self, port, family, udp, user): def restore_firewall(self, port, family, udp, user, group):
pass pass
def get_supported_features(self): def get_supported_features(self):
result = super(Method, self).get_supported_features() result = super(Method, self).get_supported_features()
result.loopback_port = False result.loopback_port = False
result.user = False result.user = False
result.dns = False result.dns = False
result.ipv6 = False result.ipv6 = False
return result return result
def get_tcp_dstip(self, sock): def get_tcp_dstip(self, sock):
if not hasattr(self, "conntrack"): if not hasattr(self, "conntrack"):
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}") self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
src_addr, src_port = sock.getpeername() src_addr, src_port = sock.getpeername()
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port) c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
if not c: if not c:
return (src_addr, src_port) return (src_addr, src_port)
return (c.dst_addr, c.dst_port) return (c.dst_addr, c.dst_port)
def is_supported(self): def is_supported(self):
if sys.platform == "win32": if sys.platform == "win32":
return True return True
return False return False
def _egress_divert(self, ready_cb): def _egress_divert(self, ready_cb):
proto = IPProtocol.TCP proto = IPProtocol.TCP
filter = f"outbound and {proto.filter}" filter = f"outbound and {proto.filter}"
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w: # with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
family_filters = [] family_filters = []
for af, c in self.network_config.items(): for af, c in self.network_config.items():
subnet_filters = [] subnet_filters = []
for cidr in c["subnets"]: for cidr in c["subnets"]:
ip_network = ipaddress.ip_network(cidr) ip_network = ipaddress.ip_network(cidr)
first_ip = ip_network.network_address first_ip = ip_network.network_address
last_ip = ip_network.broadcast_address last_ip = ip_network.broadcast_address
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})") subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ") family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
filter = f"{filter} and ({' or '.join(family_filters)})" filter = f"{filter} and ({' or '.join(family_filters)})"
debug1(f"[OUTBOUND] {filter=}") debug1(f"[OUTBOUND] {filter=}")
with pydivert.WinDivert(filter) as w: with pydivert.WinDivert(filter) as w:
ready_cb() ready_cb()
proxy_port = self.proxy_port proxy_port = self.proxy_port
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4] proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6] proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
for pkt in w: for pkt in w:
debug3(">>> " + repr_pkt(pkt)) debug3(">>> " + repr_pkt(pkt))
if pkt.tcp.syn and not pkt.tcp.ack: if pkt.tcp.syn and not pkt.tcp.ack:
# SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK) # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
self.conntrack.add( self.conntrack.add(
socket.IPPROTO_TCP, socket.IPPROTO_TCP,
pkt.src_addr, pkt.src_addr,
pkt.src_port, pkt.src_port,
pkt.dst_addr, pkt.dst_addr,
pkt.dst_port, pkt.dst_port,
ConnState.TCP_SYN_SENT, ConnState.TCP_SYN_SENT,
) )
if pkt.tcp.fin: if pkt.tcp.fin:
# FIN sent (start of graceful close our side, and we wait for ACK) # FIN sent (start of graceful close our side, and we wait for ACK)
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1) self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
if pkt.tcp.rst: if pkt.tcp.rst:
# RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply) # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port) self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
# DNAT # DNAT
if pkt.ipv4 and proxy_addr_ipv4: if pkt.ipv4 and proxy_addr_ipv4:
pkt.dst_addr = proxy_addr_ipv4 pkt.dst_addr = proxy_addr_ipv4
if pkt.ipv6 and proxy_addr_ipv6: if pkt.ipv6 and proxy_addr_ipv6:
pkt.dst_addr = proxy_addr_ipv6 pkt.dst_addr = proxy_addr_ipv6
pkt.tcp.dst_port = proxy_port pkt.tcp.dst_port = proxy_port
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well # XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well
# by setting src_addr to loopback address. # by setting src_addr to loopback address.
# Otherwise injecting packet will be ignored by Windows network stack # Otherwise injecting packet will be ignored by Windows network stack
# as they packet has to cross public to private address space. # as they packet has to cross public to private address space.
# See: https://github.com/basil00/Divert/issues/82 # See: https://github.com/basil00/Divert/issues/82
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets. # Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
# >>> pkt.dst_addr = proxy_addr_ipv4 # >>> pkt.dst_addr = proxy_addr_ipv4
w.send(pkt, recalculate_checksum=True) w.send(pkt, recalculate_checksum=True)
def _ingress_divert(self, ready_cb): def _ingress_divert(self, ready_cb):
proto = IPProtocol.TCP proto = IPProtocol.TCP
direction = "inbound" # only when proxy address is not loopback address (Useful for testing) direction = "inbound" # only when proxy address is not loopback address (Useful for testing)
ip_filters = [] ip_filters = []
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a): for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
if addr.is_loopback: # Windivert treats all loopback traffic as outbound if addr.is_loopback: # Windivert treats all loopback traffic as outbound
direction = "outbound" direction = "outbound"
if addr.version == 4: if addr.version == 4:
ip_filters.append(f"ip.SrcAddr=={addr}") ip_filters.append(f"ip.SrcAddr=={addr}")
else: else:
# ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this # ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this
ip_filters.append(f"ipv6.SrcAddr=={addr}") ip_filters.append(f"ipv6.SrcAddr=={addr}")
if not ip_filters: if not ip_filters:
raise Fatal("At least ipv4 or ipv6 address is expected") raise Fatal("At least ipv4 or ipv6 address is expected")
filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}" filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}"
debug1(f"[INGRESS] {filter=}") debug1(f"[INGRESS] {filter=}")
with pydivert.WinDivert(filter) as w: with pydivert.WinDivert(filter) as w:
ready_cb() ready_cb()
for pkt in w: for pkt in w:
debug3("<<< " + repr_pkt(pkt)) debug3("<<< " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: if pkt.tcp.syn and pkt.tcp.ack:
# SYN+ACK received (connection established) # SYN+ACK received (connection established)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
elif pkt.tcp.rst: elif pkt.tcp.rst:
# RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets # RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
# https://wiki.wireshark.org/TCP-4-times-close.md # https://wiki.wireshark.org/TCP-4-times-close.md
elif pkt.tcp.fin and pkt.tcp.ack: elif pkt.tcp.fin and pkt.tcp.ack:
# FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK) # FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK)
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
elif pkt.tcp.fin: elif pkt.tcp.fin:
# FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet) # FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
else: else:
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
if not conn: if not conn:
debug2("Unexpected packet: " + repr_pkt(pkt)) debug2("Unexpected packet: " + repr_pkt(pkt))
continue continue
pkt.src_addr = conn.dst_addr pkt.src_addr = conn.dst_addr
pkt.tcp.src_port = conn.dst_port pkt.tcp.src_port = conn.dst_port
w.send(pkt, recalculate_checksum=True) w.send(pkt, recalculate_checksum=True)
def _connection_gc(self, ready_cb): def _connection_gc(self, ready_cb):
ready_cb() ready_cb()
while True: while True:
time.sleep(5) time.sleep(5)
self.conntrack.gc() self.conntrack.gc()

View File

@ -260,7 +260,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start() threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start() threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
return s2 return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0)
# https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows # https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows
close_fds = False if sys.platform == 'win32' else True close_fds = False if sys.platform == 'win32' else True

View File

@ -10,7 +10,7 @@ import sshuttle.firewall
def setup_daemon(): def setup_daemon():
stdin = io.StringIO(u"""ROUTES stdin = io.BytesIO(u"""ROUTES
{inet},24,0,1.2.3.0,8000,9000 {inet},24,0,1.2.3.0,8000,9000
{inet},32,1,1.2.3.66,8080,8080 {inet},32,1,1.2.3.66,8080,8080
{inet6},64,0,2404:6800:4004:80c::,0,0 {inet6},64,0,2404:6800:4004:80c::,0,0
@ -127,9 +127,9 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
] ]
assert stdout.mock_calls == [ assert stdout.mock_calls == [
call.write('READY test\n'), call.write(b'READY test\n'),
call.flush(), call.flush(),
call.write('STARTED\n'), call.write(b'STARTED\n'),
call.flush() call.flush()
] ]
assert mock_setup_daemon.mock_calls == [call()] assert mock_setup_daemon.mock_calls == [call()]