[client] Stop flow grpc receiver properly (#3596)

This commit is contained in:
Viktor Liu 2025-03-28 16:08:31 +01:00 committed by GitHub
parent 6124e3b937
commit 29a6e5be71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 536 additions and 75 deletions

View File

@ -1,7 +1,6 @@
package conntrack package conntrack
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
@ -12,7 +11,7 @@ import (
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {

View File

@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -24,7 +23,7 @@ import (
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error

View File

@ -1,7 +1,6 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@ -15,7 +14,7 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{

View File

@ -31,7 +31,7 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter device.PacketFilter

View File

@ -353,7 +353,7 @@ func (e *Engine) Start() error {
// start flow manager right after interface creation // start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey() publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder) e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")

View File

@ -19,11 +19,9 @@ import (
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
type Logger struct { type Logger struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context
cancel context.CancelFunc
enabled atomic.Bool enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgIfaceIPNet net.IPNet wgIfaceIPNet net.IPNet
dnsCollection atomic.Bool dnsCollection atomic.Bool
@ -31,12 +29,9 @@ type Logger struct {
Store types.Store Store types.Store
} }
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
ctx, cancel := context.WithCancel(ctx)
return &Logger{ return &Logger{
ctx: ctx,
cancel: cancel,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet, wgIfaceIPNet: wgIfaceIPNet,
Store: store.NewMemoryStore(), Store: store.NewMemoryStore(),
@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
} }
l.mux.Lock() l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx) ctx, cancel := context.WithCancel(context.Background())
l.cancelReceiver = cancel l.cancel = cancel
l.mux.Unlock() l.mux.Unlock()
c := make(rcvChan, 100) c := make(rcvChan, 100)
@ -109,7 +104,7 @@ func (l *Logger) startReceiver() {
} }
} }
func (l *Logger) Disable() { func (l *Logger) Close() {
l.stop() l.stop()
l.Store.Close() l.Store.Close()
} }
@ -121,9 +116,9 @@ func (l *Logger) stop() {
l.enabled.Store(false) l.enabled.Store(false)
l.mux.Lock() l.mux.Lock()
if l.cancelReceiver != nil { if l.cancel != nil {
l.cancelReceiver() l.cancel()
l.cancelReceiver = nil l.cancel = nil
} }
l.rcvChan.Store(nil) l.rcvChan.Store(nil)
l.mux.Unlock() l.mux.Unlock()
@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
l.exitNodeCollection.Store(exitNodeCollection) l.exitNodeCollection.Store(exitNodeCollection)
} }
func (l *Logger) Close() {
l.stop()
l.cancel()
}
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection // check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) { if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {

View File

@ -1,7 +1,6 @@
package logger_test package logger_test
import ( import (
"context"
"net" "net"
"testing" "testing"
"time" "time"
@ -13,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(context.Background(), nil, net.IPNet{}) logger := logger.New(nil, net.IPNet{})
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{
@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
} }
// test disable // test disable
logger.Disable() logger.Close()
wait() wait()
logger.StoreEvent(event) logger.StoreEvent(event)
wait() wait()

View File

@ -27,18 +27,18 @@ type Manager struct {
logger nftypes.FlowLogger logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker conntrack nftypes.ConnTracker
ctx context.Context
receiverClient *client.GRPCClient receiverClient *client.GRPCClient
publicKey []byte publicKey []byte
cancel context.CancelFunc
} }
// NewManager creates a new netflow manager // NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet var ipNet net.IPNet
if iface != nil { if iface != nil {
ipNet = *iface.Address().Network ipNet = *iface.Address().Network
} }
flowLogger := logger.New(ctx, statusRecorder, ipNet) flowLogger := logger.New(statusRecorder, ipNet)
var ct nftypes.ConnTracker var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
return &Manager{ return &Manager{
logger: flowLogger, logger: flowLogger,
conntrack: ct, conntrack: ct,
ctx: ctx,
publicKey: publicKey, publicKey: publicKey,
} }
} }
@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error { func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up // first make sender ready so events don't pile up
if m.needsNewClient(previous) { if m.needsNewClient(previous) {
if m.receiverClient != nil { if err := m.resetClient(); err != nil {
if err := m.receiverClient.Close(); err != nil { return fmt.Errorf("reset client: %w", err)
log.Warnf("error closing previous flow client: %v", err)
}
} }
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
} }
m.logger.Enable() m.logger.Enable()
@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
return nil return nil
} }
func (m *Manager) resetClient() error {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %v", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
if m.cancel != nil {
m.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
return nil
}
// disableFlow stops components for flow tracking // disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error { func (m *Manager) disableFlow() error {
if m.cancel != nil {
m.cancel()
}
if m.conntrack != nil { if m.conntrack != nil {
m.conntrack.Stop() m.conntrack.Stop()
} }
m.logger.Disable() m.logger.Close()
if m.receiverClient != nil { if m.receiverClient != nil {
return m.receiverClient.Close() return m.receiverClient.Close()
} }
return nil return nil
} }
@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection) m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
changed := previous != nil && update.Enabled != previous.Enabled
if update.Enabled { if update.Enabled {
log.Infof("netflow manager enabled; starting netflow manager") if changed {
log.Infof("netflow manager enabled; starting netflow manager")
}
return m.enableFlow(previous) return m.enableFlow(previous)
} }
log.Infof("netflow manager disabled; stopping netflow manager") if changed {
err := m.disableFlow() log.Infof("netflow manager disabled; stopping netflow manager")
if err != nil {
log.Errorf("failed to disable netflow manager: %v", err)
} }
return err return m.disableFlow()
} }
// Close cleans up all resources // Close cleans up all resources
@ -151,17 +172,9 @@ func (m *Manager) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
if m.conntrack != nil { if err := m.disableFlow(); err != nil {
m.conntrack.Close() log.Warnf("failed to disable flow manager: %v", err)
} }
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %v", err)
}
}
m.logger.Close()
} }
// GetLogger returns the flow logger // GetLogger returns the flow logger
@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
return m.logger return m.logger
} }
func (m *Manager) startSender() { func (m *Manager) startSender(ctx context.Context) {
ticker := time.NewTicker(m.flowConfig.Interval) ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-m.ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
events := m.logger.GetEvents() events := m.logger.GetEvents()
@ -190,8 +203,8 @@ func (m *Manager) startSender() {
} }
} }
func (m *Manager) receiveACKs(client *client.GRPCClient) { func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error { err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
id, err := uuid.FromBytes(ack.EventId) id, err := uuid.FromBytes(ack.EventId)
if err != nil { if err != nil {
log.Warnf("failed to convert ack event id to uuid: %v", err) log.Warnf("failed to convert ack event id to uuid: %v", err)

View File

@ -0,0 +1,200 @@
package netflow
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
)
type mockIFaceMapper struct {
address wgaddr.Address
isUserspaceBind bool
}
func (m *mockIFaceMapper) Name() string {
return "wt0"
}
func (m *mockIFaceMapper) Address() wgaddr.Address {
return m.address
}
func (m *mockIFaceMapper) IsUserspaceBind() bool {
return m.isUserspaceBind
}
func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
statusRecorder := peer.NewRecorder("")
manager := NewManager(mockIFace, publicKey, statusRecorder)
tests := []struct {
name string
config *types.FlowConfig
}{
{
name: "nil config",
config: nil,
},
{
name: "disabled config",
config: &types.FlowConfig{
Enabled: false,
},
},
{
name: "enabled config with minimal valid settings",
config: &types.FlowConfig{
Enabled: true,
URL: "https://example.com",
TokenPayload: "test-payload",
TokenSignature: "test-signature",
Interval: 30 * time.Second,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := manager.Update(tc.config)
assert.NoError(t, err)
if tc.config == nil {
return
}
require.NotNil(t, manager.flowConfig)
if tc.config.Enabled {
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
}
if tc.config.URL != "" {
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
}
if tc.config.TokenPayload != "" {
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
}
})
}
}
func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
manager := NewManager(mockIFace, publicKey, nil)
// First update with tokens
initialConfig := &types.FlowConfig{
Enabled: false,
TokenPayload: "initial-payload",
TokenSignature: "initial-signature",
}
err := manager.Update(initialConfig)
require.NoError(t, err)
// Second update without tokens should preserve them
updatedConfig := &types.FlowConfig{
Enabled: false,
URL: "https://example.com",
}
err = manager.Update(updatedConfig)
require.NoError(t, err)
// Verify tokens were preserved
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
}
func TestManager_NeedsNewClient(t *testing.T) {
manager := &Manager{}
tests := []struct {
name string
previous *types.FlowConfig
current *types.FlowConfig
expected bool
}{
{
name: "nil previous config",
previous: nil,
current: &types.FlowConfig{},
expected: true,
},
{
name: "previous disabled",
previous: &types.FlowConfig{Enabled: false},
current: &types.FlowConfig{Enabled: true},
expected: true,
},
{
name: "different URL",
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
expected: true,
},
{
name: "different TokenPayload",
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
expected: true,
},
{
name: "different TokenSignature",
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
expected: true,
},
{
name: "same config",
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
expected: false,
},
{
name: "only interval changed",
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
manager.flowConfig = tc.current
result := manager.needsNewClient(tc.previous)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@ -120,9 +120,6 @@ type FlowLogger interface {
Close() Close()
// Enable enables the flow logger receiver // Enable enables the flow logger receiver
Enable() Enable()
// Disable disables the flow logger receiver
Disable()
// UpdateConfig updates the flow manager configuration // UpdateConfig updates the flow manager configuration
UpdateConfig(dnsCollection, exitNodeCollection bool) UpdateConfig(dnsCollection, exitNodeCollection bool)
} }

View File

@ -13,10 +13,12 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/flow/proto"
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
@ -77,17 +79,24 @@ func (c *GRPCClient) Close() error {
defer c.streamMu.Unlock() defer c.streamMu.Unlock()
c.stream = nil c.stream = nil
return c.clientConn.Close() if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("close client connection: %w", err)
}
return nil
} }
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error { func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
backOff := defaultBackoff(ctx, interval) backOff := defaultBackoff(ctx, interval)
operation := func() error { operation := func() error {
err := c.establishStreamAndReceive(ctx, msgHandler) if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
}
log.Errorf("receive failed: %v", err) log.Errorf("receive failed: %v", err)
return fmt.Errorf("receive: %w", err)
} }
return err return nil
} }
if err := backoff.Retry(operation, backOff); err != nil { if err := backoff.Retry(operation, backOff); err != nil {

256
flow/client/client_test.go Normal file
View File

@ -0,0 +1,256 @@
package client_test
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
flow "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
)
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
go func() {
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Logf("server error: %v", err)
}
}()
t.Cleanup(func() {
s.grpcSrv.Stop()
})
return s
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
go func() {
defer cancel()
for {
event, err := stream.Recv()
if err != nil {
return
}
if !event.IsInitiator {
select {
case s.events <- event:
ack := &proto.FlowEventAck{
EventId: event.EventId,
}
select {
case s.acks <- ack:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
}()
for {
select {
case ack := <-s.acks:
if err := stream.Send(ack); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func TestReceive(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
receivedAcks := make(map[string]bool)
receiveDone := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator && len(msg.EventId) > 0 {
id := string(msg.EventId)
receivedAcks[id] = true
if len(receivedAcks) >= 3 {
close(receiveDone)
}
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
for i := 0; i < 3; i++ {
eventID := uuid.New().String()
// Create acknowledgment and send it to the flow through our test server
ack := &proto.FlowEventAck{
EventId: []byte(eventID),
}
select {
case server.acks <- ack:
case <-time.After(time.Second):
t.Fatal("timeout sending ack")
}
}
select {
case <-receiveDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for acks to be processed")
}
assert.Equal(t, 3, len(receivedAcks))
}
func TestReceive_ContextCancellation(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
handlerCalled := false
msgHandler := func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator {
handlerCalled = true
}
return nil
}
err = client.Receive(ctx, 1*time.Second, msgHandler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
assert.False(t, handlerCalled)
}
func TestSend(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
ackReceived := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
if len(ack.EventId) > 0 && !ack.IsInitiator {
close(ackReceived)
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
testEvent := &proto.FlowEvent{
EventId: []byte("test-event-id"),
PublicKey: []byte("test-public-key"),
FlowFields: &proto.FlowFields{
FlowId: []byte("test-flow-id"),
Protocol: 6,
SourceIp: []byte{192, 168, 1, 1},
DestIp: []byte{192, 168, 1, 2},
ConnectionInfo: &proto.FlowFields_PortInfo{
PortInfo: &proto.PortInfo{
SourcePort: 12345,
DestPort: 443,
},
},
},
}
err = client.Send(testEvent)
require.NoError(t, err)
var receivedEvent *proto.FlowEvent
select {
case receivedEvent = <-server.events:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for event to be received by server")
}
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
select {
case <-ackReceived:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ack to be received by flow")
}
}