Enhance DNS failover reliability (#1637)

* Fix using wrong array index in log to avoid potential panic

* Increase gRPC connection timeout and add the timeout resolv.conf option

This makes sure the dns client is able to failover to a second
configured nameserver, if present. That is the case then when using the
dns `file` manager and a resolv.conf file generated for netbird.

* On file backup restore, remove the first NS if it's the netbird NS

* Bump dns mangager discovery message from debug to info to ease debugging
This commit is contained in:
Viktor Liu 2024-03-01 15:17:35 +01:00 committed by GitHub
parent a4b9e93217
commit 17b1099032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 224 additions and 15 deletions

View File

@ -8,6 +8,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -23,10 +24,16 @@ const (
fileMaxNumberOfSearchDomains = 6 fileMaxNumberOfSearchDomains = 6
) )
const (
dnsFailoverTimeout = 4 * time.Second
dnsFailoverAttempts = 1
)
type fileConfigurator struct { type fileConfigurator struct {
repair *repair repair *repair
originalPerms os.FileMode originalPerms os.FileMode
nbNameserverIP string
} }
func newFileConfigurator() (hostManager, error) { func newFileConfigurator() (hostManager, error) {
@ -64,7 +71,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
} }
nbSearchDomains := searchDomains(config) nbSearchDomains := searchDomains(config)
nbNameserverIP := config.ServerIP f.nbNameserverIP = config.ServerIP
resolvConf, err := parseBackupResolvConf() resolvConf, err := parseBackupResolvConf()
if err != nil { if err != nil {
@ -73,11 +80,11 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
f.repair.stopWatchFileChanges() f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, nbNameserverIP, resolvConf) err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
if err != nil { if err != nil {
return err return err
} }
f.repair.watchFileChanges(nbSearchDomains, nbNameserverIP) f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
return nil return nil
} }
@ -85,10 +92,11 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg) nameServers := generateNsList(nbNameserverIP, cfg)
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent( buf := prepareResolvConfContent(
searchDomainList, searchDomainList,
nameServers, nameServers,
cfg.others) options)
log.Debugf("creating managed file %s", defaultResolvConfPath) log.Debugf("creating managed file %s", defaultResolvConfPath)
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
@ -131,7 +139,12 @@ func (f *fileConfigurator) backup() error {
} }
func (f *fileConfigurator) restore() error { func (f *fileConfigurator) restore() error {
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil { if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
@ -157,7 +170,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0]) currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore // not a valid first nameserver -> restore
if err != nil { if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[1], err) log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile() return restoreResolvConfFile()
} }

View File

@ -5,6 +5,7 @@ package dns
import ( import (
"fmt" "fmt"
"os" "os"
"regexp"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -14,6 +15,9 @@ const (
defaultResolvConfPath = "/etc/resolv.conf" defaultResolvConfPath = "/etc/resolv.conf"
) )
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
type resolvConf struct { type resolvConf struct {
nameServers []string nameServers []string
searchDomains []string searchDomains []string
@ -103,3 +107,62 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
} }
return rconf, nil return rconf, nil
} }
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
// otherwise it adds a new option with timeout and attempts.
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
configs := make([]string, len(input))
copy(configs, input)
for i, config := range configs {
if strings.HasPrefix(config, "options") {
config = strings.ReplaceAll(config, "rotate", "")
config = strings.Join(strings.Fields(config), " ")
if strings.Contains(config, "timeout:") {
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
}
if strings.Contains(config, "attempts:") {
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
}
configs[i] = config
return configs
}
}
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
}
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename, nameserverIP string) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@ -6,6 +6,8 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func Test_parseResolvConf(t *testing.T) { func Test_parseResolvConf(t *testing.T) {
@ -172,3 +174,131 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg) t.Errorf("unexpected resolv.conf content: %v", cfg)
} }
} }
func TestPrepareOptionsWithTimeout(t *testing.T) {
tests := []struct {
name string
others []string
timeout int
attempts int
expected []string
}{
{
name: "Append new options with timeout and attempts",
others: []string{"some config"},
timeout: 2,
attempts: 2,
expected: []string{"some config", "options timeout:2 attempts:2"},
},
{
name: "Modify existing options to exclude rotate and include timeout and attempts",
others: []string{"some config", "options rotate someother"},
timeout: 3,
attempts: 2,
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
},
{
name: "Existing options with timeout and attempts are updated",
others: []string{"some config", "options timeout:4 attempts:3"},
timeout: 5,
attempts: 4,
expected: []string{"some config", "options timeout:5 attempts:4"},
},
{
name: "Modify existing options, add missing attempts before timeout",
others: []string{"some config", "options timeout:4"},
timeout: 4,
attempts: 3,
expected: []string{"some config", "options attempts:3 timeout:4"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
assert.Equal(t, tc.expected, result)
})
}
}
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
err = removeFirstNbNameserver(tempFile, tc.ipToRemove)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@ -65,7 +65,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return nil, err return nil, err
} }
log.Debugf("discovered mode is: %s", osManager) log.Infof("System DNS manager discovered: %s", osManager)
return newHostManagerFromType(wgInterface, osManager) return newHostManagerFromType(wgInterface, osManager)
} }

View File

@ -53,10 +53,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
searchDomainList := searchDomains(config) searchDomainList := searchDomains(config)
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains) searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent( buf := prepareResolvConfContent(
searchDomainList, searchDomainList,
append([]string{config.ServerIP}, r.originalNameServers...), append([]string{config.ServerIP}, r.originalNameServers...),
r.othersConfigs) options)
// create a backup for unclean shutdown detection before the resolv.conf is changed // create a backup for unclean shutdown detection before the resolv.conf is changed
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {

View File

@ -26,6 +26,8 @@ import (
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
) )
const ConnectTimeout = 10 * time.Second
// ConnStateNotifier is a wrapper interface of the status recorders // ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface { type ConnStateNotifier interface {
MarkManagementDisconnected(error) MarkManagementDisconnected(error)
@ -49,7 +51,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
} }
mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second) mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
mgmCtx, mgmCtx,
@ -318,7 +320,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)
return nil, err return nil, err
} }
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel() defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(), WgPubKey: c.key.PublicKey().String(),

View File

@ -21,11 +21,10 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
) )
const defaultSendTimeout = 5 * time.Second
// ConnStateNotifier is a wrapper interface of the status recorder // ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface { type ConnStateNotifier interface {
MarkSignalDisconnected(error) MarkSignalDisconnected(error)
@ -71,7 +70,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
} }
sigCtx, cancel := context.WithTimeout(ctx, 5*time.Second) sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
sigCtx, sigCtx,
@ -353,7 +352,7 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
return err return err
} }
attemptTimeout := defaultSendTimeout attemptTimeout := client.ConnectTimeout
for attempt := 0; attempt < 4; attempt++ { for attempt := 0; attempt < 4; attempt++ {
if attempt > 1 { if attempt > 1 {