Fix ctx wiring so installs work properly

This commit is contained in:
David Dworken 2022-09-20 23:30:57 -07:00
parent 414d8926f6
commit ceb1becfa6
5 changed files with 78 additions and 39 deletions

View File

@ -250,7 +250,7 @@ func installHishtory(t *testing.T, tester shellTester, userSecret string) string
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`) r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
matches := r.FindStringSubmatch(out) matches := r.FindStringSubmatch(out)
if len(matches) != 2 { if len(matches) != 2 {
t.Fatalf("Failed to extract userSecret from output: matches=%#v", matches) t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
} }
return matches[1] return matches[1]
} }
@ -321,9 +321,9 @@ echo thisisrecorded`)
runtimeMatcher := `[0-9.ms]+` runtimeMatcher := `[0-9.ms]+`
exitCodeMatcher := `0` exitCodeMatcher := `0`
pipefailMatcher := `set -em?o pipefail` pipefailMatcher := `set -em?o pipefail`
line1Matcher := `Hostname` + tableDividerMatcher + `CWD` + tableDividerMatcher + `Timestamp` + tableDividerMatcher + `Runtime` + tableDividerMatcher + `Exit Code` + tableDividerMatcher + `Command` + tableDividerMatcher + `\n` line1Matcher := tableDividerMatcher + `Hostname` + tableDividerMatcher + `CWD` + tableDividerMatcher + `Timestamp` + tableDividerMatcher + `Runtime` + tableDividerMatcher + `Exit Code` + tableDividerMatcher + `Command` + tableDividerMatcher + `\n`
line2Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + pipefailMatcher + tableDividerMatcher + `\n` line2Matcher := tableDividerMatcher + hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + pipefailMatcher + tableDividerMatcher + `\n`
line3Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + `echo thisisrecorded` + tableDividerMatcher + `\n` line3Matcher := tableDividerMatcher + hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + `echo thisisrecorded` + tableDividerMatcher + `\n`
match, err := regexp.MatchString(line3Matcher, out) match, err := regexp.MatchString(line3Matcher, out)
shared.Check(t, err) shared.Check(t, err)
if !match { if !match {
@ -868,7 +868,7 @@ go build -o /tmp/client
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`) r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
matches := r.FindStringSubmatch(out) matches := r.FindStringSubmatch(out)
if len(matches) != 2 { if len(matches) != 2 {
t.Fatalf("Failed to extract userSecret from output: matches=%#v", matches) t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
} }
userSecret := matches[1] userSecret := matches[1]
@ -1078,7 +1078,7 @@ func testInstallViaPythonScript(t *testing.T, tester shellTester) {
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`) r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
matches := r.FindStringSubmatch(out) matches := r.FindStringSubmatch(out)
if len(matches) != 2 { if len(matches) != 2 {
t.Fatalf("Failed to extract userSecret from output: matches=%#v", matches) t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
} }
userSecret := matches[1] userSecret := matches[1]

View File

@ -3,6 +3,7 @@ package hctx
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -29,6 +30,10 @@ func GetLogger() *log.Logger {
if err != nil { if err != nil {
panic(fmt.Errorf("failed to get user's home directory: %v", err)) panic(fmt.Errorf("failed to get user's home directory: %v", err))
} }
err = MakeHishtoryDir()
if err != nil {
panic(err)
}
f, err := os.OpenFile(path.Join(homedir, shared.HISHTORY_PATH, "hishtory.log"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o660) f, err := os.OpenFile(path.Join(homedir, shared.HISHTORY_PATH, "hishtory.log"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o660)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to open hishtory.log: %v", err)) panic(fmt.Errorf("failed to open hishtory.log: %v", err))
@ -39,14 +44,26 @@ func GetLogger() *log.Logger {
return hishtoryLogger return hishtoryLogger
} }
func MakeHishtoryDir() error {
homedir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user's home directory: %v", err)
}
err = os.MkdirAll(path.Join(homedir, shared.HISHTORY_PATH), 0o744)
if err != nil {
return fmt.Errorf("failed to create ~/.hishtory dir: %v", err)
}
return nil
}
func OpenLocalSqliteDb() (*gorm.DB, error) { func OpenLocalSqliteDb() (*gorm.DB, error) {
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user's home directory: %v", err) return nil, fmt.Errorf("failed to get user's home directory: %v", err)
} }
err = os.MkdirAll(path.Join(homedir, shared.HISHTORY_PATH), 0o744) err = MakeHishtoryDir()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ~/.hishtory dir: %v", err) return nil, err
} }
hishtoryLogger := GetLogger() hishtoryLogger := GetLogger()
newLogger := logger.New( newLogger := logger.New(
@ -81,12 +98,12 @@ func MakeContext() *context.Context {
ctx := context.Background() ctx := context.Background()
config, err := GetConfig() config, err := GetConfig()
if err != nil { if err != nil {
GetLogger().Fatalf("failed to retrieve config: %v", err) panic(fmt.Errorf("failed to retrieve config: %v", err))
} }
ctx = context.WithValue(ctx, hishtoryContextKey("config"), config) ctx = context.WithValue(ctx, hishtoryContextKey("config"), config)
db, err := OpenLocalSqliteDb() db, err := OpenLocalSqliteDb()
if err != nil { if err != nil {
GetLogger().Fatalf("failed to open local DB: %v", err) panic(fmt.Errorf("failed to open local DB: %v", err))
} }
ctx = context.WithValue(ctx, hishtoryContextKey("db"), db) ctx = context.WithValue(ctx, hishtoryContextKey("db"), db)
return &ctx return &ctx
@ -97,8 +114,7 @@ func GetConf(ctx *context.Context) ClientConfig {
if v != nil { if v != nil {
return v.(ClientConfig) return v.(ClientConfig)
} }
GetLogger().Fatalf("failed to find config in ctx") panic(fmt.Errorf("failed to find config in ctx"))
return ClientConfig{}
} }
func GetDb(ctx *context.Context) *gorm.DB { func GetDb(ctx *context.Context) *gorm.DB {
@ -106,8 +122,7 @@ func GetDb(ctx *context.Context) *gorm.DB {
if v != nil { if v != nil {
return v.(*gorm.DB) return v.(*gorm.DB)
} }
GetLogger().Fatalf("failed to find db in ctx") panic(fmt.Errorf("failed to find db in ctx"))
return nil
} }
type ClientConfig struct { type ClientConfig struct {
@ -161,10 +176,9 @@ func SetConfig(config ClientConfig) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to retrieve homedir: %v", err) return fmt.Errorf("failed to retrieve homedir: %v", err)
} }
clientDir := path.Join(homedir, shared.HISHTORY_PATH) err = MakeHishtoryDir()
err = os.MkdirAll(clientDir, 0o744)
if err != nil { if err != nil {
return fmt.Errorf("failed to create ~/.hishtory/ folder: %v", err) return err
} }
err = os.WriteFile(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH), serializedConfig, 0o600) err = os.WriteFile(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH), serializedConfig, 0o600)
if err != nil { if err != nil {
@ -172,3 +186,17 @@ func SetConfig(config ClientConfig) error {
} }
return nil return nil
} }
func InitConfig() error {
homedir, err := os.UserHomeDir()
if err != nil {
return err
}
_, err = os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH))
if errors.Is(err, os.ErrNotExist) {
return SetConfig(ClientConfig{})
}
return err
}
// TODO: make homedir part of the context

View File

@ -274,7 +274,7 @@ func shouldSkipHiddenCommand(ctx *context.Context, historyLine string) (bool, er
return false, nil return false, nil
} }
func Setup(ctx *context.Context, args []string) error { func Setup(args []string) error {
userSecret := uuid.Must(uuid.NewRandom()).String() userSecret := uuid.Must(uuid.NewRandom()).String()
if len(args) > 2 && args[2] != "" { if len(args) > 2 && args[2] != "" {
userSecret = args[2] userSecret = args[2]
@ -292,7 +292,10 @@ func Setup(ctx *context.Context, args []string) error {
} }
// Drop all existing data // Drop all existing data
db := hctx.GetDb(ctx) db, err := hctx.OpenLocalSqliteDb()
if err != nil {
return err
}
db.Exec("DELETE FROM history_entries") db.Exec("DELETE FROM history_entries")
// Bootstrap from remote date // Bootstrap from remote date
@ -403,21 +406,20 @@ func ImportHistory(ctx *context.Context) (int, error) {
return 0, err return 0, err
} }
for _, cmd := range historyEntries { for _, cmd := range historyEntries {
startTime := time.Now() entry := data.HistoryEntry{
endTime := time.Now()
err = ReliableDbCreate(db, data.HistoryEntry{
LocalUsername: currentUser.Name, LocalUsername: currentUser.Name,
Hostname: hostname, Hostname: hostname,
Command: cmd, Command: cmd,
CurrentWorkingDirectory: "Unknown", CurrentWorkingDirectory: "Unknown",
HomeDirectory: homedir, HomeDirectory: homedir,
ExitCode: 0, // Unknown, but assumed ExitCode: 0, // Unknown, but assumed
StartTime: startTime, StartTime: time.Now(),
EndTime: endTime, EndTime: time.Now(),
DeviceId: config.DeviceId, DeviceId: config.DeviceId,
}) }
err = ReliableDbCreate(db, entry)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("failed to insert imported history entry: %v", err)
} }
} }
config.HaveCompletedInitialImport = true config.HaveCompletedInitialImport = true
@ -465,15 +467,14 @@ func parseZshHistory(homedir string) ([]string, error) {
return readFileToArray(histfile) return readFileToArray(histfile)
} }
func Install(ctx *context.Context) error { func Install() error {
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
if err != nil { if err != nil {
return fmt.Errorf("failed to get user's home directory: %v", err) return fmt.Errorf("failed to get user's home directory: %v", err)
} }
clientDir := path.Join(homedir, shared.HISHTORY_PATH) err = hctx.MakeHishtoryDir()
err = os.MkdirAll(clientDir, 0o744)
if err != nil { if err != nil {
return fmt.Errorf("failed to create folder for hishtory binary: %v", err) return err
} }
path, err := installBinary(homedir) path, err := installBinary(homedir)
if err != nil { if err != nil {
@ -490,7 +491,7 @@ func Install(ctx *context.Context) error {
_, err = hctx.GetConfig() _, err = hctx.GetConfig()
if err != nil { if err != nil {
// No config, so set up a new installation // No config, so set up a new installation
return Setup(ctx, os.Args) return Setup(os.Args)
} }
return nil return nil
} }
@ -920,7 +921,7 @@ func ReliableDbCreate(db *gorm.DB, entry interface{}) error {
time.Sleep(time.Duration(i*rand.Intn(100)) * time.Millisecond) time.Sleep(time.Duration(i*rand.Intn(100)) * time.Millisecond)
continue continue
} }
if strings.Contains(errMsg, "constraint failed: UNIQUE constraint failed") { if strings.Contains(errMsg, "UNIQUE constraint failed") {
if i == 0 { if i == 0 {
return err return err
} else { } else {

View File

@ -16,12 +16,13 @@ import (
func TestSetup(t *testing.T) { func TestSetup(t *testing.T) {
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
defer shared.RunTestServer()() defer shared.RunTestServer()()
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
shared.Check(t, err) shared.Check(t, err)
if _, err := os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH)); err == nil { if _, err := os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH)); err == nil {
t.Fatalf("hishtory secret file already exists!") t.Fatalf("hishtory secret file already exists!")
} }
shared.Check(t, Setup(hctx.MakeContext(), []string{})) shared.Check(t, Setup([]string{}))
if _, err := os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH)); err != nil { if _, err := os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH)); err != nil {
t.Fatalf("hishtory secret file does not exist after Setup()!") t.Fatalf("hishtory secret file does not exist after Setup()!")
} }
@ -35,7 +36,7 @@ func TestSetup(t *testing.T) {
func TestBuildHistoryEntry(t *testing.T) { func TestBuildHistoryEntry(t *testing.T) {
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
defer shared.RunTestServer()() defer shared.RunTestServer()()
shared.Check(t, Setup(hctx.MakeContext(), []string{})) shared.Check(t, Setup([]string{}))
// Test building an actual entry for bash // Test building an actual entry for bash
entry, err := BuildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", " 123 ls /foo ", "1641774958"}) entry, err := BuildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", " 123 ls /foo ", "1641774958"})
@ -94,6 +95,7 @@ func TestBuildHistoryEntry(t *testing.T) {
func TestPersist(t *testing.T) { func TestPersist(t *testing.T) {
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
shared.Check(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext()) db := hctx.GetDb(hctx.MakeContext())
entry := data.MakeFakeHistoryEntry("ls ~/") entry := data.MakeFakeHistoryEntry("ls ~/")
@ -112,6 +114,7 @@ func TestPersist(t *testing.T) {
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
shared.Check(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext()) db := hctx.GetDb(hctx.MakeContext())
// Insert data // Insert data
@ -137,6 +140,7 @@ func TestSearch(t *testing.T) {
func TestAddToDbIfNew(t *testing.T) { func TestAddToDbIfNew(t *testing.T) {
// Set up // Set up
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
shared.Check(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext()) db := hctx.GetDb(hctx.MakeContext())
// Add duplicate entries // Add duplicate entries

View File

@ -21,21 +21,24 @@ func main() {
fmt.Println("Must specify a command! Do you mean `hishtory query`?") fmt.Println("Must specify a command! Do you mean `hishtory query`?")
return return
} }
ctx := hctx.MakeContext()
switch os.Args[1] { switch os.Args[1] {
case "saveHistoryEntry": case "saveHistoryEntry":
ctx := hctx.MakeContext()
lib.CheckFatalError(maybeUploadSkippedHistoryEntries(ctx)) lib.CheckFatalError(maybeUploadSkippedHistoryEntries(ctx))
saveHistoryEntry(ctx) saveHistoryEntry(ctx)
lib.CheckFatalError(processDeletionRequests(ctx)) lib.CheckFatalError(processDeletionRequests(ctx))
case "query": case "query":
ctx := hctx.MakeContext()
lib.CheckFatalError(processDeletionRequests(ctx)) lib.CheckFatalError(processDeletionRequests(ctx))
query(ctx, strings.Join(os.Args[2:], " ")) query(ctx, strings.Join(os.Args[2:], " "))
case "export": case "export":
ctx := hctx.MakeContext()
lib.CheckFatalError(processDeletionRequests(ctx)) lib.CheckFatalError(processDeletionRequests(ctx))
export(ctx, strings.Join(os.Args[2:], " ")) export(ctx, strings.Join(os.Args[2:], " "))
case "redact": case "redact":
fallthrough fallthrough
case "delete": case "delete":
ctx := hctx.MakeContext()
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(ctx)) lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(ctx))
lib.CheckFatalError(processDeletionRequests(ctx)) lib.CheckFatalError(processDeletionRequests(ctx))
query := strings.Join(os.Args[2:], " ") query := strings.Join(os.Args[2:], " ")
@ -46,10 +49,11 @@ func main() {
} }
lib.CheckFatalError(lib.Redact(ctx, query, force)) lib.CheckFatalError(lib.Redact(ctx, query, force))
case "init": case "init":
lib.CheckFatalError(lib.Setup(ctx, os.Args)) lib.CheckFatalError(lib.Setup(os.Args))
case "install": case "install":
lib.CheckFatalError(lib.Install(ctx)) lib.CheckFatalError(lib.Install())
if os.Getenv("HISHTORY_TEST") == "" { if os.Getenv("HISHTORY_TEST") == "" {
ctx := hctx.MakeContext()
numImported, err := lib.ImportHistory(ctx) numImported, err := lib.ImportHistory(ctx)
lib.CheckFatalError(err) lib.CheckFatalError(err)
if numImported > 0 { if numImported > 0 {
@ -57,6 +61,7 @@ func main() {
} }
} }
case "import": case "import":
ctx := hctx.MakeContext()
if os.Getenv("HISHTORY_TEST") == "" { if os.Getenv("HISHTORY_TEST") == "" {
lib.CheckFatalError(fmt.Errorf("the hishtory import command is only meant to be for testing purposes")) lib.CheckFatalError(fmt.Errorf("the hishtory import command is only meant to be for testing purposes"))
} }
@ -66,12 +71,15 @@ func main() {
fmt.Printf("Imported %v history entries from your existing shell history", numImported) fmt.Printf("Imported %v history entries from your existing shell history", numImported)
} }
case "enable": case "enable":
ctx := hctx.MakeContext()
lib.CheckFatalError(lib.Enable(ctx)) lib.CheckFatalError(lib.Enable(ctx))
case "disable": case "disable":
ctx := hctx.MakeContext()
lib.CheckFatalError(lib.Disable(ctx)) lib.CheckFatalError(lib.Disable(ctx))
case "version": case "version":
fallthrough fallthrough
case "status": case "status":
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx) config := hctx.GetConf(ctx)
fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", lib.Version, config.IsEnabled) fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", lib.Version, config.IsEnabled)
fmt.Printf("Secret Key: %s\n", config.UserSecret) fmt.Printf("Secret Key: %s\n", config.UserSecret)
@ -339,5 +347,3 @@ func export(ctx *context.Context, query string) {
fmt.Println(data[i].Command) fmt.Println(data[i].Command)
} }
} }
// TODO: Can we have a global db and config rather than this nonsense?