mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-22 05:01:24 +01:00
Refactor retrieval of policy and posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
7561706627
commit
eab85644cd
@ -986,6 +986,24 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou
|
||||
func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
@ -84,18 +83,12 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID })
|
||||
if policyIdx < 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, policyID)
|
||||
}
|
||||
|
||||
|
@ -315,30 +315,16 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if policy.ID == policyID {
|
||||
return policy, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
@ -400,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
|
||||
}
|
||||
|
||||
return account.Policies, nil
|
||||
return am.Store.GetAccountPolicies(ctx, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||
|
@ -15,30 +15,16 @@ const (
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
for _, postureChecks := range account.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
return postureChecks, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return account.PostureChecks, nil
|
||||
return am.Store.GetAccountPostureChecks(ctx, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||
|
@ -489,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&user, idQueryCondition, userID)
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@ -1095,7 +1095,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
|
||||
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations).
|
||||
Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
@ -1109,7 +1110,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string)
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
|
||||
Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
|
||||
Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
@ -1118,3 +1119,48 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
|
||||
var policies []*Policy
|
||||
result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID).
|
||||
Preload(clause.Associations).Find(&policies)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
|
||||
}
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||
var policy *Policy
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}).
|
||||
Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
|
||||
}
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||
var postureCheck *posture.Checks
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}).
|
||||
Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
|
||||
}
|
||||
return postureCheck, nil
|
||||
}
|
||||
|
@ -68,7 +68,12 @@ type Store interface {
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
|
Loading…
Reference in New Issue
Block a user