mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 02:41:34 +01:00
Add handler chains (#3039)
--------- Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
parent
589456a393
commit
5fee069379
155
client/internal/dns/handler_chain.go
Normal file
155
client/internal/dns/handler_chain.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PriorityDNSRoute = 100
|
||||||
|
PriorityMatchDomain = 50
|
||||||
|
PriorityDefault = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerEntry struct {
|
||||||
|
Handler dns.Handler
|
||||||
|
Priority int
|
||||||
|
Pattern string
|
||||||
|
IsWildcard bool
|
||||||
|
StopHandler handlerWithStop
|
||||||
|
}
|
||||||
|
|
||||||
|
type HandlerChain struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
handlers []HandlerEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
|
type ResponseWriterChain struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
shouldContinue bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||||
|
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
||||||
|
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
||||||
|
w.shouldContinue = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandlerChain() *HandlerChain {
|
||||||
|
return &HandlerChain{
|
||||||
|
handlers: make([]HandlerEntry, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||||
|
if isWildcard {
|
||||||
|
pattern = pattern[2:]
|
||||||
|
}
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
log.Debugf("adding handler for pattern: %s (wildcard: %v) with priority %d", pattern, isWildcard, priority)
|
||||||
|
|
||||||
|
entry := HandlerEntry{
|
||||||
|
Handler: handler,
|
||||||
|
Priority: priority,
|
||||||
|
Pattern: pattern,
|
||||||
|
IsWildcard: isWildcard,
|
||||||
|
StopHandler: stopHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert handler in priority order
|
||||||
|
pos := 0
|
||||||
|
for i, h := range c.handlers {
|
||||||
|
if h.Priority < priority {
|
||||||
|
pos = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pos = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) RemoveHandler(pattern string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
for i, entry := range c.handlers {
|
||||||
|
if entry.Pattern == pattern {
|
||||||
|
if entry.StopHandler != nil {
|
||||||
|
entry.StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qname := r.Question[0].Name
|
||||||
|
log.Debugf("handling DNS request for %s", qname)
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
log.Debugf("current handlers (%d):", len(c.handlers))
|
||||||
|
for _, h := range c.handlers {
|
||||||
|
log.Debugf(" - pattern: %s, wildcard: %v, priority: %d", h.Pattern, h.IsWildcard, h.Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try handlers in priority order
|
||||||
|
for _, entry := range c.handlers {
|
||||||
|
var matched bool
|
||||||
|
switch {
|
||||||
|
case entry.Pattern == ".":
|
||||||
|
matched = true
|
||||||
|
case entry.IsWildcard:
|
||||||
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||||
|
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||||
|
default:
|
||||||
|
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
||||||
|
entry.Pattern, qname, entry.IsWildcard)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
||||||
|
entry.Pattern, qname, entry.IsWildcard)
|
||||||
|
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
||||||
|
entry.Handler.ServeDNS(chainWriter, r)
|
||||||
|
|
||||||
|
// If handler wants to continue, try next handler
|
||||||
|
if chainWriter.shouldContinue {
|
||||||
|
log.Debugf("handler requested continue to next handler")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No handler matched or all handlers passed
|
||||||
|
log.Debugf("no handler found for %s", qname)
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
319
client/internal/dns/handler_chain_test.go
Normal file
319
client/internal/dns/handler_chain_test.go
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
package dns_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockHandler implements dns.Handler interface for testing
|
||||||
|
type MockHandler struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
m.Called(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
|
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create mock handlers for different priorities
|
||||||
|
defaultHandler := &MockHandler{}
|
||||||
|
matchDomainHandler := &MockHandler{}
|
||||||
|
dnsRouteHandler := &MockHandler{}
|
||||||
|
|
||||||
|
// Setup handlers with different priorities
|
||||||
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Create test writer
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations - only highest priority handler should be called
|
||||||
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all expectations were met
|
||||||
|
dnsRouteHandler.AssertExpectations(t)
|
||||||
|
matchDomainHandler.AssertExpectations(t)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||||
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlerDomain string
|
||||||
|
queryDomain string
|
||||||
|
isWildcard bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on apex",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone match",
|
||||||
|
handlerDomain: ".",
|
||||||
|
queryDomain: "anything.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match different domain",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.org.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
mockHandler := &MockHandler{}
|
||||||
|
|
||||||
|
pattern := tt.handlerDomain
|
||||||
|
if tt.isWildcard {
|
||||||
|
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
mockHandler.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
||||||
|
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
queryDomain string
|
||||||
|
expectedCalls int
|
||||||
|
expectedHandler int // index of the handler that should be called
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wildcard and exact same priority - exact should win",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // exact match handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "higher priority wildcard over lower priority exact",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority wildcard handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards different priorities",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with mix of patterns",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "sub.test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority matching handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone with specific domain",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority specific domain should win over root
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
var handlers []*MockHandler
|
||||||
|
|
||||||
|
// Setup handlers and expectations
|
||||||
|
for i := range tt.handlers {
|
||||||
|
handler := &MockHandler{}
|
||||||
|
handlers = append(handlers, handler)
|
||||||
|
|
||||||
|
// Set expectation based on whether this handler should be called
|
||||||
|
if i == tt.expectedHandler {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
||||||
|
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create handlers
|
||||||
|
handler1 := &MockHandler{}
|
||||||
|
handler2 := &MockHandler{}
|
||||||
|
handler3 := &MockHandler{}
|
||||||
|
|
||||||
|
// Add handlers in priority order
|
||||||
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Setup mock responses to simulate chain continuation
|
||||||
|
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// First handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true // Signal to continue
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Second handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Last handler responds normally
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all handlers were called in order
|
||||||
|
handler1.AssertExpectations(t)
|
||||||
|
handler2.AssertExpectations(t)
|
||||||
|
handler3.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||||
|
type mockResponseWriter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||||
|
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (m *mockResponseWriter) Close() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (m *mockResponseWriter) Hijack() {}
|
@ -13,27 +13,20 @@ type MockServer struct {
|
|||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
RegisterHandlerFunc func([]string, dns.Handler) error
|
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||||
UnregisterHandlerFunc func([]string) error
|
DeregisterHandlerFunc func([]string)
|
||||||
DeregisterHandlerFunc func([]string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) UnregisterHandler(domains []string) error {
|
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error {
|
|
||||||
if m.RegisterHandlerFunc != nil {
|
if m.RegisterHandlerFunc != nil {
|
||||||
return m.RegisterHandlerFunc(domains, handler)
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) DeregisterHandler(domains []string) error {
|
func (m *MockServer) DeregisterHandler(domains []string) {
|
||||||
if m.DeregisterHandlerFunc != nil {
|
if m.DeregisterHandlerFunc != nil {
|
||||||
return m.DeregisterHandlerFunc(domains)
|
m.DeregisterHandlerFunc(domains)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Initialize mock implementation of Initialize from Server interface
|
||||||
|
@ -30,7 +30,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains []string, handler dns.Handler) error
|
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||||
|
DeregisterHandler(domains []string)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@ -38,7 +39,6 @@ type Server interface {
|
|||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(strings []string)
|
||||||
SearchDomains() []string
|
SearchDomains() []string
|
||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
DeregisterHandler(domains []string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[string]handlerWithStop
|
type registeredHandlerMap map[string]handlerWithStop
|
||||||
@ -56,6 +56,7 @@ type DefaultServer struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
handlerChain *HandlerChain
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@ -76,8 +77,9 @@ type handlerWithStop interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type muxUpdate struct {
|
type muxUpdate struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
@ -137,10 +139,11 @@ func NewDefaultServerIos(
|
|||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
handlerChain: NewHandlerChain(),
|
||||||
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@ -153,32 +156,38 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error {
|
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
log.Debugf("registering handler %s", handler)
|
s.registerHandler(domains, handler, priority)
|
||||||
for _, domain := range domains {
|
|
||||||
wosuff, _ := strings.CutPrefix(domain, "*.")
|
|
||||||
pattern := dns.Fqdn(wosuff)
|
|
||||||
s.service.RegisterMux(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) DeregisterHandler(domains []string) error {
|
// registerhandler without lock
|
||||||
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||||
|
|
||||||
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains []string) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.deregisterHandler(domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterHandler(domains []string) {
|
||||||
log.Debugf("unregistering handler for domains %s", domains)
|
log.Debugf("unregistering handler for domains %s", domains)
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
wosuff, _ := strings.CutPrefix(domain, "*.")
|
s.handlerChain.RemoveHandler(domain)
|
||||||
pattern := dns.Fqdn(wosuff)
|
|
||||||
s.service.DeregisterMux(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
// Initialize instantiate host manager and the dns service
|
||||||
@ -442,8 +451,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityDefault,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -459,8 +469,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||||
}
|
}
|
||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -474,7 +485,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
var isContainRootUpdate bool
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.service.RegisterMux(update.domain, update.handler)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
@ -493,7 +504,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
} else {
|
} else {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
s.service.DeregisterMux(key)
|
s.deregisterHandler([]string{key})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -547,13 +558,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.service.DeregisterMux(nbdns.RootZone)
|
s.deregisterHandler([]string{nbdns.RootZone})
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.service.DeregisterMux(item.Domain)
|
s.deregisterHandler([]string{item.Domain})
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -584,7 +595,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.service.RegisterMux(domain, handler)
|
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@ -592,7 +603,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
@ -623,7 +634,8 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
}
|
}
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
|
||||||
|
s.RegisterHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||||
|
@ -68,7 +68,6 @@ func (s *ServiceViaMemory) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
log.Debugf("registering dns handler for pattern: %s", pattern)
|
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the upstream resolver
|
||||||
|
func (u *upstreamResolverBase) String() string {
|
||||||
|
return fmt.Sprintf("%v", u.upstreamServers)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
@ -18,6 +20,7 @@ type DNSForwarder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||||
|
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
@ -29,7 +32,7 @@ func (f *DNSForwarder) Listen() error {
|
|||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
for _, d := range f.domains {
|
for _, d := range f.domains {
|
||||||
mux.HandleFunc(d, f.handleDNSQuery)
|
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
dnsServer := &dns.Server{
|
||||||
@ -47,8 +50,8 @@ func (f *DNSForwarder) UpdateDomains(domains []string) {
|
|||||||
f.mux.HandleRemove(d)
|
f.mux.HandleRemove(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, d := range domains {
|
for _, d := range f.domains {
|
||||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||||
}
|
}
|
||||||
f.domains = domains
|
f.domains = domains
|
||||||
}
|
}
|
||||||
@ -61,10 +64,10 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
log.Tracef("received DNS query for DNS forwarder: %v", query)
|
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
|
||||||
|
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
domain := question.Name
|
domain := question.Name
|
||||||
@ -74,16 +77,17 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
ips, err := net.LookupIP(domain)
|
ips, err := net.LookupIP(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
_ = w.WriteMsg(resp)
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
log.Infof("resolved domain %s to IP %s", domain, ip)
|
|
||||||
var respRecord dns.RR
|
var respRecord dns.RR
|
||||||
if ip.To4() == nil {
|
if ip.To4() == nil {
|
||||||
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
|
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
|
||||||
rr := dns.AAAA{
|
rr := dns.AAAA{
|
||||||
AAAA: ip,
|
AAAA: ip,
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
@ -95,6 +99,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
respRecord = &rr
|
respRecord = &rr
|
||||||
} else {
|
} else {
|
||||||
|
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
|
||||||
rr := dns.A{
|
rr := dns.A{
|
||||||
A: ip,
|
A: ip,
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
miekdns "github.com/miekg/dns"
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -103,9 +102,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
RegisterHandlerFunc: func(domains []string, handler miekdns.Handler) error { return nil },
|
|
||||||
DeregisterHandlerFunc: func(domains []string) error { return nil },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sshKeysAdded []string
|
var sshKeysAdded []string
|
||||||
|
@ -57,15 +57,12 @@ func New(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) String() string {
|
func (d *DnsInterceptor) String() string {
|
||||||
s, err := d.route.Domains.String()
|
return d.route.Domains.SafeString()
|
||||||
if err != nil {
|
|
||||||
return d.route.Domains.PunycodeString()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
return d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d)
|
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) RemoveRoute() error {
|
func (d *DnsInterceptor) RemoveRoute() error {
|
||||||
@ -91,9 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
|
|||||||
|
|
||||||
clear(d.interceptedDomains)
|
clear(d.interceptedDomains)
|
||||||
|
|
||||||
if err := d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList()); err != nil {
|
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList())
|
||||||
merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
@ -143,23 +138,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("received DNS request: %v", r.Question[0].Name)
|
log.Tracef("received DNS request: %v", r.Question[0].Name)
|
||||||
|
|
||||||
if d.currentPeerKey == "" {
|
d.mu.RLock()
|
||||||
// TODO: call normal upstream instead of returning an error?
|
peerKey := d.currentPeerKey
|
||||||
log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name)
|
d.mu.RUnlock()
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
if peerKey == "" {
|
||||||
}
|
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
|
||||||
|
d.continueToNextHandler(w, r, "no current peer key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamIP, err := d.getUpstreamIP()
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to get upstream IP: %v", err)
|
log.Errorf("failed to get upstream IP: %v", err)
|
||||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||||
log.Errorf("failed writing DNS response: %v", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,7 +163,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer)
|
|
||||||
|
var answer []dns.RR
|
||||||
|
if reply != nil {
|
||||||
|
answer = reply.Answer
|
||||||
|
}
|
||||||
|
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||||
@ -185,13 +184,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
d.mu.RLock()
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
defer d.mu.RUnlock()
|
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
// Set Zero bit to signal handler chain to continue
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed writing DNS continue response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey)
|
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||||
|
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
|
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||||
}
|
}
|
||||||
return peerAllowedIP, nil
|
return peerAllowedIP, nil
|
||||||
}
|
}
|
||||||
|
@ -74,11 +74,7 @@ func NewRoute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) String() string {
|
func (r *Route) String() string {
|
||||||
s, err := r.route.Domains.String()
|
return r.route.Domains.SafeString()
|
||||||
if err != nil {
|
|
||||||
return r.route.Domains.PunycodeString()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) AddRoute(ctx context.Context) error {
|
func (r *Route) AddRoute(ctx context.Context) error {
|
||||||
|
@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) {
|
|||||||
|
|
||||||
return validHost, nil
|
return validHost, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeZone returns a normalized domain name without the wildcard prefix
|
||||||
|
func NormalizeZone(domain string) string {
|
||||||
|
d, _ := strings.CutPrefix(domain, "*.")
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
1
go.mod
1
go.mod
@ -207,6 +207,7 @@ require (
|
|||||||
github.com/spf13/cast v1.5.0 // indirect
|
github.com/spf13/cast v1.5.0 // indirect
|
||||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
|
1
go.sum
1
go.sum
@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
Loading…
Reference in New Issue
Block a user