mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 10:21:10 +01:00
7ebe58f20a
* Add DNS list argument for mobile client * Write testable code Many places are checked the wgInterface != nil condition. It is doing it just because to avoid the real wgInterface creation for tests. Instead of this involve a wgInterface interface what is moc-able. * Refactor the DNS server internal code structure With the fake resolver has been involved several if-else statement and generated some unused variables to distinguish the listener and fake resolver solutions at running time. With this commit the fake resolver and listener based solution has been moved into two separated structure. Name of this layer is the 'service'. With this modification the unit test looks simpler and open the option to add new logic for the permanent DNS service usage for mobile systems. * Remove is running check in test We can not ensure the state well so remove this check. The test will fail if the server is not running well.
268 lines
7.9 KiB
Go
268 lines
7.9 KiB
Go
package dns
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/sys/windows/registry"
|
|
)
|
|
|
|
const (
|
|
dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match"
|
|
dnsPolicyConfigVersionKey = "Version"
|
|
dnsPolicyConfigVersionValue = 2
|
|
dnsPolicyConfigNameKey = "Name"
|
|
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
|
|
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
|
dnsPolicyConfigConfigOptionsValue = 0x8
|
|
)
|
|
|
|
const (
|
|
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces"
|
|
interfaceConfigNameServerKey = "NameServer"
|
|
interfaceConfigSearchListKey = "SearchList"
|
|
tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"
|
|
)
|
|
|
|
type registryConfigurator struct {
|
|
guid string
|
|
routingAll bool
|
|
existingSearchDomains []string
|
|
}
|
|
|
|
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ®istryConfigurator{
|
|
guid: guid,
|
|
}, nil
|
|
}
|
|
|
|
func (s *registryConfigurator) supportCustomPort() bool {
|
|
return false
|
|
}
|
|
|
|
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
|
|
var err error
|
|
if config.routeAll {
|
|
err = r.addDNSSetupForAll(config.serverIP)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if r.routingAll {
|
|
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.routingAll = false
|
|
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP)
|
|
}
|
|
|
|
var (
|
|
searchDomains []string
|
|
matchDomains []string
|
|
)
|
|
|
|
for _, dConf := range config.domains {
|
|
if dConf.disabled {
|
|
continue
|
|
}
|
|
if !dConf.matchOnly {
|
|
searchDomains = append(searchDomains, dConf.domain)
|
|
}
|
|
matchDomains = append(matchDomains, "."+dConf.domain)
|
|
}
|
|
|
|
if len(matchDomains) != 0 {
|
|
err = r.addDNSMatchPolicy(matchDomains, config.serverIP)
|
|
} else {
|
|
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = r.updateSearchDomains(searchDomains)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
|
|
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
|
|
if err != nil {
|
|
return fmt.Errorf("adding dns setup for all failed with error: %s", err)
|
|
}
|
|
r.routingAll = true
|
|
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
|
return nil
|
|
}
|
|
|
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
|
|
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE)
|
|
if err == nil {
|
|
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
|
}
|
|
}
|
|
|
|
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err)
|
|
}
|
|
|
|
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err)
|
|
}
|
|
|
|
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err)
|
|
}
|
|
|
|
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err)
|
|
}
|
|
|
|
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err)
|
|
}
|
|
|
|
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *registryConfigurator) restoreHostDNS() error {
|
|
err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
|
|
if err != nil {
|
|
log.Error(err)
|
|
}
|
|
|
|
return r.updateSearchDomains([]string{})
|
|
}
|
|
|
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
|
value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to get current search domains failed with error: %s", err)
|
|
}
|
|
|
|
valueList := strings.Split(value, ",")
|
|
setExisting := false
|
|
if len(r.existingSearchDomains) == 0 {
|
|
r.existingSearchDomains = valueList
|
|
setExisting = true
|
|
}
|
|
|
|
if len(domains) == 0 && setExisting {
|
|
log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains)
|
|
return nil
|
|
}
|
|
|
|
newList := append(r.existingSearchDomains, domains...)
|
|
|
|
err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ","))
|
|
if err != nil {
|
|
return fmt.Errorf("adding search domain failed with error: %s", err)
|
|
}
|
|
|
|
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
|
|
regKey, err := r.getInterfaceRegistryKey()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer regKey.Close()
|
|
|
|
err = regKey.SetStringValue(key, value)
|
|
if err != nil {
|
|
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
|
|
regKey, err := r.getInterfaceRegistryKey()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer regKey.Close()
|
|
|
|
err = regKey.DeleteValue(propertyKey)
|
|
if err != nil {
|
|
return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
|
var regKey registry.Key
|
|
|
|
regKeyPath := interfaceConfigPath + "\\" + r.guid
|
|
|
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
|
|
if err != nil {
|
|
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
|
}
|
|
|
|
return regKey, nil
|
|
}
|
|
|
|
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
|
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
|
|
if err == nil {
|
|
k.Close()
|
|
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) {
|
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
|
}
|
|
defer regKey.Close()
|
|
|
|
val, _, err := regKey.GetStringValue(key)
|
|
if err != nil {
|
|
return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err)
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error {
|
|
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err)
|
|
}
|
|
defer regKey.Close()
|
|
|
|
err = regKey.SetStringValue(key, value)
|
|
if err != nil {
|
|
return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err)
|
|
}
|
|
|
|
return nil
|
|
}
|