diff --git a/client/client_test.go b/client/client_test.go index daf93d9..3aab69e 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -10,11 +10,13 @@ import ( "path" "regexp" "runtime" + "strconv" "strings" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/google/uuid" "github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/lib" @@ -92,7 +94,7 @@ func (z zshTester) RunInteractiveShellRelaxed(t *testing.T, script string) (stri cmd.Stderr = &stderr err := cmd.Run() if err != nil { - return "", fmt.Errorf("unexpected error when running commands, out=%#v, err=%#v: %v", stdout.String(), stderr.String(), err) + return "", fmt.Errorf("unexpected error when running command=%#v, out=%#v, err=%#v: %v", script, stdout.String(), stderr.String(), err) } outStr := stdout.String() if strings.Contains(outStr, "hishtory fatal error") { @@ -127,6 +129,7 @@ func TestParameterized(t *testing.T) { t.Run("testHelpCommand/"+tester.ShellName(), func(t *testing.T) { testHelpCommand(t, tester) }) t.Run("testStripBashTimePrefix/"+tester.ShellName(), func(t *testing.T) { testStripBashTimePrefix(t, tester) }) t.Run("testReuploadHistoryEntries/"+tester.ShellName(), func(t *testing.T) { testReuploadHistoryEntries(t, tester) }) + t.Run("testInitialHistoryImport/"+tester.ShellName(), func(t *testing.T) { testInitialHistoryImport(t, tester) }) } } @@ -1269,4 +1272,47 @@ func testReuploadHistoryEntries(t *testing.T, tester shellTester) { } } +func testInitialHistoryImport(t *testing.T, tester shellTester) { + // Setup + defer shared.BackupAndRestore(t)() + + // Record some commands before installing hishtory + randomCmdUuid := uuid.Must(uuid.NewRandom()).String() + randomCmd := fmt.Sprintf(`echo %v-foo +echo %v-bar`, randomCmdUuid, randomCmdUuid) + tester.RunInteractiveShell(t, randomCmd) + + // Install hishtory + installHishtory(t, tester, "") + + // Check that hishtory export doesn't have the commands yet + out := tester.RunInteractiveShell(t, `hishtory export `+randomCmdUuid) + expectedOutput := "" + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } + + // Trigger an import + out = tester.RunInteractiveShell(t, "hishtory import") + r := regexp.MustCompile(`Imported (.+) history entries from your existing shell history`) + matches := r.FindStringSubmatch(out) + if len(matches) != 2 { + t.Fatalf("Failed to extract history entries count from output: matches=%#v, out=%#v", matches, out) + } + num, err := strconv.Atoi(matches[1]) + if err != nil { + t.Fatal(err) + } + if num <= 2 { + t.Fatalf("hishtory didn't import enough entries, only found %v entries", num) + } + + // Check that the previously recorded commands are in hishtory + out = tester.RunInteractiveShell(t, `hishtory export `+randomCmdUuid) + expectedOutput = fmt.Sprintf("hishtory export %s\necho %s-foo\necho %s-bar\nhishtory export %s\n", randomCmdUuid, randomCmdUuid, randomCmdUuid, randomCmdUuid) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } +} + // TODO: write a test that runs hishtroy export | grep -v pipefail and then see if that shows up in query/export, I think there is weird behavior here diff --git a/client/lib/lib.go b/client/lib/lib.go index 57e288d..a0587c2 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -464,39 +464,39 @@ func CheckFatalError(err error) { } } -func ImportHistory() error { +func ImportHistory() (int, error) { config, err := GetConfig() if err != nil { - return err + return 0, err } if config.HaveCompletedInitialImport { // Don't run an import if we already have run one. This avoids importing the same entry multiple times. - return nil + return 0, nil } homedir, err := os.UserHomeDir() if err != nil { - return fmt.Errorf("failed to get user's home directory: %v", err) + return 0, fmt.Errorf("failed to get user's home directory: %v", err) } historyEntries, err := parseBashHistory(homedir) if err != nil { - return fmt.Errorf("failed to parse bash history: %v", err) + return 0, fmt.Errorf("failed to parse bash history: %v", err) } extraEntries, err := parseZshHistory(homedir) if err != nil { - return fmt.Errorf("failed to parse zsh history: %v", err) + return 0, fmt.Errorf("failed to parse zsh history: %v", err) } historyEntries = append(historyEntries, extraEntries...) db, err := OpenLocalSqliteDb() if err != nil { - return nil + return 0, nil } currentUser, err := user.Current() if err != nil { - return err + return 0, err } hostname, err := os.Hostname() if err != nil { - return err + return 0, err } for _, cmd := range historyEntries { startTime := time.Now() @@ -513,15 +513,15 @@ func ImportHistory() error { DeviceId: config.DeviceId, }) if err != nil { - return err + return 0, err } } config.HaveCompletedInitialImport = true err = SetConfig(config) if err != nil { - return fmt.Errorf("failed to mark initial import as completed, this may lead to duplicate history entries: %v", err) + return 0, fmt.Errorf("failed to mark initial import as completed, this may lead to duplicate history entries: %v", err) } - return nil + return len(historyEntries), nil } func parseBashHistory(homedir string) ([]string, error) { diff --git a/hishtory.go b/hishtory.go index 941ceb0..95a47c4 100644 --- a/hishtory.go +++ b/hishtory.go @@ -35,13 +35,17 @@ func main() { case "install": lib.CheckFatalError(lib.Install()) if os.Getenv("HISHTORY_TEST") == "" { - lib.CheckFatalError(lib.ImportHistory()) + numImported, err := lib.ImportHistory() + lib.CheckFatalError(err) + fmt.Printf("Imported %v history entries from your existing shell history", numImported) } case "import": if os.Getenv("HISHTORY_TEST") == "" { lib.CheckFatalError(fmt.Errorf("the hishtory import command is only meant to be for testing purposes")) } - lib.CheckFatalError(lib.ImportHistory()) + numImported, err := lib.ImportHistory() + lib.CheckFatalError(err) + fmt.Printf("Imported %v history entries from your existing shell history", numImported) case "enable": lib.CheckFatalError(lib.Enable()) case "disable":