diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 6a6753595..5c4ddf666 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -420,7 +420,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, return nil } - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&users) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save users to store") @@ -444,7 +444,7 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, return nil } - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&groups) if result.Error != nil { return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 4dcdadf44..6e04c7d9d 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -1331,6 +1331,14 @@ func TestSqlStore_SaveGroups(t *testing.T) { } err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) require.NoError(t, err) + + groups[1].Peers = []string{} + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID) + require.NoError(t, err) + require.Equal(t, groups[1], group) } func TestSqlStore_DeleteGroup(t *testing.T) { @@ -3046,6 +3054,14 @@ func TestSqlStore_SaveUsers(t *testing.T) { accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, accountUsers, 4) + + users[1].AutoGroups = []string{"groupA", "groupC"} + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users) + require.NoError(t, err) + + user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, users[1].Id) + require.NoError(t, err) + require.Equal(t, users[1].AutoGroups, user.AutoGroups) } func TestSqlStore_DeleteUser(t *testing.T) { @@ -3198,3 +3214,61 @@ func TestSqlStore_DeletePAT(t *testing.T) { require.Error(t, err) require.Nil(t, pat) } + +func TestSqlStore_SaveUsers_LargeBatch(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Len(t, accountUsers, 2) + + usersToSave := make([]*types.User, 0) + + for i := 1; i <= 8000; i++ { + usersToSave = append(usersToSave, &types.User{ + Id: fmt.Sprintf("user-%d", i), + AccountID: accountID, + Role: types.UserRoleUser, + }) + } + + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, usersToSave) + require.NoError(t, err) + + accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, 8002, len(accountUsers)) +} + +func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + accountGroups, err := store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Len(t, accountGroups, 3) + + groupsToSave := make([]*types.Group, 0) + + for i := 1; i <= 8000; i++ { + groupsToSave = append(groupsToSave, &types.Group{ + ID: fmt.Sprintf("%d", i), + AccountID: accountID, + Name: fmt.Sprintf("group-%d", i), + }) + } + + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groupsToSave) + require.NoError(t, err) + + accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, 8003, len(accountGroups)) +}