mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-21 23:43:18 +01:00
fix windows CRLF issue on stdin/stdout
This commit is contained in:
parent
900acc3ac7
commit
4a84ad3be6
@ -307,7 +307,7 @@ class FirewallClient:
|
|||||||
|
|
||||||
def get_pfile():
|
def get_pfile():
|
||||||
if can_use_stdio:
|
if can_use_stdio:
|
||||||
self.p.stdin.write(b'STDIO:\n')
|
self.p.stdin.write(b'COM_STDIO:\n')
|
||||||
self.p.stdin.flush()
|
self.p.stdin.flush()
|
||||||
|
|
||||||
class RWPair:
|
class RWPair:
|
||||||
@ -334,7 +334,7 @@ class FirewallClient:
|
|||||||
socket_share_data = s1.share(self.p.pid)
|
socket_share_data = s1.share(self.p.pid)
|
||||||
s1.close()
|
s1.close()
|
||||||
socket_share_data_b64 = base64.b64encode(socket_share_data)
|
socket_share_data_b64 = base64.b64encode(socket_share_data)
|
||||||
self.p.stdin.write(b'SOCKETSHARE:' + socket_share_data_b64 + b'\n')
|
self.p.stdin.write(b'COM_SOCKETSHARE:' + socket_share_data_b64 + b'\n')
|
||||||
self.p.stdin.flush()
|
self.p.stdin.flush()
|
||||||
return s2.makefile('rwb')
|
return s2.makefile('rwb')
|
||||||
try:
|
try:
|
||||||
|
@ -84,7 +84,7 @@ def firewall_exit(signum, frame):
|
|||||||
# the typical exit process as described above.
|
# the typical exit process as described above.
|
||||||
global sshuttle_pid
|
global sshuttle_pid
|
||||||
if sshuttle_pid:
|
if sshuttle_pid:
|
||||||
debug1("Relaying interupt signal to sshuttle process %d\n" % sshuttle_pid)
|
debug1("Relaying interupt signal to sshuttle process %d" % sshuttle_pid)
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
sig = signal.CTRL_C_EVENT
|
sig = signal.CTRL_C_EVENT
|
||||||
else:
|
else:
|
||||||
@ -115,7 +115,7 @@ def _setup_daemon_for_unix_like():
|
|||||||
# setsid() fails if sudo is configured with the use_pty option.
|
# setsid() fails if sudo is configured with the use_pty option.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return sys.stdin, sys.stdout
|
return sys.stdin.buffer, sys.stdout.buffer
|
||||||
|
|
||||||
|
|
||||||
def _setup_daemon_for_windows():
|
def _setup_daemon_for_windows():
|
||||||
@ -125,9 +125,9 @@ def _setup_daemon_for_windows():
|
|||||||
signal.signal(signal.SIGTERM, firewall_exit)
|
signal.signal(signal.SIGTERM, firewall_exit)
|
||||||
signal.signal(signal.SIGINT, firewall_exit)
|
signal.signal(signal.SIGINT, firewall_exit)
|
||||||
|
|
||||||
socket_share_data_prefix = 'SOCKETSHARE:'
|
socket_share_data_prefix = b'COM_SOCKETSHARE:'
|
||||||
line = sys.stdin.readline().strip()
|
line = sys.stdin.buffer.readline().strip()
|
||||||
if line.startswith('SOCKETSHARE:'):
|
if line.startswith(socket_share_data_prefix):
|
||||||
debug3('Using shared socket for communicating with sshuttle client process')
|
debug3('Using shared socket for communicating with sshuttle client process')
|
||||||
socket_share_data_b64 = line[len(socket_share_data_prefix):]
|
socket_share_data_b64 = line[len(socket_share_data_prefix):]
|
||||||
socket_share_data = base64.b64decode(socket_share_data_b64)
|
socket_share_data = base64.b64decode(socket_share_data_b64)
|
||||||
@ -135,12 +135,12 @@ def _setup_daemon_for_windows():
|
|||||||
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
|
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
|
||||||
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
|
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
|
||||||
sock.close()
|
sock.close()
|
||||||
elif line.startswith("STDIO:"):
|
elif line.startswith(b"COM_STDIO:"):
|
||||||
debug3('Using inherited stdio for communicating with sshuttle client process')
|
debug3('Using inherited stdio for communicating with sshuttle client process')
|
||||||
else:
|
else:
|
||||||
raise Fatal("Unexpected stdin: " + line)
|
raise Fatal("Unexpected stdin: " + line)
|
||||||
|
|
||||||
return sys.stdin, sys.stdout
|
return sys.stdin.buffer, sys.stdout.buffer
|
||||||
|
|
||||||
|
|
||||||
# Isolate function that needs to be replaced for tests
|
# Isolate function that needs to be replaced for tests
|
||||||
@ -221,33 +221,43 @@ def main(method_name, syslog):
|
|||||||
"PATH." % method_name)
|
"PATH." % method_name)
|
||||||
|
|
||||||
debug1('ready method name %s.' % method.name)
|
debug1('ready method name %s.' % method.name)
|
||||||
stdout.write('READY %s\n' % method.name)
|
stdout.write(('READY %s\n' % method.name).encode('ASCII'))
|
||||||
stdout.flush()
|
stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def _read_next_string_line():
|
||||||
|
try:
|
||||||
|
line = stdin.readline(128)
|
||||||
|
if not line:
|
||||||
|
return # parent probably exited
|
||||||
|
return line.decode('ASCII').strip()
|
||||||
|
except IOError as e:
|
||||||
|
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
|
||||||
|
debug3('read from stdin failed: %s' % (e,))
|
||||||
|
return
|
||||||
# we wait until we get some input before creating the rules. That way,
|
# we wait until we get some input before creating the rules. That way,
|
||||||
# sshuttle can launch us as early as possible (and get sudo password
|
# sshuttle can launch us as early as possible (and get sudo password
|
||||||
# authentication as early in the startup process as possible).
|
# authentication as early in the startup process as possible).
|
||||||
try:
|
try:
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line:
|
||||||
return # parent probably exited
|
return # parent probably exited
|
||||||
except ConnectionResetError as e:
|
except IOError as e:
|
||||||
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
|
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
|
||||||
debug3('read from stdin failed: %s' % (e,))
|
debug3('read from stdin failed: %s' % (e,))
|
||||||
return
|
return
|
||||||
|
|
||||||
subnets = []
|
subnets = []
|
||||||
if line != 'ROUTES\n':
|
if line != 'ROUTES':
|
||||||
raise Fatal('expected ROUTES but got %r' % line)
|
raise Fatal('expected ROUTES but got %r' % line)
|
||||||
while 1:
|
while 1:
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('expected route but got %r' % line)
|
raise Fatal('expected route but got %r' % line)
|
||||||
elif line.startswith("NSLIST\n"):
|
elif line.startswith("NSLIST"):
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
(family, width, exclude, ip, fport, lport) = \
|
(family, width, exclude, ip, fport, lport) = line.split(',', 5)
|
||||||
line.strip().split(',', 5)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Fatal('expected route or NSLIST but got %r' % line)
|
raise Fatal('expected route or NSLIST but got %r' % line)
|
||||||
subnets.append((
|
subnets.append((
|
||||||
@ -260,16 +270,16 @@ def main(method_name, syslog):
|
|||||||
debug2('Got subnets: %r' % subnets)
|
debug2('Got subnets: %r' % subnets)
|
||||||
|
|
||||||
nslist = []
|
nslist = []
|
||||||
if line != 'NSLIST\n':
|
if line != 'NSLIST':
|
||||||
raise Fatal('expected NSLIST but got %r' % line)
|
raise Fatal('expected NSLIST but got %r' % line)
|
||||||
while 1:
|
while 1:
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('expected nslist but got %r' % line)
|
raise Fatal('expected nslist but got %r' % line)
|
||||||
elif line.startswith("PORTS "):
|
elif line.startswith("PORTS "):
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
(family, ip) = line.strip().split(',', 1)
|
(family, ip) = line.split(',', 1)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Fatal('expected nslist or PORTS but got %r' % line)
|
raise Fatal('expected nslist or PORTS but got %r' % line)
|
||||||
nslist.append((int(family), ip))
|
nslist.append((int(family), ip))
|
||||||
@ -299,15 +309,13 @@ def main(method_name, syslog):
|
|||||||
debug2('Got ports: %d,%d,%d,%d'
|
debug2('Got ports: %d,%d,%d,%d'
|
||||||
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
||||||
|
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line or not line.startswith("GO "):
|
||||||
raise Fatal('expected GO but got %r' % line)
|
|
||||||
elif not line.startswith("GO "):
|
|
||||||
raise Fatal('expected GO but got %r' % line)
|
raise Fatal('expected GO but got %r' % line)
|
||||||
|
|
||||||
_, _, args = line.partition(" ")
|
_, _, args = line.partition(" ")
|
||||||
global sshuttle_pid
|
global sshuttle_pid
|
||||||
udp, user, group, tmark, sshuttle_pid = args.strip().split(" ", 4)
|
udp, user, group, tmark, sshuttle_pid = args.split(" ", 4)
|
||||||
udp = bool(int(udp))
|
udp = bool(int(udp))
|
||||||
sshuttle_pid = int(sshuttle_pid)
|
sshuttle_pid = int(sshuttle_pid)
|
||||||
if user == '-':
|
if user == '-':
|
||||||
@ -350,7 +358,7 @@ def main(method_name, syslog):
|
|||||||
flush_systemd_dns_cache()
|
flush_systemd_dns_cache()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stdout.write('STARTED\n')
|
stdout.write(b'STARTED\n')
|
||||||
stdout.flush()
|
stdout.flush()
|
||||||
except IOError as e: # the parent process probably died
|
except IOError as e: # the parent process probably died
|
||||||
debug3('write to stdout failed: %s' % (e,))
|
debug3('write to stdout failed: %s' % (e,))
|
||||||
@ -360,13 +368,11 @@ def main(method_name, syslog):
|
|||||||
# to stay running so that we don't need a *second* password
|
# to stay running so that we don't need a *second* password
|
||||||
# authentication at shutdown time - that cleanup is important!
|
# authentication at shutdown time - that cleanup is important!
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
line = _read_next_string_line()
|
||||||
line = stdin.readline(128)
|
if not line:
|
||||||
except IOError as e:
|
|
||||||
debug3('read from stdin failed: %s' % (e,))
|
|
||||||
return
|
return
|
||||||
if line.startswith('HOST '):
|
if line.startswith('HOST '):
|
||||||
(name, ip) = line[5:].strip().split(',', 1)
|
(name, ip) = line[5:].split(',', 1)
|
||||||
hostmap[name] = ip
|
hostmap[name] = ip
|
||||||
debug2('setting up /etc/hosts.')
|
debug2('setting up /etc/hosts.')
|
||||||
rewrite_etc_hosts(hostmap, port_v6 or port_v4)
|
rewrite_etc_hosts(hostmap, port_v6 or port_v4)
|
||||||
|
@ -1,479 +1,479 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import re
|
import re
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
import struct
|
import struct
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
from sshuttle.methods import BaseMethod
|
from sshuttle.methods import BaseMethod
|
||||||
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal
|
from sshuttle.helpers import debug3, log, debug1, debug2, Fatal
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
# https://reqrypt.org/windivert-doc.html#divert_iphdr
|
||||||
import pydivert
|
import pydivert
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert")
|
raise Exception("Could not import pydivert module. windivert requires https://pypi.org/project/pydivert")
|
||||||
|
|
||||||
|
|
||||||
ConnectionTuple = namedtuple(
|
ConnectionTuple = namedtuple(
|
||||||
"ConnectionTuple",
|
"ConnectionTuple",
|
||||||
["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"],
|
["protocol", "ip_version", "src_addr", "src_port", "dst_addr", "dst_port", "state_epoch", "state"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
WINDIVERT_MAX_CONNECTIONS = 10_000
|
WINDIVERT_MAX_CONNECTIONS = 10_000
|
||||||
|
|
||||||
|
|
||||||
class IPProtocol(IntEnum):
|
class IPProtocol(IntEnum):
|
||||||
TCP = socket.IPPROTO_TCP
|
TCP = socket.IPPROTO_TCP
|
||||||
UDP = socket.IPPROTO_UDP
|
UDP = socket.IPPROTO_UDP
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filter(self):
|
def filter(self):
|
||||||
return "tcp" if self == IPProtocol.TCP else "udp"
|
return "tcp" if self == IPProtocol.TCP else "udp"
|
||||||
|
|
||||||
|
|
||||||
class IPFamily(IntEnum):
|
class IPFamily(IntEnum):
|
||||||
IPv4 = socket.AF_INET
|
IPv4 = socket.AF_INET
|
||||||
IPv6 = socket.AF_INET6
|
IPv6 = socket.AF_INET6
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filter(self):
|
def filter(self):
|
||||||
return "ip" if self == socket.AF_INET else "ipv6"
|
return "ip" if self == socket.AF_INET else "ipv6"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self):
|
def version(self):
|
||||||
return 4 if self == socket.AF_INET else 6
|
return 4 if self == socket.AF_INET else 6
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loopback_addr(self):
|
def loopback_addr(self):
|
||||||
return "127.0.0.1" if self == socket.AF_INET else "::1"
|
return "127.0.0.1" if self == socket.AF_INET else "::1"
|
||||||
|
|
||||||
|
|
||||||
class ConnState(IntEnum):
|
class ConnState(IntEnum):
|
||||||
TCP_SYN_SENT = 11 # SYN sent
|
TCP_SYN_SENT = 11 # SYN sent
|
||||||
TCP_ESTABLISHED = 12 # SYN+ACK received
|
TCP_ESTABLISHED = 12 # SYN+ACK received
|
||||||
TCP_FIN_WAIT_1 = 91 # FIN sent
|
TCP_FIN_WAIT_1 = 91 # FIN sent
|
||||||
TCP_CLOSE_WAIT = 92 # FIN received
|
TCP_CLOSE_WAIT = 92 # FIN received
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def can_timeout(state):
|
def can_timeout(state):
|
||||||
return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT)
|
return state in (ConnState.TCP_SYN_SENT, ConnState.TCP_FIN_WAIT_1, ConnState.TCP_CLOSE_WAIT)
|
||||||
|
|
||||||
|
|
||||||
def repr_pkt(p):
|
def repr_pkt(p):
|
||||||
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
|
r = f"{p.direction.name} {p.src_addr}:{p.src_port}->{p.dst_addr}:{p.dst_port}"
|
||||||
if p.tcp:
|
if p.tcp:
|
||||||
t = p.tcp
|
t = p.tcp
|
||||||
r += f" {len(t.payload)}B ("
|
r += f" {len(t.payload)}B ("
|
||||||
r += "+".join(
|
r += "+".join(
|
||||||
f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f)
|
f.upper() for f in ("fin", "syn", "rst", "psh", "ack", "urg", "ece", "cwr", "ns") if getattr(t, f)
|
||||||
)
|
)
|
||||||
r += f") SEQ#{t.seq_num}"
|
r += f") SEQ#{t.seq_num}"
|
||||||
if t.ack:
|
if t.ack:
|
||||||
r += f" ACK#{t.ack_num}"
|
r += f" ACK#{t.ack_num}"
|
||||||
r += f" WZ={t.window_size}"
|
r += f" WZ={t.window_size}"
|
||||||
else:
|
else:
|
||||||
r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}"
|
r += f" {p.udp=} {p.icmpv4=} {p.icmpv6=}"
|
||||||
return f"<Pkt {r}>"
|
return f"<Pkt {r}>"
|
||||||
|
|
||||||
|
|
||||||
def synchronized_method(lock):
|
def synchronized_method(lock):
|
||||||
def decorator(method):
|
def decorator(method):
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def wrapped(self, *args, **kwargs):
|
def wrapped(self, *args, **kwargs):
|
||||||
with getattr(self, lock):
|
with getattr(self, lock):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class ConnTrack:
|
class ConnTrack:
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
if not cls._instance:
|
if not cls._instance:
|
||||||
cls._instance = object.__new__(cls)
|
cls._instance = object.__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
raise RuntimeError("ConnTrack can not be instantiated multiple times")
|
||||||
|
|
||||||
def __init__(self, name, max_connections=0) -> None:
|
def __init__(self, name, max_connections=0) -> None:
|
||||||
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
|
self.struct_full_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H", "16s", "H", "L", "B")))
|
||||||
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
|
self.struct_src_tuple = struct.Struct(">" + "".join(("B", "B", "16s", "H")))
|
||||||
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
|
self.struct_state_tuple = struct.Struct(">" + "".join(("L", "B")))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.max_connections = max_connections
|
self.max_connections = max_connections
|
||||||
self.shm_list = shared_memory.ShareableList(
|
self.shm_list = shared_memory.ShareableList(
|
||||||
[bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name
|
[bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name
|
||||||
)
|
)
|
||||||
self.is_owner = True
|
self.is_owner = True
|
||||||
self.next_slot = 0
|
self.next_slot = 0
|
||||||
self.used_slots = set()
|
self.used_slots = set()
|
||||||
self.rlock = threading.RLock()
|
self.rlock = threading.RLock()
|
||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
self.is_owner = False
|
self.is_owner = False
|
||||||
self.shm_list = shared_memory.ShareableList(name=name)
|
self.shm_list = shared_memory.ShareableList(name=name)
|
||||||
self.max_connections = len(self.shm_list)
|
self.max_connections = len(self.shm_list)
|
||||||
|
|
||||||
debug2(
|
debug2(
|
||||||
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
|
f"ConnTrack: is_owner={self.is_owner} entry_size={self.struct_full_tuple.size} shm_name={self.shm_list.shm.name} "
|
||||||
f"shm_size={self.shm_list.shm.size}B"
|
f"shm_size={self.shm_list.shm.size}B"
|
||||||
)
|
)
|
||||||
|
|
||||||
@synchronized_method("rlock")
|
@synchronized_method("rlock")
|
||||||
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
|
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
|
||||||
if not self.is_owner:
|
if not self.is_owner:
|
||||||
raise RuntimeError("Only owner can mutate ConnTrack")
|
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||||
if len(self.used_slots) >= self.max_connections:
|
if len(self.used_slots) >= self.max_connections:
|
||||||
raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}")
|
raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}")
|
||||||
|
|
||||||
if self.get(proto, src_addr, src_port):
|
if self.get(proto, src_addr, src_port):
|
||||||
return
|
return
|
||||||
|
|
||||||
for _ in range(self.max_connections):
|
for _ in range(self.max_connections):
|
||||||
if self.next_slot not in self.used_slots:
|
if self.next_slot not in self.used_slots:
|
||||||
break
|
break
|
||||||
self.next_slot = (self.next_slot + 1) % self.max_connections
|
self.next_slot = (self.next_slot + 1) % self.max_connections
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("No slot available in ConnTrack") # should not be here
|
raise RuntimeError("No slot available in ConnTrack") # should not be here
|
||||||
|
|
||||||
src_addr = ipaddress.ip_address(src_addr)
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
dst_addr = ipaddress.ip_address(dst_addr)
|
dst_addr = ipaddress.ip_address(dst_addr)
|
||||||
assert src_addr.version == dst_addr.version
|
assert src_addr.version == dst_addr.version
|
||||||
ip_version = src_addr.version
|
ip_version = src_addr.version
|
||||||
state_epoch = int(time.time())
|
state_epoch = int(time.time())
|
||||||
entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state)
|
entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state)
|
||||||
packed = self.struct_full_tuple.pack(*entry)
|
packed = self.struct_full_tuple.pack(*entry)
|
||||||
self.shm_list[self.next_slot] = packed
|
self.shm_list[self.next_slot] = packed
|
||||||
self.used_slots.add(self.next_slot)
|
self.used_slots.add(self.next_slot)
|
||||||
proto = IPProtocol(proto)
|
proto = IPProtocol(proto)
|
||||||
debug3(
|
debug3(
|
||||||
f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to "
|
f"ConnTrack: added ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to "
|
||||||
f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}"
|
f"slot={self.next_slot} | #ActiveConn={len(self.used_slots)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@synchronized_method("rlock")
|
@synchronized_method("rlock")
|
||||||
def update(self, proto, src_addr, src_port, state):
|
def update(self, proto, src_addr, src_port, state):
|
||||||
if not self.is_owner:
|
if not self.is_owner:
|
||||||
raise RuntimeError("Only owner can mutate ConnTrack")
|
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||||
src_addr = ipaddress.ip_address(src_addr)
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
||||||
for i in self.used_slots:
|
for i in self.used_slots:
|
||||||
if self.shm_list[i].startswith(packed):
|
if self.shm_list[i].startswith(packed):
|
||||||
state_epoch = int(time.time())
|
state_epoch = int(time.time())
|
||||||
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
|
||||||
debug3(
|
debug3(
|
||||||
f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | "
|
f"ConnTrack: updated ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | "
|
||||||
f"#ActiveConn={len(self.used_slots)}"
|
f"#ActiveConn={len(self.used_slots)}"
|
||||||
)
|
)
|
||||||
return self._unpack(self.shm_list[i])
|
return self._unpack(self.shm_list[i])
|
||||||
else:
|
else:
|
||||||
debug3(
|
debug3(
|
||||||
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | "
|
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | "
|
||||||
f"#ActiveConn={len(self.used_slots)}"
|
f"#ActiveConn={len(self.used_slots)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@synchronized_method("rlock")
|
@synchronized_method("rlock")
|
||||||
def remove(self, proto, src_addr, src_port):
|
def remove(self, proto, src_addr, src_port):
|
||||||
if not self.is_owner:
|
if not self.is_owner:
|
||||||
raise RuntimeError("Only owner can mutate ConnTrack")
|
raise RuntimeError("Only owner can mutate ConnTrack")
|
||||||
src_addr = ipaddress.ip_address(src_addr)
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
||||||
for i in self.used_slots:
|
for i in self.used_slots:
|
||||||
if self.shm_list[i].startswith(packed):
|
if self.shm_list[i].startswith(packed):
|
||||||
conn = self._unpack(self.shm_list[i])
|
conn = self._unpack(self.shm_list[i])
|
||||||
self.shm_list[i] = b""
|
self.shm_list[i] = b""
|
||||||
self.used_slots.remove(i)
|
self.used_slots.remove(i)
|
||||||
debug3(
|
debug3(
|
||||||
f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | "
|
f"ConnTrack: removed ({proto.name} src={src_addr}:{src_port} state={conn.state.name}) from slot={i} | "
|
||||||
f"#ActiveConn={len(self.used_slots)}"
|
f"#ActiveConn={len(self.used_slots)}"
|
||||||
)
|
)
|
||||||
return conn
|
return conn
|
||||||
else:
|
else:
|
||||||
debug3(
|
debug3(
|
||||||
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |"
|
f"ConnTrack: ({proto.name} src={src_addr}:{src_port}) is not found to remove |"
|
||||||
f" #ActiveConn={len(self.used_slots)}"
|
f" #ActiveConn={len(self.used_slots)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, proto, src_addr, src_port):
|
def get(self, proto, src_addr, src_port):
|
||||||
src_addr = ipaddress.ip_address(src_addr)
|
src_addr = ipaddress.ip_address(src_addr)
|
||||||
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
|
||||||
for entry in self.shm_list:
|
for entry in self.shm_list:
|
||||||
if entry and entry.startswith(packed):
|
if entry and entry.startswith(packed):
|
||||||
return self._unpack(entry)
|
return self._unpack(entry)
|
||||||
|
|
||||||
@synchronized_method("rlock")
|
@synchronized_method("rlock")
|
||||||
def gc(self, connection_timeout_sec=15):
|
def gc(self, connection_timeout_sec=15):
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
n = 0
|
n = 0
|
||||||
for i in tuple(self.used_slots):
|
for i in tuple(self.used_slots):
|
||||||
state_packed = self.shm_list[i][-5:]
|
state_packed = self.shm_list[i][-5:]
|
||||||
(state_epoch, state) = self.struct_state_tuple.unpack(state_packed)
|
(state_epoch, state) = self.struct_state_tuple.unpack(state_packed)
|
||||||
if (now - state_epoch) < connection_timeout_sec:
|
if (now - state_epoch) < connection_timeout_sec:
|
||||||
continue
|
continue
|
||||||
if ConnState.can_timeout(state):
|
if ConnState.can_timeout(state):
|
||||||
conn = self._unpack(self.shm_list[i])
|
conn = self._unpack(self.shm_list[i])
|
||||||
self.shm_list[i] = b""
|
self.shm_list[i] = b""
|
||||||
self.used_slots.remove(i)
|
self.used_slots.remove(i)
|
||||||
n += 1
|
n += 1
|
||||||
debug3(
|
debug3(
|
||||||
f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})"
|
f"ConnTrack: GC: removed ({conn.protocol.name} src={conn.src_addr}:{conn.src_port} state={conn.state.name})"
|
||||||
f" from slot={i} | #ActiveConn={len(self.used_slots)}"
|
f" from slot={i} | #ActiveConn={len(self.used_slots)}"
|
||||||
)
|
)
|
||||||
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
|
debug3(f"ConnTrack: GC: collected {n} connections | #ActiveConn={len(self.used_slots)}")
|
||||||
|
|
||||||
def _unpack(self, packed):
|
def _unpack(self, packed):
|
||||||
(
|
(
|
||||||
proto,
|
proto,
|
||||||
ip_version,
|
ip_version,
|
||||||
src_addr_packed,
|
src_addr_packed,
|
||||||
src_port,
|
src_port,
|
||||||
dst_addr_packed,
|
dst_addr_packed,
|
||||||
dst_port,
|
dst_port,
|
||||||
state_epoch,
|
state_epoch,
|
||||||
state,
|
state,
|
||||||
) = self.struct_full_tuple.unpack(packed)
|
) = self.struct_full_tuple.unpack(packed)
|
||||||
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
dst_addr = str(ipaddress.ip_address(dst_addr_packed if ip_version == 6 else dst_addr_packed[:4]))
|
||||||
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
src_addr = str(ipaddress.ip_address(src_addr_packed if ip_version == 6 else src_addr_packed[:4]))
|
||||||
return ConnectionTuple(
|
return ConnectionTuple(
|
||||||
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
|
IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
def conn_iter():
|
def conn_iter():
|
||||||
for i in self.used_slots:
|
for i in self.used_slots:
|
||||||
yield self._unpack(self.shm_list[i])
|
yield self._unpack(self.shm_list[i])
|
||||||
|
|
||||||
return conn_iter()
|
return conn_iter()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'},cap={len(self.shm_list)},owner={self.is_owner})>"
|
return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'},cap={len(self.shm_list)},owner={self.is_owner})>"
|
||||||
|
|
||||||
|
|
||||||
class Method(BaseMethod):
|
class Method(BaseMethod):
|
||||||
|
|
||||||
network_config = {}
|
network_config = {}
|
||||||
proxy_port = None
|
proxy_port = None
|
||||||
proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None}
|
proxy_addr = {IPFamily.IPv4: None, IPFamily.IPv6: None}
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
|
||||||
def _get_local_proxy_listen_addr(self, port, family):
|
def _get_local_proxy_listen_addr(self, port, family):
|
||||||
proto = "TCPv6" if family.version == 6 else "TCP"
|
proto = "TCPv6" if family.version == 6 else "TCP"
|
||||||
for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines():
|
for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines():
|
||||||
try:
|
try:
|
||||||
_, local_addr, _, state, *_ = re.split(r"\s+", line.strip())
|
_, local_addr, _, state, *_ = re.split(r"\s+", line.strip())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
port_suffix = ":" + str(port)
|
port_suffix = ":" + str(port)
|
||||||
if state == "LISTENING" and local_addr.endswith(port_suffix):
|
if state == "LISTENING" and local_addr.endswith(port_suffix):
|
||||||
return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]"))
|
return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]"))
|
||||||
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
|
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
|
||||||
|
|
||||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark):
|
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark):
|
||||||
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||||
|
|
||||||
if nslist or user or udp:
|
if nslist or user or udp:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
family = IPFamily(family)
|
family = IPFamily(family)
|
||||||
|
|
||||||
# using loopback proxy address never worked.
|
# using loopback proxy address never worked.
|
||||||
# >>> self.proxy_addr[family] = family.loopback_addr
|
# >>> self.proxy_addr[family] = family.loopback_addr
|
||||||
# See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
|
# See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 ,https://github.com/basil00/Divert/issues/82)
|
||||||
# As a workaround we use another interface ip instead.
|
# As a workaround we use another interface ip instead.
|
||||||
|
|
||||||
local_addr = self._get_local_proxy_listen_addr(port, family)
|
local_addr = self._get_local_proxy_listen_addr(port, family)
|
||||||
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
|
for addr in (ipaddress.ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)):
|
||||||
if addr.is_loopback or addr.version != family.version:
|
if addr.is_loopback or addr.version != family.version:
|
||||||
continue
|
continue
|
||||||
if local_addr.is_unspecified or local_addr == addr:
|
if local_addr.is_unspecified or local_addr == addr:
|
||||||
debug2("Found non loopback address to connect to proxy: " + str(addr))
|
debug2("Found non loopback address to connect to proxy: " + str(addr))
|
||||||
self.proxy_addr[family] = str(addr)
|
self.proxy_addr[family] = str(addr)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise Fatal("Windivert method requires proxy to listen on non loopback address")
|
raise Fatal("Windivert method requires proxy to listen on non loopback address")
|
||||||
|
|
||||||
self.proxy_port = port
|
self.proxy_port = port
|
||||||
|
|
||||||
subnet_addresses = []
|
subnet_addresses = []
|
||||||
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
for (_, mask, exclude, network_addr, fport, lport) in subnets:
|
||||||
if exclude:
|
if exclude:
|
||||||
continue
|
continue
|
||||||
assert fport == 0, "custom port range not supported"
|
assert fport == 0, "custom port range not supported"
|
||||||
assert lport == 0, "custom port range not supported"
|
assert lport == 0, "custom port range not supported"
|
||||||
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
subnet_addresses.append("%s/%s" % (network_addr, mask))
|
||||||
|
|
||||||
self.network_config[family] = {
|
self.network_config[family] = {
|
||||||
"subnets": subnet_addresses,
|
"subnets": subnet_addresses,
|
||||||
"nslist": nslist,
|
"nslist": nslist,
|
||||||
}
|
}
|
||||||
|
|
||||||
def wait_for_firewall_ready(self):
|
def wait_for_firewall_ready(self):
|
||||||
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
debug2(f"network_config={self.network_config} proxy_addr={self.proxy_addr}")
|
||||||
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
|
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getppid()}", WINDIVERT_MAX_CONNECTIONS)
|
||||||
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
methods = (self._egress_divert, self._ingress_divert, self._connection_gc)
|
||||||
ready_events = []
|
ready_events = []
|
||||||
for fn in methods:
|
for fn in methods:
|
||||||
ev = threading.Event()
|
ev = threading.Event()
|
||||||
ready_events.append(ev)
|
ready_events.append(ev)
|
||||||
|
|
||||||
def _target():
|
def _target():
|
||||||
try:
|
try:
|
||||||
fn(ev.set)
|
fn(ev.set)
|
||||||
except Exception:
|
except Exception:
|
||||||
debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc())
|
debug2(f"thread {fn.__name__} exiting due to: " + traceback.format_exc())
|
||||||
sys.stdin.close() # this will exist main thread
|
sys.stdin.close() # this will exist main thread
|
||||||
sys.stdout.close()
|
sys.stdout.close()
|
||||||
|
|
||||||
threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
|
threading.Thread(name=fn.__name__, target=_target, daemon=True).start()
|
||||||
for ev in ready_events:
|
for ev in ready_events:
|
||||||
if not ev.wait(5): # at most 5 sec
|
if not ev.wait(5): # at most 5 sec
|
||||||
raise Fatal("timeout in wait_for_firewall_ready()")
|
raise Fatal("timeout in wait_for_firewall_ready()")
|
||||||
|
|
||||||
def restore_firewall(self, port, family, udp, user):
|
def restore_firewall(self, port, family, udp, user, group):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_supported_features(self):
|
def get_supported_features(self):
|
||||||
result = super(Method, self).get_supported_features()
|
result = super(Method, self).get_supported_features()
|
||||||
result.loopback_port = False
|
result.loopback_port = False
|
||||||
result.user = False
|
result.user = False
|
||||||
result.dns = False
|
result.dns = False
|
||||||
result.ipv6 = False
|
result.ipv6 = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_tcp_dstip(self, sock):
|
def get_tcp_dstip(self, sock):
|
||||||
if not hasattr(self, "conntrack"):
|
if not hasattr(self, "conntrack"):
|
||||||
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
|
self.conntrack = ConnTrack(f"sshuttle-windivert-{os.getpid()}")
|
||||||
|
|
||||||
src_addr, src_port = sock.getpeername()
|
src_addr, src_port = sock.getpeername()
|
||||||
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
|
c = self.conntrack.get(IPProtocol.TCP, src_addr, src_port)
|
||||||
if not c:
|
if not c:
|
||||||
return (src_addr, src_port)
|
return (src_addr, src_port)
|
||||||
return (c.dst_addr, c.dst_port)
|
return (c.dst_addr, c.dst_port)
|
||||||
|
|
||||||
def is_supported(self):
|
def is_supported(self):
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _egress_divert(self, ready_cb):
|
def _egress_divert(self, ready_cb):
|
||||||
proto = IPProtocol.TCP
|
proto = IPProtocol.TCP
|
||||||
filter = f"outbound and {proto.filter}"
|
filter = f"outbound and {proto.filter}"
|
||||||
|
|
||||||
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
# with pydivert.WinDivert(f"outbound and tcp and ip.DstAddr == {subnet}") as w:
|
||||||
family_filters = []
|
family_filters = []
|
||||||
for af, c in self.network_config.items():
|
for af, c in self.network_config.items():
|
||||||
subnet_filters = []
|
subnet_filters = []
|
||||||
for cidr in c["subnets"]:
|
for cidr in c["subnets"]:
|
||||||
ip_network = ipaddress.ip_network(cidr)
|
ip_network = ipaddress.ip_network(cidr)
|
||||||
first_ip = ip_network.network_address
|
first_ip = ip_network.network_address
|
||||||
last_ip = ip_network.broadcast_address
|
last_ip = ip_network.broadcast_address
|
||||||
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
|
subnet_filters.append(f"(ip.DstAddr>={first_ip} and ip.DstAddr<={last_ip})")
|
||||||
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
|
family_filters.append(f"{af.filter} and ({' or '.join(subnet_filters)}) ")
|
||||||
|
|
||||||
filter = f"{filter} and ({' or '.join(family_filters)})"
|
filter = f"{filter} and ({' or '.join(family_filters)})"
|
||||||
|
|
||||||
debug1(f"[OUTBOUND] {filter=}")
|
debug1(f"[OUTBOUND] {filter=}")
|
||||||
with pydivert.WinDivert(filter) as w:
|
with pydivert.WinDivert(filter) as w:
|
||||||
ready_cb()
|
ready_cb()
|
||||||
proxy_port = self.proxy_port
|
proxy_port = self.proxy_port
|
||||||
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
|
proxy_addr_ipv4 = self.proxy_addr[IPFamily.IPv4]
|
||||||
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
proxy_addr_ipv6 = self.proxy_addr[IPFamily.IPv6]
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
debug3(">>> " + repr_pkt(pkt))
|
debug3(">>> " + repr_pkt(pkt))
|
||||||
if pkt.tcp.syn and not pkt.tcp.ack:
|
if pkt.tcp.syn and not pkt.tcp.ack:
|
||||||
# SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
|
# SYN sent (start of 3-way handshake connection establishment from our side, we wait for SYN+ACK)
|
||||||
self.conntrack.add(
|
self.conntrack.add(
|
||||||
socket.IPPROTO_TCP,
|
socket.IPPROTO_TCP,
|
||||||
pkt.src_addr,
|
pkt.src_addr,
|
||||||
pkt.src_port,
|
pkt.src_port,
|
||||||
pkt.dst_addr,
|
pkt.dst_addr,
|
||||||
pkt.dst_port,
|
pkt.dst_port,
|
||||||
ConnState.TCP_SYN_SENT,
|
ConnState.TCP_SYN_SENT,
|
||||||
)
|
)
|
||||||
if pkt.tcp.fin:
|
if pkt.tcp.fin:
|
||||||
# FIN sent (start of graceful close our side, and we wait for ACK)
|
# FIN sent (start of graceful close our side, and we wait for ACK)
|
||||||
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
|
self.conntrack.update(IPProtocol.TCP, pkt.src_addr, pkt.src_port, ConnState.TCP_FIN_WAIT_1)
|
||||||
if pkt.tcp.rst:
|
if pkt.tcp.rst:
|
||||||
# RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
|
# RST sent (initiate abrupt connection teardown from our side, so we don't expect any reply)
|
||||||
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
self.conntrack.remove(IPProtocol.TCP, pkt.src_addr, pkt.src_port)
|
||||||
|
|
||||||
# DNAT
|
# DNAT
|
||||||
if pkt.ipv4 and proxy_addr_ipv4:
|
if pkt.ipv4 and proxy_addr_ipv4:
|
||||||
pkt.dst_addr = proxy_addr_ipv4
|
pkt.dst_addr = proxy_addr_ipv4
|
||||||
if pkt.ipv6 and proxy_addr_ipv6:
|
if pkt.ipv6 and proxy_addr_ipv6:
|
||||||
pkt.dst_addr = proxy_addr_ipv6
|
pkt.dst_addr = proxy_addr_ipv6
|
||||||
pkt.tcp.dst_port = proxy_port
|
pkt.tcp.dst_port = proxy_port
|
||||||
|
|
||||||
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well
|
# XXX: If we set loopback proxy address (DNAT), then we should do SNAT as well
|
||||||
# by setting src_addr to loopback address.
|
# by setting src_addr to loopback address.
|
||||||
# Otherwise injecting packet will be ignored by Windows network stack
|
# Otherwise injecting packet will be ignored by Windows network stack
|
||||||
# as they packet has to cross public to private address space.
|
# as they packet has to cross public to private address space.
|
||||||
# See: https://github.com/basil00/Divert/issues/82
|
# See: https://github.com/basil00/Divert/issues/82
|
||||||
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
|
# Managing SNAT is more trickier, as we have to restore the original source IP address for reply packets.
|
||||||
# >>> pkt.dst_addr = proxy_addr_ipv4
|
# >>> pkt.dst_addr = proxy_addr_ipv4
|
||||||
|
|
||||||
w.send(pkt, recalculate_checksum=True)
|
w.send(pkt, recalculate_checksum=True)
|
||||||
|
|
||||||
def _ingress_divert(self, ready_cb):
|
def _ingress_divert(self, ready_cb):
|
||||||
proto = IPProtocol.TCP
|
proto = IPProtocol.TCP
|
||||||
direction = "inbound" # only when proxy address is not loopback address (Useful for testing)
|
direction = "inbound" # only when proxy address is not loopback address (Useful for testing)
|
||||||
ip_filters = []
|
ip_filters = []
|
||||||
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
|
for addr in (ipaddress.ip_address(a) for a in self.proxy_addr.values() if a):
|
||||||
if addr.is_loopback: # Windivert treats all loopback traffic as outbound
|
if addr.is_loopback: # Windivert treats all loopback traffic as outbound
|
||||||
direction = "outbound"
|
direction = "outbound"
|
||||||
if addr.version == 4:
|
if addr.version == 4:
|
||||||
ip_filters.append(f"ip.SrcAddr=={addr}")
|
ip_filters.append(f"ip.SrcAddr=={addr}")
|
||||||
else:
|
else:
|
||||||
# ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this
|
# ip_checks.append(f"ip.SrcAddr=={hex(int(addr))}") # only Windivert >=2 supports this
|
||||||
ip_filters.append(f"ipv6.SrcAddr=={addr}")
|
ip_filters.append(f"ipv6.SrcAddr=={addr}")
|
||||||
if not ip_filters:
|
if not ip_filters:
|
||||||
raise Fatal("At least ipv4 or ipv6 address is expected")
|
raise Fatal("At least ipv4 or ipv6 address is expected")
|
||||||
filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}"
|
filter = f"{direction} and {proto.filter} and ({' or '.join(ip_filters)}) and tcp.SrcPort=={self.proxy_port}"
|
||||||
debug1(f"[INGRESS] {filter=}")
|
debug1(f"[INGRESS] {filter=}")
|
||||||
with pydivert.WinDivert(filter) as w:
|
with pydivert.WinDivert(filter) as w:
|
||||||
ready_cb()
|
ready_cb()
|
||||||
for pkt in w:
|
for pkt in w:
|
||||||
debug3("<<< " + repr_pkt(pkt))
|
debug3("<<< " + repr_pkt(pkt))
|
||||||
if pkt.tcp.syn and pkt.tcp.ack:
|
if pkt.tcp.syn and pkt.tcp.ack:
|
||||||
# SYN+ACK received (connection established)
|
# SYN+ACK received (connection established)
|
||||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
|
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
|
||||||
elif pkt.tcp.rst:
|
elif pkt.tcp.rst:
|
||||||
# RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets
|
# RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets
|
||||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
# https://wiki.wireshark.org/TCP-4-times-close.md
|
# https://wiki.wireshark.org/TCP-4-times-close.md
|
||||||
elif pkt.tcp.fin and pkt.tcp.ack:
|
elif pkt.tcp.fin and pkt.tcp.ack:
|
||||||
# FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK)
|
# FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK)
|
||||||
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
elif pkt.tcp.fin:
|
elif pkt.tcp.fin:
|
||||||
# FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet)
|
# FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet)
|
||||||
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
|
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
|
||||||
else:
|
else:
|
||||||
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
|
||||||
if not conn:
|
if not conn:
|
||||||
debug2("Unexpected packet: " + repr_pkt(pkt))
|
debug2("Unexpected packet: " + repr_pkt(pkt))
|
||||||
continue
|
continue
|
||||||
pkt.src_addr = conn.dst_addr
|
pkt.src_addr = conn.dst_addr
|
||||||
pkt.tcp.src_port = conn.dst_port
|
pkt.tcp.src_port = conn.dst_port
|
||||||
w.send(pkt, recalculate_checksum=True)
|
w.send(pkt, recalculate_checksum=True)
|
||||||
|
|
||||||
def _connection_gc(self, ready_cb):
|
def _connection_gc(self, ready_cb):
|
||||||
ready_cb()
|
ready_cb()
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
self.conntrack.gc()
|
self.conntrack.gc()
|
||||||
|
@ -260,7 +260,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
|||||||
|
|
||||||
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
|
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
|
||||||
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
|
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
|
||||||
return s2
|
return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0)
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows
|
# https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows
|
||||||
close_fds = False if sys.platform == 'win32' else True
|
close_fds = False if sys.platform == 'win32' else True
|
||||||
|
@ -10,7 +10,7 @@ import sshuttle.firewall
|
|||||||
|
|
||||||
|
|
||||||
def setup_daemon():
|
def setup_daemon():
|
||||||
stdin = io.StringIO(u"""ROUTES
|
stdin = io.BytesIO(u"""ROUTES
|
||||||
{inet},24,0,1.2.3.0,8000,9000
|
{inet},24,0,1.2.3.0,8000,9000
|
||||||
{inet},32,1,1.2.3.66,8080,8080
|
{inet},32,1,1.2.3.66,8080,8080
|
||||||
{inet6},64,0,2404:6800:4004:80c::,0,0
|
{inet6},64,0,2404:6800:4004:80c::,0,0
|
||||||
@ -127,9 +127,9 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert stdout.mock_calls == [
|
assert stdout.mock_calls == [
|
||||||
call.write('READY test\n'),
|
call.write(b'READY test\n'),
|
||||||
call.flush(),
|
call.flush(),
|
||||||
call.write('STARTED\n'),
|
call.write(b'STARTED\n'),
|
||||||
call.flush()
|
call.flush()
|
||||||
]
|
]
|
||||||
assert mock_setup_daemon.mock_calls == [call()]
|
assert mock_setup_daemon.mock_calls == [call()]
|
||||||
|
Loading…
Reference in New Issue
Block a user