mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 18:31:18 +01:00
Add handler chains (#3039)
--------- Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
parent
589456a393
commit
5fee069379
155
client/internal/dns/handler_chain.go
Normal file
155
client/internal/dns/handler_chain.go
Normal file
@ -0,0 +1,155 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityDNSRoute = 100
|
||||
PriorityMatchDomain = 50
|
||||
PriorityDefault = 0
|
||||
)
|
||||
|
||||
type HandlerEntry struct {
|
||||
Handler dns.Handler
|
||||
Priority int
|
||||
Pattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
}
|
||||
|
||||
type HandlerChain struct {
|
||||
mu sync.RWMutex
|
||||
handlers []HandlerEntry
|
||||
}
|
||||
|
||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
shouldContinue bool
|
||||
}
|
||||
|
||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
||||
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
||||
w.shouldContinue = true
|
||||
return nil
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
func NewHandlerChain() *HandlerChain {
|
||||
return &HandlerChain{
|
||||
handlers: make([]HandlerEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||
if isWildcard {
|
||||
pattern = pattern[2:]
|
||||
}
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
log.Debugf("adding handler for pattern: %s (wildcard: %v) with priority %d", pattern, isWildcard, priority)
|
||||
|
||||
entry := HandlerEntry{
|
||||
Handler: handler,
|
||||
Priority: priority,
|
||||
Pattern: pattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
}
|
||||
|
||||
// Insert handler in priority order
|
||||
pos := 0
|
||||
for i, h := range c.handlers {
|
||||
if h.Priority < priority {
|
||||
pos = i
|
||||
break
|
||||
}
|
||||
pos = i + 1
|
||||
}
|
||||
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
}
|
||||
|
||||
func (c *HandlerChain) RemoveHandler(pattern string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
for i, entry := range c.handlers {
|
||||
if entry.Pattern == pattern {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
qname := r.Question[0].Name
|
||||
log.Debugf("handling DNS request for %s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
log.Debugf("current handlers (%d):", len(c.handlers))
|
||||
for _, h := range c.handlers {
|
||||
log.Debugf(" - pattern: %s, wildcard: %v, priority: %d", h.Pattern, h.IsWildcard, h.Priority)
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range c.handlers {
|
||||
var matched bool
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
matched = true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
default:
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
}
|
||||
|
||||
if !matched {
|
||||
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
||||
entry.Pattern, qname, entry.IsWildcard)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
||||
entry.Pattern, qname, entry.IsWildcard)
|
||||
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
log.Debugf("handler requested continue to next handler")
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
log.Debugf("no handler found for %s", qname)
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
319
client/internal/dns/handler_chain_test.go
Normal file
319
client/internal/dns/handler_chain_test.go
Normal file
@ -0,0 +1,319 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
)
|
||||
|
||||
// MockHandler implements dns.Handler interface for testing
|
||||
type MockHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.Called(w, r)
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
// Create mock handlers for different priorities
|
||||
defaultHandler := &MockHandler{}
|
||||
matchDomainHandler := &MockHandler{}
|
||||
dnsRouteHandler := &MockHandler{}
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Create test writer
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// Setup expectations - only highest priority handler should be called
|
||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||
|
||||
// Execute
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify all expectations were met
|
||||
dnsRouteHandler.AssertExpectations(t)
|
||||
matchDomainHandler.AssertExpectations(t)
|
||||
defaultHandler.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlerDomain string
|
||||
queryDomain string
|
||||
isWildcard bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain with non-wildcard",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on apex",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: true,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "root zone match",
|
||||
handlerDomain: ".",
|
||||
queryDomain: "anything.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "no match different domain",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.org.",
|
||||
isWildcard: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
mockHandler := &MockHandler{}
|
||||
|
||||
pattern := tt.handlerDomain
|
||||
if tt.isWildcard {
|
||||
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
if tt.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
mockHandler.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
||||
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlers []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}
|
||||
queryDomain string
|
||||
expectedCalls int
|
||||
expectedHandler int // index of the handler that should be called
|
||||
}{
|
||||
{
|
||||
name: "wildcard and exact same priority - exact should win",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||
},
|
||||
queryDomain: "example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1, // exact match handler should be called
|
||||
},
|
||||
{
|
||||
name: "higher priority wildcard over lower priority exact",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "test.example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1, // higher priority wildcard handler should be called
|
||||
},
|
||||
{
|
||||
name: "multiple wildcards different priorities",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "test.example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority handler should be called
|
||||
},
|
||||
{
|
||||
name: "subdomain with mix of patterns",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "sub.test.example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority matching handler should be called
|
||||
},
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1, // higher priority specific domain should win over root
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
var handlers []*MockHandler
|
||||
|
||||
// Setup handlers and expectations
|
||||
for i := range tt.handlers {
|
||||
handler := &MockHandler{}
|
||||
handlers = append(handlers, handler)
|
||||
|
||||
// Set expectation based on whether this handler should be called
|
||||
if i == tt.expectedHandler {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||
} else {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||
}
|
||||
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||
}
|
||||
|
||||
// Create and execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify expectations
|
||||
for _, handler := range handlers {
|
||||
handler.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
||||
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
// Create handlers
|
||||
handler1 := &MockHandler{}
|
||||
handler2 := &MockHandler{}
|
||||
handler3 := &MockHandler{}
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Setup mock responses to simulate chain continuation
|
||||
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
// First handler signals continue
|
||||
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true // Signal to continue
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
|
||||
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
// Second handler signals continue
|
||||
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
|
||||
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
// Last handler responds normally
|
||||
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
|
||||
// Execute
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify all handlers were called in order
|
||||
handler1.AssertExpectations(t)
|
||||
handler2.AssertExpectations(t)
|
||||
handler3.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||
type mockResponseWriter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (m *mockResponseWriter) Close() error { return nil }
|
||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (m *mockResponseWriter) Hijack() {}
|
@ -13,27 +13,20 @@ type MockServer struct {
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler) error
|
||||
UnregisterHandlerFunc func([]string) error
|
||||
DeregisterHandlerFunc func([]string) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||
DeregisterHandlerFunc func([]string)
|
||||
}
|
||||
|
||||
func (m *MockServer) UnregisterHandler(domains []string) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error {
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
if m.RegisterHandlerFunc != nil {
|
||||
return m.RegisterHandlerFunc(domains, handler)
|
||||
m.RegisterHandlerFunc(domains, handler, priority)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockServer) DeregisterHandler(domains []string) error {
|
||||
func (m *MockServer) DeregisterHandler(domains []string) {
|
||||
if m.DeregisterHandlerFunc != nil {
|
||||
return m.DeregisterHandlerFunc(domains)
|
||||
m.DeregisterHandlerFunc(domains)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize mock implementation of Initialize from Server interface
|
||||
|
@ -30,7 +30,8 @@ type IosDnsManager interface {
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
RegisterHandler(domains []string, handler dns.Handler) error
|
||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains []string)
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() string
|
||||
@ -38,7 +39,6 @@ type Server interface {
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
SearchDomains() []string
|
||||
ProbeAvailability()
|
||||
DeregisterHandler(domains []string) error
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@ -56,6 +56,7 @@ type DefaultServer struct {
|
||||
updateSerial uint64
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
handlerChain *HandlerChain
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
@ -78,6 +79,7 @@ type handlerWithStop interface {
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler handlerWithStop
|
||||
priority int
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
@ -140,6 +142,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
@ -153,32 +156,38 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error {
|
||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
log.Debugf("registering handler %s", handler)
|
||||
for _, domain := range domains {
|
||||
wosuff, _ := strings.CutPrefix(domain, "*.")
|
||||
pattern := dns.Fqdn(wosuff)
|
||||
s.service.RegisterMux(pattern, handler)
|
||||
}
|
||||
|
||||
return nil
|
||||
s.registerHandler(domains, handler, priority)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) DeregisterHandler(domains []string) error {
|
||||
// registerhandler without lock
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||
|
||||
for _, domain := range domains {
|
||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) DeregisterHandler(domains []string) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.deregisterHandler(domains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string) {
|
||||
log.Debugf("unregistering handler for domains %s", domains)
|
||||
for _, domain := range domains {
|
||||
wosuff, _ := strings.CutPrefix(domain, "*.")
|
||||
pattern := dns.Fqdn(wosuff)
|
||||
s.service.DeregisterMux(pattern)
|
||||
}
|
||||
s.handlerChain.RemoveHandler(domain)
|
||||
|
||||
return nil
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager and the dns service
|
||||
@ -444,6 +453,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
priority: PriorityDefault,
|
||||
})
|
||||
continue
|
||||
}
|
||||
@ -461,6 +471,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -474,7 +485,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
var isContainRootUpdate bool
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.service.RegisterMux(update.domain, update.handler)
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
@ -493,7 +504,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
s.service.DeregisterMux(key)
|
||||
s.deregisterHandler([]string{key})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -547,13 +558,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.RouteAll = false
|
||||
s.service.DeregisterMux(nbdns.RootZone)
|
||||
s.deregisterHandler([]string{nbdns.RootZone})
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.service.DeregisterMux(item.Domain)
|
||||
s.deregisterHandler([]string{item.Domain})
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
@ -584,7 +595,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.service.RegisterMux(domain, handler)
|
||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@ -592,7 +603,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.RouteAll = true
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
@ -623,7 +634,8 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
}
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
|
||||
s.RegisterHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
|
@ -68,7 +68,6 @@ func (s *ServiceViaMemory) Stop() {
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
|
@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("%v", u.upstreamServers)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type DNSForwarder struct {
|
||||
@ -18,6 +20,7 @@ type DNSForwarder struct {
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
@ -29,7 +32,7 @@ func (f *DNSForwarder) Listen() error {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
for _, d := range f.domains {
|
||||
mux.HandleFunc(d, f.handleDNSQuery)
|
||||
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
@ -47,8 +50,8 @@ func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||
f.mux.HandleRemove(d)
|
||||
}
|
||||
|
||||
for _, d := range domains {
|
||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||
}
|
||||
f.domains = domains
|
||||
}
|
||||
@ -61,10 +64,10 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
log.Tracef("received DNS query for DNS forwarder: %v", query)
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
|
||||
|
||||
question := query.Question[0]
|
||||
domain := question.Name
|
||||
@ -74,16 +77,17 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
||||
resp.Rcode = dns.RcodeRefused
|
||||
_ = w.WriteMsg(resp)
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
log.Infof("resolved domain %s to IP %s", domain, ip)
|
||||
var respRecord dns.RR
|
||||
if ip.To4() == nil {
|
||||
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
|
||||
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
@ -95,6 +99,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
respRecord = &rr
|
||||
} else {
|
||||
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
miekdns "github.com/miekg/dns"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -104,8 +103,6 @@ func TestEngine_SSH(t *testing.T) {
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
RegisterHandlerFunc: func(domains []string, handler miekdns.Handler) error { return nil },
|
||||
DeregisterHandlerFunc: func(domains []string) error { return nil },
|
||||
}
|
||||
|
||||
var sshKeysAdded []string
|
||||
|
@ -57,15 +57,12 @@ func New(
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) String() string {
|
||||
s, err := d.route.Domains.String()
|
||||
if err != nil {
|
||||
return d.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
return d.route.Domains.SafeString()
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
return d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d)
|
||||
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveRoute() error {
|
||||
@ -91,9 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
||||
|
||||
clear(d.interceptedDomains)
|
||||
|
||||
if err := d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList()); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err))
|
||||
}
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList())
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
@ -143,23 +138,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Debugf("received DNS request: %v", r.Question[0].Name)
|
||||
log.Tracef("received DNS request: %v", r.Question[0].Name)
|
||||
|
||||
if d.currentPeerKey == "" {
|
||||
// TODO: call normal upstream instead of returning an error?
|
||||
log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
|
||||
d.continueToNextHandler(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP()
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get upstream IP: %v", err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@ -169,7 +163,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer)
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
@ -185,13 +184,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
// Set Zero bit to signal handler chain to continue
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed writing DNS continue response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey)
|
||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
|
||||
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||
}
|
||||
return peerAllowedIP, nil
|
||||
}
|
||||
|
@ -74,11 +74,7 @@ func NewRoute(
|
||||
}
|
||||
|
||||
func (r *Route) String() string {
|
||||
s, err := r.route.Domains.String()
|
||||
if err != nil {
|
||||
return r.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
return r.route.Domains.SafeString()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(ctx context.Context) error {
|
||||
|
@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) {
|
||||
|
||||
return validHost, nil
|
||||
}
|
||||
|
||||
// NormalizeZone returns a normalized domain name without the wildcard prefix
|
||||
func NormalizeZone(domain string) string {
|
||||
d, _ := strings.CutPrefix(domain, "*.")
|
||||
return d
|
||||
}
|
||||
|
1
go.mod
1
go.mod
@ -207,6 +207,7 @@ require (
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
|
1
go.sum
1
go.sum
@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
|
Loading…
Reference in New Issue
Block a user