diff --git a/management/server/policy.go b/management/server/policy.go index 0b7fce48f..63ac36cbf 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,7 +3,7 @@ package server import ( "context" _ "embed" - "slices" + "fmt" "strconv" "strings" @@ -331,38 +331,83 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + if !user.IsAdminOrServiceUser() { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + } + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - if err = am.savePolicy(account, policy, isUpdate); err != nil { + if !user.IsAdminOrServiceUser() { + return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + } + + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + } + + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { return err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { return err } + for index, rule := range policy.Rules { + rule.Sources = getValidGroupIDs(groups, rule.Sources) + rule.Destinations = getValidGroupIDs(groups, rule.Destinations) + policy.Rules[index] = rule + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + action := activity.PolicyAdded if isUpdate { action = activity.PolicyUpdated + + if _, err = am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID); err != nil { + return err + } } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf(errNetworkSerialIncrementFmt, err) + } + + err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy) + if err != nil { + return fmt.Errorf("failed to save policy: %w", err) + } + + return nil + }) + if err != nil { + return err + } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return nil @@ -370,26 +415,42 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - policy, err := am.deletePolicy(account, policyID) + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + } + + policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) if err != nil { return err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf(errNetworkSerialIncrementFmt, err) + } + + err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + if err != nil { + return fmt.Errorf("failed to delete policy: %w", err) + } + return nil + }) + if err != nil { return err } - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } am.updateAccountPeers(ctx, account) return nil @@ -409,53 +470,6 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - - if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) - } - - // Update the existing policy - account.Policies[policyIdx] = policyToSave - return nil - } - - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return nil -} - func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { result := make([]*proto.FirewallRule, len(rules)) for i := range rules { @@ -550,28 +564,36 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { + validPostureCheckIDs := make(map[string]struct{}) + for _, check := range postureChecks { + validPostureCheckIDs[check.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := validPostureCheckIDs[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { + validGroupIDs := make(map[string]struct{}) + for _, group := range groups { + validGroupIDs[group.ID] = struct{}{} + } + + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := validGroupIDs[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs }