hishtory/shared/data.go

246 lines
6.7 KiB
Go
Raw Normal View History

2022-01-09 05:27:18 +01:00
package shared
import (
"fmt"
"io"
"os"
"path"
"strconv"
2022-01-09 20:00:53 +01:00
"strings"
"time"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"github.com/google/uuid"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
2022-01-09 05:27:18 +01:00
type HistoryEntry struct {
// TODO: UserSecret needs to be removed from here once I drop all the old code
2022-01-09 20:00:53 +01:00
UserSecret string `json:"user_secret" gorm:"index"`
2022-01-09 05:27:18 +01:00
LocalUsername string `json:"local_username"`
Hostname string `json:"hostname"`
Command string `json:"command"`
CurrentWorkingDirectory string `json:"current_working_directory"`
ExitCode int `json:"exit_code"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
}
type EncHistoryEntry struct {
EncryptedData []byte `json:"enc_data"`
Nonce []byte `json:"nonce"`
DeviceId string `json:"device_id"`
UserId string `json:"user_id"`
Date time.Time `json:"time"`
EncryptedId string `json:"id"`
ReadCount int `json:"read_count"`
}
type Device struct {
UserId string `json:"user_id"`
DeviceId string `json:"device_id"`
}
// const (
// MESSAGE_TYPE_REQUEST_DUMP = iota
// )
// type AsyncMessage struct {
// MessageType int `json:"message_type"`
// }
const (
HISHTORY_PATH = ".hishtory"
DB_PATH = ".hishtory.db"
KDF_USER_ID = "user_id"
KDF_DEVICE_ID = "device_id"
KDF_ENCRYPTION_KEY = "encryption_key"
)
func Hmac(key, additionalData string) string {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(additionalData))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
func UserId(key string) string {
return Hmac(key, KDF_USER_ID)
}
func DeviceId(key string, id int) string {
return Hmac(key, KDF_DEVICE_ID+strconv.Itoa(id))
}
func EncryptionKey(userSecret string) ([]byte, error) {
encryptionKey, err := base64.URLEncoding.DecodeString(Hmac(userSecret, KDF_ENCRYPTION_KEY))
if err != nil {
return []byte{}, fmt.Errorf("Impossible state, decode(encode(hmac)) failed: %v", err)
}
return encryptionKey, nil
}
func makeAead(userSecret string) (cipher.AEAD, error) {
key, err := EncryptionKey(userSecret)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return aead, nil
}
func Encrypt(userSecret string, data, additionalData []byte) ([]byte, []byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return []byte{}, []byte{}, fmt.Errorf("Failed to read a nonce: %v", err)
}
ciphertext := aead.Seal(nil, nonce, data, additionalData)
_, err = aead.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
panic(err)
}
return ciphertext, nonce, nil
}
func Decrypt(userSecret string, data, additionalData, nonce []byte) ([]byte, error) {
aead, err := makeAead(userSecret)
if err != nil {
return []byte{}, fmt.Errorf("Failed to make AEAD: %v", err)
}
plaintext, err := aead.Open(nil, nonce, data, additionalData)
if err != nil {
return []byte{}, fmt.Errorf("Failed to decrypt: %v", err)
}
return plaintext, nil
}
func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (EncHistoryEntry, error) {
data, err := json.Marshal(entry)
if err != nil {
return EncHistoryEntry{}, err
}
ciphertext, nonce, err := Encrypt(userSecret, data, []byte(UserId(userSecret)))
if err != nil {
return EncHistoryEntry{}, err
}
return EncHistoryEntry{
EncryptedData: ciphertext,
Nonce: nonce,
UserId: UserId(userSecret),
Date: time.Now(),
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0,
}, nil
}
func DecryptHistoryEntry(userSecret string, deviceId int, entry EncHistoryEntry) (HistoryEntry, error) {
if entry.UserId != UserId(userSecret) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
}
if entry.DeviceId != DeviceId(userSecret, deviceId) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching DeviceId")
}
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
if err != nil {
return HistoryEntry{}, nil
}
var decryptedEntry HistoryEntry
err = json.Unmarshal(plaintext, &decryptedEntry)
if err != nil {
return HistoryEntry{}, nil
}
return decryptedEntry, nil
}
func IsTestEnvironment() bool {
return os.Getenv("HISHTORY_TEST") != ""
}
func OpenLocalSqliteDb() (*gorm.DB, error) {
homedir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get user's home directory: %v", err)
}
err = os.MkdirAll(path.Join(homedir, HISHTORY_PATH), 0744)
if err != nil {
return nil, fmt.Errorf("failed to create ~/.hishtory dir: %v", err)
}
db, err := gorm.Open(sqlite.Open(path.Join(homedir, HISHTORY_PATH, DB_PATH)), &gorm.Config{SkipDefaultTransaction: true})
if err != nil {
2022-01-10 01:39:13 +01:00
return nil, fmt.Errorf("failed to connect to the DB: %v", err)
}
tx, err := db.DB()
if err != nil {
return nil, err
}
err = tx.Ping()
if err != nil {
return nil, err
}
db.AutoMigrate(&HistoryEntry{})
db.AutoMigrate(&EncHistoryEntry{})
db.AutoMigrate(&Device{})
return db, nil
}
2022-01-09 20:00:53 +01:00
func Persist(db *gorm.DB, entry HistoryEntry) error {
db.Create(&entry)
return nil
}
2022-01-09 20:00:53 +01:00
func Search(db *gorm.DB, userSecret, query string, limit int) ([]*HistoryEntry, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
}
tx := db.Where("user_secret = ?", userSecret)
for _, token := range tokens {
if strings.Contains(token, ":") {
splitToken := strings.SplitN(token, ":", 2)
field := splitToken[0]
val := splitToken[1]
// tx = tx.Where()
panic("TODO(ddworken): Use " + field + val)
2022-01-10 01:39:13 +01:00
} else if strings.HasPrefix(token, "-") {
panic("TODO(ddworken): Implement -foo as filtering out foo")
2022-01-09 20:00:53 +01:00
} else {
wildcardedToken := "%" + token + "%"
tx = tx.Where("(command LIKE ? OR hostname LIKE ? OR current_working_directory LIKE ?)", wildcardedToken, wildcardedToken, wildcardedToken)
}
}
tx = tx.Order("end_time DESC")
if limit > 0 {
tx = tx.Limit(limit)
}
var historyEntries []*HistoryEntry
result := tx.Find(&historyEntries)
if result.Error != nil {
return nil, fmt.Errorf("DB query error: %v", result.Error)
}
return historyEntries, nil
}
func tokenize(query string) ([]string, error) {
if query == "" {
return []string{}, nil
}
return strings.Split(query, " "), nil
}