mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 10:31:45 +02:00
[client] Eliminate upstream server strings in dns code (#4267)
This commit is contained in:
@ -97,9 +97,9 @@ func init() {
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []string
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, getNSHostPort(srv))
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
var dnsList []string
|
||||
var dnsList []netip.AddrPort
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
|
||||
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
func TestFormatAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
address string
|
||||
port int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
address: "8.8.8.8",
|
||||
port: 53,
|
||||
expected: "8.8.8.8:53",
|
||||
},
|
||||
{
|
||||
name: "IPv4 address with custom port",
|
||||
address: "1.1.1.1",
|
||||
port: 5353,
|
||||
expected: "1.1.1.1:5353",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
address: "fd78:94bf:7df8::1",
|
||||
port: 53,
|
||||
expected: "[fd78:94bf:7df8::1]:53",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address with custom port",
|
||||
address: "2001:db8::1",
|
||||
port: 5353,
|
||||
expected: "[2001:db8::1]:5353",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
address: "::1",
|
||||
port: 53,
|
||||
expected: "[::1]:53",
|
||||
},
|
||||
{
|
||||
name: "Invalid address treated as hostname",
|
||||
address: "dns.example.com",
|
||||
port: 53,
|
||||
expected: "dns.example.com:53",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := formatAddr(tt.address, tt.port)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user