diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index e1e1ff236..9457d3a66 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2d743f790..dacb1922b 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif + ignore_words_list: erro,clienta,hastable,iif,groupd skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/go.mod b/go.mod index e7e3c17a6..a6b83794d 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/r3labs/diff/v3 v3.0.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -210,6 +211,8 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect + github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go.sum b/go.sum index e9bc318d6..412542d5e 100644 --- a/go.sum +++ b/go.sum @@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= +github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/account.go b/management/server/account.go index b49b82f91..a8a244bdf 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -102,7 +102,6 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) @@ -2132,8 +2131,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting account: %w", err) } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } } return nil diff --git a/management/server/account_test.go b/management/server/account_test.go index 19514dad1..3c3fcebc6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1122,66 +1122,196 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) +func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) return } - userID := "account_creator" - - account, err := createAccount(manager, "test_account", userID, "") - if err != nil { - t.Fatal(err) + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, } + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - t.Fatal("error creating setup key") - return - } + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() - getPeer := func() *nbpeer.Peer { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return nil + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) } - expectedPeerKey := key.PublicKey().String() + }() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ - Key: expectedPeerKey, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - if err != nil { - t.Fatalf("expecting peer1 to be added, got failure %v", err) - return nil - } - - return peer + group.Peers = []string{peer1.ID, peer2.ID, peer3.ID} + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return } - peer1 := getPeer() - peer2 := getPeer() - peer3 := getPeer() + wg.Wait() +} - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) +func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) return } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + group := group.Group{ - ID: "group-id", + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return + } + + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("save policy: %v", err) + return + } + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 1 { + t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { + t.Errorf("delete peer: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + group := group.Group{ + ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, } @@ -1191,116 +1321,48 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { Rules: []*PolicyRule{ { Enabled: true, - Sources: []string{"group-id"}, - Destinations: []string{"group-id"}, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, Bidirectional: true, Action: PolicyTrafficActionAccept, }, }, } + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("save policy: %v", err) + return + } + wg := sync.WaitGroup{} - t.Run("save group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Add(1) + go func() { + defer wg.Done() - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { - t.Errorf("save group: %v", err) - return + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + // clean policy is pre requirement for delete group + if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } - t.Run("delete policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + t.Errorf("delete group: %v", err) + return + } - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - wg.Wait() - }) - - t.Run("save policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - wg.Wait() - }) - t.Run("delete peer update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 1 { - t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { - t.Errorf("delete peer: %v", err) - return - } - - wg.Wait() - }) - - t.Run("delete group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) - - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { - t.Errorf("delete group: %v", err) - return - } - - wg.Wait() - }) + wg.Wait() } func TestAccountManager_DeletePeer(t *testing.T) { @@ -2754,3 +2816,73 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true } } + +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { + t.Helper() + + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + account, err := createAccount(manager, "test_account", userID, "") + if err != nil { + t.Fatal(err) + } + + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + if err != nil { + t.Fatal("error creating setup key") + } + + getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + expectedPeerKey := key.PublicKey().String() + + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + Status: &nbpeer.PeerStatus{ + Connected: true, + LastSeen: time.Now().UTC(), + }, + }) + if err != nil { + t.Fatalf("expecting peer to be added, got failure %v", err) + } + + return peer + } + + peer1 := getPeer(manager, setupKey) + peer2 := getPeer(manager, setupKey) + peer3 := getPeer(manager, setupKey) + + return manager, account, peer1, peer2, peer3 +} + +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + case <-time.After(500 * time.Millisecond): + t.Error("Timed out waiting for update message") + } +} diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go new file mode 100644 index 000000000..de4aa334c --- /dev/null +++ b/management/server/differs/netip.go @@ -0,0 +1,82 @@ +package differs + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/r3labs/diff/v3" +) + +// NetIPAddr is a custom differ for netip.Addr +type NetIPAddr struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPAddr) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) +} + +func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromAddr, ok1 := a.Interface().(netip.Addr) + toAddr, ok2 := b.Interface().(netip.Addr) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromAddr.String() != toAddr.String() { + cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) + } + + return nil +} + +func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} + +// NetIPPrefix is a custom differ for netip.Prefix +type NetIPPrefix struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPPrefix) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) +} + +func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromPrefix, ok1 := a.Interface().(netip.Prefix) + toPrefix, ok2 := b.Interface().(netip.Prefix) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromPrefix.String() != toPrefix.String() { + cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) + } + + return nil +} + +func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} diff --git a/management/server/dns.go b/management/server/dns.go index 7410aaa15..256b8b125 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -125,26 +125,29 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID oldSettings := account.DNSSettings.Copy() account.DNSSettings = dnsSettingsToSave.Copy() + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) for _, id := range addedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) } - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) for _, id := range removedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - am.updateAccountPeers(ctx, account) + if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + am.updateAccountPeers(ctx, account) + } return nil } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c7f435b68..c675fc12c 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -6,9 +6,11 @@ import ( "net/netip" "reflect" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Cache should contain name server group 'group2'") } } + +func TestDNSAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates + t.Run("saving dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + // Saving DNS settings with groups that have peers should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA", "groupB"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged DNS settings with used groups should update account peers and not send peer update + // since there is no change in the network map + t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA", "groupB"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates + t.Run("removing group with no peers from dns settings", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing group with peers from DNS settings should trigger updates to account peers and send peer updates + t.Run("removing group with peers from dns settings", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/group.go b/management/server/group.go index aa387c058..bdb569e37 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -121,12 +121,19 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user eventsToStore = append(eventsToStore, events...) } + newGroupIDs := make([]string, 0, len(newGroups)) + for _, newGroup := range newGroups { + newGroupIDs = append(newGroupIDs, newGroup.ID) + } + account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, newGroupIDs) { + am.updateAccountPeers(ctx, account) + } for _, storeEvent := range eventsToStore { storeEvent() @@ -238,8 +245,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) - am.updateAccountPeers(ctx, account) - return nil } @@ -282,8 +287,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) } - am.updateAccountPeers(ctx, account) - return allErrors } @@ -336,7 +339,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, []string{group.ID}) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -366,7 +371,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, []string{group.ID}) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -469,3 +476,32 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { } return false, nil } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if group, exists := account.Groups[groupID]; exists && group.HasPeers() { + return true + } + } + return false +} + +func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { + return true + } + if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { + return true + } + if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { + return true + } + if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { + return true + } + } + + return false +} diff --git a/management/server/group/group.go b/management/server/group/group.go index 79dfd995c..d293e1afc 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -44,3 +44,8 @@ func (g *Group) Copy() *Group { copy(group.Peers, g.Peers) return group } + +// HasPeers checks if the group has any peers. +func (g *Group) HasPeers() bool { + return len(g.Peers) > 0 +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 89b68ad6c..1e59b74ef 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -4,13 +4,16 @@ import ( "context" "errors" "fmt" + "net/netip" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A } return am, acc, nil } + +func TestGroupAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1.ID, peer3.ID}, + }, + { + ID: "groupD", + Name: "GroupD", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving a group that is not linked to any resource should not update account peers + t.Run("saving unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding a peer to a group that is not linked to any resource should not update account peers + // and not send peer update + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing a peer from a group that is not linked to any resource should not update account peers + // and not send peer update + t.Run("removing peer from unliked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting group should not update account peers and not send peer update + t.Run("deleting group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB") + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // adding a group to policy + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }, false) + assert.NoError(t, err) + + // Saving a group linked to policy should update account peers and send peer update + t.Run("saving linked group to policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving an unchanged group should trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("saving unchanged group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // adding peer to a used group should update account peers and send peer update + t.Run("adding peer to linked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // removing peer from a linked group should update account peers and send peer update + t.Run("removing peer from linked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to name server group should update account peers and send peer update + t.Run("saving group linked to name server group", func(t *testing.T) { + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupC"}, + true, nil, true, userID, false, + ) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to route should update account peers and send peer update + t.Run("saving group linked to route", func(t *testing.T) { + newRoute := route.Route{ + ID: "route", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupA"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupC"}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to dns settings should update account peers and send peer update + t.Run("saving group linked to dns settings", func(t *testing.T) { + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupD"}, + }) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupD", + Name: "GroupD", + Peers: []string{peer1.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 74557e227..681bf533a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -57,7 +57,6 @@ type MockAccountManager struct { GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -434,14 +433,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ( return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } -// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager -func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if am.UpdatePeerSSHKeyFunc != nil { - return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey) - } - return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented") -} - // UpdatePeer mocks UpdatePeerFunc function of the account manager func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { if am.UpdatePeerFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 0eb5d9ae4..5ebd263dc 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -66,13 +66,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco account.NameServerGroups[newNSGroup.ID] = newNSGroup account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, newNSGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil @@ -80,7 +80,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -98,16 +97,17 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } + oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) - + if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil @@ -131,13 +131,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco delete(account.NameServerGroups, nsGroupID) account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, nsGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil @@ -277,3 +277,11 @@ func validateDomain(domain string) error { return nil } + +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false + } + return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 8a3fe6eb0..96637cd39 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -4,7 +4,9 @@ import ( "context" "net/netip" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" @@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) { } } + +func TestNameServerAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + var newNameServerGroupA *nbdns.NameServerGroup + var newNameServerGroupB *nbdns.NameServerGroup + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Creating a nameserver group with a distribution group no peers should not update account peers + // and not send peer update + t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupA, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // saving a nameserver group with a distribution group with no peers should not update account peers + // and not send peer update + t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Creating a nameserver group with a distribution group no peers should update account peers and send peer update + t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupB"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // saving a nameserver group with a distribution group with peers should update account peers and send peer update + t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB.NameServers = []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + } + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // saving unchanged nameserver group should update account peers and not send peer update + t.Run("saving unchanged nameserver group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB.NameServers = []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + } + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting a nameserver group should update account peers and send peer update + t.Run("deleting nameserver group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/network.go b/management/server/network.go index a5b188b46..8fb6a8b3c 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -41,9 +41,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 + Serial uint64 `diff:"-"` - mu sync.Mutex `json:"-" gorm:"-"` + mu sync.Mutex `json:"-" gorm:"-" diff:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer.go b/management/server/peer.go index a4c7e1266..80d43497a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "slices" "strings" "sync" "time" @@ -200,7 +201,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - if peer.Name != update.Name { + peerLabelUpdated := peer.Name != update.Name + + if peerLabelUpdated { peer.Name = update.Name existingLabels := account.getPeerDNSLabels() @@ -260,7 +263,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - am.updateAccountPeers(ctx, account) + if peerLabelUpdated { + am.updateAccountPeers(ctx, account) + } return peer, nil } @@ -304,6 +309,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, + NetworkMap: &NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -322,6 +328,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + updateAccountPeers := isPeerInActiveGroup(account, peerID) + err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { return err @@ -332,7 +340,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -422,9 +432,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer + var groupsToAdd []string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - var groupsToAdd []string var setupKeyID string var setupKeyName string var ephemeral bool @@ -576,7 +586,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, fmt.Errorf("error getting account: %w", err) } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, groupsToAdd) { + am.updateAccountPeers(ctx, account) + } approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { @@ -897,51 +909,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings return false } -// UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if sshKey == "" { - log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) - return nil - } - - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, account.Id) - if err != nil { - return err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - if peer.SSHKey == sshKey { - log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) - return nil - } - - peer.SSHKey = sshKey - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err - } - - // trigger network map update - am.updateAccountPeers(ctx, account) - - return nil -} - // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1034,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks := am.getPeerPostureChecks(account, p) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1048,3 +1015,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { } return labelMap } + +// IsPeerInActiveGroup checks if the given peer is part of a group that is used +// in an active DNS, route, or ACL configuration. +func isPeerInActiveGroup(account *Account, peerID string) bool { + peerGroupIDs := make([]string, 0) + for _, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + peerGroupIDs = append(peerGroupIDs, group.ID) + } + } + return areGroupChangesAffectPeers(account, peerGroupIDs) +} diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 9a53459a8..ef96bce7d 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -17,37 +17,37 @@ type Peer struct { // WireGuard public key Key string `gorm:"index"` // A setup key this peer was registered with - SetupKey string + SetupKey string `diff:"-"` // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` // The user ID that registered the peer - UserID string + UserID string `diff:"-"` // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool + LoginExpirationEnabled bool `diff:"-"` - InactivityExpirationEnabled bool + InactivityExpirationEnabled bool `diff:"-"` // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin time.Time `diff:"-"` // CreatedAt records the time the peer was created - CreatedAt time.Time + CreatedAt time.Time `diff:"-"` // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `diff:"-"` // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_"` + Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` } type PeerStatus struct { //nolint:revive @@ -189,7 +189,6 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, - InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c5edb5636..7b2180bf0 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1253,3 +1253,322 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) } + +func TestPeerAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) + require.NoError(t, err) + + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + require.NoError(t, err) + + // create a user with auto groups + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + { + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupA"}, + }, + { + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupB"}, + }, + { + Id: "regularUser3", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupC"}, + }, + }, true) + require.NoError(t, err) + + var peer4 *nbpeer.Peer + var peer5 *nbpeer.Peer + var peer6 *nbpeer.Peer + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update + t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding peer to unlinked group should not update account peers and not send peer update + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting peer with unlinked group should not update account peers and not send peer update + t.Run("deleting peer with unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating peer label should update account peers and send peer update + t.Run("updating peer label", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + peer1.Name = "peer-1" + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with policy should update account peers and send peer update + t.Run("adding peer to group linked with policy", func(t *testing.T) { + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }, false) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to policy should update account peers and send peer update + t.Run("deleting peer with linked group to policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with route should update account peers and send peer update + t.Run("adding peer to group linked with route", func(t *testing.T) { + route := nbroute.Route{ + ID: "testingRoute1", + Network: netip.MustParsePrefix("100.65.250.202/32"), + NetID: "superNet", + NetworkType: nbroute.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupB"}, + } + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to route should update account peers and send peer update + t.Run("deleting peer with linked group to route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with name server group should update account peers and send peer update + t.Run("adding peer to group linked with name server group", func(t *testing.T) { + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupC"}, + true, []string{}, true, userID, false, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to name server group should update account peers and send peer update + t.Run("deleting peer with linked group to route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/policy.go b/management/server/policy.go index 75647de44..055542430 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() { } } +// ruleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) ruleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} + // FirewallRule is a rule of the firewall. type FirewallRule struct { // PeerIP of the peer @@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = am.savePolicy(account, policy, isUpdate); err != nil { + updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if err != nil { return err } @@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) // savePolicy saves or updates a policy in the given account. // If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { for index, rule := range policyToSave.Rules { rule.Sources = filterValidGroupIDs(account, rule.Sources) rule.Destinations = filterValidGroupIDs(account, rule.Destinations) @@ -442,18 +457,25 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if isUpdate { policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) if policyIdx < 0 { - return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } + oldPolicy := account.Policies[policyIdx] // Update the existing policy account.Policies[policyIdx] = policyToSave - return nil + + if !policyToSave.Enabled && !oldPolicy.Enabled { + return false, nil + } + updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + + return updateAccountPeers, nil } // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return nil + return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/policy_test.go b/management/server/policy_test.go index bf9a53d16..5b1411702 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -5,7 +5,9 @@ import ( "fmt" "net" "testing" + "time" + "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { return 0 // a is equal to b } } + +func TestPolicyAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + { + ID: "groupD", + Name: "GroupD", + Peers: []string{peer1.ID, peer2.ID}, + }, + }) + assert.NoError(t, err) + + updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + }) + + // Saving policy with rule groups with no peers should not update account's peers and not send peer update + t.Run("saving policy with rule groups with no peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-rule-groups-no-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Saving policy with source group containing peers, but destination group without peers should + // update account's peers and send peer update + t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-has-peers-destination-none", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving policy with destination group containing peers, but source group without peers should + // update account's peers and send peer update + t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { + policy := Policy{ + ID: "policy-destination-has-peers-source-none", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: false, + Sources: []string{"groupC"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg2) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Disabling policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: false, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating disabled policy with destination and source groups containing peers should not update account's peers + // or send peer update + t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Description: "updated description", + Enabled: false, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Enabling policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged policy should trigger account peers update but not send peer update + t.Run("saving unchanged policy", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting policy should trigger account peers update and send peer update + t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { + policyID := "policy-source-destination-peers" + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) + + // Deleting policy with destination group containing peers, but source group without peers should + // update account's peers and send peer update + t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { + policyID := "policy-destination-has-peers-source-none" + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg2) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting policy with no peers in groups should not update account's peers and not send peer update + t.Run("deleting policy with no peers in groups", func(t *testing.T) { + policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2 + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9a4b679ce..2dccd8f59 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -67,7 +67,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if exists { + + if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { am.updateAccountPeers(ctx, account) } @@ -148,13 +149,9 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) } - // check policy links - for _, policy := range account.Policies { - for _, id := range policy.SourcePostureChecks { - if id == postureChecksID { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) - } - } + // Check if posture check is linked to any policy + if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { + return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) } postureChecks := account.PostureChecks[postureChecksIdx] @@ -217,3 +214,25 @@ func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks } } } + +func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { + for _, policy := range account.Policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return true, policy + } + } + return false, nil +} + +// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { + if !exists { + return false + } + + isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) + if !isLinked { + return false + } + return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) +} diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index d837120f4..7d31956f9 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -3,7 +3,10 @@ package server import ( "context" "testing" + "time" + "github.com/netbirdio/netbird/management/server/group" + "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/posture" @@ -118,3 +121,458 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { return am.Store.GetAccount(context.Background(), account.Id) } + +func TestPostureCheckAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + postureCheck := posture.Checks{ + ID: "postureCheck", + Name: "postureCheck", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.28.0", + }, + }, + } + + // Saving unused posture check should not update account peers and not send peer update + t.Run("saving unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating unused posture check should not update account peers and not send peer update + t.Run("updating unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + policy := Policy{ + ID: "policyA", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + + // Linking posture check to policy should trigger update account peers and send peer update + t.Run("linking posture check to policy with peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating linked posture checks should update account peers and send peer update + t.Run("updating linked to posture check with peers", func(t *testing.T) { + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged posture check should not trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("saving unchanged posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing posture check from policy should trigger account peers update and send peer update + t.Run("removing posture check from policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policy.SourcePostureChecks = []string{} + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting unused posture check should not trigger account peers update and not send peer update + t.Run("deleting unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update + t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating linked posture check to policy where destination has peers but source does not + // should trigger account peers update and send peer update + t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { + updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + }) + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating linked posture check to policy where source has peers but destination does not, + // should not trigger account peers update or send peer update + t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating linked client posture check to policy where source has peers but destination does not, + // should trigger account peers update and send peer update + t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + }, + }, + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} + +func TestArePostureCheckChangesAffectingPeers(t *testing.T) { + account := &Account{ + Policies: []*Policy{ + { + ID: "policyA", + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{"checkA"}, + }, + }, + Groups: map[string]*group.Group{ + "groupA": { + ID: "groupA", + Peers: []string{"peer1"}, + }, + "groupB": { + ID: "groupB", + Peers: []string{}, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: "checkA", + }, + { + ID: "checkB", + }, + }, + } + + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + assert.False(t, result) + }) + + t.Run("posture check does not exist", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + assert.False(t, result) + }) + + t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"groupB"} + account.Policies[0].Rules[0].Destinations = []string{"groupA"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"groupA"} + account.Policies[0].Rules[0].Destinations = []string{"groupB"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} + account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.False(t, result) + }) + + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + account.Groups["groupA"].Peers = []string{} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.False(t, result) + }) +} diff --git a/management/server/route.go b/management/server/route.go index 39ee6170c..1cf00b37c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, &newRoute) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + oldRoute := account.Routes[routeToSave.ID] account.Routes[routeToSave.ID] = routeToSave account.Network.IncSerial() @@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, routy) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { } return &portInfo } + +// isRouteChangeAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers +func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 09cbe53ff..a4b320c7e 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "testing" + "time" "github.com/rs/xid" "github.com/stretchr/testify/assert" @@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }) } + +func TestRouteAccountPeersUpdate(t *testing.T) { + manager, err := createRouterManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestRouteAccount(t, manager) + require.NoError(t, err, "failed to init testing account") + + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + }) + + // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update + t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) { + route := route.Route{ + ID: "testingRoute1", + Network: netip.MustParsePrefix("100.65.250.202/32"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupA"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupA"}, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + + }) + + // Creating a route with no routing peer and having peers in groups should update account peers and send peer update + t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) { + route := route.Route{ + ID: "testingRoute2", + Network: netip.MustParsePrefix("192.0.2.0/32"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroup3}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup3}, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) + + baseRoute := route.Route{ + ID: "testingRoute3", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + + // Creating route should update account peers and send peer update + t.Run("creating route with a routing peer", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newRoute, err := manager.CreateRoute( + context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, + baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, + baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, + ) + require.NoError(t, err) + baseRoute = *newRoute + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating the route should update account peers and send peer update when there is peers in group + t.Run("updating route", func(t *testing.T) { + baseRoute.Groups = []string{routeGroup1, routeGroup2} + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating unchanged route should update account peers and not send peer update + t.Run("updating unchanged route", func(t *testing.T) { + baseRoute.Groups = []string{routeGroup1, routeGroup2} + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting the route should update account peers and send peer update + t.Run("deleting route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to route peer groups that do not have any peers should update account peers and send peer update + t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) { + newRoute := route.Route{ + Network: netip.MustParsePrefix("192.168.12.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to route groups that do not have any peers should update account peers and send peer update + t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) { + newRoute := route.Route{ + Network: netip.MustParsePrefix("192.168.13.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupC"}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 9521e22d3..e84f8fcd6 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -323,8 +323,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } }() - am.updateAccountPeers(ctx, account) - return newKey, nil } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index aa5075b02..651b54010 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -352,3 +353,73 @@ func TestSetupKey_Copy(t *testing.T) { key.UpdatedAt, key.AutoGroups) } + +func TestSetupKeyAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"group"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + var setupKey *SetupKey + + // Creating setup key should not update account peers and not send peer update + t.Run("creating setup key", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Saving setup key should not update account peers and not send peer update + t.Run("saving setup key", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) +} diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 32a59128b..168973cad 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -26,8 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO installations VALUES(1,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 0188cef52..6fb96c971 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,9 +2,13 @@ package server import ( "context" + "fmt" + "runtime/debug" "sync" "time" + "github.com/netbirdio/netbird/management/server/differs" + "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -14,14 +18,17 @@ import ( const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + NetworkMap *NetworkMap } type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage + // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. + peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels - channelsMux *sync.Mutex + channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } @@ -29,9 +36,10 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - channelsMux: &sync.Mutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + peerUpdateMessage: make(map[string]*UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool + // skip sending sync update to the peer if there is no change in update message, + // it will not check on turn credential refresh as we do not send network map or client posture checks + if update.NetworkMap != nil { + updated := p.handlePeerMessageUpdate(ctx, peerID, update) + if !updated { + return + } + } + p.channelsMux.Lock() + defer func() { p.channelsMux.Unlock() if p.metrics != nil { @@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() + if update.NetworkMap != nil { + lastSentUpdate := p.peerUpdateMessage[peerID] + if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() { + log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", + peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) + return + } + p.peerUpdateMessage[peerID] = update + } + if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -170,3 +200,72 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. +func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { + p.channelsMux.RLock() + lastSentUpdate := p.peerUpdateMessage[peerID] + p.channelsMux.RUnlock() + + if lastSentUpdate != nil { + updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update) + if err != nil { + log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) + return false + } + if !updated { + log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) + return false + } + } + + return true +} + +// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. +func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) { + defer func() { + if r := recover(); r != nil { + log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) + isNew, err = true, nil + } + }() + + if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { + return false, nil + } + + differ, err := diff.NewDiffer( + diff.CustomValueDiffers(&differs.NetIPAddr{}), + diff.CustomValueDiffers(&differs.NetIPPrefix{}), + ) + if err != nil { + return false, fmt.Errorf("failed to create differ: %v", err) + } + + lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks) + currFiles := getChecksFiles(currUpdateToSend.Update.Checks) + + changelog, err := differ.Diff(lastSentFiles, currFiles) + if err != nil { + return false, fmt.Errorf("failed to diff checks: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) + if err != nil { + return false, fmt.Errorf("failed to diff network map: %v", err) + } + return len(changelog) > 0, nil +} + +// getChecksFiles returns a list of files from the given checks. +func getChecksFiles(checks []*proto.Checks) []string { + files := make([]string, 0, len(checks)) + for _, check := range checks { + files = append(files, check.GetFiles()...) + } + return files +} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c..52b715e95 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -2,10 +2,19 @@ package server import ( "context" + "net" + "net/netip" "testing" "time" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/util" + "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -77,3 +86,470 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } + +func TestHandlePeerMessageUpdate(t *testing.T) { + tests := []struct { + name string + peerID string + existingUpdate *UpdateMessage + newUpdate *UpdateMessage + expectedResult bool + }{ + { + name: "update message with turn credentials update", + peerID: "peer", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{}, + }, + }, + expectedResult: true, + }, + { + name: "update message for peer without existing update", + peerID: "peer1", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with no changes in update", + peerID: "peer2", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + { + name: "update message with changes in checks", + peerID: "peer3", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + Checks: []*proto.Checks{ + { + Files: []string{"/usr/bin/netbird"}, + }, + }, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with lower serial number", + peerID: "peer4", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPeersUpdateManager(nil) + ctx := context.Background() + + if tt.existingUpdate != nil { + p.peerUpdateMessage[tt.peerID] = tt.existingUpdate + } + + result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestIsNewPeerUpdateMessage(t *testing.T) { + t.Run("Unchanged value", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + }) + + t.Run("Unchanged value with serial incremented", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + }) + + t.Run("Updating routes network", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + + }) + + t.Run("Updating routes groups", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating network map peers", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newPeer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.4"), + SSHEnabled: true, + Key: "peer4-key", + DNSLabel: "peer4", + SSHKey: "peer4-ssh-key", + } + newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating process check", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + + newUpdateMessage2 := createMockUpdateMessage(t) + newUpdateMessage2.Update.NetworkMap.Serial++ + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + + newUpdateMessage3 := createMockUpdateMessage(t) + newUpdateMessage3.Update.Checks = []*proto.Checks{} + newUpdateMessage3.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3) + assert.NoError(t, err) + assert.True(t, message) + + newUpdateMessage4 := createMockUpdateMessage(t) + check := &posture.Checks{ + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/local/netbird", + MacPath: "/usr/bin/netbird", + }, + }, + }, + }, + } + newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} + newUpdateMessage4.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4) + assert.NoError(t, err) + assert.True(t, message) + + newUpdateMessage5 := createMockUpdateMessage(t) + check = &posture.Checks{ + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", + MacPath: "/usr/local/netbird", + }, + }, + }, + }, + } + newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} + newUpdateMessage5.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating DNS configuration", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newDomain := "newexample.com" + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append( + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains, + newDomain, + ) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating peer IP", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating firewall rule", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Add new firewall rule", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newRule := &FirewallRule{ + PeerIP: "192.168.1.3", + Direction: firewallRuleDirectionOUT, + Action: string(PolicyTrafficActionDrop), + Protocol: string(PolicyRuleProtocolUDP), + Port: "53", + } + newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Removing nameserver", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating name server IP", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating custom DNS zone", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + +} + +func createMockUpdateMessage(t *testing.T) *UpdateMessage { + t.Helper() + + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + if err != nil { + t.Fatal(err) + } + domainList, err := domain.FromStringList([]string{"example.com"}) + if err != nil { + t.Fatal(err) + } + + config := &Config{ + Signal: &Host{ + Proto: "https", + URI: "signal.uri", + Username: "", + Password: "", + }, + Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, + TURNConfig: &TURNConfig{ + Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, + }, + } + peer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.1"), + SSHEnabled: true, + Key: "peer-key", + DNSLabel: "peer1", + SSHKey: "peer1-ssh-key", + } + + secretManager := NewTimeBasedAuthSecretsManager( + NewPeersUpdateManager(nil), + &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{ + Duration: defaultDuration, + }, + Secret: "secret", + Turns: []*Host{TurnTestHost}, + }, + &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "secret", + }, + ) + + networkMap := &NetworkMap{ + Network: &Network{Net: *ipNet, Serial: 1000}, + Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Routes: []*nbroute.Route{ + { + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + KeepRoute: true, + NetID: "route1", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + { + ID: "route2", + Domains: domainList, + KeepRoute: true, + NetID: "route2", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + }, + DNSConfig: nbdns.Config{ + ServiceEnable: true, + NameServerGroups: []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + { + ID: "ns1", + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Groups: []string{"group1"}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + }, + CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, + }, + FirewallRules: []*FirewallRule{ + {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + }, + } + dnsName := "example.com" + checks := []*posture.Checks{ + { + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", + MacPath: "/usr/bin/netbird", + }, + }, + }, + }, + }, + } + dnsCache := &DNSConfigCache{} + + turnToken, err := secretManager.GenerateTurnToken() + if err != nil { + t.Fatal(err) + } + + relayToken, err := secretManager.GenerateRelayToken() + if err != nil { + t.Fatal(err) + } + + return &UpdateMessage{ + Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), + NetworkMap: networkMap, + } +} diff --git a/management/server/user.go b/management/server/user.go index 71608ef20..9fdd3a6ee 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" @@ -473,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init } func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err } @@ -485,15 +486,22 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { peers, err := account.FindUserPeers(targetUserID) if err != nil { - return status.Errorf(status.Internal, "failed to find user peers") + return false, status.Errorf(status.Internal, "failed to find user peers") + } + + hadPeers := len(peers) > 0 + if !hadPeers { + return false, nil } peerIDs := make([]string, 0, len(peers)) @@ -501,7 +509,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU peerIDs = append(peerIDs, peer.ID) } - return am.deletePeers(ctx, account, peerIDs, initiatorUserID) + return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -745,6 +753,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, updatedUsers := make([]*UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer + userIDs []string eventsToStore []func() ) @@ -753,6 +762,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } + userIDs = append(userIDs, update.Id) + oldUser := account.Users[update.Id] if oldUser == nil { if !addIfNotExists { @@ -816,7 +827,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, err } - if account.Settings.GroupsPropagationEnabled { + if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { am.updateAccountPeers(ctx, account) } @@ -1167,7 +1178,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") } - var allErrors error + var ( + allErrors error + updateAccountPeers bool + ) deletedUsersMeta := make(map[string]map[string]any) for _, targetUserID := range targetUserIDs { @@ -1193,12 +1207,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err)) continue } + if hadPeers { + updateAccountPeers = true + } + delete(account.Users, targetUserID) deletedUsersMeta[targetUserID] = meta } @@ -1208,7 +1226,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return fmt.Errorf("failed to delete users: %w", err) } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } for targetUserID, meta := range deletedUsersMeta { am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) @@ -1217,11 +1237,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) - return nil, err + return nil, false, err } if !isNil(am.idpManager) { @@ -1232,16 +1252,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) if err != nil { log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) - return nil, err + return nil, false, err } } else { log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) } } - err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) if err != nil { - return nil, err + return nil, false, err } u, err := account.FindUser(targetUserID) @@ -1254,7 +1274,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun tuCreatedAt = u.CreatedAt } - return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil + return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. @@ -1333,3 +1353,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } return nil, false } + +// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. +func areUsersLinkedToPeers(account *Account, userIDs []string) bool { + for _, peer := range account.Peers { + if slices.Contains(userIDs, peer.UserID) { + return true + } + } + return false +} diff --git a/management/server/user_test.go b/management/server/user_test.go index 1a5704551..d4f560a54 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,9 +10,12 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" @@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { }) } } + +func TestUserAccountPeersUpdate(t *testing.T) { + // account groups propagation is enabled + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + require.NoError(t, err) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Creating a new regular user should not update account peers and not send peer update + t.Run("creating new regular user with no groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleUser, + Issued: UserIssuedAPI, + }, true) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // updating user with no linked peers should not update account peers and not send peer update + t.Run("updating user with no linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleUser, + Issued: UserIssuedAPI, + }, false) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // deleting user with no linked peers should not update account peers and not send peer update + t.Run("deleting user with no linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1") + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // create a user and add new peer with the user + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + }, true) + require.NoError(t, err) + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + // updating user with linked peers should update account peers and send peer update + t.Run("updating user with linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + }, false) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + }) + + // deleting user with linked peers should update account peers and send peer update + t.Run("deleting user with linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, peer4UpdMsg) + close(done) + }() + + err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2") + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +}