From eb45310c8f443cc3ba537ed6f5ffc5775f5e51d5 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Mon, 13 Feb 2023 18:25:11 +0400 Subject: [PATCH] Fix nameserver peer conn check (#676) * Disable upstream DNS resolver after several tries and fails * Add tests for upstream fails * Use an extra flag to disable domains in DNS upstreams * Fix hashing IPs of nameservers for updates. --- client/internal/dns/file_linux.go | 6 +- client/internal/dns/host.go | 7 +- client/internal/dns/host_darwin.go | 8 +- client/internal/dns/host_windows.go | 6 +- client/internal/dns/mock_test.go | 3 +- client/internal/dns/network_manager_linux.go | 10 +- client/internal/dns/resolvconf_linux.go | 7 +- client/internal/dns/server.go | 172 ++++++++++++++----- client/internal/dns/server_test.go | 82 ++++++++- client/internal/dns/systemd_linux.go | 10 +- client/internal/dns/upstream.go | 104 +++++++++-- client/internal/dns/upstream_test.go | 62 ++++++- 12 files changed, 389 insertions(+), 88 deletions(-) diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go index b9a7f4804..45ac18886 100644 --- a/client/internal/dns/file_linux.go +++ b/client/internal/dns/file_linux.go @@ -3,8 +3,9 @@ package dns import ( "bytes" "fmt" - log "github.com/sirupsen/logrus" "os" + + log "github.com/sirupsen/logrus" ) const ( @@ -14,6 +15,7 @@ const ( "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + fileGeneratedResolvConfSearchBeginContent + "%s\n" ) + const ( fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" fileMaxLineCharsLimit = 256 @@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error { var searchDomains string appendedDomains := 0 for _, dConf := range config.domains { - if dConf.matchOnly { + if dConf.matchOnly || dConf.disabled { continue } if appendedDomains >= fileMaxNumberOfSearchDomains { diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index c077e2032..756f0a515 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -2,8 +2,9 @@ package dns import ( "fmt" - nbdns "github.com/netbirdio/netbird/dns" "strings" + + nbdns "github.com/netbirdio/netbird/dns" ) type hostManager interface { @@ -19,6 +20,7 @@ type hostDNSConfig struct { } type domainConfig struct { + disabled bool domain string matchOnly bool } @@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD serverPort: port, } for _, nsConfig := range dnsConfig.NameServerGroups { + if len(nsConfig.NameServers) == 0 { + continue + } if nsConfig.Primary { config.routeAll = true } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 546561d88..9ced1fe48 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -4,11 +4,12 @@ import ( "bufio" "bytes" "fmt" - "github.com/netbirdio/netbird/iface" - log "github.com/sirupsen/logrus" "os/exec" "strconv" "strings" + + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" ) const ( @@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { ) for _, dConf := range config.domains { + if dConf.disabled { + continue + } if dConf.matchOnly { matchDomains = append(matchDomains, dConf.domain) continue diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index e3f6cf34c..3c241a6c1 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -2,10 +2,11 @@ package dns import ( "fmt" + "strings" + "github.com/netbirdio/netbird/iface" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows/registry" - "strings" ) const ( @@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { ) for _, dConf := range config.domains { + if dConf.disabled { + continue + } if !dConf.matchOnly { searchDomains = append(searchDomains, dConf.domain) } diff --git a/client/internal/dns/mock_test.go b/client/internal/dns/mock_test.go index 511e31dc8..d52ae24da 100644 --- a/client/internal/dns/mock_test.go +++ b/client/internal/dns/mock_test.go @@ -1,8 +1,9 @@ package dns import ( - "github.com/miekg/dns" "net" + + "github.com/miekg/dns" ) type mockResponseWriter struct { diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_linux.go index a668dc518..b9d5ff24f 100644 --- a/client/internal/dns/network_manager_linux.go +++ b/client/internal/dns/network_manager_linux.go @@ -4,14 +4,15 @@ import ( "context" "encoding/binary" "fmt" + "net/netip" + "regexp" + "time" + "github.com/godbus/dbus/v5" "github.com/hashicorp/go-version" "github.com/miekg/dns" "github.com/netbirdio/netbird/iface" log "github.com/sirupsen/logrus" - "net/netip" - "regexp" - "time" ) const ( @@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er matchDomains []string ) for _, dConf := range config.domains { + if dConf.disabled { + continue + } if dConf.matchOnly { matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain)) continue diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_linux.go index 7bd4511ac..c1802858d 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_linux.go @@ -2,10 +2,11 @@ package dns import ( "fmt" - "github.com/netbirdio/netbird/iface" - log "github.com/sirupsen/logrus" "os/exec" "strings" + + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" ) const resolvconfCommand = "resolvconf" @@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error { var searchDomains string appendedDomains := 0 for _, dConf := range config.domains { - if dConf.matchOnly { + if dConf.matchOnly || dConf.disabled { continue } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index b0195aa14..a267ab94e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -3,16 +3,17 @@ package dns import ( "context" "fmt" - "github.com/miekg/dns" - "github.com/mitchellh/hashstructure/v2" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - log "github.com/sirupsen/logrus" "net" "net/netip" "runtime" "sync" "time" + + "github.com/miekg/dns" + "github.com/mitchellh/hashstructure/v2" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" ) const ( @@ -32,7 +33,8 @@ type Server interface { // DefaultServer dns server object type DefaultServer struct { ctx context.Context - stop context.CancelFunc + ctxCancel context.CancelFunc + upstreamCtxCancel context.CancelFunc mux sync.Mutex server *dns.Server dnsMux *dns.ServeMux @@ -45,6 +47,7 @@ type DefaultServer struct { runtimePort int runtimeIP string previousConfigHash uint64 + currentConfig hostDNSConfig customAddress *netip.AddrPort } @@ -79,7 +82,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd defaultServer := &DefaultServer{ ctx: ctx, - stop: stop, + ctxCancel: stop, server: dnsServer, dnsMux: mux, dnsMuxMap: make(registrationMap), @@ -102,7 +105,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd // Start runs the listener in a go routine func (s *DefaultServer) Start() { - if s.customAddress != nil { s.runtimeIP = s.customAddress.Addr().String() s.runtimePort = int(s.customAddress.Port()) @@ -163,7 +165,7 @@ func (s *DefaultServer) setListenerStatus(running bool) { func (s *DefaultServer) Stop() { s.mux.Lock() defer s.mux.Unlock() - s.stop() + s.ctxCancel() err := s.hostManager.restoreHostDNS() if err != nil { @@ -209,6 +211,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro ZeroNil: true, IgnoreZeroValue: true, SlicesAsSets: true, + UseStringer: true, }) if err != nil { log.Errorf("unable to hash the dns configuration update, got error: %s", err) @@ -219,34 +222,9 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro s.updateSerial = serial return nil } - // is the service should be disabled, we stop the listener - // and proceed with a regular update to clean up the handlers and records - if !update.ServiceEnable { - err := s.stopListener() - if err != nil { - log.Error(err) - } - } else if !s.listenerIsRunning { - s.Start() - } - localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) - if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) - } - upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) - if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) - } - - muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) - - s.updateMux(muxUpdates) - s.updateLocalResolver(localRecords) - - err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)) - if err != nil { - log.Error(err) + if err := s.applyConfiguration(update); err != nil { + return err } s.updateSerial = serial @@ -256,6 +234,40 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro } } +func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { + // is the service should be disabled, we stop the listener + // and proceed with a regular update to clean up the handlers and records + if !update.ServiceEnable { + err := s.stopListener() + if err != nil { + log.Error(err) + } + } else if !s.listenerIsRunning { + s.Start() + } + + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + + muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) + + s.updateMux(muxUpdates) + s.updateLocalResolver(localRecords) + s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) + + if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + log.Error(err) + } + + return nil +} + func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { var muxUpdates []muxUpdate localRecords := make(map[string]nbdns.SimpleRecord, 0) @@ -284,16 +296,22 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) } func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { + // clean up the previous upstream resolver + if s.upstreamCtxCancel != nil { + s.upstreamCtxCancel() + } + var muxUpdates []muxUpdate for _, nsGroup := range nameServerGroups { if len(nsGroup.NameServers) == 0 { - return nil, fmt.Errorf("received a nameserver group with empty nameserver list") - } - handler := &upstreamResolver{ - parentCTX: s.ctx, - upstreamClient: &dns.Client{}, - upstreamTimeout: defaultUpstreamTimeout, + log.Warn("received a nameserver group with empty nameserver list") + continue } + + var ctx context.Context + ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx) + + handler := newUpstreamResolver(ctx) for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", @@ -308,6 +326,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam continue } + // when upstream fails to resolve domain several times over all it servers + // it will calls this hook to exclude self from the configuration and + // reapply DNS settings, but it not touch the original configuration and serial number + // because it is temporal deactivation until next try + // + // after some period defined by upstream it trys to reactivate self by calling this hook + // everything we need here is just to re-apply current configuration because it already + // contains this upstream settings (temporal deactivation not removed it) + handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) + if nsGroup.Primary { muxUpdates = append(muxUpdates, muxUpdate{ domain: nbdns.RootZone, @@ -382,3 +410,63 @@ func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) { func (s *DefaultServer) deregisterMux(pattern string) { s.dnsMux.HandleRemove(pattern) } + +// upstreamCallbacks returns two functions, the first one is used to deactivate +// the upstream resolver from the configuration, the second one is used to +// reactivate it. Not allowed to call reactivate before deactivate. +func (s *DefaultServer) upstreamCallbacks( + nsGroup *nbdns.NameServerGroup, + handler dns.Handler, +) (deactivate func(), reactivate func()) { + var removeIndex map[string]int + deactivate = func() { + s.mux.Lock() + defer s.mux.Unlock() + + l := log.WithField("nameservers", nsGroup.NameServers) + l.Info("temporary deactivate nameservers group due timeout") + + removeIndex = make(map[string]int) + for _, domain := range nsGroup.Domains { + removeIndex[domain] = -1 + } + if nsGroup.Primary { + removeIndex[nbdns.RootZone] = -1 + s.currentConfig.routeAll = false + } + + for i, item := range s.currentConfig.domains { + if _, found := removeIndex[item.domain]; found { + s.currentConfig.domains[i].disabled = true + s.deregisterMux(item.domain) + removeIndex[item.domain] = i + } + } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + l.WithError(err).Error("fail to apply nameserver deactivation on the host") + } + } + reactivate = func() { + s.mux.Lock() + defer s.mux.Unlock() + + for domain, i := range removeIndex { + if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { + continue + } + s.currentConfig.domains[i].disabled = false + s.registerMux(domain, handler) + } + + l := log.WithField("nameservers", nsGroup.NameServers) + l.Debug("reactivate temporary disabled nameserver group") + + if nsGroup.Primary { + s.currentConfig.routeAll = true + } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + } + } + return +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 62de05be5..395652733 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -3,13 +3,15 @@ package dns import ( "context" "fmt" + "net" + "net/netip" + "strings" + "testing" + "time" + "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" - "net" - "net/netip" - "testing" - "time" ) var zoneRecords = []nbdns.SimpleRecord{ @@ -23,7 +25,6 @@ var zoneRecords = []nbdns.SimpleRecord{ } func TestUpdateDNSServer(t *testing.T) { - nameServers := []nbdns.NameServer{ { IP: netip.MustParseAddr("8.8.8.8"), @@ -263,7 +264,6 @@ func TestUpdateDNSServer(t *testing.T) { } func TestDNSServerStartStop(t *testing.T) { - testCases := []struct { name string addrPort string @@ -333,6 +333,72 @@ func TestDNSServerStartStop(t *testing.T) { } } +func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { + hostManager := &mockHostConfigurator{} + server := DefaultServer{ + dnsMux: dns.DefaultServeMux, + localResolver: &localResolver{ + registeredMap: make(registrationMap), + }, + hostManager: hostManager, + currentConfig: hostDNSConfig{ + domains: []domainConfig{ + {false, "domain0", false}, + {false, "domain1", false}, + {false, "domain2", false}, + }, + }, + } + + var domainsUpdate string + hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error { + domains := []string{} + for _, item := range config.domains { + if item.disabled { + continue + } + domains = append(domains, item.domain) + } + domainsUpdate = strings.Join(domains, ",") + return nil + } + + deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ + Domains: []string{"domain1"}, + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + }, nil) + + deactivate() + expected := "domain0,domain2" + domains := []string{} + for _, item := range server.currentConfig.domains { + if item.disabled { + continue + } + domains = append(domains, item.domain) + } + got := strings.Join(domains, ",") + if expected != got { + t.Errorf("expected domains list: %q, got %q", expected, got) + } + + reactivate() + expected = "domain0,domain1,domain2" + domains = []string{} + for _, item := range server.currentConfig.domains { + if item.disabled { + continue + } + domains = append(domains, item.domain) + } + got = strings.Join(domains, ",") + if expected != got { + t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) + } +} + func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer { mux := dns.NewServeMux() @@ -351,11 +417,11 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe UDPSize: 65535, } - ctx, stop := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(context.TODO()) return &DefaultServer{ ctx: ctx, - stop: stop, + ctxCancel: cancel, server: dnsServer, dnsMux: mux, dnsMuxMap: make(registrationMap), diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index d61ef97b4..fbb16fe64 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -3,15 +3,16 @@ package dns import ( "context" "fmt" + "net" + "net/netip" + "time" + "github.com/godbus/dbus/v5" "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/iface" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" - "net" - "net/netip" - "time" ) const ( @@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { domainsInput []systemdDbusLinkDomainsInput ) for _, dConf := range config.domains { + if dConf.disabled { + continue + } domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ Domain: dns.Fqdn(dConf.domain), MatchOnly: dConf.matchOnly, diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index fcc8bc685..c2af76ecd 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -3,44 +3,73 @@ package dns import ( "context" "errors" + "net" + "sync" + "sync/atomic" + "time" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "net" - "time" ) -const defaultUpstreamTimeout = 15 * time.Second +const ( + failsTillDeact = int32(3) + reactivatePeriod = time.Minute + upstreamTimeout = 15 * time.Second +) type upstreamResolver struct { - parentCTX context.Context - upstreamClient *dns.Client - upstreamServers []string - upstreamTimeout time.Duration + ctx context.Context + upstreamClient *dns.Client + upstreamServers []string + disabled bool + failsCount atomic.Int32 + failsTillDeact int32 + mutex sync.Mutex + reactivatePeriod time.Duration + upstreamTimeout time.Duration + + deactivate func() + reactivate func() +} + +func newUpstreamResolver(ctx context.Context) *upstreamResolver { + return &upstreamResolver{ + ctx: ctx, + upstreamClient: &dns.Client{}, + upstreamTimeout: upstreamTimeout, + reactivatePeriod: reactivatePeriod, + failsTillDeact: failsTillDeact, + } } // ServeDNS handles a DNS request func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + defer u.checkUpstreamFails() - log.Tracef("received an upstream question: %#v", r.Question[0]) + log.WithField("question", r.Question[0]).Trace("received an upstream question") select { - case <-u.parentCTX.Done(): + case <-u.ctx.Done(): return default: } for _, upstream := range u.upstreamServers { - ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout) + ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) cancel() if err != nil { if err == context.DeadlineExceeded || isTimeout(err) { - log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err) + log.WithError(err).WithField("upstream", upstream). + Warn("got an error while connecting to upstream") continue } - log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err) + u.failsCount.Add(1) + log.WithError(err).WithField("upstream", upstream). + Error("got an error while querying the upstream") return } @@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { err = w.WriteMsg(rm) if err != nil { - log.Errorf("got an error while writing the upstream resolver response, error: %v", err) + log.WithError(err).Error("got an error while writing the upstream resolver response") } + // count the fails only if they happen sequentially + u.failsCount.Store(0) return } - log.Errorf("all queries to the upstream nameservers failed with timeout") + u.failsCount.Add(1) + log.Error("all queries to the upstream nameservers failed with timeout") +} + +// checkUpstreamFails counts fails and disables or enables upstream resolving +// +// If fails count is greater that failsTillDeact, upstream resolving +// will be disabled for reactivatePeriod, after that time period fails counter +// will be reset and upstream will be reactivated. +func (u *upstreamResolver) checkUpstreamFails() { + u.mutex.Lock() + defer u.mutex.Unlock() + + if u.failsCount.Load() < u.failsTillDeact || u.disabled { + return + } + + select { + case <-u.ctx.Done(): + return + default: + log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) + u.deactivate() + u.disabled = true + go u.waitUntilReactivation() + } +} + +// waitUntilReactivation reset fails counter and activates upstream resolving +func (u *upstreamResolver) waitUntilReactivation() { + timer := time.NewTimer(u.reactivatePeriod) + defer func() { + if !timer.Stop() { + <-timer.C + } + }() + + select { + case <-u.ctx.Done(): + return + case <-timer.C: + log.Info("upstream resolving is reactivated") + u.failsCount.Store(0) + u.reactivate() + u.disabled = false + } } // isTimeout returns true if the given error is a network timeout error. diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 0fbb7e49f..4915e777d 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { name: "Should Resolve A Record", inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"}, - timeout: defaultUpstreamTimeout, + timeout: upstreamTimeout, expectedAnswer: "1.1.1.1", }, { @@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"}, cancelCTX: true, - timeout: defaultUpstreamTimeout, + timeout: upstreamTimeout, responseShouldBeNil: true, }, //{ @@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver := &upstreamResolver{ - parentCTX: ctx, - upstreamClient: &dns.Client{}, - upstreamServers: testCase.InputServers, - upstreamTimeout: testCase.timeout, - } + resolver := newUpstreamResolver(ctx) + resolver.upstreamServers = testCase.InputServers + resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { cancel() } else { @@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { }) } } + +func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { + resolver := newUpstreamResolver(context.TODO()) + resolver.upstreamServers = []string{"0.0.0.0:-1"} + resolver.failsTillDeact = 0 + resolver.reactivatePeriod = time.Microsecond * 100 + + responseWriter := &mockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { return nil }, + } + + failed := false + resolver.deactivate = func() { + failed = true + } + + reactivated := false + resolver.reactivate = func() { + reactivated = true + } + + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) + + if !failed { + t.Errorf("expected that resolving was deactivated") + return + } + + if !resolver.disabled { + t.Errorf("resolver should be disabled") + return + } + + time.Sleep(time.Millisecond * 200) + + if !reactivated { + t.Errorf("expected that resolving was reactivated") + return + } + + if resolver.failsCount.Load() != 0 { + t.Errorf("fails count after reactivation should be 0") + return + } + + if resolver.disabled { + t.Errorf("should be enabled") + } +}