mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 10:20:09 +01:00
[management] Optimize network map updates (#2718)
* Skip peer update on unchanged network map (#2236) * Enhance network updates by skipping unchanged messages Optimizes the network update process by skipping updates where no changes in the peer update message received. * Add unit tests * add locks * Improve concurrency and update peer message handling * Refactor account manager network update tests * fix test * Fix inverted network map update condition * Add default group and policy to test data * Run peer updates in a separate goroutine * Refactor * Refactor lock * Fix peers update by including NetworkMap and posture Checks * go mod tidy * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * [management] Skip account peers update if no changes affect peers (#2310) * Remove incrementing network serial and updating peers after group deletion * Update account peer if posture check is linked to policy * Remove account peers update on saving setup key * Refactor group link checking into re-usable functions * Add HasPeers function to group * Refactor group management * Optimize group change effects on account peers * Update account peers if ns group has peers * Refactor group changes * Optimize account peers update in DNS settings * Optimize update of account peers on jwt groups sync * Refactor peer account updates for efficiency * Optimize peer update on user deletion and changes * Remove condition check for network serial update * Optimize account peers updates on route changes * Remove UpdatePeerSSHKey method * Remove unused isPolicyRuleGroupsEmpty * Add tests for peer update behavior on posture check changes * Add tests for peer update behavior on policy changes * Add tests for peer update behavior on group changes * Add tests for peer update behavior on dns settings changes * Refactor * Add tests for peer update behavior on name server changes * Add tests for peer update behavior on user changes * Add tests for peer update behavior on route changes * fix tests * Add tests for peer update behavior on setup key changes * Add tests for peer update behavior on peers changes * fix merge * Fix tests * go mod tidy * Add NameServer and Route comparators * Update network map diff logic with custom comparators * Add tests * Refactor duplicate diff handling logic * fix linter * fix tests * Refactor policy group handling and update logic. Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update route check by checking if group has peers Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor posture check policy linking logic Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Simplify peer update condition in DNS management Refactor the condition for updating account peers to remove redundant checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add policy tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add posture checks tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix user and setup key tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix account and route tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix typo Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix nameserver tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix routes tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix group tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * upgrade diff package Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix nameserver tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * use generic differ for netip.Addr and netip.Prefix Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * go mod tidy Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add peer tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix management suite tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix postgres tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * enable diff nil structs comparison Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip the update only last sent the serial is larger Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor peer and user Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip spell check for groupD Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor group, ns group, policy and posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip spell check for GroupD Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * update account policy check before verifying policy status Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * Update management/server/route_test.go Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> * add tests missing tests for dns setting groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add tests for posture checks changes Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add ns group and policy tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add route and group tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * increase Linux test timeout to 10 minutes Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Run diff for client posture checks only Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add panic recovery and detailed logging in peer update comparison Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
parent
30ebcf38c7
commit
7bda385e1b
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
|
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif
|
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
|
3
go.mod
3
go.mod
@ -71,6 +71,7 @@ require (
|
|||||||
github.com/pion/transport/v3 v3.0.1
|
github.com/pion/transport/v3 v3.0.1
|
||||||
github.com/pion/turn/v3 v3.0.1
|
github.com/pion/turn/v3 v3.0.1
|
||||||
github.com/prometheus/client_golang v1.19.1
|
github.com/prometheus/client_golang v1.19.1
|
||||||
|
github.com/r3labs/diff/v3 v3.0.1
|
||||||
github.com/rs/xid v1.3.0
|
github.com/rs/xid v1.3.0
|
||||||
github.com/shirou/gopsutil/v3 v3.24.4
|
github.com/shirou/gopsutil/v3 v3.24.4
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
@ -210,6 +211,8 @@ require (
|
|||||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
github.com/yuin/goldmark v1.7.1 // indirect
|
github.com/yuin/goldmark v1.7.1 // indirect
|
||||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
|
6
go.sum
6
go.sum
@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
|||||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||||
|
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||||
|
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||||
@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
|||||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
@ -102,7 +102,6 @@ type AccountManager interface {
|
|||||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||||
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
|
|
||||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
||||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||||
@ -2132,8 +2131,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error getting account: %w", err)
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||||
am.updateAccountPeers(ctx, account)
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1122,66 +1122,196 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
|||||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
group := group.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{},
|
||||||
|
}
|
||||||
|
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||||
|
t.Errorf("save group: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := "account_creator"
|
policy := Policy{
|
||||||
|
ID: "policy",
|
||||||
account, err := createAccount(manager, "test_account", userID, "")
|
Enabled: true,
|
||||||
if err != nil {
|
Rules: []*PolicyRule{
|
||||||
t.Fatal(err)
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
if err != nil {
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
t.Fatal("error creating setup key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.Network.Serial != 0 {
|
wg := sync.WaitGroup{}
|
||||||
t.Errorf("expecting account network to have an initial Serial=0")
|
wg.Add(1)
|
||||||
return
|
go func() {
|
||||||
}
|
defer wg.Done()
|
||||||
|
|
||||||
getPeer := func() *nbpeer.Peer {
|
message := <-updMsg
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
networkMap := message.Update.GetNetworkMap()
|
||||||
if err != nil {
|
if len(networkMap.RemotePeers) != 2 {
|
||||||
t.Fatal(err)
|
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
expectedPeerKey := key.PublicKey().String()
|
}()
|
||||||
|
|
||||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
|
||||||
Key: expectedPeerKey,
|
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
t.Errorf("save group: %v", err)
|
||||||
})
|
return
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("expecting peer1 to be added, got failure %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peer1 := getPeer()
|
wg.Wait()
|
||||||
peer2 := getPeer()
|
}
|
||||||
peer3 := getPeer()
|
|
||||||
|
|
||||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||||
if err != nil {
|
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||||
t.Fatal(err)
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
message := <-updMsg
|
||||||
|
networkMap := message.Update.GetNetworkMap()
|
||||||
|
if len(networkMap.RemotePeers) != 0 {
|
||||||
|
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||||
|
t.Errorf("delete default rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
group := group.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID},
|
||||||
|
}
|
||||||
|
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||||
|
t.Errorf("save group: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
|
policy := Policy{
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
message := <-updMsg
|
||||||
|
networkMap := message.Update.GetNetworkMap()
|
||||||
|
if len(networkMap.RemotePeers) != 2 {
|
||||||
|
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||||
|
t.Errorf("delete default rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||||
|
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
group := group.Group{
|
group := group.Group{
|
||||||
ID: "group-id",
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer3.ID},
|
||||||
|
}
|
||||||
|
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||||
|
t.Errorf("save group: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := Policy{
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||||
|
t.Errorf("save policy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
message := <-updMsg
|
||||||
|
networkMap := message.Update.GetNetworkMap()
|
||||||
|
if len(networkMap.RemotePeers) != 1 {
|
||||||
|
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||||
|
t.Errorf("delete peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
|
group := group.Group{
|
||||||
|
ID: "groupA",
|
||||||
Name: "GroupA",
|
Name: "GroupA",
|
||||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
}
|
}
|
||||||
@ -1191,116 +1321,48 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
|||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"group-id"},
|
Sources: []string{"groupA"},
|
||||||
Destinations: []string{"group-id"},
|
Destinations: []string{"groupA"},
|
||||||
Bidirectional: true,
|
Bidirectional: true,
|
||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||||
|
t.Errorf("delete default rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||||
|
t.Errorf("save policy: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
t.Run("save group update", func(t *testing.T) {
|
wg.Add(1)
|
||||||
wg.Add(1)
|
go func() {
|
||||||
go func() {
|
defer wg.Done()
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
message := <-updMsg
|
message := <-updMsg
|
||||||
networkMap := message.Update.GetNetworkMap()
|
networkMap := message.Update.GetNetworkMap()
|
||||||
if len(networkMap.RemotePeers) != 2 {
|
if len(networkMap.RemotePeers) != 0 {
|
||||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
|
||||||
t.Errorf("save group: %v", err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
// clean policy is pre requirement for delete group
|
||||||
})
|
if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil {
|
||||||
|
t.Errorf("delete default rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("delete policy update", func(t *testing.T) {
|
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||||
wg.Add(1)
|
t.Errorf("delete group: %v", err)
|
||||||
go func() {
|
return
|
||||||
defer wg.Done()
|
}
|
||||||
|
|
||||||
message := <-updMsg
|
wg.Wait()
|
||||||
networkMap := message.Update.GetNetworkMap()
|
|
||||||
if len(networkMap.RemotePeers) != 0 {
|
|
||||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
|
||||||
t.Errorf("delete default rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("save policy update", func(t *testing.T) {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
message := <-updMsg
|
|
||||||
networkMap := message.Update.GetNetworkMap()
|
|
||||||
if len(networkMap.RemotePeers) != 2 {
|
|
||||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
|
||||||
t.Errorf("delete default rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
t.Run("delete peer update", func(t *testing.T) {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
message := <-updMsg
|
|
||||||
networkMap := message.Update.GetNetworkMap()
|
|
||||||
if len(networkMap.RemotePeers) != 1 {
|
|
||||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
|
||||||
t.Errorf("delete peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete group update", func(t *testing.T) {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
message := <-updMsg
|
|
||||||
networkMap := message.Update.GetNetworkMap()
|
|
||||||
if len(networkMap.RemotePeers) != 0 {
|
|
||||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// clean policy is pre requirement for delete group
|
|
||||||
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
|
|
||||||
|
|
||||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
|
||||||
t.Errorf("delete group: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||||
@ -2754,3 +2816,73 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := createAccount(manager, "test_account", userID, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("error creating setup key")
|
||||||
|
}
|
||||||
|
|
||||||
|
getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
expectedPeerKey := key.PublicKey().String()
|
||||||
|
|
||||||
|
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||||
|
Key: expectedPeerKey,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||||
|
Status: &nbpeer.PeerStatus{
|
||||||
|
Connected: true,
|
||||||
|
LastSeen: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expecting peer to be added, got failure %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1 := getPeer(manager, setupKey)
|
||||||
|
peer2 := getPeer(manager, setupKey)
|
||||||
|
peer3 := getPeer(manager, setupKey)
|
||||||
|
|
||||||
|
return manager, account, peer1, peer2, peer3
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||||
|
t.Helper()
|
||||||
|
select {
|
||||||
|
case msg := <-updateMessage:
|
||||||
|
t.Errorf("Unexpected message received: %+v", msg)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-updateMessage:
|
||||||
|
if msg == nil {
|
||||||
|
t.Errorf("Received nil update message, expected valid message")
|
||||||
|
}
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Error("Timed out waiting for update message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
82
management/server/differs/netip.go
Normal file
82
management/server/differs/netip.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package differs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/r3labs/diff/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetIPAddr is a custom differ for netip.Addr
|
||||||
|
type NetIPAddr struct {
|
||||||
|
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||||
|
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||||
|
if a.Kind() == reflect.Invalid {
|
||||||
|
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.Kind() == reflect.Invalid {
|
||||||
|
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||||
|
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return fmt.Errorf("invalid type for netip.Addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fromAddr.String() != toAddr.String() {
|
||||||
|
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||||
|
differ.DiffFunc = dfunc //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetIPPrefix is a custom differ for netip.Prefix
|
||||||
|
type NetIPPrefix struct {
|
||||||
|
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||||
|
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||||
|
if a.Kind() == reflect.Invalid {
|
||||||
|
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if b.Kind() == reflect.Invalid {
|
||||||
|
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||||
|
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
return fmt.Errorf("invalid type for netip.Addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fromPrefix.String() != toPrefix.String() {
|
||||||
|
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||||
|
differ.DiffFunc = dfunc //nolint
|
||||||
|
}
|
@ -125,26 +125,29 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
|||||||
oldSettings := account.DNSSettings.Copy()
|
oldSettings := account.DNSSettings.Copy()
|
||||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||||
|
|
||||||
|
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||||
|
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
|
||||||
for _, id := range addedGroups {
|
for _, id := range addedGroups {
|
||||||
group := account.GetGroup(id)
|
group := account.GetGroup(id)
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
|
||||||
for _, id := range removedGroups {
|
for _, id := range removedGroups {
|
||||||
group := account.GetGroup(id)
|
group := account.GetGroup(id)
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -6,9 +6,11 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
|||||||
t.Errorf("Cache should contain name server group 'group2'")
|
t.Errorf("Cache should contain name server group 'group2'")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||||
|
t.Run("saving dns setting with unused groups", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"groupA"},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = manager.CreateNameServerGroup(
|
||||||
|
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||||
|
NSType: dns.UDPNameServerType,
|
||||||
|
Port: dns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
[]string{"groupA"},
|
||||||
|
true, []string{}, true, userID, false,
|
||||||
|
)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||||
|
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||||
|
// since there is no change in the network map
|
||||||
|
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||||
|
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"groupA"},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Removing group with peers from DNS settings should trigger updates to account peers and send peer updates
|
||||||
|
t.Run("removing group with peers from dns settings", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -121,12 +121,19 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
|||||||
eventsToStore = append(eventsToStore, events...)
|
eventsToStore = append(eventsToStore, events...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newGroupIDs := make([]string, 0, len(newGroups))
|
||||||
|
for _, newGroup := range newGroups {
|
||||||
|
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||||
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
for _, storeEvent := range eventsToStore {
|
for _, storeEvent := range eventsToStore {
|
||||||
storeEvent()
|
storeEvent()
|
||||||
@ -238,8 +245,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
|||||||
|
|
||||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,8 +287,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us
|
|||||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
|
|
||||||
return allErrors
|
return allErrors
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -336,7 +339,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -366,7 +371,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -469,3 +476,32 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
|||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||||
|
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -44,3 +44,8 @@ func (g *Group) Copy() *Group {
|
|||||||
copy(group.Peers, g.Peers)
|
copy(group.Peers, g.Peers)
|
||||||
return group
|
return group
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasPeers checks if the group has any peers.
|
||||||
|
func (g *Group) HasPeers() bool {
|
||||||
|
return len(g.Peers) > 0
|
||||||
|
}
|
||||||
|
@ -4,13 +4,16 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
|||||||
}
|
}
|
||||||
return am, acc, nil
|
return am, acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{peer1.ID, peer3.ID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupD",
|
||||||
|
Name: "GroupD",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving a group that is not linked to any resource should not update account peers
|
||||||
|
t.Run("saving unlinked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding a peer to a group that is not linked to any resource should not update account peers
|
||||||
|
// and not send peer update
|
||||||
|
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Removing a peer from a group that is not linked to any resource should not update account peers
|
||||||
|
// and not send peer update
|
||||||
|
t.Run("removing peer from unliked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting group should not update account peers and not send peer update
|
||||||
|
t.Run("deleting group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// adding a group to policy
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
ID: "policy",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Saving a group linked to policy should update account peers and send peer update
|
||||||
|
t.Run("saving linked group to policy", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||||
|
// since there is no change in the network map
|
||||||
|
t.Run("saving unchanged group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// adding peer to a used group should update account peers and send peer update
|
||||||
|
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// removing peer from a linked group should update account peers and send peer update
|
||||||
|
t.Run("removing peer from linked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving a group linked to name server group should update account peers and send peer update
|
||||||
|
t.Run("saving group linked to name server group", func(t *testing.T) {
|
||||||
|
_, err = manager.CreateNameServerGroup(
|
||||||
|
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
[]string{"groupC"},
|
||||||
|
true, nil, true, userID, false,
|
||||||
|
)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{peer1.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving a group linked to route should update account peers and send peer update
|
||||||
|
t.Run("saving group linked to route", func(t *testing.T) {
|
||||||
|
newRoute := route.Route{
|
||||||
|
ID: "route",
|
||||||
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
PeerGroups: []string{"groupA"},
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"groupC"},
|
||||||
|
}
|
||||||
|
_, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||||
|
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||||
|
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving a group linked to dns settings should update account peers and send peer update
|
||||||
|
t.Run("saving group linked to dns settings", func(t *testing.T) {
|
||||||
|
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||||
|
DisabledManagementGroups: []string{"groupD"},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupD",
|
||||||
|
Name: "GroupD",
|
||||||
|
Peers: []string{peer1.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -57,7 +57,6 @@ type MockAccountManager struct {
|
|||||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
|
||||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
@ -434,14 +433,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
|
||||||
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
|
||||||
if am.UpdatePeerSSHKeyFunc != nil {
|
|
||||||
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
|
|
||||||
}
|
|
||||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeer mocks UpdatePeerFunc function of the account manager
|
// UpdatePeer mocks UpdatePeerFunc function of the account manager
|
||||||
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
if am.UpdatePeerFunc != nil {
|
if am.UpdatePeerFunc != nil {
|
||||||
|
@ -66,13 +66,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
|
|
||||||
return newNSGroup.Copy(), nil
|
return newNSGroup.Copy(), nil
|
||||||
@ -80,7 +80,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
|
|
||||||
// SaveNameServerGroup saves nameserver group
|
// SaveNameServerGroup saves nameserver group
|
||||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -98,16 +97,17 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
||||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -131,13 +131,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
delete(account.NameServerGroups, nsGroupID)
|
delete(account.NameServerGroups, nsGroupID)
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -277,3 +277,11 @@ func validateDomain(domain string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||||
|
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||||
|
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||||
|
}
|
||||||
|
@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
var newNameServerGroupA *nbdns.NameServerGroup
|
||||||
|
var newNameServerGroupB *nbdns.NameServerGroup
|
||||||
|
|
||||||
|
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Creating a nameserver group with a distribution group no peers should not update account peers
|
||||||
|
// and not send peer update
|
||||||
|
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
newNameServerGroupA, err = manager.CreateNameServerGroup(
|
||||||
|
context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
[]string{"groupA"},
|
||||||
|
true, []string{}, true, userID, false,
|
||||||
|
)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// saving a nameserver group with a distribution group with no peers should not update account peers
|
||||||
|
// and not send peer update
|
||||||
|
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Creating a nameserver group with a distribution group no peers should update account peers and send peer update
|
||||||
|
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
newNameServerGroupB, err = manager.CreateNameServerGroup(
|
||||||
|
context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
[]string{"groupB"},
|
||||||
|
true, []string{}, true, userID, false,
|
||||||
|
)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// saving a nameserver group with a distribution group with peers should update account peers and send peer update
|
||||||
|
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.2"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// saving unchanged nameserver group should update account peers and not send peer update
|
||||||
|
t.Run("saving unchanged nameserver group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
newNameServerGroupB.NameServers = []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.2"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting a nameserver group should update account peers and send peer update
|
||||||
|
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -41,9 +41,9 @@ type Network struct {
|
|||||||
Dns string
|
Dns string
|
||||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||||
// Used to synchronize state to the client apps.
|
// Used to synchronize state to the client apps.
|
||||||
Serial uint64
|
Serial uint64 `diff:"-"`
|
||||||
|
|
||||||
mu sync.Mutex `json:"-" gorm:"-"`
|
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -200,7 +201,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.Name != update.Name {
|
peerLabelUpdated := peer.Name != update.Name
|
||||||
|
|
||||||
|
if peerLabelUpdated {
|
||||||
peer.Name = update.Name
|
peer.Name = update.Name
|
||||||
|
|
||||||
existingLabels := account.getPeerDNSLabels()
|
existingLabels := account.getPeerDNSLabels()
|
||||||
@ -260,7 +263,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if peerLabelUpdated {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
@ -304,6 +309,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
|||||||
FirewallRulesIsEmpty: true,
|
FirewallRulesIsEmpty: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
NetworkMap: &NetworkMap{},
|
||||||
})
|
})
|
||||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||||
@ -322,6 +328,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||||
|
|
||||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -332,7 +340,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -422,9 +432,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var newPeer *nbpeer.Peer
|
var newPeer *nbpeer.Peer
|
||||||
|
var groupsToAdd []string
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
var groupsToAdd []string
|
|
||||||
var setupKeyID string
|
var setupKeyID string
|
||||||
var setupKeyName string
|
var setupKeyName string
|
||||||
var ephemeral bool
|
var ephemeral bool
|
||||||
@ -576,7 +586,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -897,51 +909,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerSSHKey updates peer's public SSH key
|
|
||||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
|
||||||
if sshKey == "" {
|
|
||||||
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
|
||||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.GetPeer(peerID)
|
|
||||||
if peer == nil {
|
|
||||||
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer.SSHKey == sshKey {
|
|
||||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.SSHKey = sshKey
|
|
||||||
account.UpdatePeer(peer)
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// trigger network map update
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
@ -1034,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
postureChecks := am.getPeerPostureChecks(account, p)
|
postureChecks := am.getPeerPostureChecks(account, p)
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||||
}(peer)
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1048,3 +1015,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
|||||||
}
|
}
|
||||||
return labelMap
|
return labelMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
|
// in an active DNS, route, or ACL configuration.
|
||||||
|
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||||
|
peerGroupIDs := make([]string, 0)
|
||||||
|
for _, group := range account.Groups {
|
||||||
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||||
|
}
|
||||||
|
@ -17,37 +17,37 @@ type Peer struct {
|
|||||||
// WireGuard public key
|
// WireGuard public key
|
||||||
Key string `gorm:"index"`
|
Key string `gorm:"index"`
|
||||||
// A setup key this peer was registered with
|
// A setup key this peer was registered with
|
||||||
SetupKey string
|
SetupKey string `diff:"-"`
|
||||||
// IP address of the Peer
|
// IP address of the Peer
|
||||||
IP net.IP `gorm:"serializer:json"`
|
IP net.IP `gorm:"serializer:json"`
|
||||||
// Meta is a Peer system meta data
|
// Meta is a Peer system meta data
|
||||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||||
// Name is peer's name (machine name)
|
// Name is peer's name (machine name)
|
||||||
Name string
|
Name string
|
||||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||||
DNSLabel string
|
DNSLabel string
|
||||||
// Status peer's management connection status
|
// Status peer's management connection status
|
||||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||||
// The user ID that registered the peer
|
// The user ID that registered the peer
|
||||||
UserID string
|
UserID string `diff:"-"`
|
||||||
// SSHKey is a public SSH key of the peer
|
// SSHKey is a public SSH key of the peer
|
||||||
SSHKey string
|
SSHKey string
|
||||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||||
SSHEnabled bool
|
SSHEnabled bool
|
||||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||||
// Works with LastLogin
|
// Works with LastLogin
|
||||||
LoginExpirationEnabled bool
|
LoginExpirationEnabled bool `diff:"-"`
|
||||||
|
|
||||||
InactivityExpirationEnabled bool
|
InactivityExpirationEnabled bool `diff:"-"`
|
||||||
// LastLogin the time when peer performed last login operation
|
// LastLogin the time when peer performed last login operation
|
||||||
LastLogin time.Time
|
LastLogin time.Time `diff:"-"`
|
||||||
// CreatedAt records the time the peer was created
|
// CreatedAt records the time the peer was created
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time `diff:"-"`
|
||||||
// Indicate ephemeral peer attribute
|
// Indicate ephemeral peer attribute
|
||||||
Ephemeral bool
|
Ephemeral bool `diff:"-"`
|
||||||
// Geo location based on connection IP
|
// Geo location based on connection IP
|
||||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeerStatus struct { //nolint:revive
|
type PeerStatus struct { //nolint:revive
|
||||||
@ -189,7 +189,6 @@ func (p *Peer) Copy() *Peer {
|
|||||||
CreatedAt: p.CreatedAt,
|
CreatedAt: p.CreatedAt,
|
||||||
Ephemeral: p.Ephemeral,
|
Ephemeral: p.Ephemeral,
|
||||||
Location: p.Location,
|
Location: p.Location,
|
||||||
|
|
||||||
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: p.InactivityExpirationEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1253,3 +1253,322 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
|||||||
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
|
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
|
||||||
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create a user with auto groups
|
||||||
|
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{
|
||||||
|
{
|
||||||
|
Id: "regularUser1",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
AutoGroups: []string{"groupA"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "regularUser2",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
AutoGroups: []string{"groupB"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "regularUser3",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
AutoGroups: []string{"groupC"},
|
||||||
|
},
|
||||||
|
}, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var peer4 *nbpeer.Peer
|
||||||
|
var peer5 *nbpeer.Peer
|
||||||
|
var peer6 *nbpeer.Peer
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
|
||||||
|
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding peer to unlinked group should not update account peers and not send peer update
|
||||||
|
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedPeerKey := key.PublicKey().String()
|
||||||
|
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
|
||||||
|
Key: expectedPeerKey,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting peer with unlinked group should not update account peers and not send peer update
|
||||||
|
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating peer label should update account peers and send peer update
|
||||||
|
t.Run("updating peer label", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
peer1.Name = "peer-1"
|
||||||
|
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding peer to group linked with policy should update account peers and send peer update
|
||||||
|
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
ID: "policy",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedPeerKey := key.PublicKey().String()
|
||||||
|
peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
|
||||||
|
Key: expectedPeerKey,
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting peer with linked group to policy should update account peers and send peer update
|
||||||
|
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding peer to group linked with route should update account peers and send peer update
|
||||||
|
t.Run("adding peer to group linked with route", func(t *testing.T) {
|
||||||
|
route := nbroute.Route{
|
||||||
|
ID: "testingRoute1",
|
||||||
|
Network: netip.MustParsePrefix("100.65.250.202/32"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: nbroute.IPv4Network,
|
||||||
|
PeerGroups: []string{"groupB"},
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"groupB"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||||
|
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||||
|
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedPeerKey := key.PublicKey().String()
|
||||||
|
peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
|
||||||
|
Key: expectedPeerKey,
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting peer with linked group to route should update account peers and send peer update
|
||||||
|
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding peer to group linked with name server group should update account peers and send peer update
|
||||||
|
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
|
||||||
|
_, err = manager.CreateNameServerGroup(
|
||||||
|
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
[]string{"groupC"},
|
||||||
|
true, []string{}, true, userID, false,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedPeerKey := key.PublicKey().String()
|
||||||
|
peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{
|
||||||
|
Key: expectedPeerKey,
|
||||||
|
LoginExpirationEnabled: true,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting peer with linked group to name server group should update account peers and send peer update
|
||||||
|
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ruleGroups returns a list of all groups referenced in the policy's rules,
|
||||||
|
// including sources and destinations.
|
||||||
|
func (p *Policy) ruleGroups() []string {
|
||||||
|
groups := make([]string, 0)
|
||||||
|
for _, rule := range p.Rules {
|
||||||
|
groups = append(groups, rule.Sources...)
|
||||||
|
groups = append(groups, rule.Destinations...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
// FirewallRule is a rule of the firewall.
|
// FirewallRule is a rule of the firewall.
|
||||||
type FirewallRule struct {
|
type FirewallRule struct {
|
||||||
// PeerIP of the peer
|
// PeerIP of the peer
|
||||||
@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
|||||||
|
|
||||||
// savePolicy saves or updates a policy in the given account.
|
// savePolicy saves or updates a policy in the given account.
|
||||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||||
for index, rule := range policyToSave.Rules {
|
for index, rule := range policyToSave.Rules {
|
||||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||||
@ -442,18 +457,25 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
|||||||
if isUpdate {
|
if isUpdate {
|
||||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||||
if policyIdx < 0 {
|
if policyIdx < 0 {
|
||||||
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldPolicy := account.Policies[policyIdx]
|
||||||
// Update the existing policy
|
// Update the existing policy
|
||||||
account.Policies[policyIdx] = policyToSave
|
account.Policies[policyIdx] = policyToSave
|
||||||
return nil
|
|
||||||
|
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||||
|
|
||||||
|
return updateAccountPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the new policy to the account
|
// Add the new policy to the account
|
||||||
account.Policies = append(account.Policies, policyToSave)
|
account.Policies = append(account.Policies, policyToSave)
|
||||||
|
|
||||||
return nil
|
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||||
|
@ -5,7 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int {
|
|||||||
return 0 // a is equal to b
|
return 0 // a is equal to b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer3.ID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupD",
|
||||||
|
Name: "GroupD",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||||
|
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-rule-groups-no-peers",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupB"},
|
||||||
|
Destinations: []string{"groupC"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving policy with source group containing peers, but destination group without peers should
|
||||||
|
// update account's peers and send peer update
|
||||||
|
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-source-has-peers-destination-none",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupB"},
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving policy with destination group containing peers, but source group without peers should
|
||||||
|
// update account's peers and send peer update
|
||||||
|
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-destination-has-peers-source-none",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: false,
|
||||||
|
Sources: []string{"groupC"},
|
||||||
|
Destinations: []string{"groupD"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg2)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving policy with destination and source groups containing peers should update account's peers
|
||||||
|
// and send peer update
|
||||||
|
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-source-destination-peers",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupD"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Disabling policy with destination and source groups containing peers should update account's peers
|
||||||
|
// and send peer update
|
||||||
|
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-source-destination-peers",
|
||||||
|
Enabled: false,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupD"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||||
|
// or send peer update
|
||||||
|
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-source-destination-peers",
|
||||||
|
Description: "updated description",
|
||||||
|
Enabled: false,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Enabling policy with destination and source groups containing peers should update account's peers
|
||||||
|
// and send peer update
|
||||||
|
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-source-destination-peers",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupD"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving unchanged policy should trigger account peers update but not send peer update
|
||||||
|
t.Run("saving unchanged policy", func(t *testing.T) {
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy-source-destination-peers",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupD"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting policy should trigger account peers update and send peer update
|
||||||
|
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
|
policyID := "policy-source-destination-peers"
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting policy with destination group containing peers, but source group without peers should
|
||||||
|
// update account's peers and send peer update
|
||||||
|
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||||
|
policyID := "policy-destination-has-peers-source-none"
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg2)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||||
|
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||||
|
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -67,7 +67,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||||
if exists {
|
|
||||||
|
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,13 +149,9 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh
|
|||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check policy links
|
// Check if posture check is linked to any policy
|
||||||
for _, policy := range account.Policies {
|
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
||||||
for _, id := range policy.SourcePostureChecks {
|
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
||||||
if id == postureChecksID {
|
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
postureChecks := account.PostureChecks[postureChecksIdx]
|
||||||
@ -217,3 +214,25 @@ func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
||||||
|
for _, policy := range account.Policies {
|
||||||
|
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||||
|
return true, policy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
||||||
|
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
||||||
|
if !isLinked {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
||||||
|
}
|
||||||
|
@ -3,7 +3,10 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/group"
|
||||||
|
"github.com/rs/xid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
@ -118,3 +121,458 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
|
|||||||
|
|
||||||
return am.Store.GetAccount(context.Background(), account.Id)
|
return am.Store.GetAccount(context.Background(), account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
postureCheck := posture.Checks{
|
||||||
|
ID: "postureCheck",
|
||||||
|
Name: "postureCheck",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.28.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Saving unused posture check should not update account peers and not send peer update
|
||||||
|
t.Run("saving unused posture check", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating unused posture check should not update account peers and not send peer update
|
||||||
|
t.Run("updating unused posture check", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.29.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policyA",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheck.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Linking posture check to policy should trigger update account peers and send peer update
|
||||||
|
t.Run("linking posture check to policy with peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating linked posture checks should update account peers and send peer update
|
||||||
|
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||||
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.29.0",
|
||||||
|
},
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving unchanged posture check should not trigger account peers update and not send peer update
|
||||||
|
// since there is no change in the network map
|
||||||
|
t.Run("saving unchanged posture check", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Removing posture check from policy should trigger account peers update and send peer update
|
||||||
|
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
policy.SourcePostureChecks = []string{}
|
||||||
|
|
||||||
|
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting unused posture check should not trigger account peers update and not send peer update
|
||||||
|
t.Run("deleting unused posture check", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||||
|
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||||
|
policy = Policy{
|
||||||
|
ID: "policyB",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupB"},
|
||||||
|
Destinations: []string{"groupC"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheck.ID},
|
||||||
|
}
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.29.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating linked posture check to policy where destination has peers but source does not
|
||||||
|
// should trigger account peers update and send peer update
|
||||||
|
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
|
||||||
|
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||||
|
})
|
||||||
|
policy = Policy{
|
||||||
|
ID: "policyB",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupB"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheck.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.29.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating linked posture check to policy where source has peers but destination does not,
|
||||||
|
// should not trigger account peers update or send peer update
|
||||||
|
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||||
|
policy = Policy{
|
||||||
|
ID: "policyB",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupB"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheck.ID},
|
||||||
|
}
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.29.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||||
|
// should trigger account peers update and send peer update
|
||||||
|
t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||||
|
policy = Policy{
|
||||||
|
ID: "policyB",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupB"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheck.ID},
|
||||||
|
}
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{
|
||||||
|
LinuxPath: "/usr/bin/netbird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Policies: []*Policy{
|
||||||
|
{
|
||||||
|
ID: "policyA",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{"checkA"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: map[string]*group.Group{
|
||||||
|
"groupA": {
|
||||||
|
ID: "groupA",
|
||||||
|
Peers: []string{"peer1"},
|
||||||
|
},
|
||||||
|
"groupB": {
|
||||||
|
ID: "groupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PostureChecks: []*posture.Checks{
|
||||||
|
{
|
||||||
|
ID: "checkA",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "checkB",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||||
|
assert.True(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
||||||
|
assert.False(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("posture check does not exist", func(t *testing.T) {
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
||||||
|
assert.False(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||||
|
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
||||||
|
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||||
|
assert.True(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||||
|
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
||||||
|
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||||
|
assert.True(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||||
|
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
||||||
|
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||||
|
assert.False(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||||
|
account.Groups["groupA"].Peers = []string{}
|
||||||
|
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||||
|
assert.False(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
|
|
||||||
@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldRoute := account.Routes[routeToSave.ID]
|
||||||
account.Routes[routeToSave.ID] = routeToSave
|
account.Routes[routeToSave.ID] = routeToSave
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
|
|
||||||
@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if isRouteChangeAffectPeers(account, routy) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
|||||||
}
|
}
|
||||||
return &portInfo
|
return &portInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||||
|
// if it has a routing peer, distribution, or peer groups that include peers
|
||||||
|
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||||
|
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||||
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, err := createRouterManager(t)
|
||||||
|
require.NoError(t, err, "failed to create account manager")
|
||||||
|
|
||||||
|
account, err := initTestRouteAccount(t, manager)
|
||||||
|
require.NoError(t, err, "failed to init testing account")
|
||||||
|
|
||||||
|
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
|
||||||
|
t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) {
|
||||||
|
route := route.Route{
|
||||||
|
ID: "testingRoute1",
|
||||||
|
Network: netip.MustParsePrefix("100.65.250.202/32"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
PeerGroups: []string{"groupA"},
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"groupA"},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||||
|
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||||
|
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
|
||||||
|
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
|
||||||
|
route := route.Route{
|
||||||
|
ID: "testingRoute2",
|
||||||
|
Network: netip.MustParsePrefix("192.0.2.0/32"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
PeerGroups: []string{routeGroup3},
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{routeGroup3},
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||||
|
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||||
|
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
baseRoute := route.Route{
|
||||||
|
ID: "testingRoute3",
|
||||||
|
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
Peer: peer1ID,
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{routeGroup1},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creating route should update account peers and send peer update
|
||||||
|
t.Run("creating route with a routing peer", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
newRoute, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||||
|
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
|
||||||
|
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
baseRoute = *newRoute
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating the route should update account peers and send peer update when there is peers in group
|
||||||
|
t.Run("updating route", func(t *testing.T) {
|
||||||
|
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Updating unchanged route should update account peers and not send peer update
|
||||||
|
t.Run("updating unchanged route", func(t *testing.T) {
|
||||||
|
baseRoute.Groups = []string{routeGroup1, routeGroup2}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Deleting the route should update account peers and send peer update
|
||||||
|
t.Run("deleting route", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding peer to route peer groups that do not have any peers should update account peers and send peer update
|
||||||
|
t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) {
|
||||||
|
newRoute := route.Route{
|
||||||
|
Network: netip.MustParsePrefix("192.168.12.0/16"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
PeerGroups: []string{"groupB"},
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{routeGroup1},
|
||||||
|
}
|
||||||
|
_, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||||
|
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||||
|
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupB",
|
||||||
|
Name: "GroupB",
|
||||||
|
Peers: []string{peer1ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Adding peer to route groups that do not have any peers should update account peers and send peer update
|
||||||
|
t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) {
|
||||||
|
newRoute := route.Route{
|
||||||
|
Network: netip.MustParsePrefix("192.168.13.0/16"),
|
||||||
|
NetID: "superNet",
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
PeerGroups: []string{"groupB"},
|
||||||
|
Description: "super",
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"groupC"},
|
||||||
|
}
|
||||||
|
_, err := manager.CreateRoute(
|
||||||
|
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||||
|
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||||
|
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupC",
|
||||||
|
Name: "GroupC",
|
||||||
|
Peers: []string{peer1ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -323,8 +323,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
|
|
||||||
return newKey, nil
|
return newKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
@ -352,3 +353,73 @@ func TestSetupKey_Copy(t *testing.T) {
|
|||||||
key.UpdatedAt, key.AutoGroups)
|
key.UpdatedAt, key.AutoGroups)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"group"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
var setupKey *SetupKey
|
||||||
|
|
||||||
|
// Creating setup key should not update account peers and not send peer update
|
||||||
|
t.Run("creating setup key", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saving setup key should not update account peers and not send peer update
|
||||||
|
t.Run("saving setup key", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
5
management/server/testdata/store.sql
vendored
5
management/server/testdata/store.sql
vendored
@ -26,8 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
|
|||||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||||
|
|
||||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||||
|
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||||
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
|
INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
|
||||||
INSERT INTO installations VALUES(1,'');
|
INSERT INTO installations VALUES(1,'');
|
||||||
|
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||||
|
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||||
|
@ -2,9 +2,13 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/differs"
|
||||||
|
"github.com/r3labs/diff/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@ -14,14 +18,17 @@ import (
|
|||||||
const channelBufferSize = 100
|
const channelBufferSize = 100
|
||||||
|
|
||||||
type UpdateMessage struct {
|
type UpdateMessage struct {
|
||||||
Update *proto.SyncResponse
|
Update *proto.SyncResponse
|
||||||
|
NetworkMap *NetworkMap
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeersUpdateManager struct {
|
type PeersUpdateManager struct {
|
||||||
// peerChannels is an update channel indexed by Peer.ID
|
// peerChannels is an update channel indexed by Peer.ID
|
||||||
peerChannels map[string]chan *UpdateMessage
|
peerChannels map[string]chan *UpdateMessage
|
||||||
|
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||||
|
peerUpdateMessage map[string]*UpdateMessage
|
||||||
// channelsMux keeps the mutex to access peerChannels
|
// channelsMux keeps the mutex to access peerChannels
|
||||||
channelsMux *sync.Mutex
|
channelsMux *sync.RWMutex
|
||||||
// metrics provides method to collect application metrics
|
// metrics provides method to collect application metrics
|
||||||
metrics telemetry.AppMetrics
|
metrics telemetry.AppMetrics
|
||||||
}
|
}
|
||||||
@ -29,9 +36,10 @@ type PeersUpdateManager struct {
|
|||||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||||
return &PeersUpdateManager{
|
return &PeersUpdateManager{
|
||||||
peerChannels: make(map[string]chan *UpdateMessage),
|
peerChannels: make(map[string]chan *UpdateMessage),
|
||||||
channelsMux: &sync.Mutex{},
|
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||||
metrics: metrics,
|
channelsMux: &sync.RWMutex{},
|
||||||
|
metrics: metrics,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
var found, dropped bool
|
var found, dropped bool
|
||||||
|
|
||||||
|
// skip sending sync update to the peer if there is no change in update message,
|
||||||
|
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||||
|
if update.NetworkMap != nil {
|
||||||
|
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||||
|
if !updated {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
p.channelsMux.Lock()
|
p.channelsMux.Lock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
p.channelsMux.Unlock()
|
p.channelsMux.Unlock()
|
||||||
if p.metrics != nil {
|
if p.metrics != nil {
|
||||||
@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if update.NetworkMap != nil {
|
||||||
|
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||||
|
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
|
||||||
|
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||||
|
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.peerUpdateMessage[peerID] = update
|
||||||
|
}
|
||||||
|
|
||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
if channel, ok := p.peerChannels[peerID]; ok {
|
||||||
found = true
|
found = true
|
||||||
select {
|
select {
|
||||||
@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
|||||||
closed = true
|
closed = true
|
||||||
delete(p.peerChannels, peerID)
|
delete(p.peerChannels, peerID)
|
||||||
close(channel)
|
close(channel)
|
||||||
|
delete(p.peerUpdateMessage, peerID)
|
||||||
}
|
}
|
||||||
// mbragin: todo shouldn't it be more? or configurable?
|
// mbragin: todo shouldn't it be more? or configurable?
|
||||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||||
@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
|||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
if channel, ok := p.peerChannels[peerID]; ok {
|
||||||
delete(p.peerChannels, peerID)
|
delete(p.peerChannels, peerID)
|
||||||
close(channel)
|
close(channel)
|
||||||
|
delete(p.peerUpdateMessage, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||||
@ -170,3 +200,72 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
|||||||
|
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||||
|
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||||
|
p.channelsMux.RLock()
|
||||||
|
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||||
|
p.channelsMux.RUnlock()
|
||||||
|
|
||||||
|
if lastSentUpdate != nil {
|
||||||
|
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !updated {
|
||||||
|
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||||
|
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
|
||||||
|
isNew, err = true, nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
differ, err := diff.NewDiffer(
|
||||||
|
diff.CustomValueDiffers(&differs.NetIPAddr{}),
|
||||||
|
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to create differ: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
|
||||||
|
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
|
||||||
|
|
||||||
|
changelog, err := differ.Diff(lastSentFiles, currFiles)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||||
|
}
|
||||||
|
if len(changelog) > 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||||
|
}
|
||||||
|
return len(changelog) > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getChecksFiles returns a list of files from the given checks.
|
||||||
|
func getChecksFiles(checks []*proto.Checks) []string {
|
||||||
|
files := make([]string, 0, len(checks))
|
||||||
|
for _, check := range checks {
|
||||||
|
files = append(files, check.GetFiles()...)
|
||||||
|
}
|
||||||
|
return files
|
||||||
|
}
|
||||||
|
@ -2,10 +2,19 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// var peersUpdater *PeersUpdateManager
|
// var peersUpdater *PeersUpdateManager
|
||||||
@ -77,3 +86,470 @@ func TestCloseChannel(t *testing.T) {
|
|||||||
t.Error("Error closing the channel")
|
t.Error("Error closing the channel")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerID string
|
||||||
|
existingUpdate *UpdateMessage
|
||||||
|
newUpdate *UpdateMessage
|
||||||
|
expectedResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "update message with turn credentials update",
|
||||||
|
peerID: "peer",
|
||||||
|
newUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update message for peer without existing update",
|
||||||
|
peerID: "peer1",
|
||||||
|
newUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||||
|
},
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update message with no changes in update",
|
||||||
|
peerID: "peer2",
|
||||||
|
existingUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||||
|
},
|
||||||
|
newUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||||
|
},
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update message with changes in checks",
|
||||||
|
peerID: "peer3",
|
||||||
|
existingUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||||
|
},
|
||||||
|
newUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||||
|
Checks: []*proto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{"/usr/bin/netbird"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||||
|
},
|
||||||
|
expectedResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update message with lower serial number",
|
||||||
|
peerID: "peer4",
|
||||||
|
existingUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||||
|
},
|
||||||
|
newUpdate: &UpdateMessage{
|
||||||
|
Update: &proto.SyncResponse{
|
||||||
|
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||||
|
},
|
||||||
|
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||||
|
},
|
||||||
|
expectedResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := NewPeersUpdateManager(nil)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if tt.existingUpdate != nil {
|
||||||
|
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||||
|
assert.Equal(t, tt.expectedResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsNewPeerUpdateMessage(t *testing.T) {
|
||||||
|
t.Run("Unchanged value", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating routes network", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating routes groups", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating network map peers", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newPeer := &nbpeer.Peer{
|
||||||
|
IP: net.ParseIP("192.168.1.4"),
|
||||||
|
SSHEnabled: true,
|
||||||
|
Key: "peer4-key",
|
||||||
|
DNSLabel: "peer4",
|
||||||
|
SSHKey: "peer4-ssh-key",
|
||||||
|
}
|
||||||
|
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating process check", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, message)
|
||||||
|
|
||||||
|
newUpdateMessage3 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage3.Update.Checks = []*proto.Checks{}
|
||||||
|
newUpdateMessage3.Update.NetworkMap.Serial++
|
||||||
|
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
|
||||||
|
newUpdateMessage4 := createMockUpdateMessage(t)
|
||||||
|
check := &posture.Checks{
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{
|
||||||
|
LinuxPath: "/usr/local/netbird",
|
||||||
|
MacPath: "/usr/bin/netbird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||||
|
newUpdateMessage4.Update.NetworkMap.Serial++
|
||||||
|
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
|
||||||
|
newUpdateMessage5 := createMockUpdateMessage(t)
|
||||||
|
check = &posture.Checks{
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{
|
||||||
|
LinuxPath: "/usr/bin/netbird",
|
||||||
|
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||||
|
MacPath: "/usr/local/netbird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
|
||||||
|
newUpdateMessage5.Update.NetworkMap.Serial++
|
||||||
|
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating DNS configuration", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newDomain := "newexample.com"
|
||||||
|
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
|
||||||
|
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
|
||||||
|
newDomain,
|
||||||
|
)
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating peer IP", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating firewall rule", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Add new firewall rule", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newRule := &FirewallRule{
|
||||||
|
PeerIP: "192.168.1.3",
|
||||||
|
Direction: firewallRuleDirectionOUT,
|
||||||
|
Action: string(PolicyTrafficActionDrop),
|
||||||
|
Protocol: string(PolicyRuleProtocolUDP),
|
||||||
|
Port: "53",
|
||||||
|
}
|
||||||
|
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Removing nameserver", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating name server IP", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updating custom DNS zone", func(t *testing.T) {
|
||||||
|
newUpdateMessage1 := createMockUpdateMessage(t)
|
||||||
|
newUpdateMessage2 := createMockUpdateMessage(t)
|
||||||
|
|
||||||
|
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
|
||||||
|
newUpdateMessage2.Update.NetworkMap.Serial++
|
||||||
|
|
||||||
|
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, message)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Signal: &Host{
|
||||||
|
Proto: "https",
|
||||||
|
URI: "signal.uri",
|
||||||
|
Username: "",
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||||
|
TURNConfig: &TURNConfig{
|
||||||
|
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
SSHEnabled: true,
|
||||||
|
Key: "peer-key",
|
||||||
|
DNSLabel: "peer1",
|
||||||
|
SSHKey: "peer1-ssh-key",
|
||||||
|
}
|
||||||
|
|
||||||
|
secretManager := NewTimeBasedAuthSecretsManager(
|
||||||
|
NewPeersUpdateManager(nil),
|
||||||
|
&TURNConfig{
|
||||||
|
TimeBasedCredentials: false,
|
||||||
|
CredentialsTTL: util.Duration{
|
||||||
|
Duration: defaultDuration,
|
||||||
|
},
|
||||||
|
Secret: "secret",
|
||||||
|
Turns: []*Host{TurnTestHost},
|
||||||
|
},
|
||||||
|
&Relay{
|
||||||
|
Addresses: []string{"localhost:0"},
|
||||||
|
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||||
|
Secret: "secret",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
networkMap := &NetworkMap{
|
||||||
|
Network: &Network{Net: *ipNet, Serial: 1000},
|
||||||
|
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||||
|
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||||
|
Routes: []*nbroute.Route{
|
||||||
|
{
|
||||||
|
ID: "route1",
|
||||||
|
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
KeepRoute: true,
|
||||||
|
NetID: "route1",
|
||||||
|
Peer: "peer1",
|
||||||
|
NetworkType: 1,
|
||||||
|
Masquerade: true,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"test1", "test2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "route2",
|
||||||
|
Domains: domainList,
|
||||||
|
KeepRoute: true,
|
||||||
|
NetID: "route2",
|
||||||
|
Peer: "peer1",
|
||||||
|
NetworkType: 1,
|
||||||
|
Masquerade: true,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
Groups: []string{"test1", "test2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DNSConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
Primary: true,
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
Enabled: true,
|
||||||
|
SearchDomainsEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "ns1",
|
||||||
|
NameServers: []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
Primary: true,
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
Enabled: true,
|
||||||
|
SearchDomainsEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||||
|
},
|
||||||
|
FirewallRules: []*FirewallRule{
|
||||||
|
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsName := "example.com"
|
||||||
|
checks := []*posture.Checks{
|
||||||
|
{
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{
|
||||||
|
LinuxPath: "/usr/bin/netbird",
|
||||||
|
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
|
||||||
|
MacPath: "/usr/bin/netbird",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsCache := &DNSConfigCache{}
|
||||||
|
|
||||||
|
turnToken, err := secretManager.GenerateTurnToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayToken, err := secretManager.GenerateRelayToken()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UpdateMessage{
|
||||||
|
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
|
||||||
|
NetworkMap: networkMap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -473,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
||||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -485,15 +486,22 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||||
am.updateAccountPeers(ctx, account)
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error {
|
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
|
||||||
peers, err := account.FindUserPeers(targetUserID)
|
peers, err := account.FindUserPeers(targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.Internal, "failed to find user peers")
|
return false, status.Errorf(status.Internal, "failed to find user peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
hadPeers := len(peers) > 0
|
||||||
|
if !hadPeers {
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
peerIDs := make([]string, 0, len(peers))
|
peerIDs := make([]string, 0, len(peers))
|
||||||
@ -501,7 +509,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
|||||||
peerIDs = append(peerIDs, peer.ID)
|
peerIDs = append(peerIDs, peer.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.deletePeers(ctx, account, peerIDs, initiatorUserID)
|
return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||||
@ -745,6 +753,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
updatedUsers := make([]*UserInfo, 0, len(updates))
|
updatedUsers := make([]*UserInfo, 0, len(updates))
|
||||||
var (
|
var (
|
||||||
expiredPeers []*nbpeer.Peer
|
expiredPeers []*nbpeer.Peer
|
||||||
|
userIDs []string
|
||||||
eventsToStore []func()
|
eventsToStore []func()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -753,6 +762,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userIDs = append(userIDs, update.Id)
|
||||||
|
|
||||||
oldUser := account.Users[update.Id]
|
oldUser := account.Users[update.Id]
|
||||||
if oldUser == nil {
|
if oldUser == nil {
|
||||||
if !addIfNotExists {
|
if !addIfNotExists {
|
||||||
@ -816,7 +827,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.GroupsPropagationEnabled {
|
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1167,7 +1178,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
|
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
|
||||||
}
|
}
|
||||||
|
|
||||||
var allErrors error
|
var (
|
||||||
|
allErrors error
|
||||||
|
updateAccountPeers bool
|
||||||
|
)
|
||||||
|
|
||||||
deletedUsersMeta := make(map[string]map[string]any)
|
deletedUsersMeta := make(map[string]map[string]any)
|
||||||
for _, targetUserID := range targetUserIDs {
|
for _, targetUserID := range targetUserIDs {
|
||||||
@ -1193,12 +1207,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
|
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hadPeers {
|
||||||
|
updateAccountPeers = true
|
||||||
|
}
|
||||||
|
|
||||||
delete(account.Users, targetUserID)
|
delete(account.Users, targetUserID)
|
||||||
deletedUsersMeta[targetUserID] = meta
|
deletedUsersMeta[targetUserID] = meta
|
||||||
}
|
}
|
||||||
@ -1208,7 +1226,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
return fmt.Errorf("failed to delete users: %w", err)
|
return fmt.Errorf("failed to delete users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
for targetUserID, meta := range deletedUsersMeta {
|
for targetUserID, meta := range deletedUsersMeta {
|
||||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||||
@ -1217,11 +1237,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
return allErrors
|
return allErrors
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
|
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
|
||||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
@ -1232,16 +1252,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
|||||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := account.FindUser(targetUserID)
|
u, err := account.FindUser(targetUserID)
|
||||||
@ -1254,7 +1274,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
|||||||
tuCreatedAt = u.CreatedAt
|
tuCreatedAt = u.CreatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
|
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||||
@ -1333,3 +1353,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa
|
|||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account.
|
||||||
|
func areUsersLinkedToPeers(account *Account, userIDs []string) bool {
|
||||||
|
for _, peer := range account.Peers {
|
||||||
|
if slices.Contains(userIDs, peer.UserID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -10,9 +10,12 @@ import (
|
|||||||
"github.com/eko/gocache/v3/cache"
|
"github.com/eko/gocache/v3/cache"
|
||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUserAccountPeersUpdate(t *testing.T) {
|
||||||
|
// account groups propagation is enabled
|
||||||
|
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||||
|
|
||||||
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||||
|
ID: "groupA",
|
||||||
|
Name: "GroupA",
|
||||||
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
policy := Policy{
|
||||||
|
ID: "policy",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Creating a new regular user should not update account peers and not send peer update
|
||||||
|
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||||
|
Id: "regularUser1",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleUser,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
}, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// updating user with no linked peers should not update account peers and not send peer update
|
||||||
|
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||||
|
Id: "regularUser1",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleUser,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
}, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// deleting user with no linked peers should not update account peers and not send peer update
|
||||||
|
t.Run("deleting user with no linked peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// create a user and add new peer with the user
|
||||||
|
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||||
|
Id: "regularUser2",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
}, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedPeerKey := key.PublicKey().String()
|
||||||
|
peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{
|
||||||
|
Key: expectedPeerKey,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// updating user with linked peers should update account peers and send peer update
|
||||||
|
t.Run("updating user with linked peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
|
||||||
|
Id: "regularUser2",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: UserRoleAdmin,
|
||||||
|
Issued: UserIssuedAPI,
|
||||||
|
}, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// deleting user with linked peers should update account peers and send peer update
|
||||||
|
t.Run("deleting user with linked peers", func(t *testing.T) {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
peerShouldReceiveUpdate(t, peer4UpdMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user