mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-07 06:29:06 +01:00
735ed7ab34
Stop the file repairer before doing the restore
330 lines
9.4 KiB
Go
330 lines
9.4 KiB
Go
//go:build !android
|
|
|
|
package dns
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
|
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
|
|
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
|
|
|
|
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
|
|
|
fileMaxLineCharsLimit = 256
|
|
fileMaxNumberOfSearchDomains = 6
|
|
)
|
|
|
|
const (
|
|
dnsFailoverTimeout = 4 * time.Second
|
|
dnsFailoverAttempts = 1
|
|
)
|
|
|
|
type fileConfigurator struct {
|
|
repair *repair
|
|
|
|
originalPerms os.FileMode
|
|
nbNameserverIP string
|
|
}
|
|
|
|
func newFileConfigurator() (hostManager, error) {
|
|
fc := &fileConfigurator{}
|
|
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
|
return fc, nil
|
|
}
|
|
|
|
func (f *fileConfigurator) supportCustomPort() bool {
|
|
return false
|
|
}
|
|
|
|
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|
backupFileExist := f.isBackupFileExist()
|
|
if !config.RouteAll {
|
|
if backupFileExist {
|
|
f.repair.stopWatchFileChanges()
|
|
err := f.restore()
|
|
if err != nil {
|
|
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
|
|
}
|
|
}
|
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
|
}
|
|
|
|
if !backupFileExist {
|
|
err := f.backup()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
|
|
}
|
|
}
|
|
|
|
nbSearchDomains := searchDomains(config)
|
|
f.nbNameserverIP = config.ServerIP
|
|
|
|
resolvConf, err := parseBackupResolvConf()
|
|
if err != nil {
|
|
log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
|
|
}
|
|
|
|
f.repair.stopWatchFileChanges()
|
|
|
|
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
|
return nil
|
|
}
|
|
|
|
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
|
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
|
nameServers := generateNsList(nbNameserverIP, cfg)
|
|
|
|
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
|
|
buf := prepareResolvConfContent(
|
|
searchDomainList,
|
|
nameServers,
|
|
options)
|
|
|
|
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
|
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
|
if err != nil {
|
|
restoreErr := f.restore()
|
|
if restoreErr != nil {
|
|
log.Errorf("attempt to restore default file failed with error: %s", err)
|
|
}
|
|
return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
|
|
}
|
|
|
|
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
|
|
|
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
|
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (f *fileConfigurator) restoreHostDNS() error {
|
|
f.repair.stopWatchFileChanges()
|
|
return f.restore()
|
|
}
|
|
|
|
func (f *fileConfigurator) backup() error {
|
|
stats, err := os.Stat(defaultResolvConfPath)
|
|
if err != nil {
|
|
return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
|
|
}
|
|
|
|
f.originalPerms = stats.Mode()
|
|
|
|
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
|
|
if err != nil {
|
|
return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fileConfigurator) restore() error {
|
|
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
|
if err != nil {
|
|
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
|
}
|
|
|
|
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
|
if err != nil {
|
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
|
}
|
|
|
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
}
|
|
|
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
|
}
|
|
|
|
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
|
resolvConf, err := parseDefaultResolvConf()
|
|
if err != nil {
|
|
return fmt.Errorf("parse current resolv.conf: %w", err)
|
|
}
|
|
|
|
// no current nameservers set -> restore
|
|
if len(resolvConf.nameServers) == 0 {
|
|
return restoreResolvConfFile()
|
|
}
|
|
|
|
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
|
// not a valid first nameserver -> restore
|
|
if err != nil {
|
|
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
|
return restoreResolvConfFile()
|
|
}
|
|
|
|
// current address is still netbird's non-available dns address -> restore
|
|
// comparing parsed addresses only, to remove ambiguity
|
|
if currentDNSAddress.String() == storedDNSAddress.String() {
|
|
return restoreResolvConfFile()
|
|
}
|
|
|
|
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
|
return nil
|
|
}
|
|
|
|
func (f *fileConfigurator) isBackupFileExist() bool {
|
|
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
|
return err == nil
|
|
}
|
|
|
|
func restoreResolvConfFile() error {
|
|
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
|
|
|
|
if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
|
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
|
}
|
|
|
|
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
|
|
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
|
|
ns := make([]string, 1, len(cfg.nameServers)+1)
|
|
ns[0] = nbNameserverIP
|
|
for _, cfgNs := range cfg.nameServers {
|
|
if nbNameserverIP != cfgNs {
|
|
ns = append(ns, cfgNs)
|
|
}
|
|
}
|
|
return ns
|
|
}
|
|
|
|
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
|
var buf bytes.Buffer
|
|
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
|
|
|
for _, cfgLine := range others {
|
|
buf.WriteString(cfgLine)
|
|
buf.WriteString("\n")
|
|
}
|
|
|
|
if len(searchDomains) > 0 {
|
|
buf.WriteString("search ")
|
|
buf.WriteString(strings.Join(searchDomains, " "))
|
|
buf.WriteString("\n")
|
|
}
|
|
|
|
for _, ns := range nameServers {
|
|
buf.WriteString("nameserver ")
|
|
buf.WriteString(ns)
|
|
buf.WriteString("\n")
|
|
}
|
|
return buf
|
|
}
|
|
|
|
func searchDomains(config HostDNSConfig) []string {
|
|
listOfDomains := make([]string, 0)
|
|
for _, dConf := range config.Domains {
|
|
if dConf.MatchOnly || dConf.Disabled {
|
|
continue
|
|
}
|
|
|
|
listOfDomains = append(listOfDomains, dConf.Domain)
|
|
}
|
|
return listOfDomains
|
|
}
|
|
|
|
// merge search Domains lists and cut off the list if it is too long
|
|
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
|
lineSize := len("search")
|
|
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
|
|
|
|
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
|
|
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
|
|
|
|
return searchDomainsList
|
|
}
|
|
|
|
// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
|
|
// extend s slice with vs elements
|
|
// return with the number of characters in the searchDomains line
|
|
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
|
for _, sd := range vs {
|
|
duplicated := false
|
|
for _, fs := range *s {
|
|
if fs == sd {
|
|
duplicated = true
|
|
break
|
|
}
|
|
|
|
}
|
|
|
|
if duplicated {
|
|
continue
|
|
}
|
|
|
|
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
|
if tmpCharsNumber > fileMaxLineCharsLimit {
|
|
// lets log all skipped Domains
|
|
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
|
|
continue
|
|
}
|
|
|
|
initialLineChars = tmpCharsNumber
|
|
|
|
if len(*s) >= fileMaxNumberOfSearchDomains {
|
|
// lets log all skipped Domains
|
|
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
|
|
continue
|
|
}
|
|
*s = append(*s, sd)
|
|
}
|
|
|
|
return initialLineChars
|
|
}
|
|
|
|
func copyFile(src, dest string) error {
|
|
stats, err := os.Stat(src)
|
|
if err != nil {
|
|
return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
|
|
}
|
|
|
|
bytesRead, err := os.ReadFile(src)
|
|
if err != nil {
|
|
return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
|
|
}
|
|
|
|
err = os.WriteFile(dest, bytesRead, stats.Mode())
|
|
if err != nil {
|
|
return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isContains(subList []string, list []string) bool {
|
|
for _, sl := range subList {
|
|
var found bool
|
|
for _, l := range list {
|
|
if sl == l {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|