mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-03 14:00:34 +02:00
Fix disabled DNS resolver fail (#978)
Fix fail of DNS when it disabled in the settings
This commit is contained in:
parent
c20f98c8b6
commit
774d8e955c
@ -42,7 +42,7 @@ type DefaultServer struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
fakeResolverWG sync.WaitGroup
|
udpFilterHookID string
|
||||||
server *dns.Server
|
server *dns.Server
|
||||||
dnsMux *dns.ServeMux
|
dnsMux *dns.ServeMux
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
@ -105,7 +105,10 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultServer.evalRuntimeAddress()
|
if wgInterface.IsUserspaceBind() {
|
||||||
|
defaultServer.evelRuntimeAddressForUserspace()
|
||||||
|
}
|
||||||
|
|
||||||
return defaultServer, nil
|
return defaultServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,6 +121,9 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !s.wgInterface.IsUserspaceBind() {
|
||||||
|
s.evalRuntimeAddress()
|
||||||
|
}
|
||||||
s.hostManager, err = newHostManager(s.wgInterface)
|
s.hostManager, err = newHostManager(s.wgInterface)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -126,17 +132,8 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
func (s *DefaultServer) listen() {
|
func (s *DefaultServer) listen() {
|
||||||
// nil check required in unit tests
|
// nil check required in unit tests
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||||
s.fakeResolverWG.Add(1)
|
s.udpFilterHookID = s.filterDNSTraffic()
|
||||||
go func() {
|
s.setListenerStatus(true)
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
hookID := s.filterDNSTraffic()
|
|
||||||
s.fakeResolverWG.Wait()
|
|
||||||
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
|
||||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,6 +150,10 @@ func (s *DefaultServer) listen() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DnsIP returns the DNS resolver server IP address
|
||||||
|
//
|
||||||
|
// When kernel space interface used it return real DNS server listener IP address
|
||||||
|
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||||
func (s *DefaultServer) DnsIP() string {
|
func (s *DefaultServer) DnsIP() string {
|
||||||
if !s.enabled {
|
if !s.enabled {
|
||||||
return ""
|
return ""
|
||||||
@ -201,10 +202,6 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
|
||||||
s.fakeResolverWG.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.stopListener()
|
err := s.stopListener()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -212,6 +209,18 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
func (s *DefaultServer) stopListener() error {
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||||
|
// udpFilterHookID here empty only in the unit tests
|
||||||
|
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
|
||||||
|
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||||
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.udpFilterHookID = ""
|
||||||
|
s.listenerIsRunning = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if !s.listenerIsRunning {
|
if !s.listenerIsRunning {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -275,12 +284,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
// is the service should be disabled, we stop the listener or fake resolver
|
// is the service should be disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if !update.ServiceEnable {
|
if !update.ServiceEnable {
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
if err := s.stopListener(); err != nil {
|
||||||
s.fakeResolverWG.Done()
|
log.Error(err)
|
||||||
} else {
|
|
||||||
if err := s.stopListener(); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if !s.listenerIsRunning {
|
} else if !s.listenerIsRunning {
|
||||||
s.listen()
|
s.listen()
|
||||||
@ -555,17 +560,17 @@ func (s *DefaultServer) filterDNSTraffic() string {
|
|||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
|
||||||
|
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||||
|
s.runtimePort = defaultPort
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) evalRuntimeAddress() {
|
func (s *DefaultServer) evalRuntimeAddress() {
|
||||||
defer func() {
|
defer func() {
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
|
||||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
|
||||||
s.runtimePort = defaultPort
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.customAddress != nil {
|
if s.customAddress != nil {
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
|
@ -5,15 +5,18 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
var zoneRecords = []nbdns.SimpleRecord{
|
var zoneRecords = []nbdns.SimpleRecord{
|
||||||
@ -241,7 +244,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
// pretend we are running
|
// pretend we are running
|
||||||
dnsServer.listenerIsRunning = true
|
dnsServer.listenerIsRunning = true
|
||||||
dnsServer.fakeResolverWG.Add(1)
|
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -276,6 +278,133 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||||
|
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||||
|
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
|
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
newNet, err := stdnet.NewNet(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create stdnet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("build interface wireguard: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = wgIface.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("crate and init wireguard interface: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = wgIface.Close(); err != nil {
|
||||||
|
t.Logf("close wireguard interface: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("parse CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
|
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
|
||||||
|
|
||||||
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
|
t.Errorf("set packet filter: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("create DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = dnsServer.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("run DNS server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
|
||||||
|
t.Logf("restore DNS settings on the host: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
|
||||||
|
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
||||||
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
|
nameServers := []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.4.4"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
update := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "netbird.cloud",
|
||||||
|
Records: zoneRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"netbird.io"},
|
||||||
|
NameServers: nameServers,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NameServers: nameServers,
|
||||||
|
Primary: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the server with regular configuration
|
||||||
|
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update2 := update
|
||||||
|
update2.ServiceEnable = false
|
||||||
|
// Disable the server, stop the listener
|
||||||
|
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
update3 := update2
|
||||||
|
update3.NameServerGroups = update3.NameServerGroups[:1]
|
||||||
|
// But service still get updates and we checking that we handle
|
||||||
|
// internal state in the right way
|
||||||
|
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
|
||||||
|
t.Fatalf("update dns server should not fail, got error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSServerStartStop(t *testing.T) {
|
func TestDNSServerStartStop(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
Loading…
Reference in New Issue
Block a user