[client] Create NRPT rules separately per domain (#4329)

This commit is contained in:
Viktor Liu
2025-08-12 15:40:37 +02:00
committed by GitHub
parent ccbabd9e2a
commit 0fdb944058
3 changed files with 60 additions and 24 deletions

View File

@ -176,4 +176,3 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg) t.Errorf("unexpected resolv.conf content: %v", cfg)
} }
} }

View File

@ -64,9 +64,10 @@ const (
) )
type registryConfigurator struct { type registryConfigurator struct {
guid string guid string
routingAll bool routingAll bool
gpo bool gpo bool
nrptEntryCount int
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil { if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err) log.Errorf("failed to update shutdown state: %s", err)
} }
@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
return fmt.Errorf("add dns match policy: %w", err) return fmt.Errorf("add dns match policy: %w", err)
} }
r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil { if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err) return fmt.Errorf("remove dns match policies: %w", err)
} }
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
@ -220,28 +236,34 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return nil return nil
} }
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
if r.gpo { for i, domain := range domains {
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil { policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
return fmt.Errorf("configure GPO DNS policy: %w", err) if r.gpo {
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
} }
singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
}
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err) log.Warnf("failed to refresh group policy: %v", err)
} }
} else {
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure local DNS policy: %w", err)
}
} }
log.Infof("added %d match domains. Domain list: %s", len(domains), domains) log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return nil return len(domains), nil
} }
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
func (r *registryConfigurator) removeDNSMatchPolicies() error { func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error var merr *multierror.Error
// Try to remove the base entries (for backward compatibility)
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { for i := 0; i < r.nrptEntryCount; i++ {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err)) localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
}
} }
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {

View File

@ -5,8 +5,9 @@ import (
) )
type ShutdownState struct { type ShutdownState struct {
Guid string Guid string
GPO bool GPO bool
NRPTEntryCount int
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@ -15,8 +16,9 @@ func (s *ShutdownState) Name() string {
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
manager := &registryConfigurator{ manager := &registryConfigurator{
guid: s.Guid, guid: s.Guid,
gpo: s.GPO, gpo: s.GPO,
nrptEntryCount: s.NRPTEntryCount,
} }
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {