From 3b4bcdf5a419f896625082e687d773c537ae6729 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 26 Sep 2024 16:28:49 +0300 Subject: [PATCH] refactor posture checks save and deletion Signed-off-by: bcmmbaga --- management/server/account.go | 2 +- management/server/file_store.go | 8 + .../server/http/posture_checks_handler.go | 4 +- management/server/mock_server/account_mock.go | 6 +- management/server/posture/checks.go | 2 +- management/server/posture_checks.go | 143 +++++++++--------- management/server/sql_store.go | 22 +++ management/server/store.go | 2 + 8 files changed, 108 insertions(+), 81 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 11c3a17e0..10f965111 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -134,7 +134,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/file_store.go b/management/server/file_store.go index 4c42bde0b..ad8d31688 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1027,6 +1027,14 @@ func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") } +func (s *FileStore) SavePostureChecks(_ context.Context, _ LockingStrength, _ *posture.Checks) error { + return status.Errorf(status.Internal, "SavePostureChecks is not implemented") +} + +func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _ string) error { + return status.Errorf(status.Internal, "DeletePostureChecks is not implemented") +} + func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) { return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") } diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bc..a75898a01 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -163,13 +163,15 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. } } + isUpdate := postureChecksID != "" + postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID) if err != nil { util.WriteError(r.Context(), err, w) return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, isUpdate); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a43e5a18c..0c953c789 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -730,9 +730,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error { if am.SavePostureChecksFunc != nil { - return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) + return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, isUpdate) } return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf..16d87b7bf 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -41,7 +41,7 @@ type Checks struct { ID string `gorm:"primaryKey"` // Name of the posture checks - Name string + Name string `gorm:"unique"` // Description of the posture checks visible in the UI Description string diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 7a03effb1..9249b3304 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "slices" "github.com/netbirdio/netbird/management/server/activity" @@ -27,85 +28,105 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) } -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err - } - - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if !user.HasAdminPower() || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "only admin users are allowed to update posture checks") } if err := postureChecks.Validate(); err != nil { return status.Errorf(status.InvalidArgument, err.Error()) //nolint } - - exists, uniqName := am.savePostureChecks(account, postureChecks) - - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") - } + postureChecks.AccountID = accountID action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() - } - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if isUpdate { + action = activity.PostureCheckUpdated + + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil { + return fmt.Errorf("failed to get posture checks: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + } + + if err = transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks); err != nil { + return fmt.Errorf("failed to save posture checks: %w", err) + } + return nil + }) + if err != nil { return err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if exists { + + if isUpdate { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) } return nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) + if !user.HasAdminPower() || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "only admin users are allowed to delete posture checks") + } + + if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil { + return err + } + + postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) if err != nil { return err } - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) + if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID); err != nil { + return fmt.Errorf("failed to delete posture checks: %w", err) + } + return nil + }) if err != nil { return err } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) + return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { @@ -119,48 +140,20 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return am.Store.GetAccountPostureChecks(ctx, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true - } - if p.Name == postureChecks.Name { - uniqName = false - } - } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) - } - return -} - -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func (am *DefaultAccountManager) isPostureCheckLinkedToPolicy(ctx context.Context, postureChecksID, accountID string) error { + policies, err := am.Store.GetAccountPolicies(ctx, accountID) + if err != nil { + return err } - // check policy links - for _, policy := range account.Policies { - for _, id := range policy.SourcePostureChecks { - if id == postureChecksID { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) - } + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) } } - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) - - return postureChecks, nil + return nil } // getPeerPostureChecks returns the posture checks applied for a given peer. diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6cf52836d..58c0adb24 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -859,6 +859,7 @@ func getGormConfig() *gorm.Config { Logger: logger.Default.LogMode(logger.Silent), CreateBatchSize: 400, PrepareStmt: true, + TranslateError: true, } } @@ -1129,6 +1130,27 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) } +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&postureCheck) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrDuplicatedKey) { + return status.Errorf(status.InvalidArgument, "name should be unique") + } + return result.Error + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error { + return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, idQueryCondition, postureChecksID).Error +} + // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) { return getRecords[*route.Route](s.db.WithContext(ctx), accountID) diff --git a/management/server/store.go b/management/server/store.go index 4ac58f6ee..b5088b751 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -77,6 +77,8 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error