[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)

[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
This commit is contained in:
hakansa 2025-04-15 15:54:17 +08:00 committed by GitHub
parent 4134b857b4
commit 51bb52cdf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 4 deletions

View File

@ -3,6 +3,7 @@ package dnsfwd
import (
"context"
"errors"
"math"
"net"
"net/netip"
"strings"
@ -62,7 +63,6 @@ func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string)
for _, d := range f.domains {
f.mux.HandleRemove(d)
f.statusRecorder.RemoveResolvedIPLookupEntry(d)
}
f.resId.Clear()
@ -122,8 +122,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
return
}
resId, ok := f.resId.Load(strings.TrimSuffix(domain, "."))
if ok {
resId := f.getResIdForDomain(strings.TrimSuffix(domain, "."))
if resId != "" {
for _, ip := range ips {
var ipWithSuffix string
if ip.Is4() {
@ -133,7 +133,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
ipWithSuffix = ip.String() + "/128"
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
}
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string))
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
}
}
@ -204,6 +204,36 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
}
}
func (f *DNSForwarder) getResIdForDomain(domain string) string {
var selectedResId string
var bestScore int
f.resId.Range(func(key, value interface{}) bool {
var score int
pattern := key.(string)
switch {
case strings.HasPrefix(pattern, "*."):
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
score = len(baseDomain)
}
case domain == pattern:
score = math.MaxInt
default:
return true
}
if score > bestScore {
bestScore = score
selectedResId = value.(string)
}
return true
})
return selectedResId
}
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))

View File

@ -0,0 +1,95 @@
package dnsfwd
import (
"sync"
"testing"
)
func TestGetResIdForDomain(t *testing.T) {
testCases := []struct {
name string
storedMappings map[string]string // key: domain pattern, value: resId
queryDomain string
expectedResId string
}{
{
name: "Empty map returns empty string",
storedMappings: map[string]string{},
queryDomain: "example.com",
expectedResId: "",
},
{
name: "Exact match returns stored resId",
storedMappings: map[string]string{"example.com": "res1"},
queryDomain: "example.com",
expectedResId: "res1",
},
{
name: "Wildcard pattern matches base domain",
storedMappings: map[string]string{"*.example.com": "res2"},
queryDomain: "example.com",
expectedResId: "res2",
},
{
name: "Wildcard pattern matches subdomain",
storedMappings: map[string]string{"*.example.com": "res3"},
queryDomain: "foo.example.com",
expectedResId: "res3",
},
{
name: "Wildcard pattern does not match different domain",
storedMappings: map[string]string{"*.example.com": "res4"},
queryDomain: "foo.notexample.com",
expectedResId: "",
},
{
name: "Non-wildcard pattern does not match subdomain",
storedMappings: map[string]string{"example.com": "res5"},
queryDomain: "foo.example.com",
expectedResId: "",
},
{
name: "Exact match over overlapping wildcard",
storedMappings: map[string]string{
"*.example.com": "resWildcard",
"foo.example.com": "resExact",
},
queryDomain: "foo.example.com",
expectedResId: "resExact",
},
{
name: "Overlapping wildcards: Select more specific wildcard",
storedMappings: map[string]string{
"*.example.com": "resA",
"*.sub.example.com": "resB",
},
queryDomain: "bar.sub.example.com",
expectedResId: "resB",
},
{
name: "Wildcard multi-level subdomain match",
storedMappings: map[string]string{
"*.example.com": "resMulti",
},
queryDomain: "a.b.example.com",
expectedResId: "resMulti",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fwd := &DNSForwarder{
resId: sync.Map{},
}
for domainPattern, resId := range tc.storedMappings {
fwd.resId.Store(domainPattern, resId)
}
got := fwd.getResIdForDomain(tc.queryDomain)
if got != tc.expectedResId {
t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
}
})
}
}