mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-03 03:49:26 +01:00
Improve dns forwarder errors and improve domain anonymization (#3052)
* Improve dns forwarder errors and improve domain anonymization * Use original domain for dns states * Don't match subdomains for non-wildcard dns routes * Fix iOS * Add string representation for local resolver * Return correct handler for dynamic * Add dns server dns route + upstream handler test
This commit is contained in:
parent
228672aed2
commit
21ba6ad266
@ -21,6 +21,8 @@ type Anonymizer struct {
|
|||||||
currentAnonIPv6 netip.Addr
|
currentAnonIPv6 netip.Addr
|
||||||
startAnonIPv4 netip.Addr
|
startAnonIPv4 netip.Addr
|
||||||
startAnonIPv6 netip.Addr
|
startAnonIPv6 netip.Addr
|
||||||
|
|
||||||
|
domainKeyRegex *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
|||||||
currentAnonIPv6: startIPv6,
|
currentAnonIPv6: startIPv6,
|
||||||
startAnonIPv4: startIPv4,
|
startAnonIPv4: startIPv4,
|
||||||
startAnonIPv6: startIPv6,
|
startAnonIPv6: startIPv6,
|
||||||
|
|
||||||
|
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
|||||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
|
||||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||||
domainRegex := regexp.MustCompile(domainPattern)
|
parts := strings.SplitN(match, "=", 2)
|
||||||
|
|
||||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
|
||||||
parts := strings.Split(match, `"`)
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
domain := parts[1]
|
domain := parts[1]
|
||||||
if strings.HasSuffix(domain, anonTLD) {
|
if strings.HasSuffix(domain, anonTLD) {
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
randomDomain := generateRandomString(10) + anonTLD
|
return "domain=" + a.AnonymizeDomain(domain)
|
||||||
return strings.Replace(match, domain, randomDomain, 1)
|
|
||||||
}
|
}
|
||||||
return match
|
return match
|
||||||
})
|
})
|
||||||
|
@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
|
|||||||
|
|
||||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
original string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic domain with trailing content",
|
||||||
|
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Domain with trailing dot",
|
||||||
|
input: "domain=example.com. processing request with status=pending",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple domains in log",
|
||||||
|
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
|
||||||
|
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
|
||||||
|
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Already anonymized domain",
|
||||||
|
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
|
||||||
|
original: "", // nothing should be anonymized
|
||||||
|
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Subdomain with trailing dot",
|
||||||
|
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Handler chain pattern log",
|
||||||
|
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
for _, tc := range tests {
|
||||||
require.NotEqual(t, testLog, result)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.NotContains(t, result, "example.com")
|
result := anonymizer.AnonymizeDNSLogLine(tc.input)
|
||||||
|
if tc.original != "" {
|
||||||
|
assert.NotContains(t, result, tc.original)
|
||||||
|
}
|
||||||
|
assert.Regexp(t, tc.expect, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeDomain(t *testing.T) {
|
func TestAnonymizeDomain(t *testing.T) {
|
||||||
|
@ -68,19 +68,19 @@ func networksList(cmd *cobra.Command, _ []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
printRoutes(cmd, resp)
|
printNetworks(cmd, resp)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func printRoutes(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||||
cmd.Println("Available Networks:")
|
cmd.Println("Available Networks:")
|
||||||
for _, route := range resp.Routes {
|
for _, route := range resp.Routes {
|
||||||
printRoute(cmd, route)
|
printNetwork(cmd, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func printRoute(cmd *cobra.Command, route *proto.Network) {
|
func printNetwork(cmd *cobra.Command, route *proto.Network) {
|
||||||
selectedStatus := getSelectedStatus(route)
|
selectedStatus := getSelectedStatus(route)
|
||||||
domains := route.GetDomains()
|
domains := route.GetDomains()
|
||||||
|
|
||||||
@ -113,12 +113,10 @@ func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus
|
|||||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
|
||||||
cmd.Printf(" Resolved IPs:\n")
|
cmd.Printf(" Resolved IPs:\n")
|
||||||
for _, domain := range domains {
|
for resolvedDomain, ipList := range resolvedIPs {
|
||||||
if ipList, exists := resolvedIPs[domain]; exists {
|
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
|
||||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,11 @@ const (
|
|||||||
PriorityDefault = 0
|
PriorityDefault = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type SubdomainMatcher interface {
|
||||||
|
dns.Handler
|
||||||
|
MatchSubdomains() bool
|
||||||
|
}
|
||||||
|
|
||||||
type HandlerEntry struct {
|
type HandlerEntry struct {
|
||||||
Handler dns.Handler
|
Handler dns.Handler
|
||||||
Priority int
|
Priority int
|
||||||
@ -21,6 +26,7 @@ type HandlerEntry struct {
|
|||||||
OrigPattern string
|
OrigPattern string
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
StopHandler handlerWithStop
|
StopHandler handlerWithStop
|
||||||
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerChain represents a prioritized chain of DNS handlers
|
// HandlerChain represents a prioritized chain of DNS handlers
|
||||||
@ -32,6 +38,7 @@ type HandlerChain struct {
|
|||||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
type ResponseWriterChain struct {
|
type ResponseWriterChain struct {
|
||||||
dns.ResponseWriter
|
dns.ResponseWriter
|
||||||
|
origPattern string
|
||||||
shouldContinue bool
|
shouldContinue bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,6 +57,11 @@ func NewHandlerChain() *HandlerChain {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||||
|
func (w *ResponseWriterChain) GetOrigPattern() string {
|
||||||
|
return w.origPattern
|
||||||
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@ -74,8 +86,14 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
|
// Check if handler implements SubdomainMatcher interface
|
||||||
pattern, origPattern, isWildcard, priority)
|
matchSubdomains := false
|
||||||
|
if matcher, ok := handler.(SubdomainMatcher); ok {
|
||||||
|
matchSubdomains = matcher.MatchSubdomains()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
|
pattern, origPattern, isWildcard, matchSubdomains, priority)
|
||||||
|
|
||||||
entry := HandlerEntry{
|
entry := HandlerEntry{
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
@ -84,6 +102,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
OrigPattern: origPattern,
|
OrigPattern: origPattern,
|
||||||
IsWildcard: isWildcard,
|
IsWildcard: isWildcard,
|
||||||
StopHandler: stopHandler,
|
StopHandler: stopHandler,
|
||||||
|
MatchSubdomains: matchSubdomains,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert handler in priority order
|
// Insert handler in priority order
|
||||||
@ -139,14 +158,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
qname := r.Question[0].Name
|
qname := r.Question[0].Name
|
||||||
log.Debugf("handling DNS request for %s", qname)
|
log.Tracef("handling DNS request for domain=%s", qname)
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
log.Debugf("current handlers (%d):", len(c.handlers))
|
log.Tracef("current handlers (%d):", len(c.handlers))
|
||||||
for _, h := range c.handlers {
|
for _, h := range c.handlers {
|
||||||
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
|
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,30 +179,41 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||||
default:
|
default:
|
||||||
|
// For non-wildcard patterns:
|
||||||
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
// Otherwise require exact match
|
||||||
|
if entry.MatchSubdomains {
|
||||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
|
} else {
|
||||||
|
matched = qname == entry.Pattern
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !matched {
|
if !matched {
|
||||||
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
|
||||||
entry.OrigPattern, qname, entry.IsWildcard)
|
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
|
||||||
entry.OrigPattern, qname, entry.IsWildcard)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
|
||||||
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
|
||||||
|
chainWriter := &ResponseWriterChain{
|
||||||
|
ResponseWriter: w,
|
||||||
|
origPattern: entry.OrigPattern,
|
||||||
|
}
|
||||||
entry.Handler.ServeDNS(chainWriter, r)
|
entry.Handler.ServeDNS(chainWriter, r)
|
||||||
|
|
||||||
// If handler wants to continue, try next handler
|
// If handler wants to continue, try next handler
|
||||||
if chainWriter.shouldContinue {
|
if chainWriter.shouldContinue {
|
||||||
log.Debugf("handler requested continue to next handler")
|
log.Tracef("handler requested continue to next handler")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// No handler matched or all handlers passed
|
// No handler matched or all handlers passed
|
||||||
log.Debugf("no handler found for %s", qname)
|
log.Tracef("no handler found for domain=%s", qname)
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetRcode(r, dns.RcodeNameError)
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
@ -11,23 +11,14 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockHandler implements dns.Handler interface for testing
|
|
||||||
type MockHandler struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
m.Called(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
// Create mock handlers for different priorities
|
// Create mock handlers for different priorities
|
||||||
defaultHandler := &MockHandler{}
|
defaultHandler := &nbdns.MockHandler{}
|
||||||
matchDomainHandler := &MockHandler{}
|
matchDomainHandler := &nbdns.MockHandler{}
|
||||||
dnsRouteHandler := &MockHandler{}
|
dnsRouteHandler := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Setup handlers with different priorities
|
// Setup handlers with different priorities
|
||||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
@ -62,6 +53,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
handlerDomain string
|
handlerDomain string
|
||||||
queryDomain string
|
queryDomain string
|
||||||
isWildcard bool
|
isWildcard bool
|
||||||
|
matchSubdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@ -69,20 +61,31 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
handlerDomain: "example.com.",
|
handlerDomain: "example.com.",
|
||||||
queryDomain: "example.com.",
|
queryDomain: "example.com.",
|
||||||
isWildcard: false,
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
shouldMatch: true,
|
shouldMatch: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "subdomain with non-wildcard",
|
name: "subdomain with non-wildcard and MatchSubdomains true",
|
||||||
handlerDomain: "example.com.",
|
handlerDomain: "example.com.",
|
||||||
queryDomain: "sub.example.com.",
|
queryDomain: "sub.example.com.",
|
||||||
isWildcard: false,
|
isWildcard: false,
|
||||||
|
matchSubdomains: true,
|
||||||
shouldMatch: true,
|
shouldMatch: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard and MatchSubdomains false",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "wildcard match",
|
name: "wildcard match",
|
||||||
handlerDomain: "*.example.com.",
|
handlerDomain: "*.example.com.",
|
||||||
queryDomain: "sub.example.com.",
|
queryDomain: "sub.example.com.",
|
||||||
isWildcard: true,
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
shouldMatch: true,
|
shouldMatch: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -90,6 +93,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
handlerDomain: "*.example.com.",
|
handlerDomain: "*.example.com.",
|
||||||
queryDomain: "example.com.",
|
queryDomain: "example.com.",
|
||||||
isWildcard: true,
|
isWildcard: true,
|
||||||
|
matchSubdomains: false,
|
||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,6 +101,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
handlerDomain: ".",
|
handlerDomain: ".",
|
||||||
queryDomain: "anything.com.",
|
queryDomain: "anything.com.",
|
||||||
isWildcard: false,
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
shouldMatch: true,
|
shouldMatch: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -104,6 +109,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
handlerDomain: "example.com.",
|
handlerDomain: "example.com.",
|
||||||
queryDomain: "example.org.",
|
queryDomain: "example.org.",
|
||||||
isWildcard: false,
|
isWildcard: false,
|
||||||
|
matchSubdomains: false,
|
||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -111,25 +117,40 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
mockHandler := &MockHandler{}
|
var handler dns.Handler
|
||||||
|
|
||||||
|
if tt.matchSubdomains {
|
||||||
|
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
handler = mockSubHandler
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
handler = mockHandler
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pattern := tt.handlerDomain
|
pattern := tt.handlerDomain
|
||||||
if tt.isWildcard {
|
if tt.isWildcard {
|
||||||
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
|
pattern = "*." + tt.handlerDomain[2:]
|
||||||
}
|
}
|
||||||
|
|
||||||
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
if tt.shouldMatch {
|
|
||||||
mockHandler.On("ServeDNS", mock.Anything, r).Once()
|
|
||||||
}
|
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
mockHandler.AssertExpectations(t)
|
|
||||||
|
if h, ok := handler.(*nbdns.MockHandler); ok {
|
||||||
|
h.AssertExpectations(t)
|
||||||
|
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
|
||||||
|
h.AssertExpectations(t)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -218,11 +239,11 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
var handlers []*MockHandler
|
var handlers []*nbdns.MockHandler
|
||||||
|
|
||||||
// Setup handlers and expectations
|
// Setup handlers and expectations
|
||||||
for i := range tt.handlers {
|
for i := range tt.handlers {
|
||||||
handler := &MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
handlers = append(handlers, handler)
|
handlers = append(handlers, handler)
|
||||||
|
|
||||||
// Set expectation based on whether this handler should be called
|
// Set expectation based on whether this handler should be called
|
||||||
@ -254,9 +275,9 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
// Create handlers
|
// Create handlers
|
||||||
handler1 := &MockHandler{}
|
handler1 := &nbdns.MockHandler{}
|
||||||
handler2 := &MockHandler{}
|
handler2 := &nbdns.MockHandler{}
|
||||||
handler3 := &MockHandler{}
|
handler3 := &nbdns.MockHandler{}
|
||||||
|
|
||||||
// Add handlers in priority order
|
// Add handlers in priority order
|
||||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||||
@ -388,12 +409,12 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
handlers := make(map[int]*MockHandler)
|
handlers := make(map[int]*nbdns.MockHandler)
|
||||||
|
|
||||||
// Execute operations
|
// Execute operations
|
||||||
for _, op := range tt.ops {
|
for _, op := range tt.ops {
|
||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
handler := &MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
handlers[op.priority] = handler
|
handlers[op.priority] = handler
|
||||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||||
} else {
|
} else {
|
||||||
@ -440,10 +461,10 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
testDomain := "example.com."
|
testDomain := "example.com."
|
||||||
testQuery := "test.example.com."
|
testQuery := "test.example.com."
|
||||||
|
|
||||||
// Create handlers for three priority levels
|
// Create handlers with MatchSubdomains enabled
|
||||||
routeHandler := &MockHandler{}
|
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
matchHandler := &MockHandler{}
|
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
defaultHandler := &MockHandler{}
|
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
|
||||||
|
|
||||||
// Create test request that will be reused
|
// Create test request that will be reused
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
|
@ -17,12 +17,24 @@ type localResolver struct {
|
|||||||
records sync.Map
|
records sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *localResolver) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
func (d *localResolver) stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the local resolver
|
||||||
|
func (d *localResolver) String() string {
|
||||||
|
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
log.Tracef("received question: %#v", r.Question[0])
|
if len(r.Question) > 0 {
|
||||||
|
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
}
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
@ -11,7 +11,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
@ -874,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockHandler implements dns.Handler interface for testing
|
||||||
|
type MockHandler struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
m.Called(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSubdomainHandler struct {
|
||||||
|
MockHandler
|
||||||
|
Subdomains bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSubdomainHandler) MatchSubdomains() bool {
|
||||||
|
return m.Subdomains
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_DomainPriorities(t *testing.T) {
|
||||||
|
chain := NewHandlerChain()
|
||||||
|
|
||||||
|
dnsRouteHandler := &MockHandler{}
|
||||||
|
upstreamHandler := &MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
expectedHandler dns.Handler
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact domain with dns route handler",
|
||||||
|
query: "example.com.",
|
||||||
|
expectedHandler: dnsRouteHandler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain should use upstream handler",
|
||||||
|
query: "sub.example.com.",
|
||||||
|
expectedHandler: upstreamHandler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep subdomain should use upstream handler",
|
||||||
|
query: "deep.sub.example.com.",
|
||||||
|
expectedHandler: upstreamHandler,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tc.query, dns.TypeA)
|
||||||
|
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.AssertExpectations(t)
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset mocks
|
||||||
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
|
mh.ExpectedCalls = nil
|
||||||
|
mh.Calls = nil
|
||||||
|
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
|
||||||
|
mh.ExpectedCalls = nil
|
||||||
|
mh.Calls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -68,7 +68,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
|||||||
|
|
||||||
// String returns a string representation of the upstream resolver
|
// String returns a string representation of the upstream resolver
|
||||||
func (u *upstreamResolverBase) String() string {
|
func (u *upstreamResolverBase) String() string {
|
||||||
return fmt.Sprintf("%v", u.upstreamServers)
|
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) stop() {
|
||||||
|
@ -2,6 +2,7 @@ package dnsfwd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -10,6 +11,8 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress string
|
||||||
ttl uint32
|
ttl uint32
|
||||||
@ -20,15 +23,16 @@ type DNSForwarder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||||
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d domains=%v", listenAddress, ttl, domains)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
domains: domains,
|
domains: domains,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Listen() error {
|
func (f *DNSForwarder) Listen() error {
|
||||||
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
|
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
for _, d := range f.domains {
|
for _, d := range f.domains {
|
||||||
@ -67,7 +71,8 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||||
|
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
|
||||||
|
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
domain := question.Name
|
domain := question.Name
|
||||||
@ -76,8 +81,26 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
|
|
||||||
ips, err := net.LookupIP(domain)
|
ips, err := net.LookupIP(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &dnsErr):
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
if dnsErr.IsNotFound {
|
||||||
|
// Pass through NXDOMAIN
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write failure DNS response: %v", err)
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
}
|
}
|
||||||
@ -87,7 +110,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var respRecord dns.RR
|
var respRecord dns.RR
|
||||||
if ip.To4() == nil {
|
if ip.To4() == nil {
|
||||||
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||||
rr := dns.AAAA{
|
rr := dns.AAAA{
|
||||||
AAAA: ip,
|
AAAA: ip,
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
@ -99,7 +122,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
respRecord = &rr
|
respRecord = &rr
|
||||||
} else {
|
} else {
|
||||||
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||||
rr := dns.A{
|
rr := dns.A{
|
||||||
A: ip,
|
A: ip,
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
|
@ -17,6 +17,11 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ResolvedDomainInfo struct {
|
||||||
|
Prefixes []netip.Prefix
|
||||||
|
ParentDomain domain.Domain
|
||||||
|
}
|
||||||
|
|
||||||
// State contains the latest state of a peer
|
// State contains the latest state of a peer
|
||||||
type State struct {
|
type State struct {
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
@ -138,7 +143,7 @@ type Status struct {
|
|||||||
rosenpassEnabled bool
|
rosenpassEnabled bool
|
||||||
rosenpassPermissive bool
|
rosenpassPermissive bool
|
||||||
nsGroupStates []NSGroupState
|
nsGroupStates []NSGroupState
|
||||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||||
|
|
||||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||||
@ -156,7 +161,7 @@ func NewRecorder(mgmAddress string) *Status {
|
|||||||
offlinePeers: make([]State, 0),
|
offlinePeers: make([]State, 0),
|
||||||
notifier: newNotifier(),
|
notifier: newNotifier(),
|
||||||
mgmAddress: mgmAddress,
|
mgmAddress: mgmAddress,
|
||||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -591,16 +596,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
|||||||
d.nsGroupStates = dnsStates
|
d.nsGroupStates = dnsStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
d.resolvedDomainsStates[domain] = prefixes
|
|
||||||
|
// Store both the original domain pattern and resolved domain
|
||||||
|
d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{
|
||||||
|
Prefixes: prefixes,
|
||||||
|
ParentDomain: originalDomain,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
delete(d.resolvedDomainsStates, domain)
|
|
||||||
|
// Remove all entries that have this domain as their parent
|
||||||
|
for k, v := range d.resolvedDomainsStates {
|
||||||
|
if v.ParentDomain == domain {
|
||||||
|
delete(d.resolvedDomainsStates, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetRosenpassState() RosenpassState {
|
func (d *Status) GetRosenpassState() RosenpassState {
|
||||||
@ -702,7 +718,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
|||||||
return d.nsGroupStates
|
return d.nsGroupStates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
return maps.Clone(d.resolvedDomainsStates)
|
return maps.Clone(d.resolvedDomainsStates)
|
||||||
|
@ -442,5 +442,5 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
|
|||||||
if useNewDNSRoute {
|
if useNewDNSRoute {
|
||||||
return handlerTypeDomain
|
return handlerTypeDomain
|
||||||
}
|
}
|
||||||
return handlerTypeStatic
|
return handlerTypeDynamic
|
||||||
}
|
}
|
||||||
|
@ -83,6 +83,8 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
}
|
}
|
||||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
|
}
|
||||||
|
for _, domain := range d.route.Domains {
|
||||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,14 +140,16 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Tracef("received DNS request: %v", r.Question[0].Name)
|
log.Tracef("received DNS request for domain=%s type=%v class=%v",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
peerKey := d.currentPeerKey
|
peerKey := d.currentPeerKey
|
||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
|
|
||||||
if peerKey == "" {
|
if peerKey == "" {
|
||||||
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
|
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
|
||||||
|
|
||||||
d.continueToNextHandler(w, r, "no current peer key")
|
d.continueToNextHandler(w, r, "no current peer key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -168,7 +172,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if reply != nil {
|
if reply != nil {
|
||||||
answer = reply.Answer
|
answer = reply.Answer
|
||||||
}
|
}
|
||||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||||
@ -186,7 +190,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// continueToNextHandler signals the handler chain to try the next handler
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
|
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
|
||||||
|
|
||||||
resp := new(dns.Msg)
|
resp := new(dns.Msg)
|
||||||
resp.SetRcode(r, dns.RcodeNameError)
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
// Set Zero bit to signal handler chain to continue
|
// Set Zero bit to signal handler chain to continue
|
||||||
@ -210,8 +215,18 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||||
// DNS names from miekg/dns are already in punycode format
|
origPattern := ""
|
||||||
dom := domain.Domain(r.Question[0].Name)
|
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
|
||||||
|
origPattern = writer.GetOrigPattern()
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedDomain := domain.Domain(r.Question[0].Name)
|
||||||
|
|
||||||
|
// already punycode via RegisterHandler()
|
||||||
|
originalDomain := domain.Domain(origPattern)
|
||||||
|
if originalDomain == "" {
|
||||||
|
originalDomain = resolvedDomain
|
||||||
|
}
|
||||||
|
|
||||||
var newPrefixes []netip.Prefix
|
var newPrefixes []netip.Prefix
|
||||||
for _, answer := range r.Answer {
|
for _, answer := range r.Answer {
|
||||||
@ -220,14 +235,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
case *dns.A:
|
case *dns.A:
|
||||||
addr, ok := netip.AddrFromSlice(rr.A)
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debugf("failed to convert A record IP: %v", rr.A)
|
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip = addr
|
ip = addr
|
||||||
case *dns.AAAA:
|
case *dns.AAAA:
|
||||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
|
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip = addr
|
ip = addr
|
||||||
@ -240,7 +255,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(newPrefixes) > 0 {
|
if len(newPrefixes) > 0 {
|
||||||
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
||||||
log.Errorf("failed to update domain prefixes: %v", err)
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -253,11 +268,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
oldPrefixes := d.interceptedDomains[domain]
|
oldPrefixes := d.interceptedDomains[resolvedDomain]
|
||||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
@ -277,7 +292,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes
|
|||||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
prefix.Addr(),
|
prefix.Addr(),
|
||||||
domain.SafeString(),
|
resolvedDomain.SafeString(),
|
||||||
ref.Out,
|
ref.Out,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -297,16 +312,23 @@ func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update domain prefixes
|
// Update domain prefixes using resolved domain as key
|
||||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
d.interceptedDomains[domain] = newPrefixes
|
d.interceptedDomains[resolvedDomain] = newPrefixes
|
||||||
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
|
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
|
||||||
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
|
||||||
|
|
||||||
if len(toAdd) > 0 {
|
if len(toAdd) > 0 {
|
||||||
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
|
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toAdd)
|
||||||
}
|
}
|
||||||
if len(toRemove) > 0 {
|
if len(toRemove) > 0 {
|
||||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
|
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
|
resolvedDomain.SafeString(),
|
||||||
|
originalDomain.SafeString(),
|
||||||
|
toRemove)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
|
|||||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||||
r.dynamicDomains[domain] = updatedPrefixes
|
r.dynamicDomains[domain] = updatedPrefixes
|
||||||
|
|
||||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
|
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
|
||||||
var routeSelection []RoutesSelectionInfo
|
var routeSelection []RoutesSelectionInfo
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
domainList := make([]DomainInfo, 0)
|
domainList := make([]DomainInfo, 0)
|
||||||
@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
|
|||||||
domainResp := DomainInfo{
|
domainResp := DomainInfo{
|
||||||
Domain: d.SafeString(),
|
Domain: d.SafeString(),
|
||||||
}
|
}
|
||||||
if prefixes, exists := resolvedDomains[d]; exists {
|
|
||||||
|
if info, exists := resolvedDomains[d]; exists {
|
||||||
var ipStrings []string
|
var ipStrings []string
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range info.Prefixes {
|
||||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||||
}
|
}
|
||||||
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@ -77,17 +78,27 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
|||||||
Selected: route.Selected,
|
Selected: route.Selected,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range route.Domains {
|
// Group resolved IPs by their parent domain
|
||||||
if prefixes, exists := resolvedDomains[domain]; exists {
|
domainMap := map[domain.Domain][]string{}
|
||||||
var ipStrings []string
|
|
||||||
for _, prefix := range prefixes {
|
for resolvedDomain, info := range resolvedDomains {
|
||||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
// Check if this resolved domain's parent is in our route's domains
|
||||||
|
if slices.Contains(route.Domains, info.ParentDomain) {
|
||||||
|
ips := make([]string, 0, len(info.Prefixes))
|
||||||
|
for _, prefix := range info.Prefixes {
|
||||||
|
ips = append(ips, prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
domainMap[resolvedDomain] = ips
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to proto format
|
||||||
|
for domain, ips := range domainMap {
|
||||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||||
Ips: ipStrings,
|
Ips: ips,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pbRoutes = append(pbRoutes, pbRoute)
|
pbRoutes = append(pbRoutes, pbRoute)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,11 +129,9 @@ func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) {
|
|||||||
grid.Add(domainsSelector)
|
grid.Add(domainsSelector)
|
||||||
|
|
||||||
var resolvedIPsList []string
|
var resolvedIPsList []string
|
||||||
for _, domain := range domains {
|
for domain, ipList := range r.GetResolvedIPs() {
|
||||||
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
|
|
||||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(resolvedIPsList) == 0 {
|
if len(resolvedIPsList) == 0 {
|
||||||
grid.Add(widget.NewLabel(""))
|
grid.Add(widget.NewLabel(""))
|
||||||
|
Loading…
Reference in New Issue
Block a user