Fix unescape function per comments on #73 and add tests for searching for a backslash

This commit is contained in:
David Dworken 2023-02-20 15:46:39 -08:00
parent dddedcf9a7
commit ff24b66fce
No known key found for this signature in database
2 changed files with 23 additions and 10 deletions

View File

@ -1174,14 +1174,14 @@ func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data
}
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
wildcardedToken := "%" + stripBackslash(token) + "%"
wildcardedToken := "%" + unescape(token) + "%"
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
}
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
splitToken := splitEscaped(token, ':', 2)
field := stripBackslash(splitToken[0])
val := stripBackslash(splitToken[1])
field := unescape(splitToken[0])
val := unescape(splitToken[1])
switch field {
case "user":
return "(local_username = ?)", val, nil, nil
@ -1297,11 +1297,15 @@ func containsUnescaped(query string, token string) bool {
return false
}
func stripBackslash(query string) string {
func unescape(query string) string {
runeQuery := []rune(query)
var newQuery []rune
for _, char := range query {
if char != '\\' {
newQuery = append(newQuery, char)
for i := 0; i < len(runeQuery); i++ {
if runeQuery[i] == '\\' {
i++
}
if i < len(runeQuery) {
newQuery = append(newQuery, runeQuery[i])
}
}
return string(newQuery)

View File

@ -281,6 +281,14 @@ func TestSearch(t *testing.T) {
if len(results) != 3 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
}
// A search for an entry containing a backslash
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("echo '\\'")).Error)
results, err = Search(ctx, db, "\\\\", 5)
testutils.Check(t, err)
if len(results) != 1 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
}
}
func TestAddToDbIfNew(t *testing.T) {
@ -487,7 +495,7 @@ func TestParseTimeGenerously(t *testing.T) {
}
}
func TestStripBackslash(t *testing.T) {
func TestUnescape(t *testing.T) {
testcases := []struct {
input string
output string
@ -498,10 +506,11 @@ func TestStripBackslash(t *testing.T) {
{"f\\:bar\\", "f:bar"},
{"\\f\\:bar\\", "f:bar"},
{"", ""},
{"\\\\", ""},
{"\\", ""},
{"\\\\", "\\"},
}
for _, tc := range testcases {
actual := stripBackslash(tc.input)
actual := unescape(tc.input)
if !reflect.DeepEqual(actual, tc.output) {
t.Fatalf("unescape failure for %#v, actual=%#v", tc.input, actual)
}