mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-23 14:28:51 +01:00
Improve dns forwarder errors and improve domain anonymization (#3052)
* Improve dns forwarder errors and improve domain anonymization * Use original domain for dns states * Don't match subdomains for non-wildcard dns routes * Fix iOS * Add string representation for local resolver * Return correct handler for dynamic * Add dns server dns route + upstream handler test
This commit is contained in:
parent
228672aed2
commit
21ba6ad266
@ -21,6 +21,8 @@ type Anonymizer struct {
|
||||
currentAnonIPv6 netip.Addr
|
||||
startAnonIPv4 netip.Addr
|
||||
startAnonIPv6 netip.Addr
|
||||
|
||||
domainKeyRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
||||
currentAnonIPv6: startIPv6,
|
||||
startAnonIPv4: startIPv4,
|
||||
startAnonIPv6: startIPv6,
|
||||
|
||||
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||
}
|
||||
|
||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
||||
domainRegex := regexp.MustCompile(domainPattern)
|
||||
|
||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||
parts := strings.Split(match, `"`)
|
||||
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||
parts := strings.SplitN(match, "=", 2)
|
||||
if len(parts) >= 2 {
|
||||
domain := parts[1]
|
||||
if strings.HasSuffix(domain, anonTLD) {
|
||||
return match
|
||||
}
|
||||
randomDomain := generateRandomString(10) + anonTLD
|
||||
return strings.Replace(match, domain, randomDomain, 1)
|
||||
return "domain=" + a.AnonymizeDomain(domain)
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
|
||||
|
||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
original string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "Basic domain with trailing content",
|
||||
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
|
||||
original: "example.com",
|
||||
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
|
||||
},
|
||||
{
|
||||
name: "Domain with trailing dot",
|
||||
input: "domain=example.com. processing request with status=pending",
|
||||
original: "example.com",
|
||||
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
|
||||
},
|
||||
{
|
||||
name: "Multiple domains in log",
|
||||
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
|
||||
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
|
||||
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
|
||||
},
|
||||
{
|
||||
name: "Already anonymized domain",
|
||||
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
|
||||
original: "", // nothing should be anonymized
|
||||
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
|
||||
},
|
||||
{
|
||||
name: "Subdomain with trailing dot",
|
||||
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
|
||||
original: "example.com",
|
||||
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
|
||||
},
|
||||
{
|
||||
name: "Handler chain pattern log",
|
||||
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
|
||||
original: "example.com",
|
||||
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
|
||||
},
|
||||
}
|
||||
|
||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
||||
require.NotEqual(t, testLog, result)
|
||||
assert.NotContains(t, result, "example.com")
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := anonymizer.AnonymizeDNSLogLine(tc.input)
|
||||
if tc.original != "" {
|
||||
assert.NotContains(t, result, tc.original)
|
||||
}
|
||||
assert.Regexp(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymizeDomain(t *testing.T) {
|
||||
|
@ -68,19 +68,19 @@ func networksList(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
printRoutes(cmd, resp)
|
||||
printNetworks(cmd, resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRoutes(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||
cmd.Println("Available Networks:")
|
||||
for _, route := range resp.Routes {
|
||||
printRoute(cmd, route)
|
||||
printNetwork(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printRoute(cmd *cobra.Command, route *proto.Network) {
|
||||
func printNetwork(cmd *cobra.Command, route *proto.Network) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
@ -113,12 +113,10 @@ func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
||||
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := resolvedIPs[domain]; exists {
|
||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
for resolvedDomain, ipList := range resolvedIPs {
|
||||
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,13 +14,19 @@ const (
|
||||
PriorityDefault = 0
|
||||
)
|
||||
|
||||
type SubdomainMatcher interface {
|
||||
dns.Handler
|
||||
MatchSubdomains() bool
|
||||
}
|
||||
|
||||
type HandlerEntry struct {
|
||||
Handler dns.Handler
|
||||
Priority int
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
Handler dns.Handler
|
||||
Priority int
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
MatchSubdomains bool
|
||||
}
|
||||
|
||||
// HandlerChain represents a prioritized chain of DNS handlers
|
||||
@ -32,6 +38,7 @@ type HandlerChain struct {
|
||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
origPattern string
|
||||
shouldContinue bool
|
||||
}
|
||||
|
||||
@ -50,6 +57,11 @@ func NewHandlerChain() *HandlerChain {
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||
func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||
return w.origPattern
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||
c.mu.Lock()
|
||||
@ -74,16 +86,23 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
|
||||
pattern, origPattern, isWildcard, priority)
|
||||
// Check if handler implements SubdomainMatcher interface
|
||||
matchSubdomains := false
|
||||
if matcher, ok := handler.(SubdomainMatcher); ok {
|
||||
matchSubdomains = matcher.MatchSubdomains()
|
||||
}
|
||||
|
||||
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||
pattern, origPattern, isWildcard, matchSubdomains, priority)
|
||||
|
||||
entry := HandlerEntry{
|
||||
Handler: handler,
|
||||
Priority: priority,
|
||||
Pattern: pattern,
|
||||
OrigPattern: origPattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
Handler: handler,
|
||||
Priority: priority,
|
||||
Pattern: pattern,
|
||||
OrigPattern: origPattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
MatchSubdomains: matchSubdomains,
|
||||
}
|
||||
|
||||
// Insert handler in priority order
|
||||
@ -139,14 +158,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
qname := r.Question[0].Name
|
||||
log.Debugf("handling DNS request for %s", qname)
|
||||
log.Tracef("handling DNS request for domain=%s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
log.Debugf("current handlers (%d):", len(c.handlers))
|
||||
log.Tracef("current handlers (%d):", len(c.handlers))
|
||||
for _, h := range c.handlers {
|
||||
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
|
||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||
}
|
||||
|
||||
@ -160,30 +179,41 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
default:
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
// Otherwise require exact match
|
||||
if entry.MatchSubdomains {
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
} else {
|
||||
matched = qname == entry.Pattern
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
||||
entry.OrigPattern, qname, entry.IsWildcard)
|
||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
||||
entry.OrigPattern, qname, entry.IsWildcard)
|
||||
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||
|
||||
chainWriter := &ResponseWriterChain{
|
||||
ResponseWriter: w,
|
||||
origPattern: entry.OrigPattern,
|
||||
}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
log.Debugf("handler requested continue to next handler")
|
||||
log.Tracef("handler requested continue to next handler")
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
log.Debugf("no handler found for %s", qname)
|
||||
log.Tracef("no handler found for domain=%s", qname)
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
|
@ -11,23 +11,14 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
)
|
||||
|
||||
// MockHandler implements dns.Handler interface for testing
|
||||
type MockHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.Called(w, r)
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
// Create mock handlers for different priorities
|
||||
defaultHandler := &MockHandler{}
|
||||
matchDomainHandler := &MockHandler{}
|
||||
dnsRouteHandler := &MockHandler{}
|
||||
defaultHandler := &nbdns.MockHandler{}
|
||||
matchDomainHandler := &nbdns.MockHandler{}
|
||||
dnsRouteHandler := &nbdns.MockHandler{}
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||
@ -58,78 +49,108 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlerDomain string
|
||||
queryDomain string
|
||||
isWildcard bool
|
||||
shouldMatch bool
|
||||
name string
|
||||
handlerDomain string
|
||||
queryDomain string
|
||||
isWildcard bool
|
||||
matchSubdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
name: "exact match",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain with non-wildcard",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
name: "subdomain with non-wildcard and MatchSubdomains true",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: true,
|
||||
shouldMatch: true,
|
||||
name: "subdomain with non-wildcard and MatchSubdomains false",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on apex",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: true,
|
||||
shouldMatch: false,
|
||||
name: "wildcard match",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "root zone match",
|
||||
handlerDomain: ".",
|
||||
queryDomain: "anything.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
name: "wildcard no match on apex",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no match different domain",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.org.",
|
||||
isWildcard: false,
|
||||
shouldMatch: false,
|
||||
name: "root zone match",
|
||||
handlerDomain: ".",
|
||||
queryDomain: "anything.com.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "no match different domain",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.org.",
|
||||
isWildcard: false,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
mockHandler := &MockHandler{}
|
||||
var handler dns.Handler
|
||||
|
||||
if tt.matchSubdomains {
|
||||
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||
handler = mockSubHandler
|
||||
if tt.shouldMatch {
|
||||
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||
}
|
||||
} else {
|
||||
mockHandler := &nbdns.MockHandler{}
|
||||
handler = mockHandler
|
||||
if tt.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||
}
|
||||
}
|
||||
|
||||
pattern := tt.handlerDomain
|
||||
if tt.isWildcard {
|
||||
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
|
||||
pattern = "*." + tt.handlerDomain[2:]
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
if tt.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
mockHandler.AssertExpectations(t)
|
||||
|
||||
if h, ok := handler.(*nbdns.MockHandler); ok {
|
||||
h.AssertExpectations(t)
|
||||
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
|
||||
h.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -218,11 +239,11 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
var handlers []*MockHandler
|
||||
var handlers []*nbdns.MockHandler
|
||||
|
||||
// Setup handlers and expectations
|
||||
for i := range tt.handlers {
|
||||
handler := &MockHandler{}
|
||||
handler := &nbdns.MockHandler{}
|
||||
handlers = append(handlers, handler)
|
||||
|
||||
// Set expectation based on whether this handler should be called
|
||||
@ -254,9 +275,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
// Create handlers
|
||||
handler1 := &MockHandler{}
|
||||
handler2 := &MockHandler{}
|
||||
handler3 := &MockHandler{}
|
||||
handler1 := &nbdns.MockHandler{}
|
||||
handler2 := &nbdns.MockHandler{}
|
||||
handler3 := &nbdns.MockHandler{}
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||
@ -388,12 +409,12 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlers := make(map[int]*MockHandler)
|
||||
handlers := make(map[int]*nbdns.MockHandler)
|
||||
|
||||
// Execute operations
|
||||
for _, op := range tt.ops {
|
||||
if op.action == "add" {
|
||||
handler := &MockHandler{}
|
||||
handler := &nbdns.MockHandler{}
|
||||
handlers[op.priority] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
} else {
|
||||
@ -440,10 +461,10 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
testDomain := "example.com."
|
||||
testQuery := "test.example.com."
|
||||
|
||||
// Create handlers for three priority levels
|
||||
routeHandler := &MockHandler{}
|
||||
matchHandler := &MockHandler{}
|
||||
defaultHandler := &MockHandler{}
|
||||
// Create handlers with MatchSubdomains enabled
|
||||
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||
|
||||
// Create test request that will be reused
|
||||
r := new(dns.Msg)
|
||||
|
@ -17,12 +17,24 @@ type localResolver struct {
|
||||
records sync.Map
|
||||
}
|
||||
|
||||
func (d *localResolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *localResolver) stop() {
|
||||
}
|
||||
|
||||
// String returns a string representation of the local resolver
|
||||
func (d *localResolver) String() string {
|
||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
log.Tracef("received question: %#v", r.Question[0])
|
||||
if len(r.Question) > 0 {
|
||||
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
}
|
||||
|
||||
replyMessage := &dns.Msg{}
|
||||
replyMessage.SetReply(r)
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
@ -11,7 +11,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@ -874,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// MockHandler implements dns.Handler interface for testing
|
||||
type MockHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.Called(w, r)
|
||||
}
|
||||
|
||||
type MockSubdomainHandler struct {
|
||||
MockHandler
|
||||
Subdomains bool
|
||||
}
|
||||
|
||||
func (m *MockSubdomainHandler) MatchSubdomains() bool {
|
||||
return m.Subdomains
|
||||
}
|
||||
|
||||
func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||
chain := NewHandlerChain()
|
||||
|
||||
dnsRouteHandler := &MockHandler{}
|
||||
upstreamHandler := &MockSubdomainHandler{
|
||||
Subdomains: true,
|
||||
}
|
||||
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
query string
|
||||
expectedHandler dns.Handler
|
||||
}{
|
||||
{
|
||||
name: "exact domain with dns route handler",
|
||||
query: "example.com.",
|
||||
expectedHandler: dnsRouteHandler,
|
||||
},
|
||||
{
|
||||
name: "subdomain should use upstream handler",
|
||||
query: "sub.example.com.",
|
||||
expectedHandler: upstreamHandler,
|
||||
},
|
||||
{
|
||||
name: "deep subdomain should use upstream handler",
|
||||
query: "deep.sub.example.com.",
|
||||
expectedHandler: upstreamHandler,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tc.query, dns.TypeA)
|
||||
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||
mh.AssertExpectations(t)
|
||||
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||
mh.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Reset mocks
|
||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||
mh.ExpectedCalls = nil
|
||||
mh.Calls = nil
|
||||
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||
mh.ExpectedCalls = nil
|
||||
mh.Calls = nil
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -68,7 +68,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("%v", u.upstreamServers)
|
||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) stop() {
|
||||
|
@ -2,6 +2,7 @@ package dnsfwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@ -10,6 +11,8 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
@ -20,15 +23,16 @@ type DNSForwarder struct {
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d domains=%v", listenAddress, ttl, domains)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
domains: domains,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen() error {
|
||||
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
|
||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
for _, d := range f.domains {
|
||||
@ -67,7 +71,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
|
||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
|
||||
|
||||
question := query.Question[0]
|
||||
domain := question.Name
|
||||
@ -76,8 +81,26 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
// Pass through NXDOMAIN
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
@ -87,7 +110,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
for _, ip := range ips {
|
||||
var respRecord dns.RR
|
||||
if ip.To4() == nil {
|
||||
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
@ -99,7 +122,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
respRecord = &rr
|
||||
} else {
|
||||
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
|
@ -17,6 +17,11 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
)
|
||||
|
||||
type ResolvedDomainInfo struct {
|
||||
Prefixes []netip.Prefix
|
||||
ParentDomain domain.Domain
|
||||
}
|
||||
|
||||
// State contains the latest state of a peer
|
||||
type State struct {
|
||||
Mux *sync.RWMutex
|
||||
@ -138,7 +143,7 @@ type Status struct {
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
||||
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||
|
||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||
@ -156,7 +161,7 @@ func NewRecorder(mgmAddress string) *Status {
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
||||
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -591,16 +596,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.resolvedDomainsStates[domain] = prefixes
|
||||
|
||||
// Store both the original domain pattern and resolved domain
|
||||
d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{
|
||||
Prefixes: prefixes,
|
||||
ParentDomain: originalDomain,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
delete(d.resolvedDomainsStates, domain)
|
||||
|
||||
// Remove all entries that have this domain as their parent
|
||||
for k, v := range d.resolvedDomainsStates {
|
||||
if v.ParentDomain == domain {
|
||||
delete(d.resolvedDomainsStates, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
@ -702,7 +718,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
||||
return d.nsGroupStates
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
|
@ -442,5 +442,5 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
|
||||
if useNewDNSRoute {
|
||||
return handlerTypeDomain
|
||||
}
|
||||
return handlerTypeStatic
|
||||
return handlerTypeDynamic
|
||||
}
|
||||
|
@ -83,6 +83,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
}
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
}
|
||||
for _, domain := range d.route.Domains {
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
@ -138,14 +140,16 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request: %v", r.Question[0].Name)
|
||||
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
|
||||
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||
|
||||
d.continueToNextHandler(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
@ -168,7 +172,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
@ -186,7 +190,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
|
||||
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
// Set Zero bit to signal handler chain to continue
|
||||
@ -210,8 +215,18 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
// DNS names from miekg/dns are already in punycode format
|
||||
dom := domain.Domain(r.Question[0].Name)
|
||||
origPattern := ""
|
||||
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||
origPattern = writer.GetOrigPattern()
|
||||
}
|
||||
|
||||
resolvedDomain := domain.Domain(r.Question[0].Name)
|
||||
|
||||
// already punycode via RegisterHandler()
|
||||
originalDomain := domain.Domain(origPattern)
|
||||
if originalDomain == "" {
|
||||
originalDomain = resolvedDomain
|
||||
}
|
||||
|
||||
var newPrefixes []netip.Prefix
|
||||
for _, answer := range r.Answer {
|
||||
@ -220,14 +235,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Debugf("failed to convert A record IP: %v", rr.A)
|
||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
|
||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
@ -240,7 +255,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
|
||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
}
|
||||
@ -253,11 +268,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
oldPrefixes := d.interceptedDomains[domain]
|
||||
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
var merr *multierror.Error
|
||||
@ -277,7 +292,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
resolvedDomain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
@ -297,16 +312,23 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes
|
||||
}
|
||||
}
|
||||
|
||||
// Update domain prefixes
|
||||
// Update domain prefixes using resolved domain as key
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
d.interceptedDomains[domain] = newPrefixes
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
|
||||
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
||||
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
|
||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 {
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
|
||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||
resolvedDomain.SafeString(),
|
||||
originalDomain.SafeString(),
|
||||
toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
|
||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||
r.dynamicDomains[domain] = updatedPrefixes
|
||||
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
|
@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
|
||||
}
|
||||
|
||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
|
||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||
var routeSelection []RoutesSelectionInfo
|
||||
for _, r := range routes {
|
||||
domainList := make([]DomainInfo, 0)
|
||||
@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
||||
domainResp := DomainInfo{
|
||||
Domain: d.SafeString(),
|
||||
}
|
||||
if prefixes, exists := resolvedDomains[d]; exists {
|
||||
|
||||
if info, exists := resolvedDomains[d]; exists {
|
||||
var ipStrings []string
|
||||
for _, prefix := range prefixes {
|
||||
for _, prefix := range info.Prefixes {
|
||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||
}
|
||||
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
@ -77,17 +78,27 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
Selected: route.Selected,
|
||||
}
|
||||
|
||||
for _, domain := range route.Domains {
|
||||
if prefixes, exists := resolvedDomains[domain]; exists {
|
||||
var ipStrings []string
|
||||
for _, prefix := range prefixes {
|
||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||
}
|
||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||
Ips: ipStrings,
|
||||
// Group resolved IPs by their parent domain
|
||||
domainMap := map[domain.Domain][]string{}
|
||||
|
||||
for resolvedDomain, info := range resolvedDomains {
|
||||
// Check if this resolved domain's parent is in our route's domains
|
||||
if slices.Contains(route.Domains, info.ParentDomain) {
|
||||
ips := make([]string, 0, len(info.Prefixes))
|
||||
for _, prefix := range info.Prefixes {
|
||||
ips = append(ips, prefix.Addr().String())
|
||||
}
|
||||
domainMap[resolvedDomain] = ips
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to proto format
|
||||
for domain, ips := range domainMap {
|
||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||
Ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
pbRoutes = append(pbRoutes, pbRoute)
|
||||
}
|
||||
|
||||
|
@ -129,10 +129,8 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) {
|
||||
grid.Add(domainsSelector)
|
||||
|
||||
var resolvedIPsList []string
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
|
||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||
}
|
||||
for domain, ipList := range r.GetResolvedIPs() {
|
||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||
}
|
||||
|
||||
if len(resolvedIPsList) == 0 {
|
||||
|
Loading…
Reference in New Issue
Block a user