Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-24 21:57:33 +03:00
parent 7561706627
commit eab85644cd
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 86 additions and 68 deletions

View File

@ -986,6 +986,24 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou
func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
}
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
}
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
}
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
}

View File

@ -3,7 +3,6 @@ package http
import (
"encoding/json"
"net/http"
"slices"
"strconv"
"github.com/gorilla/mux"
@ -84,18 +83,12 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID })
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
h.savePolicy(w, r, accountID, userID, policyID)
}

View File

@ -315,30 +315,16 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !(user.HasAdminPower() || user.IsServiceUser) {
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
for _, policy := range account.Policies {
if policy.ID == policyID {
return policy, nil
}
}
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
}
// SavePolicy in the store
@ -400,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
}
return account.Policies, nil
return am.Store.GetAccountPolicies(ctx, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {

View File

@ -15,30 +15,16 @@ const (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() {
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
for _, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks, nil
}
}
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() {
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return account.PostureChecks, nil
return am.Store.GetAccountPostureChecks(ctx, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {

View File

@ -489,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@ -1095,7 +1095,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations).
Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
@ -1109,7 +1110,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string)
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
@ -1118,3 +1119,48 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
return &group, nil
}
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
var policies []*Policy
result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID).
Preload(clause.Associations).Find(&policies)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
}
return policies, nil
}
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
var policy *Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}).
Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
}
return policy, nil
}
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
var postureChecks []*posture.Checks
result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
}
return postureChecks, nil
}
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
var postureCheck *posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}).
Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
}
return postureCheck, nil
}

View File

@ -68,7 +68,12 @@ type Store interface {
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error