mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-25 07:19:05 +01:00
Add dns interceptor based domain route functionality (#3032)
This commit is contained in:
parent
a145f0b811
commit
c91d7808bf
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
|||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v4
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
||||||
args: --timeout=12m
|
args: --timeout=12m --out-format colored-line-number
|
||||||
|
192
client/internal/dns/handler_chain.go
Normal file
192
client/internal/dns/handler_chain.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PriorityDNSRoute = 100
|
||||||
|
PriorityMatchDomain = 50
|
||||||
|
PriorityDefault = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type HandlerEntry struct {
|
||||||
|
Handler dns.Handler
|
||||||
|
Priority int
|
||||||
|
Pattern string
|
||||||
|
OrigPattern string
|
||||||
|
IsWildcard bool
|
||||||
|
StopHandler handlerWithStop
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerChain represents a prioritized chain of DNS handlers
|
||||||
|
type HandlerChain struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
handlers []HandlerEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
|
type ResponseWriterChain struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
shouldContinue bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||||
|
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
|
||||||
|
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
|
||||||
|
w.shouldContinue = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandlerChain() *HandlerChain {
|
||||||
|
return &HandlerChain{
|
||||||
|
handlers: make([]HandlerEntry, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
|
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
origPattern := pattern
|
||||||
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||||
|
if isWildcard {
|
||||||
|
pattern = pattern[2:]
|
||||||
|
}
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
origPattern = dns.Fqdn(origPattern)
|
||||||
|
|
||||||
|
// First remove any existing handler with same original pattern and priority
|
||||||
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
|
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
||||||
|
if c.handlers[i].StopHandler != nil {
|
||||||
|
c.handlers[i].StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("adding handler for pattern: %s (original: %s, wildcard: %v) with priority %d",
|
||||||
|
pattern, origPattern, isWildcard, priority)
|
||||||
|
|
||||||
|
entry := HandlerEntry{
|
||||||
|
Handler: handler,
|
||||||
|
Priority: priority,
|
||||||
|
Pattern: pattern,
|
||||||
|
OrigPattern: origPattern,
|
||||||
|
IsWildcard: isWildcard,
|
||||||
|
StopHandler: stopHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert handler in priority order
|
||||||
|
pos := 0
|
||||||
|
for i, h := range c.handlers {
|
||||||
|
if h.Priority < priority {
|
||||||
|
pos = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pos = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveHandler removes a handler for the given pattern and priority
|
||||||
|
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
// Find and remove handlers matching both original pattern and priority
|
||||||
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
|
entry := c.handlers[i]
|
||||||
|
if entry.OrigPattern == pattern && entry.Priority == priority {
|
||||||
|
if entry.StopHandler != nil {
|
||||||
|
entry.StopHandler.stop()
|
||||||
|
}
|
||||||
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||||
|
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
pattern = dns.Fqdn(pattern)
|
||||||
|
for _, entry := range c.handlers {
|
||||||
|
if entry.Pattern == pattern {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
qname := r.Question[0].Name
|
||||||
|
log.Debugf("handling DNS request for %s", qname)
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
log.Debugf("current handlers (%d):", len(c.handlers))
|
||||||
|
for _, h := range c.handlers {
|
||||||
|
log.Debugf(" - pattern: %s, original: %s, wildcard: %v, priority: %d",
|
||||||
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try handlers in priority order
|
||||||
|
for _, entry := range c.handlers {
|
||||||
|
var matched bool
|
||||||
|
switch {
|
||||||
|
case entry.Pattern == ".":
|
||||||
|
matched = true
|
||||||
|
case entry.IsWildcard:
|
||||||
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||||
|
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||||
|
default:
|
||||||
|
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
log.Debugf("trying domain match: pattern=%s qname=%s wildcard=%v matched=false",
|
||||||
|
entry.OrigPattern, qname, entry.IsWildcard)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("handler matched: pattern=%s qname=%s wildcard=%v",
|
||||||
|
entry.OrigPattern, qname, entry.IsWildcard)
|
||||||
|
chainWriter := &ResponseWriterChain{ResponseWriter: w}
|
||||||
|
entry.Handler.ServeDNS(chainWriter, r)
|
||||||
|
|
||||||
|
// If handler wants to continue, try next handler
|
||||||
|
if chainWriter.shouldContinue {
|
||||||
|
log.Debugf("handler requested continue to next handler")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No handler matched or all handlers passed
|
||||||
|
log.Debugf("no handler found for %s", qname)
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
490
client/internal/dns/handler_chain_test.go
Normal file
490
client/internal/dns/handler_chain_test.go
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
package dns_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockHandler implements dns.Handler interface for testing
|
||||||
|
type MockHandler struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
m.Called(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
|
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create mock handlers for different priorities
|
||||||
|
defaultHandler := &MockHandler{}
|
||||||
|
matchDomainHandler := &MockHandler{}
|
||||||
|
dnsRouteHandler := &MockHandler{}
|
||||||
|
|
||||||
|
// Setup handlers with different priorities
|
||||||
|
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Create test writer
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations - only highest priority handler should be called
|
||||||
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all expectations were met
|
||||||
|
dnsRouteHandler.AssertExpectations(t)
|
||||||
|
matchDomainHandler.AssertExpectations(t)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
|
||||||
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlerDomain string
|
||||||
|
queryDomain string
|
||||||
|
isWildcard bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with non-wildcard",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "sub.example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard no match on apex",
|
||||||
|
handlerDomain: "*.example.com.",
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
isWildcard: true,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone match",
|
||||||
|
handlerDomain: ".",
|
||||||
|
queryDomain: "anything.com.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match different domain",
|
||||||
|
handlerDomain: "example.com.",
|
||||||
|
queryDomain: "example.org.",
|
||||||
|
isWildcard: false,
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
mockHandler := &MockHandler{}
|
||||||
|
|
||||||
|
pattern := tt.handlerDomain
|
||||||
|
if tt.isWildcard {
|
||||||
|
pattern = "*." + tt.handlerDomain[2:] // Remove the first two chars if it's a wildcard
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, mockHandler, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
if tt.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
mockHandler.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
|
||||||
|
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
queryDomain string
|
||||||
|
expectedCalls int
|
||||||
|
expectedHandler int // index of the handler that should be called
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wildcard and exact same priority - exact should win",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // exact match handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "higher priority wildcard over lower priority exact",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority wildcard handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple wildcards different priorities",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain with mix of patterns",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
|
||||||
|
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "sub.test.example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 2, // highest priority matching handler should be called
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone with specific domain",
|
||||||
|
handlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||||
|
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
queryDomain: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
expectedHandler: 1, // higher priority specific domain should win over root
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
var handlers []*MockHandler
|
||||||
|
|
||||||
|
// Setup handlers and expectations
|
||||||
|
for i := range tt.handlers {
|
||||||
|
handler := &MockHandler{}
|
||||||
|
handlers = append(handlers, handler)
|
||||||
|
|
||||||
|
// Set expectation based on whether this handler should be called
|
||||||
|
if i == tt.expectedHandler {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
|
||||||
|
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
// Create handlers
|
||||||
|
handler1 := &MockHandler{}
|
||||||
|
handler2 := &MockHandler{}
|
||||||
|
handler3 := &MockHandler{}
|
||||||
|
|
||||||
|
// Add handlers in priority order
|
||||||
|
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
|
||||||
|
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
// Setup mock responses to simulate chain continuation
|
||||||
|
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// First handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true // Signal to continue
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Second handler signals continue
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
|
||||||
|
// Last handler responds normally
|
||||||
|
w := args.Get(0).(*nbdns.ResponseWriterChain)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify all handlers were called in order
|
||||||
|
handler1.AssertExpectations(t)
|
||||||
|
handler2.AssertExpectations(t)
|
||||||
|
handler3.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockResponseWriter implements dns.ResponseWriter for testing
|
||||||
|
type mockResponseWriter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
||||||
|
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (m *mockResponseWriter) Close() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||||
|
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (m *mockResponseWriter) Hijack() {}
|
||||||
|
|
||||||
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ops []struct {
|
||||||
|
action string // "add" or "remove"
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls map[int]bool // map[priority]shouldBeCalled
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "remove high priority keeps lower priority handler",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: false,
|
||||||
|
nbdns.PriorityMatchDomain: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remove lower priority keeps high priority handler",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: true,
|
||||||
|
nbdns.PriorityMatchDomain: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "remove all handlers in order",
|
||||||
|
ops: []struct {
|
||||||
|
action string
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"add", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
{"add", "example.com.", nbdns.PriorityDefault},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
|
{"remove", "example.com.", nbdns.PriorityMatchDomain},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: map[int]bool{
|
||||||
|
nbdns.PriorityDNSRoute: false,
|
||||||
|
nbdns.PriorityMatchDomain: false,
|
||||||
|
nbdns.PriorityDefault: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlers := make(map[int]*MockHandler)
|
||||||
|
|
||||||
|
// Execute operations
|
||||||
|
for _, op := range tt.ops {
|
||||||
|
if op.action == "add" {
|
||||||
|
handler := &MockHandler{}
|
||||||
|
handlers[op.priority] = handler
|
||||||
|
chain.AddHandler(op.pattern, handler, op.priority, nil)
|
||||||
|
} else {
|
||||||
|
chain.RemoveHandler(op.pattern, op.priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// Setup expectations
|
||||||
|
for priority, handler := range handlers {
|
||||||
|
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
} else {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Maybe()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
// Verify expectations
|
||||||
|
for _, handler := range handlers {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler exists check
|
||||||
|
for priority, shouldExist := range tt.expectedCalls {
|
||||||
|
if shouldExist {
|
||||||
|
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||||
|
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
testDomain := "example.com."
|
||||||
|
testQuery := "test.example.com."
|
||||||
|
|
||||||
|
// Create handlers for three priority levels
|
||||||
|
routeHandler := &MockHandler{}
|
||||||
|
matchHandler := &MockHandler{}
|
||||||
|
defaultHandler := &MockHandler{}
|
||||||
|
|
||||||
|
// Create test request that will be reused
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Add handlers in mixed order
|
||||||
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
|
||||||
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
|
||||||
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
|
||||||
|
|
||||||
|
// Test 1: Initial state with all three handlers
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Highest priority handler (routeHandler) should be called
|
||||||
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 2: Remove highest priority handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
|
assert.True(t, chain.HasHandlers(testDomain))
|
||||||
|
|
||||||
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Now middle priority handler (matchHandler) should be called
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 3: Remove middle priority handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
|
assert.True(t, chain.HasHandlers(testDomain))
|
||||||
|
|
||||||
|
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Test 4: Remove last handler
|
||||||
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
assert.False(t, chain.HasHandlers(testDomain))
|
||||||
|
}
|
@ -3,6 +3,8 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -11,6 +13,20 @@ type MockServer struct {
|
|||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
|
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||||
|
DeregisterHandlerFunc func([]string, int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
if m.RegisterHandlerFunc != nil {
|
||||||
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||||
|
if m.DeregisterHandlerFunc != nil {
|
||||||
|
m.DeregisterHandlerFunc(domains, priority)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize mock implementation of Initialize from Server interface
|
// Initialize mock implementation of Initialize from Server interface
|
||||||
|
@ -30,6 +30,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
|
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||||
|
DeregisterHandler(domains []string, priority int)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@ -48,12 +50,14 @@ type DefaultServer struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
|
handlerPriorities map[string]int
|
||||||
localResolver *localResolver
|
localResolver *localResolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
|
handlerChain *HandlerChain
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@ -76,6 +80,7 @@ type handlerWithStop interface {
|
|||||||
type muxUpdate struct {
|
type muxUpdate struct {
|
||||||
domain string
|
domain string
|
||||||
handler handlerWithStop
|
handler handlerWithStop
|
||||||
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
@ -138,7 +143,9 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
|
handlerPriorities: make(map[string]int),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
@ -151,6 +158,41 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.registerHandler(domains, handler, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
|
log.Debugf("registering handler %s with priority %d", handler, priority)
|
||||||
|
|
||||||
|
for _, domain := range domains {
|
||||||
|
s.handlerChain.AddHandler(domain, handler, priority, nil)
|
||||||
|
s.handlerPriorities[domain] = priority
|
||||||
|
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
s.deregisterHandler(domains, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
|
for _, domain := range domains {
|
||||||
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
|
// Only deregister from service if no handlers remain
|
||||||
|
if !s.handlerChain.HasHandlers(domain) {
|
||||||
|
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize instantiate host manager and the dns service
|
// Initialize instantiate host manager and the dns service
|
||||||
func (s *DefaultServer) Initialize() (err error) {
|
func (s *DefaultServer) Initialize() (err error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@ -343,7 +385,6 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
localRecords := make(map[string]nbdns.SimpleRecord, 0)
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
|
|
||||||
if len(customZone.Records) == 0 {
|
if len(customZone.Records) == 0 {
|
||||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||||
}
|
}
|
||||||
@ -351,6 +392,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: customZone.Domain,
|
domain: customZone.Domain,
|
||||||
handler: s.localResolver,
|
handler: s.localResolver,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
@ -414,6 +456,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityDefault,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -431,6 +474,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
muxUpdates = append(muxUpdates, muxUpdate{
|
muxUpdates = append(muxUpdates, muxUpdate{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -440,12 +484,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
|||||||
|
|
||||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
|
handlersByPriority := make(map[string]int)
|
||||||
|
|
||||||
var isContainRootUpdate bool
|
var isContainRootUpdate bool
|
||||||
|
|
||||||
|
// First register new handlers
|
||||||
for _, update := range muxUpdates {
|
for _, update := range muxUpdates {
|
||||||
s.service.RegisterMux(update.domain, update.handler)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.domain] = update.handler
|
muxUpdateMap[update.domain] = update.handler
|
||||||
|
handlersByPriority[update.domain] = update.priority
|
||||||
|
|
||||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
}
|
}
|
||||||
@ -455,6 +503,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Then deregister old handlers not in the update
|
||||||
for key, existingHandler := range s.dnsMuxMap {
|
for key, existingHandler := range s.dnsMuxMap {
|
||||||
_, found := muxUpdateMap[key]
|
_, found := muxUpdateMap[key]
|
||||||
if !found {
|
if !found {
|
||||||
@ -463,12 +512,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
|||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
} else {
|
} else {
|
||||||
existingHandler.stop()
|
existingHandler.stop()
|
||||||
s.service.DeregisterMux(key)
|
// Deregister with the priority that was used to register
|
||||||
|
if oldPriority, ok := s.handlerPriorities[key]; ok {
|
||||||
|
s.deregisterHandler([]string{key}, oldPriority)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
|
s.handlerPriorities = handlersByPriority
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
|
||||||
@ -517,13 +570,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
removeIndex[nbdns.RootZone] = -1
|
removeIndex[nbdns.RootZone] = -1
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
s.service.DeregisterMux(nbdns.RootZone)
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, item := range s.currentConfig.Domains {
|
for i, item := range s.currentConfig.Domains {
|
||||||
if _, found := removeIndex[item.Domain]; found {
|
if _, found := removeIndex[item.Domain]; found {
|
||||||
s.currentConfig.Domains[i].Disabled = true
|
s.currentConfig.Domains[i].Disabled = true
|
||||||
s.service.DeregisterMux(item.Domain)
|
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
|
||||||
removeIndex[item.Domain] = i
|
removeIndex[item.Domain] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -554,7 +607,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.currentConfig.Domains[i].Disabled = false
|
s.currentConfig.Domains[i].Disabled = false
|
||||||
s.service.RegisterMux(domain, handler)
|
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||||
@ -562,7 +615,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
|
|
||||||
if nsGroup.Primary {
|
if nsGroup.Primary {
|
||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
@ -593,7 +646,8 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
}
|
}
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
|
||||||
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
||||||
|
@ -512,7 +512,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
|
||||||
|
|
||||||
resolver := &net.Resolver{
|
resolver := &net.Resolver{
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
@ -560,6 +560,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
},
|
},
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
handlerPriorities: make(map[string]int),
|
||||||
hostManager: hostManager,
|
hostManager: hostManager,
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
|
@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +66,11 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the upstream resolver
|
||||||
|
func (u *upstreamResolverBase) String() string {
|
||||||
|
return fmt.Sprintf("%v", u.upstreamServers)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
120
client/internal/dnsfwd/forwarder.go
Normal file
120
client/internal/dnsfwd/forwarder.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DNSForwarder struct {
|
||||||
|
listenAddress string
|
||||||
|
ttl uint32
|
||||||
|
domains []string
|
||||||
|
|
||||||
|
dnsServer *dns.Server
|
||||||
|
mux *dns.ServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSForwarder {
|
||||||
|
log.Debugf("creating DNS forwarder with listen address: %s, ttl: %d, domains: %v", listenAddress, ttl, domains)
|
||||||
|
return &DNSForwarder{
|
||||||
|
listenAddress: listenAddress,
|
||||||
|
ttl: ttl,
|
||||||
|
domains: domains,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (f *DNSForwarder) Listen() error {
|
||||||
|
log.Infof("listen DNS forwarder on: %s", f.listenAddress)
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
|
for _, d := range f.domains {
|
||||||
|
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServer := &dns.Server{
|
||||||
|
Addr: f.listenAddress,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
f.dnsServer = dnsServer
|
||||||
|
f.mux = mux
|
||||||
|
return dnsServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||||
|
for _, d := range f.domains {
|
||||||
|
f.mux.HandleRemove(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range f.domains {
|
||||||
|
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
|
||||||
|
}
|
||||||
|
f.domains = domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
|
if f.dnsServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return f.dnsServer.ShutdownContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
if len(query.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Tracef("received DNS request for DNS forwarder: %v", query.Question[0].Name)
|
||||||
|
|
||||||
|
question := query.Question[0]
|
||||||
|
domain := question.Name
|
||||||
|
|
||||||
|
resp := query.SetReply(query)
|
||||||
|
|
||||||
|
ips, err := net.LookupIP(domain)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to resolve query for domain %s: %v", domain, err)
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
var respRecord dns.RR
|
||||||
|
if ip.To4() == nil {
|
||||||
|
log.Tracef("resolved domain %s to IPv6 %s", domain, ip)
|
||||||
|
rr := dns.AAAA{
|
||||||
|
AAAA: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.ttl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
} else {
|
||||||
|
log.Tracef("resolved domain %s to IPv4 %s", domain, ip)
|
||||||
|
rr := dns.A{
|
||||||
|
A: ip,
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: domain,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: f.ttl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
respRecord = &rr
|
||||||
|
}
|
||||||
|
resp.Answer = append(resp.Answer, respRecord)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
106
client/internal/dnsfwd/manager.go
Normal file
106
client/internal/dnsfwd/manager.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||||
|
ListenPort = 5353
|
||||||
|
dnsTTL = 60 //seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
firewall firewall.Manager
|
||||||
|
|
||||||
|
fwRules []firewall.Rule
|
||||||
|
dnsForwarder *DNSForwarder
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(fw firewall.Manager) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
firewall: fw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(domains []string) error {
|
||||||
|
log.Infof("starting DNS forwarder")
|
||||||
|
if m.dnsForwarder != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.allowDNSFirewall(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, domains)
|
||||||
|
go func() {
|
||||||
|
if err := m.dnsForwarder.Listen(); err != nil {
|
||||||
|
// todo handle close error if it is exists
|
||||||
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) UpdateDomains(domains []string) {
|
||||||
|
if m.dnsForwarder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder.UpdateDomains(domains)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop(ctx context.Context) error {
|
||||||
|
if m.dnsForwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var mErr *multierror.Error
|
||||||
|
if err := m.dropDNSFirewall(); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.dnsForwarder.Close(ctx); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = nil
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) allowDNSFirewall() error {
|
||||||
|
dport := &firewall.Port{
|
||||||
|
IsRange: false,
|
||||||
|
Values: []int{ListenPort},
|
||||||
|
}
|
||||||
|
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.fwRules = dnsRules
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) dropDNSFirewall() error {
|
||||||
|
var mErr *multierror.Error
|
||||||
|
for _, rule := range h.fwRules {
|
||||||
|
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.fwRules = nil
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -30,10 +29,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
@ -117,7 +118,7 @@ type Engine struct {
|
|||||||
// mgmClient is a Management Service client
|
// mgmClient is a Management Service client
|
||||||
mgmClient mgm.Client
|
mgmClient mgm.Client
|
||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerStore *peerstore.Store
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
beforePeerHook nbnet.AddHookFunc
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
afterPeerHook nbnet.RemoveHookFunc
|
||||||
@ -137,10 +138,6 @@ type Engine struct {
|
|||||||
TURNs []*stun.URI
|
TURNs []*stun.URI
|
||||||
stunTurn atomic.Value
|
stunTurn atomic.Value
|
||||||
|
|
||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
|
||||||
clientRoutes route.HAMap
|
|
||||||
clientRoutesMu sync.RWMutex
|
|
||||||
|
|
||||||
clientCtx context.Context
|
clientCtx context.Context
|
||||||
clientCancel context.CancelFunc
|
clientCancel context.CancelFunc
|
||||||
|
|
||||||
@ -164,6 +161,7 @@ type Engine struct {
|
|||||||
firewall manager.Manager
|
firewall manager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
|
dnsForwardMgr *dnsfwd.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
@ -234,7 +232,7 @@ func NewEngineWithProbes(
|
|||||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
peerConns: make(map[string]*peer.Conn),
|
peerStore: peerstore.NewConnStore(),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
mobileDep: mobileDep,
|
mobileDep: mobileDep,
|
||||||
@ -287,6 +285,13 @@ func (e *Engine) Stop() error {
|
|||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.dnsForwardMgr != nil {
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
if e.srWatcher != nil {
|
if e.srWatcher != nil {
|
||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
@ -300,10 +305,6 @@ func (e *Engine) Stop() error {
|
|||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = nil
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@ -382,6 +383,8 @@ func (e *Engine) Start() error {
|
|||||||
e.relayManager,
|
e.relayManager,
|
||||||
initialRoutes,
|
initialRoutes,
|
||||||
e.stateManager,
|
e.stateManager,
|
||||||
|
dnsServer,
|
||||||
|
e.peerStore,
|
||||||
)
|
)
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -460,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
if peerConn, ok := e.peerConns[peerPubKey]; ok {
|
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
|
||||||
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
|
if allowedIPs != strings.Join(p.AllowedIps, ",") {
|
||||||
modified = append(modified, p)
|
modified = append(modified, p)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -492,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
currentPeers := make([]string, 0, len(e.peerConns))
|
|
||||||
for p := range e.peerConns {
|
|
||||||
currentPeers = append(currentPeers, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
newPeers := make([]string, 0, len(peersUpdate))
|
newPeers := make([]string, 0, len(peersUpdate))
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
newPeers = append(newPeers, p.GetWgPubKey())
|
newPeers = append(newPeers, p.GetWgPubKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
toRemove := util.SliceDiff(currentPeers, newPeers)
|
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||||
|
|
||||||
for _, p := range toRemove {
|
for _, p := range toRemove {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
@ -516,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
|
|
||||||
func (e *Engine) removeAllPeers() error {
|
func (e *Engine) removeAllPeers() error {
|
||||||
log.Debugf("removing all peer connections")
|
log.Debugf("removing all peer connections")
|
||||||
for p := range e.peerConns {
|
for _, p := range e.peerStore.PeersPubKey() {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -540,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, exists := e.peerConns[peerKey]
|
conn, exists := e.peerStore.Remove(peerKey)
|
||||||
if exists {
|
if exists {
|
||||||
delete(e.peerConns, peerKey)
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -786,7 +783,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||||
|
|
||||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||||
if networkMap.GetPeerConfig() != nil {
|
if networkMap.GetPeerConfig() != nil {
|
||||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||||
@ -806,19 +802,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
protoRoutes := networkMap.GetRoutes()
|
routedDomains, routes := toRoutes(networkMap.GetRoutes())
|
||||||
if protoRoutes == nil {
|
|
||||||
protoRoutes = []*mgmProto.Route{}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
// todo: useRoutingPeerDnsResolutionEnabled from network map proto
|
||||||
e.clientRoutes = clientRoutes
|
e.updateDNSForwarder(true, routedDomains)
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
@ -867,8 +858,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -881,7 +871,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) {
|
||||||
|
if protoRoutes == nil {
|
||||||
|
protoRoutes = []*mgmProto.Route{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var dnsRoutes []string
|
||||||
routes := make([]*route.Route, 0)
|
routes := make([]*route.Route, 0)
|
||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
var prefix netip.Prefix
|
var prefix netip.Prefix
|
||||||
@ -892,6 +887,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||||
|
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: route.ID(protoRoute.ID),
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@ -905,7 +902,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
}
|
}
|
||||||
routes = append(routes, convertedRoute)
|
routes = append(routes, convertedRoute)
|
||||||
}
|
}
|
||||||
return routes
|
return dnsRoutes, routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||||
@ -982,12 +979,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||||
peerKey := peerConfig.GetWgPubKey()
|
peerKey := peerConfig.GetWgPubKey()
|
||||||
peerIPs := peerConfig.GetAllowedIps()
|
peerIPs := peerConfig.GetAllowedIps()
|
||||||
if _, ok := e.peerConns[peerKey]; !ok {
|
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
|
||||||
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create peer connection: %w", err)
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
|
||||||
|
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
@ -1076,8 +1077,8 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
conn := e.peerConns[msg.Key]
|
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||||
if conn == nil {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1135,7 +1136,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||||
case sProto.Body_MODE:
|
case sProto.Body_MODE:
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1239,7 +1240,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
routes := toRoutes(netMap.GetRoutes())
|
_, routes := toRoutes(netMap.GetRoutes())
|
||||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||||
return routes, &dnsCfg, nil
|
return routes, &dnsCfg, nil
|
||||||
}
|
}
|
||||||
@ -1322,26 +1323,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the current routes from the route map
|
|
||||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
return maps.Clone(e.clientRoutes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
|
||||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
|
||||||
for id, v := range e.clientRoutes {
|
|
||||||
routes[id.NetID()] = v
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRouteManager returns the route manager
|
// GetRouteManager returns the route manager
|
||||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||||
return e.routeManager
|
return e.routeManager
|
||||||
@ -1426,9 +1407,8 @@ func (e *Engine) receiveProbeEvents() {
|
|||||||
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
go e.probes.WgProbe.Receive(e.ctx, func() bool {
|
||||||
log.Debug("received wg probe request")
|
log.Debug("received wg probe request")
|
||||||
|
|
||||||
for _, peer := range e.peerConns {
|
for _, key := range e.peerStore.PeersPubKey() {
|
||||||
key := peer.GetKey()
|
wgStats, err := e.wgInterface.GetStats(key)
|
||||||
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
||||||
}
|
}
|
||||||
@ -1505,7 +1485,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
|
|
||||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||||
var vpnRoutes []netip.Prefix
|
var vpnRoutes []netip.Prefix
|
||||||
for _, routes := range e.GetClientRoutes() {
|
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||||
if len(routes) > 0 && routes[0] != nil {
|
if len(routes) > 0 && routes[0] != nil {
|
||||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||||
}
|
}
|
||||||
@ -1573,6 +1553,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
return nm, nil
|
return nm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
|
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||||
|
if !enabled {
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(domains) > 0 {
|
||||||
|
log.Infof("enable domain router service for domains: %v", domains)
|
||||||
|
if e.dnsForwardMgr == nil {
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||||
|
|
||||||
|
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("update domain router service for domains: %v", domains)
|
||||||
|
e.dnsForwardMgr.UpdateDomains(domains)
|
||||||
|
}
|
||||||
|
} else if e.dnsForwardMgr != nil {
|
||||||
|
log.Infof("disable domain router service")
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
|
@ -252,7 +252,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil)
|
||||||
_, _, err = engine.routeManager.Init()
|
_, _, err = engine.routeManager.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@ -392,8 +392,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(engine.peerConns) != c.expectedLen {
|
if len(engine.peerStore.PeersPubKey()) != c.expectedLen {
|
||||||
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns))
|
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if engine.networkSerial != c.expectedSerial {
|
if engine.networkSerial != c.expectedSerial {
|
||||||
@ -401,7 +401,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range c.expectedPeers {
|
for _, p := range c.expectedPeers {
|
||||||
conn, ok := engine.peerConns[p.GetWgPubKey()]
|
conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey())
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
|
||||||
}
|
}
|
||||||
@ -626,10 +626,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}{}
|
}{}
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
input.inputSerial = updateSerial
|
input.inputSerial = updateSerial
|
||||||
input.inputRoutes = newRoutes
|
input.inputRoutes = newRoutes
|
||||||
return nil, nil, testCase.inputErr
|
return testCase.inputErr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -802,8 +802,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1238,7 +1238,8 @@ func getConnectedPeers(e *Engine) int {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
i := 0
|
i := 0
|
||||||
for _, conn := range e.peerConns {
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
if conn.Status() == peer.StatusConnected {
|
if conn.Status() == peer.StatusConnected {
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
@ -1250,5 +1251,5 @@ func getPeers(e *Engine) int {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
return len(e.peerConns)
|
return len(e.peerStore.PeersPubKey())
|
||||||
}
|
}
|
||||||
|
@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
|||||||
conn.wgProxyRelay = proxy
|
conn.wgProxyRelay = proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowedIP returns the allowed IP of the remote peer
|
||||||
|
func (conn *Conn) AllowedIP() net.IP {
|
||||||
|
return conn.allowedIP
|
||||||
|
}
|
||||||
|
|
||||||
func isController(config ConnConfig) bool {
|
func isController(config ConnConfig) bool {
|
||||||
return config.LocalKey > config.Key
|
return config.LocalKey > config.Key
|
||||||
}
|
}
|
||||||
|
87
client/internal/peerstore/store.go
Normal file
87
client/internal/peerstore/store.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package peerstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is a thread-safe store for peer connections.
|
||||||
|
type Store struct {
|
||||||
|
peerConns map[string]*peer.Conn
|
||||||
|
peerConnsMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnStore() *Store {
|
||||||
|
return &Store{
|
||||||
|
peerConns: make(map[string]*peer.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
|
||||||
|
s.peerConnsMu.Lock()
|
||||||
|
defer s.peerConnsMu.Unlock()
|
||||||
|
|
||||||
|
_, ok := s.peerConns[pubKey]
|
||||||
|
if ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.peerConns[pubKey] = conn
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
|
||||||
|
s.peerConnsMu.Lock()
|
||||||
|
defer s.peerConnsMu.Unlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
delete(s.peerConns, pubKey)
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return p.WgConfig().AllowedIps, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p.AllowedIP(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return p, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeersPubKey() []string {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
return maps.Keys(s.peerConns)
|
||||||
|
}
|
@ -13,12 +13,16 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const useNewDNSRoute = true
|
||||||
|
|
||||||
type routerPeerStatus struct {
|
type routerPeerStatus struct {
|
||||||
connected bool
|
connected bool
|
||||||
relayed bool
|
relayed bool
|
||||||
@ -53,7 +57,17 @@ type clientNetwork struct {
|
|||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
func newClientNetworkWatcher(
|
||||||
|
ctx context.Context,
|
||||||
|
dnsRouteInterval time.Duration,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
@ -65,7 +79,16 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
|||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
handler: handlerFromRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouteInterval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
dnsServer,
|
||||||
|
peerStore,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@ -368,10 +391,37 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
|
func handlerFromRoute(
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
dnsRouterInteval time.Duration,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
wgInterface iface.IWGIface,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
) RouteHandler {
|
||||||
if rt.IsDynamic() {
|
if rt.IsDynamic() {
|
||||||
|
if useNewDNSRoute {
|
||||||
|
return dnsinterceptor.New(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
statusRecorder,
|
||||||
|
dnsServer,
|
||||||
|
peerStore,
|
||||||
|
)
|
||||||
|
}
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
return dynamic.NewRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouterInteval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||||
}
|
}
|
||||||
|
334
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
334
client/internal/routemanager/dnsinterceptor/handler.go
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
package dnsinterceptor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type domainMap map[domain.Domain][]netip.Prefix
|
||||||
|
|
||||||
|
type DnsInterceptor struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
route *route.Route
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
|
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
dnsServer nbdns.Server
|
||||||
|
currentPeerKey string
|
||||||
|
interceptedDomains domainMap
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(
|
||||||
|
rt *route.Route,
|
||||||
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
|
statusRecorder *peer.Status,
|
||||||
|
dnsServer nbdns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
|
) *DnsInterceptor {
|
||||||
|
return &DnsInterceptor{
|
||||||
|
route: rt,
|
||||||
|
routeRefCounter: routeRefCounter,
|
||||||
|
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
dnsServer: dnsServer,
|
||||||
|
interceptedDomains: make(domainMap),
|
||||||
|
peerStore: peerStore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) String() string {
|
||||||
|
return d.route.Domains.SafeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) AddRoute(context.Context) error {
|
||||||
|
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) RemoveRoute() error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
if d.currentPeerKey != "" {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||||
|
|
||||||
|
d.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
clear(d.interceptedDomains)
|
||||||
|
|
||||||
|
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for domain, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||||
|
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
prefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.currentPeerKey = peerKey
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefixes := range d.interceptedDomains {
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.currentPeerKey = ""
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements the dns.Handler interface
|
||||||
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Tracef("received DNS request: %v", r.Question[0].Name)
|
||||||
|
|
||||||
|
d.mu.RLock()
|
||||||
|
peerKey := d.currentPeerKey
|
||||||
|
d.mu.RUnlock()
|
||||||
|
|
||||||
|
if peerKey == "" {
|
||||||
|
log.Debugf("no current peer key set, letting next handler try for %s", r.Question[0].Name)
|
||||||
|
d.continueToNextHandler(w, r, "no current peer key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamIP, err := d.getUpstreamIP(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get upstream IP: %v", err)
|
||||||
|
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &dns.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Net: "udp",
|
||||||
|
}
|
||||||
|
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
|
||||||
|
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
|
||||||
|
|
||||||
|
var answer []dns.RR
|
||||||
|
if reply != nil {
|
||||||
|
answer = reply.Answer
|
||||||
|
}
|
||||||
|
log.Debugf("upstream %s (%s) DNS response for %s: %v", upstreamIP, peerKey, r.Question[0].Name, answer)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
|
||||||
|
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||||
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reply.Id = r.Id
|
||||||
|
if err := d.writeMsg(w, reply); err != nil {
|
||||||
|
log.Errorf("failed writing DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// continueToNextHandler signals the handler chain to try the next handler
|
||||||
|
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
|
||||||
|
log.Debugf("continuing to next handler for %s: %s", r.Question[0].Name, reason)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeNameError)
|
||||||
|
// Set Zero bit to signal handler chain to continue
|
||||||
|
resp.MsgHdr.Zero = true
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed writing DNS continue response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
|
||||||
|
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
|
||||||
|
}
|
||||||
|
return peerAllowedIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||||
|
if r == nil {
|
||||||
|
return fmt.Errorf("received nil DNS message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Answer) > 0 && len(r.Question) > 0 {
|
||||||
|
// DNS names from miekg/dns are already in punycode format
|
||||||
|
dom := domain.Domain(r.Question[0].Name)
|
||||||
|
|
||||||
|
var newPrefixes []netip.Prefix
|
||||||
|
for _, answer := range r.Answer {
|
||||||
|
var ip netip.Addr
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("failed to convert A record IP: %v", rr.A)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
case *dns.AAAA:
|
||||||
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("failed to convert AAAA record IP: %v", rr.AAAA)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||||
|
newPrefixes = append(newPrefixes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newPrefixes) > 0 {
|
||||||
|
if err := d.updateDomainPrefixes(dom, newPrefixes); err != nil {
|
||||||
|
log.Errorf("failed to update domain prefixes: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(r); err != nil {
|
||||||
|
return fmt.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DnsInterceptor) updateDomainPrefixes(domain domain.Domain, newPrefixes []netip.Prefix) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
oldPrefixes := d.interceptedDomains[domain]
|
||||||
|
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Add new prefixes
|
||||||
|
for _, prefix := range toAdd {
|
||||||
|
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.currentPeerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
|
||||||
|
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
|
||||||
|
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||||
|
prefix.Addr(),
|
||||||
|
domain.SafeString(),
|
||||||
|
ref.Out,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !d.route.KeepRoute {
|
||||||
|
// Remove old prefixes
|
||||||
|
for _, prefix := range toRemove {
|
||||||
|
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
if d.currentPeerKey != "" {
|
||||||
|
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update domain prefixes
|
||||||
|
if len(toAdd) > 0 || len(toRemove) > 0 {
|
||||||
|
d.interceptedDomains[domain] = newPrefixes
|
||||||
|
d.statusRecorder.UpdateResolvedDomainsStates(domain, newPrefixes)
|
||||||
|
|
||||||
|
if len(toAdd) > 0 {
|
||||||
|
log.Debugf("added dynamic route(s) for [%s]: %s", domain.SafeString(), toAdd)
|
||||||
|
}
|
||||||
|
if len(toRemove) > 0 {
|
||||||
|
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), toRemove)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||||
|
prefixSet := make(map[netip.Prefix]bool)
|
||||||
|
for _, prefix := range oldPrefixes {
|
||||||
|
prefixSet[prefix] = false
|
||||||
|
}
|
||||||
|
for _, prefix := range newPrefixes {
|
||||||
|
if _, exists := prefixSet[prefix]; exists {
|
||||||
|
prefixSet[prefix] = true
|
||||||
|
} else {
|
||||||
|
toAdd = append(toAdd, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for prefix, inUse := range prefixSet {
|
||||||
|
if !inUse {
|
||||||
|
toRemove = append(toRemove, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -74,11 +74,7 @@ func NewRoute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) String() string {
|
func (r *Route) String() string {
|
||||||
s, err := r.route.Domains.String()
|
return r.route.Domains.SafeString()
|
||||||
if err != nil {
|
|
||||||
return r.route.Domains.PunycodeString()
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Route) AddRoute(ctx context.Context) error {
|
func (r *Route) AddRoute(ctx context.Context) error {
|
||||||
|
@ -12,12 +12,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
@ -33,9 +36,11 @@ import (
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
|
GetClientRoutes() route.HAMap
|
||||||
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
EnableServerRouter(firewall firewall.Manager) error
|
||||||
@ -60,6 +65,10 @@ type DefaultManager struct {
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
|
clientRoutes route.HAMap
|
||||||
|
dnsServer dns.Server
|
||||||
|
peerStore *peerstore.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(
|
||||||
@ -71,6 +80,8 @@ func NewManager(
|
|||||||
relayMgr *relayClient.Manager,
|
relayMgr *relayClient.Manager,
|
||||||
initialRoutes []*route.Route,
|
initialRoutes []*route.Route,
|
||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
|
dnsServer dns.Server,
|
||||||
|
peerStore *peerstore.Store,
|
||||||
) *DefaultManager {
|
) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
@ -88,6 +99,8 @@ func NewManager(
|
|||||||
pubKey: pubKey,
|
pubKey: pubKey,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
|
dnsServer: dnsServer,
|
||||||
|
peerStore: peerStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
dm.routeRefCounter = refcounter.New(
|
||||||
@ -116,7 +129,7 @@ func NewManager(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.clientRoutes(initialRoutes)
|
cr := dm.initialClientRoutes(initialRoutes)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
}
|
}
|
||||||
return dm
|
return dm
|
||||||
@ -207,15 +220,21 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.clientRoutes = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
return nil, nil, m.ctx.Err()
|
return nil
|
||||||
default:
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
@ -228,12 +247,13 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
}
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||||
@ -251,9 +271,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return m.routeSelector
|
return m.routeSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the client routes
|
// GetClientRoutes returns most recent list of clientRoutes received from the Management Service
|
||||||
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||||
return m.clientNetworks
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
return maps.Clone(m.clientRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
|
||||||
|
for id, v := range m.clientRoutes {
|
||||||
|
routes[id.NetID()] = v
|
||||||
|
}
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||||
@ -273,7 +308,17 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher := newClientNetworkWatcher(
|
||||||
|
m.ctx,
|
||||||
|
m.dnsRouteInterval,
|
||||||
|
m.wgInterface,
|
||||||
|
m.statusRecorder,
|
||||||
|
routes[0],
|
||||||
|
m.routeRefCounter,
|
||||||
|
m.allowedIPsRefCounter,
|
||||||
|
m.dnsServer,
|
||||||
|
m.peerStore,
|
||||||
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||||
@ -302,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
if !found {
|
if !found {
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
}
|
}
|
||||||
@ -345,7 +390,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
return newServerRoutesMap, newClientRoutesIDMap
|
return newServerRoutesMap, newClientRoutesIDMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||||
_, crMap := m.classifyRoutes(initialRoutes)
|
_, crMap := m.classifyRoutes(initialRoutes)
|
||||||
rs := make([]*route.Route, 0, len(crMap))
|
rs := make([]*route.Route, 0, len(crMap))
|
||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
|
@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
|
@ -2,7 +2,6 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@ -15,9 +14,11 @@ import (
|
|||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
|
GetClientRoutesFunc func() route.HAMap
|
||||||
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||||
@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||||
|
if m.GetClientRoutesFunc != nil {
|
||||||
|
return m.GetClientRoutesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
return m.GetClientRoutesWithNetIDFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Start mock implementation of Start from Manager interface
|
// Start mock implementation of Start from Manager interface
|
||||||
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
||||||
}
|
}
|
||||||
|
@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
|||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := engine.GetClientRoutesWithNetID()
|
|
||||||
routeManager := engine.GetRouteManager()
|
routeManager := engine.GetRouteManager()
|
||||||
|
routesMap := routeManager.GetClientRoutesWithNetID()
|
||||||
if routeManager == nil {
|
if routeManager == nil {
|
||||||
return nil, fmt.Errorf("could not get route manager")
|
return nil, fmt.Errorf("could not get route manager")
|
||||||
}
|
}
|
||||||
@ -365,12 +365,12 @@ func (c *Client) SelectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("select route with id: %s", id)
|
log.Debugf("select route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||||
log.Debugf("error when selecting routes: %s", err)
|
log.Debugf("error when selecting routes: %s", err)
|
||||||
return fmt.Errorf("select routes: %w", err)
|
return fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -392,12 +392,12 @@ func (c *Client) DeselectRoute(id string) error {
|
|||||||
} else {
|
} else {
|
||||||
log.Debugf("deselect route with id: %s", id)
|
log.Debugf("deselect route with id: %s", id)
|
||||||
routes := toNetIDs([]string{id})
|
routes := toNetIDs([]string{id})
|
||||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
|
||||||
log.Debugf("error when deselecting routes: %s", err)
|
log.Debugf("error when deselecting routes: %s", err)
|
||||||
return fmt.Errorf("deselect routes: %w", err)
|
return fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
|||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := engine.GetClientRoutesWithNetID()
|
routesMap := engine.GetRouteManager().GetClientRoutesWithNetID()
|
||||||
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||||
|
|
||||||
var routes []*selectRoute
|
var routes []*selectRoute
|
||||||
@ -116,11 +116,12 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
|||||||
routeSelector.SelectAllRoutes()
|
routeSelector.SelectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
routes := toNetIDs(req.GetNetworkIDs())
|
routes := toNetIDs(req.GetNetworkIDs())
|
||||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
||||||
|
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||||
return nil, fmt.Errorf("select routes: %w", err)
|
return nil, fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
|
|
||||||
return &proto.SelectNetworksResponse{}, nil
|
return &proto.SelectNetworksResponse{}, nil
|
||||||
}
|
}
|
||||||
@ -145,11 +146,12 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
|||||||
routeSelector.DeselectAllRoutes()
|
routeSelector.DeselectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
routes := toNetIDs(req.GetNetworkIDs())
|
routes := toNetIDs(req.GetNetworkIDs())
|
||||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID())
|
||||||
|
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
|
||||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(routeManager.GetClientRoutes())
|
||||||
|
|
||||||
return &proto.SelectNetworksResponse{}, nil
|
return &proto.SelectNetworksResponse{}, nil
|
||||||
}
|
}
|
||||||
|
@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) {
|
|||||||
|
|
||||||
return validHost, nil
|
return validHost, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeZone returns a normalized domain name without the wildcard prefix
|
||||||
|
func NormalizeZone(domain string) string {
|
||||||
|
d, _ := strings.CutPrefix(domain, "*.")
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
1
go.mod
1
go.mod
@ -207,6 +207,7 @@ require (
|
|||||||
github.com/spf13/cast v1.5.0 // indirect
|
github.com/spf13/cast v1.5.0 // indirect
|
||||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
|
||||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
|
1
go.sum
1
go.sum
@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
@ -360,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) {
|
|||||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||||
|
|
||||||
var domainList domain.List
|
var domainList domain.List
|
||||||
|
|
||||||
|
@ -330,6 +330,14 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "POST Wildcard Domain",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/routes",
|
||||||
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "POST UnprocessableEntity when both network and domains are provided",
|
name: "POST UnprocessableEntity when both network and domains are provided",
|
||||||
requestType: http.MethodPost,
|
requestType: http.MethodPost,
|
||||||
@ -609,6 +617,30 @@ func TestValidateDomains(t *testing.T) {
|
|||||||
expected: domain.List{"google.com"},
|
expected: domain.List{"google.com"},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Valid wildcard domain",
|
||||||
|
domains: []string{"*.example.com"},
|
||||||
|
expected: domain.List{"*.example.com"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard with dot domain",
|
||||||
|
domains: []string{".*.example.com"},
|
||||||
|
expected: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard with dot domain",
|
||||||
|
domains: []string{".*.example.com"},
|
||||||
|
expected: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid wildcard domain",
|
||||||
|
domains: []string{"a.*.example.com"},
|
||||||
|
expected: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
Loading…
Reference in New Issue
Block a user