From f48cfd52e9579a8b12e16053e77137a7039bfb2b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 28 Feb 2025 00:28:17 +0000 Subject: [PATCH] fix logger stop (#3403) * fix logger stop * use context to stop receiver * update test --- client/internal/netflow/logger/logger.go | 36 +++++++++++-------- client/internal/netflow/logger/logger_test.go | 33 +++++++++++++++-- 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index d1afa7667..640f87c49 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "sync" "sync/atomic" "time" @@ -14,21 +15,21 @@ import ( type rcvChan chan *types.EventFields type Logger struct { - ctx context.Context - cancel context.CancelFunc - enabled atomic.Bool - rcvChan atomic.Pointer[rcvChan] - stopChan chan struct{} - Store types.Store + mux sync.Mutex + ctx context.Context + cancel context.CancelFunc + enabled atomic.Bool + rcvChan atomic.Pointer[rcvChan] + cancelReceiver context.CancelFunc + Store types.Store } func New(ctx context.Context) *Logger { ctx, cancel := context.WithCancel(ctx) return &Logger{ - ctx: ctx, - cancel: cancel, - Store: store.NewMemoryStore(), - stopChan: make(chan struct{}), + ctx: ctx, + cancel: cancel, + Store: store.NewMemoryStore(), } } @@ -57,6 +58,10 @@ func (l *Logger) startReceiver() { if l.enabled.Load() { return } + l.mux.Lock() + ctx, cancel := context.WithCancel(l.ctx) + l.cancelReceiver = cancel + l.mux.Unlock() c := make(rcvChan, 100) l.rcvChan.Swap(&c) @@ -64,7 +69,7 @@ func (l *Logger) startReceiver() { for { select { - case <-l.ctx.Done(): + case <-ctx.Done(): log.Info("flow Memory store receiver stopped") return case eventFields := <-c: @@ -75,8 +80,6 @@ func (l *Logger) startReceiver() { Timestamp: time.Now(), } l.Store.StoreEvent(&event) - case <-l.stopChan: - return } } } @@ -92,7 +95,12 @@ func (l *Logger) stop() { } l.enabled.Store(false) - l.stopChan <- struct{}{} + l.mux.Lock() + if l.cancelReceiver != nil { + l.cancelReceiver() + l.cancelReceiver = nil + } + l.mux.Unlock() } func (l *Logger) GetEvents() []*types.Event { diff --git a/client/internal/netflow/logger/logger_test.go b/client/internal/netflow/logger/logger_test.go index cb0913639..e986118ec 100644 --- a/client/internal/netflow/logger/logger_test.go +++ b/client/internal/netflow/logger/logger_test.go @@ -21,9 +21,11 @@ func TestStore(t *testing.T) { Direction: types.Ingress, Protocol: 6, } - time.Sleep(time.Millisecond) + + wait := func() { time.Sleep(time.Millisecond) } + wait() logger.StoreEvent(event) - time.Sleep(time.Millisecond) + wait() allEvents := logger.GetEvents() matched := false @@ -35,4 +37,31 @@ func TestStore(t *testing.T) { if !matched { t.Errorf("didn't match any event") } + + // test disable + logger.Disable() + wait() + logger.StoreEvent(event) + wait() + allEvents = logger.GetEvents() + if len(allEvents) != 0 { + t.Errorf("expected 0 events, got %d", len(allEvents)) + } + + // test re-enable + logger.Enable() + wait() + logger.StoreEvent(event) + wait() + + allEvents = logger.GetEvents() + matched = false + for _, e := range allEvents { + if e.EventFields.FlowID == event.FlowID { + matched = true + } + } + if !matched { + t.Errorf("didn't match any event") + } }