mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 10:50:45 +01:00
refactor posture checks save and deletion
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
87c8430e99
commit
3b4bcdf5a4
@ -134,7 +134,7 @@ type AccountManager interface {
|
|||||||
HasConnectedChannel(peerID string) bool
|
HasConnectedChannel(peerID string) bool
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
|
||||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
|
@ -1027,6 +1027,14 @@ func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _
|
|||||||
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) SavePostureChecks(_ context.Context, _ LockingStrength, _ *posture.Checks) error {
|
||||||
|
return status.Errorf(status.Internal, "SavePostureChecks is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) DeletePostureChecks(_ context.Context, _ LockingStrength, _ string) error {
|
||||||
|
return status.Errorf(status.Internal, "DeletePostureChecks is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) {
|
func (s *FileStore) GetAccountRoutes(_ context.Context, _ string) ([]*route.Route, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
||||||
}
|
}
|
||||||
|
@ -163,13 +163,15 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isUpdate := postureChecksID != ""
|
||||||
|
|
||||||
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
|
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, isUpdate); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ type MockAccountManager struct {
|
|||||||
HasConnectedChannelFunc func(peerID string) bool
|
HasConnectedChannelFunc func(peerID string) bool
|
||||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
|
||||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManagerFunc func() idp.Manager
|
GetIdpManagerFunc func() idp.Manager
|
||||||
@ -730,9 +730,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
|
||||||
if am.SavePostureChecksFunc != nil {
|
if am.SavePostureChecksFunc != nil {
|
||||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, isUpdate)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ type Checks struct {
|
|||||||
ID string `gorm:"primaryKey"`
|
ID string `gorm:"primaryKey"`
|
||||||
|
|
||||||
// Name of the posture checks
|
// Name of the posture checks
|
||||||
Name string
|
Name string `gorm:"unique"`
|
||||||
|
|
||||||
// Description of the posture checks visible in the UI
|
// Description of the posture checks visible in the UI
|
||||||
Description string
|
Description string
|
||||||
|
@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
@ -27,85 +28,105 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
|||||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
// SavePostureChecks saves a posture check.
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
|
||||||
defer unlock()
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
return status.Errorf(status.PermissionDenied, "only admin users are allowed to update posture checks")
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := postureChecks.Validate(); err != nil {
|
if err := postureChecks.Validate(); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||||
}
|
}
|
||||||
|
postureChecks.AccountID = accountID
|
||||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
|
||||||
|
|
||||||
// we do not allow create new posture checks with non uniq name
|
|
||||||
if !exists && !uniqName {
|
|
||||||
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
|
|
||||||
}
|
|
||||||
|
|
||||||
action := activity.PostureCheckCreated
|
action := activity.PostureCheckCreated
|
||||||
if exists {
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if isUpdate {
|
||||||
action = activity.PostureCheckUpdated
|
action = activity.PostureCheckUpdated
|
||||||
account.Network.IncSerial()
|
|
||||||
|
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to get posture checks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks); err != nil {
|
||||||
|
return fmt.Errorf("failed to save posture checks: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||||
if exists {
|
|
||||||
|
if isUpdate {
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture check by ID.
|
||||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only admin users are allowed to delete posture checks")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete posture checks: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListPostureChecks returns a list of posture checks.
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -119,48 +140,20 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
|||||||
return am.Store.GetAccountPostureChecks(ctx, accountID)
|
return am.Store.GetAccountPostureChecks(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||||
uniqName = true
|
func (am *DefaultAccountManager) isPostureCheckLinkedToPolicy(ctx context.Context, postureChecksID, accountID string) error {
|
||||||
for i, p := range account.PostureChecks {
|
policies, err := am.Store.GetAccountPolicies(ctx, accountID)
|
||||||
if !exists && p.ID == postureChecks.ID {
|
if err != nil {
|
||||||
account.PostureChecks[i] = postureChecks
|
return err
|
||||||
exists = true
|
|
||||||
}
|
|
||||||
if p.Name == postureChecks.Name {
|
|
||||||
uniqName = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
account.PostureChecks = append(account.PostureChecks, postureChecks)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
|
|
||||||
postureChecksIdx := -1
|
|
||||||
for i, postureChecks := range account.PostureChecks {
|
|
||||||
if postureChecks.ID == postureChecksID {
|
|
||||||
postureChecksIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if postureChecksIdx < 0 {
|
|
||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check policy links
|
for _, policy := range policies {
|
||||||
for _, policy := range account.Policies {
|
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||||
for _, id := range policy.SourcePostureChecks {
|
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||||
if id == postureChecksID {
|
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
return nil
|
||||||
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
|
|
||||||
|
|
||||||
return postureChecks, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||||
|
@ -859,6 +859,7 @@ func getGormConfig() *gorm.Config {
|
|||||||
Logger: logger.Default.LogMode(logger.Silent),
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
CreateBatchSize: 400,
|
CreateBatchSize: 400,
|
||||||
PrepareStmt: true,
|
PrepareStmt: true,
|
||||||
|
TranslateError: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1129,6 +1130,27 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
|
|||||||
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SavePostureChecks saves a posture checks to the database.
|
||||||
|
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||||
|
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&postureCheck)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
|
||||||
|
return status.Errorf(status.InvalidArgument, "name should be unique")
|
||||||
|
}
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture checks from the database.
|
||||||
|
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error {
|
||||||
|
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&posture.Checks{}, idQueryCondition, postureChecksID).Error
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountRoutes retrieves network routes for an account.
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
|
||||||
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
|
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
|
||||||
|
@ -77,6 +77,8 @@ type Store interface {
|
|||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||||
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
Loading…
Reference in New Issue
Block a user