From 2f34e984b0051c5b0ad1992d88b93c62165bf765 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Fri, 9 May 2025 15:06:34 +0300 Subject: [PATCH] [client] Add TCP support to DNS forwarder service listener (#3790) [client] Add TCP support to DNS forwarder service listener --- client/internal/dnsfwd/forwarder.go | 104 ++++++++++++++++++++++------ client/internal/dnsfwd/manager.go | 14 ++++ 2 files changed, 98 insertions(+), 20 deletions(-) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 8f6a31f47..45b479632 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -33,6 +33,8 @@ type DNSForwarder struct { dnsServer *dns.Server mux *dns.ServeMux + tcpServer *dns.Server + tcpMux *dns.ServeMux mutex sync.RWMutex fwdEntries []*ForwarderEntry @@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("listen DNS forwarder on address=%s", f.listenAddress) - mux := dns.NewServeMux() + log.Infof("starting DNS forwarder on address=%s", f.listenAddress) - dnsServer := &dns.Server{ + // UDP server + mux := dns.NewServeMux() + f.mux = mux + f.dnsServer = &dns.Server{ Addr: f.listenAddress, Net: "udp", Handler: mux, } - f.dnsServer = dnsServer - f.mux = mux + // TCP server + tcpMux := dns.NewServeMux() + f.tcpMux = tcpMux + f.tcpServer = &dns.Server{ + Addr: f.listenAddress, + Net: "tcp", + Handler: tcpMux, + } f.UpdateDomains(entries) - return dnsServer.ListenAndServe() -} + errCh := make(chan error, 2) + go func() { + log.Infof("DNS UDP listener running on %s", f.listenAddress) + errCh <- f.dnsServer.ListenAndServe() + }() + go func() { + log.Infof("DNS TCP listener running on %s", f.listenAddress) + errCh <- f.tcpServer.ListenAndServe() + }() + + // return the first error we get (e.g. bind failure or shutdown) + return <-errCh +} func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() @@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { } oldDomains := filterDomains(f.fwdEntries) - for _, d := range oldDomains { f.mux.HandleRemove(d.PunycodeString()) + f.tcpMux.HandleRemove(d.PunycodeString()) } newDomains := filterDomains(entries) for _, d := range newDomains { - f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery) + f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP) + f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP) } f.fwdEntries = entries - log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) } func (f *DNSForwarder) Close(ctx context.Context) error { - if f.dnsServer == nil { - return nil + var result *multierror.Error + + if f.dnsServer != nil { + if err := f.dnsServer.ShutdownContext(ctx); err != nil { + result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err)) + } } - return f.dnsServer.ShutdownContext(ctx) + if f.tcpServer != nil { + if err := f.tcpServer.ShutdownContext(ctx); err != nil { + result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg { if len(query.Question) == 0 { - return + return nil } question := query.Question[0] log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", @@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } - return + return nil } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) if err != nil { - f.handleDNSError(w, resp, domain, err) - return + f.handleDNSError(w, query, resp, domain, err) + return nil } f.updateInternalState(domain, ips) f.addIPsToResponse(resp, domain, ips) + return resp +} + +func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { + + resp := f.handleDNSQuery(w, query) + if resp == nil { + return + } + + opt := query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + // client advertised a larger EDNS0 buffer + maxSize = int(opt.UDPSize()) + } + + // if our response is too big, truncate and set the TC bit + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { + resp := f.handleDNSQuery(w, query) + if resp == nil { + return + } + if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } @@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe } // handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { +func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { var dnsErr *net.DNSError switch { @@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai } if dnsErr.Server != "" { - log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) } else { log.Warnf(errResolveFailed, domain, err) } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index e4a23450f..91abce823 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -33,6 +33,7 @@ type Manager struct { statusRecorder *peer.Status fwRules []firewall.Rule + tcpRules []firewall.Rule dnsForwarder *DNSForwarder } @@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error { } m.fwRules = dnsRules + tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + m.tcpRules = tcpRules + return nil } @@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) } } + for _, rule := range m.tcpRules { + if err := m.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } m.fwRules = nil + m.tcpRules = nil return nberrors.FormatErrorOrNil(mErr) }