diff --git a/management/server/group/group.go b/management/server/group/group.go index d293e1afc..e98e5ecc4 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -49,3 +49,8 @@ func (g *Group) Copy() *Group { func (g *Group) HasPeers() bool { return len(g.Peers) > 0 } + +// IsGroupAll checks if the group is a default "All" group. +func (g *Group) IsGroupAll() bool { + return g.Name == "All" +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index f54eafdc1..da248be25 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -233,20 +233,16 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewUserNotPartOfAccountError() } - var accountGroups []*nbgroup.Group + var groups []*nbgroup.Group var setupKey *SetupKey var plainKey string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups) if err != nil { return err } - if err = validateSetupKeyAutoGroups(accountGroups, autoGroups); err != nil { - return err - } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID @@ -257,8 +253,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) - groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) - for _, g := range accountGroups { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { groupMap[g.ID] = g } @@ -294,20 +290,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewUserNotPartOfAccountError() } - var accountGroups []*nbgroup.Group + var groups []*nbgroup.Group var oldKey *SetupKey var newKey *SetupKey err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - accountGroups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups) if err != nil { return err } - if err = validateSetupKeyAutoGroups(accountGroups, keyToSave.AutoGroups); err != nil { - return err - } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err @@ -334,8 +326,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) - groupMap := make(map[string]*nbgroup.Group, len(accountGroups)) - for _, g := range accountGroups { + groupMap := make(map[string]*nbgroup.Group, len(groups)) + for _, g := range groups { groupMap[g.ID] = g } @@ -439,22 +431,20 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error { - groupMap := make(map[string]*nbgroup.Group, len(groups)) - for _, g := range groups { - groupMap[g.ID] = g - } +func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) ([]*nbgroup.Group, error) { + autoGroups := make([]*nbgroup.Group, 0, len(autoGroupIDs)) - for _, groupID := range autoGroups { - g, exists := groupMap[groupID] - if !exists { - return status.Errorf(status.NotFound, "group %s doesn't exist", groupID) + for _, groupID := range autoGroupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + if err != nil { + return nil, err } - if g.Name == "All" { - return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") + if group.IsGroupAll() { + return nil, status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key") } + autoGroups = append(autoGroups, group) } - return nil + return autoGroups, nil }