From 51bb52cdf53b4b6f165915ffd51d57e923d6ceca Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:54:17 +0800 Subject: [PATCH] [client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651) [client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651) --- client/internal/dnsfwd/forwarder.go | 38 +++++++++- client/internal/dnsfwd/forwarder_test.go | 95 ++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 4 deletions(-) create mode 100644 client/internal/dnsfwd/forwarder_test.go diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 097daa9e2..2d69ce858 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -3,6 +3,7 @@ package dnsfwd import ( "context" "errors" + "math" "net" "net/netip" "strings" @@ -62,7 +63,6 @@ func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) for _, d := range f.domains { f.mux.HandleRemove(d) - f.statusRecorder.RemoveResolvedIPLookupEntry(d) } f.resId.Clear() @@ -122,8 +122,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { return } - resId, ok := f.resId.Load(strings.TrimSuffix(domain, ".")) - if ok { + resId := f.getResIdForDomain(strings.TrimSuffix(domain, ".")) + if resId != "" { for _, ip := range ips { var ipWithSuffix string if ip.Is4() { @@ -133,7 +133,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { ipWithSuffix = ip.String() + "/128" log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix) } - f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId.(string)) + f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId) } } @@ -204,6 +204,36 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti } } +func (f *DNSForwarder) getResIdForDomain(domain string) string { + var selectedResId string + var bestScore int + + f.resId.Range(func(key, value interface{}) bool { + var score int + pattern := key.(string) + + switch { + case strings.HasPrefix(pattern, "*."): + baseDomain := strings.TrimPrefix(pattern, "*.") + if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + score = len(baseDomain) + } + case domain == pattern: + score = math.MaxInt + default: + return true + } + + if score > bestScore { + bestScore = score + selectedResId = value.(string) + } + return true + }) + + return selectedResId +} + // filterDomains returns a list of normalized domains func filterDomains(domains []string) []string { newDomains := make([]string, 0, len(domains)) diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go new file mode 100644 index 000000000..88ffc2af3 --- /dev/null +++ b/client/internal/dnsfwd/forwarder_test.go @@ -0,0 +1,95 @@ +package dnsfwd + +import ( + "sync" + "testing" +) + +func TestGetResIdForDomain(t *testing.T) { + testCases := []struct { + name string + storedMappings map[string]string // key: domain pattern, value: resId + queryDomain string + expectedResId string + }{ + { + name: "Empty map returns empty string", + storedMappings: map[string]string{}, + queryDomain: "example.com", + expectedResId: "", + }, + { + name: "Exact match returns stored resId", + storedMappings: map[string]string{"example.com": "res1"}, + queryDomain: "example.com", + expectedResId: "res1", + }, + { + name: "Wildcard pattern matches base domain", + storedMappings: map[string]string{"*.example.com": "res2"}, + queryDomain: "example.com", + expectedResId: "res2", + }, + { + name: "Wildcard pattern matches subdomain", + storedMappings: map[string]string{"*.example.com": "res3"}, + queryDomain: "foo.example.com", + expectedResId: "res3", + }, + { + name: "Wildcard pattern does not match different domain", + storedMappings: map[string]string{"*.example.com": "res4"}, + queryDomain: "foo.notexample.com", + expectedResId: "", + }, + { + name: "Non-wildcard pattern does not match subdomain", + storedMappings: map[string]string{"example.com": "res5"}, + queryDomain: "foo.example.com", + expectedResId: "", + }, + { + name: "Exact match over overlapping wildcard", + storedMappings: map[string]string{ + "*.example.com": "resWildcard", + "foo.example.com": "resExact", + }, + queryDomain: "foo.example.com", + expectedResId: "resExact", + }, + { + name: "Overlapping wildcards: Select more specific wildcard", + storedMappings: map[string]string{ + "*.example.com": "resA", + "*.sub.example.com": "resB", + }, + queryDomain: "bar.sub.example.com", + expectedResId: "resB", + }, + { + name: "Wildcard multi-level subdomain match", + storedMappings: map[string]string{ + "*.example.com": "resMulti", + }, + queryDomain: "a.b.example.com", + expectedResId: "resMulti", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + fwd := &DNSForwarder{ + resId: sync.Map{}, + } + + for domainPattern, resId := range tc.storedMappings { + fwd.resId.Store(domainPattern, resId) + } + + got := fwd.getResIdForDomain(tc.queryDomain) + if got != tc.expectedResId { + t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got) + } + }) + } +}