diff --git a/client/internal/dns/mockServer.go b/client/internal/dns/mockServer.go index 8a7adabd7..ff218b888 100644 --- a/client/internal/dns/mockServer.go +++ b/client/internal/dns/mockServer.go @@ -7,16 +7,17 @@ import ( // MockServer is the mock instance of a dns server type MockServer struct { - StartFunc func() + InitializeFunc func() error StopFunc func() UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error } -// Start mock implementation of Start from Server interface -func (m *MockServer) Start() { - if m.StartFunc != nil { - m.StartFunc() +// Initialize mock implementation of Initialize from Server interface +func (m *MockServer) Initialize() error { + if m.InitializeFunc != nil { + return m.InitializeFunc() } + return nil } // Stop mock implementation of Stop from Server interface diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a543d469e..dbbf4e602 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -29,7 +29,7 @@ const ( // Server is a dns server interface type Server interface { - Start() + Initialize() error Stop() DnsIP() string UpdateDNSServer(serial uint64, update nbdns.Config) error @@ -82,11 +82,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd addrPort = &parsedAddrPort } - hostManager, err := newHostManager(wgInterface) - if err != nil { - return nil, err - } - ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ @@ -104,7 +99,6 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd }, wgInterface: wgInterface, customAddress: addrPort, - hostManager: hostManager, } if initialDnsCfg != nil { @@ -115,8 +109,21 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd return defaultServer, nil } -// Start runs the listener in a go routine -func (s *DefaultServer) Start() { +// Initialize instantiate host manager. It required to be initialized wginterface +func (s *DefaultServer) Initialize() (err error) { + s.mux.Lock() + defer s.mux.Unlock() + + if s.hostManager != nil { + return nil + } + + s.hostManager, err = newHostManager(s.wgInterface) + return +} + +// listen runs the listener in a go routine +func (s *DefaultServer) listen() { // nil check required in unit tests if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { s.fakeResolverWG.Add(1) @@ -231,6 +238,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro s.mux.Lock() defer s.mux.Unlock() + if s.hostManager == nil { + return fmt.Errorf("dns service is not initialized yet") + } + hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ ZeroNil: true, IgnoreZeroValue: true, @@ -270,7 +281,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { } } } else if !s.listenerIsRunning { - s.Start() + s.listen() } localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 56b44abf6..46ab169fe 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -225,6 +225,10 @@ func TestUpdateDNSServer(t *testing.T) { if err != nil { t.Fatal(err) } + err = dnsServer.Initialize() + if err != nil { + t.Fatal(err) + } defer func() { err = dnsServer.hostManager.restoreHostDNS() if err != nil { @@ -291,7 +295,7 @@ func TestDNSServerStartStop(t *testing.T) { dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort) dnsServer.hostManager = newNoopHostMocker() - dnsServer.Start() + dnsServer.listen() time.Sleep(100 * time.Millisecond) if !dnsServer.listenerIsRunning { t.Fatal("dns server listener is not running") diff --git a/client/internal/engine.go b/client/internal/engine.go index 9b277602b..3bc3bb334 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -199,7 +199,7 @@ func (e *Engine) Start() error { } } - if e.dnsServer == nil && runtime.GOOS == "android" { + if e.dnsServer == nil { // todo fix custom address dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg) if err != nil { @@ -259,14 +259,10 @@ func (e *Engine) Start() error { e.acl = acl } - if e.dnsServer == nil && runtime.GOOS != "android" { - // todo fix custom address - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg) - if err != nil { - e.close() - return err - } - e.dnsServer = dnsServer + err = e.dnsServer.Initialize() + if err != nil { + e.close() + return err } e.receiveSignalEvents()