mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 18:21:24 +01:00
Feature/android dns (#943)
Support DNS feature on mobile systems --------- Co-authored-by: Givi Khojanashvili <gigovich@gmail.com>
This commit is contained in:
parent
f8da516128
commit
481465e1ae
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
24
client/internal/dns/host_android.go
Normal file
24
client/internal/dns/host_android.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
type androidHostManager struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||||
|
return &androidHostManager{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a androidHostManager) restoreHostDNS() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a androidHostManager) supportCustomPort() bool {
|
||||||
|
return false
|
||||||
|
}
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -22,7 +22,7 @@ func (d *localResolver) stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Tracef("received question: %#v\n", r.Question[0])
|
log.Tracef("received question: %#v", r.Question[0])
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
@ -26,6 +26,10 @@ func (m *MockServer) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) DnsIP() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
if m.UpdateDNSServerFunc != nil {
|
if m.UpdateDNSServerFunc != nil {
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -1,10 +1,597 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import nbdns "github.com/netbirdio/netbird/dns"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 53
|
||||||
|
customPort = 5053
|
||||||
|
defaultIP = "127.0.0.1"
|
||||||
|
customIP = "127.0.0.153"
|
||||||
|
)
|
||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
Start()
|
Start()
|
||||||
Stop()
|
Stop()
|
||||||
|
DnsIP() string
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type registeredHandlerMap map[string]handlerWithStop
|
||||||
|
|
||||||
|
// DefaultServer dns server object
|
||||||
|
type DefaultServer struct {
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
mux sync.Mutex
|
||||||
|
fakeResolverWG sync.WaitGroup
|
||||||
|
server *dns.Server
|
||||||
|
dnsMux *dns.ServeMux
|
||||||
|
dnsMuxMap registeredHandlerMap
|
||||||
|
localResolver *localResolver
|
||||||
|
wgInterface *iface.WGIface
|
||||||
|
hostManager hostManager
|
||||||
|
updateSerial uint64
|
||||||
|
listenerIsRunning bool
|
||||||
|
runtimePort int
|
||||||
|
runtimeIP string
|
||||||
|
previousConfigHash uint64
|
||||||
|
currentConfig hostDNSConfig
|
||||||
|
customAddress *netip.AddrPort
|
||||||
|
enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type handlerWithStop interface {
|
||||||
|
dns.Handler
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxUpdate struct {
|
||||||
|
domain string
|
||||||
|
handler handlerWithStop
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultServer returns a new dns server
|
||||||
|
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
var addrPort *netip.AddrPort
|
||||||
|
if customAddress != "" {
|
||||||
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
||||||
|
}
|
||||||
|
addrPort = &parsedAddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
hostManager, err := newHostManager(wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
defaultServer := &DefaultServer{
|
||||||
|
ctx: ctx,
|
||||||
|
ctxCancel: stop,
|
||||||
|
server: &dns.Server{
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
UDPSize: 65535,
|
||||||
|
},
|
||||||
|
dnsMux: mux,
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
localResolver: &localResolver{
|
||||||
|
registeredMap: make(registrationMap),
|
||||||
|
},
|
||||||
|
wgInterface: wgInterface,
|
||||||
|
customAddress: addrPort,
|
||||||
|
hostManager: hostManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
if initialDnsCfg != nil {
|
||||||
|
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultServer.evalRuntimeAddress()
|
||||||
|
return defaultServer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start runs the listener in a go routine
|
||||||
|
func (s *DefaultServer) Start() {
|
||||||
|
// nil check required in unit tests
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||||
|
s.fakeResolverWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
hookID := s.filterDNSTraffic()
|
||||||
|
s.fakeResolverWG.Wait()
|
||||||
|
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
||||||
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
err := s.server.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) DnsIP() string {
|
||||||
|
if !s.enabled {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.runtimeIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||||
|
ips := []string{defaultIP, customIP}
|
||||||
|
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||||
|
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||||
|
}
|
||||||
|
ports := []int{defaultPort, customPort}
|
||||||
|
for _, port := range ports {
|
||||||
|
for _, ip := range ips {
|
||||||
|
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
|
if err == nil {
|
||||||
|
err = probeListener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||||
|
s.listenerIsRunning = running
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the server
|
||||||
|
func (s *DefaultServer) Stop() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
s.ctxCancel()
|
||||||
|
|
||||||
|
err := s.hostManager.restoreHostDNS()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||||
|
s.fakeResolverWG.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.stopListener()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) stopListener() error {
|
||||||
|
if !s.listenerIsRunning {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.ShutdownContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDNSServer processes an update received from the management service
|
||||||
|
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
log.Infof("not updating DNS server as context is closed")
|
||||||
|
return s.ctx.Err()
|
||||||
|
default:
|
||||||
|
if serial < s.updateSerial {
|
||||||
|
return fmt.Errorf("not applying dns update, error: "+
|
||||||
|
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||||
|
}
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||||
|
ZeroNil: true,
|
||||||
|
IgnoreZeroValue: true,
|
||||||
|
SlicesAsSets: true,
|
||||||
|
UseStringer: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.previousConfigHash == hash {
|
||||||
|
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||||
|
s.updateSerial = serial
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.applyConfiguration(update); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.updateSerial = serial
|
||||||
|
s.previousConfigHash = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
|
// is the service should be disabled, we stop the listener or fake resolver
|
||||||
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
|
if !update.ServiceEnable {
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||||
|
s.fakeResolverWG.Done()
|
||||||
|
} else {
|
||||||
|
if err := s.stopListener(); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !s.listenerIsRunning {
|
||||||
|
s.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||||
|
|
||||||
|
s.updateMux(muxUpdates)
|
||||||
|
s.updateLocalResolver(localRecords)
|
||||||
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||||
|
|
||||||
|
hostUpdate := s.currentConfig
|
||||||
|
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
|
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||||
|
hostUpdate.routeAll = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
|
if len(customZone.Records) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
|
}
|
||||||
|
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: customZone.Domain,
|
||||||
|
handler: s.localResolver,
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, record := range customZone.Records {
|
||||||
|
var class uint16 = dns.ClassINET
|
||||||
|
if record.Class != nbdns.DefaultClass {
|
||||||
|
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||||
|
}
|
||||||
|
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||||
|
localRecords[key] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, localRecords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
||||||
|
|
||||||
|
var muxUpdates []muxUpdate
|
||||||
|
for _, nsGroup := range nameServerGroups {
|
||||||
|
if len(nsGroup.NameServers) == 0 {
|
||||||
|
log.Warn("received a nameserver group with empty nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := newUpstreamResolver(s.ctx)
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
|
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
||||||
|
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(handler.upstreamServers) == 0 {
|
||||||
|
handler.stop()
|
||||||
|
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// when upstream fails to resolve domain several times over all it servers
|
||||||
|
// it will calls this hook to exclude self from the configuration and
|
||||||
|
// reapply DNS settings, but it not touch the original configuration and serial number
|
||||||
|
// because it is temporal deactivation until next try
|
||||||
|
//
|
||||||
|
// after some period defined by upstream it trys to reactivate self by calling this hook
|
||||||
|
// everything we need here is just to re-apply current configuration because it already
|
||||||
|
// contains this upstream settings (temporal deactivation not removed it)
|
||||||
|
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: nbdns.RootZone,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nsGroup.Domains) == 0 {
|
||||||
|
handler.stop()
|
||||||
|
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
if domain == "" {
|
||||||
|
handler.stop()
|
||||||
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
|
}
|
||||||
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
|
domain: domain,
|
||||||
|
handler: handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return muxUpdates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
|
||||||
|
for _, update := range muxUpdates {
|
||||||
|
s.registerMux(update.domain, update.handler)
|
||||||
|
muxUpdateMap[update.domain] = update.handler
|
||||||
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
|
existingHandler.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
|
_, found := muxUpdateMap[key]
|
||||||
|
if !found {
|
||||||
|
existingHandler.stop()
|
||||||
|
s.deregisterMux(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
|
for key := range s.localResolver.registeredMap {
|
||||||
|
_, found := update[key]
|
||||||
|
if !found {
|
||||||
|
s.localResolver.deleteRecord(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedMap := make(registrationMap)
|
||||||
|
for key, record := range update {
|
||||||
|
err := s.localResolver.registerRecord(record)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
||||||
|
}
|
||||||
|
updatedMap[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.localResolver.registeredMap = updatedMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNSHostPort(ns nbdns.NameServer) string {
|
||||||
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||||
|
s.dnsMux.Handle(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||||
|
s.dnsMux.HandleRemove(pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
|
func (s *DefaultServer) upstreamCallbacks(
|
||||||
|
nsGroup *nbdns.NameServerGroup,
|
||||||
|
handler dns.Handler,
|
||||||
|
) (deactivate func(), reactivate func()) {
|
||||||
|
var removeIndex map[string]int
|
||||||
|
deactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Info("temporary deactivate nameservers group due timeout")
|
||||||
|
|
||||||
|
removeIndex = make(map[string]int)
|
||||||
|
for _, domain := range nsGroup.Domains {
|
||||||
|
removeIndex[domain] = -1
|
||||||
|
}
|
||||||
|
if nsGroup.Primary {
|
||||||
|
removeIndex[nbdns.RootZone] = -1
|
||||||
|
s.currentConfig.routeAll = false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range s.currentConfig.domains {
|
||||||
|
if _, found := removeIndex[item.domain]; found {
|
||||||
|
s.currentConfig.domains[i].disabled = true
|
||||||
|
s.deregisterMux(item.domain)
|
||||||
|
removeIndex[item.domain] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reactivate = func() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
for domain, i := range removeIndex {
|
||||||
|
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.currentConfig.domains[i].disabled = false
|
||||||
|
s.registerMux(domain, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
|
l.Debug("reactivate temporary disabled nameserver group")
|
||||||
|
|
||||||
|
if nsGroup.Primary {
|
||||||
|
s.currentConfig.routeAll = true
|
||||||
|
}
|
||||||
|
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
||||||
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) filterDNSTraffic() string {
|
||||||
|
filter := s.wgInterface.GetFilter()
|
||||||
|
if filter == nil {
|
||||||
|
log.Error("can't set DNS filter, filter not initialized")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
|
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||||
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := func(packetData []byte) bool {
|
||||||
|
// Decode the packet
|
||||||
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
udp := udpLayer.(*layers.UDP)
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(udp.Payload); err != nil {
|
||||||
|
log.Tracef("parse DNS request: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := responseWriter{
|
||||||
|
packet: packet,
|
||||||
|
device: s.wgInterface.GetDevice().Device,
|
||||||
|
}
|
||||||
|
go s.dnsMux.ServeDNS(&writer, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) evalRuntimeAddress() {
|
||||||
|
defer func() {
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||||
|
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||||
|
s.runtimePort = defaultPort
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.customAddress != nil {
|
||||||
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, port, err := s.getFirstListenerAvailable()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtimeIP = ip
|
||||||
|
s.runtimePort = port
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||||
|
// Calculate the last IP in the CIDR range
|
||||||
|
var endIP net.IP
|
||||||
|
for i := 0; i < len(network.IP); i++ {
|
||||||
|
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to big.Int
|
||||||
|
endInt := big.NewInt(0)
|
||||||
|
endInt.SetBytes(endIP)
|
||||||
|
|
||||||
|
// subtract fromEnd from the last ip
|
||||||
|
fromEndBig := big.NewInt(int64(fromEnd))
|
||||||
|
resultInt := big.NewInt(0)
|
||||||
|
resultInt.Sub(endInt, fromEndBig)
|
||||||
|
|
||||||
|
return net.IP(resultInt.Bytes()).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
||||||
|
for _, c := range cfg.NameServerGroups {
|
||||||
|
if c.Primary {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DefaultServer dummy dns server
|
|
||||||
type DefaultServer struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServer On Android the DNS feature is not supported yet
|
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
|
||||||
return &DefaultServer{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start dummy implementation
|
|
||||||
func (s DefaultServer) Start() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop dummy implementation
|
|
||||||
func (s DefaultServer) Stop() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer dummy implementation
|
|
||||||
func (s DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,565 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultPort = 53
|
|
||||||
customPort = 5053
|
|
||||||
defaultIP = "127.0.0.1"
|
|
||||||
customIP = "127.0.0.153"
|
|
||||||
)
|
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
|
||||||
|
|
||||||
// DefaultServer dns server object
|
|
||||||
type DefaultServer struct {
|
|
||||||
ctx context.Context
|
|
||||||
ctxCancel context.CancelFunc
|
|
||||||
mux sync.Mutex
|
|
||||||
fakeResolverWG sync.WaitGroup
|
|
||||||
server *dns.Server
|
|
||||||
dnsMux *dns.ServeMux
|
|
||||||
dnsMuxMap registeredHandlerMap
|
|
||||||
localResolver *localResolver
|
|
||||||
wgInterface *iface.WGIface
|
|
||||||
hostManager hostManager
|
|
||||||
updateSerial uint64
|
|
||||||
listenerIsRunning bool
|
|
||||||
runtimePort int
|
|
||||||
runtimeIP string
|
|
||||||
previousConfigHash uint64
|
|
||||||
currentConfig hostDNSConfig
|
|
||||||
customAddress *netip.AddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
type handlerWithStop interface {
|
|
||||||
dns.Handler
|
|
||||||
stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
type muxUpdate struct {
|
|
||||||
domain string
|
|
||||||
handler handlerWithStop
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
|
||||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string) (*DefaultServer, error) {
|
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
UDPSize: 65535,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, stop := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
var addrPort *netip.AddrPort
|
|
||||||
if customAddress != "" {
|
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
|
||||||
if err != nil {
|
|
||||||
stop()
|
|
||||||
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
|
|
||||||
}
|
|
||||||
addrPort = &parsedAddrPort
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
|
||||||
ctx: ctx,
|
|
||||||
ctxCancel: stop,
|
|
||||||
server: dnsServer,
|
|
||||||
dnsMux: mux,
|
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
|
||||||
localResolver: &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
|
||||||
runtimePort: defaultPort,
|
|
||||||
customAddress: addrPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
hostmanager, err := newHostManager(wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
stop()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defaultServer.hostManager = hostmanager
|
|
||||||
return defaultServer, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start runs the listener in a go routine
|
|
||||||
func (s *DefaultServer) Start() {
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
|
||||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
|
||||||
s.runtimePort = 53
|
|
||||||
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
s.fakeResolverWG.Add(1)
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
hookID := s.filterDNSTraffic()
|
|
||||||
s.fakeResolverWG.Wait()
|
|
||||||
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
|
||||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.customAddress != nil {
|
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
|
||||||
} else {
|
|
||||||
ip, port, err := s.getFirstListenerAvailable()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtimeIP = ip
|
|
||||||
s.runtimePort = port
|
|
||||||
}
|
|
||||||
|
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
|
||||||
|
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.setListenerStatus(true)
|
|
||||||
defer s.setListenerStatus(false)
|
|
||||||
|
|
||||||
err := s.server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
|
||||||
ips := []string{defaultIP, customIP}
|
|
||||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
|
||||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
|
||||||
}
|
|
||||||
ports := []int{defaultPort, customPort}
|
|
||||||
for _, port := range ports {
|
|
||||||
for _, ip := range ips {
|
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err == nil {
|
|
||||||
err = probeListener.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
|
||||||
}
|
|
||||||
return ip, port, nil
|
|
||||||
}
|
|
||||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
|
||||||
s.listenerIsRunning = running
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *DefaultServer) Stop() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
s.ctxCancel()
|
|
||||||
|
|
||||||
err := s.hostManager.restoreHostDNS()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
|
||||||
s.fakeResolverWG.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.stopListener()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) stopListener() error {
|
|
||||||
if !s.listenerIsRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := s.server.ShutdownContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDNSServer processes an update received from the management service
|
|
||||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
log.Infof("not updating DNS server as context is closed")
|
|
||||||
return s.ctx.Err()
|
|
||||||
default:
|
|
||||||
if serial < s.updateSerial {
|
|
||||||
return fmt.Errorf("not applying dns update, error: "+
|
|
||||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
|
||||||
}
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
|
||||||
ZeroNil: true,
|
|
||||||
IgnoreZeroValue: true,
|
|
||||||
SlicesAsSets: true,
|
|
||||||
UseStringer: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.previousConfigHash == hash {
|
|
||||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
|
||||||
s.updateSerial = serial
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.applyConfiguration(update); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateSerial = serial
|
|
||||||
s.previousConfigHash = hash
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|
||||||
// is the service should be disabled, we stop the listener or fake resolver
|
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
|
||||||
if !update.ServiceEnable {
|
|
||||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
|
||||||
s.fakeResolverWG.Done()
|
|
||||||
} else {
|
|
||||||
if err := s.stopListener(); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if !s.listenerIsRunning {
|
|
||||||
s.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
|
||||||
|
|
||||||
s.updateMux(muxUpdates)
|
|
||||||
s.updateLocalResolver(localRecords)
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
|
||||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
|
||||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
|
||||||
hostUpdate.routeAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]muxUpdate, map[string]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
|
||||||
}
|
|
||||||
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: customZone.Domain,
|
|
||||||
handler: s.localResolver,
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
|
||||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
|
||||||
}
|
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
|
||||||
localRecords[key] = record
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, localRecords, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]muxUpdate, error) {
|
|
||||||
|
|
||||||
var muxUpdates []muxUpdate
|
|
||||||
for _, nsGroup := range nameServerGroups {
|
|
||||||
if len(nsGroup.NameServers) == 0 {
|
|
||||||
log.Warn("received a nameserver group with empty nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := newUpstreamResolver(s.ctx)
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
|
||||||
log.Warnf("skiping nameserver %s with type %s, this peer supports only %s",
|
|
||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// when upstream fails to resolve domain several times over all it servers
|
|
||||||
// it will calls this hook to exclude self from the configuration and
|
|
||||||
// reapply DNS settings, but it not touch the original configuration and serial number
|
|
||||||
// because it is temporal deactivation until next try
|
|
||||||
//
|
|
||||||
// after some period defined by upstream it trys to reactivate self by calling this hook
|
|
||||||
// everything we need here is just to re-apply current configuration because it already
|
|
||||||
// contains this upstream settings (temporal deactivation not removed it)
|
|
||||||
handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler)
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: nbdns.RootZone,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nsGroup.Domains) == 0 {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a non primary nameserver group with an empty domain list")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
if domain == "" {
|
|
||||||
handler.stop()
|
|
||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
|
||||||
}
|
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
|
||||||
domain: domain,
|
|
||||||
handler: handler,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return muxUpdates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
|
||||||
s.registerMux(update.domain, update.handler)
|
|
||||||
muxUpdateMap[update.domain] = update.handler
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
|
||||||
existingHandler.stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
|
||||||
_, found := muxUpdateMap[key]
|
|
||||||
if !found {
|
|
||||||
existingHandler.stop()
|
|
||||||
s.deregisterMux(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for key, record := range update {
|
|
||||||
err := s.localResolver.registerRecord(record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
|
|
||||||
}
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
|
||||||
s.dnsMux.HandleRemove(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
|
||||||
func (s *DefaultServer) upstreamCallbacks(
|
|
||||||
nsGroup *nbdns.NameServerGroup,
|
|
||||||
handler dns.Handler,
|
|
||||||
) (deactivate func(), reactivate func()) {
|
|
||||||
var removeIndex map[string]int
|
|
||||||
deactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Info("temporary deactivate nameservers group due timeout")
|
|
||||||
|
|
||||||
removeIndex = make(map[string]int)
|
|
||||||
for _, domain := range nsGroup.Domains {
|
|
||||||
removeIndex[domain] = -1
|
|
||||||
}
|
|
||||||
if nsGroup.Primary {
|
|
||||||
removeIndex[nbdns.RootZone] = -1
|
|
||||||
s.currentConfig.routeAll = false
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, item := range s.currentConfig.domains {
|
|
||||||
if _, found := removeIndex[item.domain]; found {
|
|
||||||
s.currentConfig.domains[i].disabled = true
|
|
||||||
s.deregisterMux(item.domain)
|
|
||||||
removeIndex[item.domain] = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("fail to apply nameserver deactivation on the host")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
reactivate = func() {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
for domain, i := range removeIndex {
|
|
||||||
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.currentConfig.domains[i].disabled = false
|
|
||||||
s.registerMux(domain, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
|
||||||
l.Debug("reactivate temporary disabled nameserver group")
|
|
||||||
|
|
||||||
if nsGroup.Primary {
|
|
||||||
s.currentConfig.routeAll = true
|
|
||||||
}
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DefaultServer) filterDNSTraffic() string {
|
|
||||||
filter := s.wgInterface.GetFilter()
|
|
||||||
if filter == nil {
|
|
||||||
log.Error("can't set DNS filter, filter not initialized")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
|
|
||||||
hook := func(packetData []byte) bool {
|
|
||||||
// Decode the packet
|
|
||||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
|
||||||
|
|
||||||
// Get the UDP layer
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
udp := udpLayer.(*layers.UDP)
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(udp.Payload); err != nil {
|
|
||||||
log.Tracef("parse DNS request: %v", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := responseWriter{
|
|
||||||
packet: packet,
|
|
||||||
device: s.wgInterface.GetDevice().Device,
|
|
||||||
}
|
|
||||||
go s.dnsMux.ServeDNS(&writer, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
|
||||||
// Calculate the last IP in the CIDR range
|
|
||||||
var endIP net.IP
|
|
||||||
for i := 0; i < len(network.IP); i++ {
|
|
||||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert to big.Int
|
|
||||||
endInt := big.NewInt(0)
|
|
||||||
endInt.SetBytes(endIP)
|
|
||||||
|
|
||||||
// subtract fromEnd from the last ip
|
|
||||||
fromEndBig := big.NewInt(int64(fromEnd))
|
|
||||||
resultInt := big.NewInt(0)
|
|
||||||
resultInt.Sub(endInt, fromEndBig)
|
|
||||||
|
|
||||||
return net.IP(resultInt.Bytes()).String()
|
|
||||||
}
|
|
@ -1,31 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
|
||||||
if lastIP != tt.ip {
|
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -221,7 +221,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -428,7 +428,7 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
|
|
||||||
return &DefaultServer{
|
ds := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: cancel,
|
ctxCancel: cancel,
|
||||||
server: dnsServer,
|
server: dnsServer,
|
||||||
@ -439,4 +439,31 @@ func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultSe
|
|||||||
},
|
},
|
||||||
customAddress: parsedAddrPort,
|
customAddress: parsedAddrPort,
|
||||||
}
|
}
|
||||||
|
ds.evalRuntimeAddress()
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addr string
|
||||||
|
ip string
|
||||||
|
}{
|
||||||
|
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||||
|
{"192.168.0.0/30", "192.168.0.2"},
|
||||||
|
{"192.168.0.0/16", "192.168.255.254"},
|
||||||
|
{"192.168.0.0/24", "192.168.0.254"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error parsing CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||||
|
if lastIP != tt.ip {
|
||||||
|
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -189,14 +189,37 @@ func (e *Engine) Start() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := e.readInitialRoutes()
|
var routes []*route.Route
|
||||||
if err != nil {
|
var dnsCfg *nbdns.Config
|
||||||
return err
|
|
||||||
|
if runtime.GOOS == "android" {
|
||||||
|
routes, dnsCfg, err = e.readInitialSettings()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.dnsServer == nil {
|
||||||
|
// todo fix custom address
|
||||||
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
|
||||||
|
if err != nil {
|
||||||
|
e.close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
e.dnsServer = dnsServer
|
||||||
|
}
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener)
|
||||||
|
|
||||||
err = e.wgInterface.Create()
|
if runtime.GOOS != "android" {
|
||||||
|
err = e.wgInterface.Create()
|
||||||
|
} else {
|
||||||
|
err = e.wgInterface.CreateOnMobile(iface.MobileIFaceArguments{
|
||||||
|
Routes: e.routeManager.InitialRouteRange(),
|
||||||
|
Dns: e.dnsServer.DnsIP(),
|
||||||
|
})
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
@ -236,16 +259,6 @@ func (e *Engine) Start() error {
|
|||||||
e.acl = acl
|
e.acl = acl
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.dnsServer == nil {
|
|
||||||
// todo fix custom address
|
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
e.dnsServer = dnsServer
|
|
||||||
}
|
|
||||||
|
|
||||||
e.receiveSignalEvents()
|
e.receiveSignalEvents()
|
||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
|
|
||||||
@ -1027,17 +1040,14 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) readInitialRoutes() ([]*route.Route, error) {
|
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||||
if runtime.GOOS != "android" {
|
netMap, err := e.mgmClient.GetNetworkMap()
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
routesResp, err := e.mgmClient.GetRoutes()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return toRoutes(routesResp), nil
|
routes := toRoutes(netMap.GetRoutes())
|
||||||
|
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||||
|
return routes, &dnsCfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
type Manager interface {
|
type Manager interface {
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
SetRouteChangeListener(listener RouteListener)
|
SetRouteChangeListener(listener RouteListener)
|
||||||
|
InitialRouteRange() []string
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,10 +52,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.clientRoutes(initialRoutes)
|
cr := dm.clientRoutes(initialRoutes)
|
||||||
dm.notifier.setInitialClientRoutes(cr)
|
dm.notifier.setInitialClientRoutes(cr)
|
||||||
networks := readRouteNetworks(cr)
|
|
||||||
|
|
||||||
// make sense to call before create interface
|
|
||||||
wgInterface.SetInitialRoutes(networks)
|
|
||||||
}
|
}
|
||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
@ -94,6 +91,11 @@ func (m *DefaultManager) SetRouteChangeListener(listener RouteListener) {
|
|||||||
m.notifier.setListener(listener)
|
m.notifier.setListener(listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
||||||
|
func (m *DefaultManager) InitialRouteRange() []string {
|
||||||
|
return m.notifier.initialRouteRanges()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
for id, client := range m.clientNetworks {
|
for id, client := range m.clientNetworks {
|
||||||
@ -163,11 +165,3 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
|||||||
}
|
}
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
func readRouteNetworks(cr []*route.Route) []string {
|
|
||||||
routesNetworks := make([]string, 0)
|
|
||||||
for _, r := range cr {
|
|
||||||
routesNetworks = append(routesNetworks, r.Network.String())
|
|
||||||
}
|
|
||||||
return routesNetworks
|
|
||||||
}
|
|
||||||
|
@ -14,8 +14,8 @@ type MockManager struct {
|
|||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitialClientRoutesNetworks mock implementation of InitialClientRoutesNetworks from Manager interface
|
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||||
func (m *MockManager) InitialClientRoutesNetworks() []string {
|
func (m *MockManager) InitialRouteRange() []string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,3 +84,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *notifier) initialRouteRanges() []string {
|
||||||
|
return n.initialRouteRangers
|
||||||
|
}
|
||||||
|
@ -36,15 +36,6 @@ func (w *WGIface) GetBind() *bind.ICEBind {
|
|||||||
return w.tun.iceBind
|
return w.tun.iceBind
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
|
||||||
// Will reuse an existing one.
|
|
||||||
func (w *WGIface) Create() error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
|
||||||
return w.tun.Create()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns the interface name
|
// Name returns the interface name
|
||||||
func (w *WGIface) Name() string {
|
func (w *WGIface) Name() string {
|
||||||
return w.tun.DeviceName()
|
return w.tun.DeviceName()
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
@ -27,7 +29,16 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
|
|||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInitialRoutes store the given routes and on the tun creation will be used
|
// CreateOnMobile creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
func (w *WGIface) SetInitialRoutes(routes []string) {
|
// Will reuse an existing one.
|
||||||
w.tun.SetRoutes(routes)
|
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||||
|
return w.tun.Create(mIFaceArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create this function make sense on mobile only
|
||||||
|
func (w *WGIface) Create() error {
|
||||||
|
return fmt.Errorf("this function has not implemented on mobile")
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,11 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v2"
|
"github.com/pion/transport/v2"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
@ -26,7 +28,16 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter
|
|||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInitialRoutes unused function on non Android
|
// CreateOnMobile this function make sense on mobile only
|
||||||
func (w *WGIface) SetInitialRoutes(routes []string) {
|
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||||
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
|
// Will reuse an existing one.
|
||||||
|
func (w *WGIface) Create() error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||||
|
return w.tun.Create()
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
|
type MobileIFaceArguments struct {
|
||||||
|
Routes []string
|
||||||
|
Dns string
|
||||||
|
}
|
||||||
|
|
||||||
// NetInterface represents a generic network tunnel interface
|
// NetInterface represents a generic network tunnel interface
|
||||||
type NetInterface interface {
|
type NetInterface interface {
|
||||||
Close() error
|
Close() error
|
||||||
|
@ -2,6 +2,6 @@ package iface
|
|||||||
|
|
||||||
// TunAdapter is an interface for create tun device from externel service
|
// TunAdapter is an interface for create tun device from externel service
|
||||||
type TunAdapter interface {
|
type TunAdapter interface {
|
||||||
ConfigureInterface(address string, mtu int, routes string) (int, error)
|
ConfigureInterface(address string, mtu int, dns string, routes string) (int, error)
|
||||||
UpdateAddr(address string) error
|
UpdateAddr(address string) error
|
||||||
}
|
}
|
||||||
|
@ -15,13 +15,12 @@ import (
|
|||||||
type tunDevice struct {
|
type tunDevice struct {
|
||||||
address WGAddress
|
address WGAddress
|
||||||
mtu int
|
mtu int
|
||||||
routes []string
|
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
|
iceBind *bind.ICEBind
|
||||||
|
|
||||||
fd int
|
fd int
|
||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
iceBind *bind.ICEBind
|
|
||||||
wrapper *DeviceWrapper
|
wrapper *DeviceWrapper
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,14 +33,10 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tunDevice) SetRoutes(routes []string) {
|
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
||||||
t.routes = routes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tunDevice) Create() error {
|
|
||||||
var err error
|
var err error
|
||||||
routesString := t.routesToString()
|
routesString := t.routesToString(mIFaceArgs.Routes)
|
||||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
|
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
return err
|
return err
|
||||||
@ -95,6 +90,6 @@ func (t *tunDevice) Close() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tunDevice) routesToString() string {
|
func (t *tunDevice) routesToString(routes []string) string {
|
||||||
return strings.Join(t.routes, ";")
|
return strings.Join(routes, ";")
|
||||||
}
|
}
|
||||||
|
@ -15,5 +15,5 @@ type Client interface {
|
|||||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||||
GetRoutes() ([]*proto.Route, error)
|
GetNetworkMap() (*proto.NetworkMap, error)
|
||||||
}
|
}
|
||||||
|
@ -172,8 +172,8 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoutes return with the routes
|
// GetNetworkMap return with the network map
|
||||||
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
|
func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.GetServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed getting Management Service public key: %s", err)
|
log.Debugf("failed getting Management Service public key: %s", err)
|
||||||
@ -212,7 +212,7 @@ func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
|
|||||||
return nil, fmt.Errorf("invalid msg, required network map")
|
return nil, fmt.Errorf("invalid msg, required network map")
|
||||||
}
|
}
|
||||||
|
|
||||||
return decryptedResp.GetNetworkMap().GetRoutes(), nil
|
return decryptedResp.GetNetworkMap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||||
|
@ -57,7 +57,7 @@ func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
|||||||
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoutes mock implementation of GetRoutes from mgm.Client interface
|
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
|
||||||
func (m *MockClient) GetRoutes() ([]*proto.Route, error) {
|
func (m *MockClient) GetNetworkMap() (*proto.NetworkMap, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user