Compare commits

..

4 Commits

Author SHA1 Message Date
9915d736fe Enhanced DNS support. Initial version. 2011-05-12 14:37:19 +10:00
783d33cada DNS: auto-retry if we get an error on send/recv to DNS server.
A few people have reported that they have one or more invalid DNS servers in
/etc/resolv.conf, which they don't notice because the normal resolver
library just skips the broken ones.  sshuttle would abort because it got an
unexpected socket error, which isn't so good.
2011-04-06 12:30:12 -04:00
94241b938b On FreeBSD, avoid a crash caused by buggy socket.connect() in python pre-2.5.
Bug reported by Ed Maste.  The fix in later versions of python is documented
here:
http://mail.python.org/pipermail/python-bugs-list/2006-August/034667.html

We're basically just doing the same thing when we see EINVAL.  Note that
this doesn't happen on Linux because connect() is more forgiving.
2011-03-21 03:15:11 -07:00
9031de1527 repr(socket.error) is useless in some versions of python.
So let's use %s instead of %r to print it, so that log messages can be more
useful.  This only affects one message at debug3 for now, so it's not too
exciting.
2011-03-21 03:15:11 -07:00
4 changed files with 170 additions and 18 deletions

112
client.py
View File

@ -175,8 +175,60 @@ class FirewallClient:
raise Fatal('cleanup: %r returned %d' % (self.argv, rv)) raise Fatal('cleanup: %r returned %d' % (self.argv, rv))
def unpack_dns_name(buf, off):
name = ''
while True:
# get the next octet from buffer
n = ord(buf[off])
# zero octet terminates name
if n == 0:
off += 1
break
# top two bits on
# => a 2 octect pointer to another part of the buffer
elif (n & 0xc0) == 0xc0:
ptr = struct.unpack('>H', buf[off:off+2])[0] & 0x3fff
off = ptr
# an octet representing the number of bytes to process.
else:
off += 1
name = name + buf[off:off+n] + '.'
off += n
return name.strip('.'), off
class dnspkt:
def unpack(self, buf, off):
l = len(buf)
(self.id, self.op, self.qdcount, self.ancount, self.nscount, self.arcount) = struct.unpack("!HHHHHH",buf[off:off+12])
off += 12
self.q = []
for i in range(self.qdcount):
qname, off = unpack_dns_name(buf, off)
qtype, qclass = struct.unpack('!HH', buf[off:off+4])
off += 4
self.q.append( (qname,qtype,qclass) )
return off
def match_q_domain(self, domain):
l = len(domain)
for qname,qtype,qclass in self.q:
if qname[-l:] == domain:
if l==len(qname):
return True
elif qname[-l-1] == '.':
return True
return False
def _main(listener, fw, ssh_cmd, remotename, python, latency_control, def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
dnslistener, seed_hosts, auto_nets, dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets,
syslog, daemon): syslog, daemon):
handlers = [] handlers = []
if helpers.verbose >= 1: if helpers.verbose >= 1:
@ -283,6 +335,7 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
handlers.append(Handler([listener], onaccept)) handlers.append(Handler([listener], onaccept))
dnsreqs = {} dnsreqs = {}
dnsforwards = {}
def dns_done(chan, data): def dns_done(chan, data):
peer,timeout = dnsreqs.get(chan) or (None,None) peer,timeout = dnsreqs.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer)) debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
@ -295,16 +348,54 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
now = time.time() now = time.time()
if pkt: if pkt:
debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt))) debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt)))
chan = mux.next_channel() dns = dnspkt()
dnsreqs[chan] = peer,now+30 dns.unpack(pkt, 0)
mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan,data) match=False
if dns_domains is not None:
for domain in dns_domains:
if dns.match_q_domain(domain):
match=True
break
if match:
debug3("We need to redirect this request remotely\n")
chan = mux.next_channel()
dnsreqs[chan] = peer,now+30
mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan,data)
else:
debug3("We need to forward this request locally\n")
dnsforwarder.sendto(pkt, dns_to)
dnsforwards[dns.id] = peer,now+30
for chan,(peer,timeout) in dnsreqs.items(): for chan,(peer,timeout) in dnsreqs.items():
if timeout < now: if timeout < now:
del dnsreqs[chan] del dnsreqs[chan]
for chan,(peer,timeout) in dnsforwards.items():
if timeout < now:
del dnsforwards[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs)) debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
debug3('Remaining DNS forwards: %d\n' % len(dnsforwards))
if dnslistener: if dnslistener:
handlers.append(Handler([dnslistener], ondns)) handlers.append(Handler([dnslistener], ondns))
def ondnsforward():
debug1("We got a response.\n")
pkt,server = dnsforwarder.recvfrom(4096)
now = time.time()
if server[0] != dns_to[0] or server[1] != dns_to[1]:
debug1("Ooops. The response came from the wrong server. Ignoring\n")
else:
dns = dnspkt()
dns.unpack(pkt, 0)
chan=dns.id
peer,timeout = dnsforwards.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
if peer:
del dnsforwards[chan]
debug3('doing sendto %r\n' % (peer,))
dnslistener.sendto(pkt, peer)
if dnsforwarder:
handlers.append(Handler([dnsforwarder], ondnsforward))
if seed_hosts != None: if seed_hosts != None:
debug1('seed_hosts: %r\n' % seed_hosts) debug1('seed_hosts: %r\n' % seed_hosts)
@ -321,7 +412,8 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
mux.callback() mux.callback()
def main(listenip, ssh_cmd, remotename, python, latency_control, dns, def main(listenip, ssh_cmd, remotename, python, latency_control,
dns, dns_domains, dns_to,
seed_hosts, auto_nets, seed_hosts, auto_nets,
subnets_include, subnets_exclude, syslog, daemon, pidfile): subnets_include, subnets_exclude, syslog, daemon, pidfile):
if syslog: if syslog:
@ -366,15 +458,21 @@ def main(listenip, ssh_cmd, remotename, python, latency_control, dns,
dnsip = dnslistener.getsockname() dnsip = dnslistener.getsockname()
debug1('DNS listening on %r.\n' % (dnsip,)) debug1('DNS listening on %r.\n' % (dnsip,))
dnsport = dnsip[1] dnsport = dnsip[1]
dnsforwarder = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
dnsforwarder.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dnsforwarder.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
else: else:
dnsport = 0 dnsport = 0
dnslistener = None dnslistener = None
dnsforwarder = None
fw = FirewallClient(listenip[1], subnets_include, subnets_exclude, dnsport) fw = FirewallClient(listenip[1], subnets_include, subnets_exclude, dnsport)
try: try:
return _main(listener, fw, ssh_cmd, remotename, return _main(listener, fw, ssh_cmd, remotename,
python, latency_control, dnslistener, python, latency_control,
dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets, syslog, daemon) seed_hosts, auto_nets, syslog, daemon)
finally: finally:
try: try:

18
main.py
View File

@ -54,6 +54,8 @@ l,listen= transproxy to this ip address and port number [127.0.0.1:0]
H,auto-hosts scan for remote hostnames and update local /etc/hosts H,auto-hosts scan for remote hostnames and update local /etc/hosts
N,auto-nets automatically determine subnets to route N,auto-nets automatically determine subnets to route
dns capture local DNS requests and forward to the remote DNS server dns capture local DNS requests and forward to the remote DNS server
dns-domains= comma seperated list of DNS domains for DNS forwarding
dns-to= forward any DNS requests that don't match domains to this address
python= path to python interpreter on the remote server python= path to python interpreter on the remote server
r,remote= ssh hostname (and optional username) of remote sshuttle server r,remote= ssh hostname (and optional username) of remote sshuttle server
x,exclude= exclude this subnet (can be used more than once) x,exclude= exclude this subnet (can be used more than once)
@ -110,12 +112,26 @@ try:
sh = [] sh = []
else: else:
sh = None sh = None
if opt.dns and opt.dns_domains:
dns_domains = opt.dns_domains.split(",")
if opt.dns_to:
addr,colon,port = opt.dns_to.rpartition(":")
if colon == ":":
dns_to = ( addr, int(port) )
else:
dns_to = ( port, 53 )
else:
o.fatal('--dns-to=ip is required with --dns-domains=list')
else:
dns_domains = None
dns_to = None
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
opt.ssh_cmd, opt.ssh_cmd,
remotename, remotename,
opt.python, opt.python,
opt.latency_control, opt.latency_control,
opt.dns, opt.dns, dns_domains, dns_to,
sh, sh,
opt.auto_nets, opt.auto_nets,
parse_subnets(includes), parse_subnets(includes),

View File

@ -110,23 +110,51 @@ class DnsProxy(Handler):
def __init__(self, mux, chan, request): def __init__(self, mux, chan, request):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
Handler.__init__(self, [sock]) Handler.__init__(self, [sock])
self.sock = sock
self.timeout = time.time()+30 self.timeout = time.time()+30
self.mux = mux self.mux = mux
self.chan = chan self.chan = chan
self.tries = 0
self.peer = None
self.request = request
self.sock = sock
self.sock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) self.sock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
self.sock.connect((resolvconf_random_nameserver(), 53)) self.try_send()
self.sock.send(request)
def try_send(self):
if self.tries >= 3:
return
self.tries += 1
self.peer = resolvconf_random_nameserver()
self.sock.connect((self.peer, 53))
debug2('DNS: sending to %r\n' % self.peer)
try:
self.sock.send(self.request)
except socket.error, e:
if e.args[0] in [errno.ECONNREFUSED, errno.EHOSTUNREACH]:
# might have been spurious; try again.
# Note: these errors sometimes are reported by recv(),
# and sometimes by send(). We have to catch both.
debug2('DNS send to %r: %s\n' % (self.peer, e))
self.try_send()
return
else:
log('DNS send to %r: %s\n' % (self.peer, e))
return
def callback(self): def callback(self):
try: try:
data = self.sock.recv(4096) data = self.sock.recv(4096)
except socket.error, e: except socket.error, e:
if e.args[0] == errno.ECONNREFUSED: if e.args[0] in [errno.ECONNREFUSED, errno.EHOSTUNREACH]:
debug2('DNS response: ignoring ECONNREFUSED.\n') # might have been spurious; try again.
return # might have been spurious; wait for a real answer # Note: these errors sometimes are reported by recv(),
# and sometimes by send(). We have to catch both.
debug2('DNS recv from %r: %s\n' % (self.peer, e))
self.try_send()
return
else: else:
raise log('DNS recv from %r: %s\n' % (self.peer, e))
return
debug2('DNS response: %d bytes\n' % len(data)) debug2('DNS response: %d bytes\n' % len(data))
self.mux.send(self.chan, ssnet.CMD_DNS_RESPONSE, data) self.mux.send(self.chan, ssnet.CMD_DNS_RESPONSE, data)
self.ok = False self.ok = False

View File

@ -86,7 +86,7 @@ class SockWrapper:
def __init__(self, rsock, wsock, connect_to=None, peername=None): def __init__(self, rsock, wsock, connect_to=None, peername=None):
global _swcount global _swcount
_swcount += 1 _swcount += 1
debug3('creating new SockWrapper (%d now exist\n)' % _swcount) debug3('creating new SockWrapper (%d now exist)\n' % _swcount)
self.exc = None self.exc = None
self.rsock = rsock self.rsock = rsock
self.wsock = wsock self.wsock = wsock
@ -101,7 +101,7 @@ class SockWrapper:
_swcount -= 1 _swcount -= 1
debug1('%r: deleting (%d remain)\n' % (self, _swcount)) debug1('%r: deleting (%d remain)\n' % (self, _swcount))
if self.exc: if self.exc:
debug1('%r: error was: %r\n' % (self, self.exc)) debug1('%r: error was: %s\n' % (self, self.exc))
def __repr__(self): def __repr__(self):
if self.rsock == self.wsock: if self.rsock == self.wsock:
@ -129,7 +129,17 @@ class SockWrapper:
# connected successfully (Linux) # connected successfully (Linux)
self.connect_to = None self.connect_to = None
except socket.error, e: except socket.error, e:
debug3('%r: connect result: %r\n' % (self, e)) debug3('%r: connect result: %s\n' % (self, e))
if e.args[0] == errno.EINVAL:
# this is what happens when you call connect() on a socket
# that is now connected but returned EINPROGRESS last time,
# on BSD, on python pre-2.5.1. We need to use getsockopt()
# to get the "real" error. Later pythons do this
# automatically, so this code won't run.
realerr = self.rsock.getsockopt(socket.SOL_SOCKET,
socket.SO_ERROR)
e = socket.error(realerr, os.strerror(realerr))
debug3('%r: fixed connect result: %s\n' % (self, e))
if e.args[0] in [errno.EINPROGRESS, errno.EALREADY]: if e.args[0] in [errno.EINPROGRESS, errno.EALREADY]:
pass # not connected yet pass # not connected yet
elif e.args[0] == errno.EISCONN: elif e.args[0] == errno.EISCONN: