mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-21 15:33:23 +01:00
pass flake8 linting
This commit is contained in:
parent
7da3b024dd
commit
482e0cbd00
@ -3,6 +3,6 @@ import sys
|
||||
import os
|
||||
from sshuttle.cmdline import main
|
||||
from sshuttle.helpers import debug3
|
||||
exit_code=main()
|
||||
exit_code = main()
|
||||
debug3("Exiting process %r (pid:%s) with code %s" % (sys.argv, os.getpid(), exit_code,))
|
||||
sys.exit(exit_code)
|
||||
sys.exit(exit_code)
|
||||
|
@ -226,7 +226,7 @@ class FirewallClient:
|
||||
argv_tries.append(argvbase)
|
||||
# runas_path = which("runas")
|
||||
# if runas_path:
|
||||
# argv_tries.append(['runas' , '/noprofile', '/user:Administrator', 'python'])
|
||||
# argv_tries.append(['runas' , '/noprofile', '/user:Administrator', 'python'])
|
||||
else:
|
||||
# Linux typically uses sudo; OpenBSD uses doas. However, some
|
||||
# Linux distributions are starting to use doas.
|
||||
@ -248,8 +248,8 @@ class FirewallClient:
|
||||
# --no-sudo-pythonpath option.
|
||||
if sudo_pythonpath:
|
||||
pp_prefix = ['/usr/bin/env',
|
||||
'PYTHONPATH=%s' %
|
||||
os.path.dirname(os.path.dirname(__file__))]
|
||||
'PYTHONPATH=%s' %
|
||||
os.path.dirname(os.path.dirname(__file__))]
|
||||
sudo_cmd = sudo_cmd + pp_prefix
|
||||
doas_cmd = doas_cmd + pp_prefix
|
||||
|
||||
@ -260,8 +260,7 @@ class FirewallClient:
|
||||
|
||||
# If we can find doas and not sudo or if we are on
|
||||
# OpenBSD, try using doas first.
|
||||
if (doas_path and not sudo_path) or \
|
||||
platform.platform().startswith('OpenBSD'):
|
||||
if (doas_path and not sudo_path) or platform.platform().startswith('OpenBSD'):
|
||||
argv_tries = [doas_cmd, sudo_cmd, argvbase]
|
||||
else:
|
||||
argv_tries = [sudo_cmd, doas_cmd, argvbase]
|
||||
@ -282,9 +281,11 @@ class FirewallClient:
|
||||
pstdout = s1
|
||||
pstdin = s1
|
||||
penv = None
|
||||
|
||||
def preexec_fn():
|
||||
# run in the child process
|
||||
s2.close()
|
||||
|
||||
def get_pfile():
|
||||
s1.close()
|
||||
return s2.makefile('rwb')
|
||||
@ -295,7 +296,8 @@ class FirewallClient:
|
||||
pstdin = ssubprocess.PIPE
|
||||
preexec_fn = None
|
||||
penv = os.environ.copy()
|
||||
penv['PYTHONPATH'] = os.path.dirname(os.path.dirname(__file__))
|
||||
penv['PYTHONPATH'] = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
def get_pfile():
|
||||
import base64
|
||||
socket_share_data = s1.share(self.p.pid)
|
||||
@ -318,14 +320,13 @@ class FirewallClient:
|
||||
'Command=%r Exception=%s' % (argv, e))
|
||||
continue
|
||||
self.argv = argv
|
||||
|
||||
self.pfile = get_pfile()
|
||||
|
||||
try:
|
||||
line = self.pfile.readline()
|
||||
except IOError:
|
||||
# happens when firewall subprocess exists
|
||||
line=''
|
||||
line = ''
|
||||
|
||||
rv = self.p.poll() # Check if process is still running
|
||||
if rv:
|
||||
|
@ -130,11 +130,13 @@ def _setup_daemon_windows():
|
||||
sock.close()
|
||||
return sys.stdin, sys.stdout
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
setup_daemon = _setup_daemon_windows
|
||||
else:
|
||||
setup_daemon = _setup_daemon_unix
|
||||
|
||||
|
||||
# Note that we're sorting in a very particular order:
|
||||
# we need to go from smaller, more specific, port ranges, to larger,
|
||||
# less-specific, port ranges. At each level, we order by subnet
|
||||
@ -216,8 +218,8 @@ def main(method_name, syslog):
|
||||
line = stdin.readline(128)
|
||||
if not line:
|
||||
# parent probably exited
|
||||
return
|
||||
except IOError:
|
||||
return
|
||||
except IOError as e:
|
||||
# On windows, this ConnectionResetError is thrown when parent process closes it's socket pair end
|
||||
debug3('read from stdin failed: %s' % (e,))
|
||||
return
|
||||
|
@ -16,7 +16,7 @@ def log(s):
|
||||
try:
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
except (IOError,ValueError):
|
||||
except (IOError, ValueError):
|
||||
pass
|
||||
# Put newline at end of string if line doesn't have one.
|
||||
if not s.endswith("\n"):
|
||||
@ -224,13 +224,14 @@ def which(file, mode=os.F_OK | os.X_OK):
|
||||
debug2("which() could not find '%s' in %s" % (file, path))
|
||||
return rv
|
||||
|
||||
|
||||
def is_admin_user():
|
||||
if sys.platform == 'win32':
|
||||
import ctypes
|
||||
# https://stackoverflow.com/questions/130763/request-uac-elevation-from-within-a-python-script/41930586#41930586
|
||||
try:
|
||||
return ctypes.windll.shell32.IsUserAnAdmin()
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return os.getuid() == 0
|
||||
|
@ -11,39 +11,42 @@ from enum import IntEnum
|
||||
import time
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import pydivert
|
||||
except ImportError:
|
||||
raise Fatal('Could not import pydivert module. windivert requires https://pypi.org/project/pydivert')
|
||||
|
||||
from sshuttle.methods import BaseMethod
|
||||
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal
|
||||
|
||||
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
||||
try:
|
||||
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
||||
import pydivert
|
||||
except ImportError:
|
||||
raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert")
|
||||
|
||||
|
||||
ConnectionTuple = namedtuple(
|
||||
"ConnectionTuple", ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", 'state']
|
||||
"ConnectionTuple",
|
||||
["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"],
|
||||
)
|
||||
|
||||
|
||||
WINDIVERT_MAX_CONNECTIONS = 10_000
|
||||
|
||||
|
||||
class IPProtocol(IntEnum):
|
||||
TCP = socket.IPPROTO_TCP
|
||||
UDP = socket.IPPROTO_UDP
|
||||
|
||||
@property
|
||||
def filter(self):
|
||||
return 'tcp' if self == IPProtocol.TCP else 'udp'
|
||||
return "tcp" if self == IPProtocol.TCP else "udp"
|
||||
|
||||
|
||||
class IPFamily(IntEnum):
|
||||
IPv4 = socket.AF_INET
|
||||
IPv6 = socket.AF_INET6
|
||||
IPv4 = socket.AF_INET
|
||||
IPv6 = socket.AF_INET6
|
||||
|
||||
@property
|
||||
def filter(self):
|
||||
return 'ip' if self == socket.AF_INET else 'ipv6'
|
||||
return "ip" if self == socket.AF_INET else "ipv6"
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
@ -51,14 +54,14 @@ class IPFamily(IntEnum):
|
||||
|
||||
@property
|
||||
def loopback_addr(self):
|
||||
return '127.0.0.1' if self == socket.AF_INET else '::1'
|
||||
return "127.0.0.1" if self == socket.AF_INET else "::1"
|
||||
|
||||
|
||||
class ConnState(IntEnum):
|
||||
TCP_SYN_SENT = 11 # SYN sent
|
||||
TCP_ESTABLISHED = 12 # SYN+ACK received
|
||||
TCP_FIN_WAIT_1 = 91 # FIN sent
|
||||
TCP_CLOSE_WAIT = 92 # FIN received
|
||||
TCP_SYN_SENT = 11 # SYN sent
|
||||
TCP_ESTABLISHED = 12 # SYN+ACK received
|
||||
TCP_FIN_WAIT_1 = 91 # FIN sent
|
||||
TCP_CLOSE_WAIT = 92 # FIN received
|
||||
|
||||
@staticmethod
|
||||
def can_timeout(state):
|
||||
@ -70,27 +73,34 @@ def repr_pkt(p):
|
||||
if p.tcp:
|
||||
t = p.tcp
|
||||
r += f" {len(t.payload)}B ("
|
||||
r += '+'.join(f.upper() for f in ('fin','syn', "rst", "psh", 'ack', 'urg', 'ece', 'cwr', 'ns') if getattr(t, f))
|
||||
r += f') SEQ#{t.seq_num}'
|
||||
r += "+".join(
|
||||
f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f)
|
||||
)
|
||||
r += f") SEQ#{t.seq_num}"
|
||||
if t.ack:
|
||||
r += f' ACK#{t.ack_num}'
|
||||
r += f' WZ={t.window_size}'
|
||||
r += f" ACK#{t.ack_num}"
|
||||
r += f" WZ={t.window_size}"
|
||||
else:
|
||||
r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}"
|
||||
return f"<Pkt {r}>"
|
||||
|
||||
|
||||
def synchronized_method(lock):
|
||||
def decorator(method):
|
||||
@wraps(method)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
with getattr(self, lock):
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ConnTrack:
|
||||
|
||||
_instance =None
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = object.__new__(cls)
|
||||
@ -98,13 +108,15 @@ class ConnTrack:
|
||||
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
||||
|
||||
def __init__(self, name, max_connections=0) -> None:
|
||||
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B')))
|
||||
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H')))
|
||||
self.struct_state_tuple = struct.Struct('>' + ''.join(('L', 'B')))
|
||||
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
|
||||
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
|
||||
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
|
||||
|
||||
try:
|
||||
self.max_connections = max_connections
|
||||
self.shm_list = shared_memory.ShareableList([bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name)
|
||||
self.shm_list = shared_memory.ShareableList(
|
||||
[bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name
|
||||
)
|
||||
self.is_owner = True
|
||||
self.next_slot = 0
|
||||
self.used_slots = set()
|
||||
@ -114,9 +126,12 @@ class ConnTrack:
|
||||
self.shm_list = shared_memory.ShareableList(name=name)
|
||||
self.max_connections = len(self.shm_list)
|
||||
|
||||
debug2(f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} shm_size={self.shm_list.shm.size}B")
|
||||
debug2(
|
||||
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
|
||||
f"shm_size={self.shm_list.shm.size}B"
|
||||
)
|
||||
|
||||
@synchronized_method('rlock')
|
||||
@synchronized_method("rlock")
|
||||
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
|
||||
if not self.is_owner:
|
||||
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||
@ -129,23 +144,26 @@ class ConnTrack:
|
||||
for _ in range(self.max_connections):
|
||||
if self.next_slot not in self.used_slots:
|
||||
break
|
||||
self.next_slot = (self.next_slot +1) % self.max_connections
|
||||
self.next_slot = (self.next_slot + 1) % self.max_connections
|
||||
else:
|
||||
raise RuntimeError("No slot available in ConnTrack") # should not be here
|
||||
raise RuntimeError("No slot available in ConnTrack") # should not be here
|
||||
|
||||
src_addr = ipaddress.ip_address(src_addr)
|
||||
dst_addr = ipaddress.ip_address(dst_addr)
|
||||
assert src_addr.version == dst_addr.version
|
||||
ip_version = src_addr.version
|
||||
state_epoch = int(time.time())
|
||||
state_epoch = int(time.time())
|
||||
entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state)
|
||||
packed = self.struct_full_tuple.pack(*entry)
|
||||
self.shm_list[self.next_slot] = packed
|
||||
self.used_slots.add(self.next_slot)
|
||||
proto = IPProtocol(proto)
|
||||
debug3(f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(
|
||||
f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to "
|
||||
f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}"
|
||||
)
|
||||
|
||||
@synchronized_method('rlock')
|
||||
@synchronized_method("rlock")
|
||||
def update(self, proto, src_addr, src_port, state):
|
||||
if not self.is_owner:
|
||||
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||
@ -155,12 +173,18 @@ class ConnTrack:
|
||||
if self.shm_list[i].startswith(packed):
|
||||
state_epoch = int(time.time())
|
||||
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
||||
debug3(f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(
|
||||
f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | "
|
||||
f"#ActiveConn={len(self.used_slots)}"
|
||||
)
|
||||
return self._unpack(self.shm_list[i])
|
||||
else:
|
||||
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(
|
||||
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | "
|
||||
f"#ActiveConn={len(self.used_slots)}"
|
||||
)
|
||||
|
||||
@synchronized_method('rlock')
|
||||
@synchronized_method("rlock")
|
||||
def remove(self, proto, src_addr, src_port):
|
||||
if not self.is_owner:
|
||||
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||
@ -169,13 +193,18 @@ class ConnTrack:
|
||||
for i in self.used_slots:
|
||||
if self.shm_list[i].startswith(packed):
|
||||
conn = self._unpack(self.shm_list[i])
|
||||
self.shm_list[i] = b''
|
||||
self.shm_list[i] = b""
|
||||
self.used_slots.remove(i)
|
||||
debug3(f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(
|
||||
f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | "
|
||||
f"#ActiveConn={len(self.used_slots)}"
|
||||
)
|
||||
return conn
|
||||
else:
|
||||
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
|
||||
|
||||
debug3(
|
||||
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |"
|
||||
f" #ActiveConn={len(self.used_slots)}"
|
||||
)
|
||||
|
||||
def get(self, proto, src_addr, src_port):
|
||||
src_addr = ipaddress.ip_address(src_addr)
|
||||
@ -184,7 +213,7 @@ class ConnTrack:
|
||||
if entry and entry.startswith(packed):
|
||||
return self._unpack(entry)
|
||||
|
||||
@synchronized_method('rlock')
|
||||
@synchronized_method("rlock")
|
||||
def gc(self, connection_timeout_sec=15):
|
||||
now = int(time.time())
|
||||
n = 0
|
||||
@ -195,47 +224,62 @@ class ConnTrack:
|
||||
continue
|
||||
if ConnState.can_timeout(state):
|
||||
conn = self._unpack(self.shm_list[i])
|
||||
self.shm_list[i] = b''
|
||||
self.shm_list[i] = b""
|
||||
self.used_slots.remove(i)
|
||||
n += 1
|
||||
debug3(f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
|
||||
debug3(
|
||||
f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})"
|
||||
f" from slot={i} | #ActiveConn={len(self.used_slots)}"
|
||||
)
|
||||
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
|
||||
|
||||
def _unpack(self, packed):
|
||||
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed)
|
||||
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
||||
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
||||
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
|
||||
|
||||
(
|
||||
proto,
|
||||
ip_version,
|
||||
src_addr_packed,
|
||||
src_port,
|
||||
dst_addr_packed,
|
||||
dst_port,
|
||||
state_epoch,
|
||||
state,
|
||||
) = self.struct_full_tuple.unpack(packed)
|
||||
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
||||
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
||||
return ConnectionTuple(
|
||||
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
def conn_iter():
|
||||
for i in self.used_slots:
|
||||
yield self._unpack(self.shm_list[i])
|
||||
|
||||
return conn_iter()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
|
||||
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'},cap={len(self.shm_list)},owner={self.is_owner})>"
|
||||
|
||||
|
||||
class Method(BaseMethod):
|
||||
|
||||
network_config = {}
|
||||
proxy_port = None
|
||||
proxy_addr = { IPFamily.IPv4: None, IPFamily.IPv6: None }
|
||||
proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None}
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
|
||||
user, tmark):
|
||||
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark):
|
||||
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||
|
||||
if nslist or user or udp:
|
||||
raise NotImplementedError()
|
||||
|
||||
family = IPFamily(family)
|
||||
|
||||
# using loopback proxy address never worked. See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
|
||||
# using loopback proxy address never worked.
|
||||
# See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
|
||||
# As a workaround we use another interface ip instead.
|
||||
# self.proxy_addr[family] = family.loopback_addr
|
||||
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
|
||||
@ -252,37 +296,37 @@ class Method(BaseMethod):
|
||||
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
||||
if exclude:
|
||||
continue
|
||||
assert fport == 0, 'custom port range not supported'
|
||||
assert lport == 0, 'custom port range not supported'
|
||||
assert fport == 0, "custom port range not supported"
|
||||
assert lport == 0, "custom port range not supported"
|
||||
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
||||
|
||||
self.network_config[family] = {
|
||||
'subnets': subnet_addresses,
|
||||
self.network_config[family] = {
|
||||
"subnets": subnet_addresses,
|
||||
"nslist": nslist,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def wait_for_firewall_ready(self):
|
||||
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
|
||||
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
|
||||
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
||||
ready_events = []
|
||||
for fn in methods:
|
||||
ev = threading.Event()
|
||||
ready_events.append(ev)
|
||||
|
||||
def _target():
|
||||
try:
|
||||
fn(ev.set)
|
||||
except:
|
||||
debug2(f'thread {fn.__name__} exiting due to: ' + traceback.format_exc())
|
||||
except Exception:
|
||||
debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc())
|
||||
sys.stdin.close() # this will exist main thread
|
||||
sys.stdout.close()
|
||||
|
||||
threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
|
||||
for ev in ready_events:
|
||||
if not ev.wait(5): # at most 5 sec
|
||||
raise Fatal(f"timeout in wait_for_firewall_ready()")
|
||||
|
||||
if not ev.wait(5): # at most 5 sec
|
||||
raise Fatal("timeout in wait_for_firewall_ready()")
|
||||
|
||||
def restore_firewall(self, port, family, udp, user):
|
||||
pass
|
||||
|
||||
@ -294,17 +338,17 @@ class Method(BaseMethod):
|
||||
return result
|
||||
|
||||
def get_tcp_dstip(self, sock):
|
||||
if not hasattr(self, 'conntrack'):
|
||||
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getpid()}')
|
||||
if not hasattr(self, "conntrack"):
|
||||
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
|
||||
|
||||
src_addr , src_port = sock.getpeername()
|
||||
c = self.conntrack.get(IPProtocol.TCP , src_addr, src_port)
|
||||
src_addr, src_port = sock.getpeername()
|
||||
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
|
||||
if not c:
|
||||
return (src_addr , src_port)
|
||||
return (src_addr, src_port)
|
||||
return (c.dst_addr, c.dst_port)
|
||||
|
||||
def is_supported(self):
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -315,8 +359,8 @@ class Method(BaseMethod):
|
||||
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
||||
family_filters = []
|
||||
for af, c in self.network_config.items():
|
||||
subnet_filters = []
|
||||
for cidr in c['subnets']:
|
||||
subnet_filters = []
|
||||
for cidr in c["subnets"]:
|
||||
ip_network = ipaddress.ip_network(cidr)
|
||||
first_ip = ip_network.network_address
|
||||
last_ip = ip_network.broadcast_address
|
||||
@ -333,11 +377,21 @@ class Method(BaseMethod):
|
||||
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
||||
for pkt in w:
|
||||
debug3(">>> " + repr_pkt(pkt))
|
||||
if pkt.tcp.syn and not pkt.tcp.ack: # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
|
||||
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SENT)
|
||||
if pkt.tcp.fin: # FIN sent (start of graceful close our side, and we wait for ACK)
|
||||
if pkt.tcp.syn and not pkt.tcp.ack:
|
||||
# SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
|
||||
self.conntrack.add(
|
||||
socket.IPPROTO_TCP,
|
||||
pkt.src_addr,
|
||||
pkt.src_port,
|
||||
pkt.dst_addr,
|
||||
pkt.dst_port,
|
||||
ConnState.TCP_SYN_SENT,
|
||||
)
|
||||
if pkt.tcp.fin:
|
||||
# FIN sent (start of graceful close our side, and we wait for ACK)
|
||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
|
||||
if pkt.tcp.rst : # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
|
||||
if pkt.tcp.rst:
|
||||
# RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
|
||||
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||
|
||||
# DNAT
|
||||
@ -347,18 +401,19 @@ class Method(BaseMethod):
|
||||
pkt.dst_addr = proxy_addr_ipv6
|
||||
pkt.tcp.dst_port = proxy_port
|
||||
|
||||
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well by setting src_addr to loopback address.
|
||||
# Otherwise injecting packet will be ignored by Windows network stack as teh packet has to cross public to private address space.
|
||||
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well
|
||||
# by setting src_addr to loopback address.
|
||||
# Otherwise injecting packet will be ignored by Windows network stack
|
||||
# as they packet has to cross public to private address space.
|
||||
# See: https://github.com/basil00/Divert/issues/82
|
||||
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
|
||||
# >>> pkt.dst_addr = proxy_addr_ipv4
|
||||
# >>> pkt.dst_addr = proxy_addr_ipv4
|
||||
|
||||
w.send(pkt, recalculate_checksum=True)
|
||||
|
||||
|
||||
def _ingress_divert(self, ready_cb):
|
||||
proto = IPProtocol.TCP
|
||||
direction = 'inbound' # only when proxy address is not loopback address (Useful for testing)
|
||||
direction = "inbound" # only when proxy address is not loopback address (Useful for testing)
|
||||
ip_filters = []
|
||||
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
|
||||
if addr.is_loopback: # Windivert treats all loopback traffic as outbound
|
||||
@ -374,19 +429,23 @@ class Method(BaseMethod):
|
||||
ready_cb()
|
||||
for pkt in w:
|
||||
debug3("<<< " + repr_pkt(pkt))
|
||||
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK received (connection established)
|
||||
if pkt.tcp.syn and pkt.tcp.ack:
|
||||
# SYN+ACK received (connection established)
|
||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
|
||||
elif pkt.tcp.rst: # RST received - Abrupt connection teardown initiated by other side. We don't expect anymore packets
|
||||
elif pkt.tcp.rst:
|
||||
# RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets
|
||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||
# https://wiki.wireshark.org/TCP-4-times-close.md
|
||||
elif pkt.tcp.fin and pkt.tcp.ack: # FIN+ACK received (Passive close by other side. We don't expect any more packets. Other side expects an ACK)
|
||||
elif pkt.tcp.fin and pkt.tcp.ack:
|
||||
# FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK)
|
||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||
elif pkt.tcp.fin: # FIN received (Other side initiated graceful close. We expects a final ACK for a FIN packet)
|
||||
elif pkt.tcp.fin:
|
||||
# FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet)
|
||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
|
||||
else:
|
||||
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
||||
if not conn:
|
||||
debug2("Unexpected packet: " + repr_pkt(pkt))
|
||||
debug2("Unexpected packet: " + repr_pkt(pkt))
|
||||
continue
|
||||
pkt.src_addr = conn.dst_addr
|
||||
pkt.tcp.src_port = conn.dst_port
|
||||
|
@ -175,7 +175,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
||||
# case, sshuttle might not work at all since it is not
|
||||
# possible to run python on the remote machine---even if
|
||||
# it is present.
|
||||
devnull='/dev/null'
|
||||
devnull = '/dev/null'
|
||||
pycmd = ("P=python3; $P -V 2>%s || P=python; "
|
||||
"exec \"$P\" -c %s; exit 97") % \
|
||||
(devnull, quote(pyscript))
|
||||
@ -204,9 +204,9 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
||||
raise Fatal("Failed to find '%s' in path %s" % (argv[0], get_path()))
|
||||
argv[0] = abs_path
|
||||
|
||||
|
||||
if sys.platform != 'win32':
|
||||
(s1, s2) = socket.socketpair()
|
||||
|
||||
def preexec_fn():
|
||||
# runs in the child process
|
||||
s2.close()
|
||||
@ -222,13 +222,14 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
||||
preexec_fn = None
|
||||
pstdin = ssubprocess.PIPE
|
||||
pstdout = ssubprocess.PIPE
|
||||
|
||||
def get_serversock():
|
||||
import threading
|
||||
|
||||
def stream_stdout_to_sock():
|
||||
try:
|
||||
fd = p.stdout.fileno()
|
||||
for data in iter(lambda:os.read(fd, 16384), b''):
|
||||
for data in iter(lambda: os.read(fd, 16384), b''):
|
||||
s1.sendall(data)
|
||||
debug3(f"<<<<< p.stdout.read() {len(data)} {data[:min(32,len(data))]}...")
|
||||
finally:
|
||||
@ -238,7 +239,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
||||
|
||||
def stream_sock_to_stdin():
|
||||
try:
|
||||
for data in iter(lambda:s1.recv(16384), b''):
|
||||
for data in iter(lambda: s1.recv(16384), b''):
|
||||
debug3(f">>>>> p.stdout.write() {len(data)} {data[:min(32,len(data))]}...")
|
||||
while data:
|
||||
n = p.stdin.write(data)
|
||||
@ -247,7 +248,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
||||
debug2("Thread 'stream_sock_to_stdin' exiting")
|
||||
s1.close()
|
||||
p.terminate()
|
||||
|
||||
|
||||
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
|
||||
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
|
||||
# s2.setblocking(False)
|
||||
@ -258,7 +259,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
||||
|
||||
debug2("executing: %r" % argv)
|
||||
p = ssubprocess.Popen(argv, stdin=pstdin, stdout=pstdout, preexec_fn=preexec_fn,
|
||||
close_fds=close_fds, stderr=stderr, bufsize=0)
|
||||
close_fds=close_fds, stderr=stderr, bufsize=0)
|
||||
|
||||
serversock = get_serversock()
|
||||
serversock.sendall(content)
|
||||
|
@ -477,11 +477,9 @@ class Mux(Handler):
|
||||
# If LATENCY_BUFFER_SIZE is inappropriately large, we will
|
||||
# get a MemoryError here. Read no more than 1MiB.
|
||||
if sys.platform == 'win32':
|
||||
read = _nb_clean(self.rfile.raw._sock.recv,
|
||||
min(1048576, LATENCY_BUFFER_SIZE))
|
||||
read = _nb_clean(self.rfile.raw._sock.recv, min(1048576, LATENCY_BUFFER_SIZE))
|
||||
else:
|
||||
read = _nb_clean(os.read, self.rfile.fileno(),
|
||||
min(1048576, LATENCY_BUFFER_SIZE))
|
||||
read = _nb_clean(os.read, self.rfile.fileno(), min(1048576, LATENCY_BUFFER_SIZE))
|
||||
except OSError:
|
||||
_, e = sys.exc_info()[:2]
|
||||
raise Fatal('other end: %r' % e)
|
||||
|
Loading…
Reference in New Issue
Block a user