diff --git a/client/lib/lib.go b/client/lib/lib.go index 990039c..1af130b 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -1174,14 +1174,14 @@ func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data } func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) { - wildcardedToken := "%" + stripBackslash(token) + "%" + wildcardedToken := "%" + unescape(token) + "%" return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil } func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) { splitToken := splitEscaped(token, ':', 2) - field := stripBackslash(splitToken[0]) - val := stripBackslash(splitToken[1]) + field := unescape(splitToken[0]) + val := unescape(splitToken[1]) switch field { case "user": return "(local_username = ?)", val, nil, nil @@ -1297,11 +1297,15 @@ func containsUnescaped(query string, token string) bool { return false } -func stripBackslash(query string) string { +func unescape(query string) string { + runeQuery := []rune(query) var newQuery []rune - for _, char := range query { - if char != '\\' { - newQuery = append(newQuery, char) + for i := 0; i < len(runeQuery); i++ { + if runeQuery[i] == '\\' { + i++ + } + if i < len(runeQuery) { + newQuery = append(newQuery, runeQuery[i]) } } return string(newQuery) diff --git a/client/lib/lib_test.go b/client/lib/lib_test.go index 27b4add..885fb0b 100644 --- a/client/lib/lib_test.go +++ b/client/lib/lib_test.go @@ -281,6 +281,14 @@ func TestSearch(t *testing.T) { if len(results) != 3 { t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results) } + + // A search for an entry containing a backslash + testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("echo '\\'")).Error) + results, err = Search(ctx, db, "\\\\", 5) + testutils.Check(t, err) + if len(results) != 1 { + t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results) + } } func TestAddToDbIfNew(t *testing.T) { @@ -487,7 +495,7 @@ func TestParseTimeGenerously(t *testing.T) { } } -func TestStripBackslash(t *testing.T) { +func TestUnescape(t *testing.T) { testcases := []struct { input string output string @@ -498,10 +506,11 @@ func TestStripBackslash(t *testing.T) { {"f\\:bar\\", "f:bar"}, {"\\f\\:bar\\", "f:bar"}, {"", ""}, - {"\\\\", ""}, + {"\\", ""}, + {"\\\\", "\\"}, } for _, tc := range testcases { - actual := stripBackslash(tc.input) + actual := unescape(tc.input) if !reflect.DeepEqual(actual, tc.output) { t.Fatalf("unescape failure for %#v, actual=%#v", tc.input, actual) }