Files
netbird/client/internal/dns/handler_chain.go
Viktor Liu 5fee069379 Add handler chains (#3039)
---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-12-12 18:19:06 +01:00

156 lines
3.6 KiB
Go

package dns
import (
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
)
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
IsWildcard bool
StopHandler handlerWithStop
}
type HandlerChain struct {
mu sync.RWMutex
handlers []HandlerEntry
}
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
shouldContinue bool
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
w.shouldContinue = true
return nil
}
return w.ResponseWriter.WriteMsg(m)
}
func NewHandlerChain() *HandlerChain {
return &HandlerChain{
handlers: make([]HandlerEntry, 0),
}
}
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
defer c.mu.Unlock()
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
log.Debugf("adding handler for pattern: %s (wildcard: %v) with priority %d", pattern, isWildcard, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
}
// Insert handler in priority order
pos := 0
for i, h := range c.handlers {
if h.Priority < priority {
pos = i
break
}
pos = i + 1
}
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
func (c *HandlerChain) RemoveHandler(pattern string) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
for i, entry := range c.handlers {
if entry.Pattern == pattern {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
}
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := r.Question[0].Name
log.Debugf("handling DNS request for %s", qname)
c.mu.RLock()
defer c.mu.RUnlock()
log.Debugf("current handlers (%d):", len(c.handlers))
for _, h := range c.handlers {
log.Debugf(" - pattern: %s, wildcard: %v, priority: %d", h.Pattern, h.IsWildcard, h.Priority)
}
// Try handlers in priority order
for _, entry := range c.handlers {
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
}
if !matched {
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
entry.Pattern, qname, entry.IsWildcard)
continue
}
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
entry.Pattern, qname, entry.IsWildcard)
chainWriter := &ResponseWriterChain{ResponseWriter: w}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Debugf("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Debugf("no handler found for %s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}