Fix nameserver peer conn check (#676)

* Disable upstream DNS resolver after several tries and fails

* Add tests for upstream fails

* Use an extra flag to disable domains in DNS upstreams

* Fix hashing IPs of nameservers for updates.
This commit is contained in:
Givi Khojanashvili 2023-02-13 18:25:11 +04:00 committed by GitHub
parent d5dfed498b
commit eb45310c8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 389 additions and 88 deletions

View File

@ -3,8 +3,9 @@ package dns
import (
"bytes"
"fmt"
log "github.com/sirupsen/logrus"
"os"
log "github.com/sirupsen/logrus"
)
const (
@ -14,6 +15,7 @@ const (
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n"
)
const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly {
if dConf.matchOnly || dConf.disabled {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {

View File

@ -2,8 +2,9 @@ package dns
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
"strings"
nbdns "github.com/netbirdio/netbird/dns"
)
type hostManager interface {
@ -19,6 +20,7 @@ type hostDNSConfig struct {
}
type domainConfig struct {
disabled bool
domain string
matchOnly bool
}
@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
serverPort: port,
}
for _, nsConfig := range dnsConfig.NameServerGroups {
if len(nsConfig.NameServers) == 0 {
continue
}
if nsConfig.Primary {
config.routeAll = true
}

View File

@ -4,11 +4,12 @@ import (
"bufio"
"bytes"
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os/exec"
"strconv"
"strings"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
)
const (
@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
)
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if dConf.matchOnly {
matchDomains = append(matchDomains, dConf.domain)
continue

View File

@ -2,10 +2,11 @@ package dns
import (
"fmt"
"strings"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"strings"
)
const (
@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
)
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if !dConf.matchOnly {
searchDomains = append(searchDomains, dConf.domain)
}

View File

@ -1,8 +1,9 @@
package dns
import (
"github.com/miekg/dns"
"net"
"github.com/miekg/dns"
)
type mockResponseWriter struct {

View File

@ -4,14 +4,15 @@ import (
"context"
"encoding/binary"
"fmt"
"net/netip"
"regexp"
"time"
"github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"net/netip"
"regexp"
"time"
)
const (
@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
matchDomains []string
)
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
if dConf.matchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
continue

View File

@ -2,10 +2,11 @@ package dns
import (
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os/exec"
"strings"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
)
const resolvconfCommand = "resolvconf"
@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly {
if dConf.matchOnly || dConf.disabled {
continue
}

View File

@ -3,16 +3,17 @@ package dns
import (
"context"
"fmt"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
)
const (
@ -32,7 +33,8 @@ type Server interface {
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
stop context.CancelFunc
ctxCancel context.CancelFunc
upstreamCtxCancel context.CancelFunc
mux sync.Mutex
server *dns.Server
dnsMux *dns.ServeMux
@ -45,6 +47,7 @@ type DefaultServer struct {
runtimePort int
runtimeIP string
previousConfigHash uint64
currentConfig hostDNSConfig
customAddress *netip.AddrPort
}
@ -79,7 +82,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
defaultServer := &DefaultServer{
ctx: ctx,
stop: stop,
ctxCancel: stop,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),
@ -102,7 +105,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
@ -163,7 +165,7 @@ func (s *DefaultServer) setListenerStatus(running bool) {
func (s *DefaultServer) Stop() {
s.mux.Lock()
defer s.mux.Unlock()
s.stop()
s.ctxCancel()
err := s.hostManager.restoreHostDNS()
if err != nil {
@ -209,6 +211,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
@ -219,34 +222,9 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
s.updateSerial = serial
return nil
}
// is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.Start()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort))
if err != nil {
log.Error(err)
if err := s.applyConfiguration(update); err != nil {
return err
}
s.updateSerial = serial
@ -256,6 +234,40 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
}
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.Start()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
log.Error(err)
}
return nil
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
var muxUpdates []muxUpdate
localRecords := make(map[string]nbdns.SimpleRecord, 0)
@ -284,16 +296,22 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
// clean up the previous upstream resolver
if s.upstreamCtxCancel != nil {
s.upstreamCtxCancel()
}
var muxUpdates []muxUpdate
for _, nsGroup := range nameServerGroups {
if len(nsGroup.NameServers) == 0 {
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
}
handler := &upstreamResolver{
parentCTX: s.ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: defaultUpstreamTimeout,
log.Warn("received a nameserver group with empty nameserver list")
continue
}
var ctx context.Context
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
handler := newUpstreamResolver(ctx)
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
@ -308,6 +326,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue
}
// when upstream fails to resolve domain several times over all it servers
// it will calls this hook to exclude self from the configuration and
// reapply DNS settings, but it not touch the original configuration and serial number
// because it is temporal deactivation until next try
//
// after some period defined by upstream it trys to reactivate self by calling this hook
// everything we need here is just to re-apply current configuration because it already
// contains this upstream settings (temporal deactivation not removed it)
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
@ -382,3 +410,63 @@ func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate.
func (s *DefaultServer) upstreamCallbacks(
nsGroup *nbdns.NameServerGroup,
handler dns.Handler,
) (deactivate func(), reactivate func()) {
var removeIndex map[string]int
deactivate = func() {
s.mux.Lock()
defer s.mux.Unlock()
l := log.WithField("nameservers", nsGroup.NameServers)
l.Info("temporary deactivate nameservers group due timeout")
removeIndex = make(map[string]int)
for _, domain := range nsGroup.Domains {
removeIndex[domain] = -1
}
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.routeAll = false
}
for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true
s.deregisterMux(item.domain)
removeIndex[item.domain] = i
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
}
}
reactivate = func() {
s.mux.Lock()
defer s.mux.Unlock()
for domain, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
continue
}
s.currentConfig.domains[i].disabled = false
s.registerMux(domain, handler)
}
l := log.WithField("nameservers", nsGroup.NameServers)
l.Debug("reactivate temporary disabled nameserver group")
if nsGroup.Primary {
s.currentConfig.routeAll = true
}
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
return
}

View File

@ -3,13 +3,15 @@ package dns
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"net"
"net/netip"
"testing"
"time"
)
var zoneRecords = []nbdns.SimpleRecord{
@ -23,7 +25,6 @@ var zoneRecords = []nbdns.SimpleRecord{
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
@ -263,7 +264,6 @@ func TestUpdateDNSServer(t *testing.T) {
}
func TestDNSServerStartStop(t *testing.T) {
testCases := []struct {
name string
addrPort string
@ -333,6 +333,72 @@ func TestDNSServerStartStop(t *testing.T) {
}
}
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
dnsMux: dns.DefaultServeMux,
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
hostManager: hostManager,
currentConfig: hostDNSConfig{
domains: []domainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
domains := []string{}
for _, item := range config.domains {
if item.disabled {
continue
}
domains = append(domains, item.domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil)
deactivate()
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.domains {
if item.disabled {
continue
}
domains = append(domains, item.domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.domains {
if item.disabled {
continue
}
domains = append(domains, item.domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
mux := dns.NewServeMux()
@ -351,11 +417,11 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
UDPSize: 65535,
}
ctx, stop := context.WithCancel(context.TODO())
ctx, cancel := context.WithCancel(context.TODO())
return &DefaultServer{
ctx: ctx,
stop: stop,
ctxCancel: cancel,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registrationMap),

View File

@ -3,15 +3,16 @@ package dns
import (
"context"
"fmt"
"net"
"net/netip"
"time"
"github.com/godbus/dbus/v5"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"net"
"net/netip"
"time"
)
const (
@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
domainsInput []systemdDbusLinkDomainsInput
)
for _, dConf := range config.domains {
if dConf.disabled {
continue
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.domain),
MatchOnly: dConf.matchOnly,

View File

@ -3,44 +3,73 @@ package dns
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"net"
"time"
)
const defaultUpstreamTimeout = 15 * time.Second
const (
failsTillDeact = int32(3)
reactivatePeriod = time.Minute
upstreamTimeout = 15 * time.Second
)
type upstreamResolver struct {
parentCTX context.Context
upstreamClient *dns.Client
upstreamServers []string
upstreamTimeout time.Duration
ctx context.Context
upstreamClient *dns.Client
upstreamServers []string
disabled bool
failsCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
deactivate func()
reactivate func()
}
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
return &upstreamResolver{
ctx: ctx,
upstreamClient: &dns.Client{},
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
}
// ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails()
log.Tracef("received an upstream question: %#v", r.Question[0])
log.WithField("question", r.Question[0]).Trace("received an upstream question")
select {
case <-u.parentCTX.Done():
case <-u.ctx.Done():
return
default:
}
for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) {
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
log.WithError(err).WithField("upstream", upstream).
Warn("got an error while connecting to upstream")
continue
}
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream).
Error("got an error while querying the upstream")
return
}
@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
err = w.WriteMsg(rm)
if err != nil {
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
log.WithError(err).Error("got an error while writing the upstream resolver response")
}
// count the fails only if they happen sequentially
u.failsCount.Store(0)
return
}
log.Errorf("all queries to the upstream nameservers failed with timeout")
u.failsCount.Add(1)
log.Error("all queries to the upstream nameservers failed with timeout")
}
// checkUpstreamFails counts fails and disables or enables upstream resolving
//
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolver) checkUpstreamFails() {
u.mutex.Lock()
defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
return
}
select {
case <-u.ctx.Done():
return
default:
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.deactivate()
u.disabled = true
go u.waitUntilReactivation()
}
}
// waitUntilReactivation reset fails counter and activates upstream resolving
func (u *upstreamResolver) waitUntilReactivation() {
timer := time.NewTimer(u.reactivatePeriod)
defer func() {
if !timer.Stop() {
<-timer.C
}
}()
select {
case <-u.ctx.Done():
return
case <-timer.C:
log.Info("upstream resolving is reactivated")
u.failsCount.Store(0)
u.reactivate()
u.disabled = false
}
}
// isTimeout returns true if the given error is a network timeout error.

View File

@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: defaultUpstreamTimeout,
timeout: upstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: defaultUpstreamTimeout,
timeout: upstreamTimeout,
responseShouldBeNil: true,
},
//{
@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := &upstreamResolver{
parentCTX: ctx,
upstreamClient: &dns.Client{},
upstreamServers: testCase.InputServers,
upstreamTimeout: testCase.timeout,
}
resolver := newUpstreamResolver(ctx)
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
} else {
@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
})
}
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := newUpstreamResolver(context.TODO())
resolver.upstreamServers = []string{"0.0.0.0:-1"}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}
failed := false
resolver.deactivate = func() {
failed = true
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.failsCount.Load() != 0 {
t.Errorf("fails count after reactivation should be 0")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
}
}