netbird/client/internal/dns/handler_chain.go

193 lines
4.8 KiB
Go

package dns
import (
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
)
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
}
// HandlerChain represents a prioritized chain of DNS handlers
type HandlerChain struct {
mu sync.RWMutex
handlers []HandlerEntry
}
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
shouldContinue bool
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
w.shouldContinue = true
return nil
}
return w.ResponseWriter.WriteMsg(m)
}
func NewHandlerChain() *HandlerChain {
return &HandlerChain{
handlers: make([]HandlerEntry, 0),
}
}
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
defer c.mu.Unlock()
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
pattern, origPattern, isWildcard, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
}
// Insert handler in priority order
pos := 0
for i, h := range c.handlers {
if h.Priority < priority {
pos = i
break
}
pos = i + 1
}
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
// RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
}
}
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern)
for _, entry := range c.handlers {
if entry.Pattern == pattern {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := r.Question[0].Name
log.Debugf("handling DNS request for %s", qname)
c.mu.RLock()
defer c.mu.RUnlock()
log.Debugf("current handlers (%d):", len(c.handlers))
for _, h := range c.handlers {
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
}
// Try handlers in priority order
for _, entry := range c.handlers {
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
}
if !matched {
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
entry.OrigPattern, qname, entry.IsWildcard)
continue
}
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
entry.OrigPattern, qname, entry.IsWildcard)
chainWriter := &ResponseWriterChain{ResponseWriter: w}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Debugf("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Debugf("no handler found for %s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}