netbird/client/internal/dnsfwd/forwarder.go

116 lines
2.3 KiB
Go
Raw Normal View History

2024-12-10 19:14:09 +01:00
package dnsfwd
import (
"context"
"net"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
2024-12-10 19:14:09 +01:00
dnsServer *dns.Server
mux *dns.ServeMux
2024-12-10 19:14:09 +01:00
}
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
domains: domains,
}
}
2024-12-10 19:14:09 +01:00
func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
2024-12-10 19:14:09 +01:00
mux := dns.NewServeMux()
for _, d := range f.domains {
mux.HandleFunc(d, f.handleDNSQuery)
}
2024-12-10 19:14:09 +01:00
dnsServer := &dns.Server{
Addr: f.listenAddress,
2024-12-10 19:14:09 +01:00
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
f.mux = mux
2024-12-10 19:14:09 +01:00
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
for _, d := range f.domains {
f.mux.HandleRemove(d)
}
for _, d := range domains {
f.mux.HandleFunc(d, f.handleDNSQuery)
}
f.domains = domains
}
2024-12-10 19:14:09 +01:00
func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
}
return f.dnsServer.ShutdownContext(ctx)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
log.Tracef("received DNS query for DNS forwarder: %v", query)
2024-12-10 19:14:09 +01:00
if len(query.Question) == 0 {
return
}
question := query.Question[0]
domain := question.Name
resp := query.SetReply(query)
ips, err := net.LookupIP(domain)
if err != nil {
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeRefused
2024-12-10 19:14:09 +01:00
_ = w.WriteMsg(resp)
return
}
for _, ip := range ips {
log.Infof("resolved domain %s to IP %s", domain, ip)
var respRecord dns.RR
if ip.To4() == nil {
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
2024-12-10 19:14:09 +01:00
},
}
respRecord = &rr
} else {
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
2024-12-10 19:14:09 +01:00
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}