From 03f600b5764e49d290d1ffcd7305af588b7215cb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 8 Apr 2025 13:41:13 +0200 Subject: [PATCH] [client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632) --- client/internal/dns/upstream.go | 51 ++++++++++++++++++- client/internal/dns/upstream_general.go | 3 +- client/internal/dns/upstream_ios.go | 2 +- .../routemanager/dnsinterceptor/handler.go | 4 +- .../routemanager/dynamic/route_ios.go | 2 +- 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 53fa20f62..fa69d4934 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -18,6 +18,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" ) @@ -107,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { }() log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records if r.Extra == nil { - r.SetEdns0(4096, false) r.MsgHdr.AuthenticatedData = true } @@ -337,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati _, _, err := u.upstreamClient.exchange(ctx, server, r) return err } + +// ExchangeWithFallback exchanges a DNS message with the upstream server. +// It first tries to use UDP, and if it is truncated, it falls back to TCP. +// If the passed context is nil, this will use Exchange instead of ExchangeContext. +func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { + // MTU - ip + udp headers + // Note: this could be sent out on an interface that is not ours, but our MTU should always be lower. + client.UDPSize = iface.DefaultMTU - (60 + 8) + + var ( + rm *dns.Msg + t time.Duration + err error + ) + + if ctx == nil { + rm, t, err = client.Exchange(r, upstream) + } else { + rm, t, err = client.ExchangeContext(ctx, r, upstream) + } + + if err != nil { + return nil, t, fmt.Errorf("with udp: %w", err) + } + + if rm == nil || !rm.MsgHdr.Truncated { + return rm, t, nil + } + + log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.", + r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + client.Net = "tcp" + + if ctx == nil { + rm, t, err = client.Exchange(r, upstream) + } else { + rm, t, err = client.ExchangeContext(ctx, r, upstream) + } + + if err != nil { + return nil, t, fmt.Errorf("with tcp: %w", err) + } + + // TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP + + return rm, t, nil +} diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 51acbf7a6..9bb5feab0 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -34,6 +34,5 @@ func newUpstreamResolver( } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{} - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index c73079b92..ca5b31132 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } // Cannot use client.ExchangeContext because it overwrites our Dialer - return client.Exchange(r, upstream) + return ExchangeWithFallback(nil, client, r, upstream) } // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 42d740d90..68d81d968 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -162,9 +162,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - // set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records if r.Extra == nil { - r.SetEdns0(4096, false) r.MsgHdr.AuthenticatedData = true } client := &dns.Client{ @@ -172,7 +170,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { Net: "udp", } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := client.ExchangeContext(context.Background(), r, upstream) + reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) if err != nil { log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { diff --git a/client/internal/routemanager/dynamic/route_ios.go b/client/internal/routemanager/dynamic/route_ios.go index 145d395e6..34949b626 100644 --- a/client/internal/routemanager/dynamic/route_ios.go +++ b/client/internal/routemanager/dynamic/route_ios.go @@ -27,7 +27,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) { startTime := time.Now() - response, _, err := privateClient.Exchange(msg, r.resolverAddr) + response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr) if err != nil { return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) }