mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-21 15:33:23 +01:00
783d33cada
A few people have reported that they have one or more invalid DNS servers in /etc/resolv.conf, which they don't notice because the normal resolver library just skips the broken ones. sshuttle would abort because it got an unexpected socket error, which isn't so good.
248 lines
7.3 KiB
Python
248 lines
7.3 KiB
Python
import re, struct, socket, select, traceback, time
|
|
if not globals().get('skip_imports'):
|
|
import ssnet, helpers, hostwatch
|
|
import compat.ssubprocess as ssubprocess
|
|
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
|
|
from helpers import *
|
|
|
|
|
|
def _ipmatch(ipstr):
|
|
if ipstr == 'default':
|
|
ipstr = '0.0.0.0/0'
|
|
m = re.match(r'^(\d+(\.\d+(\.\d+(\.\d+)?)?)?)(?:/(\d+))?$', ipstr)
|
|
if m:
|
|
g = m.groups()
|
|
ips = g[0]
|
|
width = int(g[4] or 32)
|
|
if g[1] == None:
|
|
ips += '.0.0.0'
|
|
width = min(width, 8)
|
|
elif g[2] == None:
|
|
ips += '.0.0'
|
|
width = min(width, 16)
|
|
elif g[3] == None:
|
|
ips += '.0'
|
|
width = min(width, 24)
|
|
return (struct.unpack('!I', socket.inet_aton(ips))[0], width)
|
|
|
|
|
|
def _ipstr(ip, width):
|
|
if width >= 32:
|
|
return ip
|
|
else:
|
|
return "%s/%d" % (ip, width)
|
|
|
|
|
|
def _maskbits(netmask):
|
|
if not netmask:
|
|
return 32
|
|
for i in range(32):
|
|
if netmask[0] & _shl(1, i):
|
|
return 32-i
|
|
return 0
|
|
|
|
|
|
def _shl(n, bits):
|
|
return n * int(2**bits)
|
|
|
|
|
|
def _list_routes():
|
|
argv = ['netstat', '-rn']
|
|
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
|
|
routes = []
|
|
for line in p.stdout:
|
|
cols = re.split(r'\s+', line)
|
|
ipw = _ipmatch(cols[0])
|
|
if not ipw:
|
|
continue # some lines won't be parseable; never mind
|
|
maskw = _ipmatch(cols[2]) # linux only
|
|
mask = _maskbits(maskw) # returns 32 if maskw is null
|
|
width = min(ipw[1], mask)
|
|
ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width)
|
|
routes.append((socket.inet_ntoa(struct.pack('!I', ip)), width))
|
|
rv = p.wait()
|
|
if rv != 0:
|
|
log('WARNING: %r returned %d\n' % (argv, rv))
|
|
log('WARNING: That prevents --auto-nets from working.\n')
|
|
return routes
|
|
|
|
|
|
def list_routes():
|
|
for (ip,width) in _list_routes():
|
|
if not ip.startswith('0.') and not ip.startswith('127.'):
|
|
yield (ip,width)
|
|
|
|
|
|
def _exc_dump():
|
|
exc_info = sys.exc_info()
|
|
return ''.join(traceback.format_exception(*exc_info))
|
|
|
|
|
|
def start_hostwatch(seed_hosts):
|
|
s1,s2 = socket.socketpair()
|
|
pid = os.fork()
|
|
if not pid:
|
|
# child
|
|
rv = 99
|
|
try:
|
|
try:
|
|
s2.close()
|
|
os.dup2(s1.fileno(), 1)
|
|
os.dup2(s1.fileno(), 0)
|
|
s1.close()
|
|
rv = hostwatch.hw_main(seed_hosts) or 0
|
|
except Exception, e:
|
|
log('%s\n' % _exc_dump())
|
|
rv = 98
|
|
finally:
|
|
os._exit(rv)
|
|
s1.close()
|
|
return pid,s2
|
|
|
|
|
|
class Hostwatch:
|
|
def __init__(self):
|
|
self.pid = 0
|
|
self.sock = None
|
|
|
|
|
|
class DnsProxy(Handler):
|
|
def __init__(self, mux, chan, request):
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
Handler.__init__(self, [sock])
|
|
self.timeout = time.time()+30
|
|
self.mux = mux
|
|
self.chan = chan
|
|
self.tries = 0
|
|
self.peer = None
|
|
self.request = request
|
|
self.sock = sock
|
|
self.sock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
|
|
self.try_send()
|
|
|
|
def try_send(self):
|
|
if self.tries >= 3:
|
|
return
|
|
self.tries += 1
|
|
self.peer = resolvconf_random_nameserver()
|
|
self.sock.connect((self.peer, 53))
|
|
debug2('DNS: sending to %r\n' % self.peer)
|
|
try:
|
|
self.sock.send(self.request)
|
|
except socket.error, e:
|
|
if e.args[0] in [errno.ECONNREFUSED, errno.EHOSTUNREACH]:
|
|
# might have been spurious; try again.
|
|
# Note: these errors sometimes are reported by recv(),
|
|
# and sometimes by send(). We have to catch both.
|
|
debug2('DNS send to %r: %s\n' % (self.peer, e))
|
|
self.try_send()
|
|
return
|
|
else:
|
|
log('DNS send to %r: %s\n' % (self.peer, e))
|
|
return
|
|
|
|
def callback(self):
|
|
try:
|
|
data = self.sock.recv(4096)
|
|
except socket.error, e:
|
|
if e.args[0] in [errno.ECONNREFUSED, errno.EHOSTUNREACH]:
|
|
# might have been spurious; try again.
|
|
# Note: these errors sometimes are reported by recv(),
|
|
# and sometimes by send(). We have to catch both.
|
|
debug2('DNS recv from %r: %s\n' % (self.peer, e))
|
|
self.try_send()
|
|
return
|
|
else:
|
|
log('DNS recv from %r: %s\n' % (self.peer, e))
|
|
return
|
|
debug2('DNS response: %d bytes\n' % len(data))
|
|
self.mux.send(self.chan, ssnet.CMD_DNS_RESPONSE, data)
|
|
self.ok = False
|
|
|
|
|
|
def main():
|
|
if helpers.verbose >= 1:
|
|
helpers.logprefix = ' s: '
|
|
else:
|
|
helpers.logprefix = 'server: '
|
|
debug1('latency control setting = %r\n' % latency_control)
|
|
|
|
routes = list(list_routes())
|
|
debug1('available routes:\n')
|
|
for r in routes:
|
|
debug1(' %s/%d\n' % r)
|
|
|
|
# synchronization header
|
|
sys.stdout.write('SSHUTTLE0001')
|
|
sys.stdout.flush()
|
|
|
|
handlers = []
|
|
mux = Mux(socket.fromfd(sys.stdin.fileno(),
|
|
socket.AF_INET, socket.SOCK_STREAM),
|
|
socket.fromfd(sys.stdout.fileno(),
|
|
socket.AF_INET, socket.SOCK_STREAM))
|
|
handlers.append(mux)
|
|
routepkt = ''
|
|
for r in routes:
|
|
routepkt += '%s,%d\n' % r
|
|
mux.send(0, ssnet.CMD_ROUTES, routepkt)
|
|
|
|
hw = Hostwatch()
|
|
hw.leftover = ''
|
|
|
|
def hostwatch_ready():
|
|
assert(hw.pid)
|
|
content = hw.sock.recv(4096)
|
|
if content:
|
|
lines = (hw.leftover + content).split('\n')
|
|
if lines[-1]:
|
|
# no terminating newline: entry isn't complete yet!
|
|
hw.leftover = lines.pop()
|
|
lines.append('')
|
|
else:
|
|
hw.leftover = ''
|
|
mux.send(0, ssnet.CMD_HOST_LIST, '\n'.join(lines))
|
|
else:
|
|
raise Fatal('hostwatch process died')
|
|
|
|
def got_host_req(data):
|
|
if not hw.pid:
|
|
(hw.pid,hw.sock) = start_hostwatch(data.strip().split())
|
|
handlers.append(Handler(socks = [hw.sock],
|
|
callback = hostwatch_ready))
|
|
mux.got_host_req = got_host_req
|
|
|
|
def new_channel(channel, data):
|
|
(dstip,dstport) = data.split(',', 1)
|
|
dstport = int(dstport)
|
|
outwrap = ssnet.connect_dst(dstip,dstport)
|
|
handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
|
|
mux.new_channel = new_channel
|
|
|
|
dnshandlers = {}
|
|
def dns_req(channel, data):
|
|
debug2('Incoming DNS request.\n')
|
|
h = DnsProxy(mux, channel, data)
|
|
handlers.append(h)
|
|
dnshandlers[channel] = h
|
|
mux.got_dns_req = dns_req
|
|
|
|
while mux.ok:
|
|
if hw.pid:
|
|
assert(hw.pid > 0)
|
|
(rpid, rv) = os.waitpid(hw.pid, os.WNOHANG)
|
|
if rpid:
|
|
raise Fatal('hostwatch exited unexpectedly: code 0x%04x\n' % rv)
|
|
|
|
ssnet.runonce(handlers, mux)
|
|
if latency_control:
|
|
mux.check_fullness()
|
|
mux.callback()
|
|
|
|
if dnshandlers:
|
|
now = time.time()
|
|
for channel,h in dnshandlers.items():
|
|
if h.timeout < now or not h.ok:
|
|
del dnshandlers[channel]
|
|
h.ok = False
|