mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-26 15:59:07 +01:00
193 lines
4.8 KiB
Go
193 lines
4.8 KiB
Go
package dns
|
|
|
|
import (
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
PriorityDNSRoute = 100
|
|
PriorityMatchDomain = 50
|
|
PriorityDefault = 0
|
|
)
|
|
|
|
type HandlerEntry struct {
|
|
Handler dns.Handler
|
|
Priority int
|
|
Pattern string
|
|
OrigPattern string
|
|
IsWildcard bool
|
|
StopHandler handlerWithStop
|
|
}
|
|
|
|
// HandlerChain represents a prioritized chain of DNS handlers
|
|
type HandlerChain struct {
|
|
mu sync.RWMutex
|
|
handlers []HandlerEntry
|
|
}
|
|
|
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
|
type ResponseWriterChain struct {
|
|
dns.ResponseWriter
|
|
shouldContinue bool
|
|
}
|
|
|
|
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
|
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
|
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
|
w.shouldContinue = true
|
|
return nil
|
|
}
|
|
return w.ResponseWriter.WriteMsg(m)
|
|
}
|
|
|
|
func NewHandlerChain() *HandlerChain {
|
|
return &HandlerChain{
|
|
handlers: make([]HandlerEntry, 0),
|
|
}
|
|
}
|
|
|
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
origPattern := pattern
|
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
|
if isWildcard {
|
|
pattern = pattern[2:]
|
|
}
|
|
pattern = dns.Fqdn(pattern)
|
|
origPattern = dns.Fqdn(origPattern)
|
|
|
|
// First remove any existing handler with same original pattern and priority
|
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
|
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
|
if c.handlers[i].StopHandler != nil {
|
|
c.handlers[i].StopHandler.stop()
|
|
}
|
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
|
|
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
|
|
pattern, origPattern, isWildcard, priority)
|
|
|
|
entry := HandlerEntry{
|
|
Handler: handler,
|
|
Priority: priority,
|
|
Pattern: pattern,
|
|
OrigPattern: origPattern,
|
|
IsWildcard: isWildcard,
|
|
StopHandler: stopHandler,
|
|
}
|
|
|
|
// Insert handler in priority order
|
|
pos := 0
|
|
for i, h := range c.handlers {
|
|
if h.Priority < priority {
|
|
pos = i
|
|
break
|
|
}
|
|
pos = i + 1
|
|
}
|
|
|
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
|
}
|
|
|
|
// RemoveHandler removes a handler for the given pattern and priority
|
|
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
pattern = dns.Fqdn(pattern)
|
|
|
|
// Find and remove handlers matching both original pattern and priority
|
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
|
entry := c.handlers[i]
|
|
if entry.OrigPattern == pattern && entry.Priority == priority {
|
|
if entry.StopHandler != nil {
|
|
entry.StopHandler.stop()
|
|
}
|
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
|
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
pattern = dns.Fqdn(pattern)
|
|
for _, entry := range c.handlers {
|
|
if entry.Pattern == pattern {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
if len(r.Question) == 0 {
|
|
return
|
|
}
|
|
|
|
qname := r.Question[0].Name
|
|
log.Debugf("handling DNS request for %s", qname)
|
|
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
log.Debugf("current handlers (%d):", len(c.handlers))
|
|
for _, h := range c.handlers {
|
|
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
|
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
|
}
|
|
|
|
// Try handlers in priority order
|
|
for _, entry := range c.handlers {
|
|
var matched bool
|
|
switch {
|
|
case entry.Pattern == ".":
|
|
matched = true
|
|
case entry.IsWildcard:
|
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
|
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
|
default:
|
|
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
|
}
|
|
|
|
if !matched {
|
|
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
|
entry.OrigPattern, qname, entry.IsWildcard)
|
|
continue
|
|
}
|
|
|
|
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
|
entry.OrigPattern, qname, entry.IsWildcard)
|
|
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
|
entry.Handler.ServeDNS(chainWriter, r)
|
|
|
|
// If handler wants to continue, try next handler
|
|
if chainWriter.shouldContinue {
|
|
log.Debugf("handler requested continue to next handler")
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
|
|
// No handler matched or all handlers passed
|
|
log.Debugf("no handler found for %s", qname)
|
|
resp := &dns.Msg{}
|
|
resp.SetRcode(r, dns.RcodeNameError)
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write DNS response: %v", err)
|
|
}
|
|
}
|