mirror of
https://github.com/netbirdio/netbird.git
synced 2025-07-21 00:13:26 +02:00
1956 lines
51 KiB
Go
1956 lines
51 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
"github.com/netbirdio/netbird/client/iface"
|
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
"github.com/netbirdio/netbird/client/iface/device"
|
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/formatter"
|
|
"github.com/netbirdio/netbird/management/domain"
|
|
)
|
|
|
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
|
|
|
type mocWGIface struct {
|
|
filter device.PacketFilter
|
|
}
|
|
|
|
func (w *mocWGIface) Name() string {
|
|
return "utun2301"
|
|
}
|
|
|
|
func (w *mocWGIface) Address() wgaddr.Address {
|
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
|
return wgaddr.Address{
|
|
IP: ip,
|
|
Network: network,
|
|
}
|
|
}
|
|
|
|
func (w *mocWGIface) ToInterface() *net.Interface {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) GetFilter() device.PacketFilter {
|
|
return w.filter
|
|
}
|
|
|
|
func (w *mocWGIface) GetDevice() *device.FilteredDevice {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) IsUserspaceBind() bool {
|
|
return false
|
|
}
|
|
|
|
func (w *mocWGIface) SetFilter(filter device.PacketFilter) error {
|
|
w.filter = filter
|
|
return nil
|
|
}
|
|
|
|
func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
|
return configurer.WGStats{}, nil
|
|
}
|
|
|
|
var zoneRecords = []nbdns.SimpleRecord{
|
|
{
|
|
Name: "peera.netbird.cloud",
|
|
Type: 1,
|
|
Class: nbdns.DefaultClass,
|
|
TTL: 300,
|
|
RData: "1.2.3.4",
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
log.SetLevel(log.TraceLevel)
|
|
formatter.SetTextFormatter(log.StandardLogger())
|
|
}
|
|
|
|
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
|
var srvs []string
|
|
for _, srv := range servers {
|
|
srvs = append(srvs, getNSHostPort(srv))
|
|
}
|
|
return &upstreamResolverBase{
|
|
domain: domain,
|
|
upstreamServers: srvs,
|
|
cancel: func() {},
|
|
}
|
|
}
|
|
|
|
func TestUpdateDNSServer(t *testing.T) {
|
|
|
|
nameServers := []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.8.8"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
}
|
|
|
|
dummyHandler := local.NewResolver()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
initUpstreamMap registeredHandlerMap
|
|
initLocalRecords []nbdns.SimpleRecord
|
|
initSerial uint64
|
|
inputSerial uint64
|
|
inputUpdate nbdns.Config
|
|
shouldFail bool
|
|
expectedUpstreamMap registeredHandlerMap
|
|
expectedLocalQs []dns.Question
|
|
}{
|
|
{
|
|
name: "Initial Config Should Succeed",
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"netbird.io"},
|
|
NameServers: nameServers,
|
|
},
|
|
{
|
|
NameServers: nameServers,
|
|
Primary: true,
|
|
},
|
|
},
|
|
},
|
|
expectedUpstreamMap: registeredHandlerMap{
|
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
|
domain: "netbird.io",
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
dummyHandler.ID(): handlerWrapper{
|
|
domain: "netbird.cloud",
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
|
domain: nbdns.RootZone,
|
|
handler: dummyHandler,
|
|
priority: PriorityDefault,
|
|
},
|
|
},
|
|
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
|
},
|
|
{
|
|
name: "New Config Should Succeed",
|
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
|
initUpstreamMap: registeredHandlerMap{
|
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
|
domain: "netbird.cloud",
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"netbird.io"},
|
|
NameServers: nameServers,
|
|
},
|
|
},
|
|
},
|
|
expectedUpstreamMap: registeredHandlerMap{
|
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
|
domain: "netbird.io",
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
"local-resolver": handlerWrapper{
|
|
domain: "netbird.cloud",
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
|
},
|
|
{
|
|
name: "Smaller Config Serial Should Be Skipped",
|
|
initLocalRecords: []nbdns.SimpleRecord{},
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 2,
|
|
inputSerial: 1,
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
|
initLocalRecords: []nbdns.SimpleRecord{},
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: nameServers,
|
|
},
|
|
},
|
|
},
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Invalid NS Group Nameservers list Should Fail",
|
|
initLocalRecords: []nbdns.SimpleRecord{},
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: nameServers,
|
|
},
|
|
},
|
|
},
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Invalid Custom Zone Records list Should Skip",
|
|
initLocalRecords: []nbdns.SimpleRecord{},
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: nameServers,
|
|
Primary: true,
|
|
},
|
|
},
|
|
},
|
|
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
|
domain: ".",
|
|
handler: dummyHandler,
|
|
priority: PriorityDefault,
|
|
}},
|
|
},
|
|
{
|
|
name: "Empty Config Should Succeed and Clean Maps",
|
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
|
initUpstreamMap: registeredHandlerMap{
|
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
|
domain: zoneRecords[0].Name,
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
|
expectedUpstreamMap: make(registeredHandlerMap),
|
|
expectedLocalQs: []dns.Question{},
|
|
},
|
|
{
|
|
name: "Disabled Service Should clean map",
|
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
|
initUpstreamMap: registeredHandlerMap{
|
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
|
domain: zoneRecords[0].Name,
|
|
handler: dummyHandler,
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
|
expectedUpstreamMap: make(registeredHandlerMap),
|
|
expectedLocalQs: []dns.Question{},
|
|
},
|
|
}
|
|
|
|
for n, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
privKey, _ := wgtypes.GenerateKey()
|
|
newNet, err := stdnet.NewNet(nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
opts := iface.WGIFaceOpts{
|
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
|
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
|
WGPort: 33100,
|
|
WGPrivKey: privKey.String(),
|
|
MTU: iface.DefaultMTU,
|
|
TransportNet: newNet,
|
|
}
|
|
|
|
wgIface, err := iface.NewWGIFace(opts)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = wgIface.Create()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err = wgIface.Close()
|
|
if err != nil {
|
|
t.Log(err)
|
|
}
|
|
}()
|
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err = dnsServer.hostManager.restoreHostDNS()
|
|
if err != nil {
|
|
t.Log(err)
|
|
}
|
|
}()
|
|
|
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
|
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
|
dnsServer.updateSerial = testCase.initSerial
|
|
|
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
|
if err != nil {
|
|
if testCase.shouldFail {
|
|
return
|
|
}
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
}
|
|
|
|
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
|
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
|
}
|
|
|
|
for key := range testCase.expectedUpstreamMap {
|
|
_, found := dnsServer.dnsMuxMap[key]
|
|
if !found {
|
|
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
|
}
|
|
}
|
|
|
|
var responseMSG *dns.Msg
|
|
responseWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
responseMSG = m
|
|
return nil
|
|
},
|
|
}
|
|
for _, q := range testCase.expectedLocalQs {
|
|
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
|
Question: []dns.Question{q},
|
|
})
|
|
}
|
|
|
|
if len(testCase.expectedLocalQs) > 0 {
|
|
assert.NotNil(t, responseMSG, "response message should not be nil")
|
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
|
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
|
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
|
if err != nil {
|
|
t.Errorf("create stdnet: %v", err)
|
|
return
|
|
}
|
|
|
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
opts := iface.WGIFaceOpts{
|
|
IFaceName: "utun2301",
|
|
Address: "100.66.100.1/32",
|
|
WGPort: 33100,
|
|
WGPrivKey: privKey.String(),
|
|
MTU: iface.DefaultMTU,
|
|
TransportNet: newNet,
|
|
}
|
|
wgIface, err := iface.NewWGIFace(opts)
|
|
if err != nil {
|
|
t.Errorf("build interface wireguard: %v", err)
|
|
return
|
|
}
|
|
|
|
err = wgIface.Create()
|
|
if err != nil {
|
|
t.Errorf("create and init wireguard interface: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err = wgIface.Close(); err != nil {
|
|
t.Logf("close wireguard interface: %v", err)
|
|
}
|
|
}()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
|
|
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
|
if err != nil {
|
|
t.Errorf("parse CIDR: %v", err)
|
|
return
|
|
}
|
|
|
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
|
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
|
t.Errorf("set packet filter: %v", err)
|
|
return
|
|
}
|
|
|
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false)
|
|
if err != nil {
|
|
t.Errorf("create DNS server: %v", err)
|
|
return
|
|
}
|
|
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("run DNS server: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
|
t.Logf("restore DNS settings on the host: %v", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
|
"id1": handlerWrapper{
|
|
domain: zoneRecords[0].Name,
|
|
handler: &local.Resolver{},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
}
|
|
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
|
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
|
dnsServer.updateSerial = 0
|
|
|
|
nameServers := []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.8.8"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
}
|
|
|
|
update := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"netbird.io"},
|
|
NameServers: nameServers,
|
|
},
|
|
{
|
|
NameServers: nameServers,
|
|
Primary: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Start the server with regular configuration
|
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
return
|
|
}
|
|
|
|
update2 := update
|
|
update2.ServiceEnable = false
|
|
// Disable the server, stop the listener
|
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
return
|
|
}
|
|
|
|
update3 := update2
|
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
|
// But service still get updates and we checking that we handle
|
|
// internal state in the right way
|
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestDNSServerStartStop(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
addrPort string
|
|
}{
|
|
{
|
|
name: "Should Pass With Port Discovery",
|
|
},
|
|
{
|
|
name: "Should Pass With Custom Port",
|
|
addrPort: "127.0.0.1:3535",
|
|
},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false)
|
|
if err != nil {
|
|
t.Fatalf("%v", err)
|
|
}
|
|
dnsServer.hostManager = newNoopHostMocker()
|
|
err = dnsServer.service.Listen()
|
|
if err != nil {
|
|
t.Fatalf("dns server is not running: %s", err)
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
defer dnsServer.Stop()
|
|
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
|
|
|
resolver := &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{
|
|
Timeout: time.Second * 5,
|
|
}
|
|
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
conn, err := d.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
t.Log(err)
|
|
// retry test before exit, for slower systems
|
|
return d.DialContext(ctx, network, addr)
|
|
}
|
|
|
|
return conn, nil
|
|
},
|
|
}
|
|
|
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect to the server, error: %v", err)
|
|
}
|
|
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
|
|
}
|
|
|
|
dnsServer.Stop()
|
|
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
|
|
defer cancel()
|
|
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
|
|
if err == nil {
|
|
t.Fatalf("we should encounter an error when querying a stopped server")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|
hostManager := &mockHostConfigurator{}
|
|
server := DefaultServer{
|
|
ctx: context.Background(),
|
|
service: NewServiceViaMemory(&mocWGIface{}),
|
|
localResolver: local.NewResolver(),
|
|
handlerChain: NewHandlerChain(),
|
|
hostManager: hostManager,
|
|
currentConfig: HostDNSConfig{
|
|
Domains: []DomainConfig{
|
|
{false, "domain0", false},
|
|
{false, "domain1", false},
|
|
{false, "domain2", false},
|
|
},
|
|
},
|
|
statusRecorder: peer.NewRecorder("mgm"),
|
|
}
|
|
|
|
var domainsUpdate string
|
|
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
|
domains := []string{}
|
|
for _, item := range config.Domains {
|
|
if item.Disabled {
|
|
continue
|
|
}
|
|
domains = append(domains, item.Domain)
|
|
}
|
|
domainsUpdate = strings.Join(domains, ",")
|
|
return nil
|
|
}
|
|
|
|
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
|
Domains: []string{"domain1"},
|
|
NameServers: []nbdns.NameServer{
|
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
|
},
|
|
}, nil, 0)
|
|
|
|
deactivate(nil)
|
|
expected := "domain0,domain2"
|
|
domains := []string{}
|
|
for _, item := range server.currentConfig.Domains {
|
|
if item.Disabled {
|
|
continue
|
|
}
|
|
domains = append(domains, item.Domain)
|
|
}
|
|
got := strings.Join(domains, ",")
|
|
if expected != got {
|
|
t.Errorf("expected domains list: %q, got %q", expected, got)
|
|
}
|
|
|
|
reactivate()
|
|
expected = "domain0,domain1,domain2"
|
|
domains = []string{}
|
|
for _, item := range server.currentConfig.Domains {
|
|
if item.Disabled {
|
|
continue
|
|
}
|
|
domains = append(domains, item.Domain)
|
|
}
|
|
got = strings.Join(domains, ",")
|
|
if expected != got {
|
|
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
|
}
|
|
}
|
|
|
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|
wgIFace, err := createWgInterfaceWithBind(t)
|
|
if err != nil {
|
|
t.Fatal("failed to initialize wg interface")
|
|
}
|
|
defer wgIFace.Close()
|
|
|
|
var dnsList []string
|
|
dnsConfig := nbdns.Config{}
|
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("failed to initialize DNS server: %v", err)
|
|
return
|
|
}
|
|
defer dnsServer.Stop()
|
|
|
|
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
|
|
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|
wgIFace, err := createWgInterfaceWithBind(t)
|
|
if err != nil {
|
|
t.Fatal("failed to initialize wg interface")
|
|
}
|
|
defer wgIFace.Close()
|
|
dnsConfig := nbdns.Config{}
|
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("failed to initialize DNS server: %v", err)
|
|
return
|
|
}
|
|
defer dnsServer.Stop()
|
|
|
|
// check initial state
|
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
|
|
update := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
},
|
|
Enabled: true,
|
|
Primary: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
err = dnsServer.UpdateDNSServer(1, update)
|
|
if err != nil {
|
|
t.Errorf("failed to update dns server: %s", err)
|
|
}
|
|
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed resolve zone record: %v", err)
|
|
}
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("invalid zone record: %v", err)
|
|
}
|
|
|
|
update2 := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{},
|
|
}
|
|
|
|
err = dnsServer.UpdateDNSServer(2, update2)
|
|
if err != nil {
|
|
t.Errorf("failed to update dns server: %s", err)
|
|
}
|
|
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
|
|
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed resolve zone record: %v", err)
|
|
}
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("invalid zone record: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
|
wgIFace, err := createWgInterfaceWithBind(t)
|
|
if err != nil {
|
|
t.Fatal("failed to initialize wg interface")
|
|
}
|
|
defer wgIFace.Close()
|
|
dnsConfig := nbdns.Config{}
|
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("failed to initialize DNS server: %v", err)
|
|
return
|
|
}
|
|
defer dnsServer.Stop()
|
|
|
|
// check initial state
|
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
|
|
update := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
{
|
|
IP: netip.MustParseAddr("9.9.9.9"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
},
|
|
Domains: []string{"google.com"},
|
|
Primary: false,
|
|
},
|
|
},
|
|
}
|
|
|
|
err = dnsServer.UpdateDNSServer(1, update)
|
|
if err != nil {
|
|
t.Errorf("failed to update dns server: %s", err)
|
|
}
|
|
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed resolve zone record: %v", err)
|
|
}
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("invalid zone record: %v", err)
|
|
}
|
|
_, err = resolver.LookupHost(context.Background(), "google.com")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
}
|
|
|
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|
t.Helper()
|
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
|
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
|
if err != nil {
|
|
t.Fatalf("create stdnet: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
|
|
opts := iface.WGIFaceOpts{
|
|
IFaceName: "utun2301",
|
|
Address: "100.66.100.2/24",
|
|
WGPort: 33100,
|
|
WGPrivKey: privKey.String(),
|
|
MTU: iface.DefaultMTU,
|
|
TransportNet: newNet,
|
|
}
|
|
|
|
wgIface, err := iface.NewWGIFace(opts)
|
|
if err != nil {
|
|
t.Fatalf("build interface wireguard: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
err = wgIface.Create()
|
|
if err != nil {
|
|
t.Fatalf("create and init wireguard interface: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
|
if err != nil {
|
|
t.Fatalf("failed to create uspfilter: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
err = wgIface.SetFilter(pf)
|
|
if err != nil {
|
|
t.Fatalf("set packet filter: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return wgIface, nil
|
|
}
|
|
|
|
func newDnsResolver(ip string, port int) *net.Resolver {
|
|
return &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{
|
|
Timeout: time.Second * 3,
|
|
}
|
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
|
return d.DialContext(ctx, network, addr)
|
|
},
|
|
}
|
|
}
|
|
|
|
// MockHandler implements dns.Handler interface for testing
|
|
type MockHandler struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
m.Called(w, r)
|
|
}
|
|
|
|
type MockSubdomainHandler struct {
|
|
MockHandler
|
|
Subdomains bool
|
|
}
|
|
|
|
func (m *MockSubdomainHandler) MatchSubdomains() bool {
|
|
return m.Subdomains
|
|
}
|
|
|
|
func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|
chain := NewHandlerChain()
|
|
|
|
dnsRouteHandler := &MockHandler{}
|
|
upstreamHandler := &MockSubdomainHandler{
|
|
Subdomains: true,
|
|
}
|
|
|
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
|
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
query string
|
|
expectedHandler dns.Handler
|
|
}{
|
|
{
|
|
name: "exact domain with dns route handler",
|
|
query: "example.com.",
|
|
expectedHandler: dnsRouteHandler,
|
|
},
|
|
{
|
|
name: "subdomain should use upstream handler",
|
|
query: "sub.example.com.",
|
|
expectedHandler: upstreamHandler,
|
|
},
|
|
{
|
|
name: "deep subdomain should use upstream handler",
|
|
query: "deep.sub.example.com.",
|
|
expectedHandler: upstreamHandler,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
r := new(dns.Msg)
|
|
r.SetQuestion(tc.query, dns.TypeA)
|
|
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
|
|
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
|
}
|
|
|
|
chain.ServeDNS(w, r)
|
|
|
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
|
mh.AssertExpectations(t)
|
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
|
mh.AssertExpectations(t)
|
|
}
|
|
|
|
// Close mocks
|
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
|
mh.ExpectedCalls = nil
|
|
mh.Calls = nil
|
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
|
mh.ExpectedCalls = nil
|
|
mh.Calls = nil
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockHandler struct {
|
|
Id string
|
|
}
|
|
|
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
|
func (m *mockHandler) Stop() {}
|
|
func (m *mockHandler) ProbeAvailability() {}
|
|
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
|
|
|
type mockService struct{}
|
|
|
|
func (m *mockService) Listen() error { return nil }
|
|
func (m *mockService) Stop() {}
|
|
func (m *mockService) RuntimeIP() string { return "127.0.0.1" }
|
|
func (m *mockService) RuntimePort() int { return 53 }
|
|
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
|
func (m *mockService) DeregisterMux(string) {}
|
|
|
|
func TestDefaultServer_UpdateMux(t *testing.T) {
|
|
baseMatchHandlers := registeredHandlerMap{
|
|
"upstream-group1": {
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
"upstream-group2": {
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group2",
|
|
},
|
|
priority: PriorityMatchDomain - 1,
|
|
},
|
|
}
|
|
|
|
baseRootHandlers := registeredHandlerMap{
|
|
"upstream-root1": {
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root1",
|
|
},
|
|
priority: PriorityDefault,
|
|
},
|
|
"upstream-root2": {
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root2",
|
|
},
|
|
priority: PriorityDefault - 1,
|
|
},
|
|
}
|
|
|
|
baseMixedHandlers := registeredHandlerMap{
|
|
"upstream-group1": {
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
"upstream-group2": {
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group2",
|
|
},
|
|
priority: PriorityMatchDomain - 1,
|
|
},
|
|
"upstream-other": {
|
|
domain: "other.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-other",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
initialHandlers registeredHandlerMap
|
|
updates []handlerWrapper
|
|
expectedHandlers map[string]string // map[HandlerID]domain
|
|
description string
|
|
}{
|
|
{
|
|
name: "Remove group1 from update",
|
|
initialHandlers: baseMatchHandlers,
|
|
updates: []handlerWrapper{
|
|
// Only group2 remains
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group2",
|
|
},
|
|
priority: PriorityMatchDomain - 1,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-group2": "example.com",
|
|
},
|
|
description: "When group1 is not included in the update, it should be removed while group2 remains",
|
|
},
|
|
{
|
|
name: "Remove group2 from update",
|
|
initialHandlers: baseMatchHandlers,
|
|
updates: []handlerWrapper{
|
|
// Only group1 remains
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-group1": "example.com",
|
|
},
|
|
description: "When group2 is not included in the update, it should be removed while group1 remains",
|
|
},
|
|
{
|
|
name: "Add group3 in first position",
|
|
initialHandlers: baseMatchHandlers,
|
|
updates: []handlerWrapper{
|
|
// Add group3 with highest priority
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group3",
|
|
},
|
|
priority: PriorityMatchDomain + 1,
|
|
},
|
|
// Keep existing groups with their original priorities
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group2",
|
|
},
|
|
priority: PriorityMatchDomain - 1,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-group1": "example.com",
|
|
"upstream-group2": "example.com",
|
|
"upstream-group3": "example.com",
|
|
},
|
|
description: "When adding group3 with highest priority, it should be first in chain while maintaining existing groups",
|
|
},
|
|
{
|
|
name: "Add group3 in last position",
|
|
initialHandlers: baseMatchHandlers,
|
|
updates: []handlerWrapper{
|
|
// Keep existing groups with their original priorities
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group2",
|
|
},
|
|
priority: PriorityMatchDomain - 1,
|
|
},
|
|
// Add group3 with lowest priority
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group3",
|
|
},
|
|
priority: PriorityMatchDomain - 2,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-group1": "example.com",
|
|
"upstream-group2": "example.com",
|
|
"upstream-group3": "example.com",
|
|
},
|
|
description: "When adding group3 with lowest priority, it should be last in chain while maintaining existing groups",
|
|
},
|
|
// Root zone tests
|
|
{
|
|
name: "Remove root1 from update",
|
|
initialHandlers: baseRootHandlers,
|
|
updates: []handlerWrapper{
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root2",
|
|
},
|
|
priority: PriorityDefault - 1,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-root2": ".",
|
|
},
|
|
description: "When root1 is not included in the update, it should be removed while root2 remains",
|
|
},
|
|
{
|
|
name: "Remove root2 from update",
|
|
initialHandlers: baseRootHandlers,
|
|
updates: []handlerWrapper{
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root1",
|
|
},
|
|
priority: PriorityDefault,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-root1": ".",
|
|
},
|
|
description: "When root2 is not included in the update, it should be removed while root1 remains",
|
|
},
|
|
{
|
|
name: "Add root3 in first position",
|
|
initialHandlers: baseRootHandlers,
|
|
updates: []handlerWrapper{
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root3",
|
|
},
|
|
priority: PriorityDefault + 1,
|
|
},
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root1",
|
|
},
|
|
priority: PriorityDefault,
|
|
},
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root2",
|
|
},
|
|
priority: PriorityDefault - 1,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-root1": ".",
|
|
"upstream-root2": ".",
|
|
"upstream-root3": ".",
|
|
},
|
|
description: "When adding root3 with highest priority, it should be first in chain while maintaining existing root handlers",
|
|
},
|
|
{
|
|
name: "Add root3 in last position",
|
|
initialHandlers: baseRootHandlers,
|
|
updates: []handlerWrapper{
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root1",
|
|
},
|
|
priority: PriorityDefault,
|
|
},
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root2",
|
|
},
|
|
priority: PriorityDefault - 1,
|
|
},
|
|
{
|
|
domain: ".",
|
|
handler: &mockHandler{
|
|
Id: "upstream-root3",
|
|
},
|
|
priority: PriorityDefault - 2,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-root1": ".",
|
|
"upstream-root2": ".",
|
|
"upstream-root3": ".",
|
|
},
|
|
description: "When adding root3 with lowest priority, it should be last in chain while maintaining existing root handlers",
|
|
},
|
|
// Mixed domain tests
|
|
{
|
|
name: "Update with mixed domains - remove one of duplicate domain",
|
|
initialHandlers: baseMixedHandlers,
|
|
updates: []handlerWrapper{
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
{
|
|
domain: "other.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-other",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-group1": "example.com",
|
|
"upstream-other": "other.com",
|
|
},
|
|
description: "When updating mixed domains, should correctly handle removal of one duplicate while maintaining other domains",
|
|
},
|
|
{
|
|
name: "Update with mixed domains - add new domain",
|
|
initialHandlers: baseMixedHandlers,
|
|
updates: []handlerWrapper{
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group1",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
{
|
|
domain: "example.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-group2",
|
|
},
|
|
priority: PriorityMatchDomain - 1,
|
|
},
|
|
{
|
|
domain: "other.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-other",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
{
|
|
domain: "new.com",
|
|
handler: &mockHandler{
|
|
Id: "upstream-new",
|
|
},
|
|
priority: PriorityMatchDomain,
|
|
},
|
|
},
|
|
expectedHandlers: map[string]string{
|
|
"upstream-group1": "example.com",
|
|
"upstream-group2": "example.com",
|
|
"upstream-other": "other.com",
|
|
"upstream-new": "new.com",
|
|
},
|
|
description: "When updating mixed domains, should maintain existing duplicates and add new domain",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := &DefaultServer{
|
|
dnsMuxMap: tt.initialHandlers,
|
|
handlerChain: NewHandlerChain(),
|
|
service: &mockService{},
|
|
}
|
|
|
|
// Perform the update
|
|
server.updateMux(tt.updates)
|
|
|
|
// Verify the results
|
|
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap),
|
|
"Number of handlers after update doesn't match expected")
|
|
|
|
// Check each expected handler
|
|
for id, expectedDomain := range tt.expectedHandlers {
|
|
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
|
assert.True(t, exists, "Expected handler %s not found", id)
|
|
if exists {
|
|
assert.Equal(t, expectedDomain, handler.domain,
|
|
"Domain mismatch for handler %s", id)
|
|
}
|
|
}
|
|
|
|
// Verify no unexpected handlers exist
|
|
for HandlerID := range server.dnsMuxMap {
|
|
_, expected := tt.expectedHandlers[string(HandlerID)]
|
|
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
|
}
|
|
|
|
// Verify the handlerChain state and order
|
|
previousPriority := 0
|
|
for _, chainEntry := range server.handlerChain.handlers {
|
|
// Verify priority order
|
|
if previousPriority > 0 {
|
|
assert.True(t, chainEntry.Priority <= previousPriority,
|
|
"Handlers in chain not properly ordered by priority")
|
|
}
|
|
previousPriority = chainEntry.Priority
|
|
|
|
// Verify handler exists in mux
|
|
foundInMux := false
|
|
for _, muxEntry := range server.dnsMuxMap {
|
|
if chainEntry.Handler == muxEntry.handler &&
|
|
chainEntry.Priority == muxEntry.priority &&
|
|
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
|
|
foundInMux = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, foundInMux,
|
|
"Handler in chain not found in dnsMuxMap")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtraDomains(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
initialConfig nbdns.Config
|
|
registerDomains []domain.List
|
|
deregisterDomains []domain.List
|
|
finalConfig nbdns.Config
|
|
expectedDomains []string
|
|
expectedMatchOnly []string
|
|
applyHostConfigCall int
|
|
}{
|
|
{
|
|
name: "Register domains before config update",
|
|
registerDomains: []domain.List{
|
|
{"extra1.example.com", "extra2.example.com"},
|
|
},
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
},
|
|
expectedDomains: []string{
|
|
"config.example.com.",
|
|
"extra1.example.com.",
|
|
"extra2.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra1.example.com.",
|
|
"extra2.example.com.",
|
|
},
|
|
applyHostConfigCall: 2,
|
|
},
|
|
{
|
|
name: "Register domains after config update",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra1.example.com", "extra2.example.com"},
|
|
},
|
|
expectedDomains: []string{
|
|
"config.example.com.",
|
|
"extra1.example.com.",
|
|
"extra2.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra1.example.com.",
|
|
"extra2.example.com.",
|
|
},
|
|
applyHostConfigCall: 2,
|
|
},
|
|
{
|
|
name: "Register overlapping domains",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
{Domain: "overlap.example.com"},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra.example.com", "overlap.example.com"},
|
|
},
|
|
expectedDomains: []string{
|
|
"config.example.com.",
|
|
"overlap.example.com.",
|
|
"extra.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra.example.com.",
|
|
},
|
|
applyHostConfigCall: 2,
|
|
},
|
|
{
|
|
name: "Register and deregister domains",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra1.example.com", "extra2.example.com"},
|
|
{"extra3.example.com", "extra4.example.com"},
|
|
},
|
|
deregisterDomains: []domain.List{
|
|
{"extra1.example.com", "extra3.example.com"},
|
|
},
|
|
expectedDomains: []string{
|
|
"config.example.com.",
|
|
"extra2.example.com.",
|
|
"extra4.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra2.example.com.",
|
|
"extra4.example.com.",
|
|
},
|
|
applyHostConfigCall: 4,
|
|
},
|
|
{
|
|
name: "Register domains with ref counter",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra.example.com", "duplicate.example.com"},
|
|
{"other.example.com", "duplicate.example.com"},
|
|
},
|
|
deregisterDomains: []domain.List{
|
|
{"duplicate.example.com"},
|
|
},
|
|
expectedDomains: []string{
|
|
"config.example.com.",
|
|
"extra.example.com.",
|
|
"other.example.com.",
|
|
"duplicate.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra.example.com.",
|
|
"other.example.com.",
|
|
"duplicate.example.com.",
|
|
},
|
|
applyHostConfigCall: 4,
|
|
},
|
|
{
|
|
name: "Config update with new domains after registration",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra.example.com", "duplicate.example.com"},
|
|
},
|
|
finalConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
{Domain: "newconfig.example.com"},
|
|
},
|
|
},
|
|
expectedDomains: []string{
|
|
"config.example.com.",
|
|
"newconfig.example.com.",
|
|
"extra.example.com.",
|
|
"duplicate.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra.example.com.",
|
|
"duplicate.example.com.",
|
|
},
|
|
applyHostConfigCall: 3,
|
|
},
|
|
{
|
|
name: "Deregister domain that is part of customZones",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
{Domain: "protected.example.com"},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra.example.com", "protected.example.com"},
|
|
},
|
|
deregisterDomains: []domain.List{
|
|
{"protected.example.com"},
|
|
},
|
|
expectedDomains: []string{
|
|
"extra.example.com.",
|
|
"config.example.com.",
|
|
"protected.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"extra.example.com.",
|
|
},
|
|
applyHostConfigCall: 3,
|
|
},
|
|
{
|
|
name: "Register domain that is part of nameserver group",
|
|
initialConfig: nbdns.Config{
|
|
ServiceEnable: true,
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
|
NameServers: []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.8.8"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
registerDomains: []domain.List{
|
|
{"extra.example.com", "overlap.ns.example.com"},
|
|
},
|
|
expectedDomains: []string{
|
|
"ns.example.com.",
|
|
"overlap.ns.example.com.",
|
|
"extra.example.com.",
|
|
},
|
|
expectedMatchOnly: []string{
|
|
"ns.example.com.",
|
|
"overlap.ns.example.com.",
|
|
"extra.example.com.",
|
|
},
|
|
applyHostConfigCall: 2,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var capturedConfigs []HostDNSConfig
|
|
mockHostConfig := &mockHostConfigurator{
|
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
|
capturedConfigs = append(capturedConfigs, config)
|
|
return nil
|
|
},
|
|
restoreHostDNSFunc: func() error {
|
|
return nil
|
|
},
|
|
supportCustomPortFunc: func() bool {
|
|
return true
|
|
},
|
|
stringFunc: func() string {
|
|
return "mock"
|
|
},
|
|
}
|
|
|
|
mockSvc := &mockService{}
|
|
|
|
server := &DefaultServer{
|
|
ctx: context.Background(),
|
|
handlerChain: NewHandlerChain(),
|
|
wgInterface: &mocWGIface{},
|
|
hostManager: mockHostConfig,
|
|
localResolver: &local.Resolver{},
|
|
service: mockSvc,
|
|
statusRecorder: peer.NewRecorder("test"),
|
|
extraDomains: make(map[domain.Domain]int),
|
|
}
|
|
|
|
// Apply initial configuration
|
|
if tt.initialConfig.ServiceEnable {
|
|
err := server.applyConfiguration(tt.initialConfig)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Register domains
|
|
for _, domains := range tt.registerDomains {
|
|
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
|
}
|
|
|
|
// Deregister domains if specified
|
|
for _, domains := range tt.deregisterDomains {
|
|
server.DeregisterHandler(domains, PriorityDefault)
|
|
}
|
|
|
|
// Apply final configuration if specified
|
|
if tt.finalConfig.ServiceEnable {
|
|
err := server.applyConfiguration(tt.finalConfig)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Verify number of calls
|
|
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
|
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
|
|
|
// Get the last applied config
|
|
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
|
|
|
// Check all expected domains are present
|
|
domainMap := make(map[string]bool)
|
|
matchOnlyMap := make(map[string]bool)
|
|
|
|
for _, d := range lastConfig.Domains {
|
|
domainMap[d.Domain] = true
|
|
if d.MatchOnly {
|
|
matchOnlyMap[d.Domain] = true
|
|
}
|
|
}
|
|
|
|
// Verify expected domains
|
|
for _, d := range tt.expectedDomains {
|
|
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
|
}
|
|
|
|
// Verify match-only domains
|
|
for _, d := range tt.expectedMatchOnly {
|
|
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
|
}
|
|
|
|
// Verify no unexpected domains
|
|
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
|
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtraDomainsRefCounting(t *testing.T) {
|
|
mockHostConfig := &mockHostConfigurator{
|
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
|
return nil
|
|
},
|
|
restoreHostDNSFunc: func() error {
|
|
return nil
|
|
},
|
|
supportCustomPortFunc: func() bool {
|
|
return true
|
|
},
|
|
stringFunc: func() string {
|
|
return "mock"
|
|
},
|
|
}
|
|
|
|
mockSvc := &mockService{}
|
|
|
|
server := &DefaultServer{
|
|
ctx: context.Background(),
|
|
handlerChain: NewHandlerChain(),
|
|
hostManager: mockHostConfig,
|
|
localResolver: &local.Resolver{},
|
|
service: mockSvc,
|
|
statusRecorder: peer.NewRecorder("test"),
|
|
extraDomains: make(map[domain.Domain]int),
|
|
}
|
|
|
|
// Register domains from different handlers with same domain
|
|
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
|
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
|
|
|
// Verify refcount is 2
|
|
zoneKey := toZone("shared.example.com")
|
|
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
|
|
|
// Deregister one handler
|
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
|
|
|
// Verify refcount is 1
|
|
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
|
|
|
// Deregister the other handler
|
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
|
|
|
// Verify domain is removed
|
|
_, exists := server.extraDomains[zoneKey]
|
|
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
|
}
|
|
|
|
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
|
var capturedConfig HostDNSConfig
|
|
mockHostConfig := &mockHostConfigurator{
|
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
|
capturedConfig = config
|
|
return nil
|
|
},
|
|
restoreHostDNSFunc: func() error {
|
|
return nil
|
|
},
|
|
supportCustomPortFunc: func() bool {
|
|
return true
|
|
},
|
|
stringFunc: func() string {
|
|
return "mock"
|
|
},
|
|
}
|
|
|
|
mockSvc := &mockService{}
|
|
|
|
server := &DefaultServer{
|
|
ctx: context.Background(),
|
|
handlerChain: NewHandlerChain(),
|
|
hostManager: mockHostConfig,
|
|
localResolver: &local.Resolver{},
|
|
service: mockSvc,
|
|
statusRecorder: peer.NewRecorder("test"),
|
|
extraDomains: make(map[domain.Domain]int),
|
|
}
|
|
|
|
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
|
|
|
initialConfig := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
}
|
|
err := server.applyConfiguration(initialConfig)
|
|
assert.NoError(t, err)
|
|
|
|
var domains []string
|
|
for _, d := range capturedConfig.Domains {
|
|
domains = append(domains, d.Domain)
|
|
}
|
|
assert.Contains(t, domains, "config.example.com.")
|
|
assert.Contains(t, domains, "extra.example.com.")
|
|
|
|
// Now apply a new configuration with overlapping domain
|
|
updatedConfig := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
{Domain: "extra.example.com"},
|
|
},
|
|
}
|
|
err = server.applyConfiguration(updatedConfig)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify both domains are in config, but no duplicates
|
|
domains = []string{}
|
|
matchOnlyCount := 0
|
|
for _, d := range capturedConfig.Domains {
|
|
domains = append(domains, d.Domain)
|
|
if d.MatchOnly {
|
|
matchOnlyCount++
|
|
}
|
|
}
|
|
|
|
assert.Contains(t, domains, "config.example.com.")
|
|
assert.Contains(t, domains, "extra.example.com.")
|
|
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
|
|
|
// Extra domain should no longer be marked as match-only when in config
|
|
matchOnlyDomain := ""
|
|
for _, d := range capturedConfig.Domains {
|
|
if d.Domain == "extra.example.com." && d.MatchOnly {
|
|
matchOnlyDomain = d.Domain
|
|
break
|
|
}
|
|
}
|
|
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
|
}
|
|
|
|
func TestDomainCaseHandling(t *testing.T) {
|
|
var capturedConfig HostDNSConfig
|
|
mockHostConfig := &mockHostConfigurator{
|
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
|
capturedConfig = config
|
|
return nil
|
|
},
|
|
restoreHostDNSFunc: func() error {
|
|
return nil
|
|
},
|
|
supportCustomPortFunc: func() bool {
|
|
return true
|
|
},
|
|
stringFunc: func() string {
|
|
return "mock"
|
|
},
|
|
}
|
|
|
|
mockSvc := &mockService{}
|
|
server := &DefaultServer{
|
|
ctx: context.Background(),
|
|
handlerChain: NewHandlerChain(),
|
|
hostManager: mockHostConfig,
|
|
localResolver: &local.Resolver{},
|
|
service: mockSvc,
|
|
statusRecorder: peer.NewRecorder("test"),
|
|
extraDomains: make(map[domain.Domain]int),
|
|
}
|
|
|
|
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
|
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
|
|
|
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
|
|
|
config := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{Domain: "config.example.com"},
|
|
},
|
|
}
|
|
err := server.applyConfiguration(config)
|
|
assert.NoError(t, err)
|
|
|
|
var domains []string
|
|
for _, d := range capturedConfig.Domains {
|
|
domains = append(domains, d.Domain)
|
|
}
|
|
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
|
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
|
}
|