mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 19:43:57 +01:00
eb45310c8f
* Disable upstream DNS resolver after several tries and fails * Add tests for upstream fails * Use an extra flag to disable domains in DNS upstreams * Fix hashing IPs of nameservers for updates.
144 lines
3.3 KiB
Go
144 lines
3.3 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
failsTillDeact = int32(3)
|
|
reactivatePeriod = time.Minute
|
|
upstreamTimeout = 15 * time.Second
|
|
)
|
|
|
|
type upstreamResolver struct {
|
|
ctx context.Context
|
|
upstreamClient *dns.Client
|
|
upstreamServers []string
|
|
disabled bool
|
|
failsCount atomic.Int32
|
|
failsTillDeact int32
|
|
mutex sync.Mutex
|
|
reactivatePeriod time.Duration
|
|
upstreamTimeout time.Duration
|
|
|
|
deactivate func()
|
|
reactivate func()
|
|
}
|
|
|
|
func newUpstreamResolver(ctx context.Context) *upstreamResolver {
|
|
return &upstreamResolver{
|
|
ctx: ctx,
|
|
upstreamClient: &dns.Client{},
|
|
upstreamTimeout: upstreamTimeout,
|
|
reactivatePeriod: reactivatePeriod,
|
|
failsTillDeact: failsTillDeact,
|
|
}
|
|
}
|
|
|
|
// ServeDNS handles a DNS request
|
|
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
defer u.checkUpstreamFails()
|
|
|
|
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
|
|
|
select {
|
|
case <-u.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
for _, upstream := range u.upstreamServers {
|
|
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
|
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
|
|
|
|
cancel()
|
|
|
|
if err != nil {
|
|
if err == context.DeadlineExceeded || isTimeout(err) {
|
|
log.WithError(err).WithField("upstream", upstream).
|
|
Warn("got an error while connecting to upstream")
|
|
continue
|
|
}
|
|
u.failsCount.Add(1)
|
|
log.WithError(err).WithField("upstream", upstream).
|
|
Error("got an error while querying the upstream")
|
|
return
|
|
}
|
|
|
|
log.Tracef("took %s to query the upstream %s", t, upstream)
|
|
|
|
err = w.WriteMsg(rm)
|
|
if err != nil {
|
|
log.WithError(err).Error("got an error while writing the upstream resolver response")
|
|
}
|
|
// count the fails only if they happen sequentially
|
|
u.failsCount.Store(0)
|
|
return
|
|
}
|
|
u.failsCount.Add(1)
|
|
log.Error("all queries to the upstream nameservers failed with timeout")
|
|
}
|
|
|
|
// checkUpstreamFails counts fails and disables or enables upstream resolving
|
|
//
|
|
// If fails count is greater that failsTillDeact, upstream resolving
|
|
// will be disabled for reactivatePeriod, after that time period fails counter
|
|
// will be reset and upstream will be reactivated.
|
|
func (u *upstreamResolver) checkUpstreamFails() {
|
|
u.mutex.Lock()
|
|
defer u.mutex.Unlock()
|
|
|
|
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-u.ctx.Done():
|
|
return
|
|
default:
|
|
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
|
|
u.deactivate()
|
|
u.disabled = true
|
|
go u.waitUntilReactivation()
|
|
}
|
|
}
|
|
|
|
// waitUntilReactivation reset fails counter and activates upstream resolving
|
|
func (u *upstreamResolver) waitUntilReactivation() {
|
|
timer := time.NewTimer(u.reactivatePeriod)
|
|
defer func() {
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-u.ctx.Done():
|
|
return
|
|
case <-timer.C:
|
|
log.Info("upstream resolving is reactivated")
|
|
u.failsCount.Store(0)
|
|
u.reactivate()
|
|
u.disabled = false
|
|
}
|
|
}
|
|
|
|
// isTimeout returns true if the given error is a network timeout error.
|
|
//
|
|
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
|
func isTimeout(err error) bool {
|
|
var neterr net.Error
|
|
if errors.As(err, &neterr) {
|
|
return neterr != nil && neterr.Timeout()
|
|
}
|
|
return false
|
|
}
|