mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-10 18:58:27 +02:00
Fix nameserver peer conn check (#676)
* Disable upstream DNS resolver after several tries and fails * Add tests for upstream fails * Use an extra flag to disable domains in DNS upstreams * Fix hashing IPs of nameservers for updates.
This commit is contained in:
parent
d5dfed498b
commit
eb45310c8f
@ -3,8 +3,9 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -14,6 +15,7 @@ const (
|
|||||||
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
|
||||||
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
fileGeneratedResolvConfSearchBeginContent + "%s\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
||||||
fileMaxLineCharsLimit = 256
|
fileMaxLineCharsLimit = 256
|
||||||
@ -66,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
var searchDomains string
|
var searchDomains string
|
||||||
appendedDomains := 0
|
appendedDomains := 0
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly || dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
if appendedDomains >= fileMaxNumberOfSearchDomains {
|
||||||
|
@ -2,8 +2,9 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
@ -19,6 +20,7 @@ type hostDNSConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type domainConfig struct {
|
type domainConfig struct {
|
||||||
|
disabled bool
|
||||||
domain string
|
domain string
|
||||||
matchOnly bool
|
matchOnly bool
|
||||||
}
|
}
|
||||||
@ -56,6 +58,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostD
|
|||||||
serverPort: port,
|
serverPort: port,
|
||||||
}
|
}
|
||||||
for _, nsConfig := range dnsConfig.NameServerGroups {
|
for _, nsConfig := range dnsConfig.NameServerGroups {
|
||||||
|
if len(nsConfig.NameServers) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if nsConfig.Primary {
|
if nsConfig.Primary {
|
||||||
config.routeAll = true
|
config.routeAll = true
|
||||||
}
|
}
|
||||||
|
@ -4,11 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -61,6 +62,9 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.domain)
|
matchDomains = append(matchDomains, dConf.domain)
|
||||||
continue
|
continue
|
||||||
|
@ -2,10 +2,11 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -63,6 +64,9 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !dConf.matchOnly {
|
if !dConf.matchOnly {
|
||||||
searchDomains = append(searchDomains, dConf.domain)
|
searchDomains = append(searchDomains, dConf.domain)
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/miekg/dns"
|
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockResponseWriter struct {
|
type mockResponseWriter struct {
|
||||||
|
@ -4,14 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net/netip"
|
|
||||||
"regexp"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -106,6 +107,9 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
|
|||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain))
|
||||||
continue
|
continue
|
||||||
|
@ -2,10 +2,11 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@ -33,7 +34,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
var searchDomains string
|
var searchDomains string
|
||||||
appendedDomains := 0
|
appendedDomains := 0
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
if dConf.matchOnly {
|
if dConf.matchOnly || dConf.disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,16 +3,17 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -32,7 +33,8 @@ type Server interface {
|
|||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
|
upstreamCtxCancel context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
server *dns.Server
|
server *dns.Server
|
||||||
dnsMux *dns.ServeMux
|
dnsMux *dns.ServeMux
|
||||||
@ -45,6 +47,7 @@ type DefaultServer struct {
|
|||||||
runtimePort int
|
runtimePort int
|
||||||
runtimeIP string
|
runtimeIP string
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
|
currentConfig hostDNSConfig
|
||||||
customAddress *netip.AddrPort
|
customAddress *netip.AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,7 +82,7 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: stop,
|
ctxCancel: stop,
|
||||||
server: dnsServer,
|
server: dnsServer,
|
||||||
dnsMux: mux,
|
dnsMux: mux,
|
||||||
dnsMuxMap: make(registrationMap),
|
dnsMuxMap: make(registrationMap),
|
||||||
@ -102,7 +105,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
|
|
||||||
// Start runs the listener in a go routine
|
// Start runs the listener in a go routine
|
||||||
func (s *DefaultServer) Start() {
|
func (s *DefaultServer) Start() {
|
||||||
|
|
||||||
if s.customAddress != nil {
|
if s.customAddress != nil {
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
@ -163,7 +165,7 @@ func (s *DefaultServer) setListenerStatus(running bool) {
|
|||||||
func (s *DefaultServer) Stop() {
|
func (s *DefaultServer) Stop() {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
s.stop()
|
s.ctxCancel()
|
||||||
|
|
||||||
err := s.hostManager.restoreHostDNS()
|
err := s.hostManager.restoreHostDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -209,6 +211,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
ZeroNil: true,
|
ZeroNil: true,
|
||||||
IgnoreZeroValue: true,
|
IgnoreZeroValue: true,
|
||||||
SlicesAsSets: true,
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
@ -219,34 +222,9 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
s.updateSerial = serial
|
s.updateSerial = serial
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// is the service should be disabled, we stop the listener
|
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
|
||||||
if !update.ServiceEnable {
|
|
||||||
err := s.stopListener()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
|
||||||
s.updateLocalResolver(localRecords)
|
|
||||||
|
|
||||||
err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort))
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateSerial = serial
|
s.updateSerial = serial
|
||||||
@ -256,6 +234,40 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
|
// is the service should be disabled, we stop the listener
|
||||||
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
|
if !update.ServiceEnable {
|
||||||
|
err := s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
} else if !s.listenerIsRunning {
|
||||||
|
s.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
|
s.updateMux(muxUpdates)
|
||||||
|
s.updateLocalResolver(localRecords)
|
||||||
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
if err = s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||||
var muxUpdates []muxUpdate
|
var muxUpdates []muxUpdate
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
@ -284,16 +296,22 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||||
|
// clean up the previous upstream resolver
|
||||||
|
if s.upstreamCtxCancel != nil {
|
||||||
|
s.upstreamCtxCancel()
|
||||||
|
}
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
var muxUpdates []muxUpdate
|
||||||
for _, nsGroup := range nameServerGroups {
|
for _, nsGroup := range nameServerGroups {
|
||||||
if len(nsGroup.NameServers) == 0 {
|
if len(nsGroup.NameServers) == 0 {
|
||||||
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
}
|
continue
|
||||||
handler := &upstreamResolver{
|
|
||||||
parentCTX: s.ctx,
|
|
||||||
upstreamClient: &dns.Client{},
|
|
||||||
upstreamTimeout: defaultUpstreamTimeout,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, s.upstreamCtxCancel = context.WithCancel(s.ctx)
|
||||||
|
|
||||||
|
handler := newUpstreamResolver(ctx)
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||||
@ -308,6 +326,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when upstream fails to resolve domain several times over all it servers
|
||||||
|
// it will calls this hook to exclude self from the configuration and
|
||||||
|
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||||
|
// because it is temporal deactivation until next try
|
||||||
|
//
|
||||||
|
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||||
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
@ -382,3 +410,63 @@ func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||||
s.dnsMux.HandleRemove(pattern)
|
s.dnsMux.HandleRemove(pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
|
nsGroup *nbdns.NameServerGroup,
|
||||||
|
handler dns.Handler,
|
||||||
|
) (deactivate func(), reactivate func()) {
|
||||||
|
var removeIndex map[string]int
|
||||||
|
deactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Info("temporary deactivate nameservers group due timeout")
|
||||||
|
|
||||||
|
removeIndex = make(map[string]int)
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
removeIndex[domain] = -1
|
||||||
|
}
|
||||||
|
if nsGroup.Primary {
|
||||||
|
removeIndex[nbdns.RootZone] = -1
|
||||||
|
s.currentConfig.routeAll = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range s.currentConfig.domains {
|
||||||
|
if _, found := removeIndex[item.domain]; found {
|
||||||
|
s.currentConfig.domains[i].disabled = true
|
||||||
|
s.deregisterMux(item.domain)
|
||||||
|
removeIndex[item.domain] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
for domain, i := range removeIndex {
|
||||||
|
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.currentConfig.domains[i].disabled = false
|
||||||
|
s.registerMux(domain, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Debug("reactivate temporary disabled nameserver group")
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
s.currentConfig.routeAll = true
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -3,13 +3,15 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
@ -23,7 +25,6 @@ var zoneRecords = []nbdns.SimpleRecord{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
@ -263,7 +264,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
addrPort string
|
addrPort string
|
||||||
@ -333,6 +333,72 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
|
hostManager := &mockHostConfigurator{}
|
||||||
|
server := DefaultServer{
|
||||||
|
dnsMux: dns.DefaultServeMux,
|
||||||
|
localResolver: &localResolver{
|
||||||
|
registeredMap: make(registrationMap),
|
||||||
|
},
|
||||||
|
hostManager: hostManager,
|
||||||
|
currentConfig: hostDNSConfig{
|
||||||
|
domains: []domainConfig{
|
||||||
|
{false, "domain0", false},
|
||||||
|
{false, "domain1", false},
|
||||||
|
{false, "domain2", false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var domainsUpdate string
|
||||||
|
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error {
|
||||||
|
domains := []string{}
|
||||||
|
for _, item := range config.domains {
|
||||||
|
if item.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, item.domain)
|
||||||
|
}
|
||||||
|
domainsUpdate = strings.Join(domains, ",")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||||
|
Domains: []string{"domain1"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
deactivate()
|
||||||
|
expected := "domain0,domain2"
|
||||||
|
domains := []string{}
|
||||||
|
for _, item := range server.currentConfig.domains {
|
||||||
|
if item.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, item.domain)
|
||||||
|
}
|
||||||
|
got := strings.Join(domains, ",")
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
reactivate()
|
||||||
|
expected = "domain0,domain1,domain2"
|
||||||
|
domains = []string{}
|
||||||
|
for _, item := range server.currentConfig.domains {
|
||||||
|
if item.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domains = append(domains, item.domain)
|
||||||
|
}
|
||||||
|
got = strings.Join(domains, ",")
|
||||||
|
if expected != got {
|
||||||
|
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
@ -351,11 +417,11 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
|
|||||||
UDPSize: 65535,
|
UDPSize: 65535,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
|
||||||
return &DefaultServer{
|
return &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: stop,
|
ctxCancel: cancel,
|
||||||
server: dnsServer,
|
server: dnsServer,
|
||||||
dnsMux: mux,
|
dnsMux: mux,
|
||||||
dnsMuxMap: make(registrationMap),
|
dnsMuxMap: make(registrationMap),
|
||||||
|
@ -3,15 +3,16 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -95,6 +96,9 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|||||||
domainsInput []systemdDbusLinkDomainsInput
|
domainsInput []systemdDbusLinkDomainsInput
|
||||||
)
|
)
|
||||||
for _, dConf := range config.domains {
|
for _, dConf := range config.domains {
|
||||||
|
if dConf.disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: dns.Fqdn(dConf.domain),
|
Domain: dns.Fqdn(dConf.domain),
|
||||||
MatchOnly: dConf.matchOnly,
|
MatchOnly: dConf.matchOnly,
|
||||||
|
@ -3,44 +3,73 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultUpstreamTimeout = 15 * time.Second
|
const (
|
||||||
|
failsTillDeact = int32(3)
|
||||||
|
reactivatePeriod = time.Minute
|
||||||
|
upstreamTimeout = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
parentCTX context.Context
|
ctx context.Context
|
||||||
upstreamClient *dns.Client
|
upstreamClient *dns.Client
|
||||||
upstreamServers []string
|
upstreamServers []string
|
||||||
upstreamTimeout time.Duration
|
disabled bool
|
||||||
|
failsCount atomic.Int32
|
||||||
|
failsTillDeact int32
|
||||||
|
mutex sync.Mutex
|
||||||
|
reactivatePeriod time.Duration
|
||||||
|
upstreamTimeout time.Duration
|
||||||
|
|
||||||
|
deactivate func()
|
||||||
|
reactivate func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
|
||||||
|
return &upstreamResolver{
|
||||||
|
ctx: ctx,
|
||||||
|
upstreamClient: &dns.Client{},
|
||||||
|
upstreamTimeout: upstreamTimeout,
|
||||||
|
reactivatePeriod: reactivatePeriod,
|
||||||
|
failsTillDeact: failsTillDeact,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
defer u.checkUpstreamFails()
|
||||||
|
|
||||||
log.Tracef("received an upstream question: %#v", r.Question[0])
|
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.parentCTX.Done():
|
case <-u.ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
|
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||||
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
|
log.WithError(err).WithField("upstream", upstream).
|
||||||
|
Warn("got an error while connecting to upstream")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
|
u.failsCount.Add(1)
|
||||||
|
log.WithError(err).WithField("upstream", upstream).
|
||||||
|
Error("got an error while querying the upstream")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,11 +77,58 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
err = w.WriteMsg(rm)
|
err = w.WriteMsg(rm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
|
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
||||||
}
|
}
|
||||||
|
// count the fails only if they happen sequentially
|
||||||
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Errorf("all queries to the upstream nameservers failed with timeout")
|
u.failsCount.Add(1)
|
||||||
|
log.Error("all queries to the upstream nameservers failed with timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
||||||
|
//
|
||||||
|
// If fails count is greater that failsTillDeact, upstream resolving
|
||||||
|
// will be disabled for reactivatePeriod, after that time period fails counter
|
||||||
|
// will be reset and upstream will be reactivated.
|
||||||
|
func (u *upstreamResolver) checkUpstreamFails() {
|
||||||
|
u.mutex.Lock()
|
||||||
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
|
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-u.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
||||||
|
u.deactivate()
|
||||||
|
u.disabled = true
|
||||||
|
go u.waitUntilReactivation()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitUntilReactivation reset fails counter and activates upstream resolving
|
||||||
|
func (u *upstreamResolver) waitUntilReactivation() {
|
||||||
|
timer := time.NewTimer(u.reactivatePeriod)
|
||||||
|
defer func() {
|
||||||
|
if !timer.Stop() {
|
||||||
|
<-timer.C
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-u.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-timer.C:
|
||||||
|
log.Info("upstream resolving is reactivated")
|
||||||
|
u.failsCount.Store(0)
|
||||||
|
u.reactivate()
|
||||||
|
u.disabled = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTimeout returns true if the given error is a network timeout error.
|
// isTimeout returns true if the given error is a network timeout error.
|
||||||
|
@ -23,7 +23,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||||
timeout: defaultUpstreamTimeout,
|
timeout: upstreamTimeout,
|
||||||
expectedAnswer: "1.1.1.1",
|
expectedAnswer: "1.1.1.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -45,7 +45,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||||
cancelCTX: true,
|
cancelCTX: true,
|
||||||
timeout: defaultUpstreamTimeout,
|
timeout: upstreamTimeout,
|
||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
//{
|
//{
|
||||||
@ -65,12 +65,9 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver := &upstreamResolver{
|
resolver := newUpstreamResolver(ctx)
|
||||||
parentCTX: ctx,
|
resolver.upstreamServers = testCase.InputServers
|
||||||
upstreamClient: &dns.Client{},
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
upstreamServers: testCase.InputServers,
|
|
||||||
upstreamTimeout: testCase.timeout,
|
|
||||||
}
|
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
cancel()
|
cancel()
|
||||||
} else {
|
} else {
|
||||||
@ -108,3 +105,52 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||||
|
resolver := newUpstreamResolver(context.TODO())
|
||||||
|
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
||||||
|
resolver.failsTillDeact = 0
|
||||||
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|
||||||
|
responseWriter := &mockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
failed := false
|
||||||
|
resolver.deactivate = func() {
|
||||||
|
failed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
reactivated := false
|
||||||
|
resolver.reactivate = func() {
|
||||||
|
reactivated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
|
||||||
|
|
||||||
|
if !failed {
|
||||||
|
t.Errorf("expected that resolving was deactivated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resolver.disabled {
|
||||||
|
t.Errorf("resolver should be disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 200)
|
||||||
|
|
||||||
|
if !reactivated {
|
||||||
|
t.Errorf("expected that resolving was reactivated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolver.failsCount.Load() != 0 {
|
||||||
|
t.Errorf("fails count after reactivation should be 0")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolver.disabled {
|
||||||
|
t.Errorf("should be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user