diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index ce75369c8..cbcf6a256 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -695,6 +695,12 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) continue } + + if ns.IP == s.service.RuntimeIP() { + log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) + continue + } + handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 91543da8f..068f001d8 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -2056,3 +2056,124 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } + +func TestDNSLoopPrevention(t *testing.T) { + wgInterface := &mocWGIface{} + service := NewServiceViaMemory(wgInterface) + dnsServerIP := service.RuntimeIP() + + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgInterface, + service: service, + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: &noopHostConfigurator{}, + dnsMuxMap: make(registeredHandlerMap), + } + + tests := []struct { + name string + nsGroups []*nbdns.NameServerGroup + expectedHandlers int + expectedServers []netip.Addr + shouldFilterOwnIP bool + }{ + { + name: "FilterOwnDNSServerIP", + nsGroups: []*nbdns.NameServerGroup{ + { + Primary: true, + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{}, + }, + }, + expectedHandlers: 1, + expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")}, + shouldFilterOwnIP: true, + }, + { + name: "AllServersFiltered", + nsGroups: []*nbdns.NameServerGroup{ + { + Primary: false, + NameServers: []nbdns.NameServer{ + {IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + }, + expectedHandlers: 0, + expectedServers: []netip.Addr{}, + shouldFilterOwnIP: true, + }, + { + name: "MixedServersWithOwnIP", + nsGroups: []*nbdns.NameServerGroup{ + { + Primary: false, + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate + }, + Domains: []string{"test.com"}, + }, + }, + expectedHandlers: 1, + expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")}, + shouldFilterOwnIP: true, + }, + { + name: "NoOwnIPInList", + nsGroups: []*nbdns.NameServerGroup{ + { + Primary: true, + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{}, + }, + }, + expectedHandlers: 1, + expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")}, + shouldFilterOwnIP: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups) + assert.NoError(t, err) + assert.Len(t, muxUpdates, tt.expectedHandlers) + + if tt.expectedHandlers > 0 { + handler := muxUpdates[0].handler.(*upstreamResolver) + assert.Len(t, handler.upstreamServers, len(tt.expectedServers)) + + if tt.shouldFilterOwnIP { + for _, upstream := range handler.upstreamServers { + assert.NotEqual(t, dnsServerIP, upstream.Addr()) + } + } + + for _, expected := range tt.expectedServers { + found := false + for _, upstream := range handler.upstreamServers { + if upstream.Addr() == expected { + found = true + break + } + } + assert.True(t, found, "Expected server %s not found", expected) + } + } + }) + } +}