mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 10:50:45 +01:00
add store policy save and method
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
c384874d7d
commit
87c8430e99
@ -118,7 +118,7 @@ func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) err
|
|||||||
return s.SaveAccount(ctx, account)
|
return s.SaveAccount(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, _ LockingStrength, accountId string) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
@ -1011,6 +1011,14 @@ func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) SavePolicy(_ context.Context, _ LockingStrength, _ *Policy) error {
|
||||||
|
return status.Errorf(status.Internal, "SavePolicy is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) DeletePolicy(_ context.Context, _ LockingStrength, _ string) error {
|
||||||
|
return status.Errorf(status.Internal, "DeletePolicy is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
|
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
@ -130,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
|
|
||||||
policy := server.Policy{
|
policy := server.Policy{
|
||||||
ID: policyID,
|
ID: policyID,
|
||||||
|
AccountID: accountID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Enabled: req.Enabled,
|
Enabled: req.Enabled,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
|
@ -502,7 +502,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"slices"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||||
@ -329,20 +329,48 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
|||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only admin users are allowed to update policies")
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
postureChecks, err := am.Store.GetAccountPostureChecks(ctx, accountID)
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, rule := range policy.Rules {
|
||||||
|
rule.Sources = getValidGroupIDs(groups, rule.Sources)
|
||||||
|
rule.Destinations = getValidGroupIDs(groups, rule.Destinations)
|
||||||
|
policy.Rules[index] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
if policy.SourcePostureChecks != nil {
|
||||||
|
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.SavePolicy(ctx, LockingStrengthUpdate, policy)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save policy: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -352,6 +380,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -359,26 +391,42 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
|
|
||||||
// DeletePolicy from the store
|
// DeletePolicy from the store
|
||||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
policy, err := am.deletePolicy(account, policyID)
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "only admin users are allowed to delete policies")
|
||||||
|
}
|
||||||
|
|
||||||
|
policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, policyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete policy: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -392,7 +440,7 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only admin users are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPolicies(ctx, accountID)
|
return am.Store.GetAccountPolicies(ctx, accountID)
|
||||||
@ -415,36 +463,6 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
|||||||
return policy, nil
|
return policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// savePolicy saves or updates a policy in the given account.
|
|
||||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
|
||||||
for index, rule := range policyToSave.Rules {
|
|
||||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
|
||||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
|
||||||
policyToSave.Rules[index] = rule
|
|
||||||
}
|
|
||||||
|
|
||||||
if policyToSave.SourcePostureChecks != nil {
|
|
||||||
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isUpdate {
|
|
||||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
|
||||||
if policyIdx < 0 {
|
|
||||||
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the existing policy
|
|
||||||
account.Policies[policyIdx] = policyToSave
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the new policy to the account
|
|
||||||
account.Policies = append(account.Policies, policyToSave)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||||
result := make([]*proto.FirewallRule, len(update))
|
result := make([]*proto.FirewallRule, len(update))
|
||||||
for i := range update {
|
for i := range update {
|
||||||
@ -558,28 +576,36 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||||
// that are valid within the provided account.
|
func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
|
||||||
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
validPostureCheckIDs := make(map[string]struct{})
|
||||||
result := make([]string, 0, len(postureChecksIds))
|
for _, check := range postureChecks {
|
||||||
|
validPostureCheckIDs[check.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
validIDs := make([]string, 0, len(postureChecksIds))
|
||||||
for _, id := range postureChecksIds {
|
for _, id := range postureChecksIds {
|
||||||
for _, postureCheck := range account.PostureChecks {
|
if _, exists := validPostureCheckIDs[id]; exists {
|
||||||
if id == postureCheck.ID {
|
validIDs = append(validIDs, id)
|
||||||
result = append(result, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
|
||||||
|
return validIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||||
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
|
||||||
result := make([]string, 0, len(groupIDs))
|
validGroupIDs := make(map[string]struct{})
|
||||||
for _, groupID := range groupIDs {
|
for _, group := range groups {
|
||||||
if _, exists := account.Groups[groupID]; exists {
|
validGroupIDs[group.ID] = struct{}{}
|
||||||
result = append(result, groupID)
|
}
|
||||||
|
|
||||||
|
validIDs := make([]string, 0, len(groupIDs))
|
||||||
|
for _, id := range groupIDs {
|
||||||
|
if _, exists := validGroupIDs[id]; exists {
|
||||||
|
validIDs = append(validIDs, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
|
||||||
|
return validIDs
|
||||||
}
|
}
|
||||||
|
@ -1007,8 +1007,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
||||||
}
|
}
|
||||||
@ -1106,6 +1107,18 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
|
|||||||
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SavePolicy saves a policy to the database.
|
||||||
|
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||||
|
return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&policy).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePolicy deletes a policy from the database.
|
||||||
|
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID string) error {
|
||||||
|
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&Policy{}, idQueryCondition, policyID).Error
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
|
||||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
|
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
|
||||||
|
@ -71,6 +71,8 @@ type Store interface {
|
|||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
|
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||||
|
DeletePolicy(ctx context.Context, lockStrength LockingStrength, postureCheckID string) error
|
||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
|
||||||
@ -97,7 +99,7 @@ type Store interface {
|
|||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
|
Loading…
Reference in New Issue
Block a user