refactored to move no longer shared things out of the shared/ folder

This commit is contained in:
David Dworken 2022-04-07 20:59:40 -07:00
parent a8d7ee2cc8
commit c2465d7c99
10 changed files with 253 additions and 253 deletions

View File

@ -11,6 +11,7 @@ import (
"gorm.io/gorm"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared"
)
@ -67,17 +68,17 @@ func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error {
if resp.StatusCode != 200 {
return fmt.Errorf("failed to retrieve data from backend, status_code=%d", resp.StatusCode)
}
data, err := ioutil.ReadAll(resp.Body)
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read latest history entries response body: %v", err)
}
var retrievedEntries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &retrievedEntries)
err = json.Unmarshal(respBody, &retrievedEntries)
if err != nil {
return fmt.Errorf("failed to load JSON response: %v", err)
}
for _, entry := range retrievedEntries {
decEntry, err := shared.DecryptHistoryEntry(config.UserSecret, *entry)
decEntry, err := data.DecryptHistoryEntry(config.UserSecret, *entry)
if err != nil {
return fmt.Errorf("failed to decrypt history entry from server: %v", err)
}
@ -91,7 +92,7 @@ func query(query string) {
lib.CheckFatalError(err)
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
lib.CheckFatalError(displayBannerIfSet())
data, err := shared.Search(db, query, 25)
data, err := data.Search(db, query, 25)
lib.CheckFatalError(err)
lib.DisplayResults(data, false)
}
@ -135,7 +136,7 @@ func saveHistoryEntry() {
lib.CheckFatalError(result.Error)
// Persist it remotely
encEntry, err := shared.EncryptHistoryEntry(config.UserSecret, *entry)
encEntry, err := data.EncryptHistoryEntry(config.UserSecret, *entry)
lib.CheckFatalError(err)
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
lib.CheckFatalError(err)
@ -150,7 +151,7 @@ func export() {
db, err := lib.OpenLocalSqliteDb()
lib.CheckFatalError(err)
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
data, err := shared.Search(db, "", 0)
data, err := data.Search(db, "", 0)
lib.CheckFatalError(err)
for i := len(data) - 1; i >= 0; i-- {
fmt.Println(data[i].Command)

View File

@ -12,6 +12,7 @@ import (
"strings"
"testing"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared"
)
@ -139,8 +140,8 @@ func TestIntegrationWithNewDevice(t *testing.T) {
// Manually submit an event that isn't in the local DB, and then we'll
// check if we see it when we do a query without ever having done an init
newEntry := shared.MakeFakeHistoryEntry("othercomputer")
encEntry, err := shared.EncryptHistoryEntry(userSecret, newEntry)
newEntry := data.MakeFakeHistoryEntry("othercomputer")
encEntry, err := data.EncryptHistoryEntry(userSecret, newEntry)
shared.Check(t, err)
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
shared.Check(t, err)

188
client/data/data.go Normal file
View File

@ -0,0 +1,188 @@
package data
import (
"fmt"
"io"
"strings"
"time"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"github.com/ddworken/hishtory/shared"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
KDF_USER_ID = "user_id"
KDF_DEVICE_ID = "device_id"
KDF_ENCRYPTION_KEY = "encryption_key"
)
type HistoryEntry struct {
LocalUsername string `json:"local_username" gorm:"uniqueIndex:compositeindex"`
Hostname string `json:"hostname" gorm:"uniqueIndex:compositeindex"`
Command string `json:"command" gorm:"uniqueIndex:compositeindex"`
CurrentWorkingDirectory string `json:"current_working_directory" gorm:"uniqueIndex:compositeindex"`
ExitCode int `json:"exit_code" gorm:"uniqueIndex:compositeindex"`
StartTime time.Time `json:"start_time" gorm:"uniqueIndex:compositeindex"`
EndTime time.Time `json:"end_time" gorm:"uniqueIndex:compositeindex"`
}
func sha256hmac(key, additionalData string) []byte {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(additionalData))
return h.Sum(nil)
}
func UserId(key string) string {
return base64.URLEncoding.EncodeToString(sha256hmac(key, KDF_USER_ID))
}
func EncryptionKey(userSecret string) []byte {
return sha256hmac(userSecret, KDF_ENCRYPTION_KEY)
}
func makeAead(userSecret string) (cipher.AEAD, error) {
key := EncryptionKey(userSecret)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aead, nil
}
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
}
ciphertext := aead.Seal(nil, nonce, data, additionalData)
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
panic(err)
}
return ciphertext, nonce, nil
}
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
plaintext, err := aead.Open(nil, nonce, data, additionalData)
if err != nil {
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
}
return plaintext, nil
}
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (shared.EncHistoryEntry, error) {
data, err := json.Marshal(entry)
if err != nil {
return shared.EncHistoryEntry{}, err
}
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
if err != nil {
return shared.EncHistoryEntry{}, err
}
return shared.EncHistoryEntry{
EncryptedData: ciphertext,
Nonce: nonce,
UserId: UserId(userSecret),
Date: time.Now(),
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0,
}, nil
}
func DecryptHistoryEntry(userSecret string, entry shared.EncHistoryEntry) (HistoryEntry, error) {
if entry.UserId != UserId(userSecret) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
}
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
if err != nil {
return HistoryEntry{}, nil
}
var decryptedEntry HistoryEntry
err = json.Unmarshal(plaintext, &decryptedEntry)
if err != nil {
return HistoryEntry{}, nil
}
return decryptedEntry, nil
}
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Where("true")
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
} else if strings.HasPrefix(token, "-") {
panic("TODO(ddworken): Implement -foo as filtering out foo")
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}
func EntryEquals(entry1, entry2 HistoryEntry) bool {
return entry1.LocalUsername == entry2.LocalUsername &&
entry1.Hostname == entry2.Hostname &&
entry1.Command == entry2.Command &&
entry1.CurrentWorkingDirectory == entry2.CurrentWorkingDirectory &&
entry1.ExitCode == entry2.ExitCode &&
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
}
func MakeFakeHistoryEntry(command string) HistoryEntry {
return HistoryEntry{
LocalUsername: "david",
Hostname: "localhost",
Command: command,
CurrentWorkingDirectory: "/tmp/",
ExitCode: 2,
StartTime: time.Now(),
EndTime: time.Now(),
}
}

View File

@ -1,22 +1,22 @@
package shared
package data
import (
"testing"
"github.com/ddworken/hishtory/shared"
)
func TestEncryptDecrypt(t *testing.T) {
k1, err := EncryptionKey("key")
Check(t, err)
k2, err := EncryptionKey("key")
Check(t, err)
k1 := EncryptionKey("key")
k2 := EncryptionKey("key")
if string(k1) != string(k2) {
t.Fatalf("Expected EncryptionKey to be deterministic!")
}
ciphertext, nonce, err := Encrypt("key", []byte("hello world!"), []byte("extra"))
Check(t, err)
shared.Check(t, err)
plaintext, err := Decrypt("key", ciphertext, []byte("extra"), nonce)
Check(t, err)
shared.Check(t, err)
if string(plaintext) != "hello world!" {
t.Fatalf("Expected decrypt(encrypt(x)) to work, but it didn't!")
}

View File

@ -27,6 +27,7 @@ import (
"github.com/google/uuid"
"github.com/rodaine/table"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared"
)
@ -53,8 +54,8 @@ func getCwd() (string, error) {
return cwd, nil
}
func BuildHistoryEntry(args []string) (*shared.HistoryEntry, error) {
var entry shared.HistoryEntry
func BuildHistoryEntry(args []string) (*data.HistoryEntry, error) {
var entry data.HistoryEntry
// exitCode
exitCode, err := strconv.Atoi(args[2])
@ -141,7 +142,7 @@ func Setup(args []string) error {
db.Exec("DELETE FROM history_entries")
// Bootstrap from remote date
resp, err := http.Get(GetServerHostname() + "/api/v1/eregister?user_id=" + shared.UserId(userSecret) + "&device_id=" + config.DeviceId)
resp, err := http.Get(GetServerHostname() + "/api/v1/eregister?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
if err != nil {
return fmt.Errorf("failed to register device with backend: %v", err)
}
@ -149,7 +150,7 @@ func Setup(args []string) error {
return fmt.Errorf("failed to register device with backend, status_code=%d", resp.StatusCode)
}
resp, err = http.Get(GetServerHostname() + "/api/v1/ebootstrap?user_id=" + shared.UserId(userSecret) + "&device_id=" + config.DeviceId)
resp, err = http.Get(GetServerHostname() + "/api/v1/ebootstrap?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
if err != nil {
return fmt.Errorf("failed to bootstrap device from the backend: %v", err)
}
@ -157,17 +158,17 @@ func Setup(args []string) error {
if resp.StatusCode != 200 {
return fmt.Errorf("failed to bootstrap device with data from existing devices, status_code=%d", resp.StatusCode)
}
data, err := ioutil.ReadAll(resp.Body)
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read bootstrap response body: %v", err)
}
var retrievedEntries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &retrievedEntries)
err = json.Unmarshal(respBody, &retrievedEntries)
if err != nil {
return fmt.Errorf("failed to load JSON response: %v", err)
}
for _, entry := range retrievedEntries {
decEntry, err := shared.DecryptHistoryEntry(userSecret, *entry)
decEntry, err := data.DecryptHistoryEntry(userSecret, *entry)
if err != nil {
return fmt.Errorf("failed to decrypt history entry from server: %v", err)
}
@ -177,7 +178,7 @@ func Setup(args []string) error {
return nil
}
func AddToDbIfNew(db *gorm.DB, entry shared.HistoryEntry) {
func AddToDbIfNew(db *gorm.DB, entry data.HistoryEntry) {
tx := db.Where("local_username = ?", entry.LocalUsername)
tx = tx.Where("hostname = ?", entry.Hostname)
tx = tx.Where("command = ?", entry.Command)
@ -185,14 +186,14 @@ func AddToDbIfNew(db *gorm.DB, entry shared.HistoryEntry) {
tx = tx.Where("exit_code = ?", entry.ExitCode)
tx = tx.Where("start_time = ?", entry.StartTime)
tx = tx.Where("end_time = ?", entry.EndTime)
var results []shared.HistoryEntry
var results []data.HistoryEntry
tx.Limit(1).Find(&results)
if len(results) == 0 {
db.Create(entry)
}
}
func DisplayResults(results []*shared.HistoryEntry, displayHostname bool) {
func DisplayResults(results []*data.HistoryEntry, displayHostname bool) {
headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc()
tbl := table.New("CWD", "Timestamp", "Runtime", "Exit Code", "Command")
if displayHostname {
@ -438,6 +439,6 @@ func OpenLocalSqliteDb() (*gorm.DB, error) {
if err != nil {
return nil, err
}
db.AutoMigrate(&shared.HistoryEntry{})
db.AutoMigrate(&data.HistoryEntry{})
return db, nil
}

View File

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared"
)
@ -76,16 +77,16 @@ func TestPersist(t *testing.T) {
db, err := OpenLocalSqliteDb()
shared.Check(t, err)
entry := shared.MakeFakeHistoryEntry("ls ~/")
entry := data.MakeFakeHistoryEntry("ls ~/")
db.Create(entry)
var historyEntries []*shared.HistoryEntry
var historyEntries []*data.HistoryEntry
result := db.Find(&historyEntries)
shared.Check(t, result.Error)
if len(historyEntries) != 1 {
t.Fatalf("DB has %d entries, expected 1!", len(historyEntries))
}
dbEntry := historyEntries[0]
if !shared.EntryEquals(entry, *dbEntry) {
if !data.EntryEquals(entry, *dbEntry) {
t.Fatalf("DB data is different than input! \ndb =%#v \ninput=%#v", *dbEntry, entry)
}
}
@ -96,21 +97,21 @@ func TestSearch(t *testing.T) {
shared.Check(t, err)
// Insert data
entry1 := shared.MakeFakeHistoryEntry("ls /foo")
entry1 := data.MakeFakeHistoryEntry("ls /foo")
db.Create(entry1)
entry2 := shared.MakeFakeHistoryEntry("ls /bar")
entry2 := data.MakeFakeHistoryEntry("ls /bar")
db.Create(entry2)
// Search for data
results, err := shared.Search(db, "ls", 5)
results, err := data.Search(db, "ls", 5)
shared.Check(t, err)
if len(results) != 2 {
t.Fatalf("Search() returned %d results, expected 2!", len(results))
}
if !shared.EntryEquals(*results[0], entry2) {
if !data.EntryEquals(*results[0], entry2) {
t.Fatalf("Search()[0]=%#v, expected: %#v", results[0], entry2)
}
if !shared.EntryEquals(*results[1], entry1) {
if !data.EntryEquals(*results[1], entry1) {
t.Fatalf("Search()[0]=%#v, expected: %#v", results[1], entry1)
}
}

View File

@ -6,6 +6,7 @@ import (
"io/ioutil"
"log"
"net/http"
"os"
"github.com/ddworken/hishtory/shared"
_ "github.com/lib/pq"
@ -104,8 +105,12 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(forcedBanner))
}
func isTestEnvironment() bool {
return os.Getenv("HISHTORY_TEST") != ""
}
func OpenDB() (*gorm.DB, error) {
if shared.IsTestEnvironment() {
if isTestEnvironment() {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("failed to connect to the DB: %v", err)

View File

@ -8,6 +8,7 @@ import (
"net/http/httptest"
"testing"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared"
"github.com/google/uuid"
)
@ -18,10 +19,10 @@ func TestESubmitThenQuery(t *testing.T) {
InitDB()
// Register a few devices
userId := shared.UserId("key")
userId := data.UserId("key")
devId1 := uuid.Must(uuid.NewRandom()).String()
devId2 := uuid.Must(uuid.NewRandom()).String()
otherUser := shared.UserId("otherkey")
otherUser := data.UserId("otherkey")
otherDev := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiERegisterHandler(nil, deviceReq)
@ -31,8 +32,8 @@ func TestESubmitThenQuery(t *testing.T) {
apiERegisterHandler(nil, deviceReq)
// Submit a few entries for different devices
entry := shared.MakeFakeHistoryEntry("ls ~/")
encEntry, err := shared.EncryptHistoryEntry("key", entry)
entry := data.MakeFakeHistoryEntry("ls ~/")
encEntry, err := data.EncryptHistoryEntry("key", entry)
shared.Check(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
shared.Check(t, err)
@ -45,10 +46,10 @@ func TestESubmitThenQuery(t *testing.T) {
apiEQueryHandler(w, searchReq)
res := w.Result()
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
respBody, err := ioutil.ReadAll(res.Body)
shared.Check(t, err)
var retrievedEntries []*shared.EncHistoryEntry
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
@ -56,15 +57,15 @@ func TestESubmitThenQuery(t *testing.T) {
if dbEntry.DeviceId != devId1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != shared.UserId("key") {
if dbEntry.UserId != data.UserId("key") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
decEntry, err := shared.DecryptHistoryEntry("key", *dbEntry)
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
shared.Check(t, err)
if !shared.EntryEquals(decEntry, entry) {
if !data.EntryEquals(decEntry, entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
}
@ -74,9 +75,9 @@ func TestESubmitThenQuery(t *testing.T) {
apiEQueryHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
data, err = ioutil.ReadAll(res.Body)
respBody, err = ioutil.ReadAll(res.Body)
shared.Check(t, err)
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
}
@ -84,27 +85,27 @@ func TestESubmitThenQuery(t *testing.T) {
if dbEntry.DeviceId != devId2 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.UserId != shared.UserId("key") {
if dbEntry.UserId != data.UserId("key") {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
}
if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
}
decEntry, err = shared.DecryptHistoryEntry("key", *dbEntry)
decEntry, err = data.DecryptHistoryEntry("key", *dbEntry)
shared.Check(t, err)
if !shared.EntryEquals(decEntry, entry) {
if !data.EntryEquals(decEntry, entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
}
// Bootstrap handler should return 2 entries, one for each device
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+shared.UserId("key"), nil)
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key"), nil)
apiEBootstrapHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
data, err = ioutil.ReadAll(res.Body)
respBody, err = ioutil.ReadAll(res.Body)
shared.Check(t, err)
shared.Check(t, json.Unmarshal(data, &retrievedEntries))
shared.Check(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
}

View File

@ -1,33 +1,9 @@
package shared
import (
"fmt"
"io"
"os"
"strings"
"time"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"github.com/google/uuid"
"gorm.io/gorm"
)
type HistoryEntry struct {
LocalUsername string `json:"local_username" gorm:"uniqueIndex:compositeindex"`
Hostname string `json:"hostname" gorm:"uniqueIndex:compositeindex"`
Command string `json:"command" gorm:"uniqueIndex:compositeindex"`
CurrentWorkingDirectory string `json:"current_working_directory" gorm:"uniqueIndex:compositeindex"`
ExitCode int `json:"exit_code" gorm:"uniqueIndex:compositeindex"`
StartTime time.Time `json:"start_time" gorm:"uniqueIndex:compositeindex"`
EndTime time.Time `json:"end_time" gorm:"uniqueIndex:compositeindex"`
}
type EncHistoryEntry struct {
EncryptedData []byte `json:"enc_data"`
Nonce []byte `json:"nonce"`
@ -43,160 +19,8 @@ type Device struct {
DeviceId string `json:"device_id"`
}
// const (
// MESSAGE_TYPE_REQUEST_DUMP = iota
// )
// type AsyncMessage struct {
// MessageType int `json:"message_type"`
// }
const (
CONFIG_PATH = ".hishtory.config"
HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db"
KDF_USER_ID = "user_id"
KDF_DEVICE_ID = "device_id"
KDF_ENCRYPTION_KEY = "encryption_key"
CONFIG_PATH = ".hishtory.config"
HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db"
)
func Hmac(key, additionalData string) string {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(additionalData))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
func UserId(key string) string {
return Hmac(key, KDF_USER_ID)
}
func EncryptionKey(userSecret string) ([]byte, error) {
encryptionKey, err := base64.URLEncoding.DecodeString(Hmac(userSecret, KDF_ENCRYPTION_KEY))
if err != nil {
return []byte{}, fmt.Errorf("Impossible state, decode(encode(hmac)) failed: %v", err)
}
return encryptionKey, nil
}
func makeAead(userSecret string) (cipher.AEAD, error) {
key, err := EncryptionKey(userSecret)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aead, nil
}
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
}
ciphertext := aead.Seal(nil, nonce, data, additionalData)
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
panic(err)
}
return ciphertext, nonce, nil
}
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
plaintext, err := aead.Open(nil, nonce, data, additionalData)
if err != nil {
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
}
return plaintext, nil
}
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (EncHistoryEntry, error) {
data, err := json.Marshal(entry)
if err != nil {
return EncHistoryEntry{}, err
}
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
if err != nil {
return EncHistoryEntry{}, err
}
return EncHistoryEntry{
EncryptedData: ciphertext,
Nonce: nonce,
UserId: UserId(userSecret),
Date: time.Now(),
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0,
}, nil
}
func DecryptHistoryEntry(userSecret string, entry EncHistoryEntry) (HistoryEntry, error) {
if entry.UserId != UserId(userSecret) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
}
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
if err != nil {
return HistoryEntry{}, nil
}
var decryptedEntry HistoryEntry
err = json.Unmarshal(plaintext, &decryptedEntry)
if err != nil {
return HistoryEntry{}, nil
}
return decryptedEntry, nil
}
func IsTestEnvironment() bool {
return os.Getenv("HISHTORY_TEST") != ""
}
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Where("true")
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
} else if strings.HasPrefix(token, "-") {
panic("TODO(ddworken): Implement -foo as filtering out foo")
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}

View File

@ -91,25 +91,3 @@ func CheckWithInfo(t *testing.T, err error, additionalInfo string) {
t.Fatalf("Unexpected error: %v! Additional info: %v", err, additionalInfo)
}
}
func EntryEquals(entry1, entry2 HistoryEntry) bool {
return entry1.LocalUsername == entry2.LocalUsername &&
entry1.Hostname == entry2.Hostname &&
entry1.Command == entry2.Command &&
entry1.CurrentWorkingDirectory == entry2.CurrentWorkingDirectory &&
entry1.ExitCode == entry2.ExitCode &&
entry1.StartTime.Format(time.RFC3339) == entry2.StartTime.Format(time.RFC3339) &&
entry1.EndTime.Format(time.RFC3339) == entry2.EndTime.Format(time.RFC3339)
}
func MakeFakeHistoryEntry(command string) HistoryEntry {
return HistoryEntry{
LocalUsername: "david",
Hostname: "localhost",
Command: command,
CurrentWorkingDirectory: "/tmp/",
ExitCode: 2,
StartTime: time.Now(),
EndTime: time.Now(),
}
}