Allow search strings to escape special chars ' ', ':' and '-' using '\'

This commit is contained in:
Håkan Fouren 2023-02-04 15:55:55 +08:00
parent f502cbee1d
commit 9062c24a7e
No known key found for this signature in database

View File

@ -1098,7 +1098,7 @@ func MakeWhereQueryFromSearch(ctx *context.Context, db *gorm.DB, query string) (
tx := db.Model(&data.HistoryEntry{}).Where("true") tx := db.Model(&data.HistoryEntry{}).Where("true")
for _, token := range tokens { for _, token := range tokens {
if strings.HasPrefix(token, "-") { if strings.HasPrefix(token, "-") {
if strings.Contains(token, ":") { if containsUnescaped(token, ":") {
query, v1, v2, err := parseAtomizedToken(ctx, token[1:]) query, v1, v2, err := parseAtomizedToken(ctx, token[1:])
if err != nil { if err != nil {
return nil, err return nil, err
@ -1111,7 +1111,7 @@ func MakeWhereQueryFromSearch(ctx *context.Context, db *gorm.DB, query string) (
} }
tx = tx.Where("NOT "+query, v1, v2, v3) tx = tx.Where("NOT "+query, v1, v2, v3)
} }
} else if strings.Contains(token, ":") { } else if containsUnescaped(token, ":") {
query, v1, v2, err := parseAtomizedToken(ctx, token) query, v1, v2, err := parseAtomizedToken(ctx, token)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1150,14 +1150,14 @@ func Search(ctx *context.Context, db *gorm.DB, query string, limit int) ([]*data
} }
func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) { func parseNonAtomizedToken(token string) (string, interface{}, interface{}, interface{}, error) {
wildcardedToken := "%" + token + "%" wildcardedToken := "%" + deEscape(token) + "%"
return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil return "(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken, nil
} }
func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) { func parseAtomizedToken(ctx *context.Context, token string) (string, interface{}, interface{}, error) {
splitToken := strings.SplitN(token, ":", 2) splitToken := splitEscaped(token, ':', 2)
field := splitToken[0] field := deEscape(splitToken[0])
val := splitToken[1] val := deEscape(splitToken[1])
switch field { switch field {
case "user": case "user":
return "(local_username = ?)", val, nil, nil return "(local_username = ?)", val, nil, nil
@ -1237,7 +1237,49 @@ func tokenize(query string) ([]string, error) {
if query == "" { if query == "" {
return []string{}, nil return []string{}, nil
} }
return strings.Split(query, " "), nil return splitEscaped(query, ' ', -1), nil
}
func splitEscaped(query string, separator byte, maxSplit int) []string {
var token []byte
var tokens []string
var splits = 1
for i := 0; i < len(query); i++ {
if (maxSplit < 0 || splits < maxSplit) && query[i] == separator {
tokens = append(tokens, string(token))
token = token[:0]
splits++
} else if query[i] == '\\' && i+1 < len(query) {
token = append(token, query[i], query[i+1])
i++
} else {
token = append(token, query[i])
}
}
tokens = append(tokens, string(token))
return tokens
}
func containsUnescaped(query string, token string) bool {
for i := 0; i < len(query); i++ {
if query[i] == '\\' && i+1 < len(query) {
i++
} else if query[i:i+len(token)] == token {
return true
}
}
return false
}
func deEscape(query string) string {
var newQuery []byte
for i := 0; i < len(query); i++ {
if query[i] == '\\' && i+1 < len(query) {
i++
}
newQuery = append(newQuery, query[i])
}
return string(newQuery)
} }
func GetDumpRequests(config hctx.ClientConfig) ([]*shared.DumpRequest, error) { func GetDumpRequests(config hctx.ClientConfig) ([]*shared.DumpRequest, error) {