Add basic untested ability to do searches using atoms containing custom columns

This commit is contained in:
David Dworken 2022-11-01 10:23:35 -07:00
parent c70134a6fb
commit 8a3969cfc4
No known key found for this signature in database
5 changed files with 170 additions and 125 deletions

View File

@ -11,18 +11,18 @@ import (
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/araddon/dateparse"
"github.com/ddworken/hishtory/shared"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
KdfUserID = "user_id"
KdfEncryptionKey = "encryption_key"
CONFIG_PATH = ".hishtory.config"
HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db"
)
type HistoryEntry struct {
@ -153,110 +153,6 @@ func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (Histo
return decryptedEntry, nil
}
func parseTimeGenerously(input string) (time.Time, error) {
input = strings.ReplaceAll(input, "_", " ")
return dateparse.ParseLocal(input)
}
func MakeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Model(&HistoryEntry{}).Where("true")
for _, token := range tokens {
if strings.HasPrefix(token, "-") {
if strings.Contains(token, ":") {
query, v1, v2, err := parseAtomizedToken(token[1:])
if err != nil {
return nil, err
}
tx = tx.Where("NOT "+query, v1, v2)
} else {
query, v1, v2, v3, err := parseNonAtomizedToken(token[1:])
if err != nil {
return nil, err
}
tx = tx.Where("NOT "+query, v1, v2, v3)
}
} else if strings.Contains(token, ":") {
query, v1, v2, err := parseAtomizedToken(token)
if err != nil {
return nil, err
}
tx = tx.Where(query, v1, v2)
} else {
query, v1, v2, v3, err := parseNonAtomizedToken(token)
if err != nil {
return nil, err
}
tx = tx.Where(query, v1, v2, v3)
}
}
return tx, nil
}
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
tx, err := MakeWhereQueryFromSearch(db, query)
if err != nil {
return nil, err
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
wildcardedToken := "%" + token + "%"
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
}
func parseAtomizedToken(token string) (string, interface{}, interface{}, error) {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
switch field {
case "user":
return "(local_username = ?)", val, nil, nil
case "host":
fallthrough
case "hostname":
return "(instr(hostname, ?) > 0)", val, nil, nil
case "cwd":
return "(instr(current_working_directory, ?) > 0 OR instr(REPLACE(current_working_directory, '~/', home_directory), ?) > 0)", strings.TrimSuffix(val, "/"), strings.TrimSuffix(val, "/"), nil
case "exit_code":
return "(exit_code = ?)", val, nil, nil
case "before":
t, err := parseTimeGenerously(val)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to parse before:%s as a timestamp: %v", val, err)
}
return "(CAST(strftime(\"%s\",start_time) AS INTEGER) < ?)", t.Unix(), nil, nil
case "after":
t, err := parseTimeGenerously(val)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to parse after:%s as a timestamp: %v", val, err)
}
return "(CAST(strftime(\"%s\",start_time) AS INTEGER) > ?)", t.Unix(), nil, nil
default:
return "", nil, nil, fmt.Errorf("search query contains unknown search atom %s", field)
}
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}
func EntryEquals(entry1, entry2 HistoryEntry) bool {
return entry1.LocalUsername == entry2.LocalUsername &&
entry1.Hostname == entry2.Hostname &&
@ -267,9 +163,3 @@ func EntryEquals(entry1, entry2 HistoryEntry) bool {
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
}
const (
CONFIG_PATH = ".hishtory.config"
HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db"
)

View File

@ -28,6 +28,7 @@ import (
"gorm.io/gorm"
"github.com/araddon/dateparse"
"github.com/fatih/color"
"github.com/google/uuid"
"github.com/rodaine/table"
@ -1221,7 +1222,7 @@ func EncryptAndMarshal(config hctx.ClientConfig, entries []*data.HistoryEntry) (
}
func Redact(ctx *context.Context, query string, force bool) error {
tx, err := data.MakeWhereQueryFromSearch(hctx.GetDb(ctx), query)
tx, err := MakeWhereQueryFromSearch(ctx, hctx.GetDb(ctx), query)
if err != nil {
return err
}
@ -1244,7 +1245,7 @@ func Redact(ctx *context.Context, query string, force bool) error {
return nil
}
}
tx, err = data.MakeWhereQueryFromSearch(hctx.GetDb(ctx), query)
tx, err = MakeWhereQueryFromSearch(ctx, hctx.GetDb(ctx), query)
if err != nil {
return err
}
@ -1285,7 +1286,7 @@ func deleteOnRemoteInstances(ctx *context.Context, historyEntries []*data.Histor
func Reupload(ctx *context.Context) error {
config := hctx.GetConf(ctx)
entries, err := data.Search(hctx.GetDb(ctx), "", 0)
entries, err := Search(ctx, hctx.GetDb(ctx), "", 0)
if err != nil {
return fmt.Errorf("failed to reupload due to failed search: %v", err)
}
@ -1394,3 +1395,158 @@ func tweakConfigForTests(configContents string) (string, error) {
}
return ret, nil
}
func parseTimeGenerously(input string) (time.Time, error) {
input = strings.ReplaceAll(input, "_", " ")
return dateparse.ParseLocal(input)
}
func MakeWhereQueryFromSearch(ctx *context.Context, db *gorm.DB, query string) (*gorm.DB, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Model(&data.HistoryEntry{}).Where("true")
for _, token := range tokens {
if strings.HasPrefix(token, "-") {
if strings.Contains(token, ":") {
query, v1, v2, err := parseAtomizedToken(ctx, token[1:])
if err != nil {
return nil, err
}
tx = tx.Where("NOT "+query, v1, v2)
} else {
query, v1, v2, v3, err := parseNonAtomizedToken(token[1:])
if err != nil {
return nil, err
}
tx = tx.Where("NOT "+query, v1, v2, v3)
}
} else if strings.Contains(token, ":") {
query, v1, v2, err := parseAtomizedToken(ctx, token)
if err != nil {
return nil, err
}
tx = tx.Where(query, v1, v2)
} else {
query, v1, v2, v3, err := parseNonAtomizedToken(token)
if err != nil {
return nil, err
}
tx = tx.Where(query, v1, v2, v3)
}
}
return tx, nil
}
func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data.HistoryEntry, error) {
if ctx == nil && query != "" {
return nil, fmt.Errorf("lib.Search called with a nil context and a non-empty query (this should never happen)")
}
tx, err := MakeWhereQueryFromSearch(ctx, db, query)
if err != nil {
return nil, err
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*data.HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
wildcardedToken := "%" + token + "%"
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
}
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
switch field {
case "user":
return "(local_username = ?)", val, nil, nil
case "host":
fallthrough
case "hostname":
return "(instr(hostname, ?) > 0)", val, nil, nil
case "cwd":
return "(instr(current_working_directory, ?) > 0 OR instr(REPLACE(current_working_directory, '~/', home_directory), ?) > 0)", strings.TrimSuffix(val, "/"), strings.TrimSuffix(val, "/"), nil
case "exit_code":
return "(exit_code = ?)", val, nil, nil
case "before":
t, err := parseTimeGenerously(val)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to parse before:%s as a timestamp: %v", val, err)
}
return "(CAST(strftime(\"%s\",start_time) AS INTEGER) < ?)", t.Unix(), nil, nil
case "after":
t, err := parseTimeGenerously(val)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to parse after:%s as a timestamp: %v", val, err)
}
return "(CAST(strftime(\"%s\",start_time) AS INTEGER) > ?)", t.Unix(), nil, nil
default:
knownCustomColumns := make([]string, 0)
// Get custom columns that are defined on this machine
conf := hctx.GetConf(ctx)
for _, c := range conf.CustomColumns {
knownCustomColumns = append(knownCustomColumns, c.ColumnName)
}
// Also get all ones that are in the DB
names, err := getAllCustomColumnNames(ctx)
if err != nil {
return "", nil, nil, fmt.Errorf("failed to get custom column names from the DB: %v", err)
}
knownCustomColumns = append(knownCustomColumns, names...)
// Check if the atom is for a custom column that exists and if it isn't, return an error
isCustomColumn := false
for _, ccName := range knownCustomColumns {
if ccName == field {
isCustomColumn = true
}
}
if !isCustomColumn {
return "", nil, nil, fmt.Errorf("search query contains unknown search atom %s", field)
}
// Build the where clause for the custom column
return "EXISTS (SELECT 1 FROM json_each(custom_columns) WHERE json_extract(value, '$.name') = ? and instr(json_extract(value, '$.value'), ?) > 0)", field, val, nil
}
}
func getAllCustomColumnNames(ctx *context.Context) ([]string, error) {
db := hctx.GetDb(ctx)
query := `
SELECT DISTINCT json_extract(value, '$.name') as cc_name
FROM history_entries
JOIN json_each(custom_columns)
WHERE value IS NOT NULL
LIMIT 10`
rows, err := db.Raw(query).Rows()
if err != nil {
return nil, err
}
ccNames := make([]string, 0)
for rows.Next() {
var ccName string
err = rows.Scan(&ccName)
if err != nil {
return nil, err
}
ccNames = append(ccNames, ccName)
}
return ccNames, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}

View File

@ -187,7 +187,7 @@ func TestSearch(t *testing.T) {
db.Create(entry2)
// Search for data
results, err := data.Search(db, "ls", 5)
results, err := lib.Search(ctx, db, "ls", 5)
testutils.Check(t, err)
if len(results) != 2 {
t.Fatalf("Search() returned %d results, expected 2!", len(results))

View File

@ -14,7 +14,6 @@ import (
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/hctx"
"github.com/muesli/termenv"
"golang.org/x/term"
@ -199,7 +198,7 @@ func (m model) View() string {
func getRows(ctx *context.Context, columnNames []string, query string, numEntries int) ([]table.Row, int, error) {
db := hctx.GetDb(ctx)
data, err := data.Search(db, query, numEntries)
data, err := Search(ctx, db, query, numEntries)
if err != nil {
return nil, 0, err
}

View File

@ -57,7 +57,7 @@ func main() {
case "init":
db, err := hctx.OpenLocalSqliteDb()
lib.CheckFatalError(err)
data, err := data.Search(db, "", 10)
data, err := lib.Search(nil, db, "", 10)
lib.CheckFatalError(err)
if len(data) > 0 {
fmt.Printf("Your current hishtory profile has saved history entries, are you sure you want to run `init` and reset? [y/N]")
@ -75,7 +75,7 @@ func main() {
if os.Getenv("HISHTORY_TEST") == "" {
db, err := hctx.OpenLocalSqliteDb()
lib.CheckFatalError(err)
data, err := data.Search(db, "", 10)
data, err := lib.Search(nil, db, "", 10)
lib.CheckFatalError(err)
if len(data) < 10 {
fmt.Println("Importing existing shell history...")
@ -299,7 +299,7 @@ func query(ctx *context.Context, query string) {
}
}
lib.CheckFatalError(displayBannerIfSet(ctx))
data, err := data.Search(db, query, 25)
data, err := lib.Search(ctx, db, query, 25)
lib.CheckFatalError(err)
lib.CheckFatalError(lib.DisplayResults(ctx, data))
}
@ -327,7 +327,7 @@ func maybeUploadSkippedHistoryEntries(ctx *context.Context) error {
// Upload the missing entries
db := hctx.GetDb(ctx)
query := fmt.Sprintf("after:%s", time.Unix(config.MissedUploadTimestamp, 0).Format("2006-01-02"))
entries, err := data.Search(db, query, 0)
entries, err := lib.Search(ctx, db, query, 0)
if err != nil {
return fmt.Errorf("failed to retrieve history entries that haven't been uploaded yet: %v", err)
}
@ -400,7 +400,7 @@ func saveHistoryEntry(ctx *context.Context) {
}
if len(dumpRequests) > 0 {
lib.CheckFatalError(lib.RetrieveAdditionalEntriesFromRemote(ctx))
entries, err := data.Search(db, "", 0)
entries, err := lib.Search(ctx, db, "", 0)
lib.CheckFatalError(err)
var encEntries []*shared.EncHistoryEntry
for _, entry := range entries {
@ -427,7 +427,7 @@ func export(ctx *context.Context, query string) {
lib.CheckFatalError(err)
}
}
data, err := data.Search(db, query, 0)
data, err := lib.Search(ctx, db, query, 0)
lib.CheckFatalError(err)
for i := len(data) - 1; i >= 0; i-- {
fmt.Println(data[i].Command)