Fix nameserver peer conn check (#676)

* Disable upstream DNS resolver after several tries and fails

* Add tests for upstream fails

* Use an extra flag to disable domains in DNS upstreams

* Fix hashing IPs of nameservers for updates.
This commit is contained in:
Givi Khojanashvili 2023-02-13 18:25:11 +04:00 committed by GitHub
parent d5dfed498b
commit eb45310c8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 389 additions and 88 deletions

View File

@ -3,8 +3,9 @@ package dns
import ( import (
"bytes" "bytes"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"os" "os"
log "github.com/sirupsen/logrus"
) )
const ( const (
@ -14,6 +15,7 @@ const (
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n" fileGeneratedResolvConfSearchBeginContent + "%s\n"
) )
const ( const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256 fileMaxLineCharsLimit = 256
@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
var searchDomains string var searchDomains string
appendedDomains := 0 appendedDomains := 0
for _, dConf := range config.domains { for _, dConf := range config.domains {
if dConf.matchOnly { if dConf.matchOnly || dConf.disabled {
continue continue
} }
if appendedDomains >= fileMaxNumberOfSearchDomains { if appendedDomains >= fileMaxNumberOfSearchDomains {

View File

@ -2,8 +2,9 @@ package dns
import ( import (
"fmt" "fmt"
nbdns "github.com/netbirdio/netbird/dns"
"strings" "strings"
nbdns "github.com/netbirdio/netbird/dns"
) )
type hostManager interface { type hostManager interface {
@ -19,6 +20,7 @@ type hostDNSConfig struct {
} }
type domainConfig struct { type domainConfig struct {
disabled bool
domain string domain string
matchOnly bool matchOnly bool
} }
@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
serverPort: port, serverPort: port,
} }
for _, nsConfig := range dnsConfig.NameServerGroups { for _, nsConfig := range dnsConfig.NameServerGroups {
if len(nsConfig.NameServers) == 0 {
continue
}
if nsConfig.Primary { if nsConfig.Primary {
config.routeAll = true config.routeAll = true
} }

View File

@ -4,11 +4,12 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
) )
const ( const (
@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
) )
for _, dConf := range config.domains { for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if dConf.matchOnly { if dConf.matchOnly {
matchDomains = append(matchDomains, dConf.domain) matchDomains = append(matchDomains, dConf.domain)
continue continue

View File

@ -2,10 +2,11 @@ package dns
import ( import (
"fmt" "fmt"
"strings"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"strings"
) )
const ( const (
@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
) )
for _, dConf := range config.domains { for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if !dConf.matchOnly { if !dConf.matchOnly {
searchDomains = append(searchDomains, dConf.domain) searchDomains = append(searchDomains, dConf.domain)
} }

View File

@ -1,8 +1,9 @@
package dns package dns
import ( import (
"github.com/miekg/dns"
"net" "net"
"github.com/miekg/dns"
) )
type mockResponseWriter struct { type mockResponseWriter struct {

View File

@ -4,14 +4,15 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net/netip"
"regexp"
"time"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/netip"
"regexp"
"time"
) )
const ( const (
@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
matchDomains []string matchDomains []string
) )
for _, dConf := range config.domains { for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if dConf.matchOnly { if dConf.matchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain)) matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
continue continue

View File

@ -2,10 +2,11 @@ package dns
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os/exec" "os/exec"
"strings" "strings"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
var searchDomains string var searchDomains string
appendedDomains := 0 appendedDomains := 0
for _, dConf := range config.domains { for _, dConf := range config.domains {
if dConf.matchOnly { if dConf.matchOnly || dConf.disabled {
continue continue
} }

View File

@ -3,16 +3,17 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
) )
const ( const (
@ -32,7 +33,8 @@ type Server interface {
// DefaultServer dns server object // DefaultServer dns server object
type DefaultServer struct { type DefaultServer struct {
ctx context.Context ctx context.Context
stop context.CancelFunc ctxCancel context.CancelFunc
upstreamCtxCancel context.CancelFunc
mux sync.Mutex mux sync.Mutex
server *dns.Server server *dns.Server
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
@ -45,6 +47,7 @@ type DefaultServer struct {
runtimePort int runtimePort int
runtimeIP string runtimeIP string
previousConfigHash uint64 previousConfigHash uint64
currentConfig hostDNSConfig
customAddress *netip.AddrPort customAddress *netip.AddrPort
} }
@ -79,7 +82,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
stop: stop, ctxCancel: stop,
server: dnsServer, server: dnsServer,
dnsMux: mux, dnsMux: mux,
dnsMuxMap: make(registrationMap), dnsMuxMap: make(registrationMap),
@ -102,7 +105,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
// Start runs the listener in a go routine // Start runs the listener in a go routine
func (s *DefaultServer) Start() { func (s *DefaultServer) Start() {
if s.customAddress != nil { if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String() s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port()) s.runtimePort = int(s.customAddress.Port())
@ -163,7 +165,7 @@ func (s *DefaultServer) setListenerStatus(running bool) {
func (s *DefaultServer) Stop() { func (s *DefaultServer) Stop() {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
s.stop() s.ctxCancel()
err := s.hostManager.restoreHostDNS() err := s.hostManager.restoreHostDNS()
if err != nil { if err != nil {
@ -209,6 +211,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
ZeroNil: true, ZeroNil: true,
IgnoreZeroValue: true, IgnoreZeroValue: true,
SlicesAsSets: true, SlicesAsSets: true,
UseStringer: true,
}) })
if err != nil { if err != nil {
log.Errorf("unable to hash the dns configuration update, got error: %s", err) log.Errorf("unable to hash the dns configuration update, got error: %s", err)
@ -219,6 +222,19 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
s.updateSerial = serial s.updateSerial = serial
return nil return nil
} }
if err := s.applyConfiguration(update); err != nil {
return err
}
s.updateSerial = serial
s.previousConfigHash = hash
return nil
}
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener // is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable { if !update.ServiceEnable {
@ -243,18 +259,14 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords) s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)) if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
if err != nil {
log.Error(err) log.Error(err)
} }
s.updateSerial = serial
s.previousConfigHash = hash
return nil return nil
} }
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate var muxUpdates []muxUpdate
@ -284,16 +296,22 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
} }
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
// clean up the previous upstream resolver
if s.upstreamCtxCancel != nil {
s.upstreamCtxCancel()
}
var muxUpdates []muxUpdate var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups { for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 { if len(nsGroup.NameServers) == 0 {
return nil, fmt.Errorf("received a nameserver group with empty nameserver list") log.Warn("received a nameserver group with empty nameserver list")
} continue
handler := &upstreamResolver{
parentCTX: s.ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: defaultUpstreamTimeout,
} }
var ctx context.Context
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
handler := newUpstreamResolver(ctx)
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType { if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
@ -308,6 +326,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue continue
} }
// when upstream fails to resolve domain several times over all it servers
// it will calls this hook to exclude self from the configuration and
// reapply DNS settings, but it not touch the original configuration and serial number
// because it is temporal deactivation until next try
//
// after some period defined by upstream it trys to reactivate self by calling this hook
// everything we need here is just to re-apply current configuration because it already
// contains this upstream settings (temporal deactivation not removed it)
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
if nsGroup.Primary { if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{ muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone, domain: nbdns.RootZone,
@ -382,3 +410,63 @@ func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
func (s *DefaultServer) deregisterMux(pattern string) { func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern) s.dnsMux.HandleRemove(pattern)
} }
// upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate.
func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup,
handler dns.Handler,
) (deactivate func(), reactivate func()) {
var removeIndex map[string]int
deactivate = func() {
s.mux.Lock()
defer s.mux.Unlock()
l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("temporary deactivate nameservers group due timeout")
removeIndex = make(map[string]int)
for _, domain := range nsGroup.Domains {
removeIndex[domain] = -1
}
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.routeAll = false
}
for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true
s.deregisterMux(item.domain)
removeIndex[item.domain] = i
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
}
}
reactivate = func() {
s.mux.Lock()
defer s.mux.Unlock()
for domain, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
continue
}
s.currentConfig.domains[i].disabled = false
s.registerMux(domain, handler)
}
l := log.WithField("nameservers", nsGroup.NameServers)
l.Debug("reactivate temporary disabled nameserver group")
if nsGroup.Primary {
s.currentConfig.routeAll = true
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
return
}

View File

@ -3,13 +3,15 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"net"
"net/netip"
"testing"
"time"
) )
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
@ -23,7 +25,6 @@ var zoneRecords = []nbdns.SimpleRecord{
} }
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("8.8.8.8"), IP: netip.MustParseAddr("8.8.8.8"),
@ -263,7 +264,6 @@ func TestUpdateDNSServer(t *testing.T) {
} }
func TestDNSServerStartStop(t *testing.T) { func TestDNSServerStartStop(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
addrPort string addrPort string
@ -333,6 +333,72 @@ func TestDNSServerStartStop(t *testing.T) {
} }
} }
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
dnsMux: dns.DefaultServeMux,
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
hostManager: hostManager,
currentConfig: hostDNSConfig{
domains: []domainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
domains := []string{}
for _, item := range config.domains {
if item.disabled {
continue
}
domains = append(domains, item.domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil)
deactivate()
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.domains {
if item.disabled {
continue
}
domains = append(domains, item.domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.domains {
if item.disabled {
continue
}
domains = append(domains, item.domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer { func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
mux := dns.NewServeMux() mux := dns.NewServeMux()
@ -351,11 +417,11 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
UDPSize: 65535, UDPSize: 65535,
} }
ctx, stop := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
return &DefaultServer{ return &DefaultServer{
ctx: ctx, ctx: ctx,
stop: stop, ctxCancel: cancel,
server: dnsServer, server: dnsServer,
dnsMux: mux, dnsMux: mux,
dnsMuxMap: make(registrationMap), dnsMuxMap: make(registrationMap),

View File

@ -3,15 +3,16 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"time"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net"
"net/netip"
"time"
) )
const ( const (
@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
domainsInput []systemdDbusLinkDomainsInput domainsInput []systemdDbusLinkDomainsInput
) )
for _, dConf := range config.domains { for _, dConf := range config.domains {
if dConf.disabled {
continue
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.domain), Domain: dns.Fqdn(dConf.domain),
MatchOnly: dConf.matchOnly, MatchOnly: dConf.matchOnly,

View File

@ -3,44 +3,73 @@ package dns
import ( import (
"context" "context"
"errors" "errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net"
"time"
) )
const defaultUpstreamTimeout = 15 * time.Second const (
failsTillDeact = int32(3)
reactivatePeriod = time.Minute
upstreamTimeout = 15 * time.Second
)
type upstreamResolver struct { type upstreamResolver struct {
parentCTX context.Context ctx context.Context
upstreamClient *dns.Client upstreamClient *dns.Client
upstreamServers []string upstreamServers []string
disabled bool
failsCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration upstreamTimeout time.Duration
deactivate func()
reactivate func()
}
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
return &upstreamResolver{
ctx: ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
} }
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails()
log.Tracef("received an upstream question: %#v", r.Question[0]) log.WithField("question", r.Question[0]).Trace("received an upstream question")
select { select {
case <-u.parentCTX.Done(): case <-u.ctx.Done():
return return
default: default:
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel() cancel()
if err != nil { if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) { if err == context.DeadlineExceeded || isTimeout(err) {
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err) log.WithError(err).WithField("upstream", upstream).
Warn("got an error while connecting to upstream")
continue continue
} }
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err) u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream).
Error("got an error while querying the upstream")
return return
} }
@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
err = w.WriteMsg(rm) err = w.WriteMsg(rm)
if err != nil { if err != nil {
log.Errorf("got an error while writing the upstream resolver response, error: %v", err) log.WithError(err).Error("got an error while writing the upstream resolver response")
} }
// count the fails only if they happen sequentially
u.failsCount.Store(0)
return return
} }
log.Errorf("all queries to the upstream nameservers failed with timeout") u.failsCount.Add(1)
log.Error("all queries to the upstream nameservers failed with timeout")
}
// checkUpstreamFails counts fails and disables or enables upstream resolving
//
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolver) checkUpstreamFails() {
u.mutex.Lock()
defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
return
}
select {
case <-u.ctx.Done():
return
default:
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilReactivation()
}
}
// waitUntilReactivation reset fails counter and activates upstream resolving
func (u *upstreamResolver) waitUntilReactivation() {
timer := time.NewTimer(u.reactivatePeriod)
defer func() {
if !timer.Stop() {
<-timer.C
}
}()
select {
case <-u.ctx.Done():
return
case <-timer.C:
log.Info("upstream resolving is reactivated")
u.failsCount.Store(0)
u.reactivate()
u.disabled = false
}
} }
// isTimeout returns true if the given error is a network timeout error. // isTimeout returns true if the given error is a network timeout error.

View File

@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
name: "Should Resolve A Record", name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"}, InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: defaultUpstreamTimeout, timeout: upstreamTimeout,
expectedAnswer: "1.1.1.1", expectedAnswer: "1.1.1.1",
}, },
{ {
@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"}, InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true, cancelCTX: true,
timeout: defaultUpstreamTimeout, timeout: upstreamTimeout,
responseShouldBeNil: true, responseShouldBeNil: true,
}, },
//{ //{
@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver := &upstreamResolver{ resolver := newUpstreamResolver(ctx)
parentCTX: ctx, resolver.upstreamServers = testCase.InputServers
upstreamClient: &dns.Client{}, resolver.upstreamTimeout = testCase.timeout
upstreamServers: testCase.InputServers,
upstreamTimeout: testCase.timeout,
}
if testCase.cancelCTX { if testCase.cancelCTX {
cancel() cancel()
} else { } else {
@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}) })
} }
} }
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := newUpstreamResolver(context.TODO())
resolver.upstreamServers = []string{"0.0.0.0:-1"}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}
failed := false
resolver.deactivate = func() {
failed = true
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.failsCount.Load() != 0 {
t.Errorf("fails count after reactivation should be 0")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
}
}