From 175674749f90613dd7cd7bab14c43f5dd1b7b246 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 25 Feb 2025 15:23:43 +0000 Subject: [PATCH] Add memory flow store (#3386) --- client/internal/engine.go | 75 +++++++++++++++++------ client/internal/flowstore/store.go | 79 +++++++++++++++++++++++++ client/internal/flowstore/store_test.go | 28 +++++++++ 3 files changed, 164 insertions(+), 18 deletions(-) create mode 100644 client/internal/flowstore/store.go create mode 100644 client/internal/flowstore/store_test.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 3d7802675..3f0732be8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" + "github.com/netbirdio/netbird/client/internal/flowstore" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" @@ -189,6 +190,7 @@ type Engine struct { persistNetworkMap bool latestNetworkMap *mgmProto.NetworkMap connSemaphore *semaphoregroup.SemaphoreGroup + flowStore flowstore.Store } // Peer is an instance of the Connection Peer @@ -320,6 +322,13 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + if e.flowStore != nil { + if err := e.flowStore.Close(); err != nil { + e.flowStore = nil + log.Errorf("failed to close flow store: %v", err) + } + } + return nil } @@ -642,25 +651,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { stunTurn = append(stunTurn, e.TURNs...) e.stunTurn.Store(stunTurn) - relayMsg := wCfg.GetRelay() - if relayMsg != nil { - // when we receive token we expect valid address list too - c := &auth.Token{ - Payload: relayMsg.GetTokenPayload(), - Signature: relayMsg.GetTokenSignature(), - } - if err := e.relayManager.UpdateToken(c); err != nil { - log.Errorf("failed to update relay token: %v", err) - return fmt.Errorf("update relay token: %w", err) - } + err = e.handleRelayUpdate(wCfg.GetRelay()) + if err != nil { + return err + } - e.relayManager.UpdateServerURLs(relayMsg.Urls) - - // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. - // We can ignore all errors because the guard will manage the reconnection retries. - _ = e.relayManager.Serve() - } else { - e.relayManager.UpdateServerURLs(nil) + err = e.handleFlowUpdate(wCfg.GetFlow()) + if err != nil { + return fmt.Errorf("handle the flow configuration: %w", err) } // todo update signal @@ -691,6 +689,47 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return nil } +func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error { + if update != nil { + // when we receive token we expect valid address list too + c := &auth.Token{ + Payload: update.GetTokenPayload(), + Signature: update.GetTokenSignature(), + } + if err := e.relayManager.UpdateToken(c); err != nil { + return fmt.Errorf("update relay token: %w", err) + } + + e.relayManager.UpdateServerURLs(update.Urls) + + // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. + // We can ignore all errors because the guard will manage the reconnection retries. + _ = e.relayManager.Serve() + } else { + e.relayManager.UpdateServerURLs(nil) + } + + return nil +} + +func (e *Engine) handleFlowUpdate(update *mgmProto.FlowConfig) error { + if update == nil { + return nil + } + + if update.GetEnabled() && e.flowStore == nil { + e.flowStore = flowstore.New(e.ctx) + return nil + } + + if !update.GetEnabled() && e.flowStore != nil { + err := e.flowStore.Close() + e.flowStore = nil + return err + } + return nil +} + // updateChecksIfNew updates checks if there are changes and sync new meta with management func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { // if checks are equal, we skip the update diff --git a/client/internal/flowstore/store.go b/client/internal/flowstore/store.go new file mode 100644 index 000000000..27ff81718 --- /dev/null +++ b/client/internal/flowstore/store.go @@ -0,0 +1,79 @@ +package flowstore + +import ( + "context" + "io" + "sync" + + log "github.com/sirupsen/logrus" +) + +type Event struct { + ID string + FlowID string +} + +type Store interface { + io.Closer + // stores a flow event + StoreEvent(flowEvent Event) + // returns all stored events + GetEvents() []*Event +} + +func New(ctx context.Context) Store { + ctx, cancel := context.WithCancel(ctx) + store := &memory{ + events: make(map[string]*Event), + rcvChan: make(chan *Event, 100), + ctx: ctx, + cancel: cancel, + } + go store.startReceiver() + return store +} + +type memory struct { + mux sync.Mutex + events map[string]*Event + rcvChan chan *Event + ctx context.Context + cancel context.CancelFunc +} + +func (m *memory) startReceiver() { + for { + select { + case <-m.ctx.Done(): + log.Info("flow memory store receiver stopped") + return + case event := <-m.rcvChan: + m.mux.Lock() + m.events[event.ID] = event + m.mux.Unlock() + } + } +} + +func (m *memory) StoreEvent(flowEvent Event) { + select { + case m.rcvChan <- &flowEvent: + default: + log.Warn("flow memory store receiver is busy") + } +} + +func (m *memory) Close() error { + m.cancel() + return nil +} + +func (m *memory) GetEvents() []*Event { + m.mux.Lock() + defer m.mux.Unlock() + events := make([]*Event, 0, len(m.events)) + for _, event := range m.events { + events = append(events, event) + } + return events +} diff --git a/client/internal/flowstore/store_test.go b/client/internal/flowstore/store_test.go new file mode 100644 index 000000000..c0f2e1216 --- /dev/null +++ b/client/internal/flowstore/store_test.go @@ -0,0 +1,28 @@ +package flowstore_test + +import ( + "context" + "testing" + + "github.com/netbirdio/netbird/client/internal/flowstore" +) + +func TestStore(t *testing.T) { + store := flowstore.New(context.Background()) + t.Cleanup(func() { + store.Close() + }) + + event := flowstore.Event{ + ID: "1", + FlowID: "1", + } + + store.StoreEvent(event) + allEvents := store.GetEvents() + for _, e := range allEvents { + if e.ID != event.ID { + t.Errorf("expected event ID %s, got %s", event.ID, e.ID) + } + } +}