Refactor posture checks to remove get and save account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-11 12:37:19 +03:00
parent 871500c5cc
commit 174e07fefd
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
9 changed files with 400 additions and 240 deletions

View File

@ -139,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager

View File

@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@ -96,7 +96,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
}
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface

View File

@ -7,8 +7,6 @@ import (
"regexp"
"github.com/hashicorp/go-version"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
}
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := Checks{
ID: postureChecksID,
Name: name,

View File

@ -2,16 +2,15 @@ package server
import (
"context"
"fmt"
"slices"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
const (
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@ -20,219 +19,279 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
return nil, status.NewAdminPermissionError()
}
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
exists, uniqName := am.savePostureChecks(account, postureChecks)
// we do not allow create new posture checks with non uniq name
if !exists && !uniqName {
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
action := activity.PostureCheckCreated
if exists {
action = activity.PostureCheckUpdated
account.Network.IncSerial()
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
var updateAccountPeers bool
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
return err
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
action = activity.PostureCheckUpdated
}
postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
})
if err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
return postureChecks, nil
}
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
return status.NewAdminPermissionError()
}
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
})
if err != nil {
return err
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
return nil
}
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
uniqName = true
for i, p := range account.PostureChecks {
if !exists && p.ID == postureChecks.ID {
account.PostureChecks[i] = postureChecks
exists = true
}
if p.Name == postureChecks.Name {
uniqName = false
}
}
if !exists {
account.PostureChecks = append(account.PostureChecks, postureChecks)
}
return
}
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
postureChecksIdx := -1
for i, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
return postureChecks, nil
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
peerPostureChecks := make(map[string]posture.Checks)
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if len(postureChecks) == 0 {
return nil
}
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
for _, policy := range policies {
if !policy.Enabled {
continue
}
if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return maps.Values(peerPostureChecks), nil
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
}
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return err
}
return nil
}
for _, policy := range account.Policies {
if !policy.Enabled {
continue
}
// For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
addPolicyPostureChecks(account, policy, peerPostureChecks)
for _, check := range checks {
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
}
}
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
checkCopy := check
postureChecksList = append(postureChecksList, &checkCopy)
postureChecks.ID = xid.New().String()
return nil
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
if err != nil {
return err
}
return postureChecksList
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
if err != nil {
return err
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup]
if ok && slices.Contains(group.Peers, peerID) {
return true
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
if err != nil {
log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil
}
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
if !exists {
return false
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
if !isLinked {
return false
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
}
}
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group"
@ -16,7 +17,6 @@ import (
const (
adminUserID = "adminUserID"
regularUserID = "regularUserID"
postureCheckID = "existing-id"
postureCheckName = "Existing check"
)
@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
_, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
assert.Error(t, err)
// regular users cannot list check
@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err)
// should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: "new-id",
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: postureCheckName,
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err)
// admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
})
}
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
assert.NoError(t, err)
// users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
assert.Error(t, err)
// admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
assert.NoError(t, err)
@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
postureCheck := posture.Checks{
ID: "postureCheck",
Name: "postureCheck",
postureCheckA := &posture.Checks{
Name: "postureCheckA",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
},
},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
require.NoError(t, err)
postureCheckB := &posture.Checks{
Name: "postureCheckB",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
SourcePostureChecks: []string{postureCheckB.ID},
}
// Linking posture check to policy should trigger update account peers and send peer update
@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture checks should update account peers and send peer update
t.Run("updating linked to posture check with peers", func(t *testing.T) {
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
assert.NoError(t, err)
select {
@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
SourcePostureChecks: []string{postureCheckB.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
SourcePostureChecks: []string{postureCheckB.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
SourcePostureChecks: []string{postureCheckB.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
postureCheckB.Checks = posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
})
}
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
account := &Account{
Policies: []*Policy{
{
ID: "policyA",
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestPostureChecksAccount(manager)
require.NoError(t, err, "failed to init testing account")
groupA := &group.Group{
ID: "groupA",
AccountID: account.Id,
Peers: []string{"peer1"},
}
groupB := &group.Group{
ID: "groupB",
AccountID: account.Id,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
postureCheckA := &posture.Checks{
Name: "checkA",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
Name: "checkB",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
require.NoError(t, err, "failed to save postureCheckB")
policy := &Policy{
ID: "policyA",
AccountID: account.Id,
Rules: []*PolicyRule{
{
ID: "ruleA",
PolicyID: "policyA",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{postureCheckA.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false)
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
policy.Rules[0].Sources = []string{"groupB"}
policy.Rules[0].Destinations = []string{"groupA"}
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
policy.Rules[0].Sources = []string{"groupA"}
policy.Rules[0].Destinations = []string{"groupB"}
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
groupA.Peers = []string{}
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
policy.Rules[0].Sources = []string{"nonExistentGroup"}
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
})
}

View File

@ -1257,12 +1257,60 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
var postureChecks []*posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
}
return postureChecks, nil
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
var postureCheck *posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
}
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
}
return postureCheck, nil
}
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
return status.Errorf(status.InvalidArgument, "name should be unique")
}
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save posture checks to store")
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete posture checks from store")
}
if result.RowsAffected == 0 {
return status.NewPostureChecksNotFoundError(postureChecksID)
}
return nil
}
// GetAccountRoutes retrieves network routes for an account.

View File

@ -140,3 +140,8 @@ func NewInvalidKeyIDError() error {
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
}
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
func NewPostureChecksNotFoundError(postureChecksID string) error {
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
}

View File

@ -83,7 +83,9 @@ type Store interface {
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error