Fix ctx wiring so installs work properly

This commit is contained in:
David Dworken 2022-09-20 23:30:57 -07:00
parent 414d8926f6
commit ceb1becfa6
5 changed files with 78 additions and 39 deletions

View File

@ -250,7 +250,7 @@ func installHishtory(t *testing.T, tester shellTester, userSecret string) string
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
matches := r.FindStringSubmatch(out)
if len(matches) != 2 {
t.Fatalf("Failed to extract userSecret from output: matches=%#v", matches)
t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
}
return matches[1]
}
@ -321,9 +321,9 @@ echo thisisrecorded`)
runtimeMatcher := `[0-9.ms]+`
exitCodeMatcher := `0`
pipefailMatcher := `set -em?o pipefail`
line1Matcher := `Hostname` + tableDividerMatcher + `CWD` + tableDividerMatcher + `Timestamp` + tableDividerMatcher + `Runtime` + tableDividerMatcher + `Exit Code` + tableDividerMatcher + `Command` + tableDividerMatcher + `\n`
line2Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + pipefailMatcher + tableDividerMatcher + `\n`
line3Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + `echo thisisrecorded` + tableDividerMatcher + `\n`
line1Matcher := tableDividerMatcher + `Hostname` + tableDividerMatcher + `CWD` + tableDividerMatcher + `Timestamp` + tableDividerMatcher + `Runtime` + tableDividerMatcher + `Exit Code` + tableDividerMatcher + `Command` + tableDividerMatcher + `\n`
line2Matcher := tableDividerMatcher + hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + pipefailMatcher + tableDividerMatcher + `\n`
line3Matcher := tableDividerMatcher + hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + `echo thisisrecorded` + tableDividerMatcher + `\n`
match, err := regexp.MatchString(line3Matcher, out)
shared.Check(t, err)
if !match {
@ -868,7 +868,7 @@ go build -o /tmp/client
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
matches := r.FindStringSubmatch(out)
if len(matches) != 2 {
t.Fatalf("Failed to extract userSecret from output: matches=%#v", matches)
t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
}
userSecret := matches[1]
@ -1078,7 +1078,7 @@ func testInstallViaPythonScript(t *testing.T, tester shellTester) {
r := regexp.MustCompile(`Setting secret hishtory key to (.*)`)
matches := r.FindStringSubmatch(out)
if len(matches) != 2 {
t.Fatalf("Failed to extract userSecret from output: matches=%#v", matches)
t.Fatalf("Failed to extract userSecret from output=%#v: matches=%#v", out, matches)
}
userSecret := matches[1]

View File

@ -3,6 +3,7 @@ package hctx
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
@ -29,6 +30,10 @@ func GetLogger() *log.Logger {
if err != nil {
panic(fmt.Errorf("failed to get user's home directory: %v", err))
}
err = MakeHishtoryDir()
if err != nil {
panic(err)
}
f, err := os.OpenFile(path.Join(homedir, shared.HISHTORY_PATH, "hishtory.log"), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o660)
if err != nil {
panic(fmt.Errorf("failed to open hishtory.log: %v", err))
@ -39,14 +44,26 @@ func GetLogger() *log.Logger {
return hishtoryLogger
}
func MakeHishtoryDir() error {
homedir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user's home directory: %v", err)
}
err = os.MkdirAll(path.Join(homedir, shared.HISHTORY_PATH), 0o744)
if err != nil {
return fmt.Errorf("failed to create ~/.hishtory dir: %v", err)
}
return nil
}
func OpenLocalSqliteDb() (*gorm.DB, error) {
homedir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get user's home directory: %v", err)
}
err = os.MkdirAll(path.Join(homedir, shared.HISHTORY_PATH), 0o744)
err = MakeHishtoryDir()
if err != nil {
return nil, fmt.Errorf("failed to create ~/.hishtory dir: %v", err)
return nil, err
}
hishtoryLogger := GetLogger()
newLogger := logger.New(
@ -81,12 +98,12 @@ func MakeContext() *context.Context {
ctx := context.Background()
config, err := GetConfig()
if err != nil {
GetLogger().Fatalf("failed to retrieve config: %v", err)
panic(fmt.Errorf("failed to retrieve config: %v", err))
}
ctx = context.WithValue(ctx, hishtoryContextKey("config"), config)
db, err := OpenLocalSqliteDb()
if err != nil {
GetLogger().Fatalf("failed to open local DB: %v", err)
panic(fmt.Errorf("failed to open local DB: %v", err))
}
ctx = context.WithValue(ctx, hishtoryContextKey("db"), db)
return &ctx
@ -97,8 +114,7 @@ func GetConf(ctx *context.Context) ClientConfig {
if v != nil {
return v.(ClientConfig)
}
GetLogger().Fatalf("failed to find config in ctx")
return ClientConfig{}
panic(fmt.Errorf("failed to find config in ctx"))
}
func GetDb(ctx *context.Context) *gorm.DB {
@ -106,8 +122,7 @@ func GetDb(ctx *context.Context) *gorm.DB {
if v != nil {
return v.(*gorm.DB)
}
GetLogger().Fatalf("failed to find db in ctx")
return nil
panic(fmt.Errorf("failed to find db in ctx"))
}
type ClientConfig struct {
@ -161,10 +176,9 @@ func SetConfig(config ClientConfig) error {
if err != nil {
return fmt.Errorf("failed to retrieve homedir: %v", err)
}
clientDir := path.Join(homedir, shared.HISHTORY_PATH)
err = os.MkdirAll(clientDir, 0o744)
err = MakeHishtoryDir()
if err != nil {
return fmt.Errorf("failed to create ~/.hishtory/ folder: %v", err)
return err
}
err = os.WriteFile(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH), serializedConfig, 0o600)
if err != nil {
@ -172,3 +186,17 @@ func SetConfig(config ClientConfig) error {
}
return nil
}
func InitConfig() error {
homedir, err := os.UserHomeDir()
if err != nil {
return err
}
_, err = os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH))
if errors.Is(err, os.ErrNotExist) {
return SetConfig(ClientConfig{})
}
return err
}
// TODO: make homedir part of the context

View File

@ -274,7 +274,7 @@ func shouldSkipHiddenCommand(ctx *context.Context, historyLine string) (bool, er
return false, nil
}
func Setup(ctx *context.Context, args []string) error {
func Setup(args []string) error {
userSecret := uuid.Must(uuid.NewRandom()).String()
if len(args) > 2 && args[2] != "" {
userSecret = args[2]
@ -292,7 +292,10 @@ func Setup(ctx *context.Context, args []string) error {
}
// Drop all existing data
db := hctx.GetDb(ctx)
db, err := hctx.OpenLocalSqliteDb()
if err != nil {
return err
}
db.Exec("DELETE FROM history_entries")
// Bootstrap from remote date
@ -403,21 +406,20 @@ func ImportHistory(ctx *context.Context) (int, error) {
return 0, err
}
for _, cmd := range historyEntries {
startTime := time.Now()
endTime := time.Now()
err = ReliableDbCreate(db, data.HistoryEntry{
entry := data.HistoryEntry{
LocalUsername: currentUser.Name,
Hostname: hostname,
Command: cmd,
CurrentWorkingDirectory: "Unknown",
HomeDirectory: homedir,
ExitCode: 0, // Unknown, but assumed
StartTime: startTime,
EndTime: endTime,
StartTime: time.Now(),
EndTime: time.Now(),
DeviceId: config.DeviceId,
})
}
err = ReliableDbCreate(db, entry)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to insert imported history entry: %v", err)
}
}
config.HaveCompletedInitialImport = true
@ -465,15 +467,14 @@ func parseZshHistory(homedir string) ([]string, error) {
return readFileToArray(histfile)
}
func Install(ctx *context.Context) error {
func Install() error {
homedir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user's home directory: %v", err)
}
clientDir := path.Join(homedir, shared.HISHTORY_PATH)
err = os.MkdirAll(clientDir, 0o744)
err = hctx.MakeHishtoryDir()
if err != nil {
return fmt.Errorf("failed to create folder for hishtory binary: %v", err)
return err
}
path, err := installBinary(homedir)
if err != nil {
@ -490,7 +491,7 @@ func Install(ctx *context.Context) error {
_, err = hctx.GetConfig()
if err != nil {
// No config, so set up a new installation
return Setup(ctx, os.Args)
return Setup(os.Args)
}
return nil
}
@ -920,7 +921,7 @@ func ReliableDbCreate(db *gorm.DB, entry interface{}) error {
time.Sleep(time.Duration(i*rand.Intn(100)) * time.Millisecond)
continue
}
if strings.Contains(errMsg, "constraint failed: UNIQUE constraint failed") {
if strings.Contains(errMsg, "UNIQUE constraint failed") {
if i == 0 {
return err
} else {

View File

@ -16,12 +16,13 @@ import (
func TestSetup(t *testing.T) {
defer shared.BackupAndRestore(t)()
defer shared.RunTestServer()()
homedir, err := os.UserHomeDir()
shared.Check(t, err)
if _, err := os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH)); err == nil {
t.Fatalf("hishtory secret file already exists!")
}
shared.Check(t, Setup(hctx.MakeContext(), []string{}))
shared.Check(t, Setup([]string{}))
if _, err := os.Stat(path.Join(homedir, shared.HISHTORY_PATH, shared.CONFIG_PATH)); err != nil {
t.Fatalf("hishtory secret file does not exist after Setup()!")
}
@ -35,7 +36,7 @@ func TestSetup(t *testing.T) {
func TestBuildHistoryEntry(t *testing.T) {
defer shared.BackupAndRestore(t)()
defer shared.RunTestServer()()
shared.Check(t, Setup(hctx.MakeContext(), []string{}))
shared.Check(t, Setup([]string{}))
// Test building an actual entry for bash
entry, err := BuildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", " 123 ls /foo ", "1641774958"})
@ -94,6 +95,7 @@ func TestBuildHistoryEntry(t *testing.T) {
func TestPersist(t *testing.T) {
defer shared.BackupAndRestore(t)()
shared.Check(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext())
entry := data.MakeFakeHistoryEntry("ls ~/")
@ -112,6 +114,7 @@ func TestPersist(t *testing.T) {
func TestSearch(t *testing.T) {
defer shared.BackupAndRestore(t)()
shared.Check(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext())
// Insert data
@ -137,6 +140,7 @@ func TestSearch(t *testing.T) {
func TestAddToDbIfNew(t *testing.T) {
// Set up
defer shared.BackupAndRestore(t)()
shared.Check(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext())
// Add duplicate entries

View File

@ -21,21 +21,24 @@ func main() {
fmt.Println("Must specify a command! Do you mean `hishtory query`?")
return
}
ctx := hctx.MakeContext()
switch os.Args[1] {
case "saveHistoryEntry":
ctx := hctx.MakeContext()
lib.CheckFatalError(maybeUploadSkippedHistoryEntries(ctx))
saveHistoryEntry(ctx)
lib.CheckFatalError(processDeletionRequests(ctx))
case "query":
ctx := hctx.MakeContext()
lib.CheckFatalError(processDeletionRequests(ctx))
query(ctx, strings.Join(os.Args[2:], " "))
case "export":
ctx := hctx.MakeContext()
lib.CheckFatalError(processDeletionRequests(ctx))
export(ctx, strings.Join(os.Args[2:], " "))
case "redact":
fallthrough
case "delete":
ctx := hctx.MakeContext()
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(ctx))
lib.CheckFatalError(processDeletionRequests(ctx))
query := strings.Join(os.Args[2:], " ")
@ -46,10 +49,11 @@ func main() {
}
lib.CheckFatalError(lib.Redact(ctx, query, force))
case "init":
lib.CheckFatalError(lib.Setup(ctx, os.Args))
lib.CheckFatalError(lib.Setup(os.Args))
case "install":
lib.CheckFatalError(lib.Install(ctx))
lib.CheckFatalError(lib.Install())
if os.Getenv("HISHTORY_TEST") == "" {
ctx := hctx.MakeContext()
numImported, err := lib.ImportHistory(ctx)
lib.CheckFatalError(err)
if numImported > 0 {
@ -57,6 +61,7 @@ func main() {
}
}
case "import":
ctx := hctx.MakeContext()
if os.Getenv("HISHTORY_TEST") == "" {
lib.CheckFatalError(fmt.Errorf("the hishtory import command is only meant to be for testing purposes"))
}
@ -66,12 +71,15 @@ func main() {
fmt.Printf("Imported %v history entries from your existing shell history", numImported)
}
case "enable":
ctx := hctx.MakeContext()
lib.CheckFatalError(lib.Enable(ctx))
case "disable":
ctx := hctx.MakeContext()
lib.CheckFatalError(lib.Disable(ctx))
case "version":
fallthrough
case "status":
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", lib.Version, config.IsEnabled)
fmt.Printf("Secret Key: %s\n", config.UserSecret)
@ -339,5 +347,3 @@ func export(ctx *context.Context, query string) {
fmt.Println(data[i].Command)
}
}
// TODO: Can we have a global db and config rather than this nonsense?