//go:build (linux && !android) || freebsd package dns import ( "bytes" "fmt" "net/netip" "os" "strings" "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( fileGeneratedResolvConfContentHeader = "# Generated by NetBird" fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + ` # If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n" fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" fileMaxLineCharsLimit = 256 fileMaxNumberOfSearchDomains = 6 ) const ( dnsFailoverTimeout = 4 * time.Second dnsFailoverAttempts = 1 ) type fileConfigurator struct { repair *repair originalPerms os.FileMode nbNameserverIP string } func newFileConfigurator() (*fileConfigurator, error) { fc := &fileConfigurator{} fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) return fc, nil } func (f *fileConfigurator) supportCustomPort() bool { return false } func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { backupFileExist := f.isBackupFileExist() if !config.RouteAll { if backupFileExist { f.repair.stopWatchFileChanges() err := f.restore() if err != nil { return fmt.Errorf("restoring the original resolv.conf file return err: %w", err) } } return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") } if !backupFileExist { err := f.backup() if err != nil { return fmt.Errorf("unable to backup the resolv.conf file: %w", err) } } nbSearchDomains := searchDomains(config) f.nbNameserverIP = config.ServerIP resolvConf, err := parseBackupResolvConf() if err != nil { log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err) } f.repair.stopWatchFileChanges() err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) if err != nil { return err } f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager) return nil } func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) buf := prepareResolvConfContent( searchDomainList, nameServers, options) log.Debugf("creating managed file %s", defaultResolvConfPath) err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) if err != nil { restoreErr := f.restore() if restoreErr != nil { log.Errorf("attempt to restore default file failed with error: %s", err) } return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err) } log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) // create another backup for unclean shutdown detection right after overwriting the original resolv.conf if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil { log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) } return nil } func (f *fileConfigurator) restoreHostDNS() error { f.repair.stopWatchFileChanges() return f.restore() } func (f *fileConfigurator) backup() error { stats, err := os.Stat(defaultResolvConfPath) if err != nil { return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err) } f.originalPerms = stats.Mode() err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation) if err != nil { return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err) } return nil } func (f *fileConfigurator) restore() error { err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) if err != nil { log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err) } err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) if err != nil { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } return os.RemoveAll(fileDefaultResolvConfBackupLocation) } func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { resolvConf, err := parseDefaultResolvConf() if err != nil { return fmt.Errorf("parse current resolv.conf: %w", err) } // no current nameservers set -> restore if len(resolvConf.nameServers) == 0 { return restoreResolvConfFile() } currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0]) // not a valid first nameserver -> restore if err != nil { log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err) return restoreResolvConfFile() } // current address is still netbird's non-available dns address -> restore // comparing parsed addresses only, to remove ambiguity if currentDNSAddress.String() == storedDNSAddress.String() { return restoreResolvConfFile() } log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress) return nil } func (f *fileConfigurator) isBackupFileExist() bool { _, err := os.Stat(fileDefaultResolvConfBackupLocation) return err == nil } func restoreResolvConfFile() error { log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation) if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) } return nil } // generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list func generateNsList(nbNameserverIP string, cfg *resolvConf) []string { ns := make([]string, 1, len(cfg.nameServers)+1) ns[0] = nbNameserverIP for _, cfgNs := range cfg.nameServers { if nbNameserverIP != cfgNs { ns = append(ns, cfgNs) } } return ns } func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer { var buf bytes.Buffer buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine) for _, cfgLine := range others { buf.WriteString(cfgLine) buf.WriteString("\n") } if len(searchDomains) > 0 { buf.WriteString("search ") buf.WriteString(strings.Join(searchDomains, " ")) buf.WriteString("\n") } for _, ns := range nameServers { buf.WriteString("nameserver ") buf.WriteString(ns) buf.WriteString("\n") } return buf } func searchDomains(config HostDNSConfig) []string { listOfDomains := make([]string, 0) for _, dConf := range config.Domains { if dConf.MatchOnly || dConf.Disabled { continue } listOfDomains = append(listOfDomains, dConf.Domain) } return listOfDomains } // merge search Domains lists and cut off the list if it is too long func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string { lineSize := len("search") searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains)) lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains) _ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains) return searchDomainsList } // validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long // extend s slice with vs elements // return with the number of characters in the searchDomains line func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int { for _, sd := range vs { duplicated := false for _, fs := range *s { if fs == sd { duplicated = true break } } if duplicated { continue } tmpCharsNumber := initialLineChars + 1 + len(sd) if tmpCharsNumber > fileMaxLineCharsLimit { // lets log all skipped Domains log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd) continue } initialLineChars = tmpCharsNumber if len(*s) >= fileMaxNumberOfSearchDomains { // lets log all skipped Domains log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd) continue } *s = append(*s, sd) } return initialLineChars } func copyFile(src, dest string) error { stats, err := os.Stat(src) if err != nil { return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err) } bytesRead, err := os.ReadFile(src) if err != nil { return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err) } err = os.WriteFile(dest, bytesRead, stats.Mode()) if err != nil { return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err) } return nil } func isContains(subList []string, list []string) bool { for _, sl := range subList { var found bool for _, l := range list { if sl == l { found = true } } if !found { return false } } return true }