mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-04 17:15:39 +02:00
[management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add lock to get account groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add check for regular user Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * get only required groups for auto-group validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor account peers update Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor groups to use store methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Run groups ops in transaction Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix missing group removed from setup key activity Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix sonar Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Change setup key log level to debug for missing group Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Retrieve modified peers once for group events Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add account locking and merge group deletion methods Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
d9b691b8a5
commit
12f442439a
@ -110,7 +110,6 @@ type AccountManager interface {
|
|||||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
|
||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||||
@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool {
|
|||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error saving groups: %w", err)
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2101,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
} else {
|
} else {
|
||||||
@ -2114,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
for _, g := range removeOldGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
} else {
|
} else {
|
||||||
@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.NewGetAccountError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2401,12 +2405,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
|
|||||||
|
|
||||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||||
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
am.updateAccountPeers(ctx, updatedAccount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||||
|
@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
group := group.Group{
|
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||||
ID: "groupA",
|
ID: "groupA",
|
||||||
Name: "GroupA",
|
Name: "GroupA",
|
||||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||||
}
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err, "failed to save group")
|
||||||
|
|
||||||
policy := Policy{
|
policy := Policy{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
|
||||||
t.Errorf("delete group: %v", err)
|
t.Errorf("delete group: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 0)
|
assert.Len(t, user.AutoGroups, 0)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1)
|
assert.Len(t, user.AutoGroups, 1)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
|
@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -33,8 +33,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -45,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
|||||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
@ -54,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
|||||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
@ -73,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
|||||||
// SaveGroups adds new groups to the account.
|
// SaveGroups adds new groups to the account.
|
||||||
// Note: This function does not acquire the global lock.
|
// Note: This function does not acquire the global lock.
|
||||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
var groupsToSave []*nbgroup.Group
|
||||||
|
var updateAccountPeers bool
|
||||||
|
|
||||||
for _, newGroup := range newGroups {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
groupIDs := make([]string, 0, len(groups))
|
||||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
for _, newGroup := range groups {
|
||||||
}
|
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||||
|
|
||||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
|
||||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
|
||||||
if err != nil {
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok || s.ErrorType != status.NotFound {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Avoid duplicate groups only for the API issued groups.
|
newGroup.AccountID = accountID
|
||||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
groupsToSave = append(groupsToSave, newGroup)
|
||||||
if existingGroup != nil {
|
groupIDs = append(groupIDs, newGroup.ID)
|
||||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
newGroup.ID = xid.New().String()
|
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||||
}
|
|
||||||
|
|
||||||
for _, peerID := range newGroup.Peers {
|
|
||||||
if account.Peers[peerID] == nil {
|
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oldGroup := account.Groups[newGroup.ID]
|
|
||||||
account.Groups[newGroup.ID] = newGroup
|
|
||||||
|
|
||||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
|
||||||
eventsToStore = append(eventsToStore, events...)
|
eventsToStore = append(eventsToStore, events...)
|
||||||
}
|
}
|
||||||
|
|
||||||
newGroupIDs := make([]string, 0, len(newGroups))
|
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||||
for _, newGroup := range newGroups {
|
if err != nil {
|
||||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
am.updateAccountPeers(ctx, account)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, storeEvent := range eventsToStore {
|
for _, storeEvent := range eventsToStore {
|
||||||
storeEvent()
|
storeEvent()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
addedPeers := make([]string, 0)
|
addedPeers := make([]string, 0)
|
||||||
removedPeers := make([]string, 0)
|
removedPeers := make([]string, 0)
|
||||||
|
|
||||||
if oldGroup != nil {
|
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||||
|
if err == nil && oldGroup != nil {
|
||||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||||
} else {
|
} else {
|
||||||
@ -155,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range addedPeers {
|
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||||
peer := account.Peers[p]
|
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||||
if peer == nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range addedPeers {
|
||||||
|
peer, ok := peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peerCopy := peer // copy to avoid closure issues
|
|
||||||
eventsToStore = append(eventsToStore, func() {
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
meta := map[string]any{
|
||||||
map[string]any{
|
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
}
|
||||||
})
|
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range removedPeers {
|
for _, peerID := range removedPeers {
|
||||||
peer := account.Peers[p]
|
peer, ok := peers[peerID]
|
||||||
if peer == nil {
|
if !ok {
|
||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
peerCopy := peer // copy to avoid closure issues
|
|
||||||
eventsToStore = append(eventsToStore, func() {
|
eventsToStore = append(eventsToStore, func() {
|
||||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
meta := map[string]any{
|
||||||
map[string]any{
|
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
}
|
||||||
})
|
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,42 +210,10 @@ func difference(a, b []string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup object of the peers.
|
// DeleteGroup object of the peers.
|
||||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||||
account, err := am.Store.GetAccount(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
|
||||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
delete(account.Groups, groupID)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroups deletes groups from an account.
|
// DeleteGroups deletes groups from an account.
|
||||||
@ -250,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
|||||||
//
|
//
|
||||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||||
// Errors are collected and returned at the end.
|
// Errors are collected and returned at the end.
|
||||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||||
account, err := am.Store.GetAccount(ctx, accountId)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
var allErrors error
|
var allErrors error
|
||||||
|
var groupIDsToDelete []string
|
||||||
|
var deletedGroups []*nbgroup.Group
|
||||||
|
|
||||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
group, ok := account.Groups[groupID]
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(account.Groups, groupID)
|
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||||
deletedGroups = append(deletedGroups, group)
|
deletedGroups = append(deletedGroups, group)
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range deletedGroups {
|
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range deletedGroups {
|
||||||
|
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||||
}
|
}
|
||||||
|
|
||||||
return allErrors
|
return allErrors
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListGroups objects of the peers
|
|
||||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
|
||||||
for _, item := range account.Groups {
|
|
||||||
groups = append(groups, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupAddPeer appends peer to the group
|
// GroupAddPeer appends peer to the group
|
||||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
var group *nbgroup.Group
|
||||||
|
var updateAccountPeers bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
if updated := group.AddPeer(peerID); !updated {
|
||||||
if !ok {
|
return nil
|
||||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
add := true
|
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||||
for _, itemID := range group.Peers {
|
if err != nil {
|
||||||
if itemID == peerID {
|
|
||||||
add = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if add {
|
|
||||||
group.Peers = append(group.Peers, peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
am.updateAccountPeers(ctx, account)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -347,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
|||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
var group *nbgroup.Group
|
||||||
|
var updateAccountPeers bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
if updated := group.RemovePeer(peerID); !updated {
|
||||||
if !ok {
|
return nil
|
||||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||||
for i, itemID := range group.Peers {
|
if err != nil {
|
||||||
if itemID == peerID {
|
|
||||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||||
am.updateAccountPeers(ctx, account)
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
// validateNewGroup validates the new group for existence and required fields.
|
||||||
|
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||||
|
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||||
|
}
|
||||||
|
|
||||||
|
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||||
|
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent duplicate groups for API-issued groups.
|
||||||
|
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||||
|
if existingGroup != nil {
|
||||||
|
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
newGroup.ID = xid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range newGroup.Peers {
|
||||||
|
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||||
// disable a deleting integration group if the initiator is not an admin service user
|
// disable a deleting integration group if the initiator is not an admin service user
|
||||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||||
executingUser := account.Users[userID]
|
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if executingUser == nil {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return err
|
||||||
}
|
}
|
||||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
if group.IsGroupAll() {
|
||||||
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"user", linkedUser.Id}
|
return &GroupLinkError{"user", linkedUser.Id}
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||||
|
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||||
|
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.Extra != nil {
|
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
if err != nil {
|
||||||
return &GroupLinkError{"integrated validator", group.Name}
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||||
|
return &GroupLinkError{"integrated validator", group.Name}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||||
|
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||||
return true, r
|
return true, r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||||
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, policy := range policies {
|
for _, policy := range policies {
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||||
@ -442,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||||
|
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, dns := range nameServerGroups {
|
for _, dns := range nameServerGroups {
|
||||||
for _, g := range dns.Groups {
|
for _, g := range dns.Groups {
|
||||||
if g == groupID {
|
if g == groupID {
|
||||||
@ -450,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||||
|
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, setupKey := range setupKeys {
|
for _, setupKey := range setupKeys {
|
||||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||||
return true, setupKey
|
return true, setupKey
|
||||||
@ -464,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||||
|
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
if slices.Contains(user.AutoGroups, groupID) {
|
if slices.Contains(user.AutoGroups, groupID) {
|
||||||
return true, user
|
return true, user
|
||||||
@ -473,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||||
|
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
@ -482,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
|
||||||
for _, groupID := range groupIDs {
|
|
||||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool {
|
|||||||
func (g *Group) IsGroupAll() bool {
|
func (g *Group) IsGroupAll() bool {
|
||||||
return g.Name == "All"
|
return g.Name == "All"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||||
|
func (g *Group) AddPeer(peerID string) bool {
|
||||||
|
if peerID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, itemID := range g.Peers {
|
||||||
|
if itemID == peerID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Peers = append(g.Peers, peerID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes peerID from Peers if present, returning true if removed.
|
||||||
|
func (g *Group) RemovePeer(peerID string) bool {
|
||||||
|
for i, itemID := range g.Peers {
|
||||||
|
if itemID == peerID {
|
||||||
|
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
90
management/server/group/group_test.go
Normal file
90
management/server/group/group_test.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package group
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAddPeer(t *testing.T) {
|
||||||
|
t.Run("add new peer to empty slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.True(t, group.AddPeer(peerID))
|
||||||
|
assert.Contains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add new peer to nil slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: nil}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.True(t, group.AddPeer(peerID))
|
||||||
|
assert.Contains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add new peer to non-empty slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := "peer3"
|
||||||
|
assert.True(t, group.AddPeer(peerID))
|
||||||
|
assert.Contains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add duplicate peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.False(t, group.AddPeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add empty peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := ""
|
||||||
|
assert.False(t, group.AddPeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemovePeer(t *testing.T) {
|
||||||
|
t.Run("remove existing peer from slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
|
||||||
|
peerID := "peer2"
|
||||||
|
assert.True(t, group.RemovePeer(peerID))
|
||||||
|
assert.NotContains(t, group.Peers, peerID)
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove peer from empty slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 0, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove peer from nil slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: nil}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Nil(t, group.Peers)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove non-existent peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := "peer3"
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove peer from single-item slice", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1"}}
|
||||||
|
peerID := "peer1"
|
||||||
|
assert.True(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 0, len(group.Peers))
|
||||||
|
assert.NotContains(t, group.Peers, peerID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove empty peer", func(t *testing.T) {
|
||||||
|
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||||
|
peerID := ""
|
||||||
|
assert.False(t, group.RemovePeer(peerID))
|
||||||
|
assert.Equal(t, 2, len(group.Peers))
|
||||||
|
})
|
||||||
|
}
|
@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "delete non-existent group",
|
name: "delete non-existent group",
|
||||||
groupIDs: []string{"non-existent-group"},
|
groupIDs: []string{"non-existent-group"},
|
||||||
expectedDeleted: []string{"non-existent-group"},
|
expectedReasons: []string{"group: non-existent-group not found"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "delete multiple groups with mixed results",
|
name: "delete multiple groups with mixed results",
|
||||||
|
@ -52,26 +52,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
|||||||
return am.Store.SaveAccount(ctx, a)
|
return am.Store.SaveAccount(ctx, a)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||||
if len(groups) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
|
||||||
|
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
for _, group := range groups {
|
|
||||||
var found bool
|
|
||||||
for _, accountGroup := range accountsGroups {
|
|
||||||
if accountGroup.ID == group {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,6 @@ type MockAccountManager struct {
|
|||||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
|
||||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||||
@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
|
|||||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
|
||||||
if am.ListGroupsFunc != nil {
|
|
||||||
return am.ListGroupsFunc(ctx, accountID)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||||
if am.GroupAddPeerFunc != nil {
|
if am.GroupAddPeerFunc != nil {
|
||||||
|
@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
if expired {
|
if expired {
|
||||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||||
// the expired one. Here we notify them that connection is now allowed again.
|
// the expired one. Here we notify them that connection is now allowed again.
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -271,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peerLabelUpdated || requiresPeerUpdates {
|
if peerLabelUpdated || requiresPeerUpdates {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
@ -335,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -348,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -555,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@ -598,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
||||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
|
||||||
am.updateAccountPeers(ctx, account)
|
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if newGroupsAffectsPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
@ -666,7 +674,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
}
|
}
|
||||||
|
|
||||||
if sync.UpdateAccountPeers {
|
if sync.UpdateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -685,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isStatusChanged {
|
if isStatusChanged {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
validPeersMap, err := am.GetValidatedPeers(account)
|
validPeersMap, err := am.GetValidatedPeers(account)
|
||||||
@ -816,7 +824,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updateRemotePeers || isStatusChanged {
|
if updateRemotePeers || isStatusChanged {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
|
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
|
||||||
@ -979,7 +987,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
|||||||
|
|
||||||
// updateAccountPeers updates all peers that belong to an account.
|
// updateAccountPeers updates all peers that belong to an account.
|
||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if am.metrics != nil {
|
if am.metrics != nil {
|
||||||
@ -987,6 +995,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
peers := account.GetPeers()
|
peers := account.GetPeers()
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
@ -1033,12 +1046,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
|||||||
|
|
||||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
// in an active DNS, route, or ACL configuration.
|
// in an active DNS, route, or ACL configuration.
|
||||||
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
|
||||||
peerGroupIDs := make([]string, 0)
|
peerGroupIDs := make([]string, 0)
|
||||||
for _, group := range account.Groups {
|
for _, group := range account.Groups {
|
||||||
if slices.Contains(group.Peers, peerID) {
|
if slices.Contains(group.Peers, peerID) {
|
||||||
peerGroupIDs = append(peerGroupIDs, group.ID)
|
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
|
||||||
}
|
}
|
||||||
|
@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.updateAccountPeers(ctx, account)
|
manager.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||||
|
|
||||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
|||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||||
|
|
||||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, routy) {
|
if isRouteChangeAffectPeers(account, routy) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||||
|
|
||||||
groups, err := am.ListGroups(context.Background(), account.Id)
|
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var groupHA1, groupHA2 *nbgroup.Group
|
var groupHA1, groupHA2 *nbgroup.Group
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
|
@ -453,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
|
|||||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
|
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range removedGroups {
|
for _, g := range removedGroups {
|
||||||
group, ok := groups[g]
|
group, ok := groups[g]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
|
|||||||
for _, g := range addedGroups {
|
for _, g := range addedGroups {
|
||||||
group, ok := groups[g]
|
group, ok := groups[g]
|
||||||
if !ok {
|
if !ok {
|
||||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ const (
|
|||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
keyQueryCondition = "key = ?"
|
keyQueryCondition = "key = ?"
|
||||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
|
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
|
||||||
accountIDCondition = "account_id = ?"
|
accountIDCondition = "account_id = ?"
|
||||||
peerNotFoundFMT = "peer %s not found"
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||||
var users []*User
|
var users []*User
|
||||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.NewUserNotFoundError(userID)
|
return status.NewUserNotFoundError(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return status.NewGetUserFromStoreError()
|
return status.NewGetUserFromStoreError()
|
||||||
}
|
}
|
||||||
user.LastLogin = lastLogin
|
user.LastLogin = lastLogin
|
||||||
@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
|||||||
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.Errorf(status.NotFound, "group not found for account")
|
return status.NewGroupNotFoundError(groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||||
@ -1079,10 +1079,45 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||||
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
||||||
|
var peer *nbpeer.Peer
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&peer, accountAndIDQueryCondition, accountID, peerID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeersByIDs retrieves peers by their IDs and account ID.
|
||||||
|
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
|
||||||
|
}
|
||||||
|
|
||||||
|
peersMap := make(map[string]*nbpeer.Peer)
|
||||||
|
for _, peer := range peers {
|
||||||
|
peersMap[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
return peersMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1104,6 +1139,7 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
|||||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||||
return &SqlStore{
|
return &SqlStore{
|
||||||
db: tx,
|
db: tx,
|
||||||
|
storeEngine: s.storeEngine,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1155,12 +1191,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByID retrieves a group by ID and account ID.
|
// GetGroupByID retrieves a group by ID and account ID.
|
||||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
|
||||||
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
|
var group *nbgroup.Group
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewGroupNotFoundError(groupID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName retrieves a group by name and account ID.
|
// GetGroupByName retrieves a group by name and account ID.
|
||||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
|
||||||
var group nbgroup.Group
|
var group nbgroup.Group
|
||||||
|
|
||||||
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
// TODO: This fix is accepted for now, but if we need to handle this more frequently
|
||||||
@ -1172,12 +1218,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
query = query.Order("json_array_length(peers) DESC")
|
query = query.Order("json_array_length(peers) DESC")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
|
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "group not found")
|
return nil, status.NewGroupNotFoundError(groupName)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
||||||
}
|
}
|
||||||
return &group, nil
|
return &group, nil
|
||||||
}
|
}
|
||||||
@ -1185,7 +1232,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
// GetGroupsByIDs retrieves groups by their IDs and account ID.
|
||||||
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
||||||
@ -1203,11 +1250,40 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
|||||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to save group to store")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteGroup deletes a group from the database.
|
||||||
|
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewGroupNotFoundError(groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteGroups deletes groups from the database.
|
||||||
|
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
|
||||||
|
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete groups from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountPolicies retrieves policies for an account.
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
|
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
|
||||||
|
@ -14,11 +14,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
route2 "github.com/netbirdio/netbird/route"
|
route2 "github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
|||||||
t.Fatal("failed to save group")
|
t.Fatal("failed to save group")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to get group")
|
t.Fatal("failed to get group")
|
||||||
return err
|
return err
|
||||||
@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
|
|||||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
account, err := store.GetAccount(context.Background(), accountID)
|
account, err := store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, users, len(account.Users))
|
require.Len(t, users, len(account.Users))
|
||||||
}
|
}
|
||||||
@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
|
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "All", group.Name)
|
require.True(t, group.IsGroupAll())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||||
@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
|||||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
groupIDs []string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing groups by existing IDs",
|
||||||
|
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||||
|
expectedCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty group IDs list",
|
||||||
|
groupIDs: []string{},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing group IDs",
|
||||||
|
groupIDs: []string{"nonexistent1", "nonexistent2"},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed existing and non-existing group IDs",
|
||||||
|
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveGroup(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
group := &nbgroup.Group{
|
||||||
|
ID: "group-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "api",
|
||||||
|
Peers: []string{"peer1", "peer2"},
|
||||||
|
}
|
||||||
|
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, savedGroup, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveGroups(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
groups := []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "group-1",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "api",
|
||||||
|
Peers: []string{"peer1", "peer2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "group-2",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "integration",
|
||||||
|
Peers: []string{"peer3", "peer4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
groupID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "delete existing group",
|
||||||
|
groupID: "cfefqs706sqkneg59g4g",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete non-existing group",
|
||||||
|
groupID: "non-existing-group-id",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with empty group ID",
|
||||||
|
groupID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, group)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteGroups(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
groupIDs []string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "delete multiple existing groups",
|
||||||
|
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete non-existing groups",
|
||||||
|
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with empty group IDs list",
|
||||||
|
groupIDs: []string{},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, groupID := range tt.groupIDs {
|
||||||
|
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPeerByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing peer",
|
||||||
|
peerID: "cfefqs706sqkneg59g4g",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing peer",
|
||||||
|
peerID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty peer ID",
|
||||||
|
peerID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, peer)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, peer)
|
||||||
|
require.Equal(t, tt.peerID, peer.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerIDs []string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing peers by existing IDs",
|
||||||
|
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
|
||||||
|
expectedCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty peer IDs list",
|
||||||
|
peerIDs: []string{},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing peer IDs",
|
||||||
|
peerIDs: []string{"nonexistent1", "nonexistent2"},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed existing and non-existing peer IDs",
|
||||||
|
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, peers, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,7 +3,6 @@ package status
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -126,11 +125,6 @@ func NewAdminPermissionError() error {
|
|||||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
|
||||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
|
||||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||||
func NewInvalidKeyIDError() error {
|
func NewInvalidKeyIDError() error {
|
||||||
return Errorf(InvalidArgument, "invalid key ID")
|
return Errorf(InvalidArgument, "invalid key ID")
|
||||||
@ -140,3 +134,8 @@ func NewInvalidKeyIDError() error {
|
|||||||
func NewGetAccountError(err error) error {
|
func NewGetAccountError(err error) error {
|
||||||
return Errorf(Internal, "error getting account: %s", err)
|
return Errorf(Internal, "error getting account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
|
||||||
|
func NewGroupNotFoundError(groupID string) error {
|
||||||
|
return Errorf(NotFound, "group: %s not found", groupID)
|
||||||
|
}
|
||||||
|
@ -62,7 +62,7 @@ type Store interface {
|
|||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
@ -76,6 +76,8 @@ type Store interface {
|
|||||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
|
||||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||||
|
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||||
|
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
@ -90,6 +92,8 @@ type Store interface {
|
|||||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
|
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||||
|
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
@ -108,7 +112,7 @@ type Store interface {
|
|||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
|
@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
|
|||||||
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -835,7 +835,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, storeEvent := range eventsToStore {
|
for _, storeEvent := range eventsToStore {
|
||||||
@ -1132,7 +1132,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
|||||||
if len(peerIDs) != 0 {
|
if len(peerIDs) != 0 {
|
||||||
// this will trigger peer disconnect from the management service
|
// this will trigger peer disconnect from the management service
|
||||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account.Id)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1240,7 +1240,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
for targetUserID, meta := range deletedUsersMeta {
|
for targetUserID, meta := range deletedUsersMeta {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user