mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-26 04:31:56 +02:00
[client] Tighten allowed domains for dns forwarder (#3978)
This commit is contained in:
parent
75c1be69cf
commit
de7384e8ea
@ -2,6 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
requestID := GenerateRequestID()
|
||||||
|
logger := log.WithField("request_id", requestID)
|
||||||
var err error
|
var err error
|
||||||
defer func() {
|
defer func() {
|
||||||
u.checkUpstreamFails(err)
|
u.checkUpstreamFails(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
log.Tracef("%s has been stopped", u)
|
logger.Tracef("%s has been stopped", u)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
|
||||||
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rm == nil || !rm.Response {
|
if rm == nil || !rm.Response {
|
||||||
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
|
||||||
|
|
||||||
if err = w.WriteMsg(rm); err != nil {
|
if err = w.WriteMsg(rm); err != nil {
|
||||||
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
// count the fails only if they happen sequentially
|
// count the fails only if they happen sequentially
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
u.failsCount.Add(1)
|
u.failsCount.Add(1)
|
||||||
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
|
||||||
|
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(r, dns.RcodeServerFailure)
|
m.SetRcode(r, dns.RcodeServerFailure)
|
||||||
if err := w.WriteMsg(m); err != nil {
|
if err := w.WriteMsg(m); err != nil {
|
||||||
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
|
|
||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateRequestID() string {
|
||||||
|
bytes := make([]byte, 4)
|
||||||
|
_, err := rand.Read(bytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to generate request ID: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
||||||
|
@ -18,14 +18,20 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
const upstreamTimeout = 15 * time.Second
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type firewaller interface {
|
||||||
|
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
|
||||||
|
}
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress string
|
||||||
ttl uint32
|
ttl uint32
|
||||||
@ -38,16 +44,18 @@ type DNSForwarder struct {
|
|||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
firewall firewall.Manager
|
firewall firewaller
|
||||||
|
resolver resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
resolver: net.DefaultResolver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
|||||||
// UDP server
|
// UDP server
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
f.mux = mux
|
f.mux = mux
|
||||||
|
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||||
f.dnsServer = &dns.Server{
|
f.dnsServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Addr: f.listenAddress,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCP server
|
// TCP server
|
||||||
tcpMux := dns.NewServeMux()
|
tcpMux := dns.NewServeMux()
|
||||||
f.tcpMux = tcpMux
|
f.tcpMux = tcpMux
|
||||||
|
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||||
f.tcpServer = &dns.Server{
|
f.tcpServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Addr: f.listenAddress,
|
||||||
Net: "tcp",
|
Net: "tcp",
|
||||||
@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
|||||||
// return the first error we get (e.g. bind failure or shutdown)
|
// return the first error we get (e.g. bind failure or shutdown)
|
||||||
return <-errCh
|
return <-errCh
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
if f.mux == nil {
|
|
||||||
log.Debug("DNS mux is nil, skipping domain update")
|
|
||||||
f.fwdEntries = entries
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
oldDomains := filterDomains(f.fwdEntries)
|
|
||||||
for _, d := range oldDomains {
|
|
||||||
f.mux.HandleRemove(d.PunycodeString())
|
|
||||||
f.tcpMux.HandleRemove(d.PunycodeString())
|
|
||||||
}
|
|
||||||
|
|
||||||
newDomains := filterDomains(entries)
|
|
||||||
for _, d := range newDomains {
|
|
||||||
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
|
|
||||||
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||||
|
// query doesn't match any configured domain
|
||||||
|
if mostSpecificResId == "" {
|
||||||
|
resp.Rcode = dns.RcodeRefused
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.handleDNSError(w, query, resp, domain, err)
|
f.handleDNSError(w, query, resp, domain, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(domain, ips)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
|
||||||
resp := f.handleDNSQuery(w, query)
|
resp := f.handleDNSQuery(w, query)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return
|
return
|
||||||
@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
var prefixes []netip.Prefix
|
var prefixes []netip.Prefix
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
|
||||||
if mostSpecificResId != "" {
|
if mostSpecificResId != "" {
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
|
|||||||
|
|
||||||
return selectedResId, matches
|
return selectedResId, matches
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterDomains returns a list of normalized domains
|
|
||||||
func filterDomains(entries []*ForwarderEntry) domain.List {
|
|
||||||
newDomains := make(domain.List, 0, len(entries))
|
|
||||||
for _, d := range entries {
|
|
||||||
if d.Domain == "" {
|
|
||||||
log.Warn("empty domain in DNS forwarder")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
|
|
||||||
}
|
|
||||||
return newDomains
|
|
||||||
}
|
|
||||||
|
@ -1,11 +1,21 @@
|
|||||||
package dnsfwd
|
package dnsfwd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -13,7 +23,7 @@ import (
|
|||||||
func Test_getMatchingEntries(t *testing.T) {
|
func Test_getMatchingEntries(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
storedMappings map[string]route.ResID // key: domain pattern, value: resId
|
storedMappings map[string]route.ResID
|
||||||
queryDomain string
|
queryDomain string
|
||||||
expectedResId route.ResID
|
expectedResId route.ResID
|
||||||
}{
|
}{
|
||||||
@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Wildcard pattern does not match different domain",
|
name: "Wildcard pattern does not match different domain",
|
||||||
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
||||||
queryDomain: "foo.notexample.com",
|
queryDomain: "foo.example.org",
|
||||||
expectedResId: "",
|
expectedResId: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MockFirewall struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
args := m.Called(set, prefixes)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockResolver struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
args := m.Called(ctx, network, host)
|
||||||
|
return args.Get(0).([]netip.Addr), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomain string
|
||||||
|
queryDomain string
|
||||||
|
shouldMatch bool
|
||||||
|
expectedResID route.ResID
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain match should be allowed",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Direct match to configured domain should work",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain access should be restricted",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldMatch: false,
|
||||||
|
expectedResID: "",
|
||||||
|
description: "Subdomain should not be accessible unless explicitly configured",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard should allow subdomains",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard domains should allow subdomain access",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard should allow base domain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard should also match the base domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain should be restricted",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldMatch: false,
|
||||||
|
expectedResID: "",
|
||||||
|
description: "Deep subdomains should not be accessible",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows deep subdomains",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldMatch: true,
|
||||||
|
expectedResID: "test-res-id",
|
||||||
|
description: "Wildcard should allow deep subdomains",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
forwarder := &DNSForwarder{}
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{
|
||||||
|
{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
|
||||||
|
|
||||||
|
if tt.shouldMatch {
|
||||||
|
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
|
||||||
|
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
|
||||||
|
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
|
||||||
|
assert.Empty(t, matchingEntries, "Expected no matching entries")
|
||||||
|
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomain string
|
||||||
|
queryDomain string
|
||||||
|
shouldResolve bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "configured exact domain resolves",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Exact match should resolve",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized subdomain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Subdomain should be blocked without wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows subdomain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "mail.example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow subdomain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows base domain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow base domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unrelated domain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "example.org",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Unrelated domain should be blocked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain blocked",
|
||||||
|
configuredDomain: "example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldResolve: false,
|
||||||
|
description: "Deep subdomain should be blocked",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard allows deep subdomain",
|
||||||
|
configuredDomain: "*.example.com",
|
||||||
|
queryDomain: "deep.mail.example.com",
|
||||||
|
shouldResolve: true,
|
||||||
|
description: "Wildcard should allow deep subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
if tt.shouldResolve {
|
||||||
|
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
|
||||||
|
|
||||||
|
// Mock successful DNS resolution
|
||||||
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configuredDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{
|
||||||
|
{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res-id",
|
||||||
|
Set: firewall.NewDomainSet([]domain.Domain{d}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
if tt.shouldResolve {
|
||||||
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
|
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
} else {
|
||||||
|
if resp != nil {
|
||||||
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
|
"Unauthorized domain should not return successful answers")
|
||||||
|
}
|
||||||
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configuredDomains []string
|
||||||
|
query string
|
||||||
|
mockIP string
|
||||||
|
shouldResolve bool
|
||||||
|
expectedSetCount int // How many sets should be updated
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain gets firewall update",
|
||||||
|
configuredDomains: []string{"example.com"},
|
||||||
|
query: "example.com",
|
||||||
|
mockIP: "1.1.1.1",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 1,
|
||||||
|
description: "Single exact match updates one set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard domain gets firewall update",
|
||||||
|
configuredDomains: []string{"*.example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.2",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 1,
|
||||||
|
description: "Wildcard match updates one set",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overlapping exact and wildcard both get updates",
|
||||||
|
configuredDomains: []string{"*.example.com", "mail.example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.3",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 2,
|
||||||
|
description: "Both exact and wildcard sets should be updated",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized domain gets no firewall update",
|
||||||
|
configuredDomains: []string{"example.com"},
|
||||||
|
query: "mail.example.com",
|
||||||
|
mockIP: "1.1.1.4",
|
||||||
|
shouldResolve: false,
|
||||||
|
expectedSetCount: 0,
|
||||||
|
description: "No firewall update for unauthorized domains",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards matching get all updated",
|
||||||
|
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
|
||||||
|
query: "test.sub.example.com",
|
||||||
|
mockIP: "1.1.1.5",
|
||||||
|
shouldResolve: true,
|
||||||
|
expectedSetCount: 2,
|
||||||
|
description: "All matching wildcard sets should be updated",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
// Set up forwarder
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Create entries and track sets
|
||||||
|
var entries []*ForwarderEntry
|
||||||
|
sets := make([]firewall.Set, 0)
|
||||||
|
|
||||||
|
for i, configDomain := range tt.configuredDomains {
|
||||||
|
d, err := domain.FromString(configDomain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
sets = append(sets, set)
|
||||||
|
|
||||||
|
entries = append(entries, &ForwarderEntry{
|
||||||
|
Domain: d,
|
||||||
|
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
|
||||||
|
Set: set,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Set up mocks
|
||||||
|
if tt.shouldResolve {
|
||||||
|
fakeIP := netip.MustParseAddr(tt.mockIP)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
|
||||||
|
Return([]netip.Addr{fakeIP}, nil).Once()
|
||||||
|
|
||||||
|
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
|
||||||
|
|
||||||
|
// Count how many sets should actually match
|
||||||
|
updateCount := 0
|
||||||
|
for i, entry := range entries {
|
||||||
|
domain := strings.ToLower(tt.query)
|
||||||
|
pattern := entry.Domain.PunycodeString()
|
||||||
|
|
||||||
|
matches := false
|
||||||
|
if strings.HasPrefix(pattern, "*.") {
|
||||||
|
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||||
|
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
} else if domain == pattern {
|
||||||
|
matches = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches {
|
||||||
|
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
|
||||||
|
updateCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedSetCount, updateCount,
|
||||||
|
"Expected %d sets to be updated, but mock expects %d",
|
||||||
|
tt.expectedSetCount, updateCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
dnsQuery := &dns.Msg{}
|
||||||
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if tt.shouldResolve {
|
||||||
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.NotEmpty(t, resp.Answer)
|
||||||
|
} else if resp != nil {
|
||||||
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
|
"Unauthorized domain should be refused or have no answers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all mock expectations were met
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
|
||||||
|
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Configure a single domain
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
entries := []*ForwarderEntry{{
|
||||||
|
Domain: d,
|
||||||
|
ResID: "test-res",
|
||||||
|
Set: set,
|
||||||
|
}}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Mock resolver returns multiple IPs
|
||||||
|
ips := []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
netip.MustParseAddr("1.1.1.2"),
|
||||||
|
netip.MustParseAddr("1.1.1.3"),
|
||||||
|
}
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||||
|
Return(ips, nil).Once()
|
||||||
|
|
||||||
|
// Expect ONE UpdateSet call with ALL prefixes
|
||||||
|
expectedPrefixes := []netip.Prefix{
|
||||||
|
netip.PrefixFrom(ips[0], 32),
|
||||||
|
netip.PrefixFrom(ips[1], 32),
|
||||||
|
netip.PrefixFrom(ips[2], 32),
|
||||||
|
}
|
||||||
|
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
|
||||||
|
|
||||||
|
// Execute query
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// Verify response contains all IPs
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryType uint16
|
||||||
|
queryDomain string
|
||||||
|
configured string
|
||||||
|
expectedCode int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unauthorized domain returns REFUSED",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
queryDomain: "evil.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeRefused,
|
||||||
|
description: "RFC compliant REFUSED for unauthorized queries",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported query type returns NOTIMP",
|
||||||
|
queryType: dns.TypeMX,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "RFC compliant NOTIMP for unsupported types",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CNAME query returns NOTIMP",
|
||||||
|
queryType: dns.TypeCNAME,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "CNAME queries not supported",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TXT query returns NOTIMP",
|
||||||
|
queryType: dns.TypeTXT,
|
||||||
|
queryDomain: "example.com",
|
||||||
|
configured: "example.com",
|
||||||
|
expectedCode: dns.RcodeNotImplemented,
|
||||||
|
description: "TXT queries not supported",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|
||||||
|
d, err := domain.FromString(tt.configured)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
|
||||||
|
|
||||||
|
// Capture the written response
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
// Check the response written to the writer
|
||||||
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
|
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||||
|
// Test that large UDP responses are truncated with TC bit set
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, _ := domain.FromString("example.com")
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Mock many IPs to create a large response
|
||||||
|
var manyIPs []netip.Addr
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
|
||||||
|
}
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
|
||||||
|
|
||||||
|
// Query without EDNS0
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writtenResp = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
forwarder.handleDNSQueryUDP(mockWriter, query)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp)
|
||||||
|
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
|
||||||
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||||
|
// Test complex overlapping pattern scenarios
|
||||||
|
mockFirewall := &MockFirewall{}
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
// Set up complex overlapping patterns
|
||||||
|
patterns := []string{
|
||||||
|
"*.example.com", // Matches all subdomains
|
||||||
|
"*.mail.example.com", // More specific wildcard
|
||||||
|
"smtp.mail.example.com", // Exact match
|
||||||
|
"example.com", // Base domain
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []*ForwarderEntry
|
||||||
|
sets := make(map[string]firewall.Set)
|
||||||
|
|
||||||
|
for _, pattern := range patterns {
|
||||||
|
d, _ := domain.FromString(pattern)
|
||||||
|
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||||
|
sets[pattern] = set
|
||||||
|
entries = append(entries, &ForwarderEntry{
|
||||||
|
Domain: d,
|
||||||
|
ResID: route.ResID("res-" + pattern),
|
||||||
|
Set: set,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
// Test smtp.mail.example.com - should match 3 patterns
|
||||||
|
fakeIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
|
||||||
|
|
||||||
|
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
|
||||||
|
// All three matching patterns should get firewall updates
|
||||||
|
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
mockWriter := &test.MockResponseWriter{}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
|
// Verify all three sets were updated
|
||||||
|
mockFirewall.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Verify the most specific ResID was selected
|
||||||
|
// (exact match should win over wildcards)
|
||||||
|
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
|
||||||
|
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
|
||||||
|
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||||
|
// Test handling of malformed query with no questions
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
|
||||||
|
query := &dns.Msg{}
|
||||||
|
// Don't set any question
|
||||||
|
|
||||||
|
writeCalled := false
|
||||||
|
mockWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
writeCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||||
|
|
||||||
|
assert.Nil(t, resp, "Should return nil for empty query")
|
||||||
|
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||||
|
}
|
||||||
|
@ -144,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
|||||||
|
|
||||||
// ServeDNS implements the dns.Handler interface
|
// ServeDNS implements the dns.Handler interface
|
||||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
requestID := nbdns.GenerateRequestID()
|
||||||
|
logger := log.WithField("request_id", requestID)
|
||||||
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
// pass if non A/AAAA query
|
// pass if non A/AAAA query
|
||||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||||
d.continueToNextHandler(w, r, "non A/AAAA query")
|
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,13 +164,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
|
|
||||||
if peerKey == "" {
|
if peerKey == "" {
|
||||||
d.writeDNSError(w, r, "no current peer key")
|
d.writeDNSError(w, r, logger, "no current peer key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
|
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,9 +187,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||||
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -196,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
answer = reply.Answer
|
answer = reply.Answer
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
reply.Id = r.Id
|
reply.Id = r.Id
|
||||||
if err := d.writeMsg(w, reply); err != nil {
|
if err := d.writeMsg(w, reply); err != nil {
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||||
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
resp := new(dns.Msg)
|
||||||
resp.SetRcode(r, dns.RcodeServerFailure)
|
resp.SetRcode(r, dns.RcodeServerFailure)
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS error response: %v", err)
|
logger.Errorf("failed to write DNS error response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// continueToNextHandler signals the handler chain to try the next handler
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
|
||||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
resp := new(dns.Msg)
|
||||||
resp.SetRcode(r, dns.RcodeNameError)
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
// Set Zero bit to signal handler chain to continue
|
// Set Zero bit to signal handler chain to continue
|
||||||
resp.MsgHdr.Zero = true
|
resp.MsgHdr.Zero = true
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed writing DNS continue response: %v", err)
|
logger.Errorf("failed writing DNS continue response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user