mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-08 23:19:09 +01:00
Add dns interceptor based domain route functionality (#3032)
This commit is contained in:
parent
a145f0b811
commit
c91d7808bf
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=12m
|
||||
args: --timeout=12m --out-format colored-line-number
|
||||
|
192
client/internal/dns/handler_chain.go
Normal file
192
client/internal/dns/handler_chain.go
Normal file
@ -0,0 +1,192 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
PriorityDNSRoute = 100
|
||||
PriorityMatchDomain = 50
|
||||
PriorityDefault = 0
|
||||
)
|
||||
|
||||
type HandlerEntry struct {
|
||||
Handler dns.Handler
|
||||
Priority int
|
||||
Pattern string
|
||||
OrigPattern string
|
||||
IsWildcard bool
|
||||
StopHandler handlerWithStop
|
||||
}
|
||||
|
||||
// HandlerChain represents a prioritized chain of DNS handlers
|
||||
type HandlerChain struct {
|
||||
mu sync.RWMutex
|
||||
handlers []HandlerEntry
|
||||
}
|
||||
|
||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||
type ResponseWriterChain struct {
|
||||
dns.ResponseWriter
|
||||
shouldContinue bool
|
||||
}
|
||||
|
||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
||||
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
||||
w.shouldContinue = true
|
||||
return nil
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
func NewHandlerChain() *HandlerChain {
|
||||
return &HandlerChain{
|
||||
handlers: make([]HandlerEntry, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
origPattern := pattern
|
||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||
if isWildcard {
|
||||
pattern = pattern[2:]
|
||||
}
|
||||
pattern = dns.Fqdn(pattern)
|
||||
origPattern = dns.Fqdn(origPattern)
|
||||
|
||||
// First remove any existing handler with same original pattern and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
|
||||
pattern, origPattern, isWildcard, priority)
|
||||
|
||||
entry := HandlerEntry{
|
||||
Handler: handler,
|
||||
Priority: priority,
|
||||
Pattern: pattern,
|
||||
OrigPattern: origPattern,
|
||||
IsWildcard: isWildcard,
|
||||
StopHandler: stopHandler,
|
||||
}
|
||||
|
||||
// Insert handler in priority order
|
||||
pos := 0
|
||||
for i, h := range c.handlers {
|
||||
if h.Priority < priority {
|
||||
pos = i
|
||||
break
|
||||
}
|
||||
pos = i + 1
|
||||
}
|
||||
|
||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||
}
|
||||
|
||||
// RemoveHandler removes a handler for the given pattern and priority
|
||||
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
// Find and remove handlers matching both original pattern and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if entry.OrigPattern == pattern && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
for _, entry := range c.handlers {
|
||||
if entry.Pattern == pattern {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
qname := r.Question[0].Name
|
||||
log.Debugf("handling DNS request for %s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
log.Debugf("current handlers (%d):", len(c.handlers))
|
||||
for _, h := range c.handlers {
|
||||
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
|
||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||
}
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range c.handlers {
|
||||
var matched bool
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
matched = true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
default:
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
}
|
||||
|
||||
if !matched {
|
||||
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
||||
entry.OrigPattern, qname, entry.IsWildcard)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
||||
entry.OrigPattern, qname, entry.IsWildcard)
|
||||
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
||||
entry.Handler.ServeDNS(chainWriter, r)
|
||||
|
||||
// If handler wants to continue, try next handler
|
||||
if chainWriter.shouldContinue {
|
||||
log.Debugf("handler requested continue to next handler")
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No handler matched or all handlers passed
|
||||
log.Debugf("no handler found for %s", qname)
|
||||
resp := &dns.Msg{}
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
490
client/internal/dns/handler_chain_test.go
Normal file
490
client/internal/dns/handler_chain_test.go
Normal file
@ -0,0 +1,490 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
)
|
||||
|
||||
// MockHandler implements dns.Handler interface for testing
|
||||
type MockHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.Called(w, r)
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
// Create mock handlers for different priorities
|
||||
defaultHandler := &MockHandler{}
|
||||
matchDomainHandler := &MockHandler{}
|
||||
dnsRouteHandler := &MockHandler{}
|
||||
|
||||
// Setup handlers with different priorities
|
||||
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Create test writer
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// Setup expectations - only highest priority handler should be called
|
||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||
|
||||
// Execute
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify all expectations were met
|
||||
dnsRouteHandler.AssertExpectations(t)
|
||||
matchDomainHandler.AssertExpectations(t)
|
||||
defaultHandler.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlerDomain string
|
||||
queryDomain string
|
||||
isWildcard bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "subdomain with non-wildcard",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "sub.example.com.",
|
||||
isWildcard: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on apex",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "example.com.",
|
||||
isWildcard: true,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "root zone match",
|
||||
handlerDomain: ".",
|
||||
queryDomain: "anything.com.",
|
||||
isWildcard: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "no match different domain",
|
||||
handlerDomain: "example.com.",
|
||||
queryDomain: "example.org.",
|
||||
isWildcard: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
mockHandler := &MockHandler{}
|
||||
|
||||
pattern := tt.handlerDomain
|
||||
if tt.isWildcard {
|
||||
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
if tt.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
mockHandler.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
||||
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handlers []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}
|
||||
queryDomain string
|
||||
expectedCalls int
|
||||
expectedHandler int // index of the handler that should be called
|
||||
}{
|
||||
{
|
||||
name: "wildcard and exact same priority - exact should win",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||
},
|
||||
queryDomain: "example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1, // exact match handler should be called
|
||||
},
|
||||
{
|
||||
name: "higher priority wildcard over lower priority exact",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "test.example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1, // higher priority wildcard handler should be called
|
||||
},
|
||||
{
|
||||
name: "multiple wildcards different priorities",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "test.example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority handler should be called
|
||||
},
|
||||
{
|
||||
name: "subdomain with mix of patterns",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "sub.test.example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority matching handler should be called
|
||||
},
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "example.com.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1, // higher priority specific domain should win over root
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
var handlers []*MockHandler
|
||||
|
||||
// Setup handlers and expectations
|
||||
for i := range tt.handlers {
|
||||
handler := &MockHandler{}
|
||||
handlers = append(handlers, handler)
|
||||
|
||||
// Set expectation based on whether this handler should be called
|
||||
if i == tt.expectedHandler {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||
} else {
|
||||
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||
}
|
||||
|
||||
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||
}
|
||||
|
||||
// Create and execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify expectations
|
||||
for _, handler := range handlers {
|
||||
handler.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
||||
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
// Create handlers
|
||||
handler1 := &MockHandler{}
|
||||
handler2 := &MockHandler{}
|
||||
handler3 := &MockHandler{}
|
||||
|
||||
// Add handlers in priority order
|
||||
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Setup mock responses to simulate chain continuation
|
||||
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
// First handler signals continue
|
||||
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true // Signal to continue
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
|
||||
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
// Second handler signals continue
|
||||
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
resp.MsgHdr.Zero = true
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
|
||||
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||
// Last handler responds normally
|
||||
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
|
||||
// Execute
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify all handlers were called in order
|
||||
handler1.AssertExpectations(t)
|
||||
handler2.AssertExpectations(t)
|
||||
handler3.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||
type mockResponseWriter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||
func (m *mockResponseWriter) Close() error { return nil }
|
||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (m *mockResponseWriter) Hijack() {}
|
||||
|
||||
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ops []struct {
|
||||
action string // "add" or "remove"
|
||||
pattern string
|
||||
priority int
|
||||
}
|
||||
query string
|
||||
expectedCalls map[int]bool // map[priority]shouldBeCalled
|
||||
}{
|
||||
{
|
||||
name: "remove high priority keeps lower priority handler",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: map[int]bool{
|
||||
nbdns.PriorityDNSRoute: false,
|
||||
nbdns.PriorityMatchDomain: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove lower priority keeps high priority handler",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: map[int]bool{
|
||||
nbdns.PriorityDNSRoute: true,
|
||||
nbdns.PriorityMatchDomain: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove all handlers in order",
|
||||
ops: []struct {
|
||||
action string
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||
{"add", "example.com.", nbdns.PriorityDefault},
|
||||
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: map[int]bool{
|
||||
nbdns.PriorityDNSRoute: false,
|
||||
nbdns.PriorityMatchDomain: false,
|
||||
nbdns.PriorityDefault: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlers := make(map[int]*MockHandler)
|
||||
|
||||
// Execute operations
|
||||
for _, op := range tt.ops {
|
||||
if op.action == "add" {
|
||||
handler := &MockHandler{}
|
||||
handlers[op.priority] = handler
|
||||
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||
} else {
|
||||
chain.RemoveHandler(op.pattern, op.priority)
|
||||
}
|
||||
}
|
||||
|
||||
// Create test request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// Setup expectations
|
||||
for priority, handler := range handlers {
|
||||
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
} else {
|
||||
handler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||
}
|
||||
}
|
||||
|
||||
// Execute request
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
// Verify expectations
|
||||
for _, handler := range handlers {
|
||||
handler.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verify handler exists check
|
||||
for priority, shouldExist := range tt.expectedCalls {
|
||||
if shouldExist {
|
||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
testDomain := "example.com."
|
||||
testQuery := "test.example.com."
|
||||
|
||||
// Create handlers for three priority levels
|
||||
routeHandler := &MockHandler{}
|
||||
matchHandler := &MockHandler{}
|
||||
defaultHandler := &MockHandler{}
|
||||
|
||||
// Create test request that will be reused
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Highest priority handler (routeHandler) should be called
|
||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
routeHandler.AssertExpectations(t)
|
||||
|
||||
// Test 2: Remove highest priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now middle priority handler (matchHandler) should be called
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
matchHandler.AssertExpectations(t)
|
||||
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
defaultHandler.AssertExpectations(t)
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
}
|
@ -3,14 +3,30 @@ package dns
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
// MockServer is the mock instance of a dns server
|
||||
type MockServer struct {
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||
DeregisterHandlerFunc func([]string, int)
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
if m.RegisterHandlerFunc != nil {
|
||||
m.RegisterHandlerFunc(domains, handler, priority)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||
if m.DeregisterHandlerFunc != nil {
|
||||
m.DeregisterHandlerFunc(domains, priority)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize mock implementation of Initialize from Server interface
|
||||
|
@ -30,6 +30,8 @@ type IosDnsManager interface {
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains []string, priority int)
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() string
|
||||
@ -48,12 +50,14 @@ type DefaultServer struct {
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
handlerPriorities map[string]int
|
||||
localResolver *localResolver
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
handlerChain *HandlerChain
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
@ -74,8 +78,9 @@ type handlerWithStop interface {
|
||||
}
|
||||
|
||||
type muxUpdate struct {
|
||||
domain string
|
||||
handler handlerWithStop
|
||||
domain string
|
||||
handler handlerWithStop
|
||||
priority int
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
@ -135,10 +140,12 @@ func NewDefaultServerIos(
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
service: dnsService,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
handlerPriorities: make(map[string]int),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
@ -151,6 +158,41 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.registerHandler(domains, handler, priority)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||
|
||||
for _, domain := range domains {
|
||||
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||
s.handlerPriorities[domain] = priority
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.deregisterHandler(domains, priority)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
for _, domain := range domains {
|
||||
s.handlerChain.RemoveHandler(domain, priority)
|
||||
|
||||
// Only deregister from service if no handlers remain
|
||||
if !s.handlerChain.HasHandlers(domain) {
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager and the dns service
|
||||
func (s *DefaultServer) Initialize() (err error) {
|
||||
s.mux.Lock()
|
||||
@ -343,14 +385,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||
|
||||
for _, customZone := range customZones {
|
||||
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
domain: customZone.Domain,
|
||||
handler: s.localResolver,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
@ -412,8 +454,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
|
||||
if nsGroup.Primary {
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
domain: nbdns.RootZone,
|
||||
handler: handler,
|
||||
priority: PriorityDefault,
|
||||
})
|
||||
continue
|
||||
}
|
||||
@ -429,8 +472,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
|
||||
}
|
||||
muxUpdates = append(muxUpdates, muxUpdate{
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
domain: domain,
|
||||
handler: handler,
|
||||
priority: PriorityMatchDomain,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -440,12 +484,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
handlersByPriority := make(map[string]int)
|
||||
|
||||
var isContainRootUpdate bool
|
||||
|
||||
// First register new handlers
|
||||
for _, update := range muxUpdates {
|
||||
s.service.RegisterMux(update.domain, update.handler)
|
||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
handlersByPriority[update.domain] = update.priority
|
||||
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
}
|
||||
@ -455,6 +503,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
// Then deregister old handlers not in the update
|
||||
for key, existingHandler := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
@ -463,12 +512,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
s.service.DeregisterMux(key)
|
||||
// Deregister with the priority that was used to register
|
||||
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
||||
s.deregisterHandler([]string{key}, oldPriority)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.dnsMuxMap = muxUpdateMap
|
||||
s.handlerPriorities = handlersByPriority
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||
@ -517,13 +570,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
if nsGroup.Primary {
|
||||
removeIndex[nbdns.RootZone] = -1
|
||||
s.currentConfig.RouteAll = false
|
||||
s.service.DeregisterMux(nbdns.RootZone)
|
||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
||||
}
|
||||
|
||||
for i, item := range s.currentConfig.Domains {
|
||||
if _, found := removeIndex[item.Domain]; found {
|
||||
s.currentConfig.Domains[i].Disabled = true
|
||||
s.service.DeregisterMux(item.Domain)
|
||||
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
||||
removeIndex[item.Domain] = i
|
||||
}
|
||||
}
|
||||
@ -554,7 +607,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.Domains[i].Disabled = false
|
||||
s.service.RegisterMux(domain, handler)
|
||||
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@ -562,7 +615,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
|
||||
if nsGroup.Primary {
|
||||
s.currentConfig.RouteAll = true
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
@ -593,7 +646,8 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
}
|
||||
handler.deactivate = func(error) {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||
|
@ -512,7 +512,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
@ -560,7 +560,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
hostManager: hostManager,
|
||||
handlerChain: NewHandlerChain(),
|
||||
handlerPriorities: make(map[string]int),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
|
@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
|
@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("%v", u.upstreamServers)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
|
120
client/internal/dnsfwd/forwarder.go
Normal file
120
client/internal/dnsfwd/forwarder.go
Normal file
@ -0,0 +1,120 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
|
||||
dnsServer *dns.Server
|
||||
mux *dns.ServeMux
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
domains: domains,
|
||||
}
|
||||
}
|
||||
func (f *DNSForwarder) Listen() error {
|
||||
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
for _, d := range f.domains {
|
||||
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
}
|
||||
f.dnsServer = dnsServer
|
||||
f.mux = mux
|
||||
return dnsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleRemove(d)
|
||||
}
|
||||
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||
}
|
||||
f.domains = domains
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
if f.dnsServer == nil {
|
||||
return nil
|
||||
}
|
||||
return f.dnsServer.ShutdownContext(ctx)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
|
||||
|
||||
question := query.Question[0]
|
||||
domain := question.Name
|
||||
|
||||
resp := query.SetReply(query)
|
||||
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
var respRecord dns.RR
|
||||
if ip.To4() == nil {
|
||||
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
} else {
|
||||
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip,
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: f.ttl,
|
||||
},
|
||||
}
|
||||
respRecord = &rr
|
||||
}
|
||||
resp.Answer = append(resp.Answer, respRecord)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
106
client/internal/dnsfwd/manager.go
Normal file
106
client/internal/dnsfwd/manager.go
Normal file
@ -0,0 +1,106 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
ListenPort = 5353
|
||||
dnsTTL = 60 //seconds
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
|
||||
fwRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(domains []string) error {
|
||||
log.Infof("starting DNS forwarder")
|
||||
if m.dnsForwarder != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.allowDNSFirewall(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateDomains(domains []string) {
|
||||
if m.dnsForwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.dnsForwarder.UpdateDomains(domains)
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
if m.dnsForwarder == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var mErr *multierror.Error
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
|
||||
if err := m.dnsForwarder.Close(ctx); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
|
||||
m.dnsForwarder = nil
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
func (h *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
}
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
}
|
||||
h.fwRules = dnsRules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range h.fwRules {
|
||||
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
h.fwRules = nil
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -30,10 +29,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@ -117,7 +118,7 @@ type Engine struct {
|
||||
// mgmClient is a Management Service client
|
||||
mgmClient mgm.Client
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
peerStore *peerstore.Store
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
@ -137,10 +138,6 @@ type Engine struct {
|
||||
TURNs []*stun.URI
|
||||
stunTurn atomic.Value
|
||||
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
clientRoutesMu sync.RWMutex
|
||||
|
||||
clientCtx context.Context
|
||||
clientCancel context.CancelFunc
|
||||
|
||||
@ -161,9 +158,10 @@ type Engine struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
firewall manager.Manager
|
||||
routeManager routemanager.Manager
|
||||
acl acl.Manager
|
||||
firewall manager.Manager
|
||||
routeManager routemanager.Manager
|
||||
acl acl.Manager
|
||||
dnsForwardMgr *dnsfwd.Manager
|
||||
|
||||
dnsServer dns.Server
|
||||
|
||||
@ -234,7 +232,7 @@ func NewEngineWithProbes(
|
||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||
mgmClient: mgmClient,
|
||||
relayManager: relayManager,
|
||||
peerConns: make(map[string]*peer.Conn),
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
@ -287,6 +285,13 @@ func (e *Engine) Stop() error {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.dnsForwardMgr != nil {
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
@ -300,10 +305,6 @@ func (e *Engine) Stop() error {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = nil
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
@ -382,6 +383,8 @@ func (e *Engine) Start() error {
|
||||
e.relayManager,
|
||||
initialRoutes,
|
||||
e.stateManager,
|
||||
dnsServer,
|
||||
e.peerStore,
|
||||
)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
@ -460,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
var modified []*mgmProto.RemotePeerConfig
|
||||
for _, p := range peersUpdate {
|
||||
peerPubKey := p.GetWgPubKey()
|
||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
||||
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
||||
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
|
||||
if allowedIPs != strings.Join(p.AllowedIps, ",") {
|
||||
modified = append(modified, p)
|
||||
continue
|
||||
}
|
||||
@ -492,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
currentPeers := make([]string, 0, len(e.peerConns))
|
||||
for p := range e.peerConns {
|
||||
currentPeers = append(currentPeers, p)
|
||||
}
|
||||
|
||||
newPeers := make([]string, 0, len(peersUpdate))
|
||||
for _, p := range peersUpdate {
|
||||
newPeers = append(newPeers, p.GetWgPubKey())
|
||||
}
|
||||
|
||||
toRemove := util.SliceDiff(currentPeers, newPeers)
|
||||
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||
|
||||
for _, p := range toRemove {
|
||||
err := e.removePeer(p)
|
||||
@ -516,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
|
||||
func (e *Engine) removeAllPeers() error {
|
||||
log.Debugf("removing all peer connections")
|
||||
for p := range e.peerConns {
|
||||
for _, p := range e.peerStore.PeersPubKey() {
|
||||
err := e.removePeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -540,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
conn, exists := e.peerConns[peerKey]
|
||||
conn, exists := e.peerStore.Remove(peerKey)
|
||||
if exists {
|
||||
delete(e.peerConns, peerKey)
|
||||
conn.Close()
|
||||
}
|
||||
return nil
|
||||
@ -786,7 +783,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||
if networkMap.GetPeerConfig() != nil {
|
||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||
@ -806,19 +802,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
routedDomains, routes := toRoutes(networkMap.GetRoutes())
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
// todo: useRoutingPeerDnsResolutionEnabled from network map proto
|
||||
e.updateDNSForwarder(true, routedDomains)
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
@ -867,8 +858,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
||||
if err != nil {
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
@ -881,7 +871,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
var dnsRoutes []string
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
var prefix netip.Prefix
|
||||
@ -892,6 +887,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
continue
|
||||
}
|
||||
}
|
||||
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
@ -905,7 +902,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
return routes
|
||||
return dnsRoutes, routes
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
@ -982,12 +979,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
peerKey := peerConfig.GetWgPubKey()
|
||||
peerIPs := peerConfig.GetAllowedIps()
|
||||
if _, ok := e.peerConns[peerKey]; !ok {
|
||||
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
|
||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create peer connection: %w", err)
|
||||
}
|
||||
e.peerConns[peerKey] = conn
|
||||
|
||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||
conn.Close()
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
@ -1076,8 +1077,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
conn := e.peerConns[msg.Key]
|
||||
if conn == nil {
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
}
|
||||
|
||||
@ -1135,7 +1136,7 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
||||
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||
case sProto.Body_MODE:
|
||||
}
|
||||
|
||||
@ -1239,7 +1240,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
_, routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
@ -1322,26 +1323,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the current routes from the route map
|
||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
||||
e.clientRoutesMu.RLock()
|
||||
defer e.clientRoutesMu.RUnlock()
|
||||
|
||||
return maps.Clone(e.clientRoutes)
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
e.clientRoutesMu.RLock()
|
||||
defer e.clientRoutesMu.RUnlock()
|
||||
|
||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
||||
for id, v := range e.clientRoutes {
|
||||
routes[id.NetID()] = v
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetRouteManager returns the route manager
|
||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
@ -1426,9 +1407,8 @@ func (e *Engine) receiveProbeEvents() {
|
||||
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
||||
log.Debug("received wg probe request")
|
||||
|
||||
for _, peer := range e.peerConns {
|
||||
key := peer.GetKey()
|
||||
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
|
||||
for _, key := range e.peerStore.PeersPubKey() {
|
||||
wgStats, err := e.wgInterface.GetStats(key)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||
}
|
||||
@ -1505,7 +1485,7 @@ func (e *Engine) startNetworkMonitor() {
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.GetClientRoutes() {
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
@ -1573,6 +1553,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
return nm, nil
|
||||
}
|
||||
|
||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(domains) > 0 {
|
||||
log.Infof("enable domain router service for domains: %v", domains)
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
} else {
|
||||
log.Infof("update domain router service for domains: %v", domains)
|
||||
e.dnsForwardMgr.UpdateDomains(domains)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
|
@ -252,7 +252,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil)
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@ -392,8 +392,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(engine.peerConns) != c.expectedLen {
|
||||
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns))
|
||||
if len(engine.peerStore.PeersPubKey()) != c.expectedLen {
|
||||
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey()))
|
||||
}
|
||||
|
||||
if engine.networkSerial != c.expectedSerial {
|
||||
@ -401,7 +401,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, p := range c.expectedPeers {
|
||||
conn, ok := engine.peerConns[p.GetWgPubKey()]
|
||||
conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey())
|
||||
if !ok {
|
||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||
}
|
||||
@ -626,10 +626,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
}{}
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
input.inputSerial = updateSerial
|
||||
input.inputRoutes = newRoutes
|
||||
return nil, nil, testCase.inputErr
|
||||
return testCase.inputErr
|
||||
},
|
||||
}
|
||||
|
||||
@ -802,8 +802,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
return nil, nil, nil
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
@ -1238,7 +1238,8 @@ func getConnectedPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
i := 0
|
||||
for _, conn := range e.peerConns {
|
||||
for _, id := range e.peerStore.PeersPubKey() {
|
||||
conn, _ := e.peerStore.PeerConn(id)
|
||||
if conn.Status() == peer.StatusConnected {
|
||||
i++
|
||||
}
|
||||
@ -1250,5 +1251,5 @@ func getPeers(e *Engine) int {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
return len(e.peerConns)
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
// AllowedIP returns the allowed IP of the remote peer
|
||||
func (conn *Conn) AllowedIP() net.IP {
|
||||
return conn.allowedIP
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
87
client/internal/peerstore/store.go
Normal file
87
client/internal/peerstore/store.go
Normal file
@ -0,0 +1,87 @@
|
||||
package peerstore
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
// Store is a thread-safe store for peer connections.
|
||||
type Store struct {
|
||||
peerConns map[string]*peer.Conn
|
||||
peerConnsMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewConnStore() *Store {
|
||||
return &Store{
|
||||
peerConns: make(map[string]*peer.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
|
||||
s.peerConnsMu.Lock()
|
||||
defer s.peerConnsMu.Unlock()
|
||||
|
||||
_, ok := s.peerConns[pubKey]
|
||||
if ok {
|
||||
return false
|
||||
}
|
||||
|
||||
s.peerConns[pubKey] = conn
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
|
||||
s.peerConnsMu.Lock()
|
||||
defer s.peerConnsMu.Unlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
delete(s.peerConns, pubKey)
|
||||
return p, true
|
||||
}
|
||||
|
||||
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return p.WgConfig().AllowedIps, true
|
||||
}
|
||||
|
||||
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return p.AllowedIP(), true
|
||||
}
|
||||
|
||||
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return p, true
|
||||
}
|
||||
|
||||
func (s *Store) PeersPubKey() []string {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
return maps.Keys(s.peerConns)
|
||||
}
|
@ -13,12 +13,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const useNewDNSRoute = true
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
@ -53,7 +57,17 @@ type clientNetwork struct {
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
func newClientNetworkWatcher(
|
||||
ctx context.Context,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface iface.IWGIface,
|
||||
statusRecorder *peer.Status,
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
@ -65,7 +79,16 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
||||
handler: handlerFromRoute(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
dnsRouteInterval,
|
||||
statusRecorder,
|
||||
wgInterface,
|
||||
dnsServer,
|
||||
peerStore,
|
||||
),
|
||||
}
|
||||
return client
|
||||
}
|
||||
@ -368,10 +391,37 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
||||
func handlerFromRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.IWGIface,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
if useNewDNSRoute {
|
||||
return dnsinterceptor.New(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
statusRecorder,
|
||||
dnsServer,
|
||||
peerStore,
|
||||
)
|
||||
}
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||
return dynamic.NewRoute(
|
||||
rt,
|
||||
routeRefCounter,
|
||||
allowedIPsRefCounter,
|
||||
dnsRouterInteval,
|
||||
statusRecorder,
|
||||
wgInterface,
|
||||
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||
)
|
||||
}
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
|
334
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
334
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@ -0,0 +1,334 @@
|
||||
package dnsinterceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
statusRecorder *peer.Status
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
peerStore *peerstore.Store
|
||||
}
|
||||
|
||||
func New(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
statusRecorder *peer.Status,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) *DnsInterceptor {
|
||||
return &DnsInterceptor{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
statusRecorder: statusRecorder,
|
||||
dnsServer: dnsServer,
|
||||
interceptedDomains: make(domainMap),
|
||||
peerStore: peerStore,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) String() string {
|
||||
return d.route.Domains.SafeString()
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveRoute() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
clear(d.interceptedDomains)
|
||||
|
||||
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, prefixes := range d.interceptedDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.currentPeerKey = ""
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// ServeDNS implements the dns.Handler interface
|
||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request: %v", r.Question[0].Name)
|
||||
|
||||
d.mu.RLock()
|
||||
peerKey := d.currentPeerKey
|
||||
d.mu.RUnlock()
|
||||
|
||||
if peerKey == "" {
|
||||
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
|
||||
d.continueToNextHandler(w, r, "no current peer key")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get upstream IP: %v", err)
|
||||
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Net: "udp",
|
||||
}
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||
|
||||
var answer []dns.RR
|
||||
if reply != nil {
|
||||
answer = reply.Answer
|
||||
}
|
||||
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reply.Id = r.Id
|
||||
if err := d.writeMsg(w, reply); err != nil {
|
||||
log.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// continueToNextHandler signals the handler chain to try the next handler
|
||||
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeNameError)
|
||||
// Set Zero bit to signal handler chain to continue
|
||||
resp.MsgHdr.Zero = true
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed writing DNS continue response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||
}
|
||||
return peerAllowedIP, nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("received nil DNS message")
|
||||
}
|
||||
|
||||
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||
// DNS names from miekg/dns are already in punycode format
|
||||
dom := domain.Domain(r.Question[0].Name)
|
||||
|
||||
var newPrefixes []netip.Prefix
|
||||
for _, answer := range r.Answer {
|
||||
var ip netip.Addr
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
addr, ok := netip.AddrFromSlice(rr.A)
|
||||
if !ok {
|
||||
log.Debugf("failed to convert A record IP: %v", rr.A)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
case *dns.AAAA:
|
||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||
if !ok {
|
||||
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
|
||||
continue
|
||||
}
|
||||
ip = addr
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
if len(newPrefixes) > 0 {
|
||||
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
|
||||
log.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(r); err != nil {
|
||||
return fmt.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
oldPrefixes := d.interceptedDomains[domain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Add new prefixes
|
||||
for _, prefix := range toAdd {
|
||||
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if d.currentPeerKey == "" {
|
||||
continue
|
||||
}
|
||||
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if !d.route.KeepRoute {
|
||||
// Remove old prefixes
|
||||
for _, prefix := range toRemove {
|
||||
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||
}
|
||||
if d.currentPeerKey != "" {
|
||||
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update domain prefixes
|
||||
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||
d.interceptedDomains[domain] = newPrefixes
|
||||
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
|
||||
|
||||
if len(toAdd) > 0 {
|
||||
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
|
||||
}
|
||||
if len(toRemove) > 0 {
|
||||
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
@ -74,11 +74,7 @@ func NewRoute(
|
||||
}
|
||||
|
||||
func (r *Route) String() string {
|
||||
s, err := r.route.Domains.String()
|
||||
if err != nil {
|
||||
return r.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
return r.route.Domains.SafeString()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(ctx context.Context) error {
|
||||
|
@ -12,12 +12,15 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
@ -33,9 +36,11 @@ import (
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
@ -60,6 +65,10 @@ type DefaultManager struct {
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
stateManager *statemanager.Manager
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
dnsServer dns.Server
|
||||
peerStore *peerstore.Store
|
||||
}
|
||||
|
||||
func NewManager(
|
||||
@ -71,6 +80,8 @@ func NewManager(
|
||||
relayMgr *relayClient.Manager,
|
||||
initialRoutes []*route.Route,
|
||||
stateManager *statemanager.Manager,
|
||||
dnsServer dns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
notifier := notifier.NewNotifier()
|
||||
@ -88,6 +99,8 @@ func NewManager(
|
||||
pubKey: pubKey,
|
||||
notifier: notifier,
|
||||
stateManager: stateManager,
|
||||
dnsServer: dnsServer,
|
||||
peerStore: peerStore,
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
@ -116,7 +129,7 @@ func NewManager(
|
||||
)
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.clientRoutes(initialRoutes)
|
||||
cr := dm.initialClientRoutes(initialRoutes)
|
||||
dm.notifier.SetInitialClientRoutes(cr)
|
||||
}
|
||||
return dm
|
||||
@ -207,33 +220,40 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
|
||||
m.ctx = nil
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
m.clientRoutes = nil
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return nil, nil, m.ctx.Err()
|
||||
return nil
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.clientRoutes = newClientRoutesIDMap
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||
@ -251,9 +271,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return m.routeSelector
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the client routes
|
||||
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
||||
return m.clientNetworks
|
||||
// GetClientRoutes returns most recent list of clientRoutes received from the Management Service
|
||||
func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
return maps.Clone(m.clientRoutes)
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
|
||||
for id, v := range m.clientRoutes {
|
||||
routes[id.NetID()] = v
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||
@ -273,7 +308,17 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
clientNetworkWatcher := newClientNetworkWatcher(
|
||||
m.ctx,
|
||||
m.dnsRouteInterval,
|
||||
m.wgInterface,
|
||||
m.statusRecorder,
|
||||
routes[0],
|
||||
m.routeRefCounter,
|
||||
m.allowedIPsRefCounter,
|
||||
m.dnsServer,
|
||||
m.peerStore,
|
||||
)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
@ -302,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
@ -345,7 +390,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
return newServerRoutesMap, newClientRoutesIDMap
|
||||
}
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0, len(crMap))
|
||||
for _, routes := range crMap {
|
||||
|
@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil)
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
|
@ -2,7 +2,6 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@ -15,10 +14,12 @@ import (
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
|
||||
}
|
||||
|
||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
if m.UpdateRoutesFunc != nil {
|
||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||
@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
if m.GetClientRoutesFunc != nil {
|
||||
return m.GetClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
return m.GetClientRoutesWithNetIDFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start mock implementation of Start from Manager interface
|
||||
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
||||
}
|
||||
|
@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
routesMap := engine.GetClientRoutesWithNetID()
|
||||
routeManager := engine.GetRouteManager()
|
||||
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||
if routeManager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
@ -365,12 +365,12 @@ func (c *Client) SelectRoute(id string) error {
|
||||
} else {
|
||||
log.Debugf("select route with id: %s", id)
|
||||
routes := toNetIDs([]string{id})
|
||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("error when selecting routes: %s", err)
|
||||
return fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||
return nil
|
||||
|
||||
}
|
||||
@ -392,12 +392,12 @@ func (c *Client) DeselectRoute(id string) error {
|
||||
} else {
|
||||
log.Debugf("deselect route with id: %s", id)
|
||||
routes := toNetIDs([]string{id})
|
||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("error when deselecting routes: %s", err)
|
||||
return fmt.Errorf("deselect routes: %w", err)
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
routesMap := engine.GetClientRoutesWithNetID()
|
||||
routesMap := engine.GetRouteManager().GetClientRoutesWithNetID()
|
||||
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||
|
||||
var routes []*selectRoute
|
||||
@ -116,11 +116,12 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
routeSelector.SelectAllRoutes()
|
||||
} else {
|
||||
routes := toNetIDs(req.GetNetworkIDs())
|
||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("select routes: %w", err)
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||
|
||||
return &proto.SelectNetworksResponse{}, nil
|
||||
}
|
||||
@ -145,11 +146,12 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
routeSelector.DeselectAllRoutes()
|
||||
} else {
|
||||
routes := toNetIDs(req.GetNetworkIDs())
|
||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
||||
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
||||
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
|
||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||
}
|
||||
}
|
||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
||||
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||
|
||||
return &proto.SelectNetworksResponse{}, nil
|
||||
}
|
||||
|
@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) {
|
||||
|
||||
return validHost, nil
|
||||
}
|
||||
|
||||
// NormalizeZone returns a normalized domain name without the wildcard prefix
|
||||
func NormalizeZone(domain string) string {
|
||||
d, _ := strings.CutPrefix(domain, "*.")
|
||||
return d
|
||||
}
|
||||
|
1
go.mod
1
go.mod
@ -207,6 +207,7 @@ require (
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
|
1
go.sum
1
go.sum
@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
|
@ -360,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) {
|
||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
}
|
||||
|
||||
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
|
||||
var domainList domain.List
|
||||
|
||||
|
@ -330,6 +330,14 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST Wildcard Domain",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST UnprocessableEntity when both network and domains are provided",
|
||||
requestType: http.MethodPost,
|
||||
@ -609,6 +617,30 @@ func TestValidateDomains(t *testing.T) {
|
||||
expected: domain.List{"google.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid wildcard domain",
|
||||
domains: []string{"*.example.com"},
|
||||
expected: domain.List{"*.example.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Wildcard with dot domain",
|
||||
domains: []string{".*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Wildcard with dot domain",
|
||||
domains: []string{".*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid wildcard domain",
|
||||
domains: []string{"a.*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
Loading…
Reference in New Issue
Block a user