From f118d81d3219b78b689206523e4159fcb495fa12 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 Nov 2024 12:46:05 +0300 Subject: [PATCH] [management] Refactor policy to use store methods (#2878) --- management/server/account.go | 2 +- management/server/account_test.go | 57 ++-- management/server/group_test.go | 5 +- management/server/http/policies_handler.go | 26 +- .../server/http/policies_handler_test.go | 4 +- management/server/mock_server/account_mock.go | 8 +- management/server/peer_test.go | 31 +- management/server/policy.go | 319 +++++++++++------- management/server/policy_test.go | 165 +++------ management/server/posture_checks_test.go | 43 +-- management/server/route_test.go | 3 +- management/server/setupkey_test.go | 5 +- management/server/sql_store.go | 82 ++++- management/server/sql_store_test.go | 158 +++++++++ management/server/status/error.go | 5 + management/server/store.go | 6 +- management/server/testdata/extended-store.sql | 1 + management/server/user_test.go | 5 +- 18 files changed, 576 insertions(+), 349 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 9fb56c855..fbe6fcc1a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f0..c8c2d5941 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - policy := Policy{ - ID: "policy", + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - policy := Policy{ - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + if err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - policy := Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { require.NoError(t, err, "failed to save group") - policy := Policy{ + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } diff --git a/management/server/group_test.go b/management/server/group_test.go index 59094a23e..ec017fc57 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) assert.NoError(t, err) // Saving a group linked to policy should update account peers and send peer update diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 73f3803b5..eff9092d4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,10 +6,8 @@ import ( "strconv" "github.com/gorilla/mux" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - isUpdate := policyID != "" - - if policyID == "" { - policyID = xid.New().String() - } - - policy := server.Policy{ + policy := &server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, } for _, rule := range req.Rules { + var ruleID string + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, + PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, Sources: rule.Sources, @@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + if err != nil { util.WriteError(r.Context(), err, w) return } @@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - resp := toPolicyResponse(allGroups, &policy) + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 228ebcbce..f8a897eb2 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } - return nil + return policy, nil }, GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 673ed33bb..46a4fbc1f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -49,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } - return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4e2dcb2c3..e410fa892 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { var ( group1 nbgroup.Group group2 nbgroup.Group - policy Policy ) group1.ID = xid.New().String() group2.ID = xid.New().String() group1.Name = "src" group2.Name = "dst" - policy.ID = xid.New().String() group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) @@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy.Name = "test" - policy.Enabled = true - policy.Rules = []*PolicyRule{ - { - Enabled: true, - Sources: []string{group1.ID}, - Destinations: []string{group2.ID}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, + policy := &Policy{ + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{group1.ID}, + Destinations: []string{group2.ID}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -1445,8 +1445,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1457,7 +1456,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) require.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/policy.go b/management/server/policy.go index c7872591d..2d3abc3f1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,13 +3,13 @@ package server import ( "context" _ "embed" - "slices" "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" + "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -125,6 +125,7 @@ type PolicyRule struct { func (pm *PolicyRule) Copy() *PolicyRule { rule := &PolicyRule{ ID: pm.ID, + PolicyID: pm.PolicyID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, @@ -171,6 +172,7 @@ type Policy struct { func (p *Policy) Copy() *Policy { c := &Policy{ ID: p.ID, + AccountID: p.AccountID, Name: p.Name, Description: p.Description, Enabled: p.Enabled, @@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + var isUpdate = policy.ID != "" + var updateAccountPeers bool + var action = activity.PolicyAdded + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + saveFunc := transaction.CreatePolicy + if isUpdate { + action = activity.PolicyUpdated + saveFunc = transaction.SavePolicy + } + + return saveFunc(ctx, LockingStrengthUpdate, policy) + }) if err != nil { - return err + return nil, err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated - } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } + return policy, nil +} + +// DeletePolicy from the store +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + + var policy *Policy + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + if err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } -// DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - policy, err := am.deletePolicy(account, policyID) - if err != nil { - return err - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - - if am.anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return false, err } - oldPolicy := account.Policies[policyIdx] - // Update the existing policy - account.Policies[policyIdx] = policyToSave - - if !policyToSave.Enabled && !oldPolicy.Enabled { + if !policy.Enabled && !existingPolicy.Enabled { return false, nil } - updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) - return updateAccountPeers, nil - } + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil -} - -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, + if hasPeers { + return true, nil } } - return result + + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) +} + +// validatePolicy validates the policy and its rules. +func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { + if policy.ID != "" { + _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return err + } + } else { + policy.ID = xid.New().String() + policy.AccountID = accountID + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + if err != nil { + return err + } + + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + if err != nil { + return err + } + + for i, rule := range policy.Rules { + ruleCopy := rule.Copy() + if ruleCopy.ID == "" { + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor + ruleCopy.PolicyID = policy.ID + } + + ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources) + ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations) + policy.Rules[i] = ruleCopy + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + + return nil } // getAllPeersFromGroups for given peer ID and list of groups @@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := postureChecks[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := groups[id]; exists { + validIDs = append(validIDs, id) + } + } + + return validIDs +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/policy_test.go b/management/server/policy_test.go index e7f0f9cd2..62d80f46e 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) + var policyWithGroupRulesNoPeers *Policy + var policyWithDestinationPeersOnly *Policy + var policyWithSourceAndDestinationPeers *Policy + // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-rule-groups-no-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with source group containing peers, but destination group without peers should // update account's peers and send peer update t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-has-peers-destination-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, @@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-destination-has-peers-source-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), - Enabled: false, + Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, @@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, @@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Disabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = false + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Updating disabled policy with destination and source groups containing peers should not update account's peers // or send peer update t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Description: "updated description", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Description = "updated description" + policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"} + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Enabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = true + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { - policyID := "policy-source-destination-peers" - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID) assert.NoError(t, err) select { @@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { - policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID) assert.NoError(t, err) select { @@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID) assert.NoError(t, err) select { diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 3c5c5fc79..93e5741cf 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := Policy{ - ID: "policyA", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -234,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -282,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }() policy.SourcePostureChecks = []string{} - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -316,12 +312,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -330,8 +324,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -362,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - policy = Policy{ - ID: "policyB", + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, @@ -376,9 +368,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -405,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -418,8 +407,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -490,12 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { require.NoError(t, err, "failed to save postureCheckB") policy := &Policy{ - ID: "policyA", AccountID: account.Id, Rules: []*PolicyRule{ { - ID: "ruleA", - PolicyID: "policyA", Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -504,7 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { @@ -528,7 +513,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -539,7 +524,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -560,7 +545,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c..108f791e0 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { defaultRule := rules[0] newPolicy := defaultRule.Copy() - newPolicy.ID = xid.New().String() newPolicy.Name = "peer1 only" newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) + _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 7c8200706..614547c60 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 47c17bb92..9a24857d1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1243,8 +1243,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren var groups []*nbgroup.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } groupsMap := make(map[string]*nbgroup.Group) @@ -1295,12 +1295,67 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + First(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewPolicyNotFoundError(policyID) + } + log.WithContext(ctx).Errorf("failed to get policy from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in store") + } + + return nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. @@ -1331,6 +1386,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return postureCheck, nil } +// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. +func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") + } + + postureChecksMap := make(map[string]*posture.Checks) + for _, postureCheck := range postureChecks { + postureChecksMap[postureCheck.ID] = postureCheck + } + + return postureChecksMap, nil +} + // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index de939e8d0..c05793fc6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1612,6 +1612,49 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } } +func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureCheckIDs []string + expectedCount int + }{ + { + name: "retrieve existing posture checks by existing IDs", + postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"}, + expectedCount: 2, + }, + { + name: "empty posture check IDs list", + postureCheckIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing posture check IDs", + postureCheckIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing posture check IDs", + postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + func TestSqlStore_SavePostureChecks(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) @@ -1699,3 +1742,118 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { }) } } + +func TestSqlStore_GetPolicyByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + policyID string + expectError bool + }{ + { + name: "retrieve existing policy", + policyID: "cs1tnh0hhcjnqoiuebf0", + expectError: false, + }, + { + name: "retrieve non-existing policy checks", + policyID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty policy ID", + policyID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, policy) + } else { + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, tt.policyID, policy.ID) + } + }) + } +} + +func TestSqlStore_CreatePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policy := &Policy{ + ID: "policy-id", + AccountID: accountID, + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) + +} + +func TestSqlStore_SavePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy.Enabled = false + policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} + err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) +} + +func TestSqlStore_DeletePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.Error(t, err) + require.Nil(t, policy) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 44391e1f1..0fff53559 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -144,3 +144,8 @@ func NewGroupNotFoundError(groupID string) error { func NewPostureChecksNotFoundError(postureChecksID string) error { return Errorf(NotFound, "posture checks: %s not found", postureChecksID) } + +// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy +func NewPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "policy: %s not found", policyID) +} diff --git a/management/server/store.go b/management/server/store.go index 03b5821e7..ba61d552d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -80,11 +80,15 @@ type Store interface { DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 1646ff4da..37db27316 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -35,4 +35,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); +INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54..498017afa 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) require.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)