Add unit tests + rename method

This commit is contained in:
David Dworken 2023-02-13 22:26:02 -08:00
parent b6eb4da4f3
commit 162dd86893
No known key found for this signature in database
2 changed files with 74 additions and 4 deletions

View File

@ -1150,14 +1150,14 @@ func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data
} }
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) { func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
wildcardedToken := "%" + unescape(token) + "%" wildcardedToken := "%" + stripBackslash(token) + "%"
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
} }
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) { func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
splitToken := splitEscaped(token, ':', 2) splitToken := splitEscaped(token, ':', 2)
field := unescape(splitToken[0]) field := stripBackslash(splitToken[0])
val := unescape(splitToken[1]) val := stripBackslash(splitToken[1])
switch field { switch field {
case "user": case "user":
return "(local_username = ?)", val, nil, nil return "(local_username = ?)", val, nil, nil
@ -1273,7 +1273,7 @@ func containsUnescaped(query string, token string) bool {
return false return false
} }
func unescape(query string) string { func stripBackslash(query string) string {
runeQuery := []rune(query) runeQuery := []rune(query)
var newQuery []rune var newQuery []rune
for i := 0; i < len(runeQuery); i++ { for i := 0; i < len(runeQuery); i++ {

View File

@ -433,3 +433,73 @@ func TestParseTimeGenerously(t *testing.T) {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
} }
func TestStripBackslash(t *testing.T) {
testcases := []struct {
input string
output string
}{
{"f bar", "f bar"},
{"f \\bar", "f bar"},
{"f\\:bar", "f:bar"},
{"f\\:bar\\", "f:bar"},
}
for _, tc := range testcases {
actual := stripBackslash(tc.input)
if !reflect.DeepEqual(actual, tc.output) {
t.Fatalf("unescape failure for %#v, actual=%#v", tc.input, actual)
}
}
}
func TestContainsUnescaped(t *testing.T) {
testcases := []struct {
input string
token string
expected bool
}{
{"f bar", "f", true},
{"f bar", "f bar", true},
{"f bar", "f r", false},
{"f bar", "f ", true},
{"foo:bar", ":", true},
{"foo:bar", "-", false},
{"foo\\:bar", ":", false},
{"foo\\-bar", "-", false},
{"foo\\-bar", "foo", true},
{"foo\\-bar", "bar", true},
{"foo\\-bar", "a", true},
}
for _, tc := range testcases {
actual := containsUnescaped(tc.input, tc.token)
if !reflect.DeepEqual(actual, tc.expected) {
t.Fatalf("containsUnescaped failure for containsUnescaped(%#v, %#v), actual=%#v", tc.input, tc.token, actual)
}
}
}
func TestSplitEscaped(t *testing.T) {
testcases := []struct {
input string
char rune
limit int
expected []string
}{
{"foo bar", ' ', 2, []string{"foo", "bar"}},
{"foo bar baz", ' ', 2, []string{"foo", "bar baz"}},
{"foo bar baz", ' ', 3, []string{"foo", "bar", "baz"}},
{"foo bar baz", ' ', 1, []string{"foo bar baz"}},
{"foo bar baz", ' ', -1, []string{"foo", "bar", "baz"}},
{"foo\\ bar baz", ' ', -1, []string{"foo\\ bar", "baz"}},
{"foo\\bar baz", ' ', -1, []string{"foo\\bar", "baz"}},
{"foo\\bar baz foob", ' ', 2, []string{"foo\\bar", "baz foob"}},
{"foo\\ bar\\ baz", ' ', -1, []string{"foo\\ bar\\ baz"}},
{"foo\\ bar\\ baz", ' ', -1, []string{"foo\\ bar\\ ", "baz"}},
}
for _, tc := range testcases {
actual := splitEscaped(tc.input, tc.char, tc.limit)
if !reflect.DeepEqual(actual, tc.expected) {
t.Fatalf("containsUnescaped failure for splitEscaped(%#v, %#v, %#v), actual=%#v", tc.input, string(tc.char), tc.limit, actual)
}
}
}