diff --git a/management/server/account.go b/management/server/account.go index 8ebbb0fa0..114489c34 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f0..c8c2d5941 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - policy := Policy{ - ID: "policy", + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - policy := Policy{ - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + if err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - policy := Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { require.NoError(t, err, "failed to save group") - policy := Policy{ + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 73f3803b5..8255e4896 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,10 +6,8 @@ import ( "strconv" "github.com/gorilla/mux" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -122,14 +120,9 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - isUpdate := policyID != "" - - if policyID == "" { - policyID = xid.New().String() - } - - policy := server.Policy{ + policy := &server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, @@ -137,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID for _, rule := range req.Rules { pr := server.PolicyRule{ ID: policyID, // TODO: when policy can contain multiple rules, need refactor + PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, Sources: rule.Sources, @@ -225,7 +219,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + if err != nil { util.WriteError(r.Context(), err, w) return } @@ -236,7 +231,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - resp := toPolicyResponse(allGroups, &policy) + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 673ed33bb..46a4fbc1f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -49,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } - return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface diff --git a/management/server/policy.go b/management/server/policy.go index c7872591d..eb44a0436 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,13 +3,13 @@ package server import ( "context" _ "embed" - "slices" "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" + "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -125,6 +125,7 @@ type PolicyRule struct { func (pm *PolicyRule) Copy() *PolicyRule { rule := &PolicyRule{ ID: pm.ID, + PolicyID: pm.PolicyID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, @@ -171,6 +172,7 @@ type Policy struct { func (p *Policy) Copy() *Policy { c := &Policy{ ID: p.ID, + AccountID: p.AccountID, Name: p.Name, Description: p.Description, Enabled: p.Enabled, @@ -343,157 +345,209 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + var isUpdate = policy.ID != "" + var updateAccountPeers bool + var action = activity.PolicyAdded + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + saveFunc := transaction.CreatePolicy + if isUpdate { + action = activity.PolicyUpdated + saveFunc = transaction.SavePolicy + } + + return saveFunc(ctx, LockingStrengthUpdate, policy) + }) if err != nil { - return err + return nil, err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated - } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } + return policy, nil +} + +// DeletePolicy from the store +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + + var policy *Policy + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + if err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } -// DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - policy, err := am.deletePolicy(account, policyID) - if err != nil { - return err - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - - if am.anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return false, err } - oldPolicy := account.Policies[policyIdx] - // Update the existing policy - account.Policies[policyIdx] = policyToSave - - if !policyToSave.Enabled && !oldPolicy.Enabled { + if !policy.Enabled && !existingPolicy.Enabled { return false, nil } - updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) - return updateAccountPeers, nil + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) } - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) } -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, +// validatePolicy validates the policy and its rules. +func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { + if policy.ID != "" { + _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return err } + } else { + policy.ID = xid.New().String() + policy.AccountID = accountID } - return result + + groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for i, rule := range policy.Rules { + ruleCopy := rule.Copy() + if ruleCopy.ID == "" { + ruleCopy.ID = xid.New().String() + ruleCopy.PolicyID = policy.ID + } + + ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources) + ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations) + policy.Rules[i] = ruleCopy + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + + return nil } // getAllPeersFromGroups for given peer ID and list of groups @@ -574,27 +628,52 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { + validPostureCheckIDs := make(map[string]struct{}) + for _, check := range postureChecks { + validPostureCheckIDs[check.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := validPostureCheckIDs[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { + validGroupIDs := make(map[string]struct{}) + for _, group := range groups { + validGroupIDs[group.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := validGroupIDs[id]; exists { + validIDs = append(validIDs, id) + } + } + + return validIDs +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 81dc704c2..2cd7ac7fd 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1286,12 +1286,67 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + First(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewPolicyNotFoundError(policyID) + } + log.WithContext(ctx).Errorf("failed to get policy from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in store") + } + + return nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. @@ -1324,7 +1379,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) if result.Error != nil { if errors.Is(result.Error, gorm.ErrDuplicatedKey) { return status.Errorf(status.InvalidArgument, "name should be unique") diff --git a/management/server/status/error.go b/management/server/status/error.go index ba9e01c4f..bef1f5143 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -139,3 +139,8 @@ func NewGroupNotFoundError(groupID string) error { func NewPostureChecksNotFoundError(postureChecksID string) error { return Errorf(NotFound, "posture checks: %s not found", postureChecksID) } + +// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy +func NewPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "policy: %s not found", policyID) +} diff --git a/management/server/store.go b/management/server/store.go index 03b5821e7..108b262b1 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -80,7 +80,10 @@ type Store interface { DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)