mirror of
https://github.com/ddworken/hishtory.git
synced 2025-01-22 22:28:51 +01:00
refactored to move no longer shared things out of the shared/ folder
This commit is contained in:
parent
a8d7ee2cc8
commit
c2465d7c99
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/client/lib"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
)
|
||||
@ -67,17 +68,17 @@ func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error {
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("failed to retrieve data from backend, status_code=%d", resp.StatusCode)
|
||||
}
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read latest history entries response body: %v", err)
|
||||
}
|
||||
var retrievedEntries []*shared.EncHistoryEntry
|
||||
err = json.Unmarshal(data, &retrievedEntries)
|
||||
err = json.Unmarshal(respBody, &retrievedEntries)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load JSON response: %v", err)
|
||||
}
|
||||
for _, entry := range retrievedEntries {
|
||||
decEntry, err := shared.DecryptHistoryEntry(config.UserSecret, *entry)
|
||||
decEntry, err := data.DecryptHistoryEntry(config.UserSecret, *entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt history entry from server: %v", err)
|
||||
}
|
||||
@ -91,7 +92,7 @@ func query(query string) {
|
||||
lib.CheckFatalError(err)
|
||||
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
|
||||
lib.CheckFatalError(displayBannerIfSet())
|
||||
data, err := shared.Search(db, query, 25)
|
||||
data, err := data.Search(db, query, 25)
|
||||
lib.CheckFatalError(err)
|
||||
lib.DisplayResults(data, false)
|
||||
}
|
||||
@ -135,7 +136,7 @@ func saveHistoryEntry() {
|
||||
lib.CheckFatalError(result.Error)
|
||||
|
||||
// Persist it remotely
|
||||
encEntry, err := shared.EncryptHistoryEntry(config.UserSecret, *entry)
|
||||
encEntry, err := data.EncryptHistoryEntry(config.UserSecret, *entry)
|
||||
lib.CheckFatalError(err)
|
||||
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||
lib.CheckFatalError(err)
|
||||
@ -150,7 +151,7 @@ func export() {
|
||||
db, err := lib.OpenLocalSqliteDb()
|
||||
lib.CheckFatalError(err)
|
||||
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
|
||||
data, err := shared.Search(db, "", 0)
|
||||
data, err := data.Search(db, "", 0)
|
||||
lib.CheckFatalError(err)
|
||||
for i := len(data) - 1; i >= 0; i-- {
|
||||
fmt.Println(data[i].Command)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
)
|
||||
|
||||
@ -139,8 +140,8 @@ func TestIntegrationWithNewDevice(t *testing.T) {
|
||||
|
||||
// Manually submit an event that isn't in the local DB, and then we'll
|
||||
// check if we see it when we do a query without ever having done an init
|
||||
newEntry := shared.MakeFakeHistoryEntry("othercomputer")
|
||||
encEntry, err := shared.EncryptHistoryEntry(userSecret, newEntry)
|
||||
newEntry := data.MakeFakeHistoryEntry("othercomputer")
|
||||
encEntry, err := data.EncryptHistoryEntry(userSecret, newEntry)
|
||||
shared.Check(t, err)
|
||||
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||
shared.Check(t, err)
|
||||
|
188
client/data/data.go
Normal file
188
client/data/data.go
Normal file
@ -0,0 +1,188 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
KDF_USER_ID = "user_id"
|
||||
KDF_DEVICE_ID = "device_id"
|
||||
KDF_ENCRYPTION_KEY = "encryption_key"
|
||||
)
|
||||
|
||||
type HistoryEntry struct {
|
||||
LocalUsername string `json:"local_username" gorm:"uniqueIndex:compositeindex"`
|
||||
Hostname string `json:"hostname" gorm:"uniqueIndex:compositeindex"`
|
||||
Command string `json:"command" gorm:"uniqueIndex:compositeindex"`
|
||||
CurrentWorkingDirectory string `json:"current_working_directory" gorm:"uniqueIndex:compositeindex"`
|
||||
ExitCode int `json:"exit_code" gorm:"uniqueIndex:compositeindex"`
|
||||
StartTime time.Time `json:"start_time" gorm:"uniqueIndex:compositeindex"`
|
||||
EndTime time.Time `json:"end_time" gorm:"uniqueIndex:compositeindex"`
|
||||
}
|
||||
|
||||
func sha256hmac(key, additionalData string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(key))
|
||||
h.Write([]byte(additionalData))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func UserId(key string) string {
|
||||
return base64.URLEncoding.EncodeToString(sha256hmac(key, KDF_USER_ID))
|
||||
}
|
||||
|
||||
func EncryptionKey(userSecret string) []byte {
|
||||
return sha256hmac(userSecret, KDF_ENCRYPTION_KEY)
|
||||
}
|
||||
|
||||
func makeAead(userSecret string) (cipher.AEAD, error) {
|
||||
key := EncryptionKey(userSecret)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead, nil
|
||||
}
|
||||
|
||||
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
|
||||
aead, err := makeAead(userSecret)
|
||||
if err != nil {
|
||||
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
|
||||
}
|
||||
nonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
|
||||
}
|
||||
ciphertext := aead.Seal(nil, nonce, data, additionalData)
|
||||
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ciphertext, nonce, nil
|
||||
}
|
||||
|
||||
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
|
||||
aead, err := makeAead(userSecret)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
|
||||
}
|
||||
plaintext, err := aead.Open(nil, nonce, data, additionalData)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (shared.EncHistoryEntry, error) {
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return shared.EncHistoryEntry{}, err
|
||||
}
|
||||
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
|
||||
if err != nil {
|
||||
return shared.EncHistoryEntry{}, err
|
||||
}
|
||||
return shared.EncHistoryEntry{
|
||||
EncryptedData: ciphertext,
|
||||
Nonce: nonce,
|
||||
UserId: UserId(userSecret),
|
||||
Date: time.Now(),
|
||||
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
|
||||
ReadCount: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (HistoryEntry, error) {
|
||||
if entry.UserId != UserId(userSecret) {
|
||||
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
|
||||
}
|
||||
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
|
||||
if err != nil {
|
||||
return HistoryEntry{}, nil
|
||||
}
|
||||
var decryptedEntry HistoryEntry
|
||||
err = json.Unmarshal(plaintext, &decryptedEntry)
|
||||
if err != nil {
|
||||
return HistoryEntry{}, nil
|
||||
}
|
||||
return decryptedEntry, nil
|
||||
}
|
||||
|
||||
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
|
||||
tokens, err := tokenize(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to tokenize query: %v", err)
|
||||
}
|
||||
tx := db.Where("true")
|
||||
for _, token := range tokens {
|
||||
if strings.Contains(token, ":") {
|
||||
splitToken := strings.SplitN(token, ":", 2)
|
||||
field := splitToken[0]
|
||||
val := splitToken[1]
|
||||
// tx = tx.Where()
|
||||
panic("TODO(ddworken): Use " + field + val)
|
||||
} else if strings.HasPrefix(token, "-") {
|
||||
panic("TODO(ddworken): Implement -foo as filtering out foo")
|
||||
} else {
|
||||
wildcardedToken := "%" + token + "%"
|
||||
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
|
||||
}
|
||||
}
|
||||
tx = tx.Order("end_time DESC")
|
||||
if limit > 0 {
|
||||
tx = tx.Limit(limit)
|
||||
}
|
||||
var historyEntries []*HistoryEntry
|
||||
result := tx.Find(&historyEntries)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("DB query error: %v", result.Error)
|
||||
}
|
||||
return historyEntries, nil
|
||||
}
|
||||
|
||||
func tokenize(query string) ([]string, error) {
|
||||
if query == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
return strings.Split(query, " "), nil
|
||||
}
|
||||
|
||||
func EntryEquals(entry1, entry2 HistoryEntry) bool {
|
||||
return entry1.LocalUsername == entry2.LocalUsername &&
|
||||
entry1.Hostname == entry2.Hostname &&
|
||||
entry1.Command == entry2.Command &&
|
||||
entry1.CurrentWorkingDirectory == entry2.CurrentWorkingDirectory &&
|
||||
entry1.ExitCode == entry2.ExitCode &&
|
||||
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
|
||||
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func MakeFakeHistoryEntry(command string) HistoryEntry {
|
||||
return HistoryEntry{
|
||||
LocalUsername: "david",
|
||||
Hostname: "localhost",
|
||||
Command: command,
|
||||
CurrentWorkingDirectory: "/tmp/",
|
||||
ExitCode: 2,
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now(),
|
||||
}
|
||||
}
|
@ -1,22 +1,22 @@
|
||||
package shared
|
||||
package data
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
k1, err := EncryptionKey("key")
|
||||
Check(t, err)
|
||||
k2, err := EncryptionKey("key")
|
||||
Check(t, err)
|
||||
k1 := EncryptionKey("key")
|
||||
k2 := EncryptionKey("key")
|
||||
if string(k1) != string(k2) {
|
||||
t.Fatalf("Expected EncryptionKey to be deterministic!")
|
||||
}
|
||||
|
||||
ciphertext, nonce, err := Encrypt("key", []byte("hello world!"), []byte("extra"))
|
||||
Check(t, err)
|
||||
shared.Check(t, err)
|
||||
plaintext, err := Decrypt("key", ciphertext, []byte("extra"), nonce)
|
||||
Check(t, err)
|
||||
shared.Check(t, err)
|
||||
if string(plaintext) != "hello world!" {
|
||||
t.Fatalf("Expected decrypt(encrypt(x)) to work, but it didn't!")
|
||||
}
|
@ -27,6 +27,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/rodaine/table"
|
||||
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
)
|
||||
|
||||
@ -53,8 +54,8 @@ func getCwd() (string, error) {
|
||||
return cwd, nil
|
||||
}
|
||||
|
||||
func BuildHistoryEntry(args []string) (*shared.HistoryEntry, error) {
|
||||
var entry shared.HistoryEntry
|
||||
func BuildHistoryEntry(args []string) (*data.HistoryEntry, error) {
|
||||
var entry data.HistoryEntry
|
||||
|
||||
// exitCode
|
||||
exitCode, err := strconv.Atoi(args[2])
|
||||
@ -141,7 +142,7 @@ func Setup(args []string) error {
|
||||
db.Exec("DELETE FROM history_entries")
|
||||
|
||||
// Bootstrap from remote date
|
||||
resp, err := http.Get(GetServerHostname() + "/api/v1/eregister?user_id=" + shared.UserId(userSecret) + "&device_id=" + config.DeviceId)
|
||||
resp, err := http.Get(GetServerHostname() + "/api/v1/eregister?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register device with backend: %v", err)
|
||||
}
|
||||
@ -149,7 +150,7 @@ func Setup(args []string) error {
|
||||
return fmt.Errorf("failed to register device with backend, status_code=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
resp, err = http.Get(GetServerHostname() + "/api/v1/ebootstrap?user_id=" + shared.UserId(userSecret) + "&device_id=" + config.DeviceId)
|
||||
resp, err = http.Get(GetServerHostname() + "/api/v1/ebootstrap?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bootstrap device from the backend: %v", err)
|
||||
}
|
||||
@ -157,17 +158,17 @@ func Setup(args []string) error {
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("failed to bootstrap device with data from existing devices, status_code=%d", resp.StatusCode)
|
||||
}
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read bootstrap response body: %v", err)
|
||||
}
|
||||
var retrievedEntries []*shared.EncHistoryEntry
|
||||
err = json.Unmarshal(data, &retrievedEntries)
|
||||
err = json.Unmarshal(respBody, &retrievedEntries)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load JSON response: %v", err)
|
||||
}
|
||||
for _, entry := range retrievedEntries {
|
||||
decEntry, err := shared.DecryptHistoryEntry(userSecret, *entry)
|
||||
decEntry, err := data.DecryptHistoryEntry(userSecret, *entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt history entry from server: %v", err)
|
||||
}
|
||||
@ -177,7 +178,7 @@ func Setup(args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddToDbIfNew(db *gorm.DB, entry shared.HistoryEntry) {
|
||||
func AddToDbIfNew(db *gorm.DB, entry data.HistoryEntry) {
|
||||
tx := db.Where("local_username = ?", entry.LocalUsername)
|
||||
tx = tx.Where("hostname = ?", entry.Hostname)
|
||||
tx = tx.Where("command = ?", entry.Command)
|
||||
@ -185,14 +186,14 @@ func AddToDbIfNew(db *gorm.DB, entry shared.HistoryEntry) {
|
||||
tx = tx.Where("exit_code = ?", entry.ExitCode)
|
||||
tx = tx.Where("start_time = ?", entry.StartTime)
|
||||
tx = tx.Where("end_time = ?", entry.EndTime)
|
||||
var results []shared.HistoryEntry
|
||||
var results []data.HistoryEntry
|
||||
tx.Limit(1).Find(&results)
|
||||
if len(results) == 0 {
|
||||
db.Create(entry)
|
||||
}
|
||||
}
|
||||
|
||||
func DisplayResults(results []*shared.HistoryEntry, displayHostname bool) {
|
||||
func DisplayResults(results []*data.HistoryEntry, displayHostname bool) {
|
||||
headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc()
|
||||
tbl := table.New("CWD", "Timestamp", "Runtime", "Exit Code", "Command")
|
||||
if displayHostname {
|
||||
@ -438,6 +439,6 @@ func OpenLocalSqliteDb() (*gorm.DB, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.AutoMigrate(&shared.HistoryEntry{})
|
||||
db.AutoMigrate(&data.HistoryEntry{})
|
||||
return db, nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
)
|
||||
|
||||
@ -76,16 +77,16 @@ func TestPersist(t *testing.T) {
|
||||
db, err := OpenLocalSqliteDb()
|
||||
shared.Check(t, err)
|
||||
|
||||
entry := shared.MakeFakeHistoryEntry("ls ~/")
|
||||
entry := data.MakeFakeHistoryEntry("ls ~/")
|
||||
db.Create(entry)
|
||||
var historyEntries []*shared.HistoryEntry
|
||||
var historyEntries []*data.HistoryEntry
|
||||
result := db.Find(&historyEntries)
|
||||
shared.Check(t, result.Error)
|
||||
if len(historyEntries) != 1 {
|
||||
t.Fatalf("DB has %d entries, expected 1!", len(historyEntries))
|
||||
}
|
||||
dbEntry := historyEntries[0]
|
||||
if !shared.EntryEquals(entry, *dbEntry) {
|
||||
if !data.EntryEquals(entry, *dbEntry) {
|
||||
t.Fatalf("DB data is different than input! \ndb =%#v \ninput=%#v", *dbEntry, entry)
|
||||
}
|
||||
}
|
||||
@ -96,21 +97,21 @@ func TestSearch(t *testing.T) {
|
||||
shared.Check(t, err)
|
||||
|
||||
// Insert data
|
||||
entry1 := shared.MakeFakeHistoryEntry("ls /foo")
|
||||
entry1 := data.MakeFakeHistoryEntry("ls /foo")
|
||||
db.Create(entry1)
|
||||
entry2 := shared.MakeFakeHistoryEntry("ls /bar")
|
||||
entry2 := data.MakeFakeHistoryEntry("ls /bar")
|
||||
db.Create(entry2)
|
||||
|
||||
// Search for data
|
||||
results, err := shared.Search(db, "ls", 5)
|
||||
results, err := data.Search(db, "ls", 5)
|
||||
shared.Check(t, err)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("Search() returned %d results, expected 2!", len(results))
|
||||
}
|
||||
if !shared.EntryEquals(*results[0], entry2) {
|
||||
if !data.EntryEquals(*results[0], entry2) {
|
||||
t.Fatalf("Search()[0]=%#v, expected: %#v", results[0], entry2)
|
||||
}
|
||||
if !shared.EntryEquals(*results[1], entry1) {
|
||||
if !data.EntryEquals(*results[1], entry1) {
|
||||
t.Fatalf("Search()[0]=%#v, expected: %#v", results[1], entry1)
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
_ "github.com/lib/pq"
|
||||
@ -104,8 +105,12 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(forcedBanner))
|
||||
}
|
||||
|
||||
func isTestEnvironment() bool {
|
||||
return os.Getenv("HISHTORY_TEST") != ""
|
||||
}
|
||||
|
||||
func OpenDB() (*gorm.DB, error) {
|
||||
if shared.IsTestEnvironment() {
|
||||
if isTestEnvironment() {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to the DB: %v", err)
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/ddworken/hishtory/client/data"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -18,10 +19,10 @@ func TestESubmitThenQuery(t *testing.T) {
|
||||
InitDB()
|
||||
|
||||
// Register a few devices
|
||||
userId := shared.UserId("key")
|
||||
userId := data.UserId("key")
|
||||
devId1 := uuid.Must(uuid.NewRandom()).String()
|
||||
devId2 := uuid.Must(uuid.NewRandom()).String()
|
||||
otherUser := shared.UserId("otherkey")
|
||||
otherUser := data.UserId("otherkey")
|
||||
otherDev := uuid.Must(uuid.NewRandom()).String()
|
||||
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||
apiERegisterHandler(nil, deviceReq)
|
||||
@ -31,8 +32,8 @@ func TestESubmitThenQuery(t *testing.T) {
|
||||
apiERegisterHandler(nil, deviceReq)
|
||||
|
||||
// Submit a few entries for different devices
|
||||
entry := shared.MakeFakeHistoryEntry("ls ~/")
|
||||
encEntry, err := shared.EncryptHistoryEntry("key", entry)
|
||||
entry := data.MakeFakeHistoryEntry("ls ~/")
|
||||
encEntry, err := data.EncryptHistoryEntry("key", entry)
|
||||
shared.Check(t, err)
|
||||
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||
shared.Check(t, err)
|
||||
@ -45,10 +46,10 @@ func TestESubmitThenQuery(t *testing.T) {
|
||||
apiEQueryHandler(w, searchReq)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
respBody, err := ioutil.ReadAll(res.Body)
|
||||
shared.Check(t, err)
|
||||
var retrievedEntries []*shared.EncHistoryEntry
|
||||
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
|
||||
shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||
if len(retrievedEntries) != 1 {
|
||||
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
|
||||
}
|
||||
@ -56,15 +57,15 @@ func TestESubmitThenQuery(t *testing.T) {
|
||||
if dbEntry.DeviceId != devId1 {
|
||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
||||
}
|
||||
if dbEntry.UserId != shared.UserId("key") {
|
||||
if dbEntry.UserId != data.UserId("key") {
|
||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
||||
}
|
||||
if dbEntry.ReadCount != 1 {
|
||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
||||
}
|
||||
decEntry, err := shared.DecryptHistoryEntry("key", *dbEntry)
|
||||
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
|
||||
shared.Check(t, err)
|
||||
if !shared.EntryEquals(decEntry, entry) {
|
||||
if !data.EntryEquals(decEntry, entry) {
|
||||
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
|
||||
}
|
||||
|
||||
@ -74,9 +75,9 @@ func TestESubmitThenQuery(t *testing.T) {
|
||||
apiEQueryHandler(w, searchReq)
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
data, err = ioutil.ReadAll(res.Body)
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
shared.Check(t, err)
|
||||
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
|
||||
shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||
if len(retrievedEntries) != 1 {
|
||||
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
|
||||
}
|
||||
@ -84,27 +85,27 @@ func TestESubmitThenQuery(t *testing.T) {
|
||||
if dbEntry.DeviceId != devId2 {
|
||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
||||
}
|
||||
if dbEntry.UserId != shared.UserId("key") {
|
||||
if dbEntry.UserId != data.UserId("key") {
|
||||
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
|
||||
}
|
||||
if dbEntry.ReadCount != 1 {
|
||||
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
|
||||
}
|
||||
decEntry, err = shared.DecryptHistoryEntry("key", *dbEntry)
|
||||
decEntry, err = data.DecryptHistoryEntry("key", *dbEntry)
|
||||
shared.Check(t, err)
|
||||
if !shared.EntryEquals(decEntry, entry) {
|
||||
if !data.EntryEquals(decEntry, entry) {
|
||||
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
|
||||
}
|
||||
|
||||
// Bootstrap handler should return 2 entries, one for each device
|
||||
w = httptest.NewRecorder()
|
||||
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+shared.UserId("key"), nil)
|
||||
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key"), nil)
|
||||
apiEBootstrapHandler(w, searchReq)
|
||||
res = w.Result()
|
||||
defer res.Body.Close()
|
||||
data, err = ioutil.ReadAll(res.Body)
|
||||
respBody, err = ioutil.ReadAll(res.Body)
|
||||
shared.Check(t, err)
|
||||
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
|
||||
shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
|
||||
if len(retrievedEntries) != 2 {
|
||||
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
|
||||
}
|
||||
|
182
shared/data.go
182
shared/data.go
@ -1,33 +1,9 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type HistoryEntry struct {
|
||||
LocalUsername string `json:"local_username" gorm:"uniqueIndex:compositeindex"`
|
||||
Hostname string `json:"hostname" gorm:"uniqueIndex:compositeindex"`
|
||||
Command string `json:"command" gorm:"uniqueIndex:compositeindex"`
|
||||
CurrentWorkingDirectory string `json:"current_working_directory" gorm:"uniqueIndex:compositeindex"`
|
||||
ExitCode int `json:"exit_code" gorm:"uniqueIndex:compositeindex"`
|
||||
StartTime time.Time `json:"start_time" gorm:"uniqueIndex:compositeindex"`
|
||||
EndTime time.Time `json:"end_time" gorm:"uniqueIndex:compositeindex"`
|
||||
}
|
||||
|
||||
type EncHistoryEntry struct {
|
||||
EncryptedData []byte `json:"enc_data"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
@ -43,160 +19,8 @@ type Device struct {
|
||||
DeviceId string `json:"device_id"`
|
||||
}
|
||||
|
||||
// const (
|
||||
// MESSAGE_TYPE_REQUEST_DUMP = iota
|
||||
// )
|
||||
|
||||
// type AsyncMessage struct {
|
||||
// MessageType int `json:"message_type"`
|
||||
// }
|
||||
|
||||
const (
|
||||
CONFIG_PATH = ".hishtory.config"
|
||||
HISHTORY_PATH = ".hishtory"
|
||||
DB_PATH = ".hishtory.db"
|
||||
KDF_USER_ID = "user_id"
|
||||
KDF_DEVICE_ID = "device_id"
|
||||
KDF_ENCRYPTION_KEY = "encryption_key"
|
||||
CONFIG_PATH = ".hishtory.config"
|
||||
HISHTORY_PATH = ".hishtory"
|
||||
DB_PATH = ".hishtory.db"
|
||||
)
|
||||
|
||||
func Hmac(key, additionalData string) string {
|
||||
h := hmac.New(sha256.New, []byte(key))
|
||||
h.Write([]byte(additionalData))
|
||||
return base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func UserId(key string) string {
|
||||
return Hmac(key, KDF_USER_ID)
|
||||
}
|
||||
|
||||
func EncryptionKey(userSecret string) ([]byte, error) {
|
||||
encryptionKey, err := base64.URLEncoding.DecodeString(Hmac(userSecret, KDF_ENCRYPTION_KEY))
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("Impossible state, decode(encode(hmac)) failed: %v", err)
|
||||
}
|
||||
return encryptionKey, nil
|
||||
}
|
||||
|
||||
func makeAead(userSecret string) (cipher.AEAD, error) {
|
||||
key, err := EncryptionKey(userSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead, nil
|
||||
}
|
||||
|
||||
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
|
||||
aead, err := makeAead(userSecret)
|
||||
if err != nil {
|
||||
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
|
||||
}
|
||||
nonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
|
||||
}
|
||||
ciphertext := aead.Seal(nil, nonce, data, additionalData)
|
||||
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ciphertext, nonce, nil
|
||||
}
|
||||
|
||||
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
|
||||
aead, err := makeAead(userSecret)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
|
||||
}
|
||||
plaintext, err := aead.Open(nil, nonce, data, additionalData)
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (EncHistoryEntry, error) {
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return EncHistoryEntry{}, err
|
||||
}
|
||||
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
|
||||
if err != nil {
|
||||
return EncHistoryEntry{}, err
|
||||
}
|
||||
return EncHistoryEntry{
|
||||
EncryptedData: ciphertext,
|
||||
Nonce: nonce,
|
||||
UserId: UserId(userSecret),
|
||||
Date: time.Now(),
|
||||
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
|
||||
ReadCount: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DecryptHistoryEntry(userSecret string, entry EncHistoryEntry) (HistoryEntry, error) {
|
||||
if entry.UserId != UserId(userSecret) {
|
||||
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
|
||||
}
|
||||
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
|
||||
if err != nil {
|
||||
return HistoryEntry{}, nil
|
||||
}
|
||||
var decryptedEntry HistoryEntry
|
||||
err = json.Unmarshal(plaintext, &decryptedEntry)
|
||||
if err != nil {
|
||||
return HistoryEntry{}, nil
|
||||
}
|
||||
return decryptedEntry, nil
|
||||
}
|
||||
|
||||
func IsTestEnvironment() bool {
|
||||
return os.Getenv("HISHTORY_TEST") != ""
|
||||
}
|
||||
|
||||
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
|
||||
tokens, err := tokenize(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to tokenize query: %v", err)
|
||||
}
|
||||
tx := db.Where("true")
|
||||
for _, token := range tokens {
|
||||
if strings.Contains(token, ":") {
|
||||
splitToken := strings.SplitN(token, ":", 2)
|
||||
field := splitToken[0]
|
||||
val := splitToken[1]
|
||||
// tx = tx.Where()
|
||||
panic("TODO(ddworken): Use " + field + val)
|
||||
} else if strings.HasPrefix(token, "-") {
|
||||
panic("TODO(ddworken): Implement -foo as filtering out foo")
|
||||
} else {
|
||||
wildcardedToken := "%" + token + "%"
|
||||
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
|
||||
}
|
||||
}
|
||||
tx = tx.Order("end_time DESC")
|
||||
if limit > 0 {
|
||||
tx = tx.Limit(limit)
|
||||
}
|
||||
var historyEntries []*HistoryEntry
|
||||
result := tx.Find(&historyEntries)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("DB query error: %v", result.Error)
|
||||
}
|
||||
return historyEntries, nil
|
||||
}
|
||||
|
||||
func tokenize(query string) ([]string, error) {
|
||||
if query == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
return strings.Split(query, " "), nil
|
||||
}
|
||||
|
@ -91,25 +91,3 @@ func CheckWithInfo(t *testing.T, err error, additionalInfo string) {
|
||||
t.Fatalf("Unexpected error: %v! Additional info: %v", err, additionalInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func EntryEquals(entry1, entry2 HistoryEntry) bool {
|
||||
return entry1.LocalUsername == entry2.LocalUsername &&
|
||||
entry1.Hostname == entry2.Hostname &&
|
||||
entry1.Command == entry2.Command &&
|
||||
entry1.CurrentWorkingDirectory == entry2.CurrentWorkingDirectory &&
|
||||
entry1.ExitCode == entry2.ExitCode &&
|
||||
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
|
||||
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func MakeFakeHistoryEntry(command string) HistoryEntry {
|
||||
return HistoryEntry{
|
||||
LocalUsername: "david",
|
||||
Hostname: "localhost",
|
||||
Command: command,
|
||||
CurrentWorkingDirectory: "/tmp/",
|
||||
ExitCode: 2,
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user