mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-05 01:17:10 +02:00
[management] Refactor posture check to use store methods (#2874)
This commit is contained in:
parent
9810386937
commit
ca12bc6953
@ -139,7 +139,7 @@ type AccountManager interface {
|
|||||||
HasConnectedChannel(peerID string) bool
|
HasConnectedChannel(peerID string) bool
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
|
@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
|||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||||
return true
|
return true
|
||||||
@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||||
|
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range groups {
|
||||||
|
if group.HasPeers() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
|
||||||
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
},
|
},
|
||||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
postureChecks.ID = "postureCheck"
|
postureChecks.ID = "postureCheck"
|
||||||
testPostureChecks[postureChecks.ID] = postureChecks
|
testPostureChecks[postureChecks.ID] = postureChecks
|
||||||
|
|
||||||
if err := postureChecks.Validate(); err != nil {
|
if err := postureChecks.Validate(); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return postureChecks, nil
|
||||||
},
|
},
|
||||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||||
_, ok := testPostureChecks[postureChecksID]
|
_, ok := testPostureChecks[postureChecksID]
|
||||||
|
@ -96,7 +96,7 @@ type MockAccountManager struct {
|
|||||||
HasConnectedChannelFunc func(peerID string) bool
|
HasConnectedChannelFunc func(peerID string) bool
|
||||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManagerFunc func() idp.Manager
|
GetIdpManagerFunc func() idp.Manager
|
||||||
@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
if am.SavePostureChecksFunc != nil {
|
if am.SavePostureChecksFunc != nil {
|
||||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
||||||
|
@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
if am.anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
if am.anyGroupHasPeers(account, nsGroup.Groups) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
@ -279,9 +279,9 @@ func validateDomain(domain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||||
}
|
}
|
||||||
|
@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
return newPeer, networkMap, postureChecks, nil
|
return newPeer, networkMap, postureChecks, nil
|
||||||
@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
||||||
}
|
}
|
||||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
|
||||||
|
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
|
||||||
|
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, p)
|
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||||
|
@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||||
|
|
||||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
if am.anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
|||||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||||
|
|
||||||
return updateAccountPeers, nil
|
return updateAccountPeers, nil
|
||||||
}
|
}
|
||||||
@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
|||||||
// Add the new policy to the account
|
// Add the new policy to the account
|
||||||
account.Policies = append(account.Policies, policyToSave)
|
account.Policies = append(account.Policies, policyToSave)
|
||||||
|
|
||||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||||
|
@ -7,8 +7,6 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/rs/xid"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
||||||
if postureChecksID == "" {
|
|
||||||
postureChecksID = xid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
postureChecks := Checks{
|
postureChecks := Checks{
|
||||||
ID: postureChecksID,
|
ID: postureChecksID,
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -2,16 +2,14 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
"github.com/rs/xid"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
const (
|
|
||||||
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
@ -20,219 +18,284 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
|
||||||
|
|
||||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
if !user.HasAdminPower() {
|
||||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := postureChecks.Validate(); err != nil {
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
}
|
||||||
|
|
||||||
|
// SavePostureChecks saves a posture check.
|
||||||
|
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
if user.AccountID != accountID {
|
||||||
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
// we do not allow create new posture checks with non uniq name
|
|
||||||
if !exists && !uniqName {
|
|
||||||
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
action := activity.PostureCheckCreated
|
if !user.HasAdminPower() {
|
||||||
if exists {
|
return nil, status.NewAdminPermissionError()
|
||||||
action = activity.PostureCheckUpdated
|
|
||||||
account.Network.IncSerial()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
var updateAccountPeers bool
|
||||||
return err
|
var isUpdate = postureChecks.ID != ""
|
||||||
|
var action = activity.PostureCheckCreated
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUpdate {
|
||||||
|
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
action = activity.PostureCheckUpdated
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks.AccountID = accountID
|
||||||
|
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||||
|
|
||||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture check by ID.
|
||||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if user.AccountID != accountID {
|
||||||
if err != nil {
|
return status.NewUserNotPartOfAccountError()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
if !user.HasAdminPower() {
|
||||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
var postureChecks *posture.Checks
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListPostureChecks returns a list of posture checks.
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
|
||||||
uniqName = true
|
|
||||||
for i, p := range account.PostureChecks {
|
|
||||||
if !exists && p.ID == postureChecks.ID {
|
|
||||||
account.PostureChecks[i] = postureChecks
|
|
||||||
exists = true
|
|
||||||
}
|
|
||||||
if p.Name == postureChecks.Name {
|
|
||||||
uniqName = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
account.PostureChecks = append(account.PostureChecks, postureChecks)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
|
|
||||||
postureChecksIdx := -1
|
|
||||||
for i, postureChecks := range account.PostureChecks {
|
|
||||||
if postureChecks.ID == postureChecksID {
|
|
||||||
postureChecksIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if postureChecksIdx < 0 {
|
|
||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if posture check is linked to any policy
|
|
||||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
|
||||||
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
|
|
||||||
|
|
||||||
return postureChecks, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
|
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
|
||||||
peerPostureChecks := make(map[string]posture.Checks)
|
peerPostureChecks := make(map[string]*posture.Checks)
|
||||||
|
|
||||||
if len(account.PostureChecks) == 0 {
|
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(postureChecks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policy := range policies {
|
||||||
|
if !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return maps.Values(peerPostureChecks), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||||
|
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
|
||||||
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policy := range policies {
|
||||||
|
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||||
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePostureChecks validates the posture checks.
|
||||||
|
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
|
||||||
|
if err := postureChecks.Validate(); err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the posture check already has an ID, verify its existence in the store.
|
||||||
|
if postureChecks.ID != "" {
|
||||||
|
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, policy := range account.Policies {
|
// For new posture checks, ensure no duplicates by name.
|
||||||
if !policy.Enabled {
|
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
continue
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
|
for _, check := range checks {
|
||||||
addPolicyPostureChecks(account, policy, peerPostureChecks)
|
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
|
postureChecks.ID = xid.New().String()
|
||||||
for _, check := range peerPostureChecks {
|
|
||||||
checkCopy := check
|
return nil
|
||||||
postureChecksList = append(postureChecksList, &checkCopy)
|
}
|
||||||
|
|
||||||
|
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||||
|
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||||
|
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return postureChecksList
|
if !isInGroup {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||||
|
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||||
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
|
func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if !rule.Enabled {
|
if !rule.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sourceGroup := range rule.Sources {
|
for _, sourceGroup := range rule.Sources {
|
||||||
group, ok := account.Groups[sourceGroup]
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
|
||||||
if ok && slices.Contains(group.Peers, peerID) {
|
if err != nil {
|
||||||
return true
|
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
|
|
||||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
|
||||||
for _, postureCheck := range account.PostureChecks {
|
|
||||||
if postureCheck.ID == sourcePostureCheckID {
|
|
||||||
peerPostureChecks[sourcePostureCheckID] = *postureCheck
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
|
||||||
for _, policy := range account.Policies {
|
|
||||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
|
||||||
return true, policy
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
|
||||||
if !exists {
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
return false
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
for _, policy := range policies {
|
||||||
if !isLinked {
|
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||||
return false
|
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/group"
|
"github.com/netbirdio/netbird/management/server/group"
|
||||||
|
|
||||||
@ -16,7 +17,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
adminUserID = "adminUserID"
|
adminUserID = "adminUserID"
|
||||||
regularUserID = "regularUserID"
|
regularUserID = "regularUserID"
|
||||||
postureCheckID = "existing-id"
|
|
||||||
postureCheckName = "Existing check"
|
postureCheckName = "Existing check"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Generic posture check flow", func(t *testing.T) {
|
t.Run("Generic posture check flow", func(t *testing.T) {
|
||||||
// regular users can not create checks
|
// regular users can not create checks
|
||||||
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
_, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// regular users cannot list check
|
// regular users cannot list check
|
||||||
@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// should be possible to create posture check with uniq name
|
// should be possible to create posture check with uniq name
|
||||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||||
ID: postureCheckID,
|
|
||||||
Name: postureCheckName,
|
Name: postureCheckName,
|
||||||
Checks: posture.ChecksDefinition{
|
Checks: posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
assert.Len(t, checks, 1)
|
assert.Len(t, checks, 1)
|
||||||
|
|
||||||
// should not be possible to create posture check with non uniq name
|
// should not be possible to create posture check with non uniq name
|
||||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||||
ID: "new-id",
|
|
||||||
Name: postureCheckName,
|
Name: postureCheckName,
|
||||||
Checks: posture.ChecksDefinition{
|
Checks: posture.ChecksDefinition{
|
||||||
GeoLocationCheck: &posture.GeoLocationCheck{
|
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||||
@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// admins can update posture checks
|
// admins can update posture checks
|
||||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
ID: postureCheckID,
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
Name: postureCheckName,
|
MinVersion: "0.27.0",
|
||||||
Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
|
||||||
MinVersion: "0.27.0",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// users should not be able to delete posture checks
|
// users should not be able to delete posture checks
|
||||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
|
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// admin should be able to delete posture checks
|
// admin should be able to delete posture checks
|
||||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
|
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
postureCheck := posture.Checks{
|
postureCheckA := &posture.Checks{
|
||||||
ID: "postureCheck",
|
Name: "postureCheckA",
|
||||||
Name: "postureCheck",
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
postureCheckB := &posture.Checks{
|
||||||
|
Name: "postureCheckB",
|
||||||
AccountID: account.Id,
|
AccountID: account.Id,
|
||||||
Checks: posture.ChecksDefinition{
|
Checks: posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Linking posture check to policy should trigger update account peers and send peer update
|
// Linking posture check to policy should trigger update account peers and send peer update
|
||||||
@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
// Updating linked posture checks should update account peers and send peer update
|
// Updating linked posture checks should update account peers and send peer update
|
||||||
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||||
@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
ProcessCheck: &posture.ProcessCheck{
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
Processes: []posture.Process{
|
Processes: []posture.Process{
|
||||||
{
|
{
|
||||||
@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||||
account := &Account{
|
manager, err := createManager(t)
|
||||||
Policies: []*Policy{
|
require.NoError(t, err, "failed to create account manager")
|
||||||
{
|
|
||||||
ID: "policyA",
|
account, err := initTestPostureChecksAccount(manager)
|
||||||
Rules: []*PolicyRule{
|
require.NoError(t, err, "failed to init testing account")
|
||||||
{
|
|
||||||
Enabled: true,
|
groupA := &group.Group{
|
||||||
Sources: []string{"groupA"},
|
ID: "groupA",
|
||||||
Destinations: []string{"groupA"},
|
AccountID: account.Id,
|
||||||
},
|
Peers: []string{"peer1"},
|
||||||
},
|
|
||||||
SourcePostureChecks: []string{"checkA"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Groups: map[string]*group.Group{
|
|
||||||
"groupA": {
|
|
||||||
ID: "groupA",
|
|
||||||
Peers: []string{"peer1"},
|
|
||||||
},
|
|
||||||
"groupB": {
|
|
||||||
ID: "groupB",
|
|
||||||
Peers: []string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PostureChecks: []*posture.Checks{
|
|
||||||
{
|
|
||||||
ID: "checkA",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "checkB",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupB := &group.Group{
|
||||||
|
ID: "groupB",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Peers: []string{},
|
||||||
|
}
|
||||||
|
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||||
|
require.NoError(t, err, "failed to save groups")
|
||||||
|
|
||||||
|
postureCheckA := &posture.Checks{
|
||||||
|
Name: "checkA",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
|
||||||
|
require.NoError(t, err, "failed to save postureCheckA")
|
||||||
|
|
||||||
|
postureCheckB := &posture.Checks{
|
||||||
|
Name: "checkB",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
|
||||||
|
require.NoError(t, err, "failed to save postureCheckB")
|
||||||
|
|
||||||
|
policy := &Policy{
|
||||||
|
ID: "policyA",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "ruleA",
|
||||||
|
PolicyID: "policyA",
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheckA.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false)
|
||||||
|
require.NoError(t, err, "failed to save policy")
|
||||||
|
|
||||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check does not exist", func(t *testing.T) {
|
t.Run("posture check does not exist", func(t *testing.T) {
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
policy.Rules[0].Sources = []string{"groupB"}
|
||||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
policy.Rules[0].Destinations = []string{"groupA"}
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
|
||||||
|
require.NoError(t, err, "failed to update policy")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
policy.Rules[0].Sources = []string{"groupA"}
|
||||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
policy.Rules[0].Destinations = []string{"groupB"}
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
|
||||||
|
require.NoError(t, err, "failed to update policy")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
groupA.Peers = []string{}
|
||||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
require.NoError(t, err, "failed to save groups")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||||
account.Groups["groupA"].Peers = []string{}
|
policy.Rules[0].Sources = []string{"nonExistentGroup"}
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||||
|
err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
|
||||||
|
require.NoError(t, err, "failed to update policy")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
if am.isRouteChangeAffectPeers(account, &newRoute) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, routy) {
|
if am.isRouteChangeAffectPeers(account, routy) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
|||||||
|
|
||||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||||
// if it has a routing peer, distribution, or peer groups that include peers
|
// if it has a routing peer, distribution, or peer groups that include peers
|
||||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||||
}
|
}
|
||||||
|
@ -1305,12 +1305,57 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
|
|||||||
|
|
||||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||||
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
|
var postureChecks []*posture.Checks
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
|
||||||
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
|
var postureCheck *posture.Checks
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureCheck, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePostureChecks saves a posture checks to the database.
|
||||||
|
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to save posture checks to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture checks from the database.
|
||||||
|
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete posture checks from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewPostureChecksNotFoundError(postureChecksID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountRoutes retrieves network routes for an account.
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
postureChecksID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing posture checks",
|
||||||
|
postureChecksID: "csplshq7qv948l48f7t0",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing posture checks",
|
||||||
|
postureChecksID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty posture checks ID",
|
||||||
|
postureChecksID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, postureChecks)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, postureChecks)
|
||||||
|
require.Equal(t, tt.postureChecksID, postureChecks.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SavePostureChecks(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
postureChecks := &posture.Checks{
|
||||||
|
ID: "posture-checks-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.31.0",
|
||||||
|
},
|
||||||
|
OSVersionCheck: &posture.OSVersionCheck{
|
||||||
|
Ios: &posture.MinVersionCheck{
|
||||||
|
MinVersion: "13.0.1",
|
||||||
|
},
|
||||||
|
Linux: &posture.MinKernelVersionCheck{
|
||||||
|
MinKernelVersion: "5.3.3-dev",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||||
|
Locations: []posture.Location{
|
||||||
|
{
|
||||||
|
CountryCode: "DE",
|
||||||
|
CityName: "Berlin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Action: posture.CheckActionAllow,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, savePostureChecks, postureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeletePostureChecks(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
postureChecksID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "delete existing posture checks",
|
||||||
|
postureChecksID: "csplshq7qv948l48f7t0",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete non-existing posture checks",
|
||||||
|
postureChecksID: "non-existing-posture-checks-id",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with empty posture checks ID",
|
||||||
|
postureChecksID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, group)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -139,3 +139,8 @@ func NewGetAccountError(err error) error {
|
|||||||
func NewGroupNotFoundError(groupID string) error {
|
func NewGroupNotFoundError(groupID string) error {
|
||||||
return Errorf(NotFound, "group: %s not found", groupID)
|
return Errorf(NotFound, "group: %s not found", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
|
||||||
|
func NewPostureChecksNotFoundError(postureChecksID string) error {
|
||||||
|
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
|
||||||
|
}
|
||||||
|
@ -84,7 +84,9 @@ type Store interface {
|
|||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||||
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
|
|||||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
||||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
||||||
|
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||||
INSERT INTO installations VALUES(1,'');
|
INSERT INTO installations VALUES(1,'');
|
||||||
|
Loading…
x
Reference in New Issue
Block a user