mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
Feature/dns-server (#537)
Adding DNS server for client Updated the API with new fields Added custom zone object for peer's DNS resolution
This commit is contained in:
parent
6aa7a2c5e1
commit
e8d82c1bd3
1
.github/workflows/golang-test-windows.yml
vendored
1
.github/workflows/golang-test-windows.yml
vendored
@ -25,7 +25,6 @@ jobs:
|
||||
needs: pre
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
|
56
client/internal/dns/local.go
Normal file
56
client/internal/dns/local.go
Normal file
@ -0,0 +1,56 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type localResolver struct {
|
||||
registeredMap registrationMap
|
||||
records sync.Map
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Tracef("received question: %#v\n", r.Question[0])
|
||||
response := d.lookupRecord(r)
|
||||
if response == nil {
|
||||
log.Debugf("got empty response for question: %#v\n", r.Question[0])
|
||||
return
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.Answer = append(replyMessage.Answer, response)
|
||||
|
||||
err := w.WriteMsg(replyMessage)
|
||||
if err != nil {
|
||||
log.Debugf("got an error while writing the local resolver response, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
|
||||
record, found := d.records.Load(r.Question[0].Name)
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
return record.(dns.RR)
|
||||
}
|
||||
|
||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||
fullRecord, err := dns.NewRR(record.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.records.Store(fullRecord.Header().Name, fullRecord)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *localResolver) deleteRecord(recordKey string) {
|
||||
d.records.Delete(dns.Fqdn(recordKey))
|
||||
}
|
86
client/internal/dns/local_test.go
Normal file
86
client/internal/dns/local_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
}
|
||||
|
||||
recordCNAME := nbdns.SimpleRecord{
|
||||
Name: "peerb.netbird.cloud.",
|
||||
Type: 5,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "www.netbird.io",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputRecord nbdns.SimpleRecord
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||
},
|
||||
{
|
||||
name: "Should Resolve CNAME Record",
|
||||
inputRecord: recordCNAME,
|
||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||
},
|
||||
{
|
||||
name: "Should Not Write When Not Found A Record",
|
||||
inputRecord: recordA,
|
||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
resolver := &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
}
|
||||
_ = resolver.registerRecord(testCase.inputRecord)
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
answerString := responseMSG.Answer[0].String()
|
||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||
}
|
||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
25
client/internal/dns/mock_test.go
Normal file
25
client/internal/dns/mock_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
)
|
||||
|
||||
type mockResponseWriter struct {
|
||||
WriteMsgFunc func(m *dns.Msg) error
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
if rw.WriteMsgFunc != nil {
|
||||
return rw.WriteMsgFunc(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (rw *mockResponseWriter) Close() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (rw *mockResponseWriter) Hijack() {}
|
270
client/internal/dns/server.go
Normal file
270
client/internal/dns/server.go
Normal file
@ -0,0 +1,270 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
port = 5053
|
||||
defaultIP = "0.0.0.0"
|
||||
)
|
||||
|
||||
// Server dns server object
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
dnsMuxMap registrationMap
|
||||
localResolver *localResolver
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
}
|
||||
|
||||
type registrationMap map[string]struct{}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler dns.Handler
|
||||
}
|
||||
|
||||
// NewServer returns a new dns server
|
||||
func NewServer(ctx context.Context) *Server {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", defaultIP, port),
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
return &Server{
|
||||
ctx: ctx,
|
||||
stop: stop,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registrationMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the listener in a go routine
|
||||
func (s *Server) Start() {
|
||||
log.Debugf("starting dns on %s:%d", defaultIP, port)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server returned an error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *Server) Stop() {
|
||||
s.stop()
|
||||
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) stopListener() error {
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *Server) UpdateDNSServer(serial uint64, update nbdns.Update) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
// is the service should be disabled, we stop the listener
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.Start()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
|
||||
s.updateSerial = serial
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
localRecords[record.Name] = record
|
||||
}
|
||||
}
|
||||
return muxUpdates, localRecords, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildUpstreamHandlerUpdate(nameServerGroups []nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||
var muxUpdates []muxUpdate
|
||||
for _, nsGroup := range nameServerGroups {
|
||||
if len(nsGroup.NameServers) == 0 {
|
||||
return nil, fmt.Errorf("received a nameserver group with empty nameserver list")
|
||||
}
|
||||
handler := &upstreamResolver{
|
||||
parentCTX: s.ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamTimeout: defaultUpstreamTimeout,
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||
}
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||
continue
|
||||
}
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if len(nsGroup.Domains) == 0 {
|
||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||
}
|
||||
|
||||
for _, domain := range nsGroup.Domains {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
}
|
||||
return muxUpdates, nil
|
||||
}
|
||||
|
||||
func (s *Server) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registrationMap)
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = struct{}{}
|
||||
}
|
||||
|
||||
for key := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
s.deregisterMux(key)
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
}
|
||||
|
||||
func (s *Server) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
for key := range s.localResolver.registeredMap {
|
||||
_, found := update[key]
|
||||
if !found {
|
||||
s.localResolver.deleteRecord(key)
|
||||
}
|
||||
}
|
||||
|
||||
updatedMap := make(registrationMap)
|
||||
for key, record := range update {
|
||||
err := s.localResolver.registerRecord(record)
|
||||
if err != nil {
|
||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||
}
|
||||
updatedMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
s.localResolver.registeredMap = updatedMap
|
||||
}
|
||||
|
||||
func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *Server) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *Server) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
285
client/internal/dns/server_test.go
Normal file
285
client/internal/dns/server_test.go
Normal file
@ -0,0 +1,285 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
Type: 1,
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: "1.2.3.4",
|
||||
},
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initUpstreamMap registrationMap
|
||||
initLocalMap registrationMap
|
||||
initSerial uint64
|
||||
inputSerial uint64
|
||||
inputUpdate nbdns.Update
|
||||
shouldFail bool
|
||||
expectedUpstreamMap registrationMap
|
||||
expectedLocalMap registrationMap
|
||||
}{
|
||||
{
|
||||
name: "Initial Update Should Succeed",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}},
|
||||
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "New Update Should Succeed",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"netbird.io"},
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}},
|
||||
expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Update Serial Should Be Skipped",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Fail",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registrationMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: nameServers,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty Update Should Succeed and Clean Maps",
|
||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
||||
initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}},
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Update{ServiceEnable: true},
|
||||
expectedUpstreamMap: make(registrationMap),
|
||||
expectedLocalMap: make(registrationMap),
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dnsServer := NewServer(ctx)
|
||||
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
dnsServer.listenerIsRunning = true
|
||||
|
||||
err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
}
|
||||
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
||||
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedUpstreamMap {
|
||||
_, found := dnsServer.dnsMuxMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
||||
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
||||
}
|
||||
|
||||
for key := range testCase.expectedLocalMap {
|
||||
_, found := dnsServer.localResolver.registeredMap[key]
|
||||
if !found {
|
||||
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerStartStop(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dnsServer := NewServer(ctx)
|
||||
if runtime.GOOS == "windows" && os.Getenv("CI") == "true" {
|
||||
// todo review why this test is not working only on github actions workflows
|
||||
t.Skip("skipping test in Windows CI workflows.")
|
||||
}
|
||||
|
||||
dnsServer.Start()
|
||||
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
// retry test before exit, for slower systems
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to the server, error: %v", err)
|
||||
}
|
||||
|
||||
t.Log(ips)
|
||||
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
|
||||
}
|
||||
|
||||
dnsServer.Stop()
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*1)
|
||||
defer cancel()
|
||||
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
|
||||
if err == nil {
|
||||
t.Fatalf("we should encounter an error when querying a stopped server")
|
||||
}
|
||||
}
|
67
client/internal/dns/upstream.go
Normal file
67
client/internal/dns/upstream.go
Normal file
@ -0,0 +1,67 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultUpstreamTimeout = 15 * time.Second
|
||||
|
||||
type upstreamResolver struct {
|
||||
parentCTX context.Context
|
||||
upstreamClient *dns.Client
|
||||
upstreamServers []string
|
||||
upstreamTimeout time.Duration
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
log.Tracef("received an upstream question: %#v", r.Question[0])
|
||||
|
||||
select {
|
||||
case <-u.parentCTX.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
ctx, cancel := context.WithTimeout(u.parentCTX, u.upstreamTimeout)
|
||||
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
||||
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
if err == context.DeadlineExceeded || isTimeout(err) {
|
||||
log.Warnf("got an error while connecting to upstream %s, error: %v", upstream, err)
|
||||
continue
|
||||
}
|
||||
log.Errorf("got an error while querying the upstream %s, error: %v", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
if err != nil {
|
||||
log.Errorf("got an error while writing the upstream resolver response, error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Errorf("all queries to the upstream nameservers failed with timeout")
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
//
|
||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||
func isTimeout(err error) bool {
|
||||
var neterr net.Error
|
||||
if errors.As(err, &neterr) {
|
||||
return neterr != nil && neterr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
110
client/internal/dns/upstream_test.go
Normal file
110
client/internal/dns/upstream_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/miekg/dns"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputMSG *dns.Msg
|
||||
responseShouldBeNil bool
|
||||
InputServers []string
|
||||
timeout time.Duration
|
||||
cancelCTX bool
|
||||
expectedAnswer string
|
||||
}{
|
||||
{
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: defaultUpstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "Should Resolve If First Upstream Times Out",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
timeout: 2 * time.Second,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
name: "Should Not Resolve If Can't Connect To Both Servers",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
|
||||
timeout: 200 * time.Millisecond,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Resolve If Parent Context Is Canceled",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: defaultUpstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
//{
|
||||
// name: "Should Resolve CNAME Record",
|
||||
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
|
||||
//},
|
||||
//{
|
||||
// name: "Should Not Write When Not Found A Record",
|
||||
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||
// responseShouldBeNil: true,
|
||||
//},
|
||||
}
|
||||
// should resolve if first upstream times out
|
||||
// should not write when both fails
|
||||
// should not resolve if parent context is canceled
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver := &upstreamResolver{
|
||||
parentCTX: ctx,
|
||||
upstreamClient: &dns.Client{},
|
||||
upstreamServers: testCase.InputServers,
|
||||
upstreamTimeout: testCase.timeout,
|
||||
}
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
} else {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &mockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||
|
||||
if responseMSG == nil {
|
||||
if testCase.responseShouldBeNil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should write a response message")
|
||||
}
|
||||
|
||||
foundAnswer := false
|
||||
for _, answer := range responseMSG.Answer {
|
||||
if strings.Contains(answer.String(), testCase.expectedAnswer) {
|
||||
foundAnswer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundAnswer {
|
||||
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
@ -103,6 +104,8 @@ type Engine struct {
|
||||
statusRecorder *nbstatus.Status
|
||||
|
||||
routeManager routemanager.Manager
|
||||
|
||||
dnsServer *dns.Server
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@ -130,6 +133,7 @@ func NewEngine(
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dns.NewServer(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
@ -190,6 +194,10 @@ func (e *Engine) Stop() error {
|
||||
e.routeManager.Stop()
|
||||
}
|
||||
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
|
54
dns/dns.go
54
dns/dns.go
@ -2,5 +2,55 @@
|
||||
// to parse and normalize dns records and configuration
|
||||
package dns
|
||||
|
||||
// DefaultDNSPort well-known port number
|
||||
const DefaultDNSPort = 53
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDNSPort well-known port number
|
||||
DefaultDNSPort = 53
|
||||
// RootZone is a string representation of the root zone
|
||||
RootZone = "."
|
||||
// DefaultClass is the class supported by the system
|
||||
DefaultClass = "IN"
|
||||
)
|
||||
|
||||
// Update represents a dns update that is exchanged between management and peers
|
||||
type Update struct {
|
||||
// ServiceEnable indicates if the service should be enabled
|
||||
ServiceEnable bool
|
||||
// NameServerGroups contains a list of nameserver group
|
||||
NameServerGroups []NameServerGroup
|
||||
// CustomZones contains a list of custom zone
|
||||
CustomZones []CustomZone
|
||||
}
|
||||
|
||||
// CustomZone represents a custom zone to be resolved by the dns server
|
||||
type CustomZone struct {
|
||||
// Domain is the zone's domain
|
||||
Domain string
|
||||
// Records custom zone records
|
||||
Records []SimpleRecord
|
||||
}
|
||||
|
||||
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records
|
||||
type SimpleRecord struct {
|
||||
// Name domain name
|
||||
Name string
|
||||
// Type of record, 1 for A, 5 for CNAME, 28 for AAAA. see https://pkg.go.dev/github.com/miekg/dns@v1.1.41#pkg-constants
|
||||
Type int
|
||||
// Class dns class, currently use the DefaultClass for all records
|
||||
Class string
|
||||
// TTL time-to-live for the record
|
||||
TTL int
|
||||
// RData is the actual value resolved in a dns query
|
||||
RData string
|
||||
}
|
||||
|
||||
// String returns a string of the simple record formatted as:
|
||||
// <Name> <TTL> <Class> <Type> <RDATA>
|
||||
func (s SimpleRecord) String() string {
|
||||
fqdn := dns.Fqdn(s.Name)
|
||||
return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData)
|
||||
}
|
||||
|
@ -9,8 +9,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxGroupNameChar maximum group name size
|
||||
MaxGroupNameChar = 40
|
||||
// InvalidNameServerType invalid nameserver type
|
||||
InvalidNameServerType NameServerType = iota
|
||||
// UDPNameServerType udp nameserver type
|
||||
@ -18,6 +16,8 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxGroupNameChar maximum group name size
|
||||
MaxGroupNameChar = 40
|
||||
// InvalidNameServerTypeString invalid nameserver type as string
|
||||
InvalidNameServerTypeString = "invalid"
|
||||
// UDPNameServerTypeString udp nameserver type as string
|
||||
@ -59,6 +59,10 @@ type NameServerGroup struct {
|
||||
NameServers []NameServer
|
||||
// Groups list of peer group IDs to distribute the nameservers information
|
||||
Groups []string
|
||||
// Primary indicates that the nameserver group is the primary resolver for any dns query
|
||||
Primary bool
|
||||
// Domains indicate the dns query domains to use with this nameserver group
|
||||
Domains []string
|
||||
// Enabled group status
|
||||
Enabled bool
|
||||
}
|
||||
@ -128,6 +132,8 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
|
||||
NameServers: g.NameServers,
|
||||
Groups: g.Groups,
|
||||
Enabled: g.Enabled,
|
||||
Primary: g.Primary,
|
||||
Domains: g.Domains,
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,8 +142,10 @@ func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
|
||||
return other.ID == g.ID &&
|
||||
other.Name == g.Name &&
|
||||
other.Description == g.Description &&
|
||||
other.Primary == g.Primary &&
|
||||
compareNameServerList(g.NameServers, other.NameServers) &&
|
||||
compareGroupsList(g.Groups, other.Groups)
|
||||
compareGroupsList(g.Groups, other.Groups) &&
|
||||
compareGroupsList(g.Domains, other.Domains)
|
||||
}
|
||||
|
||||
func compareNameServerList(list, other []NameServer) bool {
|
||||
|
1
go.mod
1
go.mod
@ -38,6 +38,7 @@ require (
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/libp2p/go-netroute v0.2.0
|
||||
github.com/magiconair/properties v1.8.5
|
||||
github.com/miekg/dns v1.1.41
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/prometheus/client_golang v1.13.0
|
||||
github.com/rs/xid v1.3.0
|
||||
|
1
go.sum
1
go.sum
@ -455,6 +455,7 @@ github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb h1:2dC7L10LmTqlyMV
|
||||
github.com/mdlayher/socket v0.0.0-20211102153432-57e3fa563ecb/go.mod h1:nFZ1EtZYK8Gi/k6QNu7z7CgO20i/4ExeQswwWuPmG/g=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
||||
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
|
||||
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
|
||||
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
|
||||
|
@ -87,7 +87,7 @@ type AccountManager interface {
|
||||
DeleteRoute(accountID, routeID string) error
|
||||
ListRoutes(accountID string) ([]*route.Route, error)
|
||||
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
||||
DeleteNameServerGroup(accountID, nsGroupID string) error
|
||||
|
@ -444,12 +444,24 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
primary:
|
||||
description: Nameserver group primary status
|
||||
type: boolean
|
||||
domains:
|
||||
description: Nameserver group domain list
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minLength: 1
|
||||
maxLength: 255
|
||||
required:
|
||||
- name
|
||||
- description
|
||||
- nameservers
|
||||
- enabled
|
||||
- groups
|
||||
- primary
|
||||
- domains
|
||||
NameserverGroup:
|
||||
allOf:
|
||||
- type: object
|
||||
|
@ -159,6 +159,9 @@ type NameserverGroup struct {
|
||||
// Description Nameserver group description
|
||||
Description string `json:"description"`
|
||||
|
||||
// Domains Nameserver group domain list
|
||||
Domains []string `json:"domains"`
|
||||
|
||||
// Enabled Nameserver group status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@ -173,6 +176,9 @@ type NameserverGroup struct {
|
||||
|
||||
// Nameservers Nameserver group
|
||||
Nameservers []Nameserver `json:"nameservers"`
|
||||
|
||||
// Primary Nameserver group primary status
|
||||
Primary bool `json:"primary"`
|
||||
}
|
||||
|
||||
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
|
||||
@ -198,6 +204,9 @@ type NameserverGroupRequest struct {
|
||||
// Description Nameserver group description
|
||||
Description string `json:"description"`
|
||||
|
||||
// Domains Nameserver group domain list
|
||||
Domains []string `json:"domains"`
|
||||
|
||||
// Enabled Nameserver group status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
@ -209,6 +218,9 @@ type NameserverGroupRequest struct {
|
||||
|
||||
// Nameservers Nameserver group
|
||||
Nameservers []Nameserver `json:"nameservers"`
|
||||
|
||||
// Primary Nameserver group primary status
|
||||
Primary bool `json:"primary"`
|
||||
}
|
||||
|
||||
// PatchMinimum defines model for PatchMinimum.
|
||||
|
@ -71,7 +71,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Enabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
return
|
||||
|
@ -35,6 +35,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
Name: "super",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
@ -60,7 +61,7 @@ func initNameserversTestData() *Nameservers {
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
},
|
||||
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
return &nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
Name: name,
|
||||
@ -68,6 +69,8 @@ func initNameserversTestData() *Nameservers {
|
||||
NameServers: nameServerList,
|
||||
Groups: groups,
|
||||
Enabled: enabled,
|
||||
Primary: primary,
|
||||
Domains: domains,
|
||||
}, nil
|
||||
},
|
||||
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
|
||||
@ -150,7 +153,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/dns/nameservers",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedNSGroup: &api.NameserverGroup{
|
||||
@ -173,7 +176,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/dns/nameservers",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: false,
|
||||
},
|
||||
@ -182,7 +185,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedNSGroup: &api.NameserverGroup{
|
||||
@ -205,7 +208,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
@ -214,7 +217,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: false,
|
||||
},
|
||||
|
@ -54,7 +54,7 @@ type MockAccountManager struct {
|
||||
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
|
||||
@ -435,9 +435,9 @@ func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*
|
||||
}
|
||||
|
||||
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
if am.CreateNameServerGroupFunc != nil {
|
||||
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, enabled)
|
||||
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/rs/xid"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -20,6 +21,10 @@ const (
|
||||
UpdateNameServerGroupGroups
|
||||
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
|
||||
UpdateNameServerGroupEnabled
|
||||
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
|
||||
UpdateNameServerGroupPrimary
|
||||
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
|
||||
UpdateNameServerGroupDomains
|
||||
)
|
||||
|
||||
// NameServerGroupUpdateOperationType operation type
|
||||
@ -37,6 +42,10 @@ func (t NameServerGroupUpdateOperationType) String() string {
|
||||
return "UpdateNameServerGroupGroups"
|
||||
case UpdateNameServerGroupEnabled:
|
||||
return "UpdateNameServerGroupEnabled"
|
||||
case UpdateNameServerGroupPrimary:
|
||||
return "UpdateNameServerGroupPrimary"
|
||||
case UpdateNameServerGroupDomains:
|
||||
return "UpdateNameServerGroupDomains"
|
||||
default:
|
||||
return "InvalidOperation"
|
||||
}
|
||||
@ -67,7 +76,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
@ -83,6 +92,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
NameServers: nameServerList,
|
||||
Groups: groups,
|
||||
Enabled: enabled,
|
||||
Primary: primary,
|
||||
Domains: domains,
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(false, newNSGroup, account)
|
||||
@ -205,6 +216,18 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newNSGroup.Enabled = enabled
|
||||
case UpdateNameServerGroupPrimary:
|
||||
primary, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newNSGroup.Primary = primary
|
||||
case UpdateNameServerGroupDomains:
|
||||
err = validateDomainInput(false, operation.Values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newNSGroup.Domains = operation.Values
|
||||
}
|
||||
}
|
||||
|
||||
@ -268,7 +291,12 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
||||
}
|
||||
}
|
||||
|
||||
err := validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -286,6 +314,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string) error {
|
||||
if !primary && len(domains) == 0 {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
" it should be primary or have at least one domain")
|
||||
}
|
||||
if primary && len(domains) != 0 {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
|
||||
" you should set either primary or domain")
|
||||
}
|
||||
for _, domain := range domains {
|
||||
_, valid := dns.IsDomainName(domain)
|
||||
if !valid {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||
|
@ -14,6 +14,8 @@ const (
|
||||
existingNSGroupID = "existingNSGroup"
|
||||
nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
|
||||
nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
|
||||
validDomain = "example.com"
|
||||
invalidDomain = "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com"
|
||||
)
|
||||
|
||||
func TestCreateNameServerGroup(t *testing.T) {
|
||||
@ -23,6 +25,8 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
enabled bool
|
||||
groups []string
|
||||
nameServers []nbdns.NameServer
|
||||
primary bool
|
||||
domains []string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@ -33,11 +37,12 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
expectedNSGroup *nbdns.NameServerGroup
|
||||
}{
|
||||
{
|
||||
name: "Create A NS Group",
|
||||
name: "Create A NS Group With Primary Status",
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
groups: []string{group1ID},
|
||||
primary: true,
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
@ -57,6 +62,52 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
expectedNSGroup: &nbdns.NameServerGroup{
|
||||
Name: "super",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
Groups: []string{group1ID},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.2.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create A NS Group With Domains",
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
groups: []string{group1ID},
|
||||
primary: false,
|
||||
domains: []string{validDomain},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.2.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
enabled: true,
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedNSGroup: &nbdns.NameServerGroup{
|
||||
Name: "super",
|
||||
Description: "super",
|
||||
Primary: false,
|
||||
Domains: []string{"example.com"},
|
||||
Groups: []string{group1ID},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -78,6 +129,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: existingNSGroupName,
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{group1ID},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -101,6 +153,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{group1ID},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -124,6 +177,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "1234567890123456789012345678901234567890extra",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{group1ID},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -147,6 +201,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{group1ID},
|
||||
nameServers: []nbdns.NameServer{},
|
||||
enabled: true,
|
||||
@ -159,6 +214,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{group1ID},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -187,6 +243,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -210,6 +267,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{"missingGroup"},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -233,6 +291,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
primary: true,
|
||||
groups: []string{""},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
@ -251,6 +310,53 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Should Not Create If No Domain Or Primary",
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
groups: []string{group1ID},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.2.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
enabled: true,
|
||||
},
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Should Not Create If Domain List Is Invalid",
|
||||
inputArgs: input{
|
||||
name: "super",
|
||||
description: "super",
|
||||
groups: []string{group1ID},
|
||||
domains: []string{invalidDomain},
|
||||
nameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.2.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
enabled: true,
|
||||
},
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
@ -270,6 +376,8 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
testCase.inputArgs.description,
|
||||
testCase.inputArgs.nameServers,
|
||||
testCase.inputArgs.groups,
|
||||
testCase.inputArgs.primary,
|
||||
testCase.inputArgs.domains,
|
||||
testCase.inputArgs.enabled,
|
||||
)
|
||||
|
||||
@ -295,6 +403,7 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
ID: "testingNSGroup",
|
||||
Name: "super",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
@ -313,6 +422,10 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
|
||||
validGroups := []string{group2ID}
|
||||
invalidGroups := []string{"nonExisting"}
|
||||
disabledPrimary := false
|
||||
validDomains := []string{validDomain}
|
||||
invalidDomains := []string{invalidDomain}
|
||||
|
||||
validNameServerList := []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
@ -348,6 +461,8 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
existingNSGroup *nbdns.NameServerGroup
|
||||
newID *string
|
||||
newName *string
|
||||
newPrimary *bool
|
||||
newDomains []string
|
||||
newNSList []nbdns.NameServer
|
||||
newGroups []string
|
||||
skipCopying bool
|
||||
@ -360,12 +475,16 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
existingNSGroup: existingNSGroup,
|
||||
newName: &validName,
|
||||
newGroups: validGroups,
|
||||
newPrimary: &disabledPrimary,
|
||||
newDomains: validDomains,
|
||||
newNSList: validNameServerList,
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedNSGroup: &nbdns.NameServerGroup{
|
||||
ID: "testingNSGroup",
|
||||
Name: validName,
|
||||
Primary: false,
|
||||
Domains: validDomains,
|
||||
Description: "super",
|
||||
NameServers: validNameServerList,
|
||||
Groups: validGroups,
|
||||
@ -435,6 +554,29 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Should Not Update If Domains List Is Empty",
|
||||
existingNSGroup: existingNSGroup,
|
||||
newPrimary: &disabledPrimary,
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Should Not Update If Primary And Domains",
|
||||
existingNSGroup: existingNSGroup,
|
||||
newPrimary: &existingNSGroup.Primary,
|
||||
newDomains: validDomains,
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
{
|
||||
name: "Should Not Update If Domains List Is Invalid",
|
||||
existingNSGroup: existingNSGroup,
|
||||
newPrimary: &disabledPrimary,
|
||||
newDomains: invalidDomains,
|
||||
errFunc: require.Error,
|
||||
shouldCreate: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
@ -475,6 +617,14 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
if testCase.newNSList != nil {
|
||||
nsGroupToSave.NameServers = testCase.newNSList
|
||||
}
|
||||
|
||||
if testCase.newPrimary != nil {
|
||||
nsGroupToSave.Primary = *testCase.newPrimary
|
||||
}
|
||||
|
||||
if testCase.newDomains != nil {
|
||||
nsGroupToSave.Domains = testCase.newDomains
|
||||
}
|
||||
}
|
||||
|
||||
err = am.SaveNameServerGroup(account.Id, nsGroupToSave)
|
||||
@ -503,6 +653,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
|
||||
ID: nsGroupID,
|
||||
Name: "super",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
@ -544,6 +695,7 @@ func TestUpdateNameServerGroup(t *testing.T) {
|
||||
ID: nsGroupID,
|
||||
Name: "superNew",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
@ -585,6 +737,14 @@ func TestUpdateNameServerGroup(t *testing.T) {
|
||||
Type: UpdateNameServerGroupEnabled,
|
||||
Values: []string{"false"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupPrimary,
|
||||
Values: []string{"false"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupDomains,
|
||||
Values: []string{validDomain},
|
||||
},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
@ -592,6 +752,8 @@ func TestUpdateNameServerGroup(t *testing.T) {
|
||||
ID: nsGroupID,
|
||||
Name: "superNew",
|
||||
Description: "superDescription",
|
||||
Primary: false,
|
||||
Domains: []string{validDomain},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("127.0.0.1"),
|
||||
@ -740,6 +902,30 @@ func TestUpdateNameServerGroup(t *testing.T) {
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Update On Invalid Domains",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupDomains,
|
||||
Values: []string{invalidDomain},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Update On Invalid Primary Status",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupPrimary,
|
||||
Values: []string{"yes"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user