From b6105e9d7c05431f689895ce00e11e20392b18ca Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 26 May 2023 17:13:59 +0200 Subject: [PATCH] Use backoff.retry to check if upstreams are responsive (#901) Retry, in an exponential interval, querying the upstream servers until it gets a positive response --- client/internal/dns/local.go | 11 ++- client/internal/dns/server_nonandroid.go | 36 ++++++---- client/internal/dns/server_test.go | 32 +++++---- client/internal/dns/upstream.go | 87 ++++++++++++++++++------ client/internal/dns/upstream_test.go | 26 ++++++- 5 files changed, 135 insertions(+), 57 deletions(-) diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index e9fcc37eb..18fec812a 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -2,10 +2,12 @@ package dns import ( "fmt" - "github.com/miekg/dns" - nbdns "github.com/netbirdio/netbird/dns" - log "github.com/sirupsen/logrus" "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" ) type registrationMap map[string]struct{} @@ -15,6 +17,9 @@ type localResolver struct { records sync.Map } +func (d *localResolver) stop() { +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { log.Tracef("received question: %#v\n", r.Question[0]) diff --git a/client/internal/dns/server_nonandroid.go b/client/internal/dns/server_nonandroid.go index d1bf5c970..ca4b25708 100644 --- a/client/internal/dns/server_nonandroid.go +++ b/client/internal/dns/server_nonandroid.go @@ -26,15 +26,16 @@ const ( customIP = "127.0.0.153" ) +type registeredHandlerMap map[string]handlerWithStop + // DefaultServer dns server object type DefaultServer struct { ctx context.Context ctxCancel context.CancelFunc - upstreamCtxCancel context.CancelFunc mux sync.Mutex server *dns.Server dnsMux *dns.ServeMux - dnsMuxMap registrationMap + dnsMuxMap registeredHandlerMap localResolver *localResolver wgInterface *iface.WGIface hostManager hostManager @@ -47,9 +48,14 @@ type DefaultServer struct { customAddress *netip.AddrPort } +type handlerWithStop interface { + dns.Handler + stop() +} + type muxUpdate struct { domain string - handler dns.Handler + handler handlerWithStop } // NewDefaultServer returns a new dns server @@ -79,7 +85,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd ctxCancel: stop, server: dnsServer, dnsMux: mux, - dnsMuxMap: make(registrationMap), + dnsMuxMap: make(registeredHandlerMap), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -297,10 +303,6 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) } func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { - // clean up the previous upstream resolver - if s.upstreamCtxCancel != nil { - s.upstreamCtxCancel() - } var muxUpdates []muxUpdate for _, nsGroup := range nameServerGroups { @@ -309,10 +311,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam continue } - var ctx context.Context - ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx) - - handler := newUpstreamResolver(ctx) + handler := newUpstreamResolver(s.ctx) for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", @@ -323,6 +322,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam } if len(handler.upstreamServers) == 0 { + handler.stop() log.Errorf("received a nameserver group with an invalid nameserver list") continue } @@ -346,11 +346,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam } if len(nsGroup.Domains) == 0 { + handler.stop() return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") } for _, domain := range nsGroup.Domains { if domain == "" { + handler.stop() return nil, fmt.Errorf("received a nameserver group with an empty domain element") } muxUpdates = append(muxUpdates, muxUpdate{ @@ -363,16 +365,20 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam } func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { - muxUpdateMap := make(registrationMap) + muxUpdateMap := make(registeredHandlerMap) for _, update := range muxUpdates { s.registerMux(update.domain, update.handler) - muxUpdateMap[update.domain] = struct{}{} + muxUpdateMap[update.domain] = update.handler + if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { + existingHandler.stop() + } } - for key := range s.dnsMuxMap { + for key, existingHandler := range s.dnsMuxMap { _, found := muxUpdateMap[key] if !found { + existingHandler.stop() s.deregisterMux(key) } } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 03e3ddc6e..9227bf0c2 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -41,21 +41,23 @@ func TestUpdateDNSServer(t *testing.T) { }, } + dummyHandler := &localResolver{} + testCases := []struct { name string - initUpstreamMap registrationMap + initUpstreamMap registeredHandlerMap initLocalMap registrationMap initSerial uint64 inputSerial uint64 inputUpdate nbdns.Config shouldFail bool - expectedUpstreamMap registrationMap + expectedUpstreamMap registeredHandlerMap expectedLocalMap registrationMap }{ { name: "Initial Config Should Succeed", initLocalMap: make(registrationMap), - initUpstreamMap: make(registrationMap), + initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -77,13 +79,13 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}}, + expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler}, expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "New Config Should Succeed", initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, + initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler}, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -101,13 +103,13 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}}, + expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler}, expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "Smaller Config Serial Should Be Skipped", initLocalMap: make(registrationMap), - initUpstreamMap: make(registrationMap), + initUpstreamMap: make(registeredHandlerMap), initSerial: 2, inputSerial: 1, shouldFail: true, @@ -115,7 +117,7 @@ func TestUpdateDNSServer(t *testing.T) { { name: "Empty NS Group Domain Or Not Primary Element Should Fail", initLocalMap: make(registrationMap), - initUpstreamMap: make(registrationMap), + initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -137,7 +139,7 @@ func TestUpdateDNSServer(t *testing.T) { { name: "Invalid NS Group Nameservers list Should Fail", initLocalMap: make(registrationMap), - initUpstreamMap: make(registrationMap), + initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -159,7 +161,7 @@ func TestUpdateDNSServer(t *testing.T) { { name: "Invalid Custom Zone Records list Should Fail", initLocalMap: make(registrationMap), - initUpstreamMap: make(registrationMap), + initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -181,21 +183,21 @@ func TestUpdateDNSServer(t *testing.T) { { name: "Empty Config Should Succeed and Clean Maps", initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: true}, - expectedUpstreamMap: make(registrationMap), + expectedUpstreamMap: make(registeredHandlerMap), expectedLocalMap: make(registrationMap), }, { name: "Disabled Service Should clean map", initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler}, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ServiceEnable: false}, - expectedUpstreamMap: make(registrationMap), + expectedUpstreamMap: make(registeredHandlerMap), expectedLocalMap: make(registrationMap), }, } @@ -431,7 +433,7 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe ctxCancel: cancel, server: dnsServer, dnsMux: mux, - dnsMuxMap: make(registrationMap), + dnsMuxMap: make(registeredHandlerMap), localResolver: &localResolver{ registeredMap: make(registrationMap), }, diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c2af76ecd..b4febd7a4 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -3,24 +3,31 @@ package dns import ( "context" "errors" + "fmt" "net" "sync" "sync/atomic" "time" + "github.com/cenkalti/backoff/v4" "github.com/miekg/dns" log "github.com/sirupsen/logrus" ) const ( - failsTillDeact = int32(3) - reactivatePeriod = time.Minute + failsTillDeact = int32(5) + reactivatePeriod = 30 * time.Second upstreamTimeout = 15 * time.Second ) +type upstreamClient interface { + ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) +} + type upstreamResolver struct { ctx context.Context - upstreamClient *dns.Client + cancel context.CancelFunc + upstreamClient upstreamClient upstreamServers []string disabled bool failsCount atomic.Int32 @@ -33,9 +40,11 @@ type upstreamResolver struct { reactivate func() } -func newUpstreamResolver(ctx context.Context) *upstreamResolver { +func newUpstreamResolver(parentCTX context.Context) *upstreamResolver { + ctx, cancel := context.WithCancel(parentCTX) return &upstreamResolver{ ctx: ctx, + cancel: cancel, upstreamClient: &dns.Client{}, upstreamTimeout: upstreamTimeout, reactivatePeriod: reactivatePeriod, @@ -43,6 +52,11 @@ func newUpstreamResolver(ctx context.Context) *upstreamResolver { } } +func (u *upstreamResolver) stop() { + log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers) + u.cancel() +} + // ServeDNS handles a DNS request func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { defer u.checkUpstreamFails() @@ -107,28 +121,57 @@ func (u *upstreamResolver) checkUpstreamFails() { log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) u.deactivate() u.disabled = true - go u.waitUntilReactivation() + go u.waitUntilResponse() } } -// waitUntilReactivation reset fails counter and activates upstream resolving -func (u *upstreamResolver) waitUntilReactivation() { - timer := time.NewTimer(u.reactivatePeriod) - defer func() { - if !timer.Stop() { - <-timer.C - } - }() - - select { - case <-u.ctx.Done(): - return - case <-timer.C: - log.Info("upstream resolving is reactivated") - u.failsCount.Store(0) - u.reactivate() - u.disabled = false +// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response +func (u *upstreamResolver) waitUntilResponse() { + exponentialBackOff := &backoff.ExponentialBackOff{ + InitialInterval: 500 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 1.1, + MaxInterval: u.reactivatePeriod, + MaxElapsedTime: 0, + Stop: backoff.Stop, + Clock: backoff.SystemClock, } + + r := new(dns.Msg).SetQuestion("netbird.io.", dns.TypeA) + + operation := func() error { + select { + case <-u.ctx.Done(): + return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) + default: + } + + var err error + for _, upstream := range u.upstreamServers { + ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) + _, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream) + + cancel() + + if err == nil { + return nil + } + } + + log.Tracef("checking connectivity with upstreams %s failed with error: %s. Retrying in %s", err, u.upstreamServers, exponentialBackOff.NextBackOff()) + return fmt.Errorf("got an error from upstream check call") + } + + err := backoff.Retry(operation, exponentialBackOff) + if err != nil { + log.Warn(err) + return + } + + log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) + u.failsCount.Store(0) + u.reactivate() + u.disabled = false } // isTimeout returns true if the given error is a network timeout error. diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 4915e777d..0a5de0b18 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -2,10 +2,11 @@ package dns import ( "context" - "github.com/miekg/dns" "strings" "testing" "time" + + "github.com/miekg/dns" ) func TestUpstreamResolver_ServeDNS(t *testing.T) { @@ -106,8 +107,29 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { } } +type mockUpstreamResolver struct { + r *dns.Msg + rtt time.Duration + err error +} + +// ExchangeContext mock implementation of ExchangeContext from upstreamResolver +func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) { + return c.r, c.rtt, c.err +} + func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - resolver := newUpstreamResolver(context.TODO()) + resolver := &upstreamResolver{ + ctx: context.TODO(), + upstreamClient: &mockUpstreamResolver{ + err: nil, + r: new(dns.Msg), + rtt: time.Millisecond, + }, + upstreamTimeout: upstreamTimeout, + reactivatePeriod: reactivatePeriod, + failsTillDeact: failsTillDeact, + } resolver.upstreamServers = []string{"0.0.0.0:-1"} resolver.failsTillDeact = 0 resolver.reactivatePeriod = time.Microsecond * 100