diff --git a/client/iface/batcher.go b/client/iface/batcher.go new file mode 100644 index 000000000..d00601f1f --- /dev/null +++ b/client/iface/batcher.go @@ -0,0 +1,338 @@ +package iface + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "os" + "strconv" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) + +const ( + // DefaultBatchFlushInterval is the default maximum time to wait before flushing batched operations + DefaultBatchFlushInterval = 300 * time.Millisecond + // DefaultBatchSizeThreshold is the default number of operations to trigger an immediate flush + DefaultBatchSizeThreshold = 100 + + // AllowedIPOpAdd represents an add operation + AllowedIPOpAdd = "add" + // AllowedIPOpRemove represents a remove operation + AllowedIPOpRemove = "remove" + + EnvDisableWGBatching = "NB_DISABLE_WG_BATCHING" + EnvWGBatchFlushIntervalMS = "NB_WG_BATCH_FLUSH_INTERVAL_MS" + EnvWGBatchSizeThreshold = "NB_WG_BATCH_SIZE_THRESHOLD" +) + +// AllowedIPOperation represents a pending allowed IP operation +type AllowedIPOperation struct { + PeerKey string + Prefix netip.Prefix + Operation string +} + +// PeerUpdateOperation represents a pending peer update operation +type PeerUpdateOperation struct { + PeerKey string + AllowedIPs []netip.Prefix + KeepAlive time.Duration + Endpoint *net.UDPAddr + PreSharedKey *wgtypes.Key +} + +// WGBatcher batches WireGuard configuration updates to reduce syscall overhead +type WGBatcher struct { + configurer device.WGConfigurer + mu sync.Mutex + + allowedIPOps []AllowedIPOperation + peerUpdates map[string]*PeerUpdateOperation + + flushTimer *time.Timer + flushChan chan struct{} + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + batchFlushInterval time.Duration + batchSizeThreshold int +} + +// NewWGBatcher creates a new WireGuard operation batcher +func NewWGBatcher(configurer device.WGConfigurer) *WGBatcher { + if os.Getenv(EnvDisableWGBatching) != "" { + log.Infof("WireGuard allowed IP batching disabled via %s", EnvDisableWGBatching) + return nil + } + + flushInterval := DefaultBatchFlushInterval + sizeThreshold := DefaultBatchSizeThreshold + + if intervalMs := os.Getenv(EnvWGBatchFlushIntervalMS); intervalMs != "" { + if ms, err := strconv.Atoi(intervalMs); err == nil && ms > 0 { + flushInterval = time.Duration(ms) * time.Millisecond + log.Infof("WireGuard batch flush interval set to %v", flushInterval) + } + } + + if threshold := os.Getenv(EnvWGBatchSizeThreshold); threshold != "" { + if size, err := strconv.Atoi(threshold); err == nil && size > 0 { + sizeThreshold = size + log.Infof("WireGuard batch size threshold set to %d", sizeThreshold) + } + } + + log.Info("WireGuard allowed IP batching enabled") + + ctx, cancel := context.WithCancel(context.Background()) + b := &WGBatcher{ + configurer: configurer, + peerUpdates: make(map[string]*PeerUpdateOperation), + flushChan: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + batchFlushInterval: flushInterval, + batchSizeThreshold: sizeThreshold, + } + + b.wg.Add(1) + go b.flushLoop() + + return b +} + +// Close stops the batcher and flushes any pending operations +func (b *WGBatcher) Close() error { + b.mu.Lock() + if b.flushTimer != nil { + b.flushTimer.Stop() + } + b.mu.Unlock() + + b.cancel() + + if err := b.Flush(); err != nil { + log.Errorf("failed to flush pending operations on close: %v", err) + } + + b.wg.Wait() + + return nil +} + +// UpdatePeer batches a peer update operation +func (b *WGBatcher) UpdatePeer(peerKey string, allowedIPs []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + b.mu.Lock() + defer b.mu.Unlock() + + b.peerUpdates[peerKey] = &PeerUpdateOperation{ + PeerKey: peerKey, + AllowedIPs: allowedIPs, + KeepAlive: keepAlive, + Endpoint: endpoint, + PreSharedKey: preSharedKey, + } + + b.scheduleFlush() + return nil +} + +// AddAllowedIP batches an allowed IP addition +func (b *WGBatcher) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { + b.mu.Lock() + defer b.mu.Unlock() + + b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{ + PeerKey: peerKey, + Prefix: allowedIP, + Operation: AllowedIPOpAdd, + }) + + b.scheduleFlush() + return nil +} + +// RemoveAllowedIP batches an allowed IP removal +func (b *WGBatcher) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { + b.mu.Lock() + defer b.mu.Unlock() + + b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{ + PeerKey: peerKey, + Prefix: allowedIP, + Operation: AllowedIPOpRemove, + }) + + b.scheduleFlush() + return nil +} + +// Flush immediately processes all batched operations +func (b *WGBatcher) Flush() error { + b.mu.Lock() + + if b.flushTimer != nil { + b.flushTimer.Stop() + b.flushTimer = nil + } + + peerUpdates := b.peerUpdates + allowedIPOps := b.allowedIPOps + + b.peerUpdates = make(map[string]*PeerUpdateOperation) + b.allowedIPOps = nil + + b.mu.Unlock() + + return b.processBatch(peerUpdates, allowedIPOps) +} + +// scheduleFlush schedules a batch flush if not already scheduled +func (b *WGBatcher) scheduleFlush() { + shouldFlushNow := len(b.allowedIPOps)+len(b.peerUpdates) >= b.batchSizeThreshold + + if shouldFlushNow { + select { + case b.flushChan <- struct{}{}: + default: + } + return + } + + if b.flushTimer == nil { + b.flushTimer = time.AfterFunc(b.batchFlushInterval, func() { + select { + case b.flushChan <- struct{}{}: + default: + } + }) + } +} + +// flushLoop handles periodic flushing of batched operations +func (b *WGBatcher) flushLoop() { + defer b.wg.Done() + + for { + select { + case <-b.flushChan: + if err := b.Flush(); err != nil { + log.Errorf("Error flushing WireGuard operations: %v", err) + } + case <-b.ctx.Done(): + return + } + } +} + +// processBatch processes a batch of operations +func (b *WGBatcher) processBatch(peerUpdates map[string]*PeerUpdateOperation, allowedIPOps []AllowedIPOperation) error { + if len(peerUpdates) == 0 && len(allowedIPOps) == 0 { + return nil + } + + start := time.Now() + defer func() { + duration := time.Since(start) + log.Debugf("Processed batch of %d peer updates and %d allowed IP operations in %v", + len(peerUpdates), len(allowedIPOps), duration) + }() + + var merr *multierror.Error + + if err := b.processPeerUpdates(peerUpdates); err != nil { + merr = multierror.Append(merr, err) + } + + if err := b.processAllowedIPOps(allowedIPOps); err != nil { + merr = multierror.Append(merr, err) + } + + return nberrors.FormatErrorOrNil(merr) +} + +// processPeerUpdates processes peer update operations +func (b *WGBatcher) processPeerUpdates(peerUpdates map[string]*PeerUpdateOperation) error { + var merr *multierror.Error + for _, update := range peerUpdates { + if err := b.configurer.UpdatePeer( + update.PeerKey, + update.AllowedIPs, + update.KeepAlive, + update.Endpoint, + update.PreSharedKey, + ); err != nil { + merr = multierror.Append(merr, fmt.Errorf("update peer %s: %w", update.PeerKey, err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +// processAllowedIPOps processes allowed IP add/remove operations +func (b *WGBatcher) processAllowedIPOps(allowedIPOps []AllowedIPOperation) error { + peerChanges := b.groupAllowedIPChanges(allowedIPOps) + return b.applyAllowedIPChanges(peerChanges) +} + +// groupAllowedIPChanges groups allowed IP operations by peer +func (b *WGBatcher) groupAllowedIPChanges(allowedIPOps []AllowedIPOperation) map[string]struct { + toAdd []netip.Prefix + toRemove []netip.Prefix +} { + peerChanges := make(map[string]struct { + toAdd []netip.Prefix + toRemove []netip.Prefix + }) + + for _, op := range allowedIPOps { + changes := peerChanges[op.PeerKey] + if op.Operation == AllowedIPOpAdd { + changes.toAdd = append(changes.toAdd, op.Prefix) + } else { + changes.toRemove = append(changes.toRemove, op.Prefix) + } + peerChanges[op.PeerKey] = changes + } + + return peerChanges +} + +// applyAllowedIPChanges applies allowed IP changes for each peer +func (b *WGBatcher) applyAllowedIPChanges(peerChanges map[string]struct { + toAdd []netip.Prefix + toRemove []netip.Prefix +}) error { + var merr *multierror.Error + + for peerKey, changes := range peerChanges { + for _, prefix := range changes.toRemove { + if err := b.configurer.RemoveAllowedIP(peerKey, prefix); err != nil { + if errors.Is(err, configurer.ErrPeerNotFound) || errors.Is(err, configurer.ErrAllowedIPNotFound) { + log.Debugf("remove allowed IP %s for peer %s: %v", prefix, peerKey, err) + } else { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s for peer %s: %w", prefix, peerKey, err)) + } + } + } + + for _, prefix := range changes.toAdd { + if err := b.configurer.AddAllowedIP(peerKey, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s for peer %s: %w", prefix, peerKey, err)) + } + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go index 15e26d02f..db0249d11 100644 --- a/client/iface/bind/udp_mux_ios.go +++ b/client/iface/bind/udp_mux_ios.go @@ -4,4 +4,4 @@ package bind func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { // iOS doesn't support nbnet hooks, so this is a no-op -} \ No newline at end of file +} diff --git a/client/iface/iface.go b/client/iface/iface.go index 0e41f8e64..51588b8ba 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -59,6 +59,7 @@ type WGIface struct { mu sync.Mutex configurer device.WGConfigurer + batcher *WGBatcher filter device.PacketFilter wgProxyFactory wgProxyFactory } @@ -128,6 +129,12 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv } log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) + + if endpoint != nil && w.batcher != nil { + if err := w.batcher.Flush(); err != nil { + log.Warnf("failed to flush batched operations: %v", err) + } + } return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } @@ -152,6 +159,10 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { } log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) + + if w.batcher != nil { + return w.batcher.AddAllowedIP(peerKey, allowedIP) + } return w.configurer.AddAllowedIP(peerKey, allowedIP) } @@ -164,6 +175,10 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error } log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) + + if w.batcher != nil { + return w.batcher.RemoveAllowedIP(peerKey, allowedIP) + } return w.configurer.RemoveAllowedIP(peerKey, allowedIP) } @@ -174,6 +189,12 @@ func (w *WGIface) Close() error { var result *multierror.Error + if w.batcher != nil { + if err := w.batcher.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to close WireGuard batcher: %w", err)) + } + } + if err := w.wgProxyFactory.Free(); err != nil { result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index 5e17c6d41..685fccba6 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -17,6 +17,7 @@ func (w *WGIface) Create() error { } w.configurer = cfgr + w.batcher = NewWGBatcher(cfgr) return nil } diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go index 373a9c95a..bc51bf1b7 100644 --- a/client/iface/iface_create_android.go +++ b/client/iface/iface_create_android.go @@ -1,8 +1,6 @@ package iface -import ( - "fmt" -) +import "fmt" // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. @@ -15,6 +13,7 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s return err } w.configurer = cfgr + w.batcher = NewWGBatcher(cfgr) return nil } diff --git a/client/iface/iface_create_darwin.go b/client/iface/iface_create_darwin.go index 1d91bce54..bef388aba 100644 --- a/client/iface/iface_create_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -29,6 +29,7 @@ func (w *WGIface) Create() error { return err } w.configurer = cfgr + w.batcher = NewWGBatcher(cfgr) return nil }