Fix disabled DNS resolver fail (#978)

Fix fail of DNS when it disabled in the settings
This commit is contained in:
Givi Khojanashvili 2023-06-22 16:59:21 +04:00 committed by GitHub
parent c20f98c8b6
commit 774d8e955c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 164 additions and 30 deletions

View File

@ -42,7 +42,7 @@ type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
mux sync.Mutex
fakeResolverWG sync.WaitGroup
udpFilterHookID string
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap
@ -105,7 +105,10 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
}
defaultServer.evalRuntimeAddress()
if wgInterface.IsUserspaceBind() {
defaultServer.evelRuntimeAddressForUserspace()
}
return defaultServer, nil
}
@ -118,6 +121,9 @@ func (s *DefaultServer) Initialize() (err error) {
return nil
}
if !s.wgInterface.IsUserspaceBind() {
s.evalRuntimeAddress()
}
s.hostManager, err = newHostManager(s.wgInterface)
return
}
@ -126,17 +132,8 @@ func (s *DefaultServer) Initialize() (err error) {
func (s *DefaultServer) listen() {
// nil check required in unit tests
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.fakeResolverWG.Add(1)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
hookID := s.filterDNSTraffic()
s.fakeResolverWG.Wait()
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}()
s.udpFilterHookID = s.filterDNSTraffic()
s.setListenerStatus(true)
return
}
@ -153,6 +150,10 @@ func (s *DefaultServer) listen() {
}()
}
// DnsIP returns the DNS resolver server IP address
//
// When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string {
if !s.enabled {
return ""
@ -201,10 +202,6 @@ func (s *DefaultServer) Stop() {
}
}
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
}
err := s.stopListener()
if err != nil {
log.Error(err)
@ -212,6 +209,18 @@ func (s *DefaultServer) Stop() {
}
func (s *DefaultServer) stopListener() error {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
// udpFilterHookID here empty only in the unit tests
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}
s.udpFilterHookID = ""
s.listenerIsRunning = false
return nil
}
if !s.listenerIsRunning {
return nil
}
@ -275,12 +284,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.fakeResolverWG.Done()
} else {
if err := s.stopListener(); err != nil {
log.Error(err)
}
if err := s.stopListener(); err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.listen()
@ -555,17 +560,17 @@ func (s *DefaultServer) filterDNSTraffic() string {
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}
func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
return
}
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())

View File

@ -5,15 +5,18 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strings"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
)
var zoneRecords = []nbdns.SimpleRecord{
@ -241,7 +244,6 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
dnsServer.fakeResolverWG.Add(1)
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
@ -276,6 +278,133 @@ func TestUpdateDNSServer(t *testing.T) {
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("crate and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().SetNetwork(ipNet)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxMap = registeredHandlerMap{zoneRecords[0].Name: &localResolver{}}
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) {
testCases := []struct {
name string