mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-26 01:53:42 +01:00
9c4bf1e899
Handle original search domains in resolv.conf type implementations. - parse the original resolv.conf file - merge the search domains - ignore the domain keyword - append any other config lines (sortstlist, options) - fix read origin resolv.conf from bkp in resolvconf implementation - fix line length validation - fix number of search domains validation
269 lines
7.2 KiB
Go
269 lines
7.2 KiB
Go
//go:build !android
|
|
|
|
package dns
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
|
|
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
|
|
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"
|
|
|
|
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
|
|
|
|
fileMaxLineCharsLimit = 256
|
|
fileMaxNumberOfSearchDomains = 6
|
|
)
|
|
|
|
type fileConfigurator struct {
|
|
originalPerms os.FileMode
|
|
}
|
|
|
|
func newFileConfigurator() (hostManager, error) {
|
|
return &fileConfigurator{}, nil
|
|
}
|
|
|
|
func (f *fileConfigurator) supportCustomPort() bool {
|
|
return false
|
|
}
|
|
|
|
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|
backupFileExist := false
|
|
_, err := os.Stat(fileDefaultResolvConfBackupLocation)
|
|
if err == nil {
|
|
backupFileExist = true
|
|
}
|
|
|
|
if !config.routeAll {
|
|
if backupFileExist {
|
|
err = f.restore()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err)
|
|
}
|
|
}
|
|
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
|
}
|
|
|
|
if !backupFileExist {
|
|
err = f.backup()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to backup the resolv.conf file")
|
|
}
|
|
}
|
|
|
|
searchDomainList := searchDomains(config)
|
|
|
|
originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
|
|
searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
|
|
|
|
buf := prepareResolvConfContent(
|
|
searchDomainList,
|
|
append([]string{config.serverIP}, nameServers...),
|
|
others)
|
|
|
|
log.Debugf("creating managed file %s", defaultResolvConfPath)
|
|
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
|
|
if err != nil {
|
|
restoreErr := f.restore()
|
|
if restoreErr != nil {
|
|
log.Errorf("attempt to restore default file failed with error: %s", err)
|
|
}
|
|
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
|
|
}
|
|
|
|
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
|
return nil
|
|
}
|
|
|
|
func (f *fileConfigurator) restoreHostDNS() error {
|
|
return f.restore()
|
|
}
|
|
|
|
func (f *fileConfigurator) backup() error {
|
|
stats, err := os.Stat(defaultResolvConfPath)
|
|
if err != nil {
|
|
return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err)
|
|
}
|
|
|
|
f.originalPerms = stats.Mode()
|
|
|
|
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
|
|
if err != nil {
|
|
return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fileConfigurator) restore() error {
|
|
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
|
if err != nil {
|
|
return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
|
}
|
|
|
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
|
}
|
|
|
|
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
|
|
var buf bytes.Buffer
|
|
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
|
|
|
|
for _, cfgLine := range others {
|
|
buf.WriteString(cfgLine)
|
|
buf.WriteString("\n")
|
|
}
|
|
|
|
if len(searchDomains) > 0 {
|
|
buf.WriteString("search ")
|
|
buf.WriteString(strings.Join(searchDomains, " "))
|
|
buf.WriteString("\n")
|
|
}
|
|
|
|
for _, ns := range nameServers {
|
|
buf.WriteString("nameserver ")
|
|
buf.WriteString(ns)
|
|
buf.WriteString("\n")
|
|
}
|
|
return buf
|
|
}
|
|
|
|
func searchDomains(config hostDNSConfig) []string {
|
|
listOfDomains := make([]string, 0)
|
|
for _, dConf := range config.domains {
|
|
if dConf.matchOnly || dConf.disabled {
|
|
continue
|
|
}
|
|
|
|
listOfDomains = append(listOfDomains, dConf.domain)
|
|
}
|
|
return listOfDomains
|
|
}
|
|
|
|
func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
|
|
file, err := os.Open(resolvconfFile)
|
|
if err != nil {
|
|
err = fmt.Errorf(`could not read existing resolv.conf`)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
reader := bufio.NewReader(file)
|
|
|
|
for {
|
|
lineBytes, isPrefix, readErr := reader.ReadLine()
|
|
if readErr != nil {
|
|
break
|
|
}
|
|
|
|
if isPrefix {
|
|
err = fmt.Errorf(`resolv.conf line too long`)
|
|
return
|
|
}
|
|
|
|
line := strings.TrimSpace(string(lineBytes))
|
|
|
|
if strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(line, "domain") {
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
|
|
line = strings.ReplaceAll(line, "rotate", "")
|
|
splitLines := strings.Fields(line)
|
|
if len(splitLines) == 1 {
|
|
continue
|
|
}
|
|
line = strings.Join(splitLines, " ")
|
|
}
|
|
|
|
if strings.HasPrefix(line, "search") {
|
|
splitLines := strings.Fields(line)
|
|
if len(splitLines) < 2 {
|
|
continue
|
|
}
|
|
|
|
searchDomains = splitLines[1:]
|
|
continue
|
|
}
|
|
|
|
if strings.HasPrefix(line, "nameserver") {
|
|
splitLines := strings.Fields(line)
|
|
if len(splitLines) != 2 {
|
|
continue
|
|
}
|
|
nameServers = append(nameServers, splitLines[1])
|
|
continue
|
|
}
|
|
|
|
others = append(others, line)
|
|
}
|
|
return
|
|
}
|
|
|
|
// merge search domains lists and cut off the list if it is too long
|
|
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
|
|
lineSize := len("search")
|
|
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
|
|
|
|
lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
|
|
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)
|
|
|
|
return searchDomainsList
|
|
}
|
|
|
|
// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long
|
|
// extend s slice with vs elements
|
|
// return with the number of characters in the searchDomains line
|
|
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
|
|
for _, sd := range vs {
|
|
tmpCharsNumber := initialLineChars + 1 + len(sd)
|
|
if tmpCharsNumber > fileMaxLineCharsLimit {
|
|
// lets log all skipped domains
|
|
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
|
|
continue
|
|
}
|
|
|
|
initialLineChars = tmpCharsNumber
|
|
|
|
if len(*s) >= fileMaxNumberOfSearchDomains {
|
|
// lets log all skipped domains
|
|
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
|
|
continue
|
|
}
|
|
*s = append(*s, sd)
|
|
}
|
|
return initialLineChars
|
|
}
|
|
|
|
func copyFile(src, dest string) error {
|
|
stats, err := os.Stat(src)
|
|
if err != nil {
|
|
return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err)
|
|
}
|
|
|
|
bytesRead, err := os.ReadFile(src)
|
|
if err != nil {
|
|
return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err)
|
|
}
|
|
|
|
err = os.WriteFile(dest, bytesRead, stats.Mode())
|
|
if err != nil {
|
|
return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err)
|
|
}
|
|
return nil
|
|
}
|