device ID is now a random uuid generated by the client

This commit is contained in:
David Dworken 2022-04-03 21:00:46 -07:00
parent 32e74eb3a1
commit c5eea01a23
6 changed files with 26 additions and 33 deletions

View File

@ -24,7 +24,7 @@ func main() {
case "export": case "export":
export() export()
case "init": case "init":
shared.CheckFatalError(shared.Setup(0, os.Args)) shared.CheckFatalError(shared.Setup( os.Args))
// TODO: Call ebootstrap here // TODO: Call ebootstrap here
case "install": case "install":
shared.CheckFatalError(shared.Install()) shared.CheckFatalError(shared.Install())

View File

@ -8,6 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/google/uuid"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
) )
@ -15,7 +16,7 @@ func TestSubmitThenQuery(t *testing.T) {
// Set up // Set up
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
InitDB() InitDB()
shared.Check(t, shared.Setup(0, []string{})) shared.Check(t, shared.Setup([]string{}))
// Submit an entry // Submit an entry
entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / ", "1641774958326745663"}) entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / ", "1641774958326745663"})
@ -62,7 +63,7 @@ func TestNoUserSecretGivesNoResults(t *testing.T) {
// Set up // Set up
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
InitDB() InitDB()
shared.Check(t, shared.Setup(0, []string{})) shared.Check(t, shared.Setup([]string{}))
// Submit an entry // Submit an entry
entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / ", "1641774958326745663"}) entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / ", "1641774958326745663"})
@ -91,7 +92,7 @@ func TestSearchQuery(t *testing.T) {
// Set up // Set up
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
InitDB() InitDB()
shared.Check(t, shared.Setup(0, []string{})) shared.Check(t, shared.Setup([]string{}))
// Submit an entry that we'll match // Submit an entry that we'll match
entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls /bar ", "1641774958326745663"}) entry, err := shared.BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls /bar ", "1641774958326745663"})
@ -133,14 +134,14 @@ func TestESubmitThenQuery(t *testing.T) {
// Set up // Set up
defer shared.BackupAndRestore(t)() defer shared.BackupAndRestore(t)()
InitDB() InitDB()
shared.Check(t, shared.Setup(0, []string{})) shared.Check(t, shared.Setup([]string{}))
// Register a few devices // Register a few devices
userId := shared.UserId("key") userId := shared.UserId("key")
devId1 := shared.DeviceId("key", 1) devId1 := uuid.Must(uuid.NewRandom()).String()
devId2 := shared.DeviceId("key", 2) devId2 := uuid.Must(uuid.NewRandom()).String()
otherUser := shared.UserId("otherkey") otherUser := shared.UserId("otherkey")
otherDev := shared.DeviceId("otherkey", 1) otherDev := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiERegisterHandler(nil, deviceReq) apiERegisterHandler(nil, deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
@ -160,7 +161,7 @@ func TestESubmitThenQuery(t *testing.T) {
// Query for device id 1 // Query for device id 1
w := httptest.NewRecorder() w := httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+shared.DeviceId("key", 1), nil) searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1, nil)
apiEQueryHandler(w, searchReq) apiEQueryHandler(w, searchReq)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
@ -172,7 +173,7 @@ func TestESubmitThenQuery(t *testing.T) {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
dbEntry := retrievedEntries[0] dbEntry := retrievedEntries[0]
if dbEntry.DeviceId != shared.DeviceId("key", 1) { if dbEntry.DeviceId != devId1 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
} }
if dbEntry.UserId != shared.UserId("key") { if dbEntry.UserId != shared.UserId("key") {
@ -181,7 +182,7 @@ func TestESubmitThenQuery(t *testing.T) {
if dbEntry.ReadCount != 1 { if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err := shared.DecryptHistoryEntry("key", 1, *dbEntry) decEntry, err := shared.DecryptHistoryEntry("key", *dbEntry)
shared.Check(t, err) shared.Check(t, err)
if !shared.EntryEquals(decEntry, *entry) { if !shared.EntryEquals(decEntry, *entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, *entry) t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, *entry)
@ -189,7 +190,7 @@ func TestESubmitThenQuery(t *testing.T) {
// Same for device id 2 // Same for device id 2
w = httptest.NewRecorder() w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+shared.DeviceId("key", 2), nil) searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2, nil)
apiEQueryHandler(w, searchReq) apiEQueryHandler(w, searchReq)
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
@ -200,7 +201,7 @@ func TestESubmitThenQuery(t *testing.T) {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
dbEntry = retrievedEntries[0] dbEntry = retrievedEntries[0]
if dbEntry.DeviceId != shared.DeviceId("key", 2) { if dbEntry.DeviceId != devId2 {
t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry)
} }
if dbEntry.UserId != shared.UserId("key") { if dbEntry.UserId != shared.UserId("key") {
@ -209,7 +210,7 @@ func TestESubmitThenQuery(t *testing.T) {
if dbEntry.ReadCount != 1 { if dbEntry.ReadCount != 1 {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err = shared.DecryptHistoryEntry("key", 2, *dbEntry) decEntry, err = shared.DecryptHistoryEntry("key", *dbEntry)
shared.Check(t, err) shared.Check(t, err)
if !shared.EntryEquals(decEntry, *entry) { if !shared.EntryEquals(decEntry, *entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, *entry) t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, *entry)

View File

@ -120,7 +120,7 @@ func GetUserSecret() (string, error) {
return config.UserSecret, nil return config.UserSecret, nil
} }
func Setup(deviceId int, args []string) error { func Setup(args []string) error {
userSecret := uuid.Must(uuid.NewRandom()).String() userSecret := uuid.Must(uuid.NewRandom()).String()
if len(args) > 2 && args[2] != "" { if len(args) > 2 && args[2] != "" {
userSecret = args[2] userSecret = args[2]
@ -130,7 +130,7 @@ func Setup(deviceId int, args []string) error {
var config ClientConfig var config ClientConfig
config.UserSecret = userSecret config.UserSecret = userSecret
config.IsEnabled = true config.IsEnabled = true
config.DeviceId = deviceId config.DeviceId = uuid.Must(uuid.NewRandom()).String()
return SetConfig(config) return SetConfig(config)
} }
@ -158,7 +158,7 @@ func DisplayResults(results []*HistoryEntry, displayHostname bool) {
type ClientConfig struct { type ClientConfig struct {
UserSecret string `json:"user_secret"` UserSecret string `json:"user_secret"`
IsEnabled bool `json:"is_enabled"` IsEnabled bool `json:"is_enabled"`
DeviceId int `json:"device_id"` DeviceId string `json:"device_id"`
} }
func GetConfig() (ClientConfig, error) { func GetConfig() (ClientConfig, error) {
@ -253,7 +253,7 @@ func Install() error {
if err != nil { if err != nil {
// No config, so set up a new installation // No config, so set up a new installation
// TODO: GO THROUGH THE REGISTRATION FLOW // TODO: GO THROUGH THE REGISTRATION FLOW
return Setup(0, os.Args) return Setup(os.Args)
} }
return nil return nil
} }

View File

@ -15,7 +15,7 @@ func TestSetup(t *testing.T) {
if _, err := os.Stat(path.Join(homedir, HISHTORY_PATH, CONFIG_PATH)); err == nil { if _, err := os.Stat(path.Join(homedir, HISHTORY_PATH, CONFIG_PATH)); err == nil {
t.Fatalf("hishtory secret file already exists!") t.Fatalf("hishtory secret file already exists!")
} }
Check(t, Setup(0, []string{})) Check(t, Setup([]string{}))
if _, err := os.Stat(path.Join(homedir, HISHTORY_PATH, CONFIG_PATH)); err != nil { if _, err := os.Stat(path.Join(homedir, HISHTORY_PATH, CONFIG_PATH)); err != nil {
t.Fatalf("hishtory secret file does not exist after Setup()!") t.Fatalf("hishtory secret file does not exist after Setup()!")
} }
@ -28,7 +28,7 @@ func TestSetup(t *testing.T) {
func TestBuildHistoryEntry(t *testing.T) { func TestBuildHistoryEntry(t *testing.T) {
defer BackupAndRestore(t)() defer BackupAndRestore(t)()
Check(t, Setup(0, []string{})) Check(t, Setup([]string{}))
entry, err := BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / ", "1641774958326745663"}) entry, err := BuildHistoryEntry([]string{"unused", "saveHistoryEntry", "120", " 123 ls / ", "1641774958326745663"})
Check(t, err) Check(t, err)
if entry.UserSecret == "" || len(entry.UserSecret) < 10 || strings.TrimSpace(entry.UserSecret) != entry.UserSecret { if entry.UserSecret == "" || len(entry.UserSecret) < 10 || strings.TrimSpace(entry.UserSecret) != entry.UserSecret {
@ -53,14 +53,14 @@ func TestBuildHistoryEntry(t *testing.T) {
func TestGetUserSecret(t *testing.T) { func TestGetUserSecret(t *testing.T) {
defer BackupAndRestore(t)() defer BackupAndRestore(t)()
Check(t, Setup(0, []string{})) Check(t, Setup([]string{}))
secret1, err := GetUserSecret() secret1, err := GetUserSecret()
Check(t, err) Check(t, err)
if len(secret1) < 10 || strings.Contains(secret1, " ") || strings.Contains(secret1, "\n") { if len(secret1) < 10 || strings.Contains(secret1, " ") || strings.Contains(secret1, "\n") {
t.Fatalf("unexpected secret: %v", secret1) t.Fatalf("unexpected secret: %v", secret1)
} }
Check(t, Setup(0, []string{})) Check(t, Setup([]string{}))
secret2, err := GetUserSecret() secret2, err := GetUserSecret()
Check(t, err) Check(t, err)

View File

@ -5,7 +5,6 @@ import (
"io" "io"
"os" "os"
"path" "path"
"strconv"
"strings" "strings"
"time" "time"
@ -74,10 +73,6 @@ func UserId(key string) string {
return Hmac(key, KDF_USER_ID) return Hmac(key, KDF_USER_ID)
} }
func DeviceId(key string, id int) string {
return Hmac(key, KDF_DEVICE_ID+strconv.Itoa(id))
}
func EncryptionKey(userSecret string) ([]byte, error) { func EncryptionKey(userSecret string) ([]byte, error) {
encryptionKey, err := base64.URLEncoding.DecodeString(Hmac(userSecret, KDF_ENCRYPTION_KEY)) encryptionKey, err := base64.URLEncoding.DecodeString(Hmac(userSecret, KDF_ENCRYPTION_KEY))
if err != nil { if err != nil {
@ -150,13 +145,10 @@ func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (EncHistoryEntry
}, nil }, nil
} }
func DecryptHistoryEntry(userSecret string, deviceId int, entry EncHistoryEntry) (HistoryEntry, error) { func DecryptHistoryEntry(userSecret string, entry EncHistoryEntry) (HistoryEntry, error) {
if entry.UserId != UserId(userSecret) { if entry.UserId != UserId(userSecret) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId") return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching UserId")
} }
if entry.DeviceId != DeviceId(userSecret, deviceId) {
return HistoryEntry{}, fmt.Errorf("Refusing to decrypt history entry with mismatching DeviceId")
}
plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce) plaintext, err := Decrypt(userSecret, entry.EncryptedData, []byte(UserId(userSecret)), entry.Nonce)
if err != nil { if err != nil {
return HistoryEntry{}, nil return HistoryEntry{}, nil

View File

@ -6,7 +6,7 @@ import (
func TestPersist(t *testing.T) { func TestPersist(t *testing.T) {
defer BackupAndRestore(t)() defer BackupAndRestore(t)()
Check(t, Setup(0, []string{})) Check(t, Setup([]string{}))
db, err := OpenLocalSqliteDb() db, err := OpenLocalSqliteDb()
Check(t, err) Check(t, err)
@ -27,7 +27,7 @@ func TestPersist(t *testing.T) {
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {
defer BackupAndRestore(t)() defer BackupAndRestore(t)()
Check(t, Setup(0, []string{})) Check(t, Setup([]string{}))
db, err := OpenLocalSqliteDb() db, err := OpenLocalSqliteDb()
Check(t, err) Check(t, err)