code cleanup and small refactoring

This commit is contained in:
nom3ad 2024-01-07 21:32:50 +05:30 committed by Brian May
parent de8a19ce69
commit 72060abbef
6 changed files with 73 additions and 63 deletions

View File

@ -58,7 +58,7 @@ else
host=$node host=$node
fi fi
if [[ "${args[$((${#args[@]} - 1))]}" != *.* && "${args[$((${#args[@]} - 1))]}" != *:* ]]; then if [[ "${#args[@]}" -ne 0 && "${args[$((${#args[@]} - 1))]}" != *.* && "${args[$((${#args[@]} - 1))]}" != *:* ]]; then
echo "No subnet specified. Using -N" >&2 echo "No subnet specified. Using -N" >&2
args+=('-N') args+=('-N')
fi fi

View File

@ -211,8 +211,8 @@ class FirewallClient:
self.auto_nets = [] self.auto_nets = []
argv0 = sys.argv[0] argv0 = sys.argv[0]
# if argv0 is a not python script, it shall be an executable. # argv0 is either be a normal python file or an executable.
# In windows it will be a .exe file and other platforms it will be a shebang script # After installed as a package, sshuttle command points to an .exe in Windows and python shebang script elsewhere.
argvbase = (([sys.executable, sys.argv[0]] if argv0.endswith('.py') else [argv0]) + argvbase = (([sys.executable, sys.argv[0]] if argv0.endswith('.py') else [argv0]) +
['-v'] * (helpers.verbose or 0) + ['-v'] * (helpers.verbose or 0) +
['--method', method_name] + ['--method', method_name] +
@ -234,7 +234,7 @@ class FirewallClient:
# Because underlying ShellExecute() Windows api does not allow child process to inherit stdio. # Because underlying ShellExecute() Windows api does not allow child process to inherit stdio.
# TODO(nom3ad): Try to implement another way to achieve this. # TODO(nom3ad): Try to implement another way to achieve this.
raise Fatal("Privilege elevation for Windows is not yet implemented. Please run from an administrator shell") raise Fatal("Privilege elevation for Windows is not yet implemented. Please run from an administrator shell")
else:
# Linux typically uses sudo; OpenBSD uses doas. However, some # Linux typically uses sudo; OpenBSD uses doas. However, some
# Linux distributions are starting to use doas. # Linux distributions are starting to use doas.
sudo_cmd = ['sudo', '-p', '[local sudo] Password: '] sudo_cmd = ['sudo', '-p', '[local sudo] Password: ']
@ -874,7 +874,7 @@ def main(listenip_v6, listenip_v4,
# listenip_v4 contains user specified value or it is set to "auto". # listenip_v4 contains user specified value or it is set to "auto".
if listenip_v4 == "auto": if listenip_v4 == "auto":
listenip_v4 = ('127.0.0.1' if avail.loopback_port else '0.0.0.0', 0) listenip_v4 = ('127.0.0.1' if avail.loopback_proxy_port else '0.0.0.0', 0)
debug1("Using default IPv4 listen address " + listenip_v4[0]) debug1("Using default IPv4 listen address " + listenip_v4[0])
# listenip_v6 is... # listenip_v6 is...
@ -885,7 +885,7 @@ def main(listenip_v6, listenip_v4,
debug1("IPv6 disabled by --disable-ipv6") debug1("IPv6 disabled by --disable-ipv6")
if listenip_v6 == "auto": if listenip_v6 == "auto":
if avail.ipv6: if avail.ipv6:
listenip_v6 = ('::1' if avail.loopback_port else '::', 0) listenip_v6 = ('::1' if avail.loopback_proxy_port else '::', 0)
debug1("IPv6 enabled: Using default IPv6 listen address " + listenip_v6[0]) debug1("IPv6 enabled: Using default IPv6 listen address " + listenip_v6[0])
else: else:
debug1("IPv6 disabled since it isn't supported by method " debug1("IPv6 disabled since it isn't supported by method "

View File

@ -242,7 +242,7 @@ def is_admin_user():
except Exception: except Exception:
return False return False
# TODO(nom3ad): for sys.platform == 'linux', support capabilities check for non-root users. (CAP_NET_ADMIN might be enough?) # TODO(nom3ad): for sys.platform == 'linux', check capabilities for non-root users. (CAP_NET_ADMIN might be enough?)
return os.getuid() == 0 return os.getuid() == 0

View File

@ -46,7 +46,7 @@ class BaseMethod(object):
@staticmethod @staticmethod
def get_supported_features(): def get_supported_features():
result = Features() result = Features()
result.loopback_port = True result.loopback_proxy_port = True
result.ipv4 = True result.ipv4 = True
result.ipv6 = False result.ipv6 = False
result.udp = False result.udp = False

View File

@ -15,7 +15,7 @@ import traceback
from sshuttle.methods import BaseMethod from sshuttle.methods import BaseMethod
from sshuttle.helpers import debug3, log, debug1, debug2, get_verbose_level, Fatal from sshuttle.helpers import debug3, debug1, debug2, get_verbose_level, Fatal
try: try:
# https://reqrypt.org/windivert-doc.html#divert_iphdr # https://reqrypt.org/windivert-doc.html#divert_iphdr
@ -47,6 +47,10 @@ class IPFamily(IntEnum):
IPv4 = socket.AF_INET IPv4 = socket.AF_INET
IPv6 = socket.AF_INET6 IPv6 = socket.AF_INET6
@staticmethod
def from_ip_version(version):
return IPFamily.IPv6 if version == 4 else IPFamily.IPv4
@property @property
def filter(self): def filter(self):
return "ip" if self == socket.AF_INET else "ipv6" return "ip" if self == socket.AF_INET else "ipv6"
@ -280,7 +284,7 @@ class Method(BaseMethod):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def _get_bind_addresses_for_port(self, port, family): def _get_bind_address_for_port(self, port, family):
proto = "TCPv6" if family.version == 6 else "TCP" proto = "TCPv6" if family.version == 6 else "TCP"
for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines(): for line in subprocess.check_output(["netstat", "-a", "-n", "-p", proto]).decode().splitlines():
try: try:
@ -293,7 +297,7 @@ class Method(BaseMethod):
raise Fatal("Could not find listening address for {}/{}".format(port, proto)) raise Fatal("Could not find listening address for {}/{}".format(port, proto))
def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark): def setup_firewall(self, proxy_port, dnsport, nslist, family, subnets, udp, user, group, tmark):
log(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") debug2(f"{proxy_port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
if nslist or user or udp: if nslist or user or udp:
raise NotImplementedError() raise NotImplementedError()
@ -304,18 +308,21 @@ class Method(BaseMethod):
# using loopback only proxy binding won't work with windivert. # using loopback only proxy binding won't work with windivert.
# See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 https://github.com/basil00/Divert/issues/82) # See: https://github.com/basil00/Divert/issues/17#issuecomment-341100167 https://github.com/basil00/Divert/issues/82)
# As a workaround, finding another interface ip instead. (client should not bind proxy to loopback address) # As a workaround, finding another interface ip instead. (client should not bind proxy to loopback address)
local_addr = self._get_bind_addresses_for_port(proxy_port, family) proxy_bind_addr = self._get_bind_address_for_port(proxy_port, family)
for addr in (ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), None)): if proxy_bind_addr.is_loopback:
if addr.version != family.version or addr.is_loopback or addr.is_link_local: raise Fatal("Windivert method requires proxy to be reachable by a non loopback address.")
continue if not proxy_bind_addr.is_unspecified:
if local_addr.is_unspecified or local_addr == addr: proxy_ip = proxy_bind_addr.exploded
else:
local_addresses = [ip_address(info[4][0]) for info in socket.getaddrinfo(socket.gethostname(), 0, family=family)]
for addr in local_addresses:
if not addr.is_loopback and not addr.is_link_local:
proxy_ip = addr.exploded proxy_ip = addr.exploded
debug2("Found non loopback address to connect to proxy: " + proxy_ip)
break break
else: else:
raise Fatal("Windivert method requires proxy to be reachable by a non loopback address." raise Fatal("Windivert method requires proxy to be reachable by a non loopback address."
f"No addersss found for {family.name}") f"No address found for {family.name} in {local_addresses}")
debug2("Found non loopback address to connect to proxy: " + proxy_ip)
subnet_addresses = [] subnet_addresses = []
for (_, mask, exclude, network_addr, fport, lport) in subnets: for (_, mask, exclude, network_addr, fport, lport) in subnets:
if exclude: if exclude:
@ -357,9 +364,11 @@ class Method(BaseMethod):
def get_supported_features(self): def get_supported_features(self):
result = super(Method, self).get_supported_features() result = super(Method, self).get_supported_features()
result.loopback_port = False result.loopback_proxy_port = False
result.user = False result.user = False
result.dns = False result.dns = False
# ipv6 only able to support with Windivert 2.x due to bugs in filter parsing
# TODO(nom3ad): Enable ipv6 once https://github.com/ffalcinelli/pydivert/pull/57 merged
result.ipv6 = False result.ipv6 = False
return result return result
@ -463,19 +472,20 @@ class Method(BaseMethod):
for pkt in w: for pkt in w:
verbose >= 3 and debug3("[INGRESS] " + repr_pkt(pkt)) verbose >= 3 and debug3("[INGRESS] " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: if pkt.tcp.syn and pkt.tcp.ack:
# SYN+ACK received (connection established) # SYN+ACK received (connection established from proxy
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_ESTABLISHED)
elif pkt.tcp.rst: elif pkt.tcp.rst:
# RST received - Abrupt connection teardown initiated by otherside. We don't expect anymore packets # RST received - Abrupt connection teardown initiated by proxy. Don't expect anymore packets
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
# https://wiki.wireshark.org/TCP-4-times-close.md # https://wiki.wireshark.org/TCP-4-times-close.md
elif pkt.tcp.fin and pkt.tcp.ack: elif pkt.tcp.fin and pkt.tcp.ack:
# FIN+ACK received (Passive close by otherside. We don't expect any more packets. Otherside expects an ACK) # FIN+ACK received (Passive close by proxy. Don't expect any more packets. proxy expects an ACK)
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
elif pkt.tcp.fin: elif pkt.tcp.fin:
# FIN received (Otherside initiated graceful close. We expects a final ACK for a FIN packet) # FIN received (proxy initiated graceful close. Expect a final ACK for a FIN packet)
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_CLOSE_WAIT)
else: else:
# data fragments and ACKs
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
if not conn: if not conn:
verbose >= 2 and debug2("Unexpected packet: " + repr_pkt(pkt)) verbose >= 2 and debug2("Unexpected packet: " + repr_pkt(pkt))

View File

@ -262,7 +262,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start() threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0) return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0)
# https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows # See: stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows
close_fds = False if sys.platform == 'win32' else True close_fds = False if sys.platform == 'win32' else True
debug2("executing: %r" % argv) debug2("executing: %r" % argv)