mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-30 22:50:22 +02:00
Refactor retrieval of policy and posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@ -986,6 +986,24 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou
|
|||||||
func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) {
|
func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
||||||
|
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
||||||
|
}
|
||||||
|
@ -3,7 +3,6 @@ package http
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
@ -84,18 +83,12 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID })
|
|
||||||
if policyIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.savePolicy(w, r, accountID, userID, policyID)
|
h.savePolicy(w, r, accountID, userID, policyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -315,30 +315,16 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
|||||||
|
|
||||||
// GetPolicy from the store
|
// GetPolicy from the store
|
||||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, policy := range account.Policies {
|
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||||
if policy.ID == policyID {
|
|
||||||
return policy, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
@ -400,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
|
|
||||||
// ListPolicies from the store
|
// ListPolicies from the store
|
||||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
|
||||||
if err != nil {
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
return am.Store.GetAccountPolicies(ctx, accountID)
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Policies, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||||
|
@ -15,30 +15,16 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, postureChecks := range account.PostureChecks {
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||||
if postureChecks.ID == postureChecksID {
|
|
||||||
return postureChecks, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||||
@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.PostureChecks, nil
|
return am.Store.GetAccountPostureChecks(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||||
|
@ -489,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
|||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
First(&user, idQueryCondition, userID)
|
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@ -1095,7 +1095,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
|||||||
|
|
||||||
func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
|
||||||
var group nbgroup.Group
|
var group nbgroup.Group
|
||||||
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
|
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations).
|
||||||
|
Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "group not found")
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
@ -1109,7 +1110,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string)
|
|||||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
var group nbgroup.Group
|
var group nbgroup.Group
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
|
||||||
Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
|
Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "group not found")
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
@ -1118,3 +1119,48 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
}
|
}
|
||||||
return &group, nil
|
return &group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
|
||||||
|
var policies []*Policy
|
||||||
|
result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID).
|
||||||
|
Preload(clause.Associations).Find(&policies)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
|
||||||
|
}
|
||||||
|
return policies, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||||
|
var policy *Policy
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}).
|
||||||
|
Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
|
||||||
|
}
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||||
|
var postureChecks []*posture.Checks
|
||||||
|
result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
|
||||||
|
}
|
||||||
|
return postureChecks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||||
|
var postureCheck *posture.Checks
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}).
|
||||||
|
Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
|
||||||
|
}
|
||||||
|
return postureCheck, nil
|
||||||
|
}
|
||||||
|
@ -68,7 +68,12 @@ type Store interface {
|
|||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||||
|
|
||||||
|
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
|
||||||
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
|
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
||||||
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
|
Reference in New Issue
Block a user