diff --git a/client/data/data.go b/client/data/data.go index 28eb194..292a611 100644 --- a/client/data/data.go +++ b/client/data/data.go @@ -11,18 +11,18 @@ import ( "encoding/json" "fmt" "io" - "strings" "time" - "github.com/araddon/dateparse" "github.com/ddworken/hishtory/shared" "github.com/google/uuid" - "gorm.io/gorm" ) const ( KdfUserID = "user_id" KdfEncryptionKey = "encryption_key" + CONFIG_PATH = ".hishtory.config" + HISHTORY_PATH = ".hishtory" + DB_PATH = ".hishtory.db" ) type HistoryEntry struct { @@ -153,110 +153,6 @@ func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (Histo return decryptedEntry, nil } -func parseTimeGenerously(input string) (time.Time, error) { - input = strings.ReplaceAll(input, "_", " ") - return dateparse.ParseLocal(input) -} - -func MakeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) { - tokens, err := tokenize(query) - if err != nil { - return nil, fmt.Errorf("failed to tokenize query: %v", err) - } - tx := db.Model(&HistoryEntry{}).Where("true") - for _, token := range tokens { - if strings.HasPrefix(token, "-") { - if strings.Contains(token, ":") { - query, v1, v2, err := parseAtomizedToken(token[1:]) - if err != nil { - return nil, err - } - tx = tx.Where("NOT "+query, v1, v2) - } else { - query, v1, v2, v3, err := parseNonAtomizedToken(token[1:]) - if err != nil { - return nil, err - } - tx = tx.Where("NOT "+query, v1, v2, v3) - } - } else if strings.Contains(token, ":") { - query, v1, v2, err := parseAtomizedToken(token) - if err != nil { - return nil, err - } - tx = tx.Where(query, v1, v2) - } else { - query, v1, v2, v3, err := parseNonAtomizedToken(token) - if err != nil { - return nil, err - } - tx = tx.Where(query, v1, v2, v3) - } - } - return tx, nil -} - -func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) { - tx, err := MakeWhereQueryFromSearch(db, query) - if err != nil { - return nil, err - } - tx = tx.Order("end_time DESC") - if limit > 0 { - tx = tx.Limit(limit) - } - var historyEntries []*HistoryEntry - result := tx.Find(&historyEntries) - if result.Error != nil { - return nil, fmt.Errorf("DB query error: %v", result.Error) - } - return historyEntries, nil -} - -func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) { - wildcardedToken := "%" + token + "%" - return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil -} - -func parseAtomizedToken(token string) (string, interface{}, interface{}, error) { - splitToken := strings.SplitN(token, ":", 2) - field := splitToken[0] - val := splitToken[1] - switch field { - case "user": - return "(local_username = ?)", val, nil, nil - case "host": - fallthrough - case "hostname": - return "(instr(hostname, ?) > 0)", val, nil, nil - case "cwd": - return "(instr(current_working_directory, ?) > 0 OR instr(REPLACE(current_working_directory, '~/', home_directory), ?) > 0)", strings.TrimSuffix(val, "/"), strings.TrimSuffix(val, "/"), nil - case "exit_code": - return "(exit_code = ?)", val, nil, nil - case "before": - t, err := parseTimeGenerously(val) - if err != nil { - return "", nil, nil, fmt.Errorf("failed to parse before:%s as a timestamp: %v", val, err) - } - return "(CAST(strftime(\"%s\",start_time) AS INTEGER) < ?)", t.Unix(), nil, nil - case "after": - t, err := parseTimeGenerously(val) - if err != nil { - return "", nil, nil, fmt.Errorf("failed to parse after:%s as a timestamp: %v", val, err) - } - return "(CAST(strftime(\"%s\",start_time) AS INTEGER) > ?)", t.Unix(), nil, nil - default: - return "", nil, nil, fmt.Errorf("search query contains unknown search atom %s", field) - } -} - -func tokenize(query string) ([]string, error) { - if query == "" { - return []string{}, nil - } - return strings.Split(query, " "), nil -} - func EntryEquals(entry1, entry2 HistoryEntry) bool { return entry1.LocalUsername == entry2.LocalUsername && entry1.Hostname == entry2.Hostname && @@ -267,9 +163,3 @@ func EntryEquals(entry1, entry2 HistoryEntry) bool { entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) && entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339) } - -const ( - CONFIG_PATH = ".hishtory.config" - HISHTORY_PATH = ".hishtory" - DB_PATH = ".hishtory.db" -) diff --git a/client/lib/lib.go b/client/lib/lib.go index 4676b5f..6a17954 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -28,6 +28,7 @@ import ( "gorm.io/gorm" + "github.com/araddon/dateparse" "github.com/fatih/color" "github.com/google/uuid" "github.com/rodaine/table" @@ -1221,7 +1222,7 @@ func EncryptAndMarshal(config hctx.ClientConfig, entries []*data.HistoryEntry) ( } func Redact(ctx *context.Context, query string, force bool) error { - tx, err := data.MakeWhereQueryFromSearch(hctx.GetDb(ctx), query) + tx, err := MakeWhereQueryFromSearch(ctx, hctx.GetDb(ctx), query) if err != nil { return err } @@ -1244,7 +1245,7 @@ func Redact(ctx *context.Context, query string, force bool) error { return nil } } - tx, err = data.MakeWhereQueryFromSearch(hctx.GetDb(ctx), query) + tx, err = MakeWhereQueryFromSearch(ctx, hctx.GetDb(ctx), query) if err != nil { return err } @@ -1285,7 +1286,7 @@ func deleteOnRemoteInstances(ctx *context.Context, historyEntries []*data.Histor func Reupload(ctx *context.Context) error { config := hctx.GetConf(ctx) - entries, err := data.Search(hctx.GetDb(ctx), "", 0) + entries, err := Search(ctx, hctx.GetDb(ctx), "", 0) if err != nil { return fmt.Errorf("failed to reupload due to failed search: %v", err) } @@ -1394,3 +1395,158 @@ func tweakConfigForTests(configContents string) (string, error) { } return ret, nil } + +func parseTimeGenerously(input string) (time.Time, error) { + input = strings.ReplaceAll(input, "_", " ") + return dateparse.ParseLocal(input) +} + +func MakeWhereQueryFromSearch(ctx *context.Context, db *gorm.DB, query string) (*gorm.DB, error) { + tokens, err := tokenize(query) + if err != nil { + return nil, fmt.Errorf("failed to tokenize query: %v", err) + } + tx := db.Model(&data.HistoryEntry{}).Where("true") + for _, token := range tokens { + if strings.HasPrefix(token, "-") { + if strings.Contains(token, ":") { + query, v1, v2, err := parseAtomizedToken(ctx, token[1:]) + if err != nil { + return nil, err + } + tx = tx.Where("NOT "+query, v1, v2) + } else { + query, v1, v2, v3, err := parseNonAtomizedToken(token[1:]) + if err != nil { + return nil, err + } + tx = tx.Where("NOT "+query, v1, v2, v3) + } + } else if strings.Contains(token, ":") { + query, v1, v2, err := parseAtomizedToken(ctx, token) + if err != nil { + return nil, err + } + tx = tx.Where(query, v1, v2) + } else { + query, v1, v2, v3, err := parseNonAtomizedToken(token) + if err != nil { + return nil, err + } + tx = tx.Where(query, v1, v2, v3) + } + } + return tx, nil +} + +func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data.HistoryEntry, error) { + if ctx == nil && query != "" { + return nil, fmt.Errorf("lib.Search called with a nil context and a non-empty query (this should never happen)") + } + + tx, err := MakeWhereQueryFromSearch(ctx, db, query) + if err != nil { + return nil, err + } + tx = tx.Order("end_time DESC") + if limit > 0 { + tx = tx.Limit(limit) + } + var historyEntries []*data.HistoryEntry + result := tx.Find(&historyEntries) + if result.Error != nil { + return nil, fmt.Errorf("DB query error: %v", result.Error) + } + return historyEntries, nil +} + +func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) { + wildcardedToken := "%" + token + "%" + return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil +} + +func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) { + splitToken := strings.SplitN(token, ":", 2) + field := splitToken[0] + val := splitToken[1] + switch field { + case "user": + return "(local_username = ?)", val, nil, nil + case "host": + fallthrough + case "hostname": + return "(instr(hostname, ?) > 0)", val, nil, nil + case "cwd": + return "(instr(current_working_directory, ?) > 0 OR instr(REPLACE(current_working_directory, '~/', home_directory), ?) > 0)", strings.TrimSuffix(val, "/"), strings.TrimSuffix(val, "/"), nil + case "exit_code": + return "(exit_code = ?)", val, nil, nil + case "before": + t, err := parseTimeGenerously(val) + if err != nil { + return "", nil, nil, fmt.Errorf("failed to parse before:%s as a timestamp: %v", val, err) + } + return "(CAST(strftime(\"%s\",start_time) AS INTEGER) < ?)", t.Unix(), nil, nil + case "after": + t, err := parseTimeGenerously(val) + if err != nil { + return "", nil, nil, fmt.Errorf("failed to parse after:%s as a timestamp: %v", val, err) + } + return "(CAST(strftime(\"%s\",start_time) AS INTEGER) > ?)", t.Unix(), nil, nil + default: + knownCustomColumns := make([]string, 0) + // Get custom columns that are defined on this machine + conf := hctx.GetConf(ctx) + for _, c := range conf.CustomColumns { + knownCustomColumns = append(knownCustomColumns, c.ColumnName) + } + // Also get all ones that are in the DB + names, err := getAllCustomColumnNames(ctx) + if err != nil { + return "", nil, nil, fmt.Errorf("failed to get custom column names from the DB: %v", err) + } + knownCustomColumns = append(knownCustomColumns, names...) + // Check if the atom is for a custom column that exists and if it isn't, return an error + isCustomColumn := false + for _, ccName := range knownCustomColumns { + if ccName == field { + isCustomColumn = true + } + } + if !isCustomColumn { + return "", nil, nil, fmt.Errorf("search query contains unknown search atom %s", field) + } + // Build the where clause for the custom column + return "EXISTS (SELECT 1 FROM json_each(custom_columns) WHERE json_extract(value, '$.name') = ? and instr(json_extract(value, '$.value'), ?) > 0)", field, val, nil + } +} + +func getAllCustomColumnNames(ctx *context.Context) ([]string, error) { + db := hctx.GetDb(ctx) + query := ` + SELECT DISTINCT json_extract(value, '$.name') as cc_name + FROM history_entries + JOIN json_each(custom_columns) + WHERE value IS NOT NULL + LIMIT 10` + rows, err := db.Raw(query).Rows() + if err != nil { + return nil, err + } + ccNames := make([]string, 0) + for rows.Next() { + var ccName string + err = rows.Scan(&ccName) + if err != nil { + return nil, err + } + ccNames = append(ccNames, ccName) + } + return ccNames, nil +} + +func tokenize(query string) ([]string, error) { + if query == "" { + return []string{}, nil + } + return strings.Split(query, " "), nil +} diff --git a/client/lib/lib_test.go b/client/lib/lib_test.go index 9eaa633..77cee2b 100644 --- a/client/lib/lib_test.go +++ b/client/lib/lib_test.go @@ -187,7 +187,7 @@ func TestSearch(t *testing.T) { db.Create(entry2) // Search for data - results, err := data.Search(db, "ls", 5) + results, err := lib.Search(ctx, db, "ls", 5) testutils.Check(t, err) if len(results) != 2 { t.Fatalf("Search() returned %d results, expected 2!", len(results)) diff --git a/client/lib/tui.go b/client/lib/tui.go index 063049e..56219fb 100644 --- a/client/lib/tui.go +++ b/client/lib/tui.go @@ -14,7 +14,6 @@ import ( "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/hctx" "github.com/muesli/termenv" "golang.org/x/term" @@ -199,7 +198,7 @@ func (m model) View() string { func getRows(ctx *context.Context, columnNames []string, query string, numEntries int) ([]table.Row, int, error) { db := hctx.GetDb(ctx) - data, err := data.Search(db, query, numEntries) + data, err := Search(ctx, db, query, numEntries) if err != nil { return nil, 0, err } diff --git a/hishtory.go b/hishtory.go index 5da3a02..e795d91 100644 --- a/hishtory.go +++ b/hishtory.go @@ -57,7 +57,7 @@ func main() { case "init": db, err := hctx.OpenLocalSqliteDb() lib.CheckFatalError(err) - data, err := data.Search(db, "", 10) + data, err := lib.Search(nil, db, "", 10) lib.CheckFatalError(err) if len(data) > 0 { fmt.Printf("Your current hishtory profile has saved history entries, are you sure you want to run `init` and reset? [y/N]") @@ -75,7 +75,7 @@ func main() { if os.Getenv("HISHTORY_TEST") == "" { db, err := hctx.OpenLocalSqliteDb() lib.CheckFatalError(err) - data, err := data.Search(db, "", 10) + data, err := lib.Search(nil, db, "", 10) lib.CheckFatalError(err) if len(data) < 10 { fmt.Println("Importing existing shell history...") @@ -299,7 +299,7 @@ func query(ctx *context.Context, query string) { } } lib.CheckFatalError(displayBannerIfSet(ctx)) - data, err := data.Search(db, query, 25) + data, err := lib.Search(ctx, db, query, 25) lib.CheckFatalError(err) lib.CheckFatalError(lib.DisplayResults(ctx, data)) } @@ -327,7 +327,7 @@ func maybeUploadSkippedHistoryEntries(ctx *context.Context) error { // Upload the missing entries db := hctx.GetDb(ctx) query := fmt.Sprintf("after:%s", time.Unix(config.MissedUploadTimestamp, 0).Format("2006-01-02")) - entries, err := data.Search(db, query, 0) + entries, err := lib.Search(ctx, db, query, 0) if err != nil { return fmt.Errorf("failed to retrieve history entries that haven't been uploaded yet: %v", err) } @@ -400,7 +400,7 @@ func saveHistoryEntry(ctx *context.Context) { } if len(dumpRequests) > 0 { lib.CheckFatalError(lib.RetrieveAdditionalEntriesFromRemote(ctx)) - entries, err := data.Search(db, "", 0) + entries, err := lib.Search(ctx, db, "", 0) lib.CheckFatalError(err) var encEntries []*shared.EncHistoryEntry for _, entry := range entries { @@ -427,7 +427,7 @@ func export(ctx *context.Context, query string) { lib.CheckFatalError(err) } } - data, err := data.Search(db, query, 0) + data, err := lib.Search(ctx, db, query, 0) lib.CheckFatalError(err) for i := len(data) - 1; i >= 0; i-- { fmt.Println(data[i].Command)