Improve dns forwarder errors and improve domain anonymization (#3052)

* Improve dns forwarder errors and improve domain anonymization


* Use original domain for dns states

* Don't match subdomains for non-wildcard dns routes

* Fix iOS

* Add string representation for local resolver

* Return correct handler for dynamic

* Add dns server dns route + upstream handler test
This commit is contained in:
Viktor Liu 2024-12-17 12:57:07 +01:00 committed by GitHub
parent 228672aed2
commit 21ba6ad266
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 429 additions and 161 deletions

View File

@ -21,6 +21,8 @@ type Anonymizer struct {
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
domainKeyRegex *regexp.Regexp
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
}
}
@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.SplitN(match, "=", 2)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, anonTLD) {
return match
}
randomDomain := generateRandomString(10) + anonTLD
return strings.Replace(match, domain, randomDomain, 1)
return "domain=" + a.AnonymizeDomain(domain)
}
return match
})

View File

@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
tests := []struct {
name string
input string
original string
expect string
}{
{
name: "Basic domain with trailing content",
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
original: "example.com",
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
},
{
name: "Domain with trailing dot",
input: "domain=example.com. processing request with status=pending",
original: "example.com",
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
},
{
name: "Multiple domains in log",
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
},
{
name: "Already anonymized domain",
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
original: "", // nothing should be anonymized
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
},
{
name: "Subdomain with trailing dot",
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
original: "example.com",
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
},
{
name: "Handler chain pattern log",
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
original: "example.com",
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
},
}
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDNSLogLine(tc.input)
if tc.original != "" {
assert.NotContains(t, result, tc.original)
}
assert.Regexp(t, tc.expect, result)
})
}
}
func TestAnonymizeDomain(t *testing.T) {

View File

@ -68,19 +68,19 @@ func networksList(cmd *cobra.Command, _ []string) error {
return nil
}
printRoutes(cmd, resp)
printNetworks(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
cmd.Println("Available Networks:")
for _, route := range resp.Routes {
printRoute(cmd, route)
printNetwork(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Network) {
func printNetwork(cmd *cobra.Command, route *proto.Network) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
@ -113,12 +113,10 @@ func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
for resolvedDomain, ipList := range resolvedIPs {
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
}
}

View File

@ -14,13 +14,19 @@ const (
PriorityDefault = 0
)
type SubdomainMatcher interface {
dns.Handler
MatchSubdomains() bool
}
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool
}
// HandlerChain represents a prioritized chain of DNS handlers
@ -32,6 +38,7 @@ type HandlerChain struct {
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
shouldContinue bool
}
@ -50,6 +57,11 @@ func NewHandlerChain() *HandlerChain {
}
}
// GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string {
return w.origPattern
}
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
@ -74,16 +86,23 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
}
}
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
pattern, origPattern, isWildcard, priority)
// Check if handler implements SubdomainMatcher interface
matchSubdomains := false
if matcher, ok := handler.(SubdomainMatcher); ok {
matchSubdomains = matcher.MatchSubdomains()
}
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
pattern, origPattern, isWildcard, matchSubdomains, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
Handler: handler,
Priority: priority,
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains,
}
// Insert handler in priority order
@ -139,14 +158,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
qname := r.Question[0].Name
log.Debugf("handling DNS request for %s", qname)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock()
defer c.mu.RUnlock()
log.Debugf("current handlers (%d):", len(c.handlers))
log.Tracef("current handlers (%d):", len(c.handlers))
for _, h := range c.handlers {
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
}
@ -160,30 +179,41 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = qname == entry.Pattern
}
}
if !matched {
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
entry.OrigPattern, qname, entry.IsWildcard)
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
continue
}
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
entry.OrigPattern, qname, entry.IsWildcard)
chainWriter := &ResponseWriterChain{ResponseWriter: w}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Debugf("handler requested continue to next handler")
log.Tracef("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Debugf("no handler found for %s", qname)
log.Tracef("no handler found for domain=%s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {

View File

@ -11,23 +11,14 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns"
)
// MockHandler implements dns.Handler interface for testing
type MockHandler struct {
mock.Mock
}
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.Called(w, r)
}
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create mock handlers for different priorities
defaultHandler := &MockHandler{}
matchDomainHandler := &MockHandler{}
dnsRouteHandler := &MockHandler{}
defaultHandler := &nbdns.MockHandler{}
matchDomainHandler := &nbdns.MockHandler{}
dnsRouteHandler := &nbdns.MockHandler{}
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
@ -58,78 +49,108 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct {
name string
handlerDomain string
queryDomain string
isWildcard bool
shouldMatch bool
name string
handlerDomain string
queryDomain string
isWildcard bool
matchSubdomains bool
shouldMatch bool
}{
{
name: "exact match",
handlerDomain: "example.com.",
queryDomain: "example.com.",
isWildcard: false,
shouldMatch: true,
name: "exact match",
handlerDomain: "example.com.",
queryDomain: "example.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
shouldMatch: true,
name: "subdomain with non-wildcard and MatchSubdomains true",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "wildcard match",
handlerDomain: "*.example.com.",
queryDomain: "sub.example.com.",
isWildcard: true,
shouldMatch: true,
name: "subdomain with non-wildcard and MatchSubdomains false",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard no match on apex",
handlerDomain: "*.example.com.",
queryDomain: "example.com.",
isWildcard: true,
shouldMatch: false,
name: "wildcard match",
handlerDomain: "*.example.com.",
queryDomain: "sub.example.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "root zone match",
handlerDomain: ".",
queryDomain: "anything.com.",
isWildcard: false,
shouldMatch: true,
name: "wildcard no match on apex",
handlerDomain: "*.example.com.",
queryDomain: "example.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "no match different domain",
handlerDomain: "example.com.",
queryDomain: "example.org.",
isWildcard: false,
shouldMatch: false,
name: "root zone match",
handlerDomain: ".",
queryDomain: "anything.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "no match different domain",
handlerDomain: "example.com.",
queryDomain: "example.org.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
mockHandler := &MockHandler{}
var handler dns.Handler
if tt.matchSubdomains {
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
handler = mockSubHandler
if tt.shouldMatch {
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
}
} else {
mockHandler := &nbdns.MockHandler{}
handler = mockHandler
if tt.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
}
}
pattern := tt.handlerDomain
if tt.isWildcard {
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
pattern = "*." + tt.handlerDomain[2:]
}
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
if tt.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, r).Once()
}
chain.ServeDNS(w, r)
mockHandler.AssertExpectations(t)
if h, ok := handler.(*nbdns.MockHandler); ok {
h.AssertExpectations(t)
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
h.AssertExpectations(t)
}
})
}
}
@ -218,11 +239,11 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handlers []*MockHandler
var handlers []*nbdns.MockHandler
// Setup handlers and expectations
for i := range tt.handlers {
handler := &MockHandler{}
handler := &nbdns.MockHandler{}
handlers = append(handlers, handler)
// Set expectation based on whether this handler should be called
@ -254,9 +275,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create handlers
handler1 := &MockHandler{}
handler2 := &MockHandler{}
handler3 := &MockHandler{}
handler1 := &nbdns.MockHandler{}
handler2 := &nbdns.MockHandler{}
handler3 := &nbdns.MockHandler{}
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
@ -388,12 +409,12 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[int]*MockHandler)
handlers := make(map[int]*nbdns.MockHandler)
// Execute operations
for _, op := range tt.ops {
if op.action == "add" {
handler := &MockHandler{}
handler := &nbdns.MockHandler{}
handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
} else {
@ -440,10 +461,10 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
testDomain := "example.com."
testQuery := "test.example.com."
// Create handlers for three priority levels
routeHandler := &MockHandler{}
matchHandler := &MockHandler{}
defaultHandler := &MockHandler{}
// Create handlers with MatchSubdomains enabled
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
// Create test request that will be reused
r := new(dns.Msg)

View File

@ -17,12 +17,24 @@ type localResolver struct {
records sync.Map
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v", r.Question[0])
if len(r.Question) > 0 {
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true

View File

@ -11,7 +11,9 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@ -874,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver {
},
}
}
// MockHandler implements dns.Handler interface for testing
type MockHandler struct {
mock.Mock
}
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.Called(w, r)
}
type MockSubdomainHandler struct {
MockHandler
Subdomains bool
}
func (m *MockSubdomainHandler) MatchSubdomains() bool {
return m.Subdomains
}
func TestHandlerChain_DomainPriorities(t *testing.T) {
chain := NewHandlerChain()
dnsRouteHandler := &MockHandler{}
upstreamHandler := &MockSubdomainHandler{
Subdomains: true,
}
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
testCases := []struct {
name string
query string
expectedHandler dns.Handler
}{
{
name: "exact domain with dns route handler",
query: "example.com.",
expectedHandler: dnsRouteHandler,
},
{
name: "subdomain should use upstream handler",
query: "sub.example.com.",
expectedHandler: upstreamHandler,
},
{
name: "deep subdomain should use upstream handler",
query: "deep.sub.example.com.",
expectedHandler: upstreamHandler,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tc.query, dns.TypeA)
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once()
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once()
}
chain.ServeDNS(w, r)
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.AssertExpectations(t)
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
mh.AssertExpectations(t)
}
// Reset mocks
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.ExpectedCalls = nil
mh.Calls = nil
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
mh.ExpectedCalls = nil
mh.Calls = nil
}
})
}
}

View File

@ -68,7 +68,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("%v", u.upstreamServers)
return fmt.Sprintf("upstream %v", u.upstreamServers)
}
func (u *upstreamResolverBase) MatchSubdomains() bool {
return true
}
func (u *upstreamResolverBase) stop() {

View File

@ -2,6 +2,7 @@ package dnsfwd
import (
"context"
"errors"
"net"
"github.com/miekg/dns"
@ -10,6 +11,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
type DNSForwarder struct {
listenAddress string
ttl uint32
@ -20,15 +23,16 @@ type DNSForwarder struct {
}
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d domains=%v", listenAddress, ttl, domains)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
domains: domains,
}
}
func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
for _, d := range f.domains {
@ -67,7 +71,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
question := query.Question[0]
domain := question.Name
@ -76,8 +81,26 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
ips, err := net.LookupIP(domain)
if err != nil {
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeServerFailure
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
@ -87,7 +110,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
for _, ip := range ips {
var respRecord dns.RR
if ip.To4() == nil {
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
@ -99,7 +122,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
}
respRecord = &rr
} else {
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{

View File

@ -17,6 +17,11 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client"
)
type ResolvedDomainInfo struct {
Prefixes []netip.Prefix
ParentDomain domain.Domain
}
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
@ -138,7 +143,7 @@ type Status struct {
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@ -156,7 +161,7 @@ func NewRecorder(mgmAddress string) *Status {
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
}
}
@ -591,16 +596,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
// Store both the original domain pattern and resolved domain
d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{
Prefixes: prefixes,
ParentDomain: originalDomain,
}
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
// Remove all entries that have this domain as their parent
for k, v := range d.resolvedDomainsStates {
if v.ParentDomain == domain {
delete(d.resolvedDomainsStates, k)
}
}
}
func (d *Status) GetRosenpassState() RosenpassState {
@ -702,7 +718,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)

View File

@ -442,5 +442,5 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
if useNewDNSRoute {
return handlerTypeDomain
}
return handlerTypeStatic
return handlerTypeDynamic
}

View File

@ -83,6 +83,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
}
for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
@ -138,14 +140,16 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request: %v", r.Question[0].Name)
log.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
d.continueToNextHandler(w, r, "no current peer key")
return
}
@ -168,7 +172,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if reply != nil {
answer = reply.Answer
}
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
@ -186,7 +190,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
@ -210,8 +215,18 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
if len(r.Answer) > 0 && len(r.Question) > 0 {
// DNS names from miekg/dns are already in punycode format
dom := domain.Domain(r.Question[0].Name)
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
origPattern = writer.GetOrigPattern()
}
resolvedDomain := domain.Domain(r.Question[0].Name)
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
if originalDomain == "" {
originalDomain = resolvedDomain
}
var newPrefixes []netip.Prefix
for _, answer := range r.Answer {
@ -220,14 +235,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Debugf("failed to convert A record IP: %v", rr.A)
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
@ -240,7 +255,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
@ -253,11 +268,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
return nil
}
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
oldPrefixes := d.interceptedDomains[domain]
oldPrefixes := d.interceptedDomains[resolvedDomain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
@ -277,7 +292,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
resolvedDomain.SafeString(),
ref.Out,
)
}
@ -297,16 +312,23 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes
}
}
// Update domain prefixes
// Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[domain] = newPrefixes
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 {
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}

View File

@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)

View File

@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
if info, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
for _, prefix := range info.Prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/netip"
"slices"
"sort"
"golang.org/x/exp/maps"
@ -77,17 +78,27 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
Selected: route.Selected,
}
for _, domain := range route.Domains {
if prefixes, exists := resolvedDomains[domain]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
Ips: ipStrings,
// Group resolved IPs by their parent domain
domainMap := map[domain.Domain][]string{}
for resolvedDomain, info := range resolvedDomains {
// Check if this resolved domain's parent is in our route's domains
if slices.Contains(route.Domains, info.ParentDomain) {
ips := make([]string, 0, len(info.Prefixes))
for _, prefix := range info.Prefixes {
ips = append(ips, prefix.Addr().String())
}
domainMap[resolvedDomain] = ips
}
}
// Convert to proto format
for domain, ips := range domainMap {
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
Ips: ips,
}
}
pbRoutes = append(pbRoutes, pbRoute)
}

View File

@ -129,10 +129,8 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) {
grid.Add(domainsSelector)
var resolvedIPsList []string
for _, domain := range domains {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
for domain, ipList := range r.GetResolvedIPs() {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
if len(resolvedIPsList) == 0 {