Enhanced DNS support. Initial version.

This commit is contained in:
Brian May 2011-05-12 14:37:19 +10:00
parent 783d33cada
commit 9915d736fe
2 changed files with 122 additions and 8 deletions

112
client.py
View File

@ -175,8 +175,60 @@ class FirewallClient:
raise Fatal('cleanup: %r returned %d' % (self.argv, rv))
def unpack_dns_name(buf, off):
name = ''
while True:
# get the next octet from buffer
n = ord(buf[off])
# zero octet terminates name
if n == 0:
off += 1
break
# top two bits on
# => a 2 octect pointer to another part of the buffer
elif (n & 0xc0) == 0xc0:
ptr = struct.unpack('>H', buf[off:off+2])[0] & 0x3fff
off = ptr
# an octet representing the number of bytes to process.
else:
off += 1
name = name + buf[off:off+n] + '.'
off += n
return name.strip('.'), off
class dnspkt:
def unpack(self, buf, off):
l = len(buf)
(self.id, self.op, self.qdcount, self.ancount, self.nscount, self.arcount) = struct.unpack("!HHHHHH",buf[off:off+12])
off += 12
self.q = []
for i in range(self.qdcount):
qname, off = unpack_dns_name(buf, off)
qtype, qclass = struct.unpack('!HH', buf[off:off+4])
off += 4
self.q.append( (qname,qtype,qclass) )
return off
def match_q_domain(self, domain):
l = len(domain)
for qname,qtype,qclass in self.q:
if qname[-l:] == domain:
if l==len(qname):
return True
elif qname[-l-1] == '.':
return True
return False
def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
dnslistener, seed_hosts, auto_nets,
dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets,
syslog, daemon):
handlers = []
if helpers.verbose >= 1:
@ -283,6 +335,7 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
handlers.append(Handler([listener], onaccept))
dnsreqs = {}
dnsforwards = {}
def dns_done(chan, data):
peer,timeout = dnsreqs.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
@ -295,16 +348,54 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
now = time.time()
if pkt:
debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt)))
chan = mux.next_channel()
dnsreqs[chan] = peer,now+30
mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan,data)
dns = dnspkt()
dns.unpack(pkt, 0)
match=False
if dns_domains is not None:
for domain in dns_domains:
if dns.match_q_domain(domain):
match=True
break
if match:
debug3("We need to redirect this request remotely\n")
chan = mux.next_channel()
dnsreqs[chan] = peer,now+30
mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan,data)
else:
debug3("We need to forward this request locally\n")
dnsforwarder.sendto(pkt, dns_to)
dnsforwards[dns.id] = peer,now+30
for chan,(peer,timeout) in dnsreqs.items():
if timeout < now:
del dnsreqs[chan]
for chan,(peer,timeout) in dnsforwards.items():
if timeout < now:
del dnsforwards[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
debug3('Remaining DNS forwards: %d\n' % len(dnsforwards))
if dnslistener:
handlers.append(Handler([dnslistener], ondns))
def ondnsforward():
debug1("We got a response.\n")
pkt,server = dnsforwarder.recvfrom(4096)
now = time.time()
if server[0] != dns_to[0] or server[1] != dns_to[1]:
debug1("Ooops. The response came from the wrong server. Ignoring\n")
else:
dns = dnspkt()
dns.unpack(pkt, 0)
chan=dns.id
peer,timeout = dnsforwards.get(chan) or (None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
if peer:
del dnsforwards[chan]
debug3('doing sendto %r\n' % (peer,))
dnslistener.sendto(pkt, peer)
if dnsforwarder:
handlers.append(Handler([dnsforwarder], ondnsforward))
if seed_hosts != None:
debug1('seed_hosts: %r\n' % seed_hosts)
@ -321,7 +412,8 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control,
mux.callback()
def main(listenip, ssh_cmd, remotename, python, latency_control, dns,
def main(listenip, ssh_cmd, remotename, python, latency_control,
dns, dns_domains, dns_to,
seed_hosts, auto_nets,
subnets_include, subnets_exclude, syslog, daemon, pidfile):
if syslog:
@ -366,15 +458,21 @@ def main(listenip, ssh_cmd, remotename, python, latency_control, dns,
dnsip = dnslistener.getsockname()
debug1('DNS listening on %r.\n' % (dnsip,))
dnsport = dnsip[1]
dnsforwarder = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
dnsforwarder.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
dnsforwarder.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
else:
dnsport = 0
dnslistener = None
dnsforwarder = None
fw = FirewallClient(listenip[1], subnets_include, subnets_exclude, dnsport)
try:
return _main(listener, fw, ssh_cmd, remotename,
python, latency_control, dnslistener,
python, latency_control,
dnslistener, dnsforwarder, dns_domains, dns_to,
seed_hosts, auto_nets, syslog, daemon)
finally:
try:

18
main.py
View File

@ -54,6 +54,8 @@ l,listen= transproxy to this ip address and port number [127.0.0.1:0]
H,auto-hosts scan for remote hostnames and update local /etc/hosts
N,auto-nets automatically determine subnets to route
dns capture local DNS requests and forward to the remote DNS server
dns-domains= comma seperated list of DNS domains for DNS forwarding
dns-to= forward any DNS requests that don't match domains to this address
python= path to python interpreter on the remote server
r,remote= ssh hostname (and optional username) of remote sshuttle server
x,exclude= exclude this subnet (can be used more than once)
@ -110,12 +112,26 @@ try:
sh = []
else:
sh = None
if opt.dns and opt.dns_domains:
dns_domains = opt.dns_domains.split(",")
if opt.dns_to:
addr,colon,port = opt.dns_to.rpartition(":")
if colon == ":":
dns_to = ( addr, int(port) )
else:
dns_to = ( port, 53 )
else:
o.fatal('--dns-to=ip is required with --dns-domains=list')
else:
dns_domains = None
dns_to = None
sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'),
opt.ssh_cmd,
remotename,
opt.python,
opt.latency_control,
opt.dns,
opt.dns, dns_domains, dns_to,
sh,
opt.auto_nets,
parse_subnets(includes),