diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 41307bb72..fbf74a1ef 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -25,7 +25,6 @@ jobs: needs: pre runs-on: windows-latest steps: - - name: Checkout code uses: actions/checkout@v2 diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go new file mode 100644 index 000000000..741ab97b4 --- /dev/null +++ b/client/internal/dns/local.go @@ -0,0 +1,56 @@ +package dns + +import ( + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" + log "github.com/sirupsen/logrus" + "sync" +) + +type localResolver struct { + registeredMap registrationMap + records sync.Map +} + +// ServeDNS handles a DNS request +func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + log.Tracef("received question: %#v\n", r.Question[0]) + response := d.lookupRecord(r) + if response == nil { + log.Debugf("got empty response for question: %#v\n", r.Question[0]) + return + } + + replyMessage := &dns.Msg{} + replyMessage.SetReply(r) + replyMessage.Answer = append(replyMessage.Answer, response) + + err := w.WriteMsg(replyMessage) + if err != nil { + log.Debugf("got an error while writing the local resolver response, error: %v", err) + } +} + +func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR { + record, found := d.records.Load(r.Question[0].Name) + if !found { + return nil + } + + return record.(dns.RR) +} + +func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error { + fullRecord, err := dns.NewRR(record.String()) + if err != nil { + return err + } + + d.records.Store(fullRecord.Header().Name, fullRecord) + + return nil +} + +func (d *localResolver) deleteRecord(recordKey string) { + d.records.Delete(dns.Fqdn(recordKey)) +} diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go new file mode 100644 index 000000000..79a57881b --- /dev/null +++ b/client/internal/dns/local_test.go @@ -0,0 +1,86 @@ +package dns + +import ( + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" + "strings" + "testing" +) + +func TestLocalResolver_ServeDNS(t *testing.T) { + recordA := nbdns.SimpleRecord{ + Name: "peera.netbird.cloud.", + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "1.2.3.4", + } + + recordCNAME := nbdns.SimpleRecord{ + Name: "peerb.netbird.cloud.", + Type: 5, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "www.netbird.io", + } + + testCases := []struct { + name string + inputRecord nbdns.SimpleRecord + inputMSG *dns.Msg + responseShouldBeNil bool + }{ + { + name: "Should Resolve A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA), + }, + { + name: "Should Resolve CNAME Record", + inputRecord: recordCNAME, + inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME), + }, + { + name: "Should Not Write When Not Found A Record", + inputRecord: recordA, + inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), + responseShouldBeNil: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resolver := &localResolver{ + registeredMap: make(registrationMap), + } + _ = resolver.registerRecord(testCase.inputRecord) + var responseMSG *dns.Msg + responseWriter := &mockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, testCase.inputMSG) + + if responseMSG == nil { + if testCase.responseShouldBeNil { + return + } + t.Fatalf("should write a response message") + } + + answerString := responseMSG.Answer[0].String() + if !strings.Contains(answerString, testCase.inputRecord.Name) { + t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString) + } + if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) { + t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString) + } + if !strings.Contains(answerString, testCase.inputRecord.RData) { + t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString) + } + }) + } +} diff --git a/client/internal/dns/mock_test.go b/client/internal/dns/mock_test.go new file mode 100644 index 000000000..511e31dc8 --- /dev/null +++ b/client/internal/dns/mock_test.go @@ -0,0 +1,25 @@ +package dns + +import ( + "github.com/miekg/dns" + "net" +) + +type mockResponseWriter struct { + WriteMsgFunc func(m *dns.Msg) error +} + +func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error { + if rw.WriteMsgFunc != nil { + return rw.WriteMsgFunc(m) + } + return nil +} + +func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (rw *mockResponseWriter) Close() error { return nil } +func (rw *mockResponseWriter) TsigStatus() error { return nil } +func (rw *mockResponseWriter) TsigTimersOnly(bool) {} +func (rw *mockResponseWriter) Hijack() {} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go new file mode 100644 index 000000000..d8e8fc4be --- /dev/null +++ b/client/internal/dns/server.go @@ -0,0 +1,270 @@ +package dns + +import ( + "context" + "fmt" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" + log "github.com/sirupsen/logrus" + "sync" + "time" +) + +const ( + port = 5053 + defaultIP = "0.0.0.0" +) + +// Server dns server object +type Server struct { + ctx context.Context + stop context.CancelFunc + mux sync.Mutex + server *dns.Server + dnsMux *dns.ServeMux + dnsMuxMap registrationMap + localResolver *localResolver + updateSerial uint64 + listenerIsRunning bool +} + +type registrationMap map[string]struct{} + +type muxUpdate struct { + domain string + handler dns.Handler +} + +// NewServer returns a new dns server +func NewServer(ctx context.Context) *Server { + mux := dns.NewServeMux() + + dnsServer := &dns.Server{ + Addr: fmt.Sprintf("%s:%d", defaultIP, port), + Net: "udp", + Handler: mux, + UDPSize: 65535, + } + + ctx, stop := context.WithCancel(ctx) + + return &Server{ + ctx: ctx, + stop: stop, + server: dnsServer, + dnsMux: mux, + dnsMuxMap: make(registrationMap), + localResolver: &localResolver{ + registeredMap: make(registrationMap), + }, + } +} + +// Start runs the listener in a go routine +func (s *Server) Start() { + log.Debugf("starting dns on %s:%d", defaultIP, port) + go func() { + s.setListenerStatus(true) + defer s.setListenerStatus(false) + err := s.server.ListenAndServe() + if err != nil { + log.Errorf("dns server returned an error: %v", err) + } + }() +} + +func (s *Server) setListenerStatus(running bool) { + s.listenerIsRunning = running +} + +// Stop stops the server +func (s *Server) Stop() { + s.stop() + + err := s.stopListener() + if err != nil { + log.Error(err) + } +} + +func (s *Server) stopListener() error { + if !s.listenerIsRunning { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := s.server.ShutdownContext(ctx) + if err != nil { + return fmt.Errorf("stopping dns server listener returned an error: %v", err) + } + return nil +} + +// UpdateDNSServer processes an update received from the management service +func (s *Server) UpdateDNSServer(serial uint64, update nbdns.Update) error { + select { + case <-s.ctx.Done(): + log.Infof("not updating DNS server as context is closed") + return s.ctx.Err() + default: + if serial < s.updateSerial { + return fmt.Errorf("not applying dns update, error: "+ + "network update is %d behind the last applied update", s.updateSerial-serial) + } + s.mux.Lock() + defer s.mux.Unlock() + + // is the service should be disabled, we stop the listener + // and proceed with a regular update to clean up the handlers and records + if !update.ServiceEnable { + err := s.stopListener() + if err != nil { + log.Error(err) + } + } else if !s.listenerIsRunning { + s.Start() + } + + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + + muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) + + s.updateMux(muxUpdates) + s.updateLocalResolver(localRecords) + + s.updateSerial = serial + + return nil + } +} + +func (s *Server) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { + var muxUpdates []muxUpdate + localRecords := make(map[string]nbdns.SimpleRecord, 0) + + for _, customZone := range customZones { + + if len(customZone.Records) == 0 { + return nil, nil, fmt.Errorf("received an empty list of records") + } + + muxUpdates = append(muxUpdates, muxUpdate{ + domain: customZone.Domain, + handler: s.localResolver, + }) + + for _, record := range customZone.Records { + localRecords[record.Name] = record + } + } + return muxUpdates, localRecords, nil +} + +func (s *Server) buildUpstreamHandlerUpdate(nameServerGroups []nbdns.NameServerGroup) ([]muxUpdate, error) { + var muxUpdates []muxUpdate + for _, nsGroup := range nameServerGroups { + if len(nsGroup.NameServers) == 0 { + return nil, fmt.Errorf("received a nameserver group with empty nameserver list") + } + handler := &upstreamResolver{ + parentCTX: s.ctx, + upstreamClient: &dns.Client{}, + upstreamTimeout: defaultUpstreamTimeout, + } + for _, ns := range nsGroup.NameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) + } + + if len(handler.upstreamServers) == 0 { + log.Errorf("received a nameserver group with an invalid nameserver list") + continue + } + + if nsGroup.Primary { + muxUpdates = append(muxUpdates, muxUpdate{ + domain: nbdns.RootZone, + handler: handler, + }) + continue + } + + if len(nsGroup.Domains) == 0 { + return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") + } + + for _, domain := range nsGroup.Domains { + if domain == "" { + return nil, fmt.Errorf("received a nameserver group with an empty domain element") + } + muxUpdates = append(muxUpdates, muxUpdate{ + domain: domain, + handler: handler, + }) + } + } + return muxUpdates, nil +} + +func (s *Server) updateMux(muxUpdates []muxUpdate) { + muxUpdateMap := make(registrationMap) + + for _, update := range muxUpdates { + s.registerMux(update.domain, update.handler) + muxUpdateMap[update.domain] = struct{}{} + } + + for key := range s.dnsMuxMap { + _, found := muxUpdateMap[key] + if !found { + s.deregisterMux(key) + } + } + + s.dnsMuxMap = muxUpdateMap +} + +func (s *Server) updateLocalResolver(update map[string]nbdns.SimpleRecord) { + for key := range s.localResolver.registeredMap { + _, found := update[key] + if !found { + s.localResolver.deleteRecord(key) + } + } + + updatedMap := make(registrationMap) + for key, record := range update { + err := s.localResolver.registerRecord(record) + if err != nil { + log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err) + } + updatedMap[key] = struct{}{} + } + + s.localResolver.registeredMap = updatedMap +} + +func getNSHostPort(ns nbdns.NameServer) string { + return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) +} + +func (s *Server) registerMux(pattern string, handler dns.Handler) { + s.dnsMux.Handle(pattern, handler) +} + +func (s *Server) deregisterMux(pattern string) { + s.dnsMux.HandleRemove(pattern) +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go new file mode 100644 index 000000000..43c35aa7e --- /dev/null +++ b/client/internal/dns/server_test.go @@ -0,0 +1,285 @@ +package dns + +import ( + "context" + "fmt" + nbdns "github.com/netbirdio/netbird/dns" + "net" + "net/netip" + "os" + "runtime" + "testing" + "time" +) + +var zoneRecords = []nbdns.SimpleRecord{ + { + Name: "peera.netbird.cloud", + Type: 1, + Class: nbdns.DefaultClass, + TTL: 300, + RData: "1.2.3.4", + }, +} + +func TestUpdateDNSServer(t *testing.T) { + + nameServers := []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + } + + testCases := []struct { + name string + initUpstreamMap registrationMap + initLocalMap registrationMap + initSerial uint64 + inputSerial uint64 + inputUpdate nbdns.Update + shouldFail bool + expectedUpstreamMap registrationMap + expectedLocalMap registrationMap + }{ + { + name: "Initial Update Should Succeed", + initLocalMap: make(registrationMap), + initUpstreamMap: make(registrationMap), + initSerial: 0, + inputSerial: 1, + inputUpdate: nbdns.Update{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud", + Records: zoneRecords, + }, + }, + NameServerGroups: []nbdns.NameServerGroup{ + { + Domains: []string{"netbird.io"}, + NameServers: nameServers, + }, + { + NameServers: nameServers, + Primary: true, + }, + }, + }, + expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}}, + expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + }, + { + name: "New Update Should Succeed", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + initSerial: 0, + inputSerial: 1, + inputUpdate: nbdns.Update{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud", + Records: zoneRecords, + }, + }, + NameServerGroups: []nbdns.NameServerGroup{ + { + Domains: []string{"netbird.io"}, + NameServers: nameServers, + }, + }, + }, + expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}}, + expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + }, + { + name: "Smaller Update Serial Should Be Skipped", + initLocalMap: make(registrationMap), + initUpstreamMap: make(registrationMap), + initSerial: 2, + inputSerial: 1, + shouldFail: true, + }, + { + name: "Empty NS Group Domain Or Not Primary Element Should Fail", + initLocalMap: make(registrationMap), + initUpstreamMap: make(registrationMap), + initSerial: 0, + inputSerial: 1, + inputUpdate: nbdns.Update{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud", + Records: zoneRecords, + }, + }, + NameServerGroups: []nbdns.NameServerGroup{ + { + NameServers: nameServers, + }, + }, + }, + shouldFail: true, + }, + { + name: "Invalid NS Group Nameservers list Should Fail", + initLocalMap: make(registrationMap), + initUpstreamMap: make(registrationMap), + initSerial: 0, + inputSerial: 1, + inputUpdate: nbdns.Update{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud", + Records: zoneRecords, + }, + }, + NameServerGroups: []nbdns.NameServerGroup{ + { + NameServers: nameServers, + }, + }, + }, + shouldFail: true, + }, + { + name: "Invalid Custom Zone Records list Should Fail", + initLocalMap: make(registrationMap), + initUpstreamMap: make(registrationMap), + initSerial: 0, + inputSerial: 1, + inputUpdate: nbdns.Update{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "netbird.cloud", + }, + }, + NameServerGroups: []nbdns.NameServerGroup{ + { + NameServers: nameServers, + Primary: true, + }, + }, + }, + shouldFail: true, + }, + { + name: "Empty Update Should Succeed and Clean Maps", + initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, + initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + initSerial: 0, + inputSerial: 1, + inputUpdate: nbdns.Update{ServiceEnable: true}, + expectedUpstreamMap: make(registrationMap), + expectedLocalMap: make(registrationMap), + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx := context.Background() + dnsServer := NewServer(ctx) + + dnsServer.dnsMuxMap = testCase.initUpstreamMap + dnsServer.localResolver.registeredMap = testCase.initLocalMap + dnsServer.updateSerial = testCase.initSerial + dnsServer.listenerIsRunning = true + + err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) + if err != nil { + if testCase.shouldFail { + return + } + t.Fatalf("update dns server should not fail, got error: %v", err) + } + + if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) { + t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap)) + } + + for key := range testCase.expectedUpstreamMap { + _, found := dnsServer.dnsMuxMap[key] + if !found { + t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap) + } + } + + if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) { + t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap)) + } + + for key := range testCase.expectedLocalMap { + _, found := dnsServer.localResolver.registeredMap[key] + if !found { + t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap) + } + } + }) + } +} + +func TestDNSServerStartStop(t *testing.T) { + ctx := context.Background() + dnsServer := NewServer(ctx) + if runtime.GOOS == "windows" && os.Getenv("CI") == "true" { + // todo review why this test is not working only on github actions workflows + t.Skip("skipping test in Windows CI workflows.") + } + + dnsServer.Start() + + err := dnsServer.localResolver.registerRecord(zoneRecords[0]) + if err != nil { + t.Error(err) + } + + dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver) + + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: time.Second * 5, + } + addr := fmt.Sprintf("127.0.0.1:%d", port) + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + t.Log(err) + // retry test before exit, for slower systems + return d.DialContext(ctx, network, addr) + } + + return conn, nil + }, + } + + ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name) + if err != nil { + t.Fatalf("failed to connect to the server, error: %v", err) + } + + t.Log(ips) + + if ips[0] != zoneRecords[0].RData { + t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0]) + } + + dnsServer.Stop() + ctx, cancel := context.WithTimeout(ctx, time.Second*1) + defer cancel() + _, err = resolver.LookupHost(ctx, zoneRecords[0].Name) + if err == nil { + t.Fatalf("we should encounter an error when querying a stopped server") + } +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go new file mode 100644 index 000000000..fcc8bc685 --- /dev/null +++ b/client/internal/dns/upstream.go @@ -0,0 +1,67 @@ +package dns + +import ( + "context" + "errors" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "net" + "time" +) + +const defaultUpstreamTimeout = 15 * time.Second + +type upstreamResolver struct { + parentCTX context.Context + upstreamClient *dns.Client + upstreamServers []string + upstreamTimeout time.Duration +} + +// ServeDNS handles a DNS request +func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + + log.Tracef("received an upstream question: %#v", r.Question[0]) + + select { + case <-u.parentCTX.Done(): + return + default: + } + + for _, upstream := range u.upstreamServers { + ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout) + rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) + + cancel() + + if err != nil { + if err == context.DeadlineExceeded || isTimeout(err) { + log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err) + continue + } + log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err) + return + } + + log.Tracef("took %s to query the upstream %s", t, upstream) + + err = w.WriteMsg(rm) + if err != nil { + log.Errorf("got an error while writing the upstream resolver response, error: %v", err) + } + return + } + log.Errorf("all queries to the upstream nameservers failed with timeout") +} + +// isTimeout returns true if the given error is a network timeout error. +// +// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout +func isTimeout(err error) bool { + var neterr net.Error + if errors.As(err, &neterr) { + return neterr != nil && neterr.Timeout() + } + return false +} diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go new file mode 100644 index 000000000..0fbb7e49f --- /dev/null +++ b/client/internal/dns/upstream_test.go @@ -0,0 +1,110 @@ +package dns + +import ( + "context" + "github.com/miekg/dns" + "strings" + "testing" + "time" +) + +func TestUpstreamResolver_ServeDNS(t *testing.T) { + + testCases := []struct { + name string + inputMSG *dns.Msg + responseShouldBeNil bool + InputServers []string + timeout time.Duration + cancelCTX bool + expectedAnswer string + }{ + { + name: "Should Resolve A Record", + inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), + InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"}, + timeout: defaultUpstreamTimeout, + expectedAnswer: "1.1.1.1", + }, + { + name: "Should Resolve If First Upstream Times Out", + inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), + InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"}, + timeout: 2 * time.Second, + expectedAnswer: "1.1.1.1", + }, + { + name: "Should Not Resolve If Can't Connect To Both Servers", + inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), + InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"}, + timeout: 200 * time.Millisecond, + responseShouldBeNil: true, + }, + { + name: "Should Not Resolve If Parent Context Is Canceled", + inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA), + InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"}, + cancelCTX: true, + timeout: defaultUpstreamTimeout, + responseShouldBeNil: true, + }, + //{ + // name: "Should Resolve CNAME Record", + // inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME), + //}, + //{ + // name: "Should Not Write When Not Found A Record", + // inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), + // responseShouldBeNil: true, + //}, + } + // should resolve if first upstream times out + // should not write when both fails + // should not resolve if parent context is canceled + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + resolver := &upstreamResolver{ + parentCTX: ctx, + upstreamClient: &dns.Client{}, + upstreamServers: testCase.InputServers, + upstreamTimeout: testCase.timeout, + } + if testCase.cancelCTX { + cancel() + } else { + defer cancel() + } + + var responseMSG *dns.Msg + responseWriter := &mockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, testCase.inputMSG) + + if responseMSG == nil { + if testCase.responseShouldBeNil { + return + } + t.Fatalf("should write a response message") + } + + foundAnswer := false + for _, answer := range responseMSG.Answer { + if strings.Contains(answer.String(), testCase.expectedAnswer) { + foundAnswer = true + break + } + } + + if !foundAnswer { + t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer) + } + }) + } +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 08dc4de4b..b94519e67 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -3,6 +3,7 @@ package internal import ( "context" "fmt" + "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/routemanager" nbssh "github.com/netbirdio/netbird/client/ssh" nbstatus "github.com/netbirdio/netbird/client/status" @@ -103,6 +104,8 @@ type Engine struct { statusRecorder *nbstatus.Status routeManager routemanager.Manager + + dnsServer *dns.Server } // Peer is an instance of the Connection Peer @@ -130,6 +133,7 @@ func NewEngine( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, + dnsServer: dns.NewServer(ctx), } } @@ -190,6 +194,10 @@ func (e *Engine) Stop() error { e.routeManager.Stop() } + if e.dnsServer != nil { + e.dnsServer.Stop() + } + log.Infof("stopped Netbird Engine") return nil diff --git a/dns/dns.go b/dns/dns.go index e34218554..5c9144375 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -2,5 +2,55 @@ // to parse and normalize dns records and configuration package dns -// DefaultDNSPort well-known port number -const DefaultDNSPort = 53 +import ( + "fmt" + "github.com/miekg/dns" +) + +const ( + // DefaultDNSPort well-known port number + DefaultDNSPort = 53 + // RootZone is a string representation of the root zone + RootZone = "." + // DefaultClass is the class supported by the system + DefaultClass = "IN" +) + +// Update represents a dns update that is exchanged between management and peers +type Update struct { + // ServiceEnable indicates if the service should be enabled + ServiceEnable bool + // NameServerGroups contains a list of nameserver group + NameServerGroups []NameServerGroup + // CustomZones contains a list of custom zone + CustomZones []CustomZone +} + +// CustomZone represents a custom zone to be resolved by the dns server +type CustomZone struct { + // Domain is the zone's domain + Domain string + // Records custom zone records + Records []SimpleRecord +} + +// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records +type SimpleRecord struct { + // Name domain name + Name string + // Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants + Type int + // Class dns class, currently use the DefaultClass for all records + Class string + // TTL time-to-live for the record + TTL int + // RData is the actual value resolved in a dns query + RData string +} + +// String returns a string of the simple record formatted as: +// +func (s SimpleRecord) String() string { + fqdn := dns.Fqdn(s.Name) + return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData) +} diff --git a/dns/nameserver.go b/dns/nameserver.go index 9f9cc8177..2af633354 100644 --- a/dns/nameserver.go +++ b/dns/nameserver.go @@ -9,8 +9,6 @@ import ( ) const ( - // MaxGroupNameChar maximum group name size - MaxGroupNameChar = 40 // InvalidNameServerType invalid nameserver type InvalidNameServerType NameServerType = iota // UDPNameServerType udp nameserver type @@ -18,6 +16,8 @@ const ( ) const ( + // MaxGroupNameChar maximum group name size + MaxGroupNameChar = 40 // InvalidNameServerTypeString invalid nameserver type as string InvalidNameServerTypeString = "invalid" // UDPNameServerTypeString udp nameserver type as string @@ -59,6 +59,10 @@ type NameServerGroup struct { NameServers []NameServer // Groups list of peer group IDs to distribute the nameservers information Groups []string + // Primary indicates that the nameserver group is the primary resolver for any dns query + Primary bool + // Domains indicate the dns query domains to use with this nameserver group + Domains []string // Enabled group status Enabled bool } @@ -128,6 +132,8 @@ func (g *NameServerGroup) Copy() *NameServerGroup { NameServers: g.NameServers, Groups: g.Groups, Enabled: g.Enabled, + Primary: g.Primary, + Domains: g.Domains, } } @@ -136,8 +142,10 @@ func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool { return other.ID == g.ID && other.Name == g.Name && other.Description == g.Description && + other.Primary == g.Primary && compareNameServerList(g.NameServers, other.NameServers) && - compareGroupsList(g.Groups, other.Groups) + compareGroupsList(g.Groups, other.Groups) && + compareGroupsList(g.Domains, other.Domains) } func compareNameServerList(list, other []NameServer) bool { diff --git a/go.mod b/go.mod index a7960088d..676a7e4db 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 + github.com/miekg/dns v1.1.41 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.13.0 github.com/rs/xid v1.3.0 diff --git a/go.sum b/go.sum index 7924122a0..43a64c2a8 100644 --- a/go.sum +++ b/go.sum @@ -455,6 +455,7 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= +github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= diff --git a/management/server/account.go b/management/server/account.go index 7ed09615a..b0555726d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -87,7 +87,7 @@ type AccountManager interface { DeleteRoute(accountID, routeID string) error ListRoutes(accountID string) ([]*route.Route, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) + CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroup(accountID, nsGroupID string) error diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 88ac357dd..ff285e7de 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -444,12 +444,24 @@ components: type: array items: type: string + primary: + description: Nameserver group primary status + type: boolean + domains: + description: Nameserver group domain list + type: array + items: + type: string + minLength: 1 + maxLength: 255 required: - name - description - nameservers - enabled - groups + - primary + - domains NameserverGroup: allOf: - type: object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index d82f1254c..352e8717a 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -159,6 +159,9 @@ type NameserverGroup struct { // Description Nameserver group description Description string `json:"description"` + // Domains Nameserver group domain list + Domains []string `json:"domains"` + // Enabled Nameserver group status Enabled bool `json:"enabled"` @@ -173,6 +176,9 @@ type NameserverGroup struct { // Nameservers Nameserver group Nameservers []Nameserver `json:"nameservers"` + + // Primary Nameserver group primary status + Primary bool `json:"primary"` } // NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation. @@ -198,6 +204,9 @@ type NameserverGroupRequest struct { // Description Nameserver group description Description string `json:"description"` + // Domains Nameserver group domain list + Domains []string `json:"domains"` + // Enabled Nameserver group status Enabled bool `json:"enabled"` @@ -209,6 +218,9 @@ type NameserverGroupRequest struct { // Nameservers Nameserver group Nameservers []Nameserver `json:"nameservers"` + + // Primary Nameserver group primary status + Primary bool `json:"primary"` } // PatchMinimum defines model for PatchMinimum. diff --git a/management/server/http/nameservers.go b/management/server/http/nameservers.go index af8bb08a4..fed939374 100644 --- a/management/server/http/nameservers.go +++ b/management/server/http/nameservers.go @@ -71,7 +71,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Enabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled) if err != nil { toHTTPError(err, w) return diff --git a/management/server/http/nameservers_test.go b/management/server/http/nameservers_test.go index c1c55a352..2433078f8 100644 --- a/management/server/http/nameservers_test.go +++ b/management/server/http/nameservers_test.go @@ -35,6 +35,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: "super", Description: "super", + Primary: true, NameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("1.1.1.1"), @@ -60,7 +61,7 @@ func initNameserversTestData() *Nameservers { } return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) }, - CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { + CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { return &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: name, @@ -68,6 +69,8 @@ func initNameserversTestData() *Nameservers { NameServers: nameServerList, Groups: groups, Enabled: enabled, + Primary: primary, + Domains: domains, }, nil }, DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error { @@ -150,7 +153,7 @@ func TestNameserversHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/dns/nameservers", requestBody: bytes.NewBuffer( - []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), + []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), expectedStatus: http.StatusOK, expectedBody: true, expectedNSGroup: &api.NameserverGroup{ @@ -173,7 +176,7 @@ func TestNameserversHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/dns/nameservers", requestBody: bytes.NewBuffer( - []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), + []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), expectedStatus: http.StatusBadRequest, expectedBody: false, }, @@ -182,7 +185,7 @@ func TestNameserversHandlers(t *testing.T) { requestType: http.MethodPut, requestPath: "/api/dns/nameservers/" + existingNSGroupID, requestBody: bytes.NewBuffer( - []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), + []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), expectedStatus: http.StatusOK, expectedBody: true, expectedNSGroup: &api.NameserverGroup{ @@ -205,7 +208,7 @@ func TestNameserversHandlers(t *testing.T) { requestType: http.MethodPut, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestBody: bytes.NewBuffer( - []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), + []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), expectedStatus: http.StatusNotFound, expectedBody: false, }, @@ -214,7 +217,7 @@ func TestNameserversHandlers(t *testing.T) { requestType: http.MethodPut, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestBody: bytes.NewBuffer( - []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")), + []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), expectedStatus: http.StatusBadRequest, expectedBody: false, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4a6099726..ae7018183 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -54,7 +54,7 @@ type MockAccountManager struct { ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) DeleteNameServerGroupFunc func(accountID, nsGroupID string) error @@ -435,9 +435,9 @@ func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (* } // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { if am.CreateNameServerGroupFunc != nil { - return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, enabled) + return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled) } return nil, nil } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 9a0468b05..476a9a8e6 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -1,6 +1,7 @@ package server import ( + "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" "github.com/rs/xid" "google.golang.org/grpc/codes" @@ -20,6 +21,10 @@ const ( UpdateNameServerGroupGroups // UpdateNameServerGroupEnabled indicates a nameserver group status update operation UpdateNameServerGroupEnabled + // UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation + UpdateNameServerGroupPrimary + // UpdateNameServerGroupDomains indicates a nameserver group' domains update operation + UpdateNameServerGroupDomains ) // NameServerGroupUpdateOperationType operation type @@ -37,6 +42,10 @@ func (t NameServerGroupUpdateOperationType) String() string { return "UpdateNameServerGroupGroups" case UpdateNameServerGroupEnabled: return "UpdateNameServerGroupEnabled" + case UpdateNameServerGroupPrimary: + return "UpdateNameServerGroupPrimary" + case UpdateNameServerGroupDomains: + return "UpdateNameServerGroupDomains" default: return "InvalidOperation" } @@ -67,7 +76,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) } // CreateNameServerGroup creates and saves a new nameserver group -func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { am.mux.Lock() defer am.mux.Unlock() @@ -83,6 +92,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d NameServers: nameServerList, Groups: groups, Enabled: enabled, + Primary: primary, + Domains: domains, } err = validateNameServerGroup(false, newNSGroup, account) @@ -205,6 +216,18 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) } newNSGroup.Enabled = enabled + case UpdateNameServerGroupPrimary: + primary, err := strconv.ParseBool(operation.Values[0]) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) + } + newNSGroup.Primary = primary + case UpdateNameServerGroupDomains: + err = validateDomainInput(false, operation.Values) + if err != nil { + return nil, err + } + newNSGroup.Domains = operation.Values } } @@ -268,7 +291,12 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ } } - err := validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) + err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains) + if err != nil { + return err + } + + err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) if err != nil { return err } @@ -286,6 +314,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ return nil } +func validateDomainInput(primary bool, domains []string) error { + if !primary && len(domains) == 0 { + return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ + " it should be primary or have at least one domain") + } + if primary && len(domains) != 0 { + return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+ + " you should set either primary or domain") + } + for _, domain := range domains { + _, valid := dns.IsDomainName(domain) + if !valid { + return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain) + } + } + return nil +} + func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d4bdb70f1..647a2fa59 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -14,6 +14,8 @@ const ( existingNSGroupID = "existingNSGroup" nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" + validDomain = "example.com" + invalidDomain = "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com" ) func TestCreateNameServerGroup(t *testing.T) { @@ -23,6 +25,8 @@ func TestCreateNameServerGroup(t *testing.T) { enabled bool groups []string nameServers []nbdns.NameServer + primary bool + domains []string } testCases := []struct { @@ -33,11 +37,12 @@ func TestCreateNameServerGroup(t *testing.T) { expectedNSGroup *nbdns.NameServerGroup }{ { - name: "Create A NS Group", + name: "Create A NS Group With Primary Status", inputArgs: input{ name: "super", description: "super", groups: []string{group1ID}, + primary: true, nameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("1.1.1.1"), @@ -57,6 +62,52 @@ func TestCreateNameServerGroup(t *testing.T) { expectedNSGroup: &nbdns.NameServerGroup{ Name: "super", Description: "super", + Primary: true, + Groups: []string{group1ID}, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("1.1.2.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + }, + Enabled: true, + }, + }, + { + name: "Create A NS Group With Domains", + inputArgs: input{ + name: "super", + description: "super", + groups: []string{group1ID}, + primary: false, + domains: []string{validDomain}, + nameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("1.1.2.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + }, + enabled: true, + }, + errFunc: require.NoError, + shouldCreate: true, + expectedNSGroup: &nbdns.NameServerGroup{ + Name: "super", + Description: "super", + Primary: false, + Domains: []string{"example.com"}, Groups: []string{group1ID}, NameServers: []nbdns.NameServer{ { @@ -78,6 +129,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: existingNSGroupName, description: "super", + primary: true, groups: []string{group1ID}, nameServers: []nbdns.NameServer{ { @@ -101,6 +153,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "", description: "super", + primary: true, groups: []string{group1ID}, nameServers: []nbdns.NameServer{ { @@ -124,6 +177,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "1234567890123456789012345678901234567890extra", description: "super", + primary: true, groups: []string{group1ID}, nameServers: []nbdns.NameServer{ { @@ -147,6 +201,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "super", description: "super", + primary: true, groups: []string{group1ID}, nameServers: []nbdns.NameServer{}, enabled: true, @@ -159,6 +214,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "super", description: "super", + primary: true, groups: []string{group1ID}, nameServers: []nbdns.NameServer{ { @@ -187,6 +243,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "super", description: "super", + primary: true, groups: []string{}, nameServers: []nbdns.NameServer{ { @@ -210,6 +267,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "super", description: "super", + primary: true, groups: []string{"missingGroup"}, nameServers: []nbdns.NameServer{ { @@ -233,6 +291,7 @@ func TestCreateNameServerGroup(t *testing.T) { inputArgs: input{ name: "super", description: "super", + primary: true, groups: []string{""}, nameServers: []nbdns.NameServer{ { @@ -251,6 +310,53 @@ func TestCreateNameServerGroup(t *testing.T) { errFunc: require.Error, shouldCreate: false, }, + { + name: "Should Not Create If No Domain Or Primary", + inputArgs: input{ + name: "super", + description: "super", + groups: []string{group1ID}, + nameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("1.1.2.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + }, + enabled: true, + }, + errFunc: require.Error, + shouldCreate: false, + }, + { + name: "Should Not Create If Domain List Is Invalid", + inputArgs: input{ + name: "super", + description: "super", + groups: []string{group1ID}, + domains: []string{invalidDomain}, + nameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("1.1.2.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + }, + enabled: true, + }, + errFunc: require.Error, + shouldCreate: false, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { @@ -270,6 +376,8 @@ func TestCreateNameServerGroup(t *testing.T) { testCase.inputArgs.description, testCase.inputArgs.nameServers, testCase.inputArgs.groups, + testCase.inputArgs.primary, + testCase.inputArgs.domains, testCase.inputArgs.enabled, ) @@ -295,6 +403,7 @@ func TestSaveNameServerGroup(t *testing.T) { ID: "testingNSGroup", Name: "super", Description: "super", + Primary: true, NameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("1.1.1.1"), @@ -313,6 +422,10 @@ func TestSaveNameServerGroup(t *testing.T) { validGroups := []string{group2ID} invalidGroups := []string{"nonExisting"} + disabledPrimary := false + validDomains := []string{validDomain} + invalidDomains := []string{invalidDomain} + validNameServerList := []nbdns.NameServer{ { IP: netip.MustParseAddr("1.1.1.1"), @@ -348,6 +461,8 @@ func TestSaveNameServerGroup(t *testing.T) { existingNSGroup *nbdns.NameServerGroup newID *string newName *string + newPrimary *bool + newDomains []string newNSList []nbdns.NameServer newGroups []string skipCopying bool @@ -360,12 +475,16 @@ func TestSaveNameServerGroup(t *testing.T) { existingNSGroup: existingNSGroup, newName: &validName, newGroups: validGroups, + newPrimary: &disabledPrimary, + newDomains: validDomains, newNSList: validNameServerList, errFunc: require.NoError, shouldCreate: true, expectedNSGroup: &nbdns.NameServerGroup{ ID: "testingNSGroup", Name: validName, + Primary: false, + Domains: validDomains, Description: "super", NameServers: validNameServerList, Groups: validGroups, @@ -435,6 +554,29 @@ func TestSaveNameServerGroup(t *testing.T) { errFunc: require.Error, shouldCreate: false, }, + { + name: "Should Not Update If Domains List Is Empty", + existingNSGroup: existingNSGroup, + newPrimary: &disabledPrimary, + errFunc: require.Error, + shouldCreate: false, + }, + { + name: "Should Not Update If Primary And Domains", + existingNSGroup: existingNSGroup, + newPrimary: &existingNSGroup.Primary, + newDomains: validDomains, + errFunc: require.Error, + shouldCreate: false, + }, + { + name: "Should Not Update If Domains List Is Invalid", + existingNSGroup: existingNSGroup, + newPrimary: &disabledPrimary, + newDomains: invalidDomains, + errFunc: require.Error, + shouldCreate: false, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { @@ -475,6 +617,14 @@ func TestSaveNameServerGroup(t *testing.T) { if testCase.newNSList != nil { nsGroupToSave.NameServers = testCase.newNSList } + + if testCase.newPrimary != nil { + nsGroupToSave.Primary = *testCase.newPrimary + } + + if testCase.newDomains != nil { + nsGroupToSave.Domains = testCase.newDomains + } } err = am.SaveNameServerGroup(account.Id, nsGroupToSave) @@ -503,6 +653,7 @@ func TestUpdateNameServerGroup(t *testing.T) { ID: nsGroupID, Name: "super", Description: "super", + Primary: true, NameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("1.1.1.1"), @@ -544,6 +695,7 @@ func TestUpdateNameServerGroup(t *testing.T) { ID: nsGroupID, Name: "superNew", Description: "super", + Primary: true, NameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("1.1.1.1"), @@ -585,6 +737,14 @@ func TestUpdateNameServerGroup(t *testing.T) { Type: UpdateNameServerGroupEnabled, Values: []string{"false"}, }, + NameServerGroupUpdateOperation{ + Type: UpdateNameServerGroupPrimary, + Values: []string{"false"}, + }, + NameServerGroupUpdateOperation{ + Type: UpdateNameServerGroupDomains, + Values: []string{validDomain}, + }, }, errFunc: require.NoError, shouldCreate: true, @@ -592,6 +752,8 @@ func TestUpdateNameServerGroup(t *testing.T) { ID: nsGroupID, Name: "superNew", Description: "superDescription", + Primary: false, + Domains: []string{validDomain}, NameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("127.0.0.1"), @@ -740,6 +902,30 @@ func TestUpdateNameServerGroup(t *testing.T) { }, errFunc: require.Error, }, + { + name: "Should Not Update On Invalid Domains", + existingNSGroup: existingNSGroup, + nsGroupID: existingNSGroup.ID, + operations: []NameServerGroupUpdateOperation{ + NameServerGroupUpdateOperation{ + Type: UpdateNameServerGroupDomains, + Values: []string{invalidDomain}, + }, + }, + errFunc: require.Error, + }, + { + name: "Should Not Update On Invalid Primary Status", + existingNSGroup: existingNSGroup, + nsGroupID: existingNSGroup.ID, + operations: []NameServerGroupUpdateOperation{ + NameServerGroupUpdateOperation{ + Type: UpdateNameServerGroupPrimary, + Values: []string{"yes"}, + }, + }, + errFunc: require.Error, + }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) {