diff --git a/client/client_test.go b/client/client_test.go index dbf2229..31130a6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -2148,6 +2148,34 @@ func TestTimestampFormat(t *testing.T) { compareGoldens(t, out, goldenName) } +func TestZDotDir(t *testing.T) { + // Setup + tester := zshTester{} + defer testutils.BackupAndRestore(t)() + defer testutils.BackupAndRestoreEnv("ZDOTDIR")() + homedir, err := os.UserHomeDir() + if err != nil { + t.Fatalf("failed to get homedir: %v", err) + } + os.Setenv("ZDOTDIR", path.Join(homedir, data.HISHTORY_PATH)) + installHishtory(t, tester, "") + + // Run a command and check that it was recorded + tester.RunInteractiveShell(t, `echo foo`) + out := tester.RunInteractiveShell(t, `hishtory export -pipefail -install`) + if out != "echo foo\n" { + t.Fatalf("hishtory export had unexpected out=%#v", out) + } + + // Check that hishtory respected ZDOTDIR + zshrc, err := os.ReadFile(path.Join(homedir, data.HISHTORY_PATH, ".zshrc")) + zshrc = []byte(tester.RunInteractiveShell(t, `cat ~/.hishtory/.zshrc`)) + testutils.Check(t, err) + if string(zshrc) != "\n# Hishtory Config:\nexport PATH=\"$PATH:/Users/david/.hishtory\"\nsource /Users/david/.hishtory/config.zsh\n" { + t.Fatalf("zshrc had unexpected contents=%#v", string(zshrc)) + } +} + func TestRemoveDuplicateRows(t *testing.T) { // Setup tester := zshTester{} diff --git a/client/lib/lib.go b/client/lib/lib.go index 0d79fbe..64b9fb2 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -791,13 +791,20 @@ func configureZshrc(homedir, binaryPath string) error { // Check if we need to configure the zshrc zshIsConfigured, err := isZshConfigured(homedir) if err != nil { - return fmt.Errorf("failed to check ~/.zshrc: %v", err) + return fmt.Errorf("failed to check .zshrc: %v", err) } if zshIsConfigured { return nil } // Add to zshrc - return addToShellConfig(path.Join(homedir, ".zshrc"), getZshConfigFragment(homedir)) + return addToShellConfig(getZshRcPath(homedir), getZshConfigFragment(homedir)) +} + +func getZshRcPath(homedir string) string { + if zdotdir := os.Getenv("ZDOTDIR"); zdotdir != "" { + return path.Join(zdotdir, ".zshrc") + } + return path.Join(homedir, ".zshrc") } func getZshConfigFragment(homedir string) string { @@ -805,11 +812,11 @@ func getZshConfigFragment(homedir string) string { } func isZshConfigured(homedir string) (bool, error) { - _, err := os.Stat(path.Join(homedir, ".zshrc")) + _, err := os.Stat(getZshRcPath(homedir)) if errors.Is(err, os.ErrNotExist) { return false, nil } - bashrc, err := ioutil.ReadFile(path.Join(homedir, ".zshrc")) + bashrc, err := ioutil.ReadFile(getZshRcPath(homedir)) if err != nil { return false, fmt.Errorf("failed to read zshrc: %v", err) } @@ -1679,7 +1686,7 @@ func Uninstall(ctx *context.Context) error { if err != nil { return err } - err = stripLines(path.Join(homedir, ".zshrc"), getZshConfigFragment(homedir)) + err = stripLines(getZshRcPath(homedir), getZshConfigFragment(homedir)) if err != nil { return err }