Use backoff.retry to check if upstreams are responsive (#901)

Retry, in an exponential interval, querying the upstream servers until it gets a positive response
This commit is contained in:
Maycon Santos 2023-05-26 17:13:59 +02:00 committed by GitHub
parent f237e8bd30
commit b6105e9d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 135 additions and 57 deletions

View File

@ -2,10 +2,12 @@ package dns
import ( import (
"fmt" "fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"sync" "sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
) )
type registrationMap map[string]struct{} type registrationMap map[string]struct{}
@ -15,6 +17,9 @@ type localResolver struct {
records sync.Map records sync.Map
} }
func (d *localResolver) stop() {
}
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v\n", r.Question[0]) log.Tracef("received question: %#v\n", r.Question[0])

View File

@ -26,15 +26,16 @@ const (
customIP = "127.0.0.153" customIP = "127.0.0.153"
) )
type registeredHandlerMap map[string]handlerWithStop
// DefaultServer dns server object // DefaultServer dns server object
type DefaultServer struct { type DefaultServer struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
upstreamCtxCancel context.CancelFunc
mux sync.Mutex mux sync.Mutex
server *dns.Server server *dns.Server
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
dnsMuxMap registrationMap dnsMuxMap registeredHandlerMap
localResolver *localResolver localResolver *localResolver
wgInterface *iface.WGIface wgInterface *iface.WGIface
hostManager hostManager hostManager hostManager
@ -47,9 +48,14 @@ type DefaultServer struct {
customAddress *netip.AddrPort customAddress *netip.AddrPort
} }
type handlerWithStop interface {
dns.Handler
stop()
}
type muxUpdate struct { type muxUpdate struct {
domain string domain string
handler dns.Handler handler handlerWithStop
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
@ -79,7 +85,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
ctxCancel: stop, ctxCancel: stop,
server: dnsServer, server: dnsServer,
dnsMux: mux, dnsMux: mux,
dnsMuxMap: make(registrationMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@ -297,10 +303,6 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
} }
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
// clean up the previous upstream resolver
if s.upstreamCtxCancel != nil {
s.upstreamCtxCancel()
}
var muxUpdates []muxUpdate var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups { for _, nsGroup := range nameServerGroups {
@ -309,10 +311,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue continue
} }
var ctx context.Context handler := newUpstreamResolver(s.ctx)
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
handler := newUpstreamResolver(ctx)
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType { if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
@ -323,6 +322,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
handler.stop()
log.Errorf("received a nameserver group with an invalid nameserver list") log.Errorf("received a nameserver group with an invalid nameserver list")
continue continue
} }
@ -346,11 +346,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
} }
if len(nsGroup.Domains) == 0 { if len(nsGroup.Domains) == 0 {
handler.stop()
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
} }
for _, domain := range nsGroup.Domains { for _, domain := range nsGroup.Domains {
if domain == "" { if domain == "" {
handler.stop()
return nil, fmt.Errorf("received a nameserver group with an empty domain element") return nil, fmt.Errorf("received a nameserver group with an empty domain element")
} }
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, muxUpdate{
@ -363,16 +365,20 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
} }
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registrationMap) muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler) s.registerMux(update.domain, update.handler)
muxUpdateMap[update.domain] = struct{}{} muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
}
} }
for key := range s.dnsMuxMap { for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
existingHandler.stop()
s.deregisterMux(key) s.deregisterMux(key)
} }
} }

View File

@ -41,21 +41,23 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
} }
dummyHandler := &localResolver{}
testCases := []struct { testCases := []struct {
name string name string
initUpstreamMap registrationMap initUpstreamMap registeredHandlerMap
initLocalMap registrationMap initLocalMap registrationMap
initSerial uint64 initSerial uint64
inputSerial uint64 inputSerial uint64
inputUpdate nbdns.Config inputUpdate nbdns.Config
shouldFail bool shouldFail bool
expectedUpstreamMap registrationMap expectedUpstreamMap registeredHandlerMap
expectedLocalMap registrationMap expectedLocalMap registrationMap
}{ }{
{ {
name: "Initial Config Should Succeed", name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap), initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@ -77,13 +79,13 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}}, expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@ -101,13 +103,13 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}}, expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap), initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 2, initSerial: 2,
inputSerial: 1, inputSerial: 1,
shouldFail: true, shouldFail: true,
@ -115,7 +117,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Empty NS Group Domain Or Not Primary Element Should Fail", name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap), initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@ -137,7 +139,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Invalid NS Group Nameservers list Should Fail", name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap), initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@ -159,7 +161,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Invalid Custom Zone Records list Should Fail", name: "Invalid Custom Zone Records list Should Fail",
initLocalMap: make(registrationMap), initLocalMap: make(registrationMap),
initUpstreamMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@ -181,21 +183,21 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true}, inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registrationMap), expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap), expectedLocalMap: make(registrationMap),
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false}, inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registrationMap), expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap), expectedLocalMap: make(registrationMap),
}, },
} }
@ -431,7 +433,7 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
ctxCancel: cancel, ctxCancel: cancel,
server: dnsServer, server: dnsServer,
dnsMux: mux, dnsMux: mux,
dnsMuxMap: make(registrationMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },

View File

@ -3,24 +3,31 @@ package dns
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/cenkalti/backoff/v4"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
failsTillDeact = int32(3) failsTillDeact = int32(5)
reactivatePeriod = time.Minute reactivatePeriod = 30 * time.Second
upstreamTimeout = 15 * time.Second upstreamTimeout = 15 * time.Second
) )
type upstreamClient interface {
ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error)
}
type upstreamResolver struct { type upstreamResolver struct {
ctx context.Context ctx context.Context
upstreamClient *dns.Client cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []string upstreamServers []string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
@ -33,9 +40,11 @@ type upstreamResolver struct {
reactivate func() reactivate func()
} }
func newUpstreamResolver(ctx context.Context) *upstreamResolver { func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
ctx, cancel := context.WithCancel(parentCTX)
return &upstreamResolver{ return &upstreamResolver{
ctx: ctx, ctx: ctx,
cancel: cancel,
upstreamClient: &dns.Client{}, upstreamClient: &dns.Client{},
upstreamTimeout: upstreamTimeout, upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
@ -43,6 +52,11 @@ func newUpstreamResolver(ctx context.Context) *upstreamResolver {
} }
} }
func (u *upstreamResolver) stop() {
log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
}
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails() defer u.checkUpstreamFails()
@ -107,28 +121,57 @@ func (u *upstreamResolver) checkUpstreamFails() {
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.deactivate() u.deactivate()
u.disabled = true u.disabled = true
go u.waitUntilReactivation() go u.waitUntilResponse()
} }
} }
// waitUntilReactivation reset fails counter and activates upstream resolving // waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolver) waitUntilReactivation() { func (u *upstreamResolver) waitUntilResponse() {
timer := time.NewTimer(u.reactivatePeriod) exponentialBackOff := &backoff.ExponentialBackOff{
defer func() { InitialInterval: 500 * time.Millisecond,
if !timer.Stop() { RandomizationFactor: 0.5,
<-timer.C Multiplier: 1.1,
MaxInterval: u.reactivatePeriod,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
} }
}()
r := new(dns.Msg).SetQuestion("netbird.io.", dns.TypeA)
operation := func() error {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers))
default:
}
var err error
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
_, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err == nil {
return nil
}
}
log.Tracef("checking connectivity with upstreams %s failed with error: %s. Retrying in %s", err, u.upstreamServers, exponentialBackOff.NextBackOff())
return fmt.Errorf("got an error from upstream check call")
}
err := backoff.Retry(operation, exponentialBackOff)
if err != nil {
log.Warn(err)
return return
case <-timer.C: }
log.Info("upstream resolving is reactivated")
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0) u.failsCount.Store(0)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
}
} }
// isTimeout returns true if the given error is a network timeout error. // isTimeout returns true if the given error is a network timeout error.

View File

@ -2,10 +2,11 @@ package dns
import ( import (
"context" "context"
"github.com/miekg/dns"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/miekg/dns"
) )
func TestUpstreamResolver_ServeDNS(t *testing.T) { func TestUpstreamResolver_ServeDNS(t *testing.T) {
@ -106,8 +107,29 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
} }
} }
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// ExchangeContext mock implementation of ExchangeContext from upstreamResolver
func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) {
return c.r, c.rtt, c.err
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := newUpstreamResolver(context.TODO()) resolver := &upstreamResolver{
ctx: context.TODO(),
upstreamClient: &mockUpstreamResolver{
err: nil,
r: new(dns.Msg),
rtt: time.Millisecond,
},
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
resolver.upstreamServers = []string{"0.0.0.0:-1"} resolver.upstreamServers = []string{"0.0.0.0:-1"}
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100