mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
[client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632)
This commit is contained in:
parent
192c97aa63
commit
03f600b576
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
@ -107,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
|
|||||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||||
|
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||||
|
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||||
|
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||||
|
// MTU - ip + udp headers
|
||||||
|
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
|
||||||
|
client.UDPSize = iface.DefaultMTU - (60 + 8)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rm *dns.Msg
|
||||||
|
t time.Duration
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
|
} else {
|
||||||
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rm == nil || !rm.MsgHdr.Truncated {
|
||||||
|
return rm, t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
client.Net = "tcp"
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
|
} else {
|
||||||
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||||
|
|
||||||
|
return rm, t, nil
|
||||||
|
}
|
||||||
|
@ -34,6 +34,5 @@ func newUpstreamResolver(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
upstreamExchangeClient := &dns.Client{}
|
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||||
return client.Exchange(r, upstream)
|
return ExchangeWithFallback(nil, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
|
@ -162,9 +162,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
@ -172,7 +170,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
Net: "udp",
|
Net: "udp",
|
||||||
}
|
}
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
|
@ -27,7 +27,7 @@ func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
|||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
|
response, _, err := nbdns.ExchangeWithFallback(nil, privateClient, msg, r.resolverAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user