mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 02:08:40 +02:00
[client] Ignore case when matching domains in handler chain (#3133)
This commit is contained in:
parent
18316be09a
commit
43ef64cf67
@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
origPattern := pattern
|
origPattern := pattern
|
||||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||||
if isWildcard {
|
if isWildcard {
|
||||||
pattern = pattern[2:]
|
pattern = pattern[2:]
|
||||||
}
|
}
|
||||||
pattern = dns.Fqdn(pattern)
|
|
||||||
origPattern = dns.Fqdn(origPattern)
|
|
||||||
|
|
||||||
// First remove any existing handler with same original pattern and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
if c.handlers[i].StopHandler != nil {
|
if c.handlers[i].StopHandler != nil {
|
||||||
c.handlers[i].StopHandler.stop()
|
c.handlers[i].StopHandler.stop()
|
||||||
}
|
}
|
||||||
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
// Find and remove handlers matching both original pattern and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if entry.OrigPattern == pattern && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
if entry.StopHandler != nil {
|
if entry.StopHandler != nil {
|
||||||
entry.StopHandler.stop()
|
entry.StopHandler.stop()
|
||||||
}
|
}
|
||||||
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
for _, entry := range c.handlers {
|
for _, entry := range c.handlers {
|
||||||
if entry.Pattern == pattern {
|
if strings.EqualFold(entry.Pattern, pattern) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
qname := r.Question[0].Name
|
qname := strings.ToLower(r.Question[0].Name)
|
||||||
log.Tracef("handling DNS request for domain=%s", qname)
|
log.Tracef("handling DNS request for domain=%s", qname)
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
// If handler wants subdomain matching, allow suffix match
|
// If handler wants subdomain matching, allow suffix match
|
||||||
// Otherwise require exact match
|
// Otherwise require exact match
|
||||||
if entry.MatchSubdomains {
|
if entry.MatchSubdomains {
|
||||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
} else {
|
} else {
|
||||||
matched = qname == entry.Pattern
|
matched = strings.EqualFold(qname, entry.Pattern)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
|
|
||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
assert.False(t, chain.HasHandlers(testDomain))
|
assert.False(t, chain.HasHandlers(testDomain))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
addHandlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "case insensitive exact match",
|
||||||
|
scenario: "handler registered lowercase, query uppercase",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive wildcard match",
|
||||||
|
scenario: "handler registered mixed case wildcard, query different case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "sub.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different case same domain",
|
||||||
|
scenario: "second handler should replace first despite case difference",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "ExAmPlE.cOm.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain matching case insensitive",
|
||||||
|
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||||
|
},
|
||||||
|
query: "SUB.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone case insensitive",
|
||||||
|
scenario: "root zone handler should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{".", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different priority",
|
||||||
|
scenario: "should call higher priority handler despite case differences",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||||
|
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||||
|
|
||||||
|
// Add handlers according to test case
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
var handler dns.Handler
|
||||||
|
pattern := h.pattern // capture pattern for closure
|
||||||
|
|
||||||
|
if h.subdomains {
|
||||||
|
subHandler := &nbdns.MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
if h.shouldMatch {
|
||||||
|
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = subHandler
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
if h.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = mockHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||||
|
|
||||||
|
// Verify each handler was called exactly as expected
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
wasCalled := handlerCalls[h.pattern]
|
||||||
|
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||||
|
"Handler for pattern %q was %s when it should%s have been",
|
||||||
|
h.pattern,
|
||||||
|
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||||
|
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify total number of calls
|
||||||
|
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||||
|
"Wrong number of total handler calls")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user