mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 02:08:40 +02:00
[client] Ignore case when matching domains in handler chain (#3133)
This commit is contained in:
parent
18316be09a
commit
43ef64cf67
@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
origPattern := pattern
|
||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||
if isWildcard {
|
||||
pattern = pattern[2:]
|
||||
}
|
||||
pattern = dns.Fqdn(pattern)
|
||||
origPattern = dns.Fqdn(origPattern)
|
||||
|
||||
// First remove any existing handler with same original pattern and priority
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
// Find and remove handlers matching both original pattern and priority
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if entry.OrigPattern == pattern && entry.Priority == priority {
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
for _, entry := range c.handlers {
|
||||
if entry.Pattern == pattern {
|
||||
if strings.EqualFold(entry.Pattern, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
qname := r.Question[0].Name
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
log.Tracef("handling DNS request for domain=%s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
// Otherwise require exact match
|
||||
if entry.MatchSubdomains {
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
} else {
|
||||
matched = qname == entry.Pattern
|
||||
matched = strings.EqualFold(qname, entry.Pattern)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
}
|
||||
|
||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
addHandlers []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}
|
||||
query string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "case insensitive exact match",
|
||||
scenario: "handler registered lowercase, query uppercase",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "case insensitive wildcard match",
|
||||
scenario: "handler registered mixed case wildcard, query different case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "sub.EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple handlers different case same domain",
|
||||
scenario: "second handler should replace first despite case difference",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "ExAmPlE.cOm.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "subdomain matching case insensitive",
|
||||
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||
},
|
||||
query: "SUB.EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone case insensitive",
|
||||
scenario: "root zone handler should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{".", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple handlers different priority",
|
||||
scenario: "should call higher priority handler despite case differences",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||
|
||||
// Add handlers according to test case
|
||||
for _, h := range tt.addHandlers {
|
||||
var handler dns.Handler
|
||||
pattern := h.pattern // capture pattern for closure
|
||||
|
||||
if h.subdomains {
|
||||
subHandler := &nbdns.MockSubdomainHandler{
|
||||
Subdomains: true,
|
||||
}
|
||||
if h.shouldMatch {
|
||||
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
handlerCalls[pattern] = true
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
handler = subHandler
|
||||
} else {
|
||||
mockHandler := &nbdns.MockHandler{}
|
||||
if h.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
handlerCalls[pattern] = true
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
handler = mockHandler
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||
|
||||
// Verify each handler was called exactly as expected
|
||||
for _, h := range tt.addHandlers {
|
||||
wasCalled := handlerCalls[h.pattern]
|
||||
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||
"Handler for pattern %q was %s when it should%s have been",
|
||||
h.pattern,
|
||||
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||
}
|
||||
|
||||
// Verify total number of calls
|
||||
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||
"Wrong number of total handler calls")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user