[client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632)

This commit is contained in:
Viktor Liu 2025-04-08 13:41:13 +02:00 committed by GitHub
parent 192c97aa63
commit 03f600b576
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 53 additions and 9 deletions

View File

@ -18,6 +18,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@ -107,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}() }()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil { if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
@ -337,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
_, _, err := u.upstreamClient.exchange(ctx, server, r) _, _, err := u.upstreamClient.exchange(ctx, server, r)
return err return err
} }
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
client.UDPSize = iface.DefaultMTU - (60 + 8)
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
if rm == nil || !rm.MsgHdr.Truncated {
return rm, t, nil
}
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
client.Net = "tcp"
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
return rm, t, nil
}

View File

@ -34,6 +34,5 @@ func newUpstreamResolver(
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{} return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }

View File

@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} }
// Cannot use client.ExchangeContext because it overwrites our Dialer // Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream) return ExchangeWithFallback(nil, client, r, upstream)
} }
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface

View File

@ -162,9 +162,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil { if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
client := &dns.Client{ client := &dns.Client{
@ -172,7 +170,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
Net: "udp", Net: "udp",
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil { if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {

View File

@ -27,7 +27,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
startTime := time.Now() startTime := time.Now()
response, _, err := privateClient.Exchange(msg, r.resolverAddr) response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err) return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
} }