pass flake8 linting

This commit is contained in:
nom3ad 2022-09-07 12:26:21 +05:30 committed by Brian May
parent 7da3b024dd
commit 482e0cbd00
7 changed files with 172 additions and 110 deletions

View File

@ -260,8 +260,7 @@ class FirewallClient:
# If we can find doas and not sudo or if we are on # If we can find doas and not sudo or if we are on
# OpenBSD, try using doas first. # OpenBSD, try using doas first.
if (doas_path and not sudo_path) or \ if (doas_path and not sudo_path) or platform.platform().startswith('OpenBSD'):
platform.platform().startswith('OpenBSD'):
argv_tries = [doas_cmd, sudo_cmd, argvbase] argv_tries = [doas_cmd, sudo_cmd, argvbase]
else: else:
argv_tries = [sudo_cmd, doas_cmd, argvbase] argv_tries = [sudo_cmd, doas_cmd, argvbase]
@ -282,9 +281,11 @@ class FirewallClient:
pstdout = s1 pstdout = s1
pstdin = s1 pstdin = s1
penv = None penv = None
def preexec_fn(): def preexec_fn():
# run in the child process # run in the child process
s2.close() s2.close()
def get_pfile(): def get_pfile():
s1.close() s1.close()
return s2.makefile('rwb') return s2.makefile('rwb')
@ -296,6 +297,7 @@ class FirewallClient:
preexec_fn = None preexec_fn = None
penv = os.environ.copy() penv = os.environ.copy()
penv['PYTHONPATH'] = os.path.dirname(os.path.dirname(__file__)) penv['PYTHONPATH'] = os.path.dirname(os.path.dirname(__file__))
def get_pfile(): def get_pfile():
import base64 import base64
socket_share_data = s1.share(self.p.pid) socket_share_data = s1.share(self.p.pid)
@ -318,7 +320,6 @@ class FirewallClient:
'Command=%r Exception=%s' % (argv, e)) 'Command=%r Exception=%s' % (argv, e))
continue continue
self.argv = argv self.argv = argv
self.pfile = get_pfile() self.pfile = get_pfile()
try: try:

View File

@ -130,11 +130,13 @@ def _setup_daemon_windows():
sock.close() sock.close()
return sys.stdin, sys.stdout return sys.stdin, sys.stdout
if sys.platform == 'win32': if sys.platform == 'win32':
setup_daemon = _setup_daemon_windows setup_daemon = _setup_daemon_windows
else: else:
setup_daemon = _setup_daemon_unix setup_daemon = _setup_daemon_unix
# Note that we're sorting in a very particular order: # Note that we're sorting in a very particular order:
# we need to go from smaller, more specific, port ranges, to larger, # we need to go from smaller, more specific, port ranges, to larger,
# less-specific, port ranges. At each level, we order by subnet # less-specific, port ranges. At each level, we order by subnet
@ -217,7 +219,7 @@ def main(method_name, syslog):
if not line: if not line:
# parent probably exited # parent probably exited
return return
except IOError: except IOError as e:
# On windows, this ConnectionResetError is thrown when parent process closes it's socket pair end # On windows, this ConnectionResetError is thrown when parent process closes it's socket pair end
debug3('read from stdin failed: %s' % (e,)) debug3('read from stdin failed: %s' % (e,))
return return

View File

@ -224,13 +224,14 @@ def which(file, mode=os.F_OK | os.X_OK):
debug2("which() could not find '%s' in %s" % (file, path)) debug2("which() could not find '%s' in %s" % (file, path))
return rv return rv
def is_admin_user(): def is_admin_user():
if sys.platform == 'win32': if sys.platform == 'win32':
import ctypes import ctypes
# https://stackoverflow.com/questions/130763/request-uac-elevation-from-within-a-python-script/41930586#41930586 # https://stackoverflow.com/questions/130763/request-uac-elevation-from-within-a-python-script/41930586#41930586
try: try:
return ctypes.windll.shell32.IsUserAnAdmin() return ctypes.windll.shell32.IsUserAnAdmin()
except: except Exception:
return False return False
return os.getuid() == 0 return os.getuid() == 0

View File

@ -11,31 +11,34 @@ from enum import IntEnum
import time import time
import traceback import traceback
try:
import pydivert
except ImportError:
raise Fatal('Could not import pydivert module. windivert requires https://pypi.org/project/pydivert')
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal from sshuttle.helpers import debug3, log, debug1, debug2, Fatal
try:
# https://reqrypt.org/windivert-doc.html#divert_iphdr # https://reqrypt.org/windivert-doc.html#divert_iphdr
import pydivert
except ImportError:
raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert")
ConnectionTuple = namedtuple( ConnectionTuple = namedtuple(
"ConnectionTuple", ["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", 'state'] "ConnectionTuple",
["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"],
) )
WINDIVERT_MAX_CONNECTIONS = 10_000 WINDIVERT_MAX_CONNECTIONS = 10_000
class IPProtocol(IntEnum): class IPProtocol(IntEnum):
TCP = socket.IPPROTO_TCP TCP = socket.IPPROTO_TCP
UDP = socket.IPPROTO_UDP UDP = socket.IPPROTO_UDP
@property @property
def filter(self): def filter(self):
return 'tcp' if self == IPProtocol.TCP else 'udp' return "tcp" if self == IPProtocol.TCP else "udp"
class IPFamily(IntEnum): class IPFamily(IntEnum):
IPv4 = socket.AF_INET IPv4 = socket.AF_INET
@ -43,7 +46,7 @@ class IPFamily(IntEnum):
@property @property
def filter(self): def filter(self):
return 'ip' if self == socket.AF_INET else 'ipv6' return "ip" if self == socket.AF_INET else "ipv6"
@property @property
def version(self): def version(self):
@ -51,7 +54,7 @@ class IPFamily(IntEnum):
@property @property
def loopback_addr(self): def loopback_addr(self):
return '127.0.0.1' if self == socket.AF_INET else '::1' return "127.0.0.1" if self == socket.AF_INET else "::1"
class ConnState(IntEnum): class ConnState(IntEnum):
@ -70,27 +73,34 @@ def repr_pkt(p):
if p.tcp: if p.tcp:
t = p.tcp t = p.tcp
r += f" {len(t.payload)}B (" r += f" {len(t.payload)}B ("
r += '+'.join(f.upper() for f in ('fin','syn', "rst", "psh", 'ack', 'urg', 'ece', 'cwr', 'ns') if getattr(t, f)) r += "+".join(
r += f') SEQ#{t.seq_num}' f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f)
)
r += f") SEQ#{t.seq_num}"
if t.ack: if t.ack:
r += f' ACK#{t.ack_num}' r += f" ACK#{t.ack_num}"
r += f' WZ={t.window_size}' r += f" WZ={t.window_size}"
else: else:
r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}" r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}"
return f"<Pkt {r}>" return f"<Pkt {r}>"
def synchronized_method(lock): def synchronized_method(lock):
def decorator(method): def decorator(method):
@wraps(method) @wraps(method)
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
with getattr(self, lock): with getattr(self, lock):
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return wrapped return wrapped
return decorator return decorator
class ConnTrack: class ConnTrack:
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if not cls._instance: if not cls._instance:
cls._instance = object.__new__(cls) cls._instance = object.__new__(cls)
@ -98,13 +108,15 @@ class ConnTrack:
raise RuntimeError("ConnTrack can not be instantiated multiple times") raise RuntimeError("ConnTrack can not be instantiated multiple times")
def __init__(self, name, max_connections=0) -> None: def __init__(self, name, max_connections=0) -> None:
self.struct_full_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H', '16s', 'H', 'L', 'B'))) self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
self.struct_src_tuple = struct.Struct('>' + ''.join(('B', 'B', '16s', 'H'))) self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
self.struct_state_tuple = struct.Struct('>' + ''.join(('L', 'B'))) self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
try: try:
self.max_connections = max_connections self.max_connections = max_connections
self.shm_list = shared_memory.ShareableList([bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name) self.shm_list = shared_memory.ShareableList(
[bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name
)
self.is_owner = True self.is_owner = True
self.next_slot = 0 self.next_slot = 0
self.used_slots = set() self.used_slots = set()
@ -114,9 +126,12 @@ class ConnTrack:
self.shm_list = shared_memory.ShareableList(name=name) self.shm_list = shared_memory.ShareableList(name=name)
self.max_connections = len(self.shm_list) self.max_connections = len(self.shm_list)
debug2(f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} shm_size={self.shm_list.shm.size}B") debug2(
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
f"shm_size={self.shm_list.shm.size}B"
)
@synchronized_method('rlock') @synchronized_method("rlock")
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state): def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
@ -143,9 +158,12 @@ class ConnTrack:
self.shm_list[self.next_slot] = packed self.shm_list[self.next_slot] = packed
self.used_slots.add(self.next_slot) self.used_slots.add(self.next_slot)
proto = IPProtocol(proto) proto = IPProtocol(proto)
debug3(f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}") debug3(
f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to "
f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}"
)
@synchronized_method('rlock') @synchronized_method("rlock")
def update(self, proto, src_addr, src_port, state): def update(self, proto, src_addr, src_port, state):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
@ -155,12 +173,18 @@ class ConnTrack:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
state_epoch = int(time.time()) state_epoch = int(time.time())
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state) self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
debug3(f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}") debug3(
f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | "
f"#ActiveConn={len(self.used_slots)}"
)
return self._unpack(self.shm_list[i]) return self._unpack(self.shm_list[i])
else: else:
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}") debug3(
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | "
f"#ActiveConn={len(self.used_slots)}"
)
@synchronized_method('rlock') @synchronized_method("rlock")
def remove(self, proto, src_addr, src_port): def remove(self, proto, src_addr, src_port):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
@ -169,13 +193,18 @@ class ConnTrack:
for i in self.used_slots: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
conn = self._unpack(self.shm_list[i]) conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b'' self.shm_list[i] = b""
self.used_slots.remove(i) self.used_slots.remove(i)
debug3(f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}") debug3(
f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | "
f"#ActiveConn={len(self.used_slots)}"
)
return conn return conn
else: else:
debug3(f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}") debug3(
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |"
f" #ActiveConn={len(self.used_slots)}"
)
def get(self, proto, src_addr, src_port): def get(self, proto, src_addr, src_port):
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
@ -184,7 +213,7 @@ class ConnTrack:
if entry and entry.startswith(packed): if entry and entry.startswith(packed):
return self._unpack(entry) return self._unpack(entry)
@synchronized_method('rlock') @synchronized_method("rlock")
def gc(self, connection_timeout_sec=15): def gc(self, connection_timeout_sec=15):
now = int(time.time()) now = int(time.time())
n = 0 n = 0
@ -195,22 +224,37 @@ class ConnTrack:
continue continue
if ConnState.can_timeout(state): if ConnState.can_timeout(state):
conn = self._unpack(self.shm_list[i]) conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b'' self.shm_list[i] = b""
self.used_slots.remove(i) self.used_slots.remove(i)
n += 1 n += 1
debug3(f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}") debug3(
f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})"
f" from slot={i} | #ActiveConn={len(self.used_slots)}"
)
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}") debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
def _unpack(self, packed): def _unpack(self, packed):
(proto, ip_version, src_addr_packed, src_port, dst_addr_packed, dst_port, state_epoch, state) = self.struct_full_tuple.unpack(packed) (
proto,
ip_version,
src_addr_packed,
src_port,
dst_addr_packed,
dst_port,
state_epoch,
state,
) = self.struct_full_tuple.unpack(packed)
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4])) dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4])) src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)) return ConnectionTuple(
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
)
def __iter__(self): def __iter__(self):
def conn_iter(): def conn_iter():
for i in self.used_slots: for i in self.used_slots:
yield self._unpack(self.shm_list[i]) yield self._unpack(self.shm_list[i])
return conn_iter() return conn_iter()
def __repr__(self): def __repr__(self):
@ -226,8 +270,7 @@ class Method(BaseMethod):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark):
user, tmark):
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
if nslist or user or udp: if nslist or user or udp:
@ -235,7 +278,8 @@ class Method(BaseMethod):
family = IPFamily(family) family = IPFamily(family)
# using loopback proxy address never worked. See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82) # using loopback proxy address never worked.
# See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
# As a workaround we use another interface ip instead. # As a workaround we use another interface ip instead.
# self.proxy_addr[family] = family.loopback_addr # self.proxy_addr[family] = family.loopback_addr
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
@ -252,36 +296,36 @@ class Method(BaseMethod):
for (_, mask, exclude, network_addr, fport, lport) in subnets: for (_, mask, exclude, network_addr, fport, lport) in subnets:
if exclude: if exclude:
continue continue
assert fport == 0, 'custom port range not supported' assert fport == 0, "custom port range not supported"
assert lport == 0, 'custom port range not supported' assert lport == 0, "custom port range not supported"
subnet_addresses.append("%s/%s" % (network_addr, mask)) subnet_addresses.append("%s/%s" % (network_addr, mask))
self.network_config[family] = { self.network_config[family] = {
'subnets': subnet_addresses, "subnets": subnet_addresses,
"nslist": nslist, "nslist": nslist,
} }
def wait_for_firewall_ready(self): def wait_for_firewall_ready(self):
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}") debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS) self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
methods = (self._egress_divert, self._ingress_divert, self._connection_gc) methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
ready_events = [] ready_events = []
for fn in methods: for fn in methods:
ev = threading.Event() ev = threading.Event()
ready_events.append(ev) ready_events.append(ev)
def _target(): def _target():
try: try:
fn(ev.set) fn(ev.set)
except: except Exception:
debug2(f'thread {fn.__name__} exiting due to: ' + traceback.format_exc()) debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc())
sys.stdin.close() # this will exist main thread sys.stdin.close() # this will exist main thread
sys.stdout.close() sys.stdout.close()
threading.Thread(name=fn.__name__, target=_target, daemon=True).start() threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
for ev in ready_events: for ev in ready_events:
if not ev.wait(5): # at most 5 sec if not ev.wait(5): # at most 5 sec
raise Fatal(f"timeout in wait_for_firewall_ready()") raise Fatal("timeout in wait_for_firewall_ready()")
def restore_firewall(self, port, family, udp, user): def restore_firewall(self, port, family, udp, user):
pass pass
@ -294,8 +338,8 @@ class Method(BaseMethod):
return result return result
def get_tcp_dstip(self, sock): def get_tcp_dstip(self, sock):
if not hasattr(self, 'conntrack'): if not hasattr(self, "conntrack"):
self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getpid()}') self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
src_addr, src_port = sock.getpeername() src_addr, src_port = sock.getpeername()
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port) c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
@ -304,7 +348,7 @@ class Method(BaseMethod):
return (c.dst_addr, c.dst_port) return (c.dst_addr, c.dst_port)
def is_supported(self): def is_supported(self):
if sys.platform == 'win32': if sys.platform == "win32":
return True return True
return False return False
@ -316,7 +360,7 @@ class Method(BaseMethod):
family_filters = [] family_filters = []
for af, c in self.network_config.items(): for af, c in self.network_config.items():
subnet_filters = [] subnet_filters = []
for cidr in c['subnets']: for cidr in c["subnets"]:
ip_network = ipaddress.ip_network(cidr) ip_network = ipaddress.ip_network(cidr)
first_ip = ip_network.network_address first_ip = ip_network.network_address
last_ip = ip_network.broadcast_address last_ip = ip_network.broadcast_address
@ -333,11 +377,21 @@ class Method(BaseMethod):
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6] proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
for pkt in w: for pkt in w:
debug3(">>> " + repr_pkt(pkt)) debug3(">>> " + repr_pkt(pkt))
if pkt.tcp.syn and not pkt.tcp.ack: # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK) if pkt.tcp.syn and not pkt.tcp.ack:
self.conntrack.add(socket.IPPROTO_TCP, pkt.src_addr, pkt.src_port, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_SENT) # SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
if pkt.tcp.fin: # FIN sent (start of graceful close our side, and we wait for ACK) self.conntrack.add(
socket.IPPROTO_TCP,
pkt.src_addr,
pkt.src_port,
pkt.dst_addr,
pkt.dst_port,
ConnState.TCP_SYN_SENT,
)
if pkt.tcp.fin:
# FIN sent (start of graceful close our side, and we wait for ACK)
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1) self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
if pkt.tcp.rst : # RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply) if pkt.tcp.rst:
# RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port) self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
# DNAT # DNAT
@ -347,18 +401,19 @@ class Method(BaseMethod):
pkt.dst_addr = proxy_addr_ipv6 pkt.dst_addr = proxy_addr_ipv6
pkt.tcp.dst_port = proxy_port pkt.tcp.dst_port = proxy_port
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well by setting src_addr to loopback address. # XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well
# Otherwise injecting packet will be ignored by Windows network stack as teh packet has to cross public to private address space. # by setting src_addr to loopback address.
# Otherwise injecting packet will be ignored by Windows network stack
# as they packet has to cross public to private address space.
# See: https://github.com/basil00/Divert/issues/82 # See: https://github.com/basil00/Divert/issues/82
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets. # Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
# >>> pkt.dst_addr = proxy_addr_ipv4 # >>> pkt.dst_addr = proxy_addr_ipv4
w.send(pkt, recalculate_checksum=True) w.send(pkt, recalculate_checksum=True)
def _ingress_divert(self, ready_cb): def _ingress_divert(self, ready_cb):
proto = IPProtocol.TCP proto = IPProtocol.TCP
direction = 'inbound' # only when proxy address is not loopback address (Useful for testing) direction = "inbound" # only when proxy address is not loopback address (Useful for testing)
ip_filters = [] ip_filters = []
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a): for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
if addr.is_loopback: # Windivert treats all loopback traffic as outbound if addr.is_loopback: # Windivert treats all loopback traffic as outbound
@ -374,14 +429,18 @@ class Method(BaseMethod):
ready_cb() ready_cb()
for pkt in w: for pkt in w:
debug3("<<< " + repr_pkt(pkt)) debug3("<<< " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK received (connection established) if pkt.tcp.syn and pkt.tcp.ack:
# SYN+ACK received (connection established)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
elif pkt.tcp.rst: # RST received - Abrupt connection teardown initiated by other side. We don't expect anymore packets elif pkt.tcp.rst:
# RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
# https://wiki.wireshark.org/TCP-4-times-close.md # https://wiki.wireshark.org/TCP-4-times-close.md
elif pkt.tcp.fin and pkt.tcp.ack: # FIN+ACK received (Passive close by other side. We don't expect any more packets. Other side expects an ACK) elif pkt.tcp.fin and pkt.tcp.ack:
# FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK)
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
elif pkt.tcp.fin: # FIN received (Other side initiated graceful close. We expects a final ACK for a FIN packet) elif pkt.tcp.fin:
# FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
else: else:
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)

View File

@ -204,9 +204,9 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
raise Fatal("Failed to find '%s' in path %s" % (argv[0], get_path())) raise Fatal("Failed to find '%s' in path %s" % (argv[0], get_path()))
argv[0] = abs_path argv[0] = abs_path
if sys.platform != 'win32': if sys.platform != 'win32':
(s1, s2) = socket.socketpair() (s1, s2) = socket.socketpair()
def preexec_fn(): def preexec_fn():
# runs in the child process # runs in the child process
s2.close() s2.close()
@ -222,6 +222,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
preexec_fn = None preexec_fn = None
pstdin = ssubprocess.PIPE pstdin = ssubprocess.PIPE
pstdout = ssubprocess.PIPE pstdout = ssubprocess.PIPE
def get_serversock(): def get_serversock():
import threading import threading

View File

@ -477,11 +477,9 @@ class Mux(Handler):
# If LATENCY_BUFFER_SIZE is inappropriately large, we will # If LATENCY_BUFFER_SIZE is inappropriately large, we will
# get a MemoryError here. Read no more than 1MiB. # get a MemoryError here. Read no more than 1MiB.
if sys.platform == 'win32': if sys.platform == 'win32':
read = _nb_clean(self.rfile.raw._sock.recv, read = _nb_clean(self.rfile.raw._sock.recv, min(1048576, LATENCY_BUFFER_SIZE))
min(1048576, LATENCY_BUFFER_SIZE))
else: else:
read = _nb_clean(os.read, self.rfile.fileno(), read = _nb_clean(os.read, self.rfile.fileno(), min(1048576, LATENCY_BUFFER_SIZE))
min(1048576, LATENCY_BUFFER_SIZE))
except OSError: except OSError:
_, e = sys.exc_info()[:2] _, e = sys.exc_info()[:2]
raise Fatal('other end: %r' % e) raise Fatal('other end: %r' % e)