mirror of
https://github.com/ddworken/hishtory.git
synced 2025-08-19 11:20:18 +02:00
Merge pull request #73 from ddworken/escape
Add support for escaping search queries
This commit is contained in:
@@ -659,6 +659,21 @@ hishtory disable`)
|
|||||||
if diff := cmp.Diff(expectedOutput, out); diff != "" {
|
if diff := cmp.Diff(expectedOutput, out); diff != "" {
|
||||||
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
|
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search using an escaped dash
|
||||||
|
out = tester.RunInteractiveShell(t, `hishtory export \\-echo`)
|
||||||
|
expectedOutput = "foo -echo\n"
|
||||||
|
if diff := cmp.Diff(expectedOutput, out); diff != "" {
|
||||||
|
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search using a colon that doesn't match a column name
|
||||||
|
manuallySubmitHistoryEntry(t, userSecret, testutils.MakeFakeHistoryEntry("foo:bar"))
|
||||||
|
out = tester.RunInteractiveShell(t, `hishtory export foo\\:bar`)
|
||||||
|
expectedOutput = "foo:bar\n"
|
||||||
|
if diff := cmp.Diff(expectedOutput, out); diff != "" {
|
||||||
|
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testUpdate(t *testing.T, tester shellTester) {
|
func testUpdate(t *testing.T, tester shellTester) {
|
||||||
|
@@ -1098,7 +1098,7 @@ func MakeWhereQueryFromSearch(ctx *context.Context, db *gorm.DB, query string) (
|
|||||||
tx := db.Model(&data.HistoryEntry{}).Where("true")
|
tx := db.Model(&data.HistoryEntry{}).Where("true")
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
if strings.HasPrefix(token, "-") {
|
if strings.HasPrefix(token, "-") {
|
||||||
if strings.Contains(token, ":") {
|
if containsUnescaped(token, ":") {
|
||||||
query, v1, v2, err := parseAtomizedToken(ctx, token[1:])
|
query, v1, v2, err := parseAtomizedToken(ctx, token[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1111,7 +1111,7 @@ func MakeWhereQueryFromSearch(ctx *context.Context, db *gorm.DB, query string) (
|
|||||||
}
|
}
|
||||||
tx = tx.Where("NOT "+query, v1, v2, v3)
|
tx = tx.Where("NOT "+query, v1, v2, v3)
|
||||||
}
|
}
|
||||||
} else if strings.Contains(token, ":") {
|
} else if containsUnescaped(token, ":") {
|
||||||
query, v1, v2, err := parseAtomizedToken(ctx, token)
|
query, v1, v2, err := parseAtomizedToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1150,14 +1150,14 @@ func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
|
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
|
||||||
wildcardedToken := "%" + token + "%"
|
wildcardedToken := "%" + stripBackslash(token) + "%"
|
||||||
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
|
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
|
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
|
||||||
splitToken := strings.SplitN(token, ":", 2)
|
splitToken := splitEscaped(token, ':', 2)
|
||||||
field := splitToken[0]
|
field := stripBackslash(splitToken[0])
|
||||||
val := splitToken[1]
|
val := stripBackslash(splitToken[1])
|
||||||
switch field {
|
switch field {
|
||||||
case "user":
|
case "user":
|
||||||
return "(local_username = ?)", val, nil, nil
|
return "(local_username = ?)", val, nil, nil
|
||||||
@@ -1237,7 +1237,50 @@ func tokenize(query string) ([]string, error) {
|
|||||||
if query == "" {
|
if query == "" {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
}
|
}
|
||||||
return strings.Split(query, " "), nil
|
return splitEscaped(query, ' ', -1), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitEscaped(query string, separator rune, maxSplit int) []string {
|
||||||
|
var token []rune
|
||||||
|
var tokens []string
|
||||||
|
splits := 1
|
||||||
|
runeQuery := []rune(query)
|
||||||
|
for i := 0; i < len(runeQuery); i++ {
|
||||||
|
if (maxSplit < 0 || splits < maxSplit) && runeQuery[i] == separator {
|
||||||
|
tokens = append(tokens, string(token))
|
||||||
|
token = token[:0]
|
||||||
|
splits++
|
||||||
|
} else if runeQuery[i] == '\\' && i+1 < len(runeQuery) {
|
||||||
|
token = append(token, runeQuery[i], runeQuery[i+1])
|
||||||
|
i++
|
||||||
|
} else {
|
||||||
|
token = append(token, runeQuery[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokens = append(tokens, string(token))
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsUnescaped(query string, token string) bool {
|
||||||
|
runeQuery := []rune(query)
|
||||||
|
for i := 0; i < len(runeQuery); i++ {
|
||||||
|
if runeQuery[i] == '\\' && i+1 < len(runeQuery) {
|
||||||
|
i++
|
||||||
|
} else if string(runeQuery[i:i+len(token)]) == token {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripBackslash(query string) string {
|
||||||
|
var newQuery []rune
|
||||||
|
for _, char := range query {
|
||||||
|
if char != '\\' {
|
||||||
|
newQuery = append(newQuery, char)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(newQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDumpRequests(config hctx.ClientConfig) ([]*shared.DumpRequest, error) {
|
func GetDumpRequests(config hctx.ClientConfig) ([]*shared.DumpRequest, error) {
|
||||||
|
@@ -228,6 +228,52 @@ func TestSearch(t *testing.T) {
|
|||||||
if !data.EntryEquals(*results[1], entry1) {
|
if !data.EntryEquals(*results[1], entry1) {
|
||||||
t.Fatalf("Search()[0]=%#v, expected: %#v", results[1], entry1)
|
t.Fatalf("Search()[0]=%#v, expected: %#v", results[1], entry1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search but exclude bar
|
||||||
|
results, err = Search(ctx, db, "ls -bar", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search but exclude foo
|
||||||
|
results, err = Search(ctx, db, "ls -foo", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search but include / also
|
||||||
|
results, err = Search(ctx, db, "ls /", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search but exclude slash
|
||||||
|
results, err = Search(ctx, db, "ls -/", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 0 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 0, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests for escaping
|
||||||
|
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("ls -baz")).Error)
|
||||||
|
results, err = Search(ctx, db, "ls", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 3 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
|
results, err = Search(ctx, db, "ls -baz", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 2, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
|
results, err = Search(ctx, db, "ls \\-baz", 5)
|
||||||
|
testutils.Check(t, err)
|
||||||
|
if len(results) != 1 {
|
||||||
|
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddToDbIfNew(t *testing.T) {
|
func TestAddToDbIfNew(t *testing.T) {
|
||||||
@@ -433,3 +479,76 @@ func TestParseTimeGenerously(t *testing.T) {
|
|||||||
t.Fatalf("parsed time incorrectly: %d", ts.Unix())
|
t.Fatalf("parsed time incorrectly: %d", ts.Unix())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStripBackslash(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
input string
|
||||||
|
output string
|
||||||
|
}{
|
||||||
|
{"f bar", "f bar"},
|
||||||
|
{"f \\bar", "f bar"},
|
||||||
|
{"f\\:bar", "f:bar"},
|
||||||
|
{"f\\:bar\\", "f:bar"},
|
||||||
|
{"\\f\\:bar\\", "f:bar"},
|
||||||
|
{"", ""},
|
||||||
|
{"\\\\", ""},
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
actual := stripBackslash(tc.input)
|
||||||
|
if !reflect.DeepEqual(actual, tc.output) {
|
||||||
|
t.Fatalf("unescape failure for %#v, actual=%#v", tc.input, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContainsUnescaped(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
input string
|
||||||
|
token string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"f bar", "f", true},
|
||||||
|
{"f bar", "f bar", true},
|
||||||
|
{"f bar", "f r", false},
|
||||||
|
{"f bar", "f ", true},
|
||||||
|
{"foo:bar", ":", true},
|
||||||
|
{"foo:bar", "-", false},
|
||||||
|
{"foo\\:bar", ":", false},
|
||||||
|
{"foo\\-bar", "-", false},
|
||||||
|
{"foo\\-bar", "foo", true},
|
||||||
|
{"foo\\-bar", "bar", true},
|
||||||
|
{"foo\\-bar", "a", true},
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
actual := containsUnescaped(tc.input, tc.token)
|
||||||
|
if !reflect.DeepEqual(actual, tc.expected) {
|
||||||
|
t.Fatalf("containsUnescaped failure for containsUnescaped(%#v, %#v), actual=%#v", tc.input, tc.token, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitEscaped(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
input string
|
||||||
|
char rune
|
||||||
|
limit int
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{"foo bar", ' ', 2, []string{"foo", "bar"}},
|
||||||
|
{"foo bar baz", ' ', 2, []string{"foo", "bar baz"}},
|
||||||
|
{"foo bar baz", ' ', 3, []string{"foo", "bar", "baz"}},
|
||||||
|
{"foo bar baz", ' ', 1, []string{"foo bar baz"}},
|
||||||
|
{"foo bar baz", ' ', -1, []string{"foo", "bar", "baz"}},
|
||||||
|
{"foo\\ bar baz", ' ', -1, []string{"foo\\ bar", "baz"}},
|
||||||
|
{"foo\\bar baz", ' ', -1, []string{"foo\\bar", "baz"}},
|
||||||
|
{"foo\\bar baz foob", ' ', 2, []string{"foo\\bar", "baz foob"}},
|
||||||
|
{"foo\\ bar\\ baz", ' ', -1, []string{"foo\\ bar\\ baz"}},
|
||||||
|
{"foo\\ bar\\ baz", ' ', -1, []string{"foo\\ bar\\ ", "baz"}},
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
actual := splitEscaped(tc.input, tc.char, tc.limit)
|
||||||
|
if !reflect.DeepEqual(actual, tc.expected) {
|
||||||
|
t.Fatalf("containsUnescaped failure for splitEscaped(%#v, %#v, %#v), actual=%#v", tc.input, string(tc.char), tc.limit, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user