[client] Add reverse dns zone (#3217)

This commit is contained in:
Viktor Liu 2025-02-21 12:52:04 +01:00 committed by GitHub
parent 6554026a82
commit 5134e3a06a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 153 additions and 10 deletions

111
client/internal/dns.go Normal file
View File

@ -0,0 +1,111 @@
package internal
import (
"fmt"
"net"
"slices"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData)
if ip == nil || ip.To4() == nil {
return nbdns.SimpleRecord{}, false
}
if !ipNet.Contains(ip) {
return nbdns.SimpleRecord{}, false
}
ipOctets := strings.Split(ip.String(), ".")
slices.Reverse(ipOctets)
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
return nbdns.SimpleRecord{
Name: rdnsName,
Type: int(dns.TypePTR),
Class: aRecord.Class,
TTL: aRecord.TTL,
RData: dns.Fqdn(aRecord.Name),
}, true
}
// generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask)
maskOnes, _ := ipNet.Mask.Size()
// round up to nearest byte
octetsToUse := (maskOnes + 7) / 8
octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
}
reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
}
// zoneExists checks if a zone with the given name already exists in the configuration
func zoneExists(config *nbdns.Config, zoneName string) bool {
for _, zone := range config.CustomZones {
if zone.Domain == zoneName {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return true
}
}
return false
}
// collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
}
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
records = append(records, ptrRecord)
}
}
}
return records
}
// addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
zoneName, err := generateReverseZoneName(ipNet)
if err != nil {
log.Warn(err)
return
}
if zoneExists(config, zoneName) {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return
}
records := collectPTRRecords(config, ipNet)
reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
}
config.CustomZones = append(config.CustomZones, reverseZone)
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
}

View File

@ -9,6 +9,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
const (
ipv4ReverseZone = ".in-addr.arpa"
ipv6ReverseZone = ".ip6.arpa"
)
type hostManager interface { type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error restoreHostDNS() error
@ -94,9 +99,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
} }
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."), Domain: strings.TrimSuffix(customZone.Domain, "."),
MatchOnly: false, MatchOnly: matchOnly,
}) })
} }

View File

@ -395,12 +395,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil { if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err) return fmt.Errorf("local handler updater: %w", err)
} }
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil { if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err) return fmt.Errorf("upstream handler updater: %w", err)
} }
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
@ -447,7 +447,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
for _, customZone := range customZones { for _, customZone := range customZones {
if len(customZone.Records) == 0 { if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records") log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
continue
} }
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
@ -460,7 +461,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
for _, record := range customZone.Records { for _, record := range customZone.Records {
var class uint16 = dns.ClassINET var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass { if record.Class != nbdns.DefaultClass {
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) log.Warnf("received an invalid class type: %s", record.Class)
continue
} }
key := buildRecordKey(record.Name, class, uint16(record.Type)) key := buildRecordKey(record.Name, class, uint16(record.Type))

View File

@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid Custom Zone Records list Should Fail", name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap), initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
shouldFail: true, expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",

View File

@ -953,7 +953,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
} }
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
@ -1022,7 +1022,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
return dnsRoutes return dnsRoutes
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
@ -1062,6 +1062,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
} }
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup) dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
} }
if len(dnsUpdate.CustomZones) > 0 {
addReverseZone(&dnsUpdate, network)
}
return dnsUpdate return dnsUpdate
} }
@ -1368,7 +1373,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
return nil, nil, err return nil, nil, err
} }
routes := toRoutes(netMap.GetRoutes()) routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig()) dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
return routes, &dnsCfg, nil return routes, &dnsCfg, nil
} }

View File

@ -361,6 +361,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
@ -803,6 +812,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}, },
}, },
}, },
{
Domain: "0.66.100.in-addr.arpa.",
},
}, },
NameServerGroups: []*mgmtProto.NameServerGroup{ NameServerGroups: []*mgmtProto.NameServerGroup{
{ {
@ -832,6 +844,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}, },
}, },
}, },
{
Domain: "0.66.100.in-addr.arpa.",
},
}, },
expectedNSGroupsLen: 1, expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{ expectedNSGroups: []*nbdns.NameServerGroup{