package dns import ( "context" "errors" "fmt" "math/rand" "net" "net/netip" "sync" "sync/atomic" "syscall" "time" "github.com/cenkalti/backoff/v4" "github.com/libp2p/go-netroute" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) const ( failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second upstreamTimeout = 15 * time.Second ) type upstreamClient interface { ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) } type upstreamResolver struct { ctx context.Context cancel context.CancelFunc upstreamClient upstreamClient upstreamServers []string disabled bool failsCount atomic.Int32 failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration deactivate func() reactivate func() } // func newUpstreamResolver(parentCTX context.Context) *upstreamResolver { // ctx, cancel := context.WithCancel(parentCTX) // return &upstreamResolver{ // ctx: ctx, // cancel: cancel, // upstreamClient: &dns.Client{}, // upstreamTimeout: upstreamTimeout, // reactivatePeriod: reactivatePeriod, // failsTillDeact: failsTillDeact, // } // } func getInterfaceIndex(interfaceName string) (int, error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { return 0, err } return iface.Index, nil } func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr string) *upstreamResolver { ctx, cancel := context.WithCancel(parentCTX) // Specify the local IP address you want to bind to localIP, _, err := net.ParseCIDR(wgAddr) // Should be our interface IP if err != nil { log.Errorf("error while parsing CIDR: %s", err) } index, err := getInterfaceIndex(interfaceName) rand.Seed(time.Now().UnixNano()) port := rand.Intn(4001) + 1000 log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s, port: %d", interfaceName, index, localIP, port) if err != nil { log.Debugf("unable to get interface index for %s: %s", interfaceName, err) } localIFaceIndex := index // Should be our interface index // Create a custom dialer with the LocalAddr set to the desired IP dialer := &net.Dialer{ LocalAddr: &net.UDPAddr{ IP: localIP, Port: port, // Let the OS pick a free port }, Control: func(network, address string, c syscall.RawConn) error { var operr error fn := func(s uintptr) { operr = syscall.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, localIFaceIndex) } if err := c.Control(fn); err != nil { return err } return operr }, } // pktConn, err := dialer.Dial("udp", "100.127.136.151:10053") // if err != nil { // log.Errorf("error while dialing: %s", err) // // } else { // pktConn.Write([]byte("hello")) // pktConn.Close() // } // Create a new DNS client with the custom dialer client := &dns.Client{ Dialer: dialer, } return &upstreamResolver{ ctx: ctx, cancel: cancel, upstreamClient: client, upstreamTimeout: upstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, } } func (u *upstreamResolver) stop() { log.Debugf("stoping serving DNS for upstreams %s", u.upstreamServers) u.cancel() } // ServeDNS handles a DNS request func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { defer u.checkUpstreamFails() log.WithField("question", r.Question[0]).Debug("received an upstream question") select { case <-u.ctx.Done(): return default: } for _, upstream := range u.upstreamServers { log.Debugf("querying the upstream %s", upstream) rr, errR := netroute.New() if errR != nil { log.Errorf("unable to create networute: %s", errR) } else { add := netip.MustParseAddrPort(upstream) _, gateway, preferredSrc, errR := rr.Route(add.Addr().AsSlice()) if errR != nil { log.Errorf("getting routes returned an error: %v", errR) } else { log.Infof("upstream %s gateway: %s, preferredSrc: %s", add.Addr(), gateway, preferredSrc) } } ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) cancel() if err != nil { if err == context.DeadlineExceeded || isTimeout(err) { log.WithError(err).WithField("upstream", upstream). Warn("got an error while connecting to upstream") continue } u.failsCount.Add(1) log.WithError(err).WithField("upstream", upstream). Error("got an error while querying the upstream") return } log.Debugf("took %s to query the upstream %s", t, upstream) err = w.WriteMsg(rm) if err != nil { log.WithError(err).Error("got an error while writing the upstream resolver response") } // count the fails only if they happen sequentially u.failsCount.Store(0) return } u.failsCount.Add(1) log.Error("all queries to the upstream nameservers failed with timeout") } // checkUpstreamFails counts fails and disables or enables upstream resolving // // If fails count is greater that failsTillDeact, upstream resolving // will be disabled for reactivatePeriod, after that time period fails counter // will be reset and upstream will be reactivated. func (u *upstreamResolver) checkUpstreamFails() { u.mutex.Lock() defer u.mutex.Unlock() if u.failsCount.Load() < u.failsTillDeact || u.disabled { return } select { case <-u.ctx.Done(): return default: log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) u.deactivate() u.disabled = true go u.waitUntilResponse() } } // waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response func (u *upstreamResolver) waitUntilResponse() { exponentialBackOff := &backoff.ExponentialBackOff{ InitialInterval: 500 * time.Millisecond, RandomizationFactor: 0.5, Multiplier: 1.1, MaxInterval: u.reactivatePeriod, MaxElapsedTime: 0, Stop: backoff.Stop, Clock: backoff.SystemClock, } r := new(dns.Msg).SetQuestion("netbird.io.", dns.TypeA) operation := func() error { select { case <-u.ctx.Done(): return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) default: } var err error for _, upstream := range u.upstreamServers { ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) _, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream) cancel() if err == nil { return nil } } log.Tracef("checking connectivity with upstreams %s failed with error: %s. Retrying in %s", err, u.upstreamServers, exponentialBackOff.NextBackOff()) return fmt.Errorf("got an error from upstream check call") } err := backoff.Retry(operation, exponentialBackOff) if err != nil { log.Warn(err) return } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) u.failsCount.Store(0) u.reactivate() u.disabled = false } // isTimeout returns true if the given error is a network timeout error. // // Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout func isTimeout(err error) bool { var neterr net.Error if errors.As(err, &neterr) { return neterr != nil && neterr.Timeout() } return false }