mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 01:38:41 +02:00
[client] Match more specific dns handler first (#3226)
This commit is contained in:
parent
2e61ce006d
commit
790a9ed7df
@ -105,17 +105,30 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
MatchSubdomains: matchSubdomains,
|
||||
}
|
||||
|
||||
// Insert handler in priority order
|
||||
pos := 0
|
||||
pos := c.findHandlerPosition(entry)
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
}
|
||||
|
||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
||||
for i, h := range c.handlers {
|
||||
if h.Priority < priority {
|
||||
pos = i
|
||||
break
|
||||
// prio first
|
||||
if h.Priority < newEntry.Priority {
|
||||
return i
|
||||
}
|
||||
|
||||
// domain specificity next
|
||||
if h.Priority == newEntry.Priority {
|
||||
newDots := strings.Count(newEntry.Pattern, ".")
|
||||
existingDots := strings.Count(h.Pattern, ".")
|
||||
if newDots > existingDots {
|
||||
return i
|
||||
}
|
||||
}
|
||||
pos = i + 1
|
||||
}
|
||||
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
// add at end
|
||||
return len(c.handlers)
|
||||
}
|
||||
|
||||
// RemoveHandler removes a handler for the given pattern and priority
|
||||
|
@ -677,3 +677,156 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
ops []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}
|
||||
query string
|
||||
expectedMatch string
|
||||
}{
|
||||
{
|
||||
name: "more specific domain matches first",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "more specific domain matches first, both match subdomains",
|
||||
scenario: "sub.example.com should match before example.com",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "maintain specificity order after removal",
|
||||
scenario: "after removing most specific, should fall back to less specific",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "test.sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "priority overrides specificity",
|
||||
scenario: "less specific domain with higher priority should match first",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "example.com.",
|
||||
},
|
||||
{
|
||||
name: "equal priority respects specificity",
|
||||
scenario: "with equal priority, more specific domain should match",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
|
||||
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
{
|
||||
name: "specific matches before wildcard",
|
||||
scenario: "specific domain should match before wildcard at same priority",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
subdomain bool
|
||||
}{
|
||||
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
|
||||
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
|
||||
},
|
||||
query: "sub.example.com.",
|
||||
expectedMatch: "sub.example.com.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
||||
|
||||
for _, op := range tt.ops {
|
||||
if op.action == "add" {
|
||||
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
|
||||
handlers[op.pattern] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
}
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// Setup handler expectations
|
||||
for pattern, handler := range handlers {
|
||||
if pattern == tt.expectedMatch {
|
||||
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetReply(r)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
for pattern, handler := range handlers {
|
||||
if pattern == tt.expectedMatch {
|
||||
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
|
||||
} else {
|
||||
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user