mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-22 16:13:31 +01:00
Fix nameserver peer conn check (#676)
* Disable upstream DNS resolver after several tries and fails * Add tests for upstream fails * Use an extra flag to disable domains in DNS upstreams * Fix hashing IPs of nameservers for updates.
This commit is contained in:
parent
d5dfed498b
commit
eb45310c8f
@ -3,8 +3,9 @@ package dns
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -14,6 +15,7 @@ const (
|
||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||
)
|
||||
|
||||
const (
|
||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||
fileMaxLineCharsLimit = 256
|
||||
@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||
|
@ -2,8 +2,9 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"strings"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
@ -19,6 +20,7 @@ type hostDNSConfig struct {
|
||||
}
|
||||
|
||||
type domainConfig struct {
|
||||
disabled bool
|
||||
domain string
|
||||
matchOnly bool
|
||||
}
|
||||
@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
|
||||
serverPort: port,
|
||||
}
|
||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||
if len(nsConfig.NameServers) == 0 {
|
||||
continue
|
||||
}
|
||||
if nsConfig.Primary {
|
||||
config.routeAll = true
|
||||
}
|
||||
|
@ -4,11 +4,12 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, dConf.domain)
|
||||
continue
|
||||
|
@ -2,10 +2,11 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
)
|
||||
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if !dConf.matchOnly {
|
||||
searchDomains = append(searchDomains, dConf.domain)
|
||||
}
|
||||
|
@ -1,8 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct {
|
||||
|
@ -4,14 +4,15 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
||||
matchDomains []string
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
if dConf.matchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||
continue
|
||||
|
@ -2,10 +2,11 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
||||
var searchDomains string
|
||||
appendedDomains := 0
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.matchOnly {
|
||||
if dConf.matchOnly || dConf.disabled {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -3,16 +3,17 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -32,7 +33,8 @@ type Server interface {
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
ctxCancel context.CancelFunc
|
||||
upstreamCtxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
@ -45,6 +47,7 @@ type DefaultServer struct {
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
}
|
||||
|
||||
@ -79,7 +82,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
stop: stop,
|
||||
ctxCancel: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
@ -102,7 +105,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *DefaultServer) Start() {
|
||||
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
@ -163,7 +165,7 @@ func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.stop()
|
||||
s.ctxCancel()
|
||||
|
||||
err := s.hostManager.restoreHostDNS()
|
||||
if err != nil {
|
||||
@ -209,6 +211,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
@ -219,34 +222,9 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
|
||||
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
@ -256,6 +234,40 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
@ -284,16 +296,22 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
// clean up the previous upstream resolver
|
||||
if s.upstreamCtxCancel != nil {
|
||||
s.upstreamCtxCancel()
|
||||
}
|
||||
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
|
||||
}
|
||||
handler := &upstreamResolver{
|
||||
parentCTX: s.ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamTimeout: defaultUpstreamTimeout,
|
||||
log.Warn("received a nameserver group with empty nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
var ctx context.Context
|
||||
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||
|
||||
handler := newUpstreamResolver(ctx)
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
@ -308,6 +326,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
continue
|
||||
}
|
||||
|
||||
// when upstream fails to resolve domain several times over all it servers
|
||||
// it will calls this hook to exclude self from the configuration and
|
||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||
// because it is temporal deactivation until next try
|
||||
//
|
||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||
// everything we need here is just to re-apply current configuration because it already
|
||||
// contains this upstream settings (temporal deactivation not removed it)
|
||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
@ -382,3 +410,63 @@ func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
func (s *DefaultServer) upstreamCallbacks(
|
||||
nsGroup *nbdns.NameServerGroup,
|
||||
handler dns.Handler,
|
||||
) (deactivate func(), reactivate func()) {
|
||||
var removeIndex map[string]int
|
||||
deactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Info("temporary deactivate nameservers group due timeout")
|
||||
|
||||
removeIndex = make(map[string]int)
|
||||
for _, domain := range nsGroup.Domains {
|
||||
removeIndex[domain] = -1
|
||||
}
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.routeAll = false
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||
}
|
||||
}
|
||||
reactivate = func() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
for domain, i := range removeIndex {
|
||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
l.Debug("reactivate temporary disabled nameserver group")
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.routeAll = true
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -3,13 +3,15 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
@ -23,7 +25,6 @@ var zoneRecords = []nbdns.SimpleRecord{
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
@ -263,7 +264,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
addrPort string
|
||||
@ -333,6 +333,72 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
dnsMux: dns.DefaultServeMux,
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
hostManager: hostManager,
|
||||
currentConfig: hostDNSConfig{
|
||||
domains: []domainConfig{
|
||||
{false, "domain0", false},
|
||||
{false, "domain1", false},
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.domains {
|
||||
if item.disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||
Domains: []string{"domain1"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
deactivate()
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.domains {
|
||||
if item.disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||
}
|
||||
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
for _, item := range server.currentConfig.domains {
|
||||
if item.disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
@ -351,11 +417,11 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(context.TODO())
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
return &DefaultServer{
|
||||
ctx: ctx,
|
||||
stop: stop,
|
||||
ctxCancel: cancel,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
|
@ -3,15 +3,16 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
||||
domainsInput []systemdDbusLinkDomainsInput
|
||||
)
|
||||
for _, dConf := range config.domains {
|
||||
if dConf.disabled {
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dns.Fqdn(dConf.domain),
|
||||
MatchOnly: dConf.matchOnly,
|
||||
|
@ -3,44 +3,73 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultUpstreamTimeout = 15 * time.Second
|
||||
const (
|
||||
failsTillDeact = int32(3)
|
||||
reactivatePeriod = time.Minute
|
||||
upstreamTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
parentCTX context.Context
|
||||
upstreamClient *dns.Client
|
||||
upstreamServers []string
|
||||
upstreamTimeout time.Duration
|
||||
ctx context.Context
|
||||
upstreamClient *dns.Client
|
||||
upstreamServers []string
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
failsTillDeact int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
|
||||
deactivate func()
|
||||
reactivate func()
|
||||
}
|
||||
|
||||
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
|
||||
return &upstreamResolver{
|
||||
ctx: ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
defer u.checkUpstreamFails()
|
||||
|
||||
log.Tracef("received an upstream question: %#v", r.Question[0])
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
|
||||
select {
|
||||
case <-u.parentCTX.Done():
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
|
||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Warn("got an error while connecting to upstream")
|
||||
continue
|
||||
}
|
||||
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
|
||||
u.failsCount.Add(1)
|
||||
log.WithError(err).WithField("upstream", upstream).
|
||||
Error("got an error while querying the upstream")
|
||||
return
|
||||
}
|
||||
|
||||
@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
if err != nil {
|
||||
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
|
||||
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
||||
}
|
||||
// count the fails only if they happen sequentially
|
||||
u.failsCount.Store(0)
|
||||
return
|
||||
}
|
||||
log.Errorf("all queries to the upstream nameservers failed with timeout")
|
||||
u.failsCount.Add(1)
|
||||
log.Error("all queries to the upstream nameservers failed with timeout")
|
||||
}
|
||||
|
||||
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||
//
|
||||
// If fails count is greater that failsTillDeact, upstream resolving
|
||||
// will be disabled for reactivatePeriod, after that time period fails counter
|
||||
// will be reset and upstream will be reactivated.
|
||||
func (u *upstreamResolver) checkUpstreamFails() {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||
u.deactivate()
|
||||
u.disabled = true
|
||||
go u.waitUntilReactivation()
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilReactivation reset fails counter and activates upstream resolving
|
||||
func (u *upstreamResolver) waitUntilReactivation() {
|
||||
timer := time.NewTimer(u.reactivatePeriod)
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
log.Info("upstream resolving is reactivated")
|
||||
u.failsCount.Store(0)
|
||||
u.reactivate()
|
||||
u.disabled = false
|
||||
}
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
|
@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: defaultUpstreamTimeout,
|
||||
timeout: upstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: defaultUpstreamTimeout,
|
||||
timeout: upstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
//{
|
||||
@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver := &upstreamResolver{
|
||||
parentCTX: ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamServers: testCase.InputServers,
|
||||
upstreamTimeout: testCase.timeout,
|
||||
}
|
||||
resolver := newUpstreamResolver(ctx)
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
} else {
|
||||
@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
resolver := newUpstreamResolver(context.TODO())
|
||||
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
||||
resolver.failsTillDeact = 0
|
||||
resolver.reactivatePeriod = time.Microsecond * 100
|
||||
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||
}
|
||||
|
||||
failed := false
|
||||
resolver.deactivate = func() {
|
||||
failed = true
|
||||
}
|
||||
|
||||
reactivated := false
|
||||
resolver.reactivate = func() {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
t.Errorf("resolver should be disabled")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
if !reactivated {
|
||||
t.Errorf("expected that resolving was reactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.failsCount.Load() != 0 {
|
||||
t.Errorf("fails count after reactivation should be 0")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user