Fix unescape function per comments on #73 and add tests for searching for a backslash

This commit is contained in:
David Dworken 2023-02-20 15:46:39 -08:00
parent dddedcf9a7
commit ff24b66fce
No known key found for this signature in database
2 changed files with 23 additions and 10 deletions

View File

@ -1174,14 +1174,14 @@ func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data
} }
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) { func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
wildcardedToken := "%" + stripBackslash(token) + "%" wildcardedToken := "%" + unescape(token) + "%"
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
} }
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) { func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
splitToken := splitEscaped(token, ':', 2) splitToken := splitEscaped(token, ':', 2)
field := stripBackslash(splitToken[0]) field := unescape(splitToken[0])
val := stripBackslash(splitToken[1]) val := unescape(splitToken[1])
switch field { switch field {
case "user": case "user":
return "(local_username = ?)", val, nil, nil return "(local_username = ?)", val, nil, nil
@ -1297,11 +1297,15 @@ func containsUnescaped(query string, token string) bool {
return false return false
} }
func stripBackslash(query string) string { func unescape(query string) string {
runeQuery := []rune(query)
var newQuery []rune var newQuery []rune
for _, char := range query { for i := 0; i < len(runeQuery); i++ {
if char != '\\' { if runeQuery[i] == '\\' {
newQuery = append(newQuery, char) i++
}
if i < len(runeQuery) {
newQuery = append(newQuery, runeQuery[i])
} }
} }
return string(newQuery) return string(newQuery)

View File

@ -281,6 +281,14 @@ func TestSearch(t *testing.T) {
if len(results) != 3 { if len(results) != 3 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
} }
// A search for an entry containing a backslash
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("echo '\\'")).Error)
results, err = Search(ctx, db, "\\\\", 5)
testutils.Check(t, err)
if len(results) != 1 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
}
} }
func TestAddToDbIfNew(t *testing.T) { func TestAddToDbIfNew(t *testing.T) {
@ -487,7 +495,7 @@ func TestParseTimeGenerously(t *testing.T) {
} }
} }
func TestStripBackslash(t *testing.T) { func TestUnescape(t *testing.T) {
testcases := []struct { testcases := []struct {
input string input string
output string output string
@ -498,10 +506,11 @@ func TestStripBackslash(t *testing.T) {
{"f\\:bar\\", "f:bar"}, {"f\\:bar\\", "f:bar"},
{"\\f\\:bar\\", "f:bar"}, {"\\f\\:bar\\", "f:bar"},
{"", ""}, {"", ""},
{"\\\\", ""}, {"\\", ""},
{"\\\\", "\\"},
} }
for _, tc := range testcases { for _, tc := range testcases {
actual := stripBackslash(tc.input) actual := unescape(tc.input)
if !reflect.DeepEqual(actual, tc.output) { if !reflect.DeepEqual(actual, tc.output) {
t.Fatalf("unescape failure for %#v, actual=%#v", tc.input, actual) t.Fatalf("unescape failure for %#v, actual=%#v", tc.input, actual)
} }