mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
[client] Add reverse dns zone (#3217)
This commit is contained in:
parent
6554026a82
commit
5134e3a06a
111
client/internal/dns.go
Normal file
111
client/internal/dns.go
Normal file
@ -0,0 +1,111 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ipOctets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(ipOctets)
|
||||
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||
|
||||
return nbdns.SimpleRecord{
|
||||
Name: rdnsName,
|
||||
Type: int(dns.TypePTR),
|
||||
Class: aRecord.Class,
|
||||
TTL: aRecord.TTL,
|
||||
RData: dns.Fqdn(aRecord.Name),
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
}
|
||||
|
||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||
func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.Domain == zoneName {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
if zoneExists(config, zoneName) {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
|
||||
}
|
@ -9,6 +9,11 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
@ -94,9 +99,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
MatchOnly: false,
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -395,12 +395,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("upstream handler updater: %w", err)
|
||||
}
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
||||
|
||||
@ -447,7 +447,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
|
||||
continue
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
@ -460,7 +461,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
|
@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Fail",
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
|
@ -953,7 +953,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
@ -1022,7 +1022,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||
return dnsRoutes
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
@ -1062,6 +1062,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
}
|
||||
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
|
||||
}
|
||||
|
||||
if len(dnsUpdate.CustomZones) > 0 {
|
||||
addReverseZone(&dnsUpdate, network)
|
||||
}
|
||||
|
||||
return dnsUpdate
|
||||
}
|
||||
|
||||
@ -1368,7 +1373,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
|
||||
|
@ -361,6 +361,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
@ -803,6 +812,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Domain: "0.66.100.in-addr.arpa.",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*mgmtProto.NameServerGroup{
|
||||
{
|
||||
@ -832,6 +844,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Domain: "0.66.100.in-addr.arpa.",
|
||||
},
|
||||
},
|
||||
expectedNSGroupsLen: 1,
|
||||
expectedNSGroups: []*nbdns.NameServerGroup{
|
||||
|
Loading…
x
Reference in New Issue
Block a user