mirror of
https://github.com/ddworken/hishtory.git
synced 2025-01-09 07:48:19 +01:00
425 lines
13 KiB
Go
425 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"regexp"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"testing"
|
|
|
|
"github.com/ddworken/hishtory/client/data"
|
|
"github.com/ddworken/hishtory/client/hctx"
|
|
"github.com/ddworken/hishtory/client/lib"
|
|
"github.com/ddworken/hishtory/shared"
|
|
"github.com/ddworken/hishtory/shared/testutils"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type shellTester interface {
|
|
RunInteractiveShell(t testing.TB, script string) string
|
|
RunInteractiveShellRelaxed(t testing.TB, script string) (string, error)
|
|
RunInteractiveShellBackground(t testing.TB, script string) error
|
|
ShellName() string
|
|
}
|
|
type bashTester struct{}
|
|
|
|
func (b bashTester) RunInteractiveShell(t testing.TB, script string) string {
|
|
out, err := b.RunInteractiveShellRelaxed(t, "set -emo pipefail\n"+script)
|
|
if err != nil {
|
|
_, filename, line, _ := runtime.Caller(1)
|
|
require.NoError(t, err, fmt.Sprintf("error when running command at %s:%dv", filename, line))
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (b bashTester) RunInteractiveShellRelaxed(t testing.TB, script string) (string, error) {
|
|
cmd := exec.Command("bash", "-i")
|
|
cmd.Stdin = strings.NewReader(script)
|
|
var stdout bytes.Buffer
|
|
cmd.Stdout = &stdout
|
|
var stderr bytes.Buffer
|
|
cmd.Stderr = &stderr
|
|
err := cmd.Run()
|
|
if err != nil {
|
|
return "", fmt.Errorf("unexpected error when running commands, out=%#v, err=%#v: %w", stdout.String(), stderr.String(), err)
|
|
}
|
|
outStr := stdout.String()
|
|
require.NotContains(t, outStr, "hishtory fatal error", "Ran command, but hishtory had a fatal error!")
|
|
return outStr, nil
|
|
}
|
|
|
|
func (b bashTester) RunInteractiveShellBackground(t testing.TB, script string) error {
|
|
cmd := exec.Command("bash", "-i")
|
|
// SetSid: true is required to prevent SIGTTIN signal killing the entire test
|
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
|
|
cmd.Stdin = strings.NewReader(script)
|
|
cmd.Stdout = nil
|
|
cmd.Stderr = nil
|
|
return cmd.Start()
|
|
}
|
|
|
|
func (b bashTester) ShellName() string {
|
|
return "bash"
|
|
}
|
|
|
|
type zshTester struct{}
|
|
|
|
func (z zshTester) RunInteractiveShell(t testing.TB, script string) string {
|
|
res, err := z.RunInteractiveShellRelaxed(t, "set -eo pipefail\n"+script)
|
|
require.NoError(t, err)
|
|
return res
|
|
}
|
|
|
|
func (z zshTester) RunInteractiveShellRelaxed(t testing.TB, script string) (string, error) {
|
|
cmd := exec.Command("zsh", "-is")
|
|
cmd.Stdin = strings.NewReader(script)
|
|
var stdout bytes.Buffer
|
|
cmd.Stdout = &stdout
|
|
var stderr bytes.Buffer
|
|
cmd.Stderr = &stderr
|
|
err := cmd.Run()
|
|
if err != nil {
|
|
return stdout.String(), fmt.Errorf("unexpected error when running command=%#v, out=%#v, err=%#v: %w", script, stdout.String(), stderr.String(), err)
|
|
}
|
|
outStr := stdout.String()
|
|
require.NotContains(t, outStr, "hishtory fatal error")
|
|
return outStr, nil
|
|
}
|
|
|
|
func (z zshTester) RunInteractiveShellBackground(t testing.TB, script string) error {
|
|
cmd := exec.Command("zsh", "-is")
|
|
cmd.Stdin = strings.NewReader(script)
|
|
cmd.Stdout = nil
|
|
cmd.Stderr = nil
|
|
return cmd.Start()
|
|
}
|
|
|
|
func (z zshTester) ShellName() string {
|
|
return "zsh"
|
|
}
|
|
|
|
type OnlineStatus int64
|
|
|
|
const (
|
|
Online OnlineStatus = iota
|
|
Offline
|
|
)
|
|
|
|
func assertOnlineStatus(t testing.TB, onlineStatus OnlineStatus) {
|
|
config := hctx.GetConf(hctx.MakeContext())
|
|
if onlineStatus == Online && config.IsOffline {
|
|
t.Fatalf("We're supposed to be online, yet config.IsOffline=%#v (config=%#v)", config.IsOffline, config)
|
|
}
|
|
if onlineStatus == Offline && !config.IsOffline {
|
|
t.Fatalf("We're supposed to be offline, yet config.IsOffline=%#v (config=%#v)", config.IsOffline, config)
|
|
}
|
|
}
|
|
|
|
func hishtoryQuery(t testing.TB, tester shellTester, query string) string {
|
|
return tester.RunInteractiveShell(t, "hishtory query "+query)
|
|
}
|
|
|
|
func manuallySubmitHistoryEntry(t testing.TB, userSecret string, entry data.HistoryEntry) {
|
|
encEntry, err := data.EncryptHistoryEntry(userSecret, entry)
|
|
require.NoError(t, err)
|
|
if encEntry.Date != entry.EndTime {
|
|
t.Fatalf("encEntry.Date does not match the entry")
|
|
}
|
|
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, "", entry.DeviceId)
|
|
resp, err := http.Post("http://localhost:8080/api/v1/submit?source_device_id="+entry.DeviceId, "application/json", bytes.NewBuffer(jsonValue))
|
|
require.NoError(t, err)
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("failed to submit result to backend, status_code=%d", resp.StatusCode)
|
|
}
|
|
defer resp.Body.Close()
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("failed to read resp.Body: %v", err)
|
|
}
|
|
submitResp := shared.SubmitResponse{}
|
|
err = json.Unmarshal(respBody, &submitResp)
|
|
if err != nil {
|
|
t.Fatalf("failed to deserialize SubmitResponse: %v", err)
|
|
}
|
|
}
|
|
|
|
func captureTerminalOutput(t testing.TB, tester shellTester, commands []string) string {
|
|
return captureTerminalOutputWithShellName(t, tester, tester.ShellName(), commands)
|
|
}
|
|
|
|
func captureTerminalOutputWithComplexCommands(t testing.TB, tester shellTester, commands []TmuxCommand) string {
|
|
return captureTerminalOutputWithShellNameAndDimensions(t, tester, tester.ShellName(), 200, 50, commands)
|
|
}
|
|
|
|
type TmuxCommand struct {
|
|
Keys string
|
|
ResizeX int
|
|
ResizeY int
|
|
ExtraDelay float64
|
|
}
|
|
|
|
func captureTerminalOutputWithShellName(t testing.TB, tester shellTester, overriddenShellName string, commands []string) string {
|
|
sCommands := make([]TmuxCommand, 0)
|
|
for _, command := range commands {
|
|
sCommands = append(sCommands, TmuxCommand{Keys: command})
|
|
}
|
|
return captureTerminalOutputWithShellNameAndDimensions(t, tester, overriddenShellName, 200, 50, sCommands)
|
|
}
|
|
|
|
func captureTerminalOutputWithShellNameAndDimensions(t testing.TB, tester shellTester, overriddenShellName string, width, height int, commands []TmuxCommand) string {
|
|
return captureTerminalOutputComplex(t,
|
|
TmuxCaptureConfig{
|
|
tester: tester,
|
|
overriddenShellName: overriddenShellName,
|
|
width: width,
|
|
height: height,
|
|
complexCommands: commands,
|
|
})
|
|
}
|
|
|
|
type TmuxCaptureConfig struct {
|
|
tester shellTester
|
|
overriddenShellName string
|
|
commands []string
|
|
complexCommands []TmuxCommand
|
|
width, height int
|
|
includeEscapeSequences bool
|
|
}
|
|
|
|
func buildTmuxInputCommands(t testing.TB, captureConfig TmuxCaptureConfig) string {
|
|
if captureConfig.overriddenShellName == "" {
|
|
captureConfig.overriddenShellName = captureConfig.tester.ShellName()
|
|
}
|
|
if captureConfig.width == 0 {
|
|
captureConfig.width = 200
|
|
}
|
|
if captureConfig.height == 0 {
|
|
captureConfig.height = 50
|
|
}
|
|
sleepAmount := "0.1"
|
|
if runtime.GOOS == "linux" {
|
|
sleepAmount = "0.2"
|
|
}
|
|
if captureConfig.overriddenShellName == "fish" {
|
|
// Fish is considerably slower so this is sadly necessary
|
|
sleepAmount = "0.5"
|
|
}
|
|
if testutils.IsGithubAction() {
|
|
sleepAmount = "0.5"
|
|
}
|
|
fullCommand := ""
|
|
fullCommand += " tmux kill-session -t foo || true\n"
|
|
fullCommand += fmt.Sprintf(" tmux -u new-session -d -x %d -y %d -s foo %s\n", captureConfig.width, captureConfig.height, captureConfig.overriddenShellName)
|
|
fullCommand += " sleep 1\n"
|
|
if captureConfig.overriddenShellName == "bash" {
|
|
fullCommand += " tmux send -t foo SPACE source SPACE ~/.bashrc ENTER\n"
|
|
}
|
|
fullCommand += " sleep " + sleepAmount + "\n"
|
|
if len(captureConfig.commands) > 0 {
|
|
require.Empty(t, captureConfig.complexCommands)
|
|
for _, command := range captureConfig.commands {
|
|
captureConfig.complexCommands = append(captureConfig.complexCommands, TmuxCommand{Keys: command})
|
|
}
|
|
}
|
|
require.NotEmpty(t, captureConfig.complexCommands)
|
|
for _, cmd := range captureConfig.complexCommands {
|
|
if cmd.Keys != "" {
|
|
fullCommand += " tmux send -t foo -- "
|
|
fullCommand += cmd.Keys
|
|
fullCommand += "\n"
|
|
}
|
|
if cmd.ResizeX != 0 && cmd.ResizeY != 0 {
|
|
fullCommand += fmt.Sprintf(" tmux resize-window -t foo -x %d -y %d\n", cmd.ResizeX, cmd.ResizeY)
|
|
}
|
|
if cmd.ExtraDelay != 0 {
|
|
fullCommand += fmt.Sprintf(" sleep %f\n", cmd.ExtraDelay)
|
|
}
|
|
fullCommand += " sleep " + sleepAmount + "\n"
|
|
}
|
|
fullCommand += " sleep 2.5\n"
|
|
if testutils.IsGithubAction() {
|
|
fullCommand += " sleep 2.5\n"
|
|
}
|
|
return fullCommand
|
|
}
|
|
|
|
func captureTerminalOutputComplex(t testing.TB, captureConfig TmuxCaptureConfig) string {
|
|
require.NotNil(t, captureConfig.tester)
|
|
fullCommand := ""
|
|
fullCommand += buildTmuxInputCommands(t, captureConfig)
|
|
fullCommand += " tmux capture-pane -t foo -p"
|
|
if captureConfig.includeEscapeSequences {
|
|
// -e ensures that tmux runs the command in an environment that supports escape sequences. Used for rendering colors in the TUI.
|
|
fullCommand += "e"
|
|
}
|
|
fullCommand += "\n"
|
|
fullCommand += " tmux kill-session -t foo\n"
|
|
testutils.TestLog(t, "Running tmux command: "+fullCommand)
|
|
return strings.TrimSpace(captureConfig.tester.RunInteractiveShell(t, fullCommand))
|
|
}
|
|
|
|
func assertNoLeakedConnections(t testing.TB) {
|
|
resp, err := lib.ApiGet(makeTestOnlyContextWithFakeConfig(), "/api/v1/get-num-connections")
|
|
require.NoError(t, err)
|
|
numConnections, err := strconv.Atoi(string(resp))
|
|
require.NoError(t, err)
|
|
if numConnections > 1 {
|
|
t.Fatalf("DB has %d open connections, expected to have 1 or less", numConnections)
|
|
}
|
|
}
|
|
|
|
func getPidofCommand() string {
|
|
if runtime.GOOS == "darwin" {
|
|
// MacOS doesn't have pidof by default
|
|
return "pgrep"
|
|
}
|
|
return "pidof"
|
|
}
|
|
|
|
func makeTestOnlyContextWithFakeConfig() context.Context {
|
|
fakeConfig := hctx.ClientConfig{
|
|
UserSecret: "FAKE_TEST_DEVICE",
|
|
DeviceId: "FAKE_TEST_DEVICE",
|
|
}
|
|
ctx := context.Background()
|
|
ctx = context.WithValue(ctx, hctx.ConfigCtxKey, &fakeConfig)
|
|
// Note: We don't create a DB here
|
|
homedir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
panic(fmt.Errorf("failed to get homedir: %w", err))
|
|
}
|
|
return context.WithValue(ctx, hctx.HomedirCtxKey, homedir)
|
|
}
|
|
|
|
type deviceSet struct {
|
|
deviceMap *map[device]deviceOp
|
|
currentDevice *device
|
|
}
|
|
|
|
type device struct {
|
|
key string
|
|
deviceId string
|
|
}
|
|
|
|
type deviceOp struct {
|
|
backup func()
|
|
restore func()
|
|
}
|
|
|
|
func createDevice(t testing.TB, tester shellTester, devices *deviceSet, key, deviceId string) {
|
|
d := device{key, deviceId}
|
|
_, ok := (*devices.deviceMap)[d]
|
|
if ok {
|
|
t.Fatalf("cannot create device twice for key=%s deviceId=%s", key, deviceId)
|
|
}
|
|
installHishtory(t, tester, key)
|
|
(*devices.deviceMap)[d] = deviceOp{
|
|
backup: func() { testutils.BackupAndRestoreWithId(t, key+deviceId) },
|
|
restore: testutils.BackupAndRestoreWithId(t, key+deviceId),
|
|
}
|
|
}
|
|
|
|
func switchToDevice(devices *deviceSet, d device) {
|
|
if devices.currentDevice != nil && d == *devices.currentDevice {
|
|
return
|
|
}
|
|
if devices.currentDevice != nil {
|
|
(*devices.deviceMap)[*devices.currentDevice].backup()
|
|
}
|
|
devices.currentDevice = &d
|
|
(*devices.deviceMap)[d].restore()
|
|
}
|
|
|
|
func installHishtory(t testing.TB, tester shellTester, userSecret string) string {
|
|
out := tester.RunInteractiveShell(t, ` /tmp/client install `+userSecret)
|
|
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
|
|
matches := r.FindStringSubmatch(out)
|
|
if len(matches) != 2 {
|
|
t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
|
|
}
|
|
return matches[1]
|
|
}
|
|
|
|
func stripShellPrefix(out string) string {
|
|
if strings.Contains(out, "\n\n\n") {
|
|
return strings.TrimSpace(strings.Split(out, "\n\n\n")[1])
|
|
}
|
|
return out
|
|
}
|
|
|
|
func stripRequiredPrefix(t *testing.T, out, prefix string) string {
|
|
require.Contains(t, out, prefix)
|
|
return strings.TrimSpace(strings.Split(out, prefix)[1])
|
|
}
|
|
|
|
func stripTuiCommandPrefix(t *testing.T, out string) string {
|
|
return stripRequiredPrefix(t, out, "hishtory tquery")
|
|
}
|
|
|
|
// Wrap the given test so that it can be run on Github Actions with sharding. This
|
|
// makes it possible to run only 1/N tests on each of N github action jobs, speeding
|
|
// up test execution through parallelization. This is necessary since the wrapped
|
|
// integration tests rely on OS-level globals (the shell history) that can't otherwise
|
|
// be parallelized.
|
|
func wrapTestForSharding(test func(t *testing.T)) func(t *testing.T) {
|
|
shardNumberAllocator += 1
|
|
return func(t *testing.T) {
|
|
testShardNumber := shardNumberAllocator
|
|
markTestForSharding(t, testShardNumber)
|
|
test(t)
|
|
}
|
|
}
|
|
|
|
var shardNumberAllocator int = 0
|
|
|
|
// Returns whether this is a sharded test run. false during all normal non-github action operations.
|
|
func isShardedTestRun() bool {
|
|
return numTestShards() != -1 && currentShardNumber() != -1
|
|
}
|
|
|
|
// Get the total number of test shards
|
|
func numTestShards() int {
|
|
numTestShardsStr := os.Getenv("NUM_TEST_SHARDS")
|
|
if numTestShardsStr == "" {
|
|
return -1
|
|
}
|
|
numTestShards, err := strconv.Atoi(numTestShardsStr)
|
|
if err != nil {
|
|
panic(fmt.Errorf("failed to parse NUM_TEST_SHARDS: %v", err))
|
|
}
|
|
return numTestShards
|
|
}
|
|
|
|
// Get the current shard number
|
|
func currentShardNumber() int {
|
|
currentShardNumberStr := os.Getenv("CURRENT_SHARD_NUM")
|
|
if currentShardNumberStr == "" {
|
|
return -1
|
|
}
|
|
currentShardNumber, err := strconv.Atoi(currentShardNumberStr)
|
|
if err != nil {
|
|
panic(fmt.Errorf("failed to parse CURRENT_SHARD_NUM: %v", err))
|
|
}
|
|
return currentShardNumber
|
|
}
|
|
|
|
// Mark the given test for sharding with the given test ID number.
|
|
func markTestForSharding(t *testing.T, testShardNumber int) {
|
|
if isShardedTestRun() {
|
|
if testShardNumber%numTestShards() != currentShardNumber() {
|
|
t.Skip("Skipping sharded test")
|
|
}
|
|
}
|
|
}
|