diff --git a/client/internal/dns/dbus_linux.go b/client/internal/dns/dbus_linux.go index 0f6d4156a..cb38cd9d9 100644 --- a/client/internal/dns/dbus_linux.go +++ b/client/internal/dns/dbus_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package dns import ( diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go index 817c5a9b5..1d3a1e383 100644 --- a/client/internal/dns/file_linux.go +++ b/client/internal/dns/file_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package dns import ( diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go new file mode 100644 index 000000000..1cb07f0c7 --- /dev/null +++ b/client/internal/dns/host_android.go @@ -0,0 +1,24 @@ +package dns + +import ( + "github.com/netbirdio/netbird/iface" +) + +type androidHostManager struct { +} + +func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { + return &androidHostManager{}, nil +} + +func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error { + return nil +} + +func (a androidHostManager) restoreHostDNS() error { + return nil +} + +func (a androidHostManager) supportCustomPort() bool { + return false +} diff --git a/client/internal/dns/host_linux.go b/client/internal/dns/host_linux.go index e411280cc..ee80ab5f6 100644 --- a/client/internal/dns/host_linux.go +++ b/client/internal/dns/host_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package dns import ( diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 18fec812a..b4110f342 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -22,7 +22,7 @@ func (d *localResolver) stop() { // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v\n", r.Question[0]) + log.Tracef("received question: %#v", r.Question[0]) replyMessage := &dns.Msg{} replyMessage.SetReply(r) replyMessage.RecursionAvailable = true diff --git a/client/internal/dns/mockServer.go b/client/internal/dns/mockServer.go index ce02ed88e..8a7adabd7 100644 --- a/client/internal/dns/mockServer.go +++ b/client/internal/dns/mockServer.go @@ -26,6 +26,10 @@ func (m *MockServer) Stop() { } } +func (m *MockServer) DnsIP() string { + return "" +} + // UpdateDNSServer mock implementation of UpdateDNSServer from Server interface func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { if m.UpdateDNSServerFunc != nil { diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_linux.go index 1867f9f50..1fa713f46 100644 --- a/client/internal/dns/network_manager_linux.go +++ b/client/internal/dns/network_manager_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package dns import ( diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_linux.go index 374882ec2..0d2616f31 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package dns import ( diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 53006b164..a543d469e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -1,10 +1,597 @@ package dns -import nbdns "github.com/netbirdio/netbird/dns" +import ( + "context" + "fmt" + "math/big" + "net" + "net/netip" + "runtime" + "sync" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/miekg/dns" + "github.com/mitchellh/hashstructure/v2" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" +) + +const ( + defaultPort = 53 + customPort = 5053 + defaultIP = "127.0.0.1" + customIP = "127.0.0.153" +) // Server is a dns server interface type Server interface { Start() Stop() + DnsIP() string UpdateDNSServer(serial uint64, update nbdns.Config) error } + +type registeredHandlerMap map[string]handlerWithStop + +// DefaultServer dns server object +type DefaultServer struct { + ctx context.Context + ctxCancel context.CancelFunc + mux sync.Mutex + fakeResolverWG sync.WaitGroup + server *dns.Server + dnsMux *dns.ServeMux + dnsMuxMap registeredHandlerMap + localResolver *localResolver + wgInterface *iface.WGIface + hostManager hostManager + updateSerial uint64 + listenerIsRunning bool + runtimePort int + runtimeIP string + previousConfigHash uint64 + currentConfig hostDNSConfig + customAddress *netip.AddrPort + enabled bool +} + +type handlerWithStop interface { + dns.Handler + stop() +} + +type muxUpdate struct { + domain string + handler handlerWithStop +} + +// NewDefaultServer returns a new dns server +func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) { + mux := dns.NewServeMux() + + var addrPort *netip.AddrPort + if customAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if err != nil { + return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) + } + addrPort = &parsedAddrPort + } + + hostManager, err := newHostManager(wgInterface) + if err != nil { + return nil, err + } + + ctx, stop := context.WithCancel(ctx) + + defaultServer := &DefaultServer{ + ctx: ctx, + ctxCancel: stop, + server: &dns.Server{ + Net: "udp", + Handler: mux, + UDPSize: 65535, + }, + dnsMux: mux, + dnsMuxMap: make(registeredHandlerMap), + localResolver: &localResolver{ + registeredMap: make(registrationMap), + }, + wgInterface: wgInterface, + customAddress: addrPort, + hostManager: hostManager, + } + + if initialDnsCfg != nil { + defaultServer.enabled = hasValidDnsServer(initialDnsCfg) + } + + defaultServer.evalRuntimeAddress() + return defaultServer, nil +} + +// Start runs the listener in a go routine +func (s *DefaultServer) Start() { + // nil check required in unit tests + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { + s.fakeResolverWG.Add(1) + go func() { + s.setListenerStatus(true) + defer s.setListenerStatus(false) + + hookID := s.filterDNSTraffic() + s.fakeResolverWG.Wait() + if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil { + log.Errorf("unable to remove DNS packet hook: %s", err) + } + }() + return + } + + log.Debugf("starting dns on %s", s.server.Addr) + + go func() { + s.setListenerStatus(true) + defer s.setListenerStatus(false) + + err := s.server.ListenAndServe() + if err != nil { + log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) + } + }() +} + +func (s *DefaultServer) DnsIP() string { + if !s.enabled { + return "" + } + return s.runtimeIP +} + +func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) { + ips := []string{defaultIP, customIP} + if runtime.GOOS != "darwin" && s.wgInterface != nil { + ips = append([]string{s.wgInterface.Address().IP.String()}, ips...) + } + ports := []int{defaultPort, customPort} + for _, port := range ports { + for _, ip := range ips { + addrString := fmt.Sprintf("%s:%d", ip, port) + udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) + probeListener, err := net.ListenUDP("udp", udpAddr) + if err == nil { + err = probeListener.Close() + if err != nil { + log.Errorf("got an error closing the probe listener, error: %s", err) + } + return ip, port, nil + } + log.Warnf("binding dns on %s is not available, error: %s", addrString, err) + } + } + return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) +} + +func (s *DefaultServer) setListenerStatus(running bool) { + s.listenerIsRunning = running +} + +// Stop stops the server +func (s *DefaultServer) Stop() { + s.mux.Lock() + defer s.mux.Unlock() + s.ctxCancel() + + err := s.hostManager.restoreHostDNS() + if err != nil { + log.Error(err) + } + + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { + s.fakeResolverWG.Done() + } + + err = s.stopListener() + if err != nil { + log.Error(err) + } +} + +func (s *DefaultServer) stopListener() error { + if !s.listenerIsRunning { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := s.server.ShutdownContext(ctx) + if err != nil { + return fmt.Errorf("stopping dns server listener returned an error: %v", err) + } + return nil +} + +// UpdateDNSServer processes an update received from the management service +func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { + select { + case <-s.ctx.Done(): + log.Infof("not updating DNS server as context is closed") + return s.ctx.Err() + default: + if serial < s.updateSerial { + return fmt.Errorf("not applying dns update, error: "+ + "network update is %d behind the last applied update", s.updateSerial-serial) + } + s.mux.Lock() + defer s.mux.Unlock() + + hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Errorf("unable to hash the dns configuration update, got error: %s", err) + } + + if s.previousConfigHash == hash { + log.Debugf("not applying the dns configuration update as there is nothing new") + s.updateSerial = serial + return nil + } + + if err := s.applyConfiguration(update); err != nil { + return err + } + + s.updateSerial = serial + s.previousConfigHash = hash + + return nil + } +} + +func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { + // is the service should be disabled, we stop the listener or fake resolver + // and proceed with a regular update to clean up the handlers and records + if !update.ServiceEnable { + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { + s.fakeResolverWG.Done() + } else { + if err := s.stopListener(); err != nil { + log.Error(err) + } + } + } else if !s.listenerIsRunning { + s.Start() + } + + localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) + if err != nil { + return fmt.Errorf("not applying dns update, error: %v", err) + } + + muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) + + s.updateMux(muxUpdates) + s.updateLocalResolver(localRecords) + s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) + + hostUpdate := s.currentConfig + if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() { + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + + "Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") + hostUpdate.routeAll = false + } + + if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { + log.Error(err) + } + + return nil +} + +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { + var muxUpdates []muxUpdate + localRecords := make(map[string]nbdns.SimpleRecord, 0) + + for _, customZone := range customZones { + + if len(customZone.Records) == 0 { + return nil, nil, fmt.Errorf("received an empty list of records") + } + + muxUpdates = append(muxUpdates, muxUpdate{ + domain: customZone.Domain, + handler: s.localResolver, + }) + + for _, record := range customZone.Records { + var class uint16 = dns.ClassINET + if record.Class != nbdns.DefaultClass { + return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) + } + key := buildRecordKey(record.Name, class, uint16(record.Type)) + localRecords[key] = record + } + } + return muxUpdates, localRecords, nil +} + +func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { + + var muxUpdates []muxUpdate + for _, nsGroup := range nameServerGroups { + if len(nsGroup.NameServers) == 0 { + log.Warn("received a nameserver group with empty nameserver list") + continue + } + + handler := newUpstreamResolver(s.ctx) + for _, ns := range nsGroup.NameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) + } + + if len(handler.upstreamServers) == 0 { + handler.stop() + log.Errorf("received a nameserver group with an invalid nameserver list") + continue + } + + // when upstream fails to resolve domain several times over all it servers + // it will calls this hook to exclude self from the configuration and + // reapply DNS settings, but it not touch the original configuration and serial number + // because it is temporal deactivation until next try + // + // after some period defined by upstream it trys to reactivate self by calling this hook + // everything we need here is just to re-apply current configuration because it already + // contains this upstream settings (temporal deactivation not removed it) + handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) + + if nsGroup.Primary { + muxUpdates = append(muxUpdates, muxUpdate{ + domain: nbdns.RootZone, + handler: handler, + }) + continue + } + + if len(nsGroup.Domains) == 0 { + handler.stop() + return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") + } + + for _, domain := range nsGroup.Domains { + if domain == "" { + handler.stop() + return nil, fmt.Errorf("received a nameserver group with an empty domain element") + } + muxUpdates = append(muxUpdates, muxUpdate{ + domain: domain, + handler: handler, + }) + } + } + return muxUpdates, nil +} + +func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { + muxUpdateMap := make(registeredHandlerMap) + + for _, update := range muxUpdates { + s.registerMux(update.domain, update.handler) + muxUpdateMap[update.domain] = update.handler + if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { + existingHandler.stop() + } + } + + for key, existingHandler := range s.dnsMuxMap { + _, found := muxUpdateMap[key] + if !found { + existingHandler.stop() + s.deregisterMux(key) + } + } + + s.dnsMuxMap = muxUpdateMap +} + +func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { + for key := range s.localResolver.registeredMap { + _, found := update[key] + if !found { + s.localResolver.deleteRecord(key) + } + } + + updatedMap := make(registrationMap) + for key, record := range update { + err := s.localResolver.registerRecord(record) + if err != nil { + log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err) + } + updatedMap[key] = struct{}{} + } + + s.localResolver.registeredMap = updatedMap +} + +func getNSHostPort(ns nbdns.NameServer) string { + return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) +} + +func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) { + s.dnsMux.Handle(pattern, handler) +} + +func (s *DefaultServer) deregisterMux(pattern string) { + s.dnsMux.HandleRemove(pattern) +} + +// upstreamCallbacks returns two functions, the first one is used to deactivate +// the upstream resolver from the configuration, the second one is used to +// reactivate it. Not allowed to call reactivate before deactivate. +func (s *DefaultServer) upstreamCallbacks( + nsGroup *nbdns.NameServerGroup, + handler dns.Handler, +) (deactivate func(), reactivate func()) { + var removeIndex map[string]int + deactivate = func() { + s.mux.Lock() + defer s.mux.Unlock() + + l := log.WithField("nameservers", nsGroup.NameServers) + l.Info("temporary deactivate nameservers group due timeout") + + removeIndex = make(map[string]int) + for _, domain := range nsGroup.Domains { + removeIndex[domain] = -1 + } + if nsGroup.Primary { + removeIndex[nbdns.RootZone] = -1 + s.currentConfig.routeAll = false + } + + for i, item := range s.currentConfig.domains { + if _, found := removeIndex[item.domain]; found { + s.currentConfig.domains[i].disabled = true + s.deregisterMux(item.domain) + removeIndex[item.domain] = i + } + } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + l.WithError(err).Error("fail to apply nameserver deactivation on the host") + } + } + reactivate = func() { + s.mux.Lock() + defer s.mux.Unlock() + + for domain, i := range removeIndex { + if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { + continue + } + s.currentConfig.domains[i].disabled = false + s.registerMux(domain, handler) + } + + l := log.WithField("nameservers", nsGroup.NameServers) + l.Debug("reactivate temporary disabled nameserver group") + + if nsGroup.Primary { + s.currentConfig.routeAll = true + } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + } + } + return +} + +func (s *DefaultServer) filterDNSTraffic() string { + filter := s.wgInterface.GetFilter() + if filter == nil { + log.Error("can't set DNS filter, filter not initialized") + return "" + } + + firstLayerDecoder := layers.LayerTypeIPv4 + if s.wgInterface.Address().Network.IP.To4() == nil { + firstLayerDecoder = layers.LayerTypeIPv6 + } + + hook := func(packetData []byte) bool { + // Decode the packet + packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) + + // Get the UDP layer + udpLayer := packet.Layer(layers.LayerTypeUDP) + udp := udpLayer.(*layers.UDP) + + msg := new(dns.Msg) + if err := msg.Unpack(udp.Payload); err != nil { + log.Tracef("parse DNS request: %v", err) + return true + } + + writer := responseWriter{ + packet: packet, + device: s.wgInterface.GetDevice().Device, + } + go s.dnsMux.ServeDNS(&writer, msg) + return true + } + + return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook) +} + +func (s *DefaultServer) evalRuntimeAddress() { + defer func() { + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) + }() + + if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { + s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1) + s.runtimePort = defaultPort + return + } + + if s.customAddress != nil { + s.runtimeIP = s.customAddress.Addr().String() + s.runtimePort = int(s.customAddress.Port()) + return + } + + ip, port, err := s.getFirstListenerAvailable() + if err != nil { + log.Error(err) + return + } + s.runtimeIP = ip + s.runtimePort = port +} + +func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string { + // Calculate the last IP in the CIDR range + var endIP net.IP + for i := 0; i < len(network.IP); i++ { + endIP = append(endIP, network.IP[i]|^network.Mask[i]) + } + + // convert to big.Int + endInt := big.NewInt(0) + endInt.SetBytes(endIP) + + // subtract fromEnd from the last ip + fromEndBig := big.NewInt(int64(fromEnd)) + resultInt := big.NewInt(0) + resultInt.Sub(endInt, fromEndBig) + + return net.IP(resultInt.Bytes()).String() +} + +func hasValidDnsServer(cfg *nbdns.Config) bool { + for _, c := range cfg.NameServerGroups { + if c.Primary { + return true + } + } + return false +} diff --git a/client/internal/dns/server_android.go b/client/internal/dns/server_android.go deleted file mode 100644 index dddbc65a2..000000000 --- a/client/internal/dns/server_android.go +++ /dev/null @@ -1,32 +0,0 @@ -package dns - -import ( - "context" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" -) - -// DefaultServer dummy dns server -type DefaultServer struct { -} - -// NewDefaultServer On Android the DNS feature is not supported yet -func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) { - return &DefaultServer{}, nil -} - -// Start dummy implementation -func (s DefaultServer) Start() { - -} - -// Stop dummy implementation -func (s DefaultServer) Stop() { - -} - -// UpdateDNSServer dummy implementation -func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { - return nil -} diff --git a/client/internal/dns/server_nonandroid.go b/client/internal/dns/server_nonandroid.go deleted file mode 100644 index ec970bccf..000000000 --- a/client/internal/dns/server_nonandroid.go +++ /dev/null @@ -1,565 +0,0 @@ -//go:build !android - -package dns - -import ( - "context" - "fmt" - "math/big" - "net" - "net/netip" - "runtime" - "sync" - "time" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/miekg/dns" - "github.com/mitchellh/hashstructure/v2" - log "github.com/sirupsen/logrus" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" -) - -const ( - defaultPort = 53 - customPort = 5053 - defaultIP = "127.0.0.1" - customIP = "127.0.0.153" -) - -type registeredHandlerMap map[string]handlerWithStop - -// DefaultServer dns server object -type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc - mux sync.Mutex - fakeResolverWG sync.WaitGroup - server *dns.Server - dnsMux *dns.ServeMux - dnsMuxMap registeredHandlerMap - localResolver *localResolver - wgInterface *iface.WGIface - hostManager hostManager - updateSerial uint64 - listenerIsRunning bool - runtimePort int - runtimeIP string - previousConfigHash uint64 - currentConfig hostDNSConfig - customAddress *netip.AddrPort -} - -type handlerWithStop interface { - dns.Handler - stop() -} - -type muxUpdate struct { - domain string - handler handlerWithStop -} - -// NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) { - mux := dns.NewServeMux() - - dnsServer := &dns.Server{ - Net: "udp", - Handler: mux, - UDPSize: 65535, - } - - ctx, stop := context.WithCancel(ctx) - - var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) - if err != nil { - stop() - return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) - } - addrPort = &parsedAddrPort - } - - defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - server: dnsServer, - dnsMux: mux, - dnsMuxMap: make(registeredHandlerMap), - localResolver: &localResolver{ - registeredMap: make(registrationMap), - }, - wgInterface: wgInterface, - runtimePort: defaultPort, - customAddress: addrPort, - } - - hostmanager, err := newHostManager(wgInterface) - if err != nil { - stop() - return nil, err - } - defaultServer.hostManager = hostmanager - return defaultServer, err -} - -// Start runs the listener in a go routine -func (s *DefaultServer) Start() { - if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() { - s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1) - s.runtimePort = 53 - - s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) - s.fakeResolverWG.Add(1) - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) - - hookID := s.filterDNSTraffic() - s.fakeResolverWG.Wait() - if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil { - log.Errorf("unable to remove DNS packet hook: %s", err) - } - }() - return - } - - if s.customAddress != nil { - s.runtimeIP = s.customAddress.Addr().String() - s.runtimePort = int(s.customAddress.Port()) - } else { - ip, port, err := s.getFirstListenerAvailable() - if err != nil { - log.Error(err) - return - } - s.runtimeIP = ip - s.runtimePort = port - } - - s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort) - - log.Debugf("starting dns on %s", s.server.Addr) - - go func() { - s.setListenerStatus(true) - defer s.setListenerStatus(false) - - err := s.server.ListenAndServe() - if err != nil { - log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) - } - }() -} - -func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) { - ips := []string{defaultIP, customIP} - if runtime.GOOS != "darwin" && s.wgInterface != nil { - ips = append([]string{s.wgInterface.Address().IP.String()}, ips...) - } - ports := []int{defaultPort, customPort} - for _, port := range ports { - for _, ip := range ips { - addrString := fmt.Sprintf("%s:%d", ip, port) - udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) - probeListener, err := net.ListenUDP("udp", udpAddr) - if err == nil { - err = probeListener.Close() - if err != nil { - log.Errorf("got an error closing the probe listener, error: %s", err) - } - return ip, port, nil - } - log.Warnf("binding dns on %s is not available, error: %s", addrString, err) - } - } - return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports) -} - -func (s *DefaultServer) setListenerStatus(running bool) { - s.listenerIsRunning = running -} - -// Stop stops the server -func (s *DefaultServer) Stop() { - s.mux.Lock() - defer s.mux.Unlock() - s.ctxCancel() - - err := s.hostManager.restoreHostDNS() - if err != nil { - log.Error(err) - } - - if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { - s.fakeResolverWG.Done() - } - - err = s.stopListener() - if err != nil { - log.Error(err) - } -} - -func (s *DefaultServer) stopListener() error { - if !s.listenerIsRunning { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - err := s.server.ShutdownContext(ctx) - if err != nil { - return fmt.Errorf("stopping dns server listener returned an error: %v", err) - } - return nil -} - -// UpdateDNSServer processes an update received from the management service -func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { - select { - case <-s.ctx.Done(): - log.Infof("not updating DNS server as context is closed") - return s.ctx.Err() - default: - if serial < s.updateSerial { - return fmt.Errorf("not applying dns update, error: "+ - "network update is %d behind the last applied update", s.updateSerial-serial) - } - s.mux.Lock() - defer s.mux.Unlock() - - hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ - ZeroNil: true, - IgnoreZeroValue: true, - SlicesAsSets: true, - UseStringer: true, - }) - if err != nil { - log.Errorf("unable to hash the dns configuration update, got error: %s", err) - } - - if s.previousConfigHash == hash { - log.Debugf("not applying the dns configuration update as there is nothing new") - s.updateSerial = serial - return nil - } - - if err := s.applyConfiguration(update); err != nil { - return err - } - - s.updateSerial = serial - s.previousConfigHash = hash - - return nil - } -} - -func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { - // is the service should be disabled, we stop the listener or fake resolver - // and proceed with a regular update to clean up the handlers and records - if !update.ServiceEnable { - if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { - s.fakeResolverWG.Done() - } else { - if err := s.stopListener(); err != nil { - log.Error(err) - } - } - } else if !s.listenerIsRunning { - s.Start() - } - - localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) - if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) - } - upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups) - if err != nil { - return fmt.Errorf("not applying dns update, error: %v", err) - } - - muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) - - s.updateMux(muxUpdates) - s.updateLocalResolver(localRecords) - s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) - - hostUpdate := s.currentConfig - if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() { - log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + - "Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") - hostUpdate.routeAll = false - } - - if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { - log.Error(err) - } - - return nil -} - -func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) { - var muxUpdates []muxUpdate - localRecords := make(map[string]nbdns.SimpleRecord, 0) - - for _, customZone := range customZones { - - if len(customZone.Records) == 0 { - return nil, nil, fmt.Errorf("received an empty list of records") - } - - muxUpdates = append(muxUpdates, muxUpdate{ - domain: customZone.Domain, - handler: s.localResolver, - }) - - for _, record := range customZone.Records { - var class uint16 = dns.ClassINET - if record.Class != nbdns.DefaultClass { - return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) - } - key := buildRecordKey(record.Name, class, uint16(record.Type)) - localRecords[key] = record - } - } - return muxUpdates, localRecords, nil -} - -func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) { - - var muxUpdates []muxUpdate - for _, nsGroup := range nameServerGroups { - if len(nsGroup.NameServers) == 0 { - log.Warn("received a nameserver group with empty nameserver list") - continue - } - - handler := newUpstreamResolver(s.ctx) - for _, ns := range nsGroup.NameServers { - if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skiping nameserver %s with type %s, this peer supports only %s", - ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) - continue - } - handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) - } - - if len(handler.upstreamServers) == 0 { - handler.stop() - log.Errorf("received a nameserver group with an invalid nameserver list") - continue - } - - // when upstream fails to resolve domain several times over all it servers - // it will calls this hook to exclude self from the configuration and - // reapply DNS settings, but it not touch the original configuration and serial number - // because it is temporal deactivation until next try - // - // after some period defined by upstream it trys to reactivate self by calling this hook - // everything we need here is just to re-apply current configuration because it already - // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler) - - if nsGroup.Primary { - muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, - }) - continue - } - - if len(nsGroup.Domains) == 0 { - handler.stop() - return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list") - } - - for _, domain := range nsGroup.Domains { - if domain == "" { - handler.stop() - return nil, fmt.Errorf("received a nameserver group with an empty domain element") - } - muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, - }) - } - } - return muxUpdates, nil -} - -func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { - muxUpdateMap := make(registeredHandlerMap) - - for _, update := range muxUpdates { - s.registerMux(update.domain, update.handler) - muxUpdateMap[update.domain] = update.handler - if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { - existingHandler.stop() - } - } - - for key, existingHandler := range s.dnsMuxMap { - _, found := muxUpdateMap[key] - if !found { - existingHandler.stop() - s.deregisterMux(key) - } - } - - s.dnsMuxMap = muxUpdateMap -} - -func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { - for key := range s.localResolver.registeredMap { - _, found := update[key] - if !found { - s.localResolver.deleteRecord(key) - } - } - - updatedMap := make(registrationMap) - for key, record := range update { - err := s.localResolver.registerRecord(record) - if err != nil { - log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err) - } - updatedMap[key] = struct{}{} - } - - s.localResolver.registeredMap = updatedMap -} - -func getNSHostPort(ns nbdns.NameServer) string { - return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) -} - -func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) { - s.dnsMux.Handle(pattern, handler) -} - -func (s *DefaultServer) deregisterMux(pattern string) { - s.dnsMux.HandleRemove(pattern) -} - -// upstreamCallbacks returns two functions, the first one is used to deactivate -// the upstream resolver from the configuration, the second one is used to -// reactivate it. Not allowed to call reactivate before deactivate. -func (s *DefaultServer) upstreamCallbacks( - nsGroup *nbdns.NameServerGroup, - handler dns.Handler, -) (deactivate func(), reactivate func()) { - var removeIndex map[string]int - deactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("temporary deactivate nameservers group due timeout") - - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } - if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 - s.currentConfig.routeAll = false - } - - for i, item := range s.currentConfig.domains { - if _, found := removeIndex[item.domain]; found { - s.currentConfig.domains[i].disabled = true - s.deregisterMux(item.domain) - removeIndex[item.domain] = i - } - } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("fail to apply nameserver deactivation on the host") - } - } - reactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { - continue - } - s.currentConfig.domains[i].disabled = false - s.registerMux(domain, handler) - } - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") - - if nsGroup.Primary { - s.currentConfig.routeAll = true - } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") - } - } - return -} - -func (s *DefaultServer) filterDNSTraffic() string { - filter := s.wgInterface.GetFilter() - if filter == nil { - log.Error("can't set DNS filter, filter not initialized") - return "" - } - - firstLayerDecoder := layers.LayerTypeIPv4 - if s.wgInterface.Address().Network.IP.To4() == nil { - firstLayerDecoder = layers.LayerTypeIPv6 - } - - hook := func(packetData []byte) bool { - // Decode the packet - packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default) - - // Get the UDP layer - udpLayer := packet.Layer(layers.LayerTypeUDP) - udp := udpLayer.(*layers.UDP) - - msg := new(dns.Msg) - if err := msg.Unpack(udp.Payload); err != nil { - log.Tracef("parse DNS request: %v", err) - return true - } - - writer := responseWriter{ - packet: packet, - device: s.wgInterface.GetDevice().Device, - } - go s.dnsMux.ServeDNS(&writer, msg) - return true - } - - return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook) -} - -func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string { - // Calculate the last IP in the CIDR range - var endIP net.IP - for i := 0; i < len(network.IP); i++ { - endIP = append(endIP, network.IP[i]|^network.Mask[i]) - } - - // convert to big.Int - endInt := big.NewInt(0) - endInt.SetBytes(endIP) - - // subtract fromEnd from the last ip - fromEndBig := big.NewInt(int64(fromEnd)) - resultInt := big.NewInt(0) - resultInt.Sub(endInt, fromEndBig) - - return net.IP(resultInt.Bytes()).String() -} diff --git a/client/internal/dns/server_nonandroid_test.go b/client/internal/dns/server_nonandroid_test.go deleted file mode 100644 index bea4f4ce8..000000000 --- a/client/internal/dns/server_nonandroid_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package dns - -import ( - "net" - "testing" -) - -func TestGetLastIPFromNetwork(t *testing.T) { - tests := []struct { - addr string - ip string - }{ - {"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"}, - {"192.168.0.0/30", "192.168.0.2"}, - {"192.168.0.0/16", "192.168.255.254"}, - {"192.168.0.0/24", "192.168.0.254"}, - } - - for _, tt := range tests { - _, ipnet, err := net.ParseCIDR(tt.addr) - if err != nil { - t.Errorf("Error parsing CIDR: %v", err) - return - } - - lastIP := getLastIPFromNetwork(ipnet, 1) - if lastIP != tt.ip { - t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) - } - } -} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 201392105..56b44abf6 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -221,7 +221,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "") + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil) if err != nil { t.Fatal(err) } @@ -428,7 +428,7 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe ctx, cancel := context.WithCancel(context.TODO()) - return &DefaultServer{ + ds := &DefaultServer{ ctx: ctx, ctxCancel: cancel, server: dnsServer, @@ -439,4 +439,31 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe }, customAddress: parsedAddrPort, } + ds.evalRuntimeAddress() + return ds +} + +func TestGetLastIPFromNetwork(t *testing.T) { + tests := []struct { + addr string + ip string + }{ + {"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"}, + {"192.168.0.0/30", "192.168.0.2"}, + {"192.168.0.0/16", "192.168.255.254"}, + {"192.168.0.0/24", "192.168.0.254"}, + } + + for _, tt := range tests { + _, ipnet, err := net.ParseCIDR(tt.addr) + if err != nil { + t.Errorf("Error parsing CIDR: %v", err) + return + } + + lastIP := getLastIPFromNetwork(ipnet, 1) + if lastIP != tt.ip { + t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) + } + } } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 19776eacb..e870181af 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package dns import ( diff --git a/client/internal/engine.go b/client/internal/engine.go index 289f80c2a..2c08b1415 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -189,14 +189,37 @@ func (e *Engine) Start() error { return err } - routes, err := e.readInitialRoutes() - if err != nil { - return err + var routes []*route.Route + var dnsCfg *nbdns.Config + + if runtime.GOOS == "android" { + routes, dnsCfg, err = e.readInitialSettings() + if err != nil { + return err + } } + + if e.dnsServer == nil { + // todo fix custom address + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg) + if err != nil { + e.close() + return err + } + e.dnsServer = dnsServer + } + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener) - err = e.wgInterface.Create() + if runtime.GOOS != "android" { + err = e.wgInterface.Create() + } else { + err = e.wgInterface.CreateOnMobile(iface.MobileIFaceArguments{ + Routes: e.routeManager.InitialRouteRange(), + Dns: e.dnsServer.DnsIP(), + }) + } if err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error()) e.close() @@ -236,16 +259,6 @@ func (e *Engine) Start() error { e.acl = acl } - if e.dnsServer == nil { - // todo fix custom address - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) - if err != nil { - e.close() - return err - } - e.dnsServer = dnsServer - } - e.receiveSignalEvents() e.receiveManagementEvents() @@ -1027,17 +1040,14 @@ func (e *Engine) close() { } } -func (e *Engine) readInitialRoutes() ([]*route.Route, error) { - if runtime.GOOS != "android" { - return nil, nil - } - - routesResp, err := e.mgmClient.GetRoutes() +func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { + netMap, err := e.mgmClient.GetNetworkMap() if err != nil { - return nil, err + return nil, nil, err } - return toRoutes(routesResp), nil - + routes := toRoutes(netMap.GetRoutes()) + dnsCfg := toDNSConfig(netMap.GetDNSConfig()) + return routes, &dnsCfg, nil } func findIPFromInterfaceName(ifaceName string) (net.IP, error) { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 840d74269..7324759f9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -17,6 +17,7 @@ import ( type Manager interface { UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener RouteListener) + InitialRouteRange() []string Stop() } @@ -51,10 +52,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, if runtime.GOOS == "android" { cr := dm.clientRoutes(initialRoutes) dm.notifier.setInitialClientRoutes(cr) - networks := readRouteNetworks(cr) - - // make sense to call before create interface - wgInterface.SetInitialRoutes(networks) } return dm } @@ -94,6 +91,11 @@ func (m *DefaultManager) SetRouteChangeListener(listener RouteListener) { m.notifier.setListener(listener) } +// InitialRouteRange return the list of initial routes. It used by mobile systems +func (m *DefaultManager) InitialRouteRange() []string { + return m.notifier.initialRouteRanges() +} + func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { // removing routes that do not exist as per the update from the Management service. for id, client := range m.clientNetworks { @@ -163,11 +165,3 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } return rs } - -func readRouteNetworks(cr []*route.Route) []string { - routesNetworks := make([]string, 0) - for _, r := range cr { - routesNetworks = append(routesNetworks, r.Network.String()) - } - return routesNetworks -} diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index bd619a1c8..f56dbfb17 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -14,8 +14,8 @@ type MockManager struct { StopFunc func() } -// InitialClientRoutesNetworks mock implementation of InitialClientRoutesNetworks from Manager interface -func (m *MockManager) InitialClientRoutesNetworks() []string { +// InitialRouteRange mock implementation of InitialRouteRange from Manager interface +func (m *MockManager) InitialRouteRange() []string { return nil } diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go index 2d1afa055..e37811166 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier.go @@ -84,3 +84,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool { } return false } + +func (n *notifier) initialRouteRanges() []string { + return n.initialRouteRangers +} diff --git a/iface/iface.go b/iface/iface.go index 4fcc064d1..6c7e1a1cd 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -36,15 +36,6 @@ func (w *WGIface) GetBind() *bind.ICEBind { return w.tun.iceBind } -// Create creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) Create() error { - w.mu.Lock() - defer w.mu.Unlock() - log.Debugf("create WireGuard interface %s", w.tun.DeviceName()) - return w.tun.Create() -} - // Name returns the interface name func (w *WGIface) Name() string { return w.tun.DeviceName() diff --git a/iface/iface_android.go b/iface/iface_android.go index 8b6e55f96..6f47e5aa7 100644 --- a/iface/iface_android.go +++ b/iface/iface_android.go @@ -1,9 +1,11 @@ package iface import ( + "fmt" "sync" "github.com/pion/transport/v2" + log "github.com/sirupsen/logrus" ) // NewWGIFace Creates a new WireGuard interface instance @@ -27,7 +29,16 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter return wgIFace, nil } -// SetInitialRoutes store the given routes and on the tun creation will be used -func (w *WGIface) SetInitialRoutes(routes []string) { - w.tun.SetRoutes(routes) +// CreateOnMobile creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { + w.mu.Lock() + defer w.mu.Unlock() + log.Debugf("create WireGuard interface %s", w.tun.DeviceName()) + return w.tun.Create(mIFaceArgs) +} + +// Create this function make sense on mobile only +func (w *WGIface) Create() error { + return fmt.Errorf("this function has not implemented on mobile") } diff --git a/iface/iface_nonandroid.go b/iface/iface_nonandroid.go index fca7059f0..9622207bb 100644 --- a/iface/iface_nonandroid.go +++ b/iface/iface_nonandroid.go @@ -3,9 +3,11 @@ package iface import ( + "fmt" "sync" "github.com/pion/transport/v2" + log "github.com/sirupsen/logrus" ) // NewWGIFace Creates a new WireGuard interface instance @@ -26,7 +28,16 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter return wgIFace, nil } -// SetInitialRoutes unused function on non Android -func (w *WGIface) SetInitialRoutes(routes []string) { - +// CreateOnMobile this function make sense on mobile only +func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { + return fmt.Errorf("this function has not implemented on non mobile") +} + +// Create creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) Create() error { + w.mu.Lock() + defer w.mu.Unlock() + log.Debugf("create WireGuard interface %s", w.tun.DeviceName()) + return w.tun.Create() } diff --git a/iface/tun.go b/iface/tun.go index f81222cdb..51a7783a1 100644 --- a/iface/tun.go +++ b/iface/tun.go @@ -1,5 +1,10 @@ package iface +type MobileIFaceArguments struct { + Routes []string + Dns string +} + // NetInterface represents a generic network tunnel interface type NetInterface interface { Close() error diff --git a/iface/tun_adapter.go b/iface/tun_adapter.go index 07b593ffb..0ba0bde22 100644 --- a/iface/tun_adapter.go +++ b/iface/tun_adapter.go @@ -2,6 +2,6 @@ package iface // TunAdapter is an interface for create tun device from externel service type TunAdapter interface { - ConfigureInterface(address string, mtu int, routes string) (int, error) + ConfigureInterface(address string, mtu int, dns string, routes string) (int, error) UpdateAddr(address string) error } diff --git a/iface/tun_android.go b/iface/tun_android.go index 9f4f6e192..e54a6f730 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -15,13 +15,12 @@ import ( type tunDevice struct { address WGAddress mtu int - routes []string tunAdapter TunAdapter + iceBind *bind.ICEBind fd int name string device *device.Device - iceBind *bind.ICEBind wrapper *DeviceWrapper } @@ -34,14 +33,10 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe } } -func (t *tunDevice) SetRoutes(routes []string) { - t.routes = routes -} - -func (t *tunDevice) Create() error { +func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { var err error - routesString := t.routesToString() - t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString) + routesString := t.routesToString(mIFaceArgs.Routes) + t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return err @@ -95,6 +90,6 @@ func (t *tunDevice) Close() (err error) { return } -func (t *tunDevice) routesToString() string { - return strings.Join(t.routes, ";") +func (t *tunDevice) routesToString(routes []string) string { + return strings.Join(routes, ";") } diff --git a/management/client/client.go b/management/client/client.go index 2f903d210..d2022f806 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -15,5 +15,5 @@ type Client interface { Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetRoutes() ([]*proto.Route, error) + GetNetworkMap() (*proto.NetworkMap, error) } diff --git a/management/client/grpc.go b/management/client/grpc.go index 6b3e74b30..d2ca8c088 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -172,8 +172,8 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error return nil } -// GetRoutes return with the routes -func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) { +// GetNetworkMap return with the network map +func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) { serverPubKey, err := c.GetServerPublicKey() if err != nil { log.Debugf("failed getting Management Service public key: %s", err) @@ -212,7 +212,7 @@ func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) { return nil, fmt.Errorf("invalid msg, required network map") } - return decryptedResp.GetNetworkMap().GetRoutes(), nil + return decryptedResp.GetNetworkMap(), nil } func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { diff --git a/management/client/mock.go b/management/client/mock.go index 589a4f784..ccad538c1 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -57,7 +57,7 @@ func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D return m.GetDeviceAuthorizationFlowFunc(serverKey) } -// GetRoutes mock implementation of GetRoutes from mgm.Client interface -func (m *MockClient) GetRoutes() ([]*proto.Route, error) { +// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface +func (m *MockClient) GetNetworkMap() (*proto.NetworkMap, error) { return nil, nil }