mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
[client] Stop flow grpc receiver properly (#3596)
This commit is contained in:
parent
6124e3b937
commit
29a6e5be71
@ -1,7 +1,6 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -24,7 +23,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -15,7 +14,7 @@ import (
|
|||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
@ -31,7 +31,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
)
|
)
|
||||||
|
|
||||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type mocWGIface struct {
|
type mocWGIface struct {
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
|
@ -353,7 +353,7 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
// start flow manager right after interface creation
|
// start flow manager right after interface creation
|
||||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
|
@ -19,11 +19,9 @@ import (
|
|||||||
type rcvChan chan *types.EventFields
|
type rcvChan chan *types.EventFields
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
enabled atomic.Bool
|
enabled atomic.Bool
|
||||||
rcvChan atomic.Pointer[rcvChan]
|
rcvChan atomic.Pointer[rcvChan]
|
||||||
cancelReceiver context.CancelFunc
|
cancel context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgIfaceIPNet net.IPNet
|
wgIfaceIPNet net.IPNet
|
||||||
dnsCollection atomic.Bool
|
dnsCollection atomic.Bool
|
||||||
@ -31,12 +29,9 @@ type Logger struct {
|
|||||||
Store types.Store
|
Store types.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
return &Logger{
|
return &Logger{
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgIfaceIPNet: wgIfaceIPNet,
|
wgIfaceIPNet: wgIfaceIPNet,
|
||||||
Store: store.NewMemoryStore(),
|
Store: store.NewMemoryStore(),
|
||||||
@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
l.mux.Lock()
|
l.mux.Lock()
|
||||||
ctx, cancel := context.WithCancel(l.ctx)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
l.cancelReceiver = cancel
|
l.cancel = cancel
|
||||||
l.mux.Unlock()
|
l.mux.Unlock()
|
||||||
|
|
||||||
c := make(rcvChan, 100)
|
c := make(rcvChan, 100)
|
||||||
@ -109,7 +104,7 @@ func (l *Logger) startReceiver() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Disable() {
|
func (l *Logger) Close() {
|
||||||
l.stop()
|
l.stop()
|
||||||
l.Store.Close()
|
l.Store.Close()
|
||||||
}
|
}
|
||||||
@ -121,9 +116,9 @@ func (l *Logger) stop() {
|
|||||||
|
|
||||||
l.enabled.Store(false)
|
l.enabled.Store(false)
|
||||||
l.mux.Lock()
|
l.mux.Lock()
|
||||||
if l.cancelReceiver != nil {
|
if l.cancel != nil {
|
||||||
l.cancelReceiver()
|
l.cancel()
|
||||||
l.cancelReceiver = nil
|
l.cancel = nil
|
||||||
}
|
}
|
||||||
l.rcvChan.Store(nil)
|
l.rcvChan.Store(nil)
|
||||||
l.mux.Unlock()
|
l.mux.Unlock()
|
||||||
@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
|||||||
l.exitNodeCollection.Store(exitNodeCollection)
|
l.exitNodeCollection.Store(exitNodeCollection)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Close() {
|
|
||||||
l.stop()
|
|
||||||
l.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||||
// check dns collection
|
// check dns collection
|
||||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package logger_test
|
package logger_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -13,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
logger := logger.New(context.Background(), nil, net.IPNet{})
|
logger := logger.New(nil, net.IPNet{})
|
||||||
logger.Enable()
|
logger.Enable()
|
||||||
|
|
||||||
event := types.EventFields{
|
event := types.EventFields{
|
||||||
@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// test disable
|
// test disable
|
||||||
logger.Disable()
|
logger.Close()
|
||||||
wait()
|
wait()
|
||||||
logger.StoreEvent(event)
|
logger.StoreEvent(event)
|
||||||
wait()
|
wait()
|
||||||
|
@ -27,18 +27,18 @@ type Manager struct {
|
|||||||
logger nftypes.FlowLogger
|
logger nftypes.FlowLogger
|
||||||
flowConfig *nftypes.FlowConfig
|
flowConfig *nftypes.FlowConfig
|
||||||
conntrack nftypes.ConnTracker
|
conntrack nftypes.ConnTracker
|
||||||
ctx context.Context
|
|
||||||
receiverClient *client.GRPCClient
|
receiverClient *client.GRPCClient
|
||||||
publicKey []byte
|
publicKey []byte
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new netflow manager
|
// NewManager creates a new netflow manager
|
||||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||||
var ipNet net.IPNet
|
var ipNet net.IPNet
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
ipNet = *iface.Address().Network
|
ipNet = *iface.Address().Network
|
||||||
}
|
}
|
||||||
flowLogger := logger.New(ctx, statusRecorder, ipNet)
|
flowLogger := logger.New(statusRecorder, ipNet)
|
||||||
|
|
||||||
var ct nftypes.ConnTracker
|
var ct nftypes.ConnTracker
|
||||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||||
@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
|
|||||||
return &Manager{
|
return &Manager{
|
||||||
logger: flowLogger,
|
logger: flowLogger,
|
||||||
conntrack: ct,
|
conntrack: ct,
|
||||||
ctx: ctx,
|
|
||||||
publicKey: publicKey,
|
publicKey: publicKey,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -68,23 +67,11 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
|||||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||||
// first make sender ready so events don't pile up
|
// first make sender ready so events don't pile up
|
||||||
if m.needsNewClient(previous) {
|
if m.needsNewClient(previous) {
|
||||||
if m.receiverClient != nil {
|
if err := m.resetClient(); err != nil {
|
||||||
if err := m.receiverClient.Close(); err != nil {
|
return fmt.Errorf("reset client: %w", err)
|
||||||
log.Warnf("error closing previous flow client: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create client: %w", err)
|
|
||||||
}
|
|
||||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
|
||||||
|
|
||||||
m.receiverClient = flowClient
|
|
||||||
go m.receiveACKs(flowClient)
|
|
||||||
go m.startSender()
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Enable()
|
m.logger.Enable()
|
||||||
|
|
||||||
if m.conntrack != nil {
|
if m.conntrack != nil {
|
||||||
@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) resetClient() error {
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
if err := m.receiverClient.Close(); err != nil {
|
||||||
|
log.Warnf("error closing previous flow client: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create client: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||||
|
|
||||||
|
m.receiverClient = flowClient
|
||||||
|
|
||||||
|
if m.cancel != nil {
|
||||||
|
m.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
m.cancel = cancel
|
||||||
|
|
||||||
|
go m.receiveACKs(ctx, flowClient)
|
||||||
|
go m.startSender(ctx)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// disableFlow stops components for flow tracking
|
// disableFlow stops components for flow tracking
|
||||||
func (m *Manager) disableFlow() error {
|
func (m *Manager) disableFlow() error {
|
||||||
|
if m.cancel != nil {
|
||||||
|
m.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
if m.conntrack != nil {
|
if m.conntrack != nil {
|
||||||
m.conntrack.Stop()
|
m.conntrack.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
m.logger.Disable()
|
m.logger.Close()
|
||||||
|
|
||||||
if m.receiverClient != nil {
|
if m.receiverClient != nil {
|
||||||
return m.receiverClient.Close()
|
return m.receiverClient.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
|||||||
|
|
||||||
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
|
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
|
||||||
|
|
||||||
|
changed := previous != nil && update.Enabled != previous.Enabled
|
||||||
if update.Enabled {
|
if update.Enabled {
|
||||||
|
if changed {
|
||||||
log.Infof("netflow manager enabled; starting netflow manager")
|
log.Infof("netflow manager enabled; starting netflow manager")
|
||||||
|
}
|
||||||
return m.enableFlow(previous)
|
return m.enableFlow(previous)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||||
err := m.disableFlow()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to disable netflow manager: %v", err)
|
|
||||||
}
|
}
|
||||||
return err
|
return m.disableFlow()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close cleans up all resources
|
// Close cleans up all resources
|
||||||
@ -151,17 +172,9 @@ func (m *Manager) Close() {
|
|||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
if m.conntrack != nil {
|
if err := m.disableFlow(); err != nil {
|
||||||
m.conntrack.Close()
|
log.Warnf("failed to disable flow manager: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.receiverClient != nil {
|
|
||||||
if err := m.receiverClient.Close(); err != nil {
|
|
||||||
log.Warnf("failed to close receiver client: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.logger.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLogger returns the flow logger
|
// GetLogger returns the flow logger
|
||||||
@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
|
|||||||
return m.logger
|
return m.logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) startSender() {
|
func (m *Manager) startSender(ctx context.Context) {
|
||||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
events := m.logger.GetEvents()
|
events := m.logger.GetEvents()
|
||||||
@ -190,8 +203,8 @@ func (m *Manager) startSender() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
|
||||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||||
id, err := uuid.FromBytes(ack.EventId)
|
id, err := uuid.FromBytes(ack.EventId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to convert ack event id to uuid: %v", err)
|
log.Warnf("failed to convert ack event id to uuid: %v", err)
|
||||||
|
200
client/internal/netflow/manager_test.go
Normal file
200
client/internal/netflow/manager_test.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package netflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockIFaceMapper struct {
|
||||||
|
address wgaddr.Address
|
||||||
|
isUserspaceBind bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockIFaceMapper) Name() string {
|
||||||
|
return "wt0"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||||
|
return m.address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||||
|
return m.isUserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Update(t *testing.T) {
|
||||||
|
mockIFace := &mockIFaceMapper{
|
||||||
|
address: wgaddr.Address{
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
isUserspaceBind: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := []byte("test-public-key")
|
||||||
|
statusRecorder := peer.NewRecorder("")
|
||||||
|
|
||||||
|
manager := NewManager(mockIFace, publicKey, statusRecorder)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *types.FlowConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil config",
|
||||||
|
config: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled config",
|
||||||
|
config: &types.FlowConfig{
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled config with minimal valid settings",
|
||||||
|
config: &types.FlowConfig{
|
||||||
|
Enabled: true,
|
||||||
|
URL: "https://example.com",
|
||||||
|
TokenPayload: "test-payload",
|
||||||
|
TokenSignature: "test-signature",
|
||||||
|
Interval: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := manager.Update(tc.config)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
if tc.config == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotNil(t, manager.flowConfig)
|
||||||
|
|
||||||
|
if tc.config.Enabled {
|
||||||
|
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.config.URL != "" {
|
||||||
|
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.config.TokenPayload != "" {
|
||||||
|
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||||
|
mockIFace := &mockIFaceMapper{
|
||||||
|
address: wgaddr.Address{
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
isUserspaceBind: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := []byte("test-public-key")
|
||||||
|
manager := NewManager(mockIFace, publicKey, nil)
|
||||||
|
|
||||||
|
// First update with tokens
|
||||||
|
initialConfig := &types.FlowConfig{
|
||||||
|
Enabled: false,
|
||||||
|
TokenPayload: "initial-payload",
|
||||||
|
TokenSignature: "initial-signature",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.Update(initialConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Second update without tokens should preserve them
|
||||||
|
updatedConfig := &types.FlowConfig{
|
||||||
|
Enabled: false,
|
||||||
|
URL: "https://example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.Update(updatedConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify tokens were preserved
|
||||||
|
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
|
||||||
|
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_NeedsNewClient(t *testing.T) {
|
||||||
|
manager := &Manager{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
previous *types.FlowConfig
|
||||||
|
current *types.FlowConfig
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil previous config",
|
||||||
|
previous: nil,
|
||||||
|
current: &types.FlowConfig{},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "previous disabled",
|
||||||
|
previous: &types.FlowConfig{Enabled: false},
|
||||||
|
current: &types.FlowConfig{Enabled: true},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different URL",
|
||||||
|
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
|
||||||
|
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different TokenPayload",
|
||||||
|
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
|
||||||
|
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different TokenSignature",
|
||||||
|
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
|
||||||
|
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same config",
|
||||||
|
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||||
|
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only interval changed",
|
||||||
|
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
|
||||||
|
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
manager.flowConfig = tc.current
|
||||||
|
result := manager.needsNewClient(tc.previous)
|
||||||
|
assert.Equal(t, tc.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -120,9 +120,6 @@ type FlowLogger interface {
|
|||||||
Close()
|
Close()
|
||||||
// Enable enables the flow logger receiver
|
// Enable enables the flow logger receiver
|
||||||
Enable()
|
Enable()
|
||||||
// Disable disables the flow logger receiver
|
|
||||||
Disable()
|
|
||||||
|
|
||||||
// UpdateConfig updates the flow manager configuration
|
// UpdateConfig updates the flow manager configuration
|
||||||
UpdateConfig(dnsCollection, exitNodeCollection bool)
|
UpdateConfig(dnsCollection, exitNodeCollection bool)
|
||||||
}
|
}
|
||||||
|
@ -13,10 +13,12 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/connectivity"
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/flow/proto"
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
@ -77,17 +79,24 @@ func (c *GRPCClient) Close() error {
|
|||||||
defer c.streamMu.Unlock()
|
defer c.streamMu.Unlock()
|
||||||
|
|
||||||
c.stream = nil
|
c.stream = nil
|
||||||
return c.clientConn.Close()
|
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
return fmt.Errorf("close client connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||||
backOff := defaultBackoff(ctx, interval)
|
backOff := defaultBackoff(ctx, interval)
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
err := c.establishStreamAndReceive(ctx, msgHandler)
|
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
|
||||||
if err != nil {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||||
log.Errorf("receive failed: %v", err)
|
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
|
||||||
}
|
}
|
||||||
return err
|
log.Errorf("receive failed: %v", err)
|
||||||
|
return fmt.Errorf("receive: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := backoff.Retry(operation, backOff); err != nil {
|
if err := backoff.Retry(operation, backOff); err != nil {
|
||||||
|
256
flow/client/client_test.go
Normal file
256
flow/client/client_test.go
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
package client_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
flow "github.com/netbirdio/netbird/flow/client"
|
||||||
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testServer struct {
|
||||||
|
proto.UnimplementedFlowServiceServer
|
||||||
|
events chan *proto.FlowEvent
|
||||||
|
acks chan *proto.FlowEventAck
|
||||||
|
grpcSrv *grpc.Server
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestServer(t *testing.T) *testServer {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s := &testServer{
|
||||||
|
events: make(chan *proto.FlowEvent, 100),
|
||||||
|
acks: make(chan *proto.FlowEventAck, 100),
|
||||||
|
grpcSrv: grpc.NewServer(),
|
||||||
|
addr: listener.Addr().String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
||||||
|
t.Logf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.grpcSrv.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||||
|
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(stream.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
for {
|
||||||
|
event, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !event.IsInitiator {
|
||||||
|
select {
|
||||||
|
case s.events <- event:
|
||||||
|
ack := &proto.FlowEventAck{
|
||||||
|
EventId: event.EventId,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case s.acks <- ack:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ack := <-s.acks:
|
||||||
|
if err := stream.Send(ack); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReceive(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := client.Close()
|
||||||
|
assert.NoError(t, err, "failed to close flow")
|
||||||
|
})
|
||||||
|
|
||||||
|
receivedAcks := make(map[string]bool)
|
||||||
|
receiveDone := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
|
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||||
|
id := string(msg.EventId)
|
||||||
|
receivedAcks[id] = true
|
||||||
|
|
||||||
|
if len(receivedAcks) >= 3 {
|
||||||
|
close(receiveDone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
t.Logf("receive error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
eventID := uuid.New().String()
|
||||||
|
|
||||||
|
// Create acknowledgment and send it to the flow through our test server
|
||||||
|
ack := &proto.FlowEventAck{
|
||||||
|
EventId: []byte(eventID),
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case server.acks <- ack:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timeout sending ack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-receiveDone:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for acks to be processed")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(receivedAcks))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReceive_ContextCancellation(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := client.Close()
|
||||||
|
assert.NoError(t, err, "failed to close flow")
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
msgHandler := func(msg *proto.FlowEventAck) error {
|
||||||
|
if !msg.IsInitiator {
|
||||||
|
handlerCalled = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = client.Receive(ctx, 1*time.Second, msgHandler)
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "context canceled")
|
||||||
|
assert.False(t, handlerCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSend(t *testing.T) {
|
||||||
|
server := newTestServer(t)
|
||||||
|
|
||||||
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := client.Close()
|
||||||
|
assert.NoError(t, err, "failed to close flow")
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
ackReceived := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
|
||||||
|
if len(ack.EventId) > 0 && !ack.IsInitiator {
|
||||||
|
close(ackReceived)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
t.Logf("receive error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
testEvent := &proto.FlowEvent{
|
||||||
|
EventId: []byte("test-event-id"),
|
||||||
|
PublicKey: []byte("test-public-key"),
|
||||||
|
FlowFields: &proto.FlowFields{
|
||||||
|
FlowId: []byte("test-flow-id"),
|
||||||
|
Protocol: 6,
|
||||||
|
SourceIp: []byte{192, 168, 1, 1},
|
||||||
|
DestIp: []byte{192, 168, 1, 2},
|
||||||
|
ConnectionInfo: &proto.FlowFields_PortInfo{
|
||||||
|
PortInfo: &proto.PortInfo{
|
||||||
|
SourcePort: 12345,
|
||||||
|
DestPort: 443,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = client.Send(testEvent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var receivedEvent *proto.FlowEvent
|
||||||
|
select {
|
||||||
|
case receivedEvent = <-server.events:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for event to be received by server")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
|
||||||
|
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ackReceived:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ack to be received by flow")
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user