mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
[client] Skip dns upstream servers pointing to our dns server IP to prevent loops (#4330)
This commit is contained in:
@ -2056,3 +2056,124 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
func TestDNSLoopPrevention(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
dnsServerIP := service.RuntimeIP()
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgInterface,
|
||||
service: service,
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nsGroups []*nbdns.NameServerGroup
|
||||
expectedHandlers int
|
||||
expectedServers []netip.Addr
|
||||
shouldFilterOwnIP bool
|
||||
}{
|
||||
{
|
||||
name: "FilterOwnDNSServerIP",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 1,
|
||||
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||
shouldFilterOwnIP: true,
|
||||
},
|
||||
{
|
||||
name: "AllServersFiltered",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: false,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 0,
|
||||
expectedServers: []netip.Addr{},
|
||||
shouldFilterOwnIP: true,
|
||||
},
|
||||
{
|
||||
name: "MixedServersWithOwnIP",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: false,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
|
||||
},
|
||||
Domains: []string{"test.com"},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 1,
|
||||
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||
shouldFilterOwnIP: true,
|
||||
},
|
||||
{
|
||||
name: "NoOwnIPInList",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 1,
|
||||
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||
shouldFilterOwnIP: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, muxUpdates, tt.expectedHandlers)
|
||||
|
||||
if tt.expectedHandlers > 0 {
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
||||
|
||||
if tt.shouldFilterOwnIP {
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range tt.expectedServers {
|
||||
found := false
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
if upstream.Addr() == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected server %s not found", expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user