refactored to move no longer shared things out of the shared/ folder

This commit is contained in:
David Dworken 2022-04-07 20:59:40 -07:00
parent a8d7ee2cc8
commit c2465d7c99
10 changed files with 253 additions and 253 deletions

View File

@ -11,6 +11,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/lib" "github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -67,17 +68,17 @@ func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error {
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return fmt.Errorf("failed to retrieve data from backend, status_code=%d", resp.StatusCode) return fmt.Errorf("failed to retrieve data from backend, status_code=%d", resp.StatusCode)
} }
data, err := ioutil.ReadAll(resp.Body) respBody, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to read latest history entries response body: %v", err) return fmt.Errorf("failed to read latest history entries response body: %v", err)
} }
var retrievedEntries []*shared.EncHistoryEntry var retrievedEntries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &retrievedEntries) err = json.Unmarshal(respBody, &retrievedEntries)
if err != nil { if err != nil {
return fmt.Errorf("failed to load JSON response: %v", err) return fmt.Errorf("failed to load JSON response: %v", err)
} }
for _, entry := range retrievedEntries { for _, entry := range retrievedEntries {
decEntry, err := shared.DecryptHistoryEntry(config.UserSecret, *entry) decEntry, err := data.DecryptHistoryEntry(config.UserSecret, *entry)
if err != nil { if err != nil {
return fmt.Errorf("failed to decrypt history entry from server: %v", err) return fmt.Errorf("failed to decrypt history entry from server: %v", err)
} }
@ -91,7 +92,7 @@ func query(query string) {
lib.CheckFatalError(err) lib.CheckFatalError(err)
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db)) lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
lib.CheckFatalError(displayBannerIfSet()) lib.CheckFatalError(displayBannerIfSet())
data, err := shared.Search(db, query, 25) data, err := data.Search(db, query, 25)
lib.CheckFatalError(err) lib.CheckFatalError(err)
lib.DisplayResults(data, false) lib.DisplayResults(data, false)
} }
@ -135,7 +136,7 @@ func saveHistoryEntry() {
lib.CheckFatalError(result.Error) lib.CheckFatalError(result.Error)
// Persist it remotely // Persist it remotely
encEntry, err := shared.EncryptHistoryEntry(config.UserSecret, *entry) encEntry, err := data.EncryptHistoryEntry(config.UserSecret, *entry)
lib.CheckFatalError(err) lib.CheckFatalError(err)
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
lib.CheckFatalError(err) lib.CheckFatalError(err)
@ -150,7 +151,7 @@ func export() {
db, err := lib.OpenLocalSqliteDb() db, err := lib.OpenLocalSqliteDb()
lib.CheckFatalError(err) lib.CheckFatalError(err)
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db)) lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
data, err := shared.Search(db, "", 0) data, err := data.Search(db, "", 0)
lib.CheckFatalError(err) lib.CheckFatalError(err)
for i := len(data) - 1; i >= 0; i-- { for i := len(data) - 1; i >= 0; i-- {
fmt.Println(data[i].Command) fmt.Println(data[i].Command)

View File

@ -12,6 +12,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -139,8 +140,8 @@ func TestIntegrationWithNewDevice(t *testing.T) {
// Manually submit an event that isn't in the local DB, and then we'll // Manually submit an event that isn't in the local DB, and then we'll
// check if we see it when we do a query without ever having done an init // check if we see it when we do a query without ever having done an init
newEntry := shared.MakeFakeHistoryEntry("othercomputer") newEntry := data.MakeFakeHistoryEntry("othercomputer")
encEntry, err := shared.EncryptHistoryEntry(userSecret, newEntry) encEntry, err := data.EncryptHistoryEntry(userSecret, newEntry)
shared.Check(t, err) shared.Check(t, err)
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
shared.Check(t, err) shared.Check(t, err)

188
client/data/data.go Normal file
View File

@ -0,0 +1,188 @@
package data
import (
"fmt"
"io"
"strings"
"time"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"github.com/ddworken/hishtory/shared"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
KDF_USER_ID = "user_id"
KDF_DEVICE_ID = "device_id"
KDF_ENCRYPTION_KEY = "encryption_key"
)
type HistoryEntry struct {
LocalUsername string `json:"local_username" gorm:"uniqueIndex:compositeindex"`
Hostname string `json:"hostname" gorm:"uniqueIndex:compositeindex"`
Command string `json:"command" gorm:"uniqueIndex:compositeindex"`
CurrentWorkingDirectory string `json:"current_working_directory" gorm:"uniqueIndex:compositeindex"`
ExitCode int `json:"exit_code" gorm:"uniqueIndex:compositeindex"`
StartTime time.Time `json:"start_time" gorm:"uniqueIndex:compositeindex"`
EndTime time.Time `json:"end_time" gorm:"uniqueIndex:compositeindex"`
}
func sha256hmac(key, additionalData string) []byte {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(additionalData))
return h.Sum(nil)
}
func UserId(key string) string {
return base64.URLEncoding.EncodeToString(sha256hmac(key, KDF_USER_ID))
}
func EncryptionKey(userSecret string) []byte {
return sha256hmac(userSecret, KDF_ENCRYPTION_KEY)
}
func makeAead(userSecret string) (cipher.AEAD, error) {
key := EncryptionKey(userSecret)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aead, nil
}
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
}
ciphertext := aead.Seal(nil, nonce, data, additionalData)
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
panic(err)
}
return ciphertext, nonce, nil
}
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
plaintext, err := aead.Open(nil, nonce, data, additionalData)
if err != nil {
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
}
return plaintext, nil
}
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (shared.EncHistoryEntry, error) {
data, err := json.Marshal(entry)
if err != nil {
return shared.EncHistoryEntry{}, err
}
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
if err != nil {
return shared.EncHistoryEntry{}, err
}
return shared.EncHistoryEntry{
EncryptedData: ciphertext,
Nonce: nonce,
UserId: UserId(userSecret),
Date: time.Now(),
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0,
}, nil
}
func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (HistoryEntry, error) {
if entry.UserId != UserId(userSecret) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
}
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
if err != nil {
return HistoryEntry{}, nil
}
var decryptedEntry HistoryEntry
err = json.Unmarshal(plaintext, &decryptedEntry)
if err != nil {
return HistoryEntry{}, nil
}
return decryptedEntry, nil
}
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Where("true")
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
} else if strings.HasPrefix(token, "-") {
panic("TODO(ddworken): Implement -foo as filtering out foo")
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}
func EntryEquals(entry1, entry2 HistoryEntry) bool {
return entry1.LocalUsername == entry2.LocalUsername &&
entry1.Hostname == entry2.Hostname &&
entry1.Command == entry2.Command &&
entry1.CurrentWorkingDirectory == entry2.CurrentWorkingDirectory &&
entry1.ExitCode == entry2.ExitCode &&
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
}
func MakeFakeHistoryEntry(command string) HistoryEntry {
return HistoryEntry{
LocalUsername: "david",
Hostname: "localhost",
Command: command,
CurrentWorkingDirectory: "/tmp/",
ExitCode: 2,
StartTime: time.Now(),
EndTime: time.Now(),
}
}

View File

@ -1,22 +1,22 @@
package shared package data
import ( import (
"testing" "testing"
"github.com/ddworken/hishtory/shared"
) )
func TestEncryptDecrypt(t *testing.T) { func TestEncryptDecrypt(t *testing.T) {
k1, err := EncryptionKey("key") k1 := EncryptionKey("key")
Check(t, err) k2 := EncryptionKey("key")
k2, err := EncryptionKey("key")
Check(t, err)
if string(k1) != string(k2) { if string(k1) != string(k2) {
t.Fatalf("Expected EncryptionKey to be deterministic!") t.Fatalf("Expected EncryptionKey to be deterministic!")
} }
ciphertext, nonce, err := Encrypt("key", []byte("hello world!"), []byte("extra")) ciphertext, nonce, err := Encrypt("key", []byte("hello world!"), []byte("extra"))
Check(t, err) shared.Check(t, err)
plaintext, err := Decrypt("key", ciphertext, []byte("extra"), nonce) plaintext, err := Decrypt("key", ciphertext, []byte("extra"), nonce)
Check(t, err) shared.Check(t, err)
if string(plaintext) != "hello world!" { if string(plaintext) != "hello world!" {
t.Fatalf("Expected decrypt(encrypt(x)) to work, but it didn't!") t.Fatalf("Expected decrypt(encrypt(x)) to work, but it didn't!")
} }

View File

@ -27,6 +27,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rodaine/table" "github.com/rodaine/table"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -53,8 +54,8 @@ func getCwd() (string, error) {
return cwd, nil return cwd, nil
} }
func BuildHistoryEntry(args []string) (*shared.HistoryEntry, error) { func BuildHistoryEntry(args []string) (*data.HistoryEntry, error) {
var entry shared.HistoryEntry var entry data.HistoryEntry
// exitCode // exitCode
exitCode, err := strconv.Atoi(args[2]) exitCode, err := strconv.Atoi(args[2])
@ -141,7 +142,7 @@ func Setup(args []string) error {
db.Exec("DELETE FROM history_entries") db.Exec("DELETE FROM history_entries")
// Bootstrap from remote date // Bootstrap from remote date
resp, err := http.Get(GetServerHostname() + "/api/v1/eregister?user_id=" + shared.UserId(userSecret) + "&device_id=" + config.DeviceId) resp, err := http.Get(GetServerHostname() + "/api/v1/eregister?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
if err != nil { if err != nil {
return fmt.Errorf("failed to register device with backend: %v", err) return fmt.Errorf("failed to register device with backend: %v", err)
} }
@ -149,7 +150,7 @@ func Setup(args []string) error {
return fmt.Errorf("failed to register device with backend, status_code=%d", resp.StatusCode) return fmt.Errorf("failed to register device with backend, status_code=%d", resp.StatusCode)
} }
resp, err = http.Get(GetServerHostname() + "/api/v1/ebootstrap?user_id=" + shared.UserId(userSecret) + "&device_id=" + config.DeviceId) resp, err = http.Get(GetServerHostname() + "/api/v1/ebootstrap?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
if err != nil { if err != nil {
return fmt.Errorf("failed to bootstrap device from the backend: %v", err) return fmt.Errorf("failed to bootstrap device from the backend: %v", err)
} }
@ -157,17 +158,17 @@ func Setup(args []string) error {
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return fmt.Errorf("failed to bootstrap device with data from existing devices, status_code=%d", resp.StatusCode) return fmt.Errorf("failed to bootstrap device with data from existing devices, status_code=%d", resp.StatusCode)
} }
data, err := ioutil.ReadAll(resp.Body) respBody, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to read bootstrap response body: %v", err) return fmt.Errorf("failed to read bootstrap response body: %v", err)
} }
var retrievedEntries []*shared.EncHistoryEntry var retrievedEntries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &retrievedEntries) err = json.Unmarshal(respBody, &retrievedEntries)
if err != nil { if err != nil {
return fmt.Errorf("failed to load JSON response: %v", err) return fmt.Errorf("failed to load JSON response: %v", err)
} }
for _, entry := range retrievedEntries { for _, entry := range retrievedEntries {
decEntry, err := shared.DecryptHistoryEntry(userSecret, *entry) decEntry, err := data.DecryptHistoryEntry(userSecret, *entry)
if err != nil { if err != nil {
return fmt.Errorf("failed to decrypt history entry from server: %v", err) return fmt.Errorf("failed to decrypt history entry from server: %v", err)
} }
@ -177,7 +178,7 @@ func Setup(args []string) error {
return nil return nil
} }
func AddToDbIfNew(db *gorm.DB, entry shared.HistoryEntry) { func AddToDbIfNew(db *gorm.DB, entry data.HistoryEntry) {
tx := db.Where("local_username = ?", entry.LocalUsername) tx := db.Where("local_username = ?", entry.LocalUsername)
tx = tx.Where("hostname = ?", entry.Hostname) tx = tx.Where("hostname = ?", entry.Hostname)
tx = tx.Where("command = ?", entry.Command) tx = tx.Where("command = ?", entry.Command)
@ -185,14 +186,14 @@ func AddToDbIfNew(db *gorm.DB, entry shared.HistoryEntry) {
tx = tx.Where("exit_code = ?", entry.ExitCode) tx = tx.Where("exit_code = ?", entry.ExitCode)
tx = tx.Where("start_time = ?", entry.StartTime) tx = tx.Where("start_time = ?", entry.StartTime)
tx = tx.Where("end_time = ?", entry.EndTime) tx = tx.Where("end_time = ?", entry.EndTime)
var results []shared.HistoryEntry var results []data.HistoryEntry
tx.Limit(1).Find(&results) tx.Limit(1).Find(&results)
if len(results) == 0 { if len(results) == 0 {
db.Create(entry) db.Create(entry)
} }
} }
func DisplayResults(results []*shared.HistoryEntry, displayHostname bool) { func DisplayResults(results []*data.HistoryEntry, displayHostname bool) {
headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc() headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc()
tbl := table.New("CWD", "Timestamp", "Runtime", "Exit Code", "Command") tbl := table.New("CWD", "Timestamp", "Runtime", "Exit Code", "Command")
if displayHostname { if displayHostname {
@ -438,6 +439,6 @@ func OpenLocalSqliteDb() (*gorm.DB, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.AutoMigrate(&shared.HistoryEntry{}) db.AutoMigrate(&data.HistoryEntry{})
return db, nil return db, nil
} }

View File

@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -76,16 +77,16 @@ func TestPersist(t *testing.T) {
db, err := OpenLocalSqliteDb() db, err := OpenLocalSqliteDb()
shared.Check(t, err) shared.Check(t, err)
entry := shared.MakeFakeHistoryEntry("ls ~/") entry := data.MakeFakeHistoryEntry("ls ~/")
db.Create(entry) db.Create(entry)
var historyEntries []*shared.HistoryEntry var historyEntries []*data.HistoryEntry
result := db.Find(&historyEntries) result := db.Find(&historyEntries)
shared.Check(t, result.Error) shared.Check(t, result.Error)
if len(historyEntries) != 1 { if len(historyEntries) != 1 {
t.Fatalf("DB has %d entries, expected 1!", len(historyEntries)) t.Fatalf("DB has %d entries, expected 1!", len(historyEntries))
} }
dbEntry := historyEntries[0] dbEntry := historyEntries[0]
if !shared.EntryEquals(entry, *dbEntry) { if !data.EntryEquals(entry, *dbEntry) {
t.Fatalf("DB data is different than input! \ndb =%#v \ninput=%#v", *dbEntry, entry) t.Fatalf("DB data is different than input! \ndb =%#v \ninput=%#v", *dbEntry, entry)
} }
} }
@ -96,21 +97,21 @@ func TestSearch(t *testing.T) {
shared.Check(t, err) shared.Check(t, err)
// Insert data // Insert data
entry1 := shared.MakeFakeHistoryEntry("ls /foo") entry1 := data.MakeFakeHistoryEntry("ls /foo")
db.Create(entry1) db.Create(entry1)
entry2 := shared.MakeFakeHistoryEntry("ls /bar") entry2 := data.MakeFakeHistoryEntry("ls /bar")
db.Create(entry2) db.Create(entry2)
// Search for data // Search for data
results, err := shared.Search(db, "ls", 5) results, err := data.Search(db, "ls", 5)
shared.Check(t, err) shared.Check(t, err)
if len(results) != 2 { if len(results) != 2 {
t.Fatalf("Search() returned %d results, expected 2!", len(results)) t.Fatalf("Search() returned %d results, expected 2!", len(results))
} }
if !shared.EntryEquals(*results[0], entry2) { if !data.EntryEquals(*results[0], entry2) {
t.Fatalf("Search()[0]=%#v, expected: %#v", results[0], entry2) t.Fatalf("Search()[0]=%#v, expected: %#v", results[0], entry2)
} }
if !shared.EntryEquals(*results[1], entry1) { if !data.EntryEquals(*results[1], entry1) {
t.Fatalf("Search()[0]=%#v, expected: %#v", results[1], entry1) t.Fatalf("Search()[0]=%#v, expected: %#v", results[1], entry1)
} }
} }

View File

@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
"os"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -104,8 +105,12 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(forcedBanner)) w.Write([]byte(forcedBanner))
} }
func isTestEnvironment() bool {
return os.Getenv("HISHTORY_TEST") != ""
}
func OpenDB() (*gorm.DB, error) { func OpenDB() (*gorm.DB, error) {
if shared.IsTestEnvironment() { if isTestEnvironment() {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to the DB: %v", err) return nil, fmt.Errorf("failed to connect to the DB: %v", err)

View File

@ -8,6 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -18,10 +19,10 @@ func TestESubmitThenQuery(t *testing.T) {
InitDB() InitDB()
// Register a few devices // Register a few devices
userId := shared.UserId("key") userId := data.UserId("key")
devId1 := uuid.Must(uuid.NewRandom()).String() devId1 := uuid.Must(uuid.NewRandom()).String()
devId2 := uuid.Must(uuid.NewRandom()).String() devId2 := uuid.Must(uuid.NewRandom()).String()
otherUser := shared.UserId("otherkey") otherUser := data.UserId("otherkey")
otherDev := uuid.Must(uuid.NewRandom()).String() otherDev := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiERegisterHandler(nil, deviceReq) apiERegisterHandler(nil, deviceReq)
@ -31,8 +32,8 @@ func TestESubmitThenQuery(t *testing.T) {
apiERegisterHandler(nil, deviceReq) apiERegisterHandler(nil, deviceReq)
// Submit a few entries for different devices // Submit a few entries for different devices
entry := shared.MakeFakeHistoryEntry("ls ~/") entry := data.MakeFakeHistoryEntry("ls ~/")
encEntry, err := shared.EncryptHistoryEntry("key", entry) encEntry, err := data.EncryptHistoryEntry("key", entry)
shared.Check(t, err) shared.Check(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
shared.Check(t, err) shared.Check(t, err)
@ -45,10 +46,10 @@ func TestESubmitThenQuery(t *testing.T) {
apiEQueryHandler(w, searchReq) apiEQueryHandler(w, searchReq)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body) respBody, err := ioutil.ReadAll(res.Body)
shared.Check(t, err) shared.Check(t, err)
var retrievedEntries []*shared.EncHistoryEntry var retrievedEntries []*shared.EncHistoryEntry
shared.Check(t, json.Unmarshal(data, &retrievedEntries)) shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 { if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
@ -56,15 +57,15 @@ func TestESubmitThenQuery(t *testing.T) {
if dbEntry.DeviceId != devId1 { if dbEntry.DeviceId != devId1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
} }
if dbEntry.UserId != shared.UserId("key") { if dbEntry.UserId != data.UserId("key") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
} }
if dbEntry.ReadCount != 1 { if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err := shared.DecryptHistoryEntry("key", *dbEntry) decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
shared.Check(t, err) shared.Check(t, err)
if !shared.EntryEquals(decEntry, entry) { if !data.EntryEquals(decEntry, entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry) t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
} }
@ -74,9 +75,9 @@ func TestESubmitThenQuery(t *testing.T) {
apiEQueryHandler(w, searchReq) apiEQueryHandler(w, searchReq)
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
data, err = ioutil.ReadAll(res.Body) respBody, err = ioutil.ReadAll(res.Body)
shared.Check(t, err) shared.Check(t, err)
shared.Check(t, json.Unmarshal(data, &retrievedEntries)) shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 { if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
@ -84,27 +85,27 @@ func TestESubmitThenQuery(t *testing.T) {
if dbEntry.DeviceId != devId2 { if dbEntry.DeviceId != devId2 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
} }
if dbEntry.UserId != shared.UserId("key") { if dbEntry.UserId != data.UserId("key") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
} }
if dbEntry.ReadCount != 1 { if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err = shared.DecryptHistoryEntry("key", *dbEntry) decEntry, err = data.DecryptHistoryEntry("key", *dbEntry)
shared.Check(t, err) shared.Check(t, err)
if !shared.EntryEquals(decEntry, entry) { if !data.EntryEquals(decEntry, entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry) t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
} }
// Bootstrap handler should return 2 entries, one for each device // Bootstrap handler should return 2 entries, one for each device
w = httptest.NewRecorder() w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+shared.UserId("key"), nil) searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key"), nil)
apiEBootstrapHandler(w, searchReq) apiEBootstrapHandler(w, searchReq)
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
data, err = ioutil.ReadAll(res.Body) respBody, err = ioutil.ReadAll(res.Body)
shared.Check(t, err) shared.Check(t, err)
shared.Check(t, json.Unmarshal(data, &retrievedEntries)) shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 { if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
} }

View File

@ -1,33 +1,9 @@
package shared package shared
import ( import (
"fmt"
"io"
"os"
"strings"
"time" "time"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"github.com/google/uuid"
"gorm.io/gorm"
) )
type HistoryEntry struct {
LocalUsername string `json:"local_username" gorm:"uniqueIndex:compositeindex"`
Hostname string `json:"hostname" gorm:"uniqueIndex:compositeindex"`
Command string `json:"command" gorm:"uniqueIndex:compositeindex"`
CurrentWorkingDirectory string `json:"current_working_directory" gorm:"uniqueIndex:compositeindex"`
ExitCode int `json:"exit_code" gorm:"uniqueIndex:compositeindex"`
StartTime time.Time `json:"start_time" gorm:"uniqueIndex:compositeindex"`
EndTime time.Time `json:"end_time" gorm:"uniqueIndex:compositeindex"`
}
type EncHistoryEntry struct { type EncHistoryEntry struct {
EncryptedData []byte `json:"enc_data"` EncryptedData []byte `json:"enc_data"`
Nonce []byte `json:"nonce"` Nonce []byte `json:"nonce"`
@ -43,160 +19,8 @@ type Device struct {
DeviceId string `json:"device_id"` DeviceId string `json:"device_id"`
} }
// const (
// MESSAGE_TYPE_REQUEST_DUMP = iota
// )
// type AsyncMessage struct {
// MessageType int `json:"message_type"`
// }
const ( const (
CONFIG_PATH = ".hishtory.config" CONFIG_PATH = ".hishtory.config"
HISHTORY_PATH = ".hishtory" HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db" DB_PATH = ".hishtory.db"
KDF_USER_ID = "user_id"
KDF_DEVICE_ID = "device_id"
KDF_ENCRYPTION_KEY = "encryption_key"
) )
func Hmac(key, additionalData string) string {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(additionalData))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
func UserId(key string) string {
return Hmac(key, KDF_USER_ID)
}
func EncryptionKey(userSecret string) ([]byte, error) {
encryptionKey, err := base64.URLEncoding.DecodeString(Hmac(userSecret, KDF_ENCRYPTION_KEY))
if err != nil {
return []byte{}, fmt.Errorf("Impossible state, decode(encode(hmac)) failed: %v", err)
}
return encryptionKey, nil
}
func makeAead(userSecret string) (cipher.AEAD, error) {
key, err := EncryptionKey(userSecret)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aead, nil
}
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
}
ciphertext := aead.Seal(nil, nonce, data, additionalData)
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
panic(err)
}
return ciphertext, nonce, nil
}
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
plaintext, err := aead.Open(nil, nonce, data, additionalData)
if err != nil {
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
}
return plaintext, nil
}
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (EncHistoryEntry, error) {
data, err := json.Marshal(entry)
if err != nil {
return EncHistoryEntry{}, err
}
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
if err != nil {
return EncHistoryEntry{}, err
}
return EncHistoryEntry{
EncryptedData: ciphertext,
Nonce: nonce,
UserId: UserId(userSecret),
Date: time.Now(),
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0,
}, nil
}
func DecryptHistoryEntry(userSecret string, entry EncHistoryEntry) (HistoryEntry, error) {
if entry.UserId != UserId(userSecret) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
}
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
if err != nil {
return HistoryEntry{}, nil
}
var decryptedEntry HistoryEntry
err = json.Unmarshal(plaintext, &decryptedEntry)
if err != nil {
return HistoryEntry{}, nil
}
return decryptedEntry, nil
}
func IsTestEnvironment() bool {
return os.Getenv("HISHTORY_TEST") != ""
}
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Where("true")
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
} else if strings.HasPrefix(token, "-") {
panic("TODO(ddworken): Implement -foo as filtering out foo")
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}

View File

@ -91,25 +91,3 @@ func CheckWithInfo(t *testing.T, err error, additionalInfo string) {
t.Fatalf("Unexpected error: %v! Additional info: %v", err, additionalInfo) t.Fatalf("Unexpected error: %v! Additional info: %v", err, additionalInfo)
} }
} }
func EntryEquals(entry1, entry2 HistoryEntry) bool {
return entry1.LocalUsername == entry2.LocalUsername &&
entry1.Hostname == entry2.Hostname &&
entry1.Command == entry2.Command &&
entry1.CurrentWorkingDirectory == entry2.CurrentWorkingDirectory &&
entry1.ExitCode == entry2.ExitCode &&
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
}
func MakeFakeHistoryEntry(command string) HistoryEntry {
return HistoryEntry{
LocalUsername: "david",
Hostname: "localhost",
Command: command,
CurrentWorkingDirectory: "/tmp/",
ExitCode: 2,
StartTime: time.Now(),
EndTime: time.Now(),
}
}