package main import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "os" "os/exec" "regexp" "runtime" "strconv" "strings" "syscall" "testing" "github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/hctx" "github.com/ddworken/hishtory/client/lib" "github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared/testutils" "github.com/stretchr/testify/require" ) type shellTester interface { RunInteractiveShell(t testing.TB, script string) string RunInteractiveShellRelaxed(t testing.TB, script string) (string, error) RunInteractiveShellBackground(t testing.TB, script string) error ShellName() string } type bashTester struct{} func (b bashTester) RunInteractiveShell(t testing.TB, script string) string { out, err := b.RunInteractiveShellRelaxed(t, "set -emo pipefail\n"+script) if err != nil { _, filename, line, _ := runtime.Caller(1) require.NoError(t, err, fmt.Sprintf("error when running command at %s:%dv", filename, line)) } return out } func (b bashTester) RunInteractiveShellRelaxed(t testing.TB, script string) (string, error) { cmd := exec.Command("bash", "-i") cmd.Stdin = strings.NewReader(script) var stdout bytes.Buffer cmd.Stdout = &stdout var stderr bytes.Buffer cmd.Stderr = &stderr err := cmd.Run() if err != nil { return "", fmt.Errorf("unexpected error when running commands, out=%#v, err=%#v: %w", stdout.String(), stderr.String(), err) } outStr := stdout.String() require.NotContains(t, outStr, "hishtory fatal error", "Ran command, but hishtory had a fatal error!") return outStr, nil } func (b bashTester) RunInteractiveShellBackground(t testing.TB, script string) error { cmd := exec.Command("bash", "-i") // SetSid: true is required to prevent SIGTTIN signal killing the entire test cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} cmd.Stdin = strings.NewReader(script) cmd.Stdout = nil cmd.Stderr = nil return cmd.Start() } func (b bashTester) ShellName() string { return "bash" } type zshTester struct{} func (z zshTester) RunInteractiveShell(t testing.TB, script string) string { res, err := z.RunInteractiveShellRelaxed(t, "set -eo pipefail\n"+script) require.NoError(t, err) return res } func (z zshTester) RunInteractiveShellRelaxed(t testing.TB, script string) (string, error) { cmd := exec.Command("zsh", "-is") cmd.Stdin = strings.NewReader(script) var stdout bytes.Buffer cmd.Stdout = &stdout var stderr bytes.Buffer cmd.Stderr = &stderr err := cmd.Run() if err != nil { return stdout.String(), fmt.Errorf("unexpected error when running command=%#v, out=%#v, err=%#v: %w", script, stdout.String(), stderr.String(), err) } outStr := stdout.String() require.NotContains(t, outStr, "hishtory fatal error") return outStr, nil } func (z zshTester) RunInteractiveShellBackground(t testing.TB, script string) error { cmd := exec.Command("zsh", "-is") cmd.Stdin = strings.NewReader(script) cmd.Stdout = nil cmd.Stderr = nil return cmd.Start() } func (z zshTester) ShellName() string { return "zsh" } type OnlineStatus int64 const ( Online OnlineStatus = iota Offline ) func assertOnlineStatus(t testing.TB, onlineStatus OnlineStatus) { config := hctx.GetConf(hctx.MakeContext()) if onlineStatus == Online && config.IsOffline { t.Fatalf("We're supposed to be online, yet config.IsOffline=%#v (config=%#v)", config.IsOffline, config) } if onlineStatus == Offline && !config.IsOffline { t.Fatalf("We're supposed to be offline, yet config.IsOffline=%#v (config=%#v)", config.IsOffline, config) } } func hishtoryQuery(t testing.TB, tester shellTester, query string) string { return tester.RunInteractiveShell(t, "hishtory query "+query) } func manuallySubmitHistoryEntry(t testing.TB, userSecret string, entry data.HistoryEntry) { encEntry, err := data.EncryptHistoryEntry(userSecret, entry) require.NoError(t, err) if encEntry.Date != entry.EndTime { t.Fatalf("encEntry.Date does not match the entry") } jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) require.NoError(t, err) require.NotEqual(t, "", entry.DeviceId) resp, err := http.Post("http://localhost:8080/api/v1/submit?source_device_id="+entry.DeviceId, "application/json", bytes.NewBuffer(jsonValue)) require.NoError(t, err) if resp.StatusCode != 200 { t.Fatalf("failed to submit result to backend, status_code=%d", resp.StatusCode) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read resp.Body: %v", err) } submitResp := shared.SubmitResponse{} err = json.Unmarshal(respBody, &submitResp) if err != nil { t.Fatalf("failed to deserialize SubmitResponse: %v", err) } } func captureTerminalOutput(t testing.TB, tester shellTester, commands []string) string { return captureTerminalOutputWithShellName(t, tester, tester.ShellName(), commands) } func captureTerminalOutputWithComplexCommands(t testing.TB, tester shellTester, commands []TmuxCommand) string { return captureTerminalOutputWithShellNameAndDimensions(t, tester, tester.ShellName(), 200, 50, commands) } type TmuxCommand struct { Keys string ResizeX int ResizeY int ExtraDelay float64 } func captureTerminalOutputWithShellName(t testing.TB, tester shellTester, overriddenShellName string, commands []string) string { sCommands := make([]TmuxCommand, 0) for _, command := range commands { sCommands = append(sCommands, TmuxCommand{Keys: command}) } return captureTerminalOutputWithShellNameAndDimensions(t, tester, overriddenShellName, 200, 50, sCommands) } func captureTerminalOutputWithShellNameAndDimensions(t testing.TB, tester shellTester, overriddenShellName string, width, height int, commands []TmuxCommand) string { return captureTerminalOutputComplex(t, TmuxCaptureConfig{ tester: tester, overriddenShellName: overriddenShellName, width: width, height: height, complexCommands: commands, }) } type TmuxCaptureConfig struct { tester shellTester overriddenShellName string commands []string complexCommands []TmuxCommand width, height int includeEscapeSequences bool } func buildTmuxInputCommands(t testing.TB, captureConfig TmuxCaptureConfig) string { if captureConfig.overriddenShellName == "" { captureConfig.overriddenShellName = captureConfig.tester.ShellName() } if captureConfig.width == 0 { captureConfig.width = 200 } if captureConfig.height == 0 { captureConfig.height = 50 } sleepAmount := "0.1" if runtime.GOOS == "linux" { sleepAmount = "0.2" } if captureConfig.overriddenShellName == "fish" { // Fish is considerably slower so this is sadly necessary sleepAmount = "0.5" } if testutils.IsGithubAction() { sleepAmount = "0.5" } fullCommand := "" fullCommand += " tmux kill-session -t foo || true\n" fullCommand += fmt.Sprintf(" tmux -u new-session -d -x %d -y %d -s foo %s\n", captureConfig.width, captureConfig.height, captureConfig.overriddenShellName) fullCommand += " sleep 1\n" if captureConfig.overriddenShellName == "bash" { fullCommand += " tmux send -t foo SPACE source SPACE ~/.bashrc ENTER\n" } fullCommand += " sleep " + sleepAmount + "\n" if len(captureConfig.commands) > 0 { require.Empty(t, captureConfig.complexCommands) for _, command := range captureConfig.commands { captureConfig.complexCommands = append(captureConfig.complexCommands, TmuxCommand{Keys: command}) } } require.NotEmpty(t, captureConfig.complexCommands) for _, cmd := range captureConfig.complexCommands { if cmd.Keys != "" { fullCommand += " tmux send -t foo -- " fullCommand += cmd.Keys fullCommand += "\n" } if cmd.ResizeX != 0 && cmd.ResizeY != 0 { fullCommand += fmt.Sprintf(" tmux resize-window -t foo -x %d -y %d\n", cmd.ResizeX, cmd.ResizeY) } if cmd.ExtraDelay != 0 { fullCommand += fmt.Sprintf(" sleep %f\n", cmd.ExtraDelay) } fullCommand += " sleep " + sleepAmount + "\n" } fullCommand += " sleep 2.5\n" if testutils.IsGithubAction() { fullCommand += " sleep 2.5\n" } return fullCommand } func captureTerminalOutputComplex(t testing.TB, captureConfig TmuxCaptureConfig) string { require.NotNil(t, captureConfig.tester) fullCommand := "" fullCommand += buildTmuxInputCommands(t, captureConfig) fullCommand += " tmux capture-pane -t foo -p" if captureConfig.includeEscapeSequences { // -e ensures that tmux runs the command in an environment that supports escape sequences. Used for rendering colors in the TUI. fullCommand += "e" } fullCommand += "\n" fullCommand += " tmux kill-session -t foo\n" testutils.TestLog(t, "Running tmux command: "+fullCommand) return strings.TrimSpace(captureConfig.tester.RunInteractiveShell(t, fullCommand)) } func assertNoLeakedConnections(t testing.TB) { resp, err := lib.ApiGet(makeTestOnlyContextWithFakeConfig(), "/api/v1/get-num-connections") require.NoError(t, err) numConnections, err := strconv.Atoi(string(resp)) require.NoError(t, err) if numConnections > 1 { t.Fatalf("DB has %d open connections, expected to have 1 or less", numConnections) } } func getPidofCommand() string { if runtime.GOOS == "darwin" { // MacOS doesn't have pidof by default return "pgrep" } return "pidof" } func makeTestOnlyContextWithFakeConfig() context.Context { fakeConfig := hctx.ClientConfig{ UserSecret: "FAKE_TEST_DEVICE", DeviceId: "FAKE_TEST_DEVICE", } ctx := context.Background() ctx = context.WithValue(ctx, hctx.ConfigCtxKey, &fakeConfig) // Note: We don't create a DB here homedir, err := os.UserHomeDir() if err != nil { panic(fmt.Errorf("failed to get homedir: %w", err)) } return context.WithValue(ctx, hctx.HomedirCtxKey, homedir) } type deviceSet struct { deviceMap *map[device]deviceOp currentDevice *device } type device struct { key string deviceId string } type deviceOp struct { backup func() restore func() } func createDevice(t testing.TB, tester shellTester, devices *deviceSet, key, deviceId string) { d := device{key, deviceId} _, ok := (*devices.deviceMap)[d] if ok { t.Fatalf("cannot create device twice for key=%s deviceId=%s", key, deviceId) } installHishtory(t, tester, key) (*devices.deviceMap)[d] = deviceOp{ backup: func() { testutils.BackupAndRestoreWithId(t, key+deviceId) }, restore: testutils.BackupAndRestoreWithId(t, key+deviceId), } } func switchToDevice(devices *deviceSet, d device) { if devices.currentDevice != nil && d == *devices.currentDevice { return } if devices.currentDevice != nil { (*devices.deviceMap)[*devices.currentDevice].backup() } devices.currentDevice = &d (*devices.deviceMap)[d].restore() } func installHishtory(t testing.TB, tester shellTester, userSecret string) string { out := tester.RunInteractiveShell(t, ` /tmp/client install `+userSecret) r := regexp.MustCompile(`Setting secret hishtory key to (.*)`) matches := r.FindStringSubmatch(out) if len(matches) != 2 { t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches) } return matches[1] } func stripShellPrefix(out string) string { if strings.Contains(out, "\n\n\n") { return strings.TrimSpace(strings.Split(out, "\n\n\n")[1]) } return out } func stripRequiredPrefix(t *testing.T, out, prefix string) string { require.Contains(t, out, prefix) return strings.TrimSpace(strings.Split(out, prefix)[1]) } func stripTuiCommandPrefix(t *testing.T, out string) string { return stripRequiredPrefix(t, out, "hishtory tquery") } // Wrap the given test so that it can be run on Github Actions with sharding. This // makes it possible to run only 1/N tests on each of N github action jobs, speeding // up test execution through parallelization. This is necessary since the wrapped // integration tests rely on OS-level globals (the shell history) that can't otherwise // be parallelized. func wrapTestForSharding(test func(t *testing.T)) func(t *testing.T) { shardNumberAllocator += 1 return func(t *testing.T) { testShardNumber := shardNumberAllocator markTestForSharding(t, testShardNumber) test(t) } } var shardNumberAllocator int = 0 // Returns whether this is a sharded test run. false during all normal non-github action operations. func isShardedTestRun() bool { return numTestShards() != -1 && currentShardNumber() != -1 } // Get the total number of test shards func numTestShards() int { numTestShardsStr := os.Getenv("NUM_TEST_SHARDS") if numTestShardsStr == "" { return -1 } numTestShards, err := strconv.Atoi(numTestShardsStr) if err != nil { panic(fmt.Errorf("failed to parse NUM_TEST_SHARDS: %v", err)) } return numTestShards } // Get the current shard number func currentShardNumber() int { currentShardNumberStr := os.Getenv("CURRENT_SHARD_NUM") if currentShardNumberStr == "" { return -1 } currentShardNumber, err := strconv.Atoi(currentShardNumberStr) if err != nil { panic(fmt.Errorf("failed to parse CURRENT_SHARD_NUM: %v", err)) } return currentShardNumber } // Mark the given test for sharding with the given test ID number. func markTestForSharding(t *testing.T, testShardNumber int) { if isShardedTestRun() { if testShardNumber%numTestShards() != currentShardNumber() { t.Skip("Skipping sharded test") } } }