netbird/client/internal/dns/upstream_test.go
Maycon Santos b6105e9d7c
Use backoff.retry to check if upstreams are responsive (#901)
Retry, in an exponential interval, querying the upstream servers until it gets a positive response
2023-05-26 17:13:59 +02:00

179 lines
4.5 KiB
Go

package dns
import (
"context"
"strings"
"testing"
"time"
"github.com/miekg/dns"
)
func TestUpstreamResolver_ServeDNS(t *testing.T) {
testCases := []struct {
name string
inputMSG *dns.Msg
responseShouldBeNil bool
InputServers []string
timeout time.Duration
cancelCTX bool
expectedAnswer string
}{
{
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: upstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Resolve If First Upstream Times Out",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
timeout: 2 * time.Second,
expectedAnswer: "1.1.1.1",
},
{
name: "Should Not Resolve If Can't Connect To Both Servers",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.0.0.1:53"},
timeout: 200 * time.Millisecond,
responseShouldBeNil: true,
},
{
name: "Should Not Resolve If Parent Context Is Canceled",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: upstreamTimeout,
responseShouldBeNil: true,
},
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
}
// should resolve if first upstream times out
// should not write when both fails
// should not resolve if parent context is canceled
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver := newUpstreamResolver(ctx)
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
} else {
defer cancel()
}
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
foundAnswer := false
for _, answer := range responseMSG.Answer {
if strings.Contains(answer.String(), testCase.expectedAnswer) {
foundAnswer = true
break
}
}
if !foundAnswer {
t.Errorf("couldn't find the required answer, %s, in the dns response", testCase.expectedAnswer)
}
})
}
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// ExchangeContext mock implementation of ExchangeContext from upstreamResolver
func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) {
return c.r, c.rtt, c.err
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := &upstreamResolver{
ctx: context.TODO(),
upstreamClient: &mockUpstreamResolver{
err: nil,
r: new(dns.Msg),
rtt: time.Millisecond,
},
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
resolver.upstreamServers = []string{"0.0.0.0:-1"}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}
failed := false
resolver.deactivate = func() {
failed = true
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.failsCount.Load() != 0 {
t.Errorf("fails count after reactivation should be 0")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
}
}