in progress integration tests

This commit is contained in:
David Dworken 2022-01-09 11:00:53 -08:00
parent a523504c40
commit 3d450a1175
12 changed files with 283 additions and 109 deletions

View File

@ -2,6 +2,7 @@ package main
import ( import (
"os" "os"
"strings"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -13,10 +14,11 @@ func main() {
case "query": case "query":
query() query()
case "init": case "init":
err := shared.Setup(os.Args) shared.CheckFatalError(shared.Setup(os.Args))
if err != nil { case "enable":
panic(err) shared.CheckFatalError(shared.Enable())
} case "disable":
shared.CheckFatalError(shared.Disable())
} }
} }
@ -28,19 +30,24 @@ func getServerHostname() string {
} }
func query() { func query() {
// TODO(ddworken) userSecret, err := shared.GetUserSecret()
var data []*shared.HistoryEntry shared.CheckFatalError(err)
db, err := shared.OpenDB()
shared.CheckFatalError(err)
query := strings.Join(os.Args[2:], " ")
data, err := shared.Search(db, query, userSecret, 25)
shared.CheckFatalError(err)
shared.DisplayResults(data) shared.DisplayResults(data)
} }
func saveHistoryEntry() { func saveHistoryEntry() {
isEnabled, err := shared.IsEnabled()
shared.CheckFatalError(err)
if !isEnabled {
return
}
entry, err := shared.BuildHistoryEntry(os.Args) entry, err := shared.BuildHistoryEntry(os.Args)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
err = shared.Persist(*entry) err = shared.Persist(*entry)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
} }

View File

@ -0,0 +1,24 @@
package main
import (
"bytes"
"fmt"
"os/exec"
"testing"
"github.com/ddworken/hishtory/shared"
)
func TestIntegration(t *testing.T) {
// Set up
defer shared.BackupAndRestore(t)
// Run the test
cmd := exec.Command("bash", "--init-file", "test_interaction.sh")
var out bytes.Buffer
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
t.Fatalf("unexpected error when running test script: %v", err)
}
fmt.Printf("%q\n", out.String())
}

View File

@ -0,0 +1,8 @@
go build -o /tmp/client clients/local/client.go
/tmp/client init
export PROMPT_COMMAND='/tmp/client upload $? "`history 1`"'
ls /a
ls /bar
ls /foo
echo foo
/tmp/client query

View File

@ -23,6 +23,10 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
case "enable":
shared.CheckFatalError(shared.Enable())
case "disable":
shared.CheckFatalError(shared.Disable())
} }
} }
@ -35,14 +39,10 @@ func getServerHostname() string {
func query() { func query() {
userSecret, err := shared.GetUserSecret() userSecret, err := shared.GetUserSecret()
if err != nil { shared.CheckFatalError(err)
panic(err)
}
req, err := http.NewRequest("GET", getServerHostname()+"/api/v1/search", nil) req, err := http.NewRequest("GET", getServerHostname()+"/api/v1/search", nil)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
q := req.URL.Query() q := req.URL.Query()
q.Add("query", strings.Join(os.Args[2:], " ")) q.Add("query", strings.Join(os.Args[2:], " "))
@ -52,36 +52,30 @@ func query() {
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
defer resp.Body.Close() defer resp.Body.Close()
resp_body, err := ioutil.ReadAll(resp.Body) resp_body, err := ioutil.ReadAll(resp.Body)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
if resp.Status != "200 OK" { if resp.Status != "200 OK" {
panic("search API returned invalid result. status=" + resp.Status) shared.CheckFatalError(fmt.Errorf("search API returned invalid result. status=" + resp.Status))
} }
var data []*shared.HistoryEntry var data []*shared.HistoryEntry
err = json.Unmarshal(resp_body, &data) err = json.Unmarshal(resp_body, &data)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
shared.DisplayResults(data) shared.DisplayResults(data)
} }
func saveHistoryEntry() { func saveHistoryEntry() {
isEnabled, err := shared.IsEnabled()
shared.CheckFatalError(err)
if !isEnabled {
return
}
entry, err := shared.BuildHistoryEntry(os.Args) entry, err := shared.BuildHistoryEntry(os.Args)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
err = send(*entry) err = send(*entry)
if err != nil { shared.CheckFatalError(err)
panic(err)
}
} }
func send(entry shared.HistoryEntry) error { func send(entry shared.HistoryEntry) error {

17
go.mod
View File

@ -1,11 +1,20 @@
module github.com/ddworken/hishtory module github.com/ddworken/hishtory
go 1.13 go 1.17
require ( require (
github.com/fatih/color v1.13.0 // indirect github.com/fatih/color v1.13.0
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0
github.com/rodaine/table v1.0.1 // indirect github.com/rodaine/table v1.0.1
gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlite v1.2.6
gorm.io/gorm v1.22.4 gorm.io/gorm v1.22.4
) )
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.3 // indirect
github.com/mattn/go-colorable v0.1.9 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-sqlite3 v1.14.9 // indirect
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect
)

5
go.sum
View File

@ -1,4 +1,5 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
@ -14,20 +15,24 @@ github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ= github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ=
github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4= github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.2.6 h1:SStaH/b+280M7C8vXeZLz/zo9cLQmIGwwj3cSj7p6l4= gorm.io/driver/sqlite v1.2.6 h1:SStaH/b+280M7C8vXeZLz/zo9cLQmIGwwj3cSj7p6l4=
gorm.io/driver/sqlite v1.2.6/go.mod h1:gyoX0vHiiwi0g49tv+x2E7l8ksauLK0U/gShcdUsjWY= gorm.io/driver/sqlite v1.2.6/go.mod h1:gyoX0vHiiwi0g49tv+x2E7l8ksauLK0U/gShcdUsjWY=

View File

@ -6,51 +6,10 @@ import (
"log" "log"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"gorm.io/gorm"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
func search(db *gorm.DB, userSecret, query string, limit int) ([]*shared.HistoryEntry, error) {
fmt.Println("Received search query: " + query)
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Debug().Where("user_secret = ?", userSecret)
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*shared.HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}
func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body) decoder := json.NewDecoder(r.Body)
var entry shared.HistoryEntry var entry shared.HistoryEntry
@ -66,10 +25,8 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
func apiSearchHandler(w http.ResponseWriter, r *http.Request) { func apiSearchHandler(w http.ResponseWriter, r *http.Request) {
userSecret := r.URL.Query().Get("user_secret") userSecret := r.URL.Query().Get("user_secret")
if userSecret == "" {
panic("cannot search without specifying a user secret")
}
query := r.URL.Query().Get("query") query := r.URL.Query().Get("query")
fmt.Println("Received search query: " + query)
limitStr := r.URL.Query().Get("limit") limitStr := r.URL.Query().Get("limit")
limit, err := strconv.Atoi(limitStr) limit, err := strconv.Atoi(limitStr)
if err != nil { if err != nil {
@ -79,7 +36,7 @@ func apiSearchHandler(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
entries, err := search(db, userSecret, query, limit) entries, err := shared.Search(db, userSecret, query, limit)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -44,10 +44,10 @@ func TestSubmitThenQuery(t *testing.T) {
shared.Check(t, err) shared.Check(t, err)
var retrievedEntries []*shared.HistoryEntry var retrievedEntries []*shared.HistoryEntry
shared.Check(t, json.Unmarshal(data, &retrievedEntries)) shared.Check(t, json.Unmarshal(data, &retrievedEntries))
dbEntry := retrievedEntries[0]
if len(retrievedEntries) != 1 { if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
dbEntry := retrievedEntries[0]
if dbEntry.UserSecret != "" { if dbEntry.UserSecret != "" {
t.Fatalf("Response contains a user secret: %#v", *dbEntry) t.Fatalf("Response contains a user secret: %#v", *dbEntry)
} }
@ -56,3 +56,72 @@ func TestSubmitThenQuery(t *testing.T) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, *entry) t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, *entry)
} }
} }
func TestNoUserSecretGivesNoResults(t *testing.T) {
// Set up
defer shared.BackupAndRestore(t)
shared.Check(t, shared.Setup([]string{}))
// Submit an entry
entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / "})
shared.Check(t, err)
reqBody, err := json.Marshal(entry)
shared.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(nil, submitReq)
// Retrieve entries with no user secret
w := httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/", nil)
apiSearchHandler(w, searchReq)
res := w.Result()
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
shared.Check(t, err)
var retrievedEntries []*shared.HistoryEntry
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
if len(retrievedEntries) != 0 {
t.Fatalf("Expected to retrieve 0 entries, found %d", len(retrievedEntries))
}
}
func TestSearchQuery(t *testing.T) {
// Set up
defer shared.BackupAndRestore(t)
shared.Check(t, shared.Setup([]string{}))
// Submit an entry that we'll match
entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls /bar "})
shared.Check(t, err)
reqBody, err := json.Marshal(entry)
shared.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(nil, submitReq)
// Submit an entry that we won't match
entry, err = shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls /foo "})
shared.Check(t, err)
reqBody, err = json.Marshal(entry)
shared.Check(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(nil, submitReq)
// Retrieve the entry
secret, err := shared.GetUserSecret()
shared.Check(t, err)
w := httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?user_secret="+secret+"&query=foo", nil)
apiSearchHandler(w, searchReq)
res := w.Result()
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
shared.Check(t, err)
var retrievedEntries []*shared.HistoryEntry
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
dbEntry := retrievedEntries[0]
if dbEntry.Command != "ls /foo" {
t.Fatalf("Response contains an unexpected command: %#v", *dbEntry)
}
}

View File

@ -1,7 +1,9 @@
package shared package shared
import ( import (
"encoding/json"
"fmt" "fmt"
"log"
"os" "os"
"os/user" "os/user"
"path" "path"
@ -15,7 +17,7 @@ import (
) )
const ( const (
SECRET_PATH = ".hishtory.secret" CONFIG_PATH = ".hishtory.config"
) )
func getCwd() (string, error) { func getCwd() (string, error) {
@ -91,15 +93,11 @@ func getLastCommand(history string) (string, error) {
} }
func GetUserSecret() (string, error) { func GetUserSecret() (string, error) {
homedir, err := os.UserHomeDir() config, err := GetConfig()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read secret hishtory key: %v", err) return "", err
} }
secret, err := os.ReadFile(path.Join(homedir, SECRET_PATH)) return config.UserSecret, nil
if err != nil {
return "", fmt.Errorf("failed to read secret hishtory key: %v", err)
}
return string(secret), nil
} }
func Setup(args []string) error { func Setup(args []string) error {
@ -109,15 +107,10 @@ func Setup(args []string) error {
} }
fmt.Println("Setting secret hishtory key to " + string(userSecret)) fmt.Println("Setting secret hishtory key to " + string(userSecret))
homedir, err := os.UserHomeDir() var config ClientConfig
if err != nil { config.UserSecret = userSecret
return fmt.Errorf("failed to retrieve homedir: %v", err) config.IsEnabled = true
} return SetConfig(config)
err = os.WriteFile(path.Join(homedir, SECRET_PATH), []byte(userSecret), 0600)
if err != nil {
return fmt.Errorf("failed to write hishtory secret: %v", err)
}
return nil
} }
func DisplayResults(results []*HistoryEntry) { func DisplayResults(results []*HistoryEntry) {
@ -131,3 +124,73 @@ func DisplayResults(results []*HistoryEntry) {
tbl.Print() tbl.Print()
} }
type ClientConfig struct {
UserSecret string `json:"user_secret"`
IsEnabled bool `json:"is_enabled"`
}
func GetConfig() (ClientConfig, error) {
homedir, err := os.UserHomeDir()
if err != nil {
return ClientConfig{}, fmt.Errorf("failed to retrieve homedir: %v", err)
}
data, err := os.ReadFile(path.Join(homedir, CONFIG_PATH))
if err != nil {
return ClientConfig{}, fmt.Errorf("failed to read config file: %v", err)
}
var config ClientConfig
err = json.Unmarshal(data, &config)
if err != nil {
return ClientConfig{}, fmt.Errorf("failed to parse config file: %v", err)
}
return config, nil
}
func SetConfig(config ClientConfig) error {
serializedConfig, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("failed to serialize config: %v", err)
}
homedir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to retrieve homedir: %v", err)
}
err = os.WriteFile(path.Join(homedir, CONFIG_PATH), serializedConfig, 0600)
if err != nil {
return fmt.Errorf("failed to write config: %v", err)
}
return nil
}
func IsEnabled() (bool, error) {
config, err := GetConfig()
if err != nil {
return false, err
}
return config.IsEnabled, nil
}
func Enable() error {
config, err := GetConfig()
if err != nil {
return err
}
config.IsEnabled = true
return SetConfig(config)
}
func Disable() error {
config, err := GetConfig()
if err != nil {
return err
}
config.IsEnabled = false
return SetConfig(config)
}
func CheckFatalError(err error) {
if err != nil {
log.Fatalf("hishtory fatal error: %v", err)
}
}

View File

@ -11,14 +11,14 @@ func TestSetup(t *testing.T) {
defer BackupAndRestore(t) defer BackupAndRestore(t)
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
Check(t, err) Check(t, err)
if _, err := os.Stat(path.Join(homedir, SECRET_PATH)); err == nil { if _, err := os.Stat(path.Join(homedir, CONFIG_PATH)); err == nil {
t.Fatalf("hishtory secret file already exists!") t.Fatalf("hishtory secret file already exists!")
} }
Check(t, Setup([]string{})) Check(t, Setup([]string{}))
if _, err := os.Stat(path.Join(homedir, SECRET_PATH)); err != nil { if _, err := os.Stat(path.Join(homedir, CONFIG_PATH)); err != nil {
t.Fatalf("hishtory secret file does not exist after Setup()!") t.Fatalf("hishtory secret file does not exist after Setup()!")
} }
data, err := os.ReadFile(path.Join(homedir, SECRET_PATH)) data, err := os.ReadFile(path.Join(homedir, CONFIG_PATH))
Check(t, err) Check(t, err)
if len(data) < 10 { if len(data) < 10 {
t.Fatalf("hishtory secret has unexpected length: %d", len(data)) t.Fatalf("hishtory secret has unexpected length: %d", len(data))

View File

@ -5,6 +5,7 @@ import (
"log" "log"
"os" "os"
"path" "path"
"strings"
"time" "time"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@ -12,7 +13,7 @@ import (
) )
type HistoryEntry struct { type HistoryEntry struct {
UserSecret string `json:"user_secret"` UserSecret string `json:"user_secret" gorm:"index"`
LocalUsername string `json:"local_username"` LocalUsername string `json:"local_username"`
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
Command string `json:"command"` Command string `json:"command"`
@ -50,3 +51,40 @@ func OpenDB() (*gorm.DB, error) {
db.AutoMigrate(&HistoryEntry{}) db.AutoMigrate(&HistoryEntry{})
return db, nil return db, nil
} }
func Search(db *gorm.DB, userSecret, query string, limit int) ([]*HistoryEntry, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Where("user_secret = ?", userSecret)
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}

View File

@ -20,10 +20,10 @@ func BackupAndRestore(t *testing.T) func() {
} }
os.Rename(path.Join(homedir, DB_PATH), path.Join(homedir, DB_PATH+".bak")) os.Rename(path.Join(homedir, DB_PATH), path.Join(homedir, DB_PATH+".bak"))
os.Rename(path.Join(homedir, SECRET_PATH), path.Join(homedir, SECRET_PATH+".bak")) os.Rename(path.Join(homedir, CONFIG_PATH), path.Join(homedir, CONFIG_PATH+".bak"))
return func() { return func() {
Check(t, os.Rename(path.Join(homedir, DB_PATH+".bak"), path.Join(homedir, DB_PATH))) Check(t, os.Rename(path.Join(homedir, DB_PATH+".bak"), path.Join(homedir, DB_PATH)))
Check(t, os.Rename(path.Join(homedir, SECRET_PATH+".bak"), path.Join(homedir, SECRET_PATH))) Check(t, os.Rename(path.Join(homedir, CONFIG_PATH+".bak"), path.Join(homedir, CONFIG_PATH)))
} }
} }