diff --git a/sshuttle/cmdline.py b/sshuttle/cmdline.py index 5f1ba10..bda4c39 100644 --- a/sshuttle/cmdline.py +++ b/sshuttle/cmdline.py @@ -47,8 +47,15 @@ def main(): elif opt.hostwatch: return hostwatch.hw_main(opt.subnets, opt.auto_hosts) else: - includes = opt.subnets + opt.subnets_file - excludes = opt.exclude + # parse_subnetports() is used to create a list of includes + # and excludes. It returns a list of one or more items for + # each subnet. It returns a list since "-x example.com" + # might find that example.com has multiple IP addresses. Here, + # we flatten these lists. + includes = [item for sublist in opt.subnets+opt.subnets_file + for item in sublist] + excludes = [item for sublist in opt.exclude for item in sublist] + if not includes and not opt.auto_nets: parser.error('at least one subnet, subnet file, ' 'or -N expected') diff --git a/sshuttle/options.py b/sshuttle/options.py index 12ce55d..e104c79 100644 --- a/sshuttle/options.py +++ b/sshuttle/options.py @@ -28,7 +28,14 @@ def parse_subnetport_file(s): # 1.2.3.4/5:678, 1.2.3.4:567, 1.2.3.4/16 or just 1.2.3.4 # [1:2::3/64]:456, [1:2::3]:456, 1:2::3/64 or just 1:2::3 # example.com:123 or just example.com +# +# In addition, the port number can be specified as a range: +# 1.2.3.4:8000-8080. +# +# Can return multiple matches if the domain name used in the request +# has multiple IP addresses. def parse_subnetport(s): + if s.count(':') > 1: rx = r'(?:\[?([\w\:]+)(?:/(\d+))?]?)(?::(\d+)(?:-(\d+))?)?$' else: @@ -38,19 +45,57 @@ def parse_subnetport(s): if not m: raise Fatal('%r is not a valid address/mask:port format' % s) - addr, width, fport, lport = m.groups() + # Ports range from fport to lport. If only one port is specified, + # fport is defined and lport is None. + # + # cidr is the mask defined with the slash notation + host, cidr, fport, lport = m.groups() try: - addrinfo = socket.getaddrinfo(addr, 0, 0, socket.SOCK_STREAM) + addrinfo = socket.getaddrinfo(host, 0, 0, socket.SOCK_STREAM) except socket.gaierror: - raise Fatal('Unable to resolve address: %s' % addr) + raise Fatal('Unable to resolve address: %s' % host) - family, _, _, _, addr = min(addrinfo) - max_width = 32 if family == socket.AF_INET else 128 - width = int(width or max_width) - if not 0 <= width <= max_width: - raise Fatal('width %d is not between 0 and %d' % (width, max_width)) + # If the address is a domain with multiple IPs and a mask is also + # provided, proceed cautiously: + if cidr is not None: + addr_v6 = [a for a in addrinfo if a[0] == socket.AF_INET6] + addr_v4 = [a for a in addrinfo if a[0] == socket.AF_INET] - return (family, addr[0], width, int(fport or 0), int(lport or fport or 0)) + # Refuse to proceed if IPv4 and IPv6 addresses are present: + if len(addr_v6) > 0 and len(addr_v4) > 0: + raise Fatal("%s has IPv4 and IPv6 addresses, so the mask " + "of /%s is not supported. Specify the IP " + "addresses directly if you wish to specify " + "a mask." % (host, cidr)) + + # Warn if a domain has multiple IPs of the same type (IPv4 vs + # IPv6) and the mask is applied to all of the IPs. + if len(addr_v4) > 1 or len(addr_v6) > 1: + print("WARNING: %s has multiple IP addresses. The " + "mask of /%s is applied to all of the addresses." + % (host, cidr)) + + rv = [] + for a in addrinfo: + family, _, _, _, addr = a + + # Largest possible slash value we can use with this IP: + max_cidr = 32 if family == socket.AF_INET else 128 + + if cidr is None: # if no mask, use largest mask + cidr_to_use = max_cidr + else: # verify user-provided mask is appropriate + cidr_to_use = int(cidr) + if not 0 <= cidr_to_use <= max_cidr: + raise Fatal('Slash in CIDR notation (/%d) is ' + 'not between 0 and %d' + % (cidr_to_use, max_cidr)) + + rv.append((family, addr[0], cidr_to_use, + int(fport or 0), int(lport or fport or 0))) + + + return rv # 1.2.3.4:567 or just 1.2.3.4 or just 567 @@ -69,16 +114,20 @@ def parse_ipport(s): if not m: raise Fatal('%r is not a valid IP:port format' % s) - ip, port = m.groups() - ip = ip or '0.0.0.0' + host, port = m.groups() + host = host or '0.0.0.0' port = int(port or 0) try: - addrinfo = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM) + addrinfo = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM) except socket.gaierror: - raise Fatal('%r is not a valid IP:port format' % s) + raise Fatal('Unable to resolve address: %s' % host) + + if len(addrinfo) > 1: + print("WARNING: Host %s has more than one IP, only using one of them." % host) family, _, _, _, addr = min(addrinfo) + # Note: addr contains (ip, port) return (family,) + addr[:2]