diff --git a/client.py b/client.py index 0ff5f2b..21ecd22 100644 --- a/client.py +++ b/client.py @@ -175,8 +175,60 @@ class FirewallClient: raise Fatal('cleanup: %r returned %d' % (self.argv, rv)) +def unpack_dns_name(buf, off): + name = '' + while True: + # get the next octet from buffer + n = ord(buf[off]) + + # zero octet terminates name + if n == 0: + off += 1 + break + + # top two bits on + # => a 2 octect pointer to another part of the buffer + elif (n & 0xc0) == 0xc0: + ptr = struct.unpack('>H', buf[off:off+2])[0] & 0x3fff + off = ptr + + # an octet representing the number of bytes to process. + else: + off += 1 + name = name + buf[off:off+n] + '.' + off += n + + return name.strip('.'), off + +class dnspkt: + def unpack(self, buf, off): + l = len(buf) + + (self.id, self.op, self.qdcount, self.ancount, self.nscount, self.arcount) = struct.unpack("!HHHHHH",buf[off:off+12]) + off += 12 + + self.q = [] + for i in range(self.qdcount): + qname, off = unpack_dns_name(buf, off) + qtype, qclass = struct.unpack('!HH', buf[off:off+4]) + off += 4 + self.q.append( (qname,qtype,qclass) ) + + return off + + def match_q_domain(self, domain): + l = len(domain) + for qname,qtype,qclass in self.q: + if qname[-l:] == domain: + if l==len(qname): + return True + elif qname[-l-1] == '.': + return True + return False + def _main(listener, fw, ssh_cmd, remotename, python, latency_control, - dnslistener, seed_hosts, auto_nets, + dnslistener, dnsforwarder, dns_domains, dns_to, + seed_hosts, auto_nets, syslog, daemon): handlers = [] if helpers.verbose >= 1: @@ -283,6 +335,7 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control, handlers.append(Handler([listener], onaccept)) dnsreqs = {} + dnsforwards = {} def dns_done(chan, data): peer,timeout = dnsreqs.get(chan) or (None,None) debug3('dns_done: channel=%r peer=%r\n' % (chan, peer)) @@ -295,16 +348,54 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control, now = time.time() if pkt: debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt))) - chan = mux.next_channel() - dnsreqs[chan] = peer,now+30 - mux.send(chan, ssnet.CMD_DNS_REQ, pkt) - mux.channels[chan] = lambda cmd,data: dns_done(chan,data) + dns = dnspkt() + dns.unpack(pkt, 0) + + match=False + if dns_domains is not None: + for domain in dns_domains: + if dns.match_q_domain(domain): + match=True + break + + if match: + debug3("We need to redirect this request remotely\n") + chan = mux.next_channel() + dnsreqs[chan] = peer,now+30 + mux.send(chan, ssnet.CMD_DNS_REQ, pkt) + mux.channels[chan] = lambda cmd,data: dns_done(chan,data) + else: + debug3("We need to forward this request locally\n") + dnsforwarder.sendto(pkt, dns_to) + dnsforwards[dns.id] = peer,now+30 for chan,(peer,timeout) in dnsreqs.items(): if timeout < now: del dnsreqs[chan] + for chan,(peer,timeout) in dnsforwards.items(): + if timeout < now: + del dnsforwards[chan] debug3('Remaining DNS requests: %d\n' % len(dnsreqs)) + debug3('Remaining DNS forwards: %d\n' % len(dnsforwards)) if dnslistener: handlers.append(Handler([dnslistener], ondns)) + def ondnsforward(): + debug1("We got a response.\n") + pkt,server = dnsforwarder.recvfrom(4096) + now = time.time() + if server[0] != dns_to[0] or server[1] != dns_to[1]: + debug1("Ooops. The response came from the wrong server. Ignoring\n") + else: + dns = dnspkt() + dns.unpack(pkt, 0) + chan=dns.id + peer,timeout = dnsforwards.get(chan) or (None,None) + debug3('dns_done: channel=%r peer=%r\n' % (chan, peer)) + if peer: + del dnsforwards[chan] + debug3('doing sendto %r\n' % (peer,)) + dnslistener.sendto(pkt, peer) + if dnsforwarder: + handlers.append(Handler([dnsforwarder], ondnsforward)) if seed_hosts != None: debug1('seed_hosts: %r\n' % seed_hosts) @@ -321,7 +412,8 @@ def _main(listener, fw, ssh_cmd, remotename, python, latency_control, mux.callback() -def main(listenip, ssh_cmd, remotename, python, latency_control, dns, +def main(listenip, ssh_cmd, remotename, python, latency_control, + dns, dns_domains, dns_to, seed_hosts, auto_nets, subnets_include, subnets_exclude, syslog, daemon, pidfile): if syslog: @@ -366,15 +458,21 @@ def main(listenip, ssh_cmd, remotename, python, latency_control, dns, dnsip = dnslistener.getsockname() debug1('DNS listening on %r.\n' % (dnsip,)) dnsport = dnsip[1] + + dnsforwarder = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + dnsforwarder.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + dnsforwarder.setsockopt(socket.SOL_IP, socket.IP_TTL, 42) else: dnsport = 0 dnslistener = None + dnsforwarder = None fw = FirewallClient(listenip[1], subnets_include, subnets_exclude, dnsport) try: return _main(listener, fw, ssh_cmd, remotename, - python, latency_control, dnslistener, + python, latency_control, + dnslistener, dnsforwarder, dns_domains, dns_to, seed_hosts, auto_nets, syslog, daemon) finally: try: diff --git a/main.py b/main.py index 1cf00af..6afeefe 100644 --- a/main.py +++ b/main.py @@ -54,6 +54,8 @@ l,listen= transproxy to this ip address and port number [127.0.0.1:0] H,auto-hosts scan for remote hostnames and update local /etc/hosts N,auto-nets automatically determine subnets to route dns capture local DNS requests and forward to the remote DNS server +dns-domains= comma seperated list of DNS domains for DNS forwarding +dns-to= forward any DNS requests that don't match domains to this address python= path to python interpreter on the remote server r,remote= ssh hostname (and optional username) of remote sshuttle server x,exclude= exclude this subnet (can be used more than once) @@ -110,12 +112,26 @@ try: sh = [] else: sh = None + if opt.dns and opt.dns_domains: + dns_domains = opt.dns_domains.split(",") + if opt.dns_to: + addr,colon,port = opt.dns_to.rpartition(":") + if colon == ":": + dns_to = ( addr, int(port) ) + else: + dns_to = ( port, 53 ) + else: + o.fatal('--dns-to=ip is required with --dns-domains=list') + else: + dns_domains = None + dns_to = None + sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), opt.ssh_cmd, remotename, opt.python, opt.latency_control, - opt.dns, + opt.dns, dns_domains, dns_to, sh, opt.auto_nets, parse_subnets(includes),