mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-24 06:48:40 +01:00
ddc365f7a0
--------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
158 lines
3.4 KiB
Go
158 lines
3.4 KiB
Go
package dnsfwd
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
)
|
|
|
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
|
|
|
type DNSForwarder struct {
|
|
listenAddress string
|
|
ttl uint32
|
|
domains []string
|
|
|
|
dnsServer *dns.Server
|
|
mux *dns.ServeMux
|
|
}
|
|
|
|
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
|
return &DNSForwarder{
|
|
listenAddress: listenAddress,
|
|
ttl: ttl,
|
|
}
|
|
}
|
|
|
|
func (f *DNSForwarder) Listen(domains []string) error {
|
|
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
|
mux := dns.NewServeMux()
|
|
|
|
dnsServer := &dns.Server{
|
|
Addr: f.listenAddress,
|
|
Net: "udp",
|
|
Handler: mux,
|
|
}
|
|
f.dnsServer = dnsServer
|
|
f.mux = mux
|
|
|
|
f.UpdateDomains(domains)
|
|
|
|
return dnsServer.ListenAndServe()
|
|
}
|
|
|
|
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
|
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
|
|
|
for _, d := range f.domains {
|
|
f.mux.HandleRemove(d)
|
|
}
|
|
|
|
newDomains := filterDomains(domains)
|
|
for _, d := range newDomains {
|
|
f.mux.HandleFunc(d, f.handleDNSQuery)
|
|
}
|
|
f.domains = newDomains
|
|
}
|
|
|
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
|
if f.dnsServer == nil {
|
|
return nil
|
|
}
|
|
return f.dnsServer.ShutdownContext(ctx)
|
|
}
|
|
|
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|
if len(query.Question) == 0 {
|
|
return
|
|
}
|
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
|
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
|
|
|
|
question := query.Question[0]
|
|
domain := question.Name
|
|
|
|
resp := query.SetReply(query)
|
|
|
|
ips, err := net.LookupIP(domain)
|
|
if err != nil {
|
|
var dnsErr *net.DNSError
|
|
|
|
switch {
|
|
case errors.As(err, &dnsErr):
|
|
resp.Rcode = dns.RcodeServerFailure
|
|
if dnsErr.IsNotFound {
|
|
// Pass through NXDOMAIN
|
|
resp.Rcode = dns.RcodeNameError
|
|
}
|
|
|
|
if dnsErr.Server != "" {
|
|
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
|
} else {
|
|
log.Warnf(errResolveFailed, domain, err)
|
|
}
|
|
default:
|
|
resp.Rcode = dns.RcodeServerFailure
|
|
log.Warnf(errResolveFailed, domain, err)
|
|
}
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write failure DNS response: %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
var respRecord dns.RR
|
|
if ip.To4() == nil {
|
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
|
rr := dns.AAAA{
|
|
AAAA: ip,
|
|
Hdr: dns.RR_Header{
|
|
Name: domain,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: f.ttl,
|
|
},
|
|
}
|
|
respRecord = &rr
|
|
} else {
|
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
|
rr := dns.A{
|
|
A: ip,
|
|
Hdr: dns.RR_Header{
|
|
Name: domain,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: f.ttl,
|
|
},
|
|
}
|
|
respRecord = &rr
|
|
}
|
|
resp.Answer = append(resp.Answer, respRecord)
|
|
}
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write DNS response: %v", err)
|
|
}
|
|
}
|
|
|
|
// filterDomains returns a list of normalized domains
|
|
func filterDomains(domains []string) []string {
|
|
newDomains := make([]string, 0, len(domains))
|
|
for _, d := range domains {
|
|
if d == "" {
|
|
log.Warn("empty domain in DNS forwarder")
|
|
continue
|
|
}
|
|
newDomains = append(newDomains, nbdns.NormalizeZone(d))
|
|
}
|
|
return newDomains
|
|
}
|