Run diff for client posture checks only

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-22 17:35:54 +03:00
parent b025dbeb75
commit abdba6c650
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
3 changed files with 66 additions and 19 deletions

View File

@ -310,7 +310,6 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
},
},
NetworkMap: &NetworkMap{},
Checks: []*posture.Checks{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
@ -1002,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks})
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}

View File

@ -7,7 +7,6 @@ import (
"time"
"github.com/netbirdio/netbird/management/server/differs"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/r3labs/diff/v3"
log "github.com/sirupsen/logrus"
@ -20,7 +19,6 @@ const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
NetworkMap *NetworkMap
Checks []*posture.Checks
}
type PeersUpdateManager struct {
@ -237,7 +235,10 @@ func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bo
return false, fmt.Errorf("failed to create differ: %v", err)
}
changelog, err := differ.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks)
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
changelog, err := differ.Diff(lastSentFiles, currFiles)
if err != nil {
return false, fmt.Errorf("failed to diff checks: %v", err)
}
@ -251,3 +252,12 @@ func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bo
}
return len(changelog) > 0, nil
}
// getChecksFiles returns a list of files from the given checks.
func getChecksFiles(checks []*proto.Checks) []string {
files := make([]string, 0, len(checks))
for _, check := range checks {
files = append(files, check.GetFiles()...)
}
return files
}

View File

@ -124,14 +124,12 @@ func TestHandlePeerMessageUpdate(t *testing.T) {
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
},
expectedResult: false,
},
@ -143,14 +141,12 @@ func TestHandlePeerMessageUpdate(t *testing.T) {
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
Checks: []*posture.Checks{{ID: "check1"}},
},
expectedResult: true,
},
@ -253,21 +249,58 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
assert.True(t, message)
})
t.Run("Updating posture checks", func(t *testing.T) {
t.Run("Updating process check", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newCheck := &posture.Checks{
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.False(t, message)
newUpdateMessage3 := createMockUpdateMessage(t)
newUpdateMessage3.Update.Checks = []*proto.Checks{}
newUpdateMessage3.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage3)
assert.NoError(t, err)
assert.True(t, message)
newUpdateMessage4 := createMockUpdateMessage(t)
check := &posture.Checks{
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "10.0",
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/local/netbird",
MacPath: "/usr/bin/netbird",
},
},
},
},
}
newUpdateMessage2.Checks = append(newUpdateMessage2.Checks, newCheck)
newUpdateMessage2.Update.NetworkMap.Serial++
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage4.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage4)
assert.NoError(t, err)
assert.True(t, message)
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2)
newUpdateMessage5 := createMockUpdateMessage(t)
check = &posture.Checks{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/local/netbird",
},
},
},
},
}
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage5.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage5)
assert.NoError(t, err)
assert.True(t, message)
})
@ -487,7 +520,13 @@ func createMockUpdateMessage(t *testing.T) *UpdateMessage {
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/bin/netbird",
},
},
},
},
},
@ -507,6 +546,5 @@ func createMockUpdateMessage(t *testing.T) *UpdateMessage {
return &UpdateMessage{
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
NetworkMap: networkMap,
Checks: checks,
}
}