Refactor policy save and delete

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-17 14:11:22 +03:00
parent b66f331711
commit 408d0cd504
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547

View File

@ -3,7 +3,7 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"slices" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -331,38 +331,83 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if !user.IsAdminOrServiceUser() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
} }
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
if err = am.savePolicy(account, policy, isUpdate); err != nil { if !user.IsAdminOrServiceUser() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err return err
} }
account.Network.IncSerial() postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
for index, rule := range policy.Rules {
rule.Sources = getValidGroupIDs(groups, rule.Sources)
rule.Destinations = getValidGroupIDs(groups, rule.Destinations)
policy.Rules[index] = rule
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
action := activity.PolicyAdded action := activity.PolicyAdded
if isUpdate { if isUpdate {
action = activity.PolicyUpdated action = activity.PolicyUpdated
if _, err = am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID); err != nil {
return err
} }
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy)
if err != nil {
return fmt.Errorf("failed to save policy: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
@ -370,26 +415,42 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store // DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
policy, err := am.deletePolicy(account, policyID) if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
if err != nil { if err != nil {
return err return err
} }
account.Network.IncSerial() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = am.Store.SaveAccount(ctx, account); err != nil { err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
if err != nil {
return fmt.Errorf("failed to delete policy: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
@ -409,53 +470,6 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
} }
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return nil
}
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules)) result := make([]*proto.FirewallRule, len(rules))
for i := range rules { for i := range rules {
@ -550,28 +564,36 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil return nil
} }
// filterValidPostureChecks filters and returns the posture check IDs from the given list // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
// that are valid within the provided account. func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { validPostureCheckIDs := make(map[string]struct{})
result := make([]string, 0, len(postureChecksIds)) for _, check := range postureChecks {
for _, id := range postureChecksIds { validPostureCheckIDs[check.ID] = struct{}{}
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
} }
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. validIDs := make([]string, 0, len(postureChecksIds))
func filterValidGroupIDs(account *Account, groupIDs []string) []string { for _, id := range postureChecksIds {
result := make([]string, 0, len(groupIDs)) if _, exists := validPostureCheckIDs[id]; exists {
for _, groupID := range groupIDs { validIDs = append(validIDs, id)
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
} }
} }
return result
return validIDs
}
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
validGroupIDs := make(map[string]struct{})
for _, group := range groups {
validGroupIDs[group.ID] = struct{}{}
}
validIDs := make([]string, 0, len(groupIDs))
for _, id := range groupIDs {
if _, exists := validGroupIDs[id]; exists {
validIDs = append(validIDs, id)
}
}
return validIDs
} }