mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
[management] Optimize network map updates (#2718)
* Skip peer update on unchanged network map (#2236) * Enhance network updates by skipping unchanged messages Optimizes the network update process by skipping updates where no changes in the peer update message received. * Add unit tests * add locks * Improve concurrency and update peer message handling * Refactor account manager network update tests * fix test * Fix inverted network map update condition * Add default group and policy to test data * Run peer updates in a separate goroutine * Refactor * Refactor lock * Fix peers update by including NetworkMap and posture Checks * go mod tidy * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Skip account peers update if no changes affect peers (#2310) * Remove incrementing network serial and updating peers after group deletion * Update account peer if posture check is linked to policy * Remove account peers update on saving setup key * Refactor group link checking into re-usable functions * Add HasPeers function to group * Refactor group management * Optimize group change effects on account peers * Update account peers if ns group has peers * Refactor group changes * Optimize account peers update in DNS settings * Optimize update of account peers on jwt groups sync * Refactor peer account updates for efficiency * Optimize peer update on user deletion and changes * Remove condition check for network serial update * Optimize account peers updates on route changes * Remove UpdatePeerSSHKey method * Remove unused isPolicyRuleGroupsEmpty * Add tests for peer update behavior on posture check changes * Add tests for peer update behavior on policy changes * Add tests for peer update behavior on group changes * Add tests for peer update behavior on dns settings changes * Refactor * Add tests for peer update behavior on name server changes * Add tests for peer update behavior on user changes * Add tests for peer update behavior on route changes * fix tests * Add tests for peer update behavior on setup key changes * Add tests for peer update behavior on peers changes * fix merge * Fix tests * go mod tidy * Add NameServer and Route comparators * Update network map diff logic with custom comparators * Add tests * Refactor duplicate diff handling logic * fix linter * fix tests * Refactor policy group handling and update logic. Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update route check by checking if group has peers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor posture check policy linking logic Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Simplify peer update condition in DNS management Refactor the condition for updating account peers to remove redundant checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add policy tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add posture checks tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix user and setup key tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix account and route tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix typo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix nameserver tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix routes tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix group tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * upgrade diff package Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix nameserver tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * use generic differ for netip.Addr and netip.Prefix Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * go mod tidy Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add peer tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix management suite tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix postgres tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * enable diff nil structs comparison Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip the update only last sent the serial is larger Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor peer and user Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip spell check for groupD Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor group, ns group, policy and posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip spell check for GroupD Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update account policy check before verifying policy status Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * add tests missing tests for dns setting groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add tests for posture checks changes Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add ns group and policy tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add route and group tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * increase Linux test timeout to 10 minutes Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Run diff for client posture checks only Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add panic recovery and detailed logging in peer update comparison Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
parent
30ebcf38c7
commit
7bda385e1b
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
|
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
3
go.mod
3
go.mod
@ -71,6 +71,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/r3labs/diff/v3 v3.0.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@ -210,6 +211,8 @@ require (
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
|
6
go.sum
6
go.sum
@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
|
@ -102,7 +102,6 @@ type AccountManager interface {
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||
@ -2132,8 +2131,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1122,66 +1122,196 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
userID := "account_creator"
|
||||
|
||||
account, err := createAccount(manager, "test_account", userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
return
|
||||
}
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
getPeer := func() *nbpeer.Peer {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
}()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expecting peer1 to be added, got failure %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return peer
|
||||
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := getPeer()
|
||||
peer2 := getPeer()
|
||||
peer3 := getPeer()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := group.Group{
|
||||
ID: "group-id",
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}
|
||||
@ -1191,116 +1321,48 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"group-id"},
|
||||
Destinations: []string{"group-id"},
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
t.Run("save group update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
// clean policy is pre requirement for delete group
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("delete policy update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("save policy update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("delete peer update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("delete group update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
// clean policy is pre requirement for delete group
|
||||
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
@ -2754,3 +2816,73 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
t.Helper()
|
||||
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := createAccount(manager, "test_account", userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
}
|
||||
|
||||
getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LastSeen: time.Now().UTC(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expecting peer to be added, got failure %v", err)
|
||||
}
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
peer1 := getPeer(manager, setupKey)
|
||||
peer2 := getPeer(manager, setupKey)
|
||||
peer3 := getPeer(manager, setupKey)
|
||||
|
||||
return manager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
t.Errorf("Unexpected message received: %+v", msg)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
82
management/server/differs/netip.go
Normal file
82
management/server/differs/netip.go
Normal file
@ -0,0 +1,82 @@
|
||||
package differs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
|
||||
"github.com/r3labs/diff/v3"
|
||||
)
|
||||
|
||||
// NetIPAddr is a custom differ for netip.Addr
|
||||
type NetIPAddr struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromAddr.String() != toAddr.String() {
|
||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
|
||||
// NetIPPrefix is a custom differ for netip.Prefix
|
||||
type NetIPPrefix struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromPrefix.String() != toPrefix.String() {
|
||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
@ -125,26 +125,29 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -6,9 +6,11 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||
t.Run("saving dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with peers from DNS settings should trigger updates to account peers and send peer updates
|
||||
t.Run("removing group with peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -121,12 +121,19 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
newGroupIDs := make([]string, 0, len(newGroups))
|
||||
for _, newGroup := range newGroups {
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
@ -238,8 +245,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -282,8 +287,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
@ -336,7 +339,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -366,7 +371,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -469,3 +476,32 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -44,3 +44,8 @@ func (g *Group) Copy() *Group {
|
||||
copy(group.Peers, g.Peers)
|
||||
return group
|
||||
}
|
||||
|
||||
// HasPeers checks if the group has any peers.
|
||||
func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
@ -4,13 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
}
|
||||
return am, acc, nil
|
||||
}
|
||||
|
||||
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Saving a group that is not linked to any resource should not update account peers
|
||||
t.Run("saving unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding a peer to a group that is not linked to any resource should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing a peer from a group that is not linked to any resource should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("removing peer from unliked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting group should not update account peers and not send peer update
|
||||
t.Run("deleting group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving a group linked to policy should update account peers and send peer update
|
||||
t.Run("saving linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding peer to a used group should update account peers and send peer update
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// removing peer from a linked group should update account peers and send peer update
|
||||
t.Run("removing peer from linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to name server group should update account peers and send peer update
|
||||
t.Run("saving group linked to name server group", func(t *testing.T) {
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupC"},
|
||||
true, nil, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to route should update account peers and send peer update
|
||||
t.Run("saving group linked to route", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
ID: "route",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupA"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupC"},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving a group linked to dns settings should update account peers and send peer update
|
||||
t.Run("saving group linked to dns settings", func(t *testing.T) {
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupD"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -57,7 +57,6 @@ type MockAccountManager struct {
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@ -434,14 +433,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if am.UpdatePeerSSHKeyFunc != nil {
|
||||
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeer mocks UpdatePeerFunc function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
if am.UpdatePeerFunc != nil {
|
||||
|
@ -66,13 +66,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@ -80,7 +80,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
@ -98,16 +97,17 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
@ -131,13 +131,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
return nil
|
||||
@ -277,3 +277,11 @@ func validateDomain(domain string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||
}
|
||||
|
@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
var newNameServerGroupA *nbdns.NameServerGroup
|
||||
var newNameServerGroupB *nbdns.NameServerGroup
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupA, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// saving a nameserver group with a distribution group with no peers should not update account peers
|
||||
// and not send peer update
|
||||
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should update account peers and send peer update
|
||||
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// saving a nameserver group with a distribution group with peers should update account peers and send peer update
|
||||
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// saving unchanged nameserver group should update account peers and not send peer update
|
||||
t.Run("saving unchanged nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
}
|
||||
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting a nameserver group should update account peers and send peer update
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -41,9 +41,9 @@ type Network struct {
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64
|
||||
Serial uint64 `diff:"-"`
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -200,7 +201,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
if peer.Name != update.Name {
|
||||
peerLabelUpdated := peer.Name != update.Name
|
||||
|
||||
if peerLabelUpdated {
|
||||
peer.Name = update.Name
|
||||
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
@ -260,7 +263,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if peerLabelUpdated {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
@ -304,6 +309,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
FirewallRulesIsEmpty: true,
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
@ -322,6 +328,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||
|
||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -332,7 +340,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -422,9 +432,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var groupsToAdd []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
var groupsToAdd []string
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
@ -576,7 +586,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
@ -897,51 +909,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if peer.SSHKey == sshKey {
|
||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.SSHKey = sshKey
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@ -1034,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
@ -1048,3 +1015,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
}
|
||||
return labelMap
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||
peerGroupIDs := make([]string, 0)
|
||||
for _, group := range account.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||
}
|
||||
|
@ -17,37 +17,37 @@ type Peer struct {
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string
|
||||
SetupKey string `diff:"-"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||
// The user ID that registered the peer
|
||||
UserID string
|
||||
UserID string `diff:"-"`
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
|
||||
InactivityExpirationEnabled bool
|
||||
InactivityExpirationEnabled bool `diff:"-"`
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time
|
||||
LastLogin time.Time `diff:"-"`
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool
|
||||
Ephemeral bool `diff:"-"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
@ -189,7 +189,6 @@ func (p *Peer) Copy() *Peer {
|
||||
CreatedAt: p.CreatedAt,
|
||||
Ephemeral: p.Ephemeral,
|
||||
Location: p.Location,
|
||||
|
||||
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||
}
|
||||
}
|
||||
|
@ -1253,3 +1253,322 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
|
||||
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||
}
|
||||
|
||||
func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// create a user with auto groups
|
||||
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{
|
||||
{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
AutoGroups: []string{"groupA"},
|
||||
},
|
||||
{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
AutoGroups: []string{"groupB"},
|
||||
},
|
||||
{
|
||||
Id: "regularUser3",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
AutoGroups: []string{"groupC"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
var peer4 *nbpeer.Peer
|
||||
var peer5 *nbpeer.Peer
|
||||
var peer6 *nbpeer.Peer
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
|
||||
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to unlinked group should not update account peers and not send peer update
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with unlinked group should not update account peers and not send peer update
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating peer label should update account peers and send peer update
|
||||
t.Run("updating peer label", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
peer1.Name = "peer-1"
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to policy should update account peers and send peer update
|
||||
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with route should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with route", func(t *testing.T) {
|
||||
route := nbroute.Route{
|
||||
ID: "testingRoute1",
|
||||
Network: netip.MustParsePrefix("100.65.250.202/32"),
|
||||
NetID: "superNet",
|
||||
NetworkType: nbroute.IPv4Network,
|
||||
PeerGroups: []string{"groupB"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupB"},
|
||||
}
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to route should update account peers and send peer update
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to group linked with name server group should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupC"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
LoginExpirationEnabled: true,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting peer with linked group to name server group should update account peers and send peer update
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
|
||||
}
|
||||
}
|
||||
|
||||
// ruleGroups returns a list of all groups referenced in the policy's rules,
|
||||
// including sources and destinations.
|
||||
func (p *Policy) ruleGroups() []string {
|
||||
groups := make([]string, 0)
|
||||
for _, rule := range p.Rules {
|
||||
groups = append(groups, rule.Sources...)
|
||||
groups = append(groups, rule.Destinations...)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// FirewallRule is a rule of the firewall.
|
||||
type FirewallRule struct {
|
||||
// PeerIP of the peer
|
||||
@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
||||
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
}
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
||||
|
||||
// savePolicy saves or updates a policy in the given account.
|
||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||
for index, rule := range policyToSave.Rules {
|
||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||
@ -442,18 +457,25 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
||||
if isUpdate {
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||
if policyIdx < 0 {
|
||||
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
}
|
||||
|
||||
oldPolicy := account.Policies[policyIdx]
|
||||
// Update the existing policy
|
||||
account.Policies[policyIdx] = policyToSave
|
||||
return nil
|
||||
|
||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||
|
||||
return updateAccountPeers, nil
|
||||
}
|
||||
|
||||
// Add the new policy to the account
|
||||
account.Policies = append(account.Policies, policyToSave)
|
||||
|
||||
return nil
|
||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
|
@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int {
|
||||
return 0 // a is equal to b
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-rule-groups-no-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving policy with source group containing peers, but destination group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-has-peers-destination-none",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-destination-has-peers-source-none",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: false,
|
||||
Sources: []string{"groupC"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Disabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||
// or send peer update
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Description: "updated description",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Enabling policy with destination and source groups containing peers should update account's peers
|
||||
// and send peer update
|
||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged policy should trigger account peers update but not send peer update
|
||||
t.Run("saving unchanged policy", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy should trigger account peers update and send peer update
|
||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policyID := "policy-source-destination-peers"
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Deleting policy with destination group containing peers, but source group without peers should
|
||||
// update account's peers and send peer update
|
||||
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg2)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -67,7 +67,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
if exists {
|
||||
|
||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
@ -148,13 +149,9 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
||||
}
|
||||
|
||||
// check policy links
|
||||
for _, policy := range account.Policies {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if id == postureChecksID {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||
}
|
||||
}
|
||||
// Check if posture check is linked to any policy
|
||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
||||
}
|
||||
|
||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
||||
@ -217,3 +214,25 @@ func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
||||
for _, policy := range account.Policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
||||
if !isLinked {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
||||
}
|
||||
|
@ -3,7 +3,10 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@ -118,3 +121,458 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
postureCheck := posture.Checks{
|
||||
ID: "postureCheck",
|
||||
Name: "postureCheck",
|
||||
AccountID: account.Id,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.28.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Saving unused posture check should not update account peers and not send peer update
|
||||
t.Run("saving unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unused posture check should not update account peers and not send peer update
|
||||
t.Run("updating unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
policy := Policy{
|
||||
ID: "policyA",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
|
||||
// Linking posture check to policy should trigger update account peers and send peer update
|
||||
t.Run("linking posture check to policy with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture checks should update account peers and send peer update
|
||||
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged posture check should not trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing posture check from policy should trigger account peers update and send peer update
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.SourcePostureChecks = []string{}
|
||||
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting unused posture check should not trigger account peers update and not send peer update
|
||||
t.Run("deleting unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture check to policy where destination has peers but source does not
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked posture check to policy where source has peers but destination does not,
|
||||
// should not trigger account peers update or send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{postureCheck.ID},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
postureCheck.Checks = posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policyA",
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"checkA"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*group.Group{
|
||||
"groupA": {
|
||||
ID: "groupA",
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
"groupB": {
|
||||
ID: "groupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: "checkA",
|
||||
},
|
||||
{
|
||||
ID: "checkB",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check does not exist", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
account.Groups["groupA"].Peers = []string{}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute := account.Routes[routeToSave.ID]
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
}
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
manager, err := createRouterManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestRouteAccount(t, manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
|
||||
t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) {
|
||||
route := route.Route{
|
||||
ID: "testingRoute1",
|
||||
Network: netip.MustParsePrefix("100.65.250.202/32"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupA"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupA"},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
|
||||
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
|
||||
route := route.Route{
|
||||
ID: "testingRoute2",
|
||||
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{routeGroup3},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup3},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
baseRoute := route.Route{
|
||||
ID: "testingRoute3",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
|
||||
// Creating route should update account peers and send peer update
|
||||
t.Run("creating route with a routing peer", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
newRoute, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
|
||||
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
baseRoute = *newRoute
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating the route should update account peers and send peer update when there is peers in group
|
||||
t.Run("updating route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Updating unchanged route should update account peers and not send peer update
|
||||
t.Run("updating unchanged route", func(t *testing.T) {
|
||||
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Deleting the route should update account peers and send peer update
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to route peer groups that do not have any peers should update account peers and send peer update
|
||||
t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
Network: netip.MustParsePrefix("192.168.12.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupB"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Adding peer to route groups that do not have any peers should update account peers and send peer update
|
||||
t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
Network: netip.MustParsePrefix("192.168.13.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
PeerGroups: []string{"groupB"},
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupC"},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -323,8 +323,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
}
|
||||
}()
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return newKey, nil
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
@ -352,3 +353,73 @@ func TestSetupKey_Copy(t *testing.T) {
|
||||
key.UpdatedAt, key.AutoGroups)
|
||||
|
||||
}
|
||||
|
||||
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"group"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
var setupKey *SetupKey
|
||||
|
||||
// Creating setup key should not update account peers and not send peer update
|
||||
t.Run("creating setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Saving setup key should not update account peers and not send peer update
|
||||
t.Run("saving setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
5
management/server/testdata/store.sql
vendored
5
management/server/testdata/store.sql
vendored
@ -26,8 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||
|
@ -2,9 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/differs"
|
||||
"github.com/r3labs/diff/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@ -14,14 +18,17 @@ import (
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *NetworkMap
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||
peerUpdateMessage map[string]*UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.Mutex
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
@ -29,9 +36,10 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.Mutex{},
|
||||
metrics: metrics,
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
// skip sending sync update to the peer if there is no change in update message,
|
||||
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||
if update.NetworkMap != nil {
|
||||
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.channelsMux.Lock()
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.Unlock()
|
||||
if p.metrics != nil {
|
||||
@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if update.NetworkMap != nil {
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
|
||||
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||
return
|
||||
}
|
||||
p.peerUpdateMessage[peerID] = update
|
||||
}
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
@ -170,3 +200,72 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||
p.channelsMux.RLock()
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
p.channelsMux.RUnlock()
|
||||
|
||||
if lastSentUpdate != nil {
|
||||
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||
return false
|
||||
}
|
||||
if !updated {
|
||||
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
|
||||
isNew, err = true, nil
|
||||
}
|
||||
}()
|
||||
|
||||
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
differ, err := diff.NewDiffer(
|
||||
diff.CustomValueDiffers(&differs.NetIPAddr{}),
|
||||
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create differ: %v", err)
|
||||
}
|
||||
|
||||
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
|
||||
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
|
||||
|
||||
changelog, err := differ.Diff(lastSentFiles, currFiles)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||
}
|
||||
return len(changelog) > 0, nil
|
||||
}
|
||||
|
||||
// getChecksFiles returns a list of files from the given checks.
|
||||
func getChecksFiles(checks []*proto.Checks) []string {
|
||||
files := make([]string, 0, len(checks))
|
||||
for _, check := range checks {
|
||||
files = append(files, check.GetFiles()...)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
@ -2,10 +2,19 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// var peersUpdater *PeersUpdateManager
|
||||
@ -77,3 +86,470 @@ func TestCloseChannel(t *testing.T) {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
existingUpdate *UpdateMessage
|
||||
newUpdate *UpdateMessage
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "update message with turn credentials update",
|
||||
peerID: "peer",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||
},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message for peer without existing update",
|
||||
peerID: "peer1",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with no changes in update",
|
||||
peerID: "peer2",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "update message with changes in checks",
|
||||
peerID: "peer3",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
Checks: []*proto.Checks{
|
||||
{
|
||||
Files: []string{"/usr/bin/netbird"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with lower serial number",
|
||||
peerID: "peer4",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewPeersUpdateManager(nil)
|
||||
ctx := context.Background()
|
||||
|
||||
if tt.existingUpdate != nil {
|
||||
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||
}
|
||||
|
||||
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNewPeerUpdateMessage(t *testing.T) {
|
||||
t.Run("Unchanged value", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating routes network", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
})
|
||||
|
||||
t.Run("Updating routes groups", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating network map peers", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.4"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer4-key",
|
||||
DNSLabel: "peer4",
|
||||
SSHKey: "peer4-ssh-key",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating process check", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, message)
|
||||
|
||||
newUpdateMessage3 := createMockUpdateMessage(t)
|
||||
newUpdateMessage3.Update.Checks = []*proto.Checks{}
|
||||
newUpdateMessage3.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage4 := createMockUpdateMessage(t)
|
||||
check := &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/local/netbird",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage4.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
|
||||
newUpdateMessage5 := createMockUpdateMessage(t)
|
||||
check = &posture.Checks{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/local/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||
newUpdateMessage5.Update.NetworkMap.Serial++
|
||||
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating DNS configuration", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newDomain := "newexample.com"
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
|
||||
newDomain,
|
||||
)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating peer IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Add new firewall rule", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newRule := &FirewallRule{
|
||||
PeerIP: "192.168.1.3",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Action: string(PolicyTrafficActionDrop),
|
||||
Protocol: string(PolicyRuleProtocolUDP),
|
||||
Port: "53",
|
||||
}
|
||||
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Removing nameserver", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating name server IP", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
t.Run("Updating custom DNS zone", func(t *testing.T) {
|
||||
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||
|
||||
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
|
||||
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||
|
||||
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, message)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
|
||||
t.Helper()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
|
||||
secretManager := NewTimeBasedAuthSecretsManager(
|
||||
NewPeersUpdateManager(nil),
|
||||
&TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{
|
||||
Duration: defaultDuration,
|
||||
},
|
||||
Secret: "secret",
|
||||
Turns: []*Host{TurnTestHost},
|
||||
},
|
||||
&Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "secret",
|
||||
},
|
||||
)
|
||||
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipNet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
{
|
||||
ID: "route2",
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route2",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"test1", "test2"},
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{
|
||||
{
|
||||
LinuxPath: "/usr/bin/netbird",
|
||||
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||
MacPath: "/usr/bin/netbird",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
turnToken, err := secretManager.GenerateTurnToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
relayToken, err := secretManager.GenerateRelayToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &UpdateMessage{
|
||||
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
|
||||
NetworkMap: networkMap,
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -473,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -485,15 +486,22 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error {
|
||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
|
||||
peers, err := account.FindUserPeers(targetUserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to find user peers")
|
||||
return false, status.Errorf(status.Internal, "failed to find user peers")
|
||||
}
|
||||
|
||||
hadPeers := len(peers) > 0
|
||||
if !hadPeers {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
peerIDs := make([]string, 0, len(peers))
|
||||
@ -501,7 +509,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
}
|
||||
|
||||
return am.deletePeers(ctx, account, peerIDs, initiatorUserID)
|
||||
return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID)
|
||||
}
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
@ -745,6 +753,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
updatedUsers := make([]*UserInfo, 0, len(updates))
|
||||
var (
|
||||
expiredPeers []*nbpeer.Peer
|
||||
userIDs []string
|
||||
eventsToStore []func()
|
||||
)
|
||||
|
||||
@ -753,6 +762,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
userIDs = append(userIDs, update.Id)
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
if !addIfNotExists {
|
||||
@ -816,7 +827,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
@ -1167,7 +1178,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var (
|
||||
allErrors error
|
||||
updateAccountPeers bool
|
||||
)
|
||||
|
||||
deletedUsersMeta := make(map[string]map[string]any)
|
||||
for _, targetUserID := range targetUserIDs {
|
||||
@ -1193,12 +1207,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
continue
|
||||
}
|
||||
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if hadPeers {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
deletedUsersMeta[targetUserID] = meta
|
||||
}
|
||||
@ -1208,7 +1226,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return fmt.Errorf("failed to delete users: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for targetUserID, meta := range deletedUsersMeta {
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
@ -1217,11 +1237,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
||||
return allErrors
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
@ -1232,16 +1252,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
u, err := account.FindUser(targetUserID)
|
||||
@ -1254,7 +1274,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
||||
tuCreatedAt = u.CreatedAt
|
||||
}
|
||||
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
|
||||
}
|
||||
|
||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||
@ -1333,3 +1353,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account.
|
||||
func areUsersLinkedToPeers(account *Account, userIDs []string) bool {
|
||||
for _, peer := range account.Peers {
|
||||
if slices.Contains(userIDs, peer.UserID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -10,9 +10,12 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
// account groups propagation is enabled
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedAPI,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// updating user with no linked peers should not update account peers and not send peer update
|
||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser1",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedAPI,
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// deleting user with no linked peers should not update account peers and not send peer update
|
||||
t.Run("deleting user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1")
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// create a user and add new peer with the user
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// updating user with linked peers should update account peers and send peer update
|
||||
t.Run("updating user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||
Id: "regularUser2",
|
||||
AccountID: account.Id,
|
||||
Role: UserRoleAdmin,
|
||||
Issued: UserIssuedAPI,
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
|
||||
})
|
||||
|
||||
// deleting user with linked peers should update account peers and send peer update
|
||||
t.Run("deleting user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, peer4UpdMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2")
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user