diff --git a/client/client_test.go b/client/client_test.go index 1efa500..c783274 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -60,15 +60,37 @@ var shellTesters []shellTester = []shellTester{bashTester{}, zshTester{}} var shardNumberAllocator int = 0 -func markTestForSharding(t *testing.T, testShardNumber int) { +func numTestShards() int { numTestShardsStr := os.Getenv("NUM_TEST_SHARDS") + if numTestShardsStr == "" { + return 0 + } + numTestShards, err := strconv.Atoi(numTestShardsStr) + if err != nil { + panic(fmt.Errorf("failed to parse NUM_TEST_SHARDS: %v", err)) + } + return numTestShards +} + +func currentShardNumber() int { currentShardNumberStr := os.Getenv("CURRENT_SHARD_NUM") - if numTestShardsStr != "" && currentShardNumberStr != "" { - numTestShards, err := strconv.Atoi(numTestShardsStr) - require.NoError(t, err) - currentShardNumber, err := strconv.Atoi(currentShardNumberStr) - require.NoError(t, err) - if testShardNumber%numTestShards != currentShardNumber { + if currentShardNumberStr == "" { + return 0 + } + currentShardNumber, err := strconv.Atoi(currentShardNumberStr) + if err != nil { + panic(fmt.Errorf("failed to parse CURRENT_SHARD_NUM: %v", err)) + } + return currentShardNumber +} + +func isShardedTestRun() bool { + return numTestShards() != 0 && currentShardNumber() != 0 +} + +func markTestForSharding(t *testing.T, testShardNumber int) { + if isShardedTestRun() { + if testShardNumber%numTestShards() == currentShardNumber() { t.Skip("Skipping sharded test") } } diff --git a/client/fuzz_test.go b/client/fuzz_test.go index 5af928e..9e8c500 100644 --- a/client/fuzz_test.go +++ b/client/fuzz_test.go @@ -142,6 +142,12 @@ func FuzzTestMultipleUsers(f *testing.F) { if skipSlowTests() { f.Skip("skipping slow tests") } + if isShardedTestRun() { + if currentShardNumber() == 0 { + f.Skip("Skipping sharded test") + } + } + s := os.Getenv("SPLIT_TESTS") if s != "" && s != "BASIC" { f.Skip()