From 5391ecd220fe0ab681c1bbd07623747852e0140b Mon Sep 17 00:00:00 2001 From: David Dworken Date: Mon, 19 Sep 2022 22:49:48 -0700 Subject: [PATCH] First version of working redaction with passing integration tests --- backend/server/server.go | 68 +++++++++++++++++++++++ client/client_test.go | 113 +++++++++++++++++++++++++++++++++++++++ client/data/data.go | 43 +++------------ client/lib/lib.go | 67 +++++++++++++++++++++++ hishtory.go | 64 +++++++++++++++++++--- shared/data.go | 36 +++++++++++++ 6 files changed, 347 insertions(+), 44 deletions(-) diff --git a/backend/server/server.go b/backend/server/server.go index 8732479..641bc0e 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -124,6 +124,9 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) { panic(err) } w.Write(resp) + + // TODO: Make thsi method also check the pending deletion requests + // And then can delete the extra round trip of doing processDeletionRequests() after pulling from the remote } func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { @@ -201,6 +204,67 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte(forcedBanner)) } +func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { + // TODO: Count how many times they've been read and eventually delete them + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + var deletionRequests []*shared.DeletionRequest + result := GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests) + if result.Error != nil { + panic(fmt.Errorf("DB query error: %v", result.Error)) + } + respBody, err := json.Marshal(deletionRequests) + if err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) + } + w.Write(respBody) +} + +func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + panic(err) + } + var request shared.DeletionRequest + err = json.Unmarshal(data, &request) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) + + // Store the deletion request so all the devices will get it + tx := GLOBAL_DB.Where("user_id = ?", request.UserId) + var devices []*shared.Device + result := tx.Find(&devices) + if result.Error != nil { + panic(fmt.Errorf("DB query error: %v", result.Error)) + } + if len(devices) == 0 { + panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId)) + } + fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices)) + for _, device := range devices { + request.DestinationDeviceId = device.DeviceId + result := GLOBAL_DB.Create(&request) + if result.Error != nil { + panic(result.Error) + } + } + + // Also delete anything currently in the DB matching it + numDeleted := 0 + for _, message := range request.Messages.Ids { + // TODO: Optimize this into one query + tx = GLOBAL_DB.Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date) + result := tx.Delete(&shared.EncHistoryEntry{}) + if result.Error != nil { + panic(result.Error) + } + numDeleted += int(result.RowsAffected) + } + fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted) +} + func wipeDbHandler(w http.ResponseWriter, r *http.Request) { result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries") if result.Error != nil { @@ -226,6 +290,7 @@ func OpenDB() (*gorm.DB, error) { db.AutoMigrate(&shared.Device{}) db.AutoMigrate(&UsageData{}) db.AutoMigrate(&shared.DumpRequest{}) + db.AutoMigrate(&shared.DeletionRequest{}) db.Exec("PRAGMA journal_mode = WAL") return db, nil } @@ -238,6 +303,7 @@ func OpenDB() (*gorm.DB, error) { db.AutoMigrate(&shared.Device{}) db.AutoMigrate(&UsageData{}) db.AutoMigrate(&shared.DumpRequest{}) + db.AutoMigrate(&shared.DeletionRequest{}) return db, nil } @@ -468,6 +534,8 @@ func main() { http.Handle("/api/v1/banner", withLogging(apiBannerHandler)) http.Handle("/api/v1/download", withLogging(apiDownloadHandler)) http.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler)) + http.Handle("/api/v1/get-deletion-requests", withLogging(getDeletionRequestsHandler)) + http.Handle("/api/v1/add-deletion-request", withLogging(addDeletionRequestHandler)) if isTestEnvironment() { http.Handle("/api/v1/wipe-db", withLogging(wipeDbHandler)) } diff --git a/client/client_test.go b/client/client_test.go index 121e5e3..12447fa 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -130,6 +130,8 @@ func TestParameterized(t *testing.T) { t.Run("testStripBashTimePrefix/"+tester.ShellName(), func(t *testing.T) { testStripBashTimePrefix(t, tester) }) t.Run("testReuploadHistoryEntries/"+tester.ShellName(), func(t *testing.T) { testReuploadHistoryEntries(t, tester) }) t.Run("testInitialHistoryImport/"+tester.ShellName(), func(t *testing.T) { testInitialHistoryImport(t, tester) }) + t.Run("testLocalRedaction/"+tester.ShellName(), func(t *testing.T) { testLocalRedaction(t, tester) }) + t.Run("testRemoteRedaction/"+tester.ShellName(), func(t *testing.T) { testRemoteRedaction(t, tester) }) } } @@ -785,6 +787,9 @@ func hishtoryQuery(t *testing.T, tester shellTester, query string) string { func manuallySubmitHistoryEntry(t *testing.T, userSecret string, entry data.HistoryEntry) { encEntry, err := data.EncryptHistoryEntry(userSecret, entry) shared.Check(t, err) + if encEntry.Date != entry.EndTime { + t.Fatalf("encEntry.Date does not match the entry") + } jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) shared.Check(t, err) resp, err := http.Post("http://localhost:8080/api/v1/submit", "application/json", bytes.NewBuffer(jsonValue)) @@ -1320,9 +1325,117 @@ echo %v-bar`, randomCmdUuid, randomCmdUuid) } // Check that the previously recorded commands are in hishtory + // TODO: change the below to | grep -v pipefail and see that it fails weirdly with zsh out = tester.RunInteractiveShell(t, `hishtory export `+randomCmdUuid) expectedOutput = fmt.Sprintf("hishtory export %s\necho %s-foo\necho %s-bar\nhishtory export %s\n", randomCmdUuid, randomCmdUuid, randomCmdUuid, randomCmdUuid) if diff := cmp.Diff(expectedOutput, out); diff != "" { t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) } } + +func testLocalRedaction(t *testing.T, tester shellTester) { + // Setup + defer shared.BackupAndRestore(t)() + + // Install hishtory + installHishtory(t, tester, "") + + // Record some commands + randomCmdUuid := uuid.Must(uuid.NewRandom()).String() + randomCmd := fmt.Sprintf(`echo %v-foo +echo %v-bas +echo foo +ls /tmp`, randomCmdUuid, randomCmdUuid) + tester.RunInteractiveShell(t, randomCmd) + + // Check that the previously recorded commands are in hishtory + out := tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + expectedOutput := fmt.Sprintf("echo %s-foo\necho %s-bas\necho foo\nls /tmp\n", randomCmdUuid, randomCmdUuid) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } + + // Redact foo + out = tester.RunInteractiveShell(t, `hishtory redact --force foo`) + if out != "Permanently deleting 2 entries" { + t.Fatalf("hishtory redact gave unexpected output=%#v", out) + } + + // Check that the commands are redacted + out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + expectedOutput = fmt.Sprintf("echo %s-bas\nls /tmp\nhishtory redact --force foo\n", randomCmdUuid) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } + + // Redact s + out = tester.RunInteractiveShell(t, `hishtory redact --force s`) + if out != "Permanently deleting 10 entries" { + t.Fatalf("hishtory redact gave unexpected output=%#v", out) + } + + // Check that the commands are redacted + out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + expectedOutput = "hishtory redact --force s\n" + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } +} + +func testRemoteRedaction(t *testing.T, tester shellTester) { + // Setup + defer shared.BackupAndRestore(t)() + + // Install hishtory client 1 + userSecret := installHishtory(t, tester, "") + + // Record some commands + randomCmdUuid := uuid.Must(uuid.NewRandom()).String() + randomCmd := fmt.Sprintf(`echo %v-foo +echo %v-bas`, randomCmdUuid, randomCmdUuid) + tester.RunInteractiveShell(t, randomCmd) + time.Sleep(2 * time.Second) // TODO: this sleep is covering up a bug + tester.RunInteractiveShell(t, `echo foo +ls /tmp`) + + // Check that the previously recorded commands are in hishtory + out := tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + expectedOutput := fmt.Sprintf("echo %s-foo\necho %s-bas\necho foo\nls /tmp\n", randomCmdUuid, randomCmdUuid) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } + + // Install hishtory client 2 + restoreInstall1 := shared.BackupAndRestoreWithId(t, "-1") + installHishtory(t, tester, userSecret) + + // And confirm that it has the commands too + out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } + + // Restore the first client, and redact some commands + restoreInstall2 := shared.BackupAndRestoreWithId(t, "-2") + restoreInstall1() + out = tester.RunInteractiveShell(t, `hishtory redact --force `+randomCmdUuid) + if out != "Permanently deleting 2 entries" { + t.Fatalf("hishtory redact gave unexpected output=%#v", out) + } + + // Confirm that client1 doesn't have the commands + out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + expectedOutput = fmt.Sprintf("echo foo\nls /tmp\nhishtory redact --force %s\n", randomCmdUuid) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } + + // Swap back to the second client and then confirm it processed the deletion request + restoreInstall2() + out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`) + if diff := cmp.Diff(expectedOutput, out); diff != "" { + t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) + } +} + +// TODO: some tests for offline behavior diff --git a/client/data/data.go b/client/data/data.go index 5ab101b..f035f72 100644 --- a/client/data/data.go +++ b/client/data/data.go @@ -1,7 +1,6 @@ package data import ( - "bufio" "crypto/aes" "crypto/cipher" "crypto/hmac" @@ -11,7 +10,6 @@ import ( "encoding/json" "fmt" "io" - "os" "strings" "time" @@ -39,6 +37,10 @@ type HistoryEntry struct { DeviceId string `json:"device_id" gorm:"uniqueIndex:compositeindex"` } +func (h *HistoryEntry) GoString() string { + return fmt.Sprintf("%#v", *h) +} + func sha256hmac(key, additionalData string) []byte { h := hmac.New(sha256.New, []byte(key)) h.Write([]byte(additionalData)) @@ -108,7 +110,7 @@ func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (shared.EncHisto EncryptedData: ciphertext, Nonce: nonce, UserId: UserId(userSecret), - Date: time.Now(), + Date: entry.EndTime, EncryptedId: uuid.Must(uuid.NewRandom()).String(), ReadCount: 0, }, nil @@ -135,7 +137,7 @@ func parseTimeGenerously(input string) (time.Time, error) { return dateparse.ParseLocal(input) } -func makeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) { +func MakeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) { tokens, err := tokenize(query) if err != nil { return nil, fmt.Errorf("failed to tokenize query: %v", err) @@ -173,39 +175,8 @@ func makeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) { return tx, nil } -func Redact(db *gorm.DB, query string) error { - tx, err := makeWhereQueryFromSearch(db, query) - if err != nil { - return err - } - var count int64 - res := tx.Count(&count) - if res.Error != nil { - return res.Error - } - fmt.Printf("This will permanently delete %d entries, are you sure? [y/N]", count) - reader := bufio.NewReader(os.Stdin) - resp, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("failed to read response: %v", err) - } - if strings.TrimSpace(resp) != "y" { - fmt.Printf("Aborting delete per user response of %#v\n", strings.TrimSpace(resp)) - return nil - } - tx, err = makeWhereQueryFromSearch(db, query) - if err != nil { - return err - } - res = tx.Delete(&HistoryEntry{}) - if res.Error != nil { - return res.Error - } - return nil -} - func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) { - tx, err := makeWhereQueryFromSearch(db, query) + tx, err := MakeWhereQueryFromSearch(db, query) if err != nil { return nil, err } diff --git a/client/lib/lib.go b/client/lib/lib.go index fe23f25..1f0a726 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -1101,3 +1101,70 @@ func EncryptAndMarshal(config ClientConfig, entry *data.HistoryEntry) ([]byte, e } return jsonValue, nil } + +func Redact(db *gorm.DB, query string, force bool) error { + tx, err := data.MakeWhereQueryFromSearch(db, query) + if err != nil { + return err + } + var historyEntries []*data.HistoryEntry + res := tx.Find(&historyEntries) + if res.Error != nil { + return res.Error + } + if force { + fmt.Printf("Permanently deleting %d entries", len(historyEntries)) + } else { + // TODO: Find a way to test the prompting + fmt.Printf("This will permanently delete %d entries, are you sure? [y/N]", len(historyEntries)) + reader := bufio.NewReader(os.Stdin) + resp, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read response: %v", err) + } + if strings.TrimSpace(resp) != "y" { + fmt.Printf("Aborting delete per user response of %#v\n", strings.TrimSpace(resp)) + return nil + } + } + tx, err = data.MakeWhereQueryFromSearch(db, query) + if err != nil { + return err + } + res = tx.Delete(&data.HistoryEntry{}) + if res.Error != nil { + return res.Error + } + if res.RowsAffected != int64(len(historyEntries)) { + return fmt.Errorf("DB deleted %d rows, when we only expected to delete %d rows, something may have gone wrong", res.RowsAffected, len(historyEntries)) + } + err = deleteOnRemoteInstances(historyEntries) + if err != nil { + return err + } + return nil +} + +func deleteOnRemoteInstances(historyEntries []*data.HistoryEntry) error { + config, err := GetConfig() + if err != nil { + return err + } + + var deletionRequest shared.DeletionRequest + deletionRequest.SendTime = time.Now() + deletionRequest.UserId = data.UserId(config.UserSecret) + + for _, entry := range historyEntries { + deletionRequest.Messages.Ids = append(deletionRequest.Messages.Ids, shared.MessageIdentifier{Date: entry.EndTime, DeviceId: entry.DeviceId}) + } + data, err := json.Marshal(deletionRequest) + if err != nil { + return err + } + _, err = ApiPost("/api/v1/add-deletion-request", "application/json", data) + if err != nil { + return fmt.Errorf("failed to send deletion request to backend service, this may cause commands to not get deleted on other instances of hishtory: %v", err) + } + return nil +} diff --git a/hishtory.go b/hishtory.go index 9f1f010..f61e042 100644 --- a/hishtory.go +++ b/hishtory.go @@ -8,8 +8,6 @@ import ( "strings" "time" - "gorm.io/gorm" - "github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/lib" "github.com/ddworken/hishtory/shared" @@ -26,16 +24,27 @@ func main() { case "saveHistoryEntry": lib.CheckFatalError(maybeUploadSkippedHistoryEntries()) saveHistoryEntry() + lib.CheckFatalError(processDeletionRequests()) case "query": + lib.CheckFatalError(processDeletionRequests()) query(strings.Join(os.Args[2:], " ")) case "export": + lib.CheckFatalError(processDeletionRequests()) export(strings.Join(os.Args[2:], " ")) case "redact": fallthrough case "delete": + lib.CheckFatalError(retrieveAdditionalEntriesFromRemote()) + lib.CheckFatalError(processDeletionRequests()) db, err := lib.OpenLocalSqliteDb() lib.CheckFatalError(err) - lib.CheckFatalError(data.Redact(db, strings.Join(os.Args[2:], " "))) + query := strings.Join(os.Args[2:], " ") + force := false + if os.Args[2] == "--force" { + query = strings.Join(os.Args[3:], " ") + force = true + } + lib.CheckFatalError(lib.Redact(db, query, force)) case "init": lib.CheckFatalError(lib.Setup(os.Args)) case "install": @@ -128,7 +137,44 @@ func getDumpRequests(config lib.ClientConfig) ([]*shared.DumpRequest, error) { return dumpRequests, err } -func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error { +func processDeletionRequests() error { + config, err := lib.GetConfig() + if err != nil { + return err + } + + resp, err := lib.ApiGet("/api/v1/get-deletion-requests?user_id=" + data.UserId(config.UserSecret) + "&device_id=" + config.DeviceId) + if lib.IsOfflineError(err) { + return nil + } + if err != nil { + return err + } + var deletionRequests []*shared.DeletionRequest + err = json.Unmarshal(resp, &deletionRequests) + if err != nil { + return err + } + db, err := lib.OpenLocalSqliteDb() + if err != nil { + return err + } + for _, request := range deletionRequests { + for _, entry := range request.Messages.Ids { + res := db.Where("device_id = ? AND end_time = ?", entry.DeviceId, entry.Date).Delete(&data.HistoryEntry{}) + if res.Error != nil { + return fmt.Errorf("DB error: %v", res.Error) + } + } + } + return nil +} + +func retrieveAdditionalEntriesFromRemote() error { + db, err := lib.OpenLocalSqliteDb() + if err != nil { + return err + } config, err := lib.GetConfig() if err != nil { return err @@ -152,13 +198,13 @@ func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error { } lib.AddToDbIfNew(db, decEntry) } - return nil + return processDeletionRequests() } func query(query string) { db, err := lib.OpenLocalSqliteDb() lib.CheckFatalError(err) - err = retrieveAdditionalEntriesFromRemote(db) + err = retrieveAdditionalEntriesFromRemote() if err != nil { if lib.IsOfflineError(err) { fmt.Println("Warning: hishtory is offline so this may be missing recent results from your other machines!") @@ -284,7 +330,7 @@ func saveHistoryEntry() { } } if len(dumpRequests) > 0 { - lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db)) + lib.CheckFatalError(retrieveAdditionalEntriesFromRemote()) entries, err := data.Search(db, "", 0) lib.CheckFatalError(err) var encEntries []*shared.EncHistoryEntry @@ -305,7 +351,7 @@ func saveHistoryEntry() { func export(query string) { db, err := lib.OpenLocalSqliteDb() lib.CheckFatalError(err) - err = retrieveAdditionalEntriesFromRemote(db) + err = retrieveAdditionalEntriesFromRemote() if err != nil { if lib.IsOfflineError(err) { fmt.Println("Warning: hishtory is offline so this may be missing recent results from your other machines!") @@ -319,3 +365,5 @@ func export(query string) { fmt.Println(data[i].Command) } } + +// TODO: Can we have a global db and config rather than this nonsense? diff --git a/shared/data.go b/shared/data.go index ece70ea..d5b6aaa 100644 --- a/shared/data.go +++ b/shared/data.go @@ -1,6 +1,9 @@ package shared import ( + "database/sql/driver" + "encoding/json" + "fmt" "time" ) @@ -43,6 +46,39 @@ type UpdateInfo struct { Version string `json:"version"` } +type DeletionRequest struct { + // TODO: Add a ReadCount + UserId string `json:"user_id"` + DestinationDeviceId string `json:"destination_device_id"` + SendTime time.Time `json:"send_time"` + Messages MessageIdentifiers `json:"messages"` +} + +type MessageIdentifiers struct { + Ids []MessageIdentifier `json:"message_ids"` +} + +type MessageIdentifier struct { + DeviceId string `json:"device_id"` + Date time.Time `json:"date"` +} + +func (m *MessageIdentifiers) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return fmt.Errorf("failed to unmarshal JSONB value: %v", value) + } + + result := MessageIdentifiers{} + err := json.Unmarshal(bytes, &result) + *m = result + return err +} + +func (m MessageIdentifiers) Value() (driver.Value, error) { + return json.Marshal(m) +} + const ( CONFIG_PATH = ".hishtory.config" HISHTORY_PATH = ".hishtory"