mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
This commit is contained in:
parent
4134b857b4
commit
51bb52cdf5
@ -3,6 +3,7 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@ -62,7 +63,6 @@ func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string)
|
||||
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleRemove(d)
|
||||
f.statusRecorder.RemoveResolvedIPLookupEntry(d)
|
||||
}
|
||||
f.resId.Clear()
|
||||
|
||||
@ -122,8 +122,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
resId, ok := f.resId.Load(strings.TrimSuffix(domain, "."))
|
||||
if ok {
|
||||
resId := f.getResIdForDomain(strings.TrimSuffix(domain, "."))
|
||||
if resId != "" {
|
||||
for _, ip := range ips {
|
||||
var ipWithSuffix string
|
||||
if ip.Is4() {
|
||||
@ -133,7 +133,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
ipWithSuffix = ip.String() + "/128"
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
|
||||
}
|
||||
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string))
|
||||
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,6 +204,36 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) getResIdForDomain(domain string) string {
|
||||
var selectedResId string
|
||||
var bestScore int
|
||||
|
||||
f.resId.Range(func(key, value interface{}) bool {
|
||||
var score int
|
||||
pattern := key.(string)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(pattern, "*."):
|
||||
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||
score = len(baseDomain)
|
||||
}
|
||||
case domain == pattern:
|
||||
score = math.MaxInt
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
selectedResId = value.(string)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return selectedResId
|
||||
}
|
||||
|
||||
// filterDomains returns a list of normalized domains
|
||||
func filterDomains(domains []string) []string {
|
||||
newDomains := make([]string, 0, len(domains))
|
||||
|
95
client/internal/dnsfwd/forwarder_test.go
Normal file
95
client/internal/dnsfwd/forwarder_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetResIdForDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storedMappings map[string]string // key: domain pattern, value: resId
|
||||
queryDomain string
|
||||
expectedResId string
|
||||
}{
|
||||
{
|
||||
name: "Empty map returns empty string",
|
||||
storedMappings: map[string]string{},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Exact match returns stored resId",
|
||||
storedMappings: map[string]string{"example.com": "res1"},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "res1",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern matches base domain",
|
||||
storedMappings: map[string]string{"*.example.com": "res2"},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "res2",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern matches subdomain",
|
||||
storedMappings: map[string]string{"*.example.com": "res3"},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "res3",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern does not match different domain",
|
||||
storedMappings: map[string]string{"*.example.com": "res4"},
|
||||
queryDomain: "foo.notexample.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Non-wildcard pattern does not match subdomain",
|
||||
storedMappings: map[string]string{"example.com": "res5"},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Exact match over overlapping wildcard",
|
||||
storedMappings: map[string]string{
|
||||
"*.example.com": "resWildcard",
|
||||
"foo.example.com": "resExact",
|
||||
},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "resExact",
|
||||
},
|
||||
{
|
||||
name: "Overlapping wildcards: Select more specific wildcard",
|
||||
storedMappings: map[string]string{
|
||||
"*.example.com": "resA",
|
||||
"*.sub.example.com": "resB",
|
||||
},
|
||||
queryDomain: "bar.sub.example.com",
|
||||
expectedResId: "resB",
|
||||
},
|
||||
{
|
||||
name: "Wildcard multi-level subdomain match",
|
||||
storedMappings: map[string]string{
|
||||
"*.example.com": "resMulti",
|
||||
},
|
||||
queryDomain: "a.b.example.com",
|
||||
expectedResId: "resMulti",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fwd := &DNSForwarder{
|
||||
resId: sync.Map{},
|
||||
}
|
||||
|
||||
for domainPattern, resId := range tc.storedMappings {
|
||||
fwd.resId.Store(domainPattern, resId)
|
||||
}
|
||||
|
||||
got := fwd.getResIdForDomain(tc.queryDomain)
|
||||
if got != tc.expectedResId {
|
||||
t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user