mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-26 15:59:07 +01:00
ddc365f7a0
--------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
962 lines
25 KiB
Go
962 lines
25 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/mock"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
"github.com/netbirdio/netbird/client/iface"
|
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
"github.com/netbirdio/netbird/client/iface/device"
|
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/formatter"
|
|
)
|
|
|
|
type mocWGIface struct {
|
|
filter device.PacketFilter
|
|
}
|
|
|
|
func (w *mocWGIface) Name() string {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) Address() iface.WGAddress {
|
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
|
return iface.WGAddress{
|
|
IP: ip,
|
|
Network: network,
|
|
}
|
|
}
|
|
|
|
func (w *mocWGIface) ToInterface() *net.Interface {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) GetFilter() device.PacketFilter {
|
|
return w.filter
|
|
}
|
|
|
|
func (w *mocWGIface) GetDevice() *device.FilteredDevice {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
|
panic("implement me")
|
|
}
|
|
|
|
func (w *mocWGIface) IsUserspaceBind() bool {
|
|
return false
|
|
}
|
|
|
|
func (w *mocWGIface) SetFilter(filter device.PacketFilter) error {
|
|
w.filter = filter
|
|
return nil
|
|
}
|
|
|
|
func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
|
return configurer.WGStats{}, nil
|
|
}
|
|
|
|
var zoneRecords = []nbdns.SimpleRecord{
|
|
{
|
|
Name: "peera.netbird.cloud",
|
|
Type: 1,
|
|
Class: nbdns.DefaultClass,
|
|
TTL: 300,
|
|
RData: "1.2.3.4",
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
log.SetLevel(log.TraceLevel)
|
|
formatter.SetTextFormatter(log.StandardLogger())
|
|
}
|
|
|
|
func TestUpdateDNSServer(t *testing.T) {
|
|
nameServers := []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.8.8"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
}
|
|
|
|
dummyHandler := &localResolver{}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
initUpstreamMap registeredHandlerMap
|
|
initLocalMap registrationMap
|
|
initSerial uint64
|
|
inputSerial uint64
|
|
inputUpdate nbdns.Config
|
|
shouldFail bool
|
|
expectedUpstreamMap registeredHandlerMap
|
|
expectedLocalMap registrationMap
|
|
}{
|
|
{
|
|
name: "Initial Config Should Succeed",
|
|
initLocalMap: make(registrationMap),
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"netbird.io"},
|
|
NameServers: nameServers,
|
|
},
|
|
{
|
|
NameServers: nameServers,
|
|
Primary: true,
|
|
},
|
|
},
|
|
},
|
|
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler, nbdns.RootZone: dummyHandler},
|
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
|
},
|
|
{
|
|
name: "New Config Should Succeed",
|
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
|
initUpstreamMap: registeredHandlerMap{buildRecordKey(zoneRecords[0].Name, 1, 1): dummyHandler},
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"netbird.io"},
|
|
NameServers: nameServers,
|
|
},
|
|
},
|
|
},
|
|
expectedUpstreamMap: registeredHandlerMap{"netbird.io": dummyHandler, "netbird.cloud": dummyHandler},
|
|
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
|
},
|
|
{
|
|
name: "Smaller Config Serial Should Be Skipped",
|
|
initLocalMap: make(registrationMap),
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 2,
|
|
inputSerial: 1,
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
|
initLocalMap: make(registrationMap),
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: nameServers,
|
|
},
|
|
},
|
|
},
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Invalid NS Group Nameservers list Should Fail",
|
|
initLocalMap: make(registrationMap),
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: nameServers,
|
|
},
|
|
},
|
|
},
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Invalid Custom Zone Records list Should Fail",
|
|
initLocalMap: make(registrationMap),
|
|
initUpstreamMap: make(registeredHandlerMap),
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: nameServers,
|
|
Primary: true,
|
|
},
|
|
},
|
|
},
|
|
shouldFail: true,
|
|
},
|
|
{
|
|
name: "Empty Config Should Succeed and Clean Maps",
|
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
|
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
|
expectedUpstreamMap: make(registeredHandlerMap),
|
|
expectedLocalMap: make(registrationMap),
|
|
},
|
|
{
|
|
name: "Disabled Service Should clean map",
|
|
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
|
initUpstreamMap: registeredHandlerMap{zoneRecords[0].Name: dummyHandler},
|
|
initSerial: 0,
|
|
inputSerial: 1,
|
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
|
expectedUpstreamMap: make(registeredHandlerMap),
|
|
expectedLocalMap: make(registrationMap),
|
|
},
|
|
}
|
|
|
|
for n, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
privKey, _ := wgtypes.GenerateKey()
|
|
newNet, err := stdnet.NewNet(nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
opts := iface.WGIFaceOpts{
|
|
IFaceName: fmt.Sprintf("utun230%d", n),
|
|
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
|
WGPort: 33100,
|
|
WGPrivKey: privKey.String(),
|
|
MTU: iface.DefaultMTU,
|
|
TransportNet: newNet,
|
|
}
|
|
|
|
wgIface, err := iface.NewWGIFace(opts)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = wgIface.Create()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err = wgIface.Close()
|
|
if err != nil {
|
|
t.Log(err)
|
|
}
|
|
}()
|
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func() {
|
|
err = dnsServer.hostManager.restoreHostDNS()
|
|
if err != nil {
|
|
t.Log(err)
|
|
}
|
|
}()
|
|
|
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
|
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
|
dnsServer.updateSerial = testCase.initSerial
|
|
|
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
|
if err != nil {
|
|
if testCase.shouldFail {
|
|
return
|
|
}
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
}
|
|
|
|
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) {
|
|
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap))
|
|
}
|
|
|
|
for key := range testCase.expectedUpstreamMap {
|
|
_, found := dnsServer.dnsMuxMap[key]
|
|
if !found {
|
|
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap)
|
|
}
|
|
}
|
|
|
|
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
|
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
|
}
|
|
|
|
for key := range testCase.expectedLocalMap {
|
|
_, found := dnsServer.localResolver.registeredMap[key]
|
|
if !found {
|
|
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
|
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
newNet, err := stdnet.NewNet(nil)
|
|
if err != nil {
|
|
t.Errorf("create stdnet: %v", err)
|
|
return
|
|
}
|
|
|
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
opts := iface.WGIFaceOpts{
|
|
IFaceName: "utun2301",
|
|
Address: "100.66.100.1/32",
|
|
WGPort: 33100,
|
|
WGPrivKey: privKey.String(),
|
|
MTU: iface.DefaultMTU,
|
|
TransportNet: newNet,
|
|
}
|
|
wgIface, err := iface.NewWGIFace(opts)
|
|
if err != nil {
|
|
t.Errorf("build interface wireguard: %v", err)
|
|
return
|
|
}
|
|
|
|
err = wgIface.Create()
|
|
if err != nil {
|
|
t.Errorf("create and init wireguard interface: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err = wgIface.Close(); err != nil {
|
|
t.Logf("close wireguard interface: %v", err)
|
|
}
|
|
}()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
|
|
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
|
if err != nil {
|
|
t.Errorf("parse CIDR: %v", err)
|
|
return
|
|
}
|
|
|
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
|
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
|
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
|
t.Errorf("set packet filter: %v", err)
|
|
return
|
|
}
|
|
|
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
|
if err != nil {
|
|
t.Errorf("create DNS server: %v", err)
|
|
return
|
|
}
|
|
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("run DNS server: %v", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
|
t.Logf("restore DNS settings on the host: %v", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
|
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
|
dnsServer.updateSerial = 0
|
|
|
|
nameServers := []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.8.8"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
}
|
|
|
|
update := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
Domains: []string{"netbird.io"},
|
|
NameServers: nameServers,
|
|
},
|
|
{
|
|
NameServers: nameServers,
|
|
Primary: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Start the server with regular configuration
|
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
return
|
|
}
|
|
|
|
update2 := update
|
|
update2.ServiceEnable = false
|
|
// Disable the server, stop the listener
|
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
return
|
|
}
|
|
|
|
update3 := update2
|
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
|
// But service still get updates and we checking that we handle
|
|
// internal state in the right way
|
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestDNSServerStartStop(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
addrPort string
|
|
}{
|
|
{
|
|
name: "Should Pass With Port Discovery",
|
|
},
|
|
{
|
|
name: "Should Pass With Custom Port",
|
|
addrPort: "127.0.0.1:3535",
|
|
},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
|
if err != nil {
|
|
t.Fatalf("%v", err)
|
|
}
|
|
dnsServer.hostManager = newNoopHostMocker()
|
|
err = dnsServer.service.Listen()
|
|
if err != nil {
|
|
t.Fatalf("dns server is not running: %s", err)
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
defer dnsServer.Stop()
|
|
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
|
|
|
resolver := &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{
|
|
Timeout: time.Second * 5,
|
|
}
|
|
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
conn, err := d.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
t.Log(err)
|
|
// retry test before exit, for slower systems
|
|
return d.DialContext(ctx, network, addr)
|
|
}
|
|
|
|
return conn, nil
|
|
},
|
|
}
|
|
|
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed to connect to the server, error: %v", err)
|
|
}
|
|
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("got a different IP from the server: want %s, got %s", zoneRecords[0].RData, ips[0])
|
|
}
|
|
|
|
dnsServer.Stop()
|
|
ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1)
|
|
defer cancel()
|
|
_, err = resolver.LookupHost(ctx, zoneRecords[0].Name)
|
|
if err == nil {
|
|
t.Fatalf("we should encounter an error when querying a stopped server")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|
hostManager := &mockHostConfigurator{}
|
|
server := DefaultServer{
|
|
ctx: context.Background(),
|
|
service: NewServiceViaMemory(&mocWGIface{}),
|
|
localResolver: &localResolver{
|
|
registeredMap: make(registrationMap),
|
|
},
|
|
handlerChain: NewHandlerChain(),
|
|
handlerPriorities: make(map[string]int),
|
|
hostManager: hostManager,
|
|
currentConfig: HostDNSConfig{
|
|
Domains: []DomainConfig{
|
|
{false, "domain0", false},
|
|
{false, "domain1", false},
|
|
{false, "domain2", false},
|
|
},
|
|
},
|
|
statusRecorder: &peer.Status{},
|
|
}
|
|
|
|
var domainsUpdate string
|
|
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
|
domains := []string{}
|
|
for _, item := range config.Domains {
|
|
if item.Disabled {
|
|
continue
|
|
}
|
|
domains = append(domains, item.Domain)
|
|
}
|
|
domainsUpdate = strings.Join(domains, ",")
|
|
return nil
|
|
}
|
|
|
|
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
|
Domains: []string{"domain1"},
|
|
NameServers: []nbdns.NameServer{
|
|
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
|
},
|
|
}, nil)
|
|
|
|
deactivate(nil)
|
|
expected := "domain0,domain2"
|
|
domains := []string{}
|
|
for _, item := range server.currentConfig.Domains {
|
|
if item.Disabled {
|
|
continue
|
|
}
|
|
domains = append(domains, item.Domain)
|
|
}
|
|
got := strings.Join(domains, ",")
|
|
if expected != got {
|
|
t.Errorf("expected domains list: %q, got %q", expected, got)
|
|
}
|
|
|
|
reactivate()
|
|
expected = "domain0,domain1,domain2"
|
|
domains = []string{}
|
|
for _, item := range server.currentConfig.Domains {
|
|
if item.Disabled {
|
|
continue
|
|
}
|
|
domains = append(domains, item.Domain)
|
|
}
|
|
got = strings.Join(domains, ",")
|
|
if expected != got {
|
|
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
|
}
|
|
}
|
|
|
|
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|
wgIFace, err := createWgInterfaceWithBind(t)
|
|
if err != nil {
|
|
t.Fatal("failed to initialize wg interface")
|
|
}
|
|
defer wgIFace.Close()
|
|
|
|
var dnsList []string
|
|
dnsConfig := nbdns.Config{}
|
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("failed to initialize DNS server: %v", err)
|
|
return
|
|
}
|
|
defer dnsServer.Stop()
|
|
|
|
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
|
|
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|
wgIFace, err := createWgInterfaceWithBind(t)
|
|
if err != nil {
|
|
t.Fatal("failed to initialize wg interface")
|
|
}
|
|
defer wgIFace.Close()
|
|
dnsConfig := nbdns.Config{}
|
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("failed to initialize DNS server: %v", err)
|
|
return
|
|
}
|
|
defer dnsServer.Stop()
|
|
|
|
// check initial state
|
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
|
|
update := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
},
|
|
Enabled: true,
|
|
Primary: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
err = dnsServer.UpdateDNSServer(1, update)
|
|
if err != nil {
|
|
t.Errorf("failed to update dns server: %s", err)
|
|
}
|
|
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed resolve zone record: %v", err)
|
|
}
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("invalid zone record: %v", err)
|
|
}
|
|
|
|
update2 := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{},
|
|
}
|
|
|
|
err = dnsServer.UpdateDNSServer(2, update2)
|
|
if err != nil {
|
|
t.Errorf("failed to update dns server: %s", err)
|
|
}
|
|
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
|
|
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed resolve zone record: %v", err)
|
|
}
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("invalid zone record: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDNSPermanent_matchOnly(t *testing.T) {
|
|
wgIFace, err := createWgInterfaceWithBind(t)
|
|
if err != nil {
|
|
t.Fatal("failed to initialize wg interface")
|
|
}
|
|
defer wgIFace.Close()
|
|
dnsConfig := nbdns.Config{}
|
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
|
err = dnsServer.Initialize()
|
|
if err != nil {
|
|
t.Errorf("failed to initialize DNS server: %v", err)
|
|
return
|
|
}
|
|
defer dnsServer.Stop()
|
|
|
|
// check initial state
|
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
|
|
update := nbdns.Config{
|
|
ServiceEnable: true,
|
|
CustomZones: []nbdns.CustomZone{
|
|
{
|
|
Domain: "netbird.cloud",
|
|
Records: zoneRecords,
|
|
},
|
|
},
|
|
NameServerGroups: []*nbdns.NameServerGroup{
|
|
{
|
|
NameServers: []nbdns.NameServer{
|
|
{
|
|
IP: netip.MustParseAddr("8.8.4.4"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
{
|
|
IP: netip.MustParseAddr("9.9.9.9"),
|
|
NSType: nbdns.UDPNameServerType,
|
|
Port: 53,
|
|
},
|
|
},
|
|
Domains: []string{"google.com"},
|
|
Primary: false,
|
|
},
|
|
},
|
|
}
|
|
|
|
err = dnsServer.UpdateDNSServer(1, update)
|
|
if err != nil {
|
|
t.Errorf("failed to update dns server: %s", err)
|
|
}
|
|
|
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
|
if err != nil {
|
|
t.Fatalf("failed resolve zone record: %v", err)
|
|
}
|
|
if ips[0] != zoneRecords[0].RData {
|
|
t.Fatalf("invalid zone record: %v", err)
|
|
}
|
|
_, err = resolver.LookupHost(context.Background(), "google.com")
|
|
if err != nil {
|
|
t.Errorf("failed to resolve: %s", err)
|
|
}
|
|
}
|
|
|
|
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|
t.Helper()
|
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
|
|
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
newNet, err := stdnet.NewNet(nil)
|
|
if err != nil {
|
|
t.Fatalf("create stdnet: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
privKey, _ := wgtypes.GeneratePrivateKey()
|
|
|
|
opts := iface.WGIFaceOpts{
|
|
IFaceName: "utun2301",
|
|
Address: "100.66.100.2/24",
|
|
WGPort: 33100,
|
|
WGPrivKey: privKey.String(),
|
|
MTU: iface.DefaultMTU,
|
|
TransportNet: newNet,
|
|
}
|
|
|
|
wgIface, err := iface.NewWGIFace(opts)
|
|
if err != nil {
|
|
t.Fatalf("build interface wireguard: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
err = wgIface.Create()
|
|
if err != nil {
|
|
t.Fatalf("create and init wireguard interface: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
pf, err := uspfilter.Create(wgIface)
|
|
if err != nil {
|
|
t.Fatalf("failed to create uspfilter: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
err = wgIface.SetFilter(pf)
|
|
if err != nil {
|
|
t.Fatalf("set packet filter: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return wgIface, nil
|
|
}
|
|
|
|
func newDnsResolver(ip string, port int) *net.Resolver {
|
|
return &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{
|
|
Timeout: time.Second * 3,
|
|
}
|
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
|
return d.DialContext(ctx, network, addr)
|
|
},
|
|
}
|
|
}
|
|
|
|
// MockHandler implements dns.Handler interface for testing
|
|
type MockHandler struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
m.Called(w, r)
|
|
}
|
|
|
|
type MockSubdomainHandler struct {
|
|
MockHandler
|
|
Subdomains bool
|
|
}
|
|
|
|
func (m *MockSubdomainHandler) MatchSubdomains() bool {
|
|
return m.Subdomains
|
|
}
|
|
|
|
func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|
chain := NewHandlerChain()
|
|
|
|
dnsRouteHandler := &MockHandler{}
|
|
upstreamHandler := &MockSubdomainHandler{
|
|
Subdomains: true,
|
|
}
|
|
|
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
query string
|
|
expectedHandler dns.Handler
|
|
}{
|
|
{
|
|
name: "exact domain with dns route handler",
|
|
query: "example.com.",
|
|
expectedHandler: dnsRouteHandler,
|
|
},
|
|
{
|
|
name: "subdomain should use upstream handler",
|
|
query: "sub.example.com.",
|
|
expectedHandler: upstreamHandler,
|
|
},
|
|
{
|
|
name: "deep subdomain should use upstream handler",
|
|
query: "deep.sub.example.com.",
|
|
expectedHandler: upstreamHandler,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
r := new(dns.Msg)
|
|
r.SetQuestion(tc.query, dns.TypeA)
|
|
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
|
|
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
|
}
|
|
|
|
chain.ServeDNS(w, r)
|
|
|
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
|
mh.AssertExpectations(t)
|
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
|
mh.AssertExpectations(t)
|
|
}
|
|
|
|
// Reset mocks
|
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
|
mh.ExpectedCalls = nil
|
|
mh.Calls = nil
|
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
|
mh.ExpectedCalls = nil
|
|
mh.Calls = nil
|
|
}
|
|
})
|
|
}
|
|
}
|