diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index f28cd56e5..d868dd1fb 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -1,7 +1,6 @@ package conntrack import ( - "context" "net/netip" "testing" @@ -12,7 +11,7 @@ import ( ) var logger = log.NewFromLogrus(logrus.StandardLogger()) -var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() +var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index a095a5e39..a48a483f8 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -1,7 +1,6 @@ package uspfilter import ( - "context" "fmt" "net" "net/netip" @@ -24,7 +23,7 @@ import ( ) var logger = log.NewFromLogrus(logrus.StandardLogger()) -var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() +var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() type IFaceMock struct { SetFilterFunc func(device.PacketFilter) error diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index ca79111ef..9488d33ab 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,7 +1,6 @@ package acl import ( - "context" "net" "testing" @@ -15,7 +14,7 @@ import ( mgmProto "github.com/netbirdio/netbird/management/proto" ) -var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() +var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() func TestDefaultManager(t *testing.T) { networkMap := &mgmProto.NetworkMap{ diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index c7eeb7870..8a15c430b 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -31,7 +31,7 @@ import ( "github.com/netbirdio/netbird/formatter" ) -var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() +var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() type mocWGIface struct { filter device.PacketFilter diff --git a/client/internal/engine.go b/client/internal/engine.go index babea2131..6ae494312 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -353,7 +353,7 @@ func (e *Engine) Start() error { // start flow manager right after interface creation publicKey := e.config.WgPrivateKey.PublicKey() - e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder) + e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder) if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 882fed2cb..43dc975fd 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -19,11 +19,9 @@ import ( type rcvChan chan *types.EventFields type Logger struct { mux sync.Mutex - ctx context.Context - cancel context.CancelFunc enabled atomic.Bool rcvChan atomic.Pointer[rcvChan] - cancelReceiver context.CancelFunc + cancel context.CancelFunc statusRecorder *peer.Status wgIfaceIPNet net.IPNet dnsCollection atomic.Bool @@ -31,12 +29,9 @@ type Logger struct { Store types.Store } -func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { +func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { - ctx, cancel := context.WithCancel(ctx) return &Logger{ - ctx: ctx, - cancel: cancel, statusRecorder: statusRecorder, wgIfaceIPNet: wgIfaceIPNet, Store: store.NewMemoryStore(), @@ -70,8 +65,8 @@ func (l *Logger) startReceiver() { } l.mux.Lock() - ctx, cancel := context.WithCancel(l.ctx) - l.cancelReceiver = cancel + ctx, cancel := context.WithCancel(context.Background()) + l.cancel = cancel l.mux.Unlock() c := make(rcvChan, 100) @@ -109,7 +104,7 @@ func (l *Logger) startReceiver() { } } -func (l *Logger) Disable() { +func (l *Logger) Close() { l.stop() l.Store.Close() } @@ -121,9 +116,9 @@ func (l *Logger) stop() { l.enabled.Store(false) l.mux.Lock() - if l.cancelReceiver != nil { - l.cancelReceiver() - l.cancelReceiver = nil + if l.cancel != nil { + l.cancel() + l.cancel = nil } l.rcvChan.Store(nil) l.mux.Unlock() @@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { l.exitNodeCollection.Store(exitNodeCollection) } -func (l *Logger) Close() { - l.stop() - l.cancel() -} - func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) { diff --git a/client/internal/netflow/logger/logger_test.go b/client/internal/netflow/logger/logger_test.go index 3ce9d8fd8..06e10c36c 100644 --- a/client/internal/netflow/logger/logger_test.go +++ b/client/internal/netflow/logger/logger_test.go @@ -1,7 +1,6 @@ package logger_test import ( - "context" "net" "testing" "time" @@ -13,7 +12,7 @@ import ( ) func TestStore(t *testing.T) { - logger := logger.New(context.Background(), nil, net.IPNet{}) + logger := logger.New(nil, net.IPNet{}) logger.Enable() event := types.EventFields{ @@ -40,7 +39,7 @@ func TestStore(t *testing.T) { } // test disable - logger.Disable() + logger.Close() wait() logger.StoreEvent(event) wait() diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index ce642b86a..0f1cdce37 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -27,18 +27,18 @@ type Manager struct { logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker - ctx context.Context receiverClient *client.GRPCClient publicKey []byte + cancel context.CancelFunc } // NewManager creates a new netflow manager -func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { +func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { var ipNet net.IPNet if iface != nil { ipNet = *iface.Address().Network } - flowLogger := logger.New(ctx, statusRecorder, ipNet) + flowLogger := logger.New(statusRecorder, ipNet) var ct nftypes.ConnTracker if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { @@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte return &Manager{ logger: flowLogger, conntrack: ct, - ctx: ctx, publicKey: publicKey, } } @@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool { func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error { // first make sender ready so events don't pile up if m.needsNewClient(previous) { - if m.receiverClient != nil { - if err := m.receiverClient.Close(); err != nil { - log.Warnf("error closing previous flow client: %v", err) - } + if err := m.resetClient(); err != nil { + return fmt.Errorf("reset client: %w", err) } - - flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval) - if err != nil { - return fmt.Errorf("create client: %w", err) - } - log.Infof("flow client configured to connect to %s", m.flowConfig.URL) - - m.receiverClient = flowClient - go m.receiveACKs(flowClient) - go m.startSender() } m.logger.Enable() @@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error { return nil } +func (m *Manager) resetClient() error { + if m.receiverClient != nil { + if err := m.receiverClient.Close(); err != nil { + log.Warnf("error closing previous flow client: %v", err) + } + } + + flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval) + if err != nil { + return fmt.Errorf("create client: %w", err) + } + log.Infof("flow client configured to connect to %s", m.flowConfig.URL) + + m.receiverClient = flowClient + + if m.cancel != nil { + m.cancel() + } + + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + + go m.receiveACKs(ctx, flowClient) + go m.startSender(ctx) + + return nil +} + // disableFlow stops components for flow tracking func (m *Manager) disableFlow() error { + if m.cancel != nil { + m.cancel() + } + if m.conntrack != nil { m.conntrack.Stop() } - m.logger.Disable() + m.logger.Close() if m.receiverClient != nil { return m.receiverClient.Close() } + return nil } @@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection) + changed := previous != nil && update.Enabled != previous.Enabled if update.Enabled { - log.Infof("netflow manager enabled; starting netflow manager") + if changed { + log.Infof("netflow manager enabled; starting netflow manager") + } return m.enableFlow(previous) } - log.Infof("netflow manager disabled; stopping netflow manager") - err := m.disableFlow() - if err != nil { - log.Errorf("failed to disable netflow manager: %v", err) + if changed { + log.Infof("netflow manager disabled; stopping netflow manager") } - return err + return m.disableFlow() } // Close cleans up all resources @@ -151,17 +172,9 @@ func (m *Manager) Close() { m.mux.Lock() defer m.mux.Unlock() - if m.conntrack != nil { - m.conntrack.Close() + if err := m.disableFlow(); err != nil { + log.Warnf("failed to disable flow manager: %v", err) } - - if m.receiverClient != nil { - if err := m.receiverClient.Close(); err != nil { - log.Warnf("failed to close receiver client: %v", err) - } - } - - m.logger.Close() } // GetLogger returns the flow logger @@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger { return m.logger } -func (m *Manager) startSender() { +func (m *Manager) startSender(ctx context.Context) { ticker := time.NewTicker(m.flowConfig.Interval) defer ticker.Stop() for { select { - case <-m.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: events := m.logger.GetEvents() @@ -190,8 +203,8 @@ func (m *Manager) startSender() { } } -func (m *Manager) receiveACKs(client *client.GRPCClient) { - err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error { +func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) { + err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error { id, err := uuid.FromBytes(ack.EventId) if err != nil { log.Warnf("failed to convert ack event id to uuid: %v", err) diff --git a/client/internal/netflow/manager_test.go b/client/internal/netflow/manager_test.go new file mode 100644 index 000000000..bf7e05f8e --- /dev/null +++ b/client/internal/netflow/manager_test.go @@ -0,0 +1,200 @@ +package netflow + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/netflow/types" + "github.com/netbirdio/netbird/client/internal/peer" +) + +type mockIFaceMapper struct { + address wgaddr.Address + isUserspaceBind bool +} + +func (m *mockIFaceMapper) Name() string { + return "wt0" +} + +func (m *mockIFaceMapper) Address() wgaddr.Address { + return m.address +} + +func (m *mockIFaceMapper) IsUserspaceBind() bool { + return m.isUserspaceBind +} + +func TestManager_Update(t *testing.T) { + mockIFace := &mockIFaceMapper{ + address: wgaddr.Address{ + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.1"), + Mask: net.CIDRMask(24, 32), + }, + }, + isUserspaceBind: true, + } + + publicKey := []byte("test-public-key") + statusRecorder := peer.NewRecorder("") + + manager := NewManager(mockIFace, publicKey, statusRecorder) + + tests := []struct { + name string + config *types.FlowConfig + }{ + { + name: "nil config", + config: nil, + }, + { + name: "disabled config", + config: &types.FlowConfig{ + Enabled: false, + }, + }, + { + name: "enabled config with minimal valid settings", + config: &types.FlowConfig{ + Enabled: true, + URL: "https://example.com", + TokenPayload: "test-payload", + TokenSignature: "test-signature", + Interval: 30 * time.Second, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := manager.Update(tc.config) + + assert.NoError(t, err) + + if tc.config == nil { + return + } + + require.NotNil(t, manager.flowConfig) + + if tc.config.Enabled { + assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled) + } + + if tc.config.URL != "" { + assert.Equal(t, tc.config.URL, manager.flowConfig.URL) + } + + if tc.config.TokenPayload != "" { + assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload) + } + }) + } +} + +func TestManager_Update_TokenPreservation(t *testing.T) { + mockIFace := &mockIFaceMapper{ + address: wgaddr.Address{ + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.1"), + Mask: net.CIDRMask(24, 32), + }, + }, + isUserspaceBind: true, + } + + publicKey := []byte("test-public-key") + manager := NewManager(mockIFace, publicKey, nil) + + // First update with tokens + initialConfig := &types.FlowConfig{ + Enabled: false, + TokenPayload: "initial-payload", + TokenSignature: "initial-signature", + } + + err := manager.Update(initialConfig) + require.NoError(t, err) + + // Second update without tokens should preserve them + updatedConfig := &types.FlowConfig{ + Enabled: false, + URL: "https://example.com", + } + + err = manager.Update(updatedConfig) + require.NoError(t, err) + + // Verify tokens were preserved + assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload) + assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature) +} + +func TestManager_NeedsNewClient(t *testing.T) { + manager := &Manager{} + + tests := []struct { + name string + previous *types.FlowConfig + current *types.FlowConfig + expected bool + }{ + { + name: "nil previous config", + previous: nil, + current: &types.FlowConfig{}, + expected: true, + }, + { + name: "previous disabled", + previous: &types.FlowConfig{Enabled: false}, + current: &types.FlowConfig{Enabled: true}, + expected: true, + }, + { + name: "different URL", + previous: &types.FlowConfig{Enabled: true, URL: "old-url"}, + current: &types.FlowConfig{Enabled: true, URL: "new-url"}, + expected: true, + }, + { + name: "different TokenPayload", + previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"}, + current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"}, + expected: true, + }, + { + name: "different TokenSignature", + previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"}, + current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"}, + expected: true, + }, + { + name: "same config", + previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"}, + current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"}, + expected: false, + }, + { + name: "only interval changed", + previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second}, + current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + manager.flowConfig = tc.current + result := manager.needsNewClient(tc.previous) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index 881f30bd8..ea752131b 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -120,9 +120,6 @@ type FlowLogger interface { Close() // Enable enables the flow logger receiver Enable() - // Disable disables the flow logger receiver - Disable() - // UpdateConfig updates the flow manager configuration UpdateConfig(dnsCollection, exitNodeCollection bool) } diff --git a/flow/client/client.go b/flow/client/client.go index 2d3890ba5..949824065 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -13,10 +13,12 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" @@ -77,17 +79,24 @@ func (c *GRPCClient) Close() error { defer c.streamMu.Unlock() c.stream = nil - return c.clientConn.Close() + if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) { + return fmt.Errorf("close client connection: %w", err) + } + + return nil } func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error { backOff := defaultBackoff(ctx, interval) operation := func() error { - err := c.establishStreamAndReceive(ctx, msgHandler) - if err != nil { + if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil { + if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + return fmt.Errorf("receive: %w: %w", err, context.Canceled) + } log.Errorf("receive failed: %v", err) + return fmt.Errorf("receive: %w", err) } - return err + return nil } if err := backoff.Retry(operation, backOff); err != nil { diff --git a/flow/client/client_test.go b/flow/client/client_test.go new file mode 100644 index 000000000..efe01c003 --- /dev/null +++ b/flow/client/client_test.go @@ -0,0 +1,256 @@ +package client_test + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + flow "github.com/netbirdio/netbird/flow/client" + "github.com/netbirdio/netbird/flow/proto" +) + +type testServer struct { + proto.UnimplementedFlowServiceServer + events chan *proto.FlowEvent + acks chan *proto.FlowEventAck + grpcSrv *grpc.Server + addr string +} + +func newTestServer(t *testing.T) *testServer { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + s := &testServer{ + events: make(chan *proto.FlowEvent, 100), + acks: make(chan *proto.FlowEventAck, 100), + grpcSrv: grpc.NewServer(), + addr: listener.Addr().String(), + } + + proto.RegisterFlowServiceServer(s.grpcSrv, s) + + go func() { + if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { + t.Logf("server error: %v", err) + } + }() + + t.Cleanup(func() { + s.grpcSrv.Stop() + }) + + return s +} + +func (s *testServer) Events(stream proto.FlowService_EventsServer) error { + err := stream.Send(&proto.FlowEventAck{IsInitiator: true}) + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + + go func() { + defer cancel() + for { + event, err := stream.Recv() + if err != nil { + return + } + + if !event.IsInitiator { + select { + case s.events <- event: + ack := &proto.FlowEventAck{ + EventId: event.EventId, + } + select { + case s.acks <- ack: + case <-ctx.Done(): + return + } + case <-ctx.Done(): + return + } + } + } + }() + + for { + select { + case ack := <-s.acks: + if err := stream.Send(ack); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func TestReceive(t *testing.T) { + server := newTestServer(t) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + t.Cleanup(func() { + err := client.Close() + assert.NoError(t, err, "failed to close flow") + }) + + receivedAcks := make(map[string]bool) + receiveDone := make(chan struct{}) + + go func() { + err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { + if !msg.IsInitiator && len(msg.EventId) > 0 { + id := string(msg.EventId) + receivedAcks[id] = true + + if len(receivedAcks) >= 3 { + close(receiveDone) + } + } + return nil + }) + if err != nil && !errors.Is(err, context.Canceled) { + t.Logf("receive error: %v", err) + } + }() + + time.Sleep(500 * time.Millisecond) + + for i := 0; i < 3; i++ { + eventID := uuid.New().String() + + // Create acknowledgment and send it to the flow through our test server + ack := &proto.FlowEventAck{ + EventId: []byte(eventID), + } + + select { + case server.acks <- ack: + case <-time.After(time.Second): + t.Fatal("timeout sending ack") + } + } + + select { + case <-receiveDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for acks to be processed") + } + + assert.Equal(t, 3, len(receivedAcks)) +} + +func TestReceive_ContextCancellation(t *testing.T) { + server := newTestServer(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + t.Cleanup(func() { + err := client.Close() + assert.NoError(t, err, "failed to close flow") + }) + + go func() { + time.Sleep(100 * time.Millisecond) + cancel() + }() + + handlerCalled := false + msgHandler := func(msg *proto.FlowEventAck) error { + if !msg.IsInitiator { + handlerCalled = true + } + return nil + } + + err = client.Receive(ctx, 1*time.Second, msgHandler) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + assert.False(t, handlerCalled) +} + +func TestSend(t *testing.T) { + server := newTestServer(t) + + client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) + require.NoError(t, err) + t.Cleanup(func() { + err := client.Close() + assert.NoError(t, err, "failed to close flow") + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + ackReceived := make(chan struct{}) + + go func() { + err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error { + if len(ack.EventId) > 0 && !ack.IsInitiator { + close(ackReceived) + } + return nil + }) + if err != nil && !errors.Is(err, context.Canceled) { + t.Logf("receive error: %v", err) + } + }() + + time.Sleep(500 * time.Millisecond) + + testEvent := &proto.FlowEvent{ + EventId: []byte("test-event-id"), + PublicKey: []byte("test-public-key"), + FlowFields: &proto.FlowFields{ + FlowId: []byte("test-flow-id"), + Protocol: 6, + SourceIp: []byte{192, 168, 1, 1}, + DestIp: []byte{192, 168, 1, 2}, + ConnectionInfo: &proto.FlowFields_PortInfo{ + PortInfo: &proto.PortInfo{ + SourcePort: 12345, + DestPort: 443, + }, + }, + }, + } + + err = client.Send(testEvent) + require.NoError(t, err) + + var receivedEvent *proto.FlowEvent + select { + case receivedEvent = <-server.events: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for event to be received by server") + } + + assert.Equal(t, testEvent.EventId, receivedEvent.EventId) + assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey) + + select { + case <-ackReceived: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ack to be received by flow") + } +}