Add handler chains (#3039)

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
Viktor Liu 2024-12-12 18:19:06 +01:00 committed by GitHub
parent 589456a393
commit 5fee069379
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 588 additions and 91 deletions

View File

@ -0,0 +1,155 @@
package dns
import (
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
)
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
IsWildcard bool
StopHandler handlerWithStop
}
type HandlerChain struct {
mu sync.RWMutex
handlers []HandlerEntry
}
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
shouldContinue bool
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
w.shouldContinue = true
return nil
}
return w.ResponseWriter.WriteMsg(m)
}
func NewHandlerChain() *HandlerChain {
return &HandlerChain{
handlers: make([]HandlerEntry, 0),
}
}
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
defer c.mu.Unlock()
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
log.Debugf("adding handler for pattern: %s (wildcard: %v) with priority %d", pattern, isWildcard, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
}
// Insert handler in priority order
pos := 0
for i, h := range c.handlers {
if h.Priority < priority {
pos = i
break
}
pos = i + 1
}
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
func (c *HandlerChain) RemoveHandler(pattern string) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
for i, entry := range c.handlers {
if entry.Pattern == pattern {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
}
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := r.Question[0].Name
log.Debugf("handling DNS request for %s", qname)
c.mu.RLock()
defer c.mu.RUnlock()
log.Debugf("current handlers (%d):", len(c.handlers))
for _, h := range c.handlers {
log.Debugf(" - pattern: %s, wildcard: %v, priority: %d", h.Pattern, h.IsWildcard, h.Priority)
}
// Try handlers in priority order
for _, entry := range c.handlers {
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
}
if !matched {
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
entry.Pattern, qname, entry.IsWildcard)
continue
}
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
entry.Pattern, qname, entry.IsWildcard)
chainWriter := &ResponseWriterChain{ResponseWriter: w}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Debugf("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Debugf("no handler found for %s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}

View File

@ -0,0 +1,319 @@
package dns_test
import (
"net"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
)
// MockHandler implements dns.Handler interface for testing
type MockHandler struct {
mock.Mock
}
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.Called(w, r)
}
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create mock handlers for different priorities
defaultHandler := &MockHandler{}
matchDomainHandler := &MockHandler{}
dnsRouteHandler := &MockHandler{}
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
// Execute
chain.ServeDNS(w, r)
// Verify all expectations were met
dnsRouteHandler.AssertExpectations(t)
matchDomainHandler.AssertExpectations(t)
defaultHandler.AssertExpectations(t)
}
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct {
name string
handlerDomain string
queryDomain string
isWildcard bool
shouldMatch bool
}{
{
name: "exact match",
handlerDomain: "example.com.",
queryDomain: "example.com.",
isWildcard: false,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
shouldMatch: true,
},
{
name: "wildcard match",
handlerDomain: "*.example.com.",
queryDomain: "sub.example.com.",
isWildcard: true,
shouldMatch: true,
},
{
name: "wildcard no match on apex",
handlerDomain: "*.example.com.",
queryDomain: "example.com.",
isWildcard: true,
shouldMatch: false,
},
{
name: "root zone match",
handlerDomain: ".",
queryDomain: "anything.com.",
isWildcard: false,
shouldMatch: true,
},
{
name: "no match different domain",
handlerDomain: "example.com.",
queryDomain: "example.org.",
isWildcard: false,
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
mockHandler := &MockHandler{}
pattern := tt.handlerDomain
if tt.isWildcard {
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
}
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
if tt.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, r).Once()
}
chain.ServeDNS(w, r)
mockHandler.AssertExpectations(t)
})
}
}
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct {
name string
handlers []struct {
pattern string
priority int
}
queryDomain string
expectedCalls int
expectedHandler int // index of the handler that should be called
}{
{
name: "wildcard and exact same priority - exact should win",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDefault},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // exact match handler should be called
},
{
name: "higher priority wildcard over lower priority exact",
handlers: []struct {
pattern string
priority int
}{
{pattern: "example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority wildcard handler should be called
},
{
name: "multiple wildcards different priorities",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority handler should be called
},
{
name: "subdomain with mix of patterns",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "root zone with specific domain",
handlers: []struct {
pattern string
priority int
}{
{pattern: ".", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority specific domain should win over root
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handlers []*MockHandler
// Setup handlers and expectations
for i := range tt.handlers {
handler := &MockHandler{}
handlers = append(handlers, handler)
// Set expectation based on whether this handler should be called
if i == tt.expectedHandler {
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
} else {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
}
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
}
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
})
}
}
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create handlers
handler1 := &MockHandler{}
handler2 := &MockHandler{}
handler3 := &MockHandler{}
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Setup mock responses to simulate chain continuation
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// First handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true // Signal to continue
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Second handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Last handler responds normally
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
// Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify all handlers were called in order
handler1.AssertExpectations(t)
handler2.AssertExpectations(t)
handler3.AssertExpectations(t)
}
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}

View File

@ -13,27 +13,20 @@ type MockServer struct {
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler) error
UnregisterHandlerFunc func([]string) error
DeregisterHandlerFunc func([]string) error
RegisterHandlerFunc func([]string, dns.Handler, int)
DeregisterHandlerFunc func([]string)
}
func (m *MockServer) UnregisterHandler(domains []string) error {
panic("implement me")
}
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler) error {
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
if m.RegisterHandlerFunc != nil {
return m.RegisterHandlerFunc(domains, handler)
m.RegisterHandlerFunc(domains, handler, priority)
}
return nil
}
func (m *MockServer) DeregisterHandler(domains []string) error {
func (m *MockServer) DeregisterHandler(domains []string) {
if m.DeregisterHandlerFunc != nil {
return m.DeregisterHandlerFunc(domains)
m.DeregisterHandlerFunc(domains)
}
return nil
}
// Initialize mock implementation of Initialize from Server interface

View File

@ -30,7 +30,8 @@ type IosDnsManager interface {
// Server is a dns server interface
type Server interface {
RegisterHandler(domains []string, handler dns.Handler) error
RegisterHandler(domains []string, handler dns.Handler, priority int)
DeregisterHandler(domains []string)
Initialize() error
Stop()
DnsIP() string
@ -38,7 +39,6 @@ type Server interface {
OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string
ProbeAvailability()
DeregisterHandler(domains []string) error
}
type registeredHandlerMap map[string]handlerWithStop
@ -56,6 +56,7 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
handlerChain *HandlerChain
// permanent related properties
permanent bool
@ -78,6 +79,7 @@ type handlerWithStop interface {
type muxUpdate struct {
domain string
handler handlerWithStop
priority int
}
// NewDefaultServer returns a new dns server
@ -140,6 +142,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
ctx: ctx,
ctxCancel: stop,
service: dnsService,
handlerChain: NewHandlerChain(),
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
@ -153,32 +156,38 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
return defaultServer
}
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler) error {
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("registering handler %s", handler)
for _, domain := range domains {
wosuff, _ := strings.CutPrefix(domain, "*.")
pattern := dns.Fqdn(wosuff)
s.service.RegisterMux(pattern, handler)
}
return nil
s.registerHandler(domains, handler, priority)
}
func (s *DefaultServer) DeregisterHandler(domains []string) error {
// registerhandler without lock
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains {
s.handlerChain.AddHandler(domain, handler, priority, nil)
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
}
}
func (s *DefaultServer) DeregisterHandler(domains []string) {
s.mux.Lock()
defer s.mux.Unlock()
s.deregisterHandler(domains)
}
func (s *DefaultServer) deregisterHandler(domains []string) {
log.Debugf("unregistering handler for domains %s", domains)
for _, domain := range domains {
wosuff, _ := strings.CutPrefix(domain, "*.")
pattern := dns.Fqdn(wosuff)
s.service.DeregisterMux(pattern)
}
s.handlerChain.RemoveHandler(domain)
return nil
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
}
}
// Initialize instantiate host manager and the dns service
@ -444,6 +453,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
priority: PriorityDefault,
})
continue
}
@ -461,6 +471,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
priority: PriorityMatchDomain,
})
}
}
@ -474,7 +485,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
var isContainRootUpdate bool
for _, update := range muxUpdates {
s.service.RegisterMux(update.domain, update.handler)
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
@ -493,7 +504,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
existingHandler.stop()
} else {
existingHandler.stop()
s.service.DeregisterMux(key)
s.deregisterHandler([]string{key})
}
}
}
@ -547,13 +558,13 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false
s.service.DeregisterMux(nbdns.RootZone)
s.deregisterHandler([]string{nbdns.RootZone})
}
for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true
s.service.DeregisterMux(item.Domain)
s.deregisterHandler([]string{item.Domain})
removeIndex[item.Domain] = i
}
}
@ -584,7 +595,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue
}
s.currentConfig.Domains[i].Disabled = false
s.service.RegisterMux(domain, handler)
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
}
l := log.WithField("nameservers", nsGroup.NameServers)
@ -592,7 +603,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler)
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
}
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
@ -623,7 +634,8 @@ func (s *DefaultServer) addHostRootZone() {
}
handler.deactivate = func(error) {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
s.RegisterHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
}
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {

View File

@ -68,7 +68,6 @@ func (s *ServiceViaMemory) Stop() {
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
log.Debugf("registering dns handler for pattern: %s", pattern)
s.dnsMux.Handle(pattern, handler)
}

View File

@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("%v", u.upstreamServers)
}
func (u *upstreamResolverBase) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()

View File

@ -6,6 +6,8 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type DNSForwarder struct {
@ -18,6 +20,7 @@ type DNSForwarder struct {
}
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
@ -29,7 +32,7 @@ func (f *DNSForwarder) Listen() error {
mux := dns.NewServeMux()
for _, d := range f.domains {
mux.HandleFunc(d, f.handleDNSQuery)
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
}
dnsServer := &dns.Server{
@ -47,8 +50,8 @@ func (f *DNSForwarder) UpdateDomains(domains []string) {
f.mux.HandleRemove(d)
}
for _, d := range domains {
f.mux.HandleFunc(d, f.handleDNSQuery)
for _, d := range f.domains {
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
}
f.domains = domains
}
@ -61,10 +64,10 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
log.Tracef("received DNS query for DNS forwarder: %v", query)
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
question := query.Question[0]
domain := question.Name
@ -74,16 +77,17 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
ips, err := net.LookupIP(domain)
if err != nil {
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
resp.Rcode = dns.RcodeRefused
_ = w.WriteMsg(resp)
resp.Rcode = dns.RcodeServerFailure
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
return
}
for _, ip := range ips {
log.Infof("resolved domain %s to IP %s", domain, ip)
var respRecord dns.RR
if ip.To4() == nil {
log.Infof("resolved domain %s to IPv6 %s", domain, ip)
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
@ -95,6 +99,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
}
respRecord = &rr
} else {
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{

View File

@ -13,7 +13,6 @@ import (
"time"
"github.com/google/uuid"
miekdns "github.com/miekg/dns"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@ -104,8 +103,6 @@ func TestEngine_SSH(t *testing.T) {
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
RegisterHandlerFunc: func(domains []string, handler miekdns.Handler) error { return nil },
DeregisterHandlerFunc: func(domains []string) error { return nil },
}
var sshKeysAdded []string

View File

@ -57,15 +57,12 @@ func New(
}
func (d *DnsInterceptor) String() string {
s, err := d.route.Domains.String()
if err != nil {
return d.route.Domains.PunycodeString()
}
return s
return d.route.Domains.SafeString()
}
func (d *DnsInterceptor) AddRoute(context.Context) error {
return d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d)
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
return nil
}
func (d *DnsInterceptor) RemoveRoute() error {
@ -91,9 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
clear(d.interceptedDomains)
if err := d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList()); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unregister DNS handler: %v", err))
}
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList())
return nberrors.FormatErrorOrNil(merr)
}
@ -143,23 +138,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
log.Debugf("received DNS request: %v", r.Question[0].Name)
log.Tracef("received DNS request: %v", r.Question[0].Name)
if d.currentPeerKey == "" {
// TODO: call normal upstream instead of returning an error?
log.Debugf("no current peer key set, not resolving DNS request %s", r.Question[0].Name)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
d.continueToNextHandler(w, r, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP()
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
log.Errorf("failed to get upstream IP: %v", err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
return
}
@ -169,7 +163,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, d.currentPeerKey, r.Question[0].Name, reply.Answer)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
@ -185,13 +184,22 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}
func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
d.mu.RLock()
defer d.mu.RUnlock()
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
}
}
peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey)
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
}
return peerAllowedIP, nil
}

View File

@ -74,11 +74,7 @@ func NewRoute(
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
return r.route.Domains.SafeString()
}
func (r *Route) AddRoute(ctx context.Context) error {

View File

@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) {
return validHost, nil
}
// NormalizeZone returns a normalized domain name without the wildcard prefix
func NormalizeZone(domain string) string {
d, _ := strings.CutPrefix(domain, "*.")
return d
}

1
go.mod
View File

@ -207,6 +207,7 @@ require (
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect

1
go.sum
View File

@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=