diff --git a/client/client_test.go b/client/client_test.go index 68a4341..adc9dd6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -9,7 +9,6 @@ import ( "path" "regexp" "runtime" - "strconv" "strings" "sync" "testing" @@ -58,53 +57,6 @@ func TestMain(m *testing.M) { var shellTesters []shellTester = []shellTester{bashTester{}, zshTester{}} -func numTestShards() int { - numTestShardsStr := os.Getenv("NUM_TEST_SHARDS") - if numTestShardsStr == "" { - return -1 - } - numTestShards, err := strconv.Atoi(numTestShardsStr) - if err != nil { - panic(fmt.Errorf("failed to parse NUM_TEST_SHARDS: %v", err)) - } - return numTestShards -} - -func currentShardNumber() int { - currentShardNumberStr := os.Getenv("CURRENT_SHARD_NUM") - if currentShardNumberStr == "" { - return -1 - } - currentShardNumber, err := strconv.Atoi(currentShardNumberStr) - if err != nil { - panic(fmt.Errorf("failed to parse CURRENT_SHARD_NUM: %v", err)) - } - return currentShardNumber -} - -func isShardedTestRun() bool { - return numTestShards() != -1 && currentShardNumber() != -1 -} - -func markTestForSharding(t *testing.T, testShardNumber int) { - if isShardedTestRun() { - if testShardNumber%numTestShards() != currentShardNumber() { - t.Skip("Skipping sharded test") - } - } -} - -var shardNumberAllocator int = 0 - -func wrapTestForSharding(test func(t *testing.T)) func(t *testing.T) { - shardNumberAllocator += 1 - return func(t *testing.T) { - testShardNumber := shardNumberAllocator - markTestForSharding(t, testShardNumber) - test(t) - } -} - func TestParam(t *testing.T) { if skipSlowTests() { shellTesters = shellTesters[:1] diff --git a/client/testutils.go b/client/testutils.go index be30435..5016417 100644 --- a/client/testutils.go +++ b/client/testutils.go @@ -366,3 +366,59 @@ func stripRequiredPrefix(t *testing.T, out, prefix string) string { func stripTuiCommandPrefix(t *testing.T, out string) string { return stripRequiredPrefix(t, out, "hishtory tquery") } + +// Wrap the given test so that it can be run on Github Actions with sharding. This +// makes it possible to run only 1/N tests on each of N github action jobs, speeding +// up test execution through parallelization. This is necessary since the wrapped +// integration tests rely on OS-level globals (the shell history) that can't otherwise +// be parallelized. +func wrapTestForSharding(test func(t *testing.T)) func(t *testing.T) { + shardNumberAllocator += 1 + return func(t *testing.T) { + testShardNumber := shardNumberAllocator + markTestForSharding(t, testShardNumber) + test(t) + } +} + +var shardNumberAllocator int = 0 + +// Returns whether this is a sharded test run. false during all normal non-github action operations. +func isShardedTestRun() bool { + return numTestShards() != -1 && currentShardNumber() != -1 +} + +// Get the total number of test shards +func numTestShards() int { + numTestShardsStr := os.Getenv("NUM_TEST_SHARDS") + if numTestShardsStr == "" { + return -1 + } + numTestShards, err := strconv.Atoi(numTestShardsStr) + if err != nil { + panic(fmt.Errorf("failed to parse NUM_TEST_SHARDS: %v", err)) + } + return numTestShards +} + +// Get the current shard number +func currentShardNumber() int { + currentShardNumberStr := os.Getenv("CURRENT_SHARD_NUM") + if currentShardNumberStr == "" { + return -1 + } + currentShardNumber, err := strconv.Atoi(currentShardNumberStr) + if err != nil { + panic(fmt.Errorf("failed to parse CURRENT_SHARD_NUM: %v", err)) + } + return currentShardNumber +} + +// Mark the given test for sharding with the given test ID number. +func markTestForSharding(t *testing.T, testShardNumber int) { + if isShardedTestRun() { + if testShardNumber%numTestShards() != currentShardNumber() { + t.Skip("Skipping sharded test") + } + } +}