From 5134e3a06adf50700165984a6a4a0f67b7789a21 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 21 Feb 2025 12:52:04 +0100 Subject: [PATCH] [client] Add reverse dns zone (#3217) --- client/internal/dns.go | 111 +++++++++++++++++++++++++++++ client/internal/dns/host.go | 8 ++- client/internal/dns/server.go | 10 +-- client/internal/dns/server_test.go | 8 ++- client/internal/engine.go | 11 ++- client/internal/engine_test.go | 15 ++++ 6 files changed, 153 insertions(+), 10 deletions(-) create mode 100644 client/internal/dns.go diff --git a/client/internal/dns.go b/client/internal/dns.go new file mode 100644 index 000000000..8a73f50f2 --- /dev/null +++ b/client/internal/dns.go @@ -0,0 +1,111 @@ +package internal + +import ( + "fmt" + "net" + "slices" + "strings" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" +) + +func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { + ip := net.ParseIP(aRecord.RData) + if ip == nil || ip.To4() == nil { + return nbdns.SimpleRecord{}, false + } + + if !ipNet.Contains(ip) { + return nbdns.SimpleRecord{}, false + } + + ipOctets := strings.Split(ip.String(), ".") + slices.Reverse(ipOctets) + rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa") + + return nbdns.SimpleRecord{ + Name: rdnsName, + Type: int(dns.TypePTR), + Class: aRecord.Class, + TTL: aRecord.TTL, + RData: dns.Fqdn(aRecord.Name), + }, true +} + +// generateReverseZoneName creates the reverse DNS zone name for a given network +func generateReverseZoneName(ipNet *net.IPNet) (string, error) { + networkIP := ipNet.IP.Mask(ipNet.Mask) + maskOnes, _ := ipNet.Mask.Size() + + // round up to nearest byte + octetsToUse := (maskOnes + 7) / 8 + + octets := strings.Split(networkIP.String(), ".") + if octetsToUse > len(octets) { + return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) + } + + reverseOctets := make([]string, octetsToUse) + for i := 0; i < octetsToUse; i++ { + reverseOctets[octetsToUse-1-i] = octets[i] + } + + return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil +} + +// zoneExists checks if a zone with the given name already exists in the configuration +func zoneExists(config *nbdns.Config, zoneName string) bool { + for _, zone := range config.CustomZones { + if zone.Domain == zoneName { + log.Debugf("reverse DNS zone %s already exists", zoneName) + return true + } + } + return false +} + +// collectPTRRecords gathers all PTR records for the given network from A records +func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + + for _, zone := range config.CustomZones { + for _, record := range zone.Records { + if record.Type != int(dns.TypeA) { + continue + } + + if ptrRecord, ok := createPTRRecord(record, ipNet); ok { + records = append(records, ptrRecord) + } + } + } + + return records +} + +// addReverseZone adds a reverse DNS zone to the configuration for the given network +func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { + zoneName, err := generateReverseZoneName(ipNet) + if err != nil { + log.Warn(err) + return + } + + if zoneExists(config, zoneName) { + log.Debugf("reverse DNS zone %s already exists", zoneName) + return + } + + records := collectPTRRecords(config, ipNet) + + reverseZone := nbdns.CustomZone{ + Domain: zoneName, + Records: records, + } + + config.CustomZones = append(config.CustomZones, reverseZone) + log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records)) +} diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index fbe8c4dbb..cfc0cc3c3 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -9,6 +9,11 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) +const ( + ipv4ReverseZone = ".in-addr.arpa" + ipv6ReverseZone = ".ip6.arpa" +) + type hostManager interface { applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error @@ -94,9 +99,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD } for _, customZone := range dnsConfig.CustomZones { + matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ Domain: strings.TrimSuffix(customZone.Domain, "."), - MatchOnly: false, + MatchOnly: matchOnly, }) } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4d68370d..f536a1434 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -395,12 +395,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) + return fmt.Errorf("local handler updater: %w", err) } upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) + return fmt.Errorf("upstream handler updater: %w", err) } muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic @@ -447,7 +447,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate( for _, customZone := range customZones { if len(customZone.Records) == 0 { - return nil, nil, fmt.Errorf("received an empty list of records") + log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain) + continue } muxUpdates = append(muxUpdates, handlerWrapper{ @@ -460,7 +461,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate( for _, record := range customZone.Records { var class uint16 = dns.ClassINET if record.Class != nbdns.DefaultClass { - return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) + log.Warnf("received an invalid class type: %s", record.Class) + continue } key := buildRecordKey(record.Name, class, uint16(record.Type)) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index e9ddd5f59..1354462d9 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) { shouldFail: true, }, { - name: "Invalid Custom Zone Records list Should Fail", + name: "Invalid Custom Zone Records list Should Skip", initLocalMap: make(registrationMap), initUpstreamMap: make(registeredHandlerMap), initSerial: 0, @@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) { }, }, }, - shouldFail: true, + expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ + domain: ".", + handler: dummyHandler, + priority: PriorityDefault, + }}, }, { name: "Empty Config Should Succeed and Clean Maps", diff --git a/client/internal/engine.go b/client/internal/engine.go index 90f70a827..ebb68b98b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -953,7 +953,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -1022,7 +1022,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string { return dnsRoutes } -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), @@ -1062,6 +1062,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { } dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup) } + + if len(dnsUpdate.CustomZones) > 0 { + addReverseZone(&dnsUpdate, network) + } + return dnsUpdate } @@ -1368,7 +1373,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { return nil, nil, err } routes := toRoutes(netMap.GetRoutes()) - dnsCfg := toDNSConfig(netMap.GetDNSConfig()) + dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) return routes, &dnsCfg, nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2b1d3b098..599d36eab 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -361,6 +361,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { RemovePeerFunc: func(peerKey string) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, } engine.wgInterface = wgIface engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ @@ -803,6 +812,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, }, }, + { + Domain: "0.66.100.in-addr.arpa.", + }, }, NameServerGroups: []*mgmtProto.NameServerGroup{ { @@ -832,6 +844,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, }, }, + { + Domain: "0.66.100.in-addr.arpa.", + }, }, expectedNSGroupsLen: 1, expectedNSGroups: []*nbdns.NameServerGroup{