First version of working redaction with passing integration tests

This commit is contained in:
David Dworken 2022-09-19 22:49:48 -07:00
parent afe1fc5043
commit 5391ecd220
6 changed files with 347 additions and 44 deletions

View File

@ -124,6 +124,9 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
panic(err)
}
w.Write(resp)
// TODO: Make thsi method also check the pending deletion requests
// And then can delete the extra round trip of doing processDeletionRequests() after pulling from the remote
}
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
@ -201,6 +204,67 @@ func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(forcedBanner))
}
func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
// TODO: Count how many times they've been read and eventually delete them
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var deletionRequests []*shared.DeletionRequest
result := GLOBAL_DB.Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
respBody, err := json.Marshal(deletionRequests)
if err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
}
w.Write(respBody)
}
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadAll(r.Body)
if err != nil {
panic(err)
}
var request shared.DeletionRequest
err = json.Unmarshal(data, &request)
if err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
// Store the deletion request so all the devices will get it
tx := GLOBAL_DB.Where("user_id = ?", request.UserId)
var devices []*shared.Device
result := tx.Find(&devices)
if result.Error != nil {
panic(fmt.Errorf("DB query error: %v", result.Error))
}
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId))
}
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
for _, device := range devices {
request.DestinationDeviceId = device.DeviceId
result := GLOBAL_DB.Create(&request)
if result.Error != nil {
panic(result.Error)
}
}
// Also delete anything currently in the DB matching it
numDeleted := 0
for _, message := range request.Messages.Ids {
// TODO: Optimize this into one query
tx = GLOBAL_DB.Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date)
result := tx.Delete(&shared.EncHistoryEntry{})
if result.Error != nil {
panic(result.Error)
}
numDeleted += int(result.RowsAffected)
}
fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
}
func wipeDbHandler(w http.ResponseWriter, r *http.Request) {
result := GLOBAL_DB.Exec("DELETE FROM enc_history_entries")
if result.Error != nil {
@ -226,6 +290,7 @@ func OpenDB() (*gorm.DB, error) {
db.AutoMigrate(&shared.Device{})
db.AutoMigrate(&UsageData{})
db.AutoMigrate(&shared.DumpRequest{})
db.AutoMigrate(&shared.DeletionRequest{})
db.Exec("PRAGMA journal_mode = WAL")
return db, nil
}
@ -238,6 +303,7 @@ func OpenDB() (*gorm.DB, error) {
db.AutoMigrate(&shared.Device{})
db.AutoMigrate(&UsageData{})
db.AutoMigrate(&shared.DumpRequest{})
db.AutoMigrate(&shared.DeletionRequest{})
return db, nil
}
@ -468,6 +534,8 @@ func main() {
http.Handle("/api/v1/banner", withLogging(apiBannerHandler))
http.Handle("/api/v1/download", withLogging(apiDownloadHandler))
http.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler))
http.Handle("/api/v1/get-deletion-requests", withLogging(getDeletionRequestsHandler))
http.Handle("/api/v1/add-deletion-request", withLogging(addDeletionRequestHandler))
if isTestEnvironment() {
http.Handle("/api/v1/wipe-db", withLogging(wipeDbHandler))
}

View File

@ -130,6 +130,8 @@ func TestParameterized(t *testing.T) {
t.Run("testStripBashTimePrefix/"+tester.ShellName(), func(t *testing.T) { testStripBashTimePrefix(t, tester) })
t.Run("testReuploadHistoryEntries/"+tester.ShellName(), func(t *testing.T) { testReuploadHistoryEntries(t, tester) })
t.Run("testInitialHistoryImport/"+tester.ShellName(), func(t *testing.T) { testInitialHistoryImport(t, tester) })
t.Run("testLocalRedaction/"+tester.ShellName(), func(t *testing.T) { testLocalRedaction(t, tester) })
t.Run("testRemoteRedaction/"+tester.ShellName(), func(t *testing.T) { testRemoteRedaction(t, tester) })
}
}
@ -785,6 +787,9 @@ func hishtoryQuery(t *testing.T, tester shellTester, query string) string {
func manuallySubmitHistoryEntry(t *testing.T, userSecret string, entry data.HistoryEntry) {
encEntry, err := data.EncryptHistoryEntry(userSecret, entry)
shared.Check(t, err)
if encEntry.Date != entry.EndTime {
t.Fatalf("encEntry.Date does not match the entry")
}
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
shared.Check(t, err)
resp, err := http.Post("http://localhost:8080/api/v1/submit", "application/json", bytes.NewBuffer(jsonValue))
@ -1320,9 +1325,117 @@ echo %v-bar`, randomCmdUuid, randomCmdUuid)
}
// Check that the previously recorded commands are in hishtory
// TODO: change the below to | grep -v pipefail and see that it fails weirdly with zsh
out = tester.RunInteractiveShell(t, `hishtory export `+randomCmdUuid)
expectedOutput = fmt.Sprintf("hishtory export %s\necho %s-foo\necho %s-bar\nhishtory export %s\n", randomCmdUuid, randomCmdUuid, randomCmdUuid, randomCmdUuid)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
}
func testLocalRedaction(t *testing.T, tester shellTester) {
// Setup
defer shared.BackupAndRestore(t)()
// Install hishtory
installHishtory(t, tester, "")
// Record some commands
randomCmdUuid := uuid.Must(uuid.NewRandom()).String()
randomCmd := fmt.Sprintf(`echo %v-foo
echo %v-bas
echo foo
ls /tmp`, randomCmdUuid, randomCmdUuid)
tester.RunInteractiveShell(t, randomCmd)
// Check that the previously recorded commands are in hishtory
out := tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
expectedOutput := fmt.Sprintf("echo %s-foo\necho %s-bas\necho foo\nls /tmp\n", randomCmdUuid, randomCmdUuid)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
// Redact foo
out = tester.RunInteractiveShell(t, `hishtory redact --force foo`)
if out != "Permanently deleting 2 entries" {
t.Fatalf("hishtory redact gave unexpected output=%#v", out)
}
// Check that the commands are redacted
out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
expectedOutput = fmt.Sprintf("echo %s-bas\nls /tmp\nhishtory redact --force foo\n", randomCmdUuid)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
// Redact s
out = tester.RunInteractiveShell(t, `hishtory redact --force s`)
if out != "Permanently deleting 10 entries" {
t.Fatalf("hishtory redact gave unexpected output=%#v", out)
}
// Check that the commands are redacted
out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
expectedOutput = "hishtory redact --force s\n"
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
}
func testRemoteRedaction(t *testing.T, tester shellTester) {
// Setup
defer shared.BackupAndRestore(t)()
// Install hishtory client 1
userSecret := installHishtory(t, tester, "")
// Record some commands
randomCmdUuid := uuid.Must(uuid.NewRandom()).String()
randomCmd := fmt.Sprintf(`echo %v-foo
echo %v-bas`, randomCmdUuid, randomCmdUuid)
tester.RunInteractiveShell(t, randomCmd)
time.Sleep(2 * time.Second) // TODO: this sleep is covering up a bug
tester.RunInteractiveShell(t, `echo foo
ls /tmp`)
// Check that the previously recorded commands are in hishtory
out := tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
expectedOutput := fmt.Sprintf("echo %s-foo\necho %s-bas\necho foo\nls /tmp\n", randomCmdUuid, randomCmdUuid)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
// Install hishtory client 2
restoreInstall1 := shared.BackupAndRestoreWithId(t, "-1")
installHishtory(t, tester, userSecret)
// And confirm that it has the commands too
out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
// Restore the first client, and redact some commands
restoreInstall2 := shared.BackupAndRestoreWithId(t, "-2")
restoreInstall1()
out = tester.RunInteractiveShell(t, `hishtory redact --force `+randomCmdUuid)
if out != "Permanently deleting 2 entries" {
t.Fatalf("hishtory redact gave unexpected output=%#v", out)
}
// Confirm that client1 doesn't have the commands
out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
expectedOutput = fmt.Sprintf("echo foo\nls /tmp\nhishtory redact --force %s\n", randomCmdUuid)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
// Swap back to the second client and then confirm it processed the deletion request
restoreInstall2()
out = tester.RunInteractiveShell(t, `hishtory export | grep -v pipefail`)
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
}
// TODO: some tests for offline behavior

View File

@ -1,7 +1,6 @@
package data
import (
"bufio"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
@ -11,7 +10,6 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"strings"
"time"
@ -39,6 +37,10 @@ type HistoryEntry struct {
DeviceId string `json:"device_id" gorm:"uniqueIndex:compositeindex"`
}
func (h *HistoryEntry) GoString() string {
return fmt.Sprintf("%#v", *h)
}
func sha256hmac(key, additionalData string) []byte {
h := hmac.New(sha256.New, []byte(key))
h.Write([]byte(additionalData))
@ -108,7 +110,7 @@ func EncryptHistoryEntry(userSecret string, entry HistoryEntry) (shared.EncHisto
EncryptedData: ciphertext,
Nonce: nonce,
UserId: UserId(userSecret),
Date: time.Now(),
Date: entry.EndTime,
EncryptedId: uuid.Must(uuid.NewRandom()).String(),
ReadCount: 0,
}, nil
@ -135,7 +137,7 @@ func parseTimeGenerously(input string) (time.Time, error) {
return dateparse.ParseLocal(input)
}
func makeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) {
func MakeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) {
tokens, err := tokenize(query)
if err != nil {
return nil, fmt.Errorf("failed to tokenize query: %v", err)
@ -173,39 +175,8 @@ func makeWhereQueryFromSearch(db *gorm.DB, query string) (*gorm.DB, error) {
return tx, nil
}
func Redact(db *gorm.DB, query string) error {
tx, err := makeWhereQueryFromSearch(db, query)
if err != nil {
return err
}
var count int64
res := tx.Count(&count)
if res.Error != nil {
return res.Error
}
fmt.Printf("This will permanently delete %d entries, are you sure? [y/N]", count)
reader := bufio.NewReader(os.Stdin)
resp, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read response: %v", err)
}
if strings.TrimSpace(resp) != "y" {
fmt.Printf("Aborting delete per user response of %#v\n", strings.TrimSpace(resp))
return nil
}
tx, err = makeWhereQueryFromSearch(db, query)
if err != nil {
return err
}
res = tx.Delete(&HistoryEntry{})
if res.Error != nil {
return res.Error
}
return nil
}
func Search(db *gorm.DB, query string, limit int) ([]*HistoryEntry, error) {
tx, err := makeWhereQueryFromSearch(db, query)
tx, err := MakeWhereQueryFromSearch(db, query)
if err != nil {
return nil, err
}

View File

@ -1101,3 +1101,70 @@ func EncryptAndMarshal(config ClientConfig, entry *data.HistoryEntry) ([]byte, e
}
return jsonValue, nil
}
func Redact(db *gorm.DB, query string, force bool) error {
tx, err := data.MakeWhereQueryFromSearch(db, query)
if err != nil {
return err
}
var historyEntries []*data.HistoryEntry
res := tx.Find(&historyEntries)
if res.Error != nil {
return res.Error
}
if force {
fmt.Printf("Permanently deleting %d entries", len(historyEntries))
} else {
// TODO: Find a way to test the prompting
fmt.Printf("This will permanently delete %d entries, are you sure? [y/N]", len(historyEntries))
reader := bufio.NewReader(os.Stdin)
resp, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read response: %v", err)
}
if strings.TrimSpace(resp) != "y" {
fmt.Printf("Aborting delete per user response of %#v\n", strings.TrimSpace(resp))
return nil
}
}
tx, err = data.MakeWhereQueryFromSearch(db, query)
if err != nil {
return err
}
res = tx.Delete(&data.HistoryEntry{})
if res.Error != nil {
return res.Error
}
if res.RowsAffected != int64(len(historyEntries)) {
return fmt.Errorf("DB deleted %d rows, when we only expected to delete %d rows, something may have gone wrong", res.RowsAffected, len(historyEntries))
}
err = deleteOnRemoteInstances(historyEntries)
if err != nil {
return err
}
return nil
}
func deleteOnRemoteInstances(historyEntries []*data.HistoryEntry) error {
config, err := GetConfig()
if err != nil {
return err
}
var deletionRequest shared.DeletionRequest
deletionRequest.SendTime = time.Now()
deletionRequest.UserId = data.UserId(config.UserSecret)
for _, entry := range historyEntries {
deletionRequest.Messages.Ids = append(deletionRequest.Messages.Ids, shared.MessageIdentifier{Date: entry.EndTime, DeviceId: entry.DeviceId})
}
data, err := json.Marshal(deletionRequest)
if err != nil {
return err
}
_, err = ApiPost("/api/v1/add-deletion-request", "application/json", data)
if err != nil {
return fmt.Errorf("failed to send deletion request to backend service, this may cause commands to not get deleted on other instances of hishtory: %v", err)
}
return nil
}

View File

@ -8,8 +8,6 @@ import (
"strings"
"time"
"gorm.io/gorm"
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared"
@ -26,16 +24,27 @@ func main() {
case "saveHistoryEntry":
lib.CheckFatalError(maybeUploadSkippedHistoryEntries())
saveHistoryEntry()
lib.CheckFatalError(processDeletionRequests())
case "query":
lib.CheckFatalError(processDeletionRequests())
query(strings.Join(os.Args[2:], " "))
case "export":
lib.CheckFatalError(processDeletionRequests())
export(strings.Join(os.Args[2:], " "))
case "redact":
fallthrough
case "delete":
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote())
lib.CheckFatalError(processDeletionRequests())
db, err := lib.OpenLocalSqliteDb()
lib.CheckFatalError(err)
lib.CheckFatalError(data.Redact(db, strings.Join(os.Args[2:], " ")))
query := strings.Join(os.Args[2:], " ")
force := false
if os.Args[2] == "--force" {
query = strings.Join(os.Args[3:], " ")
force = true
}
lib.CheckFatalError(lib.Redact(db, query, force))
case "init":
lib.CheckFatalError(lib.Setup(os.Args))
case "install":
@ -128,7 +137,44 @@ func getDumpRequests(config lib.ClientConfig) ([]*shared.DumpRequest, error) {
return dumpRequests, err
}
func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error {
func processDeletionRequests() error {
config, err := lib.GetConfig()
if err != nil {
return err
}
resp, err := lib.ApiGet("/api/v1/get-deletion-requests?user_id=" + data.UserId(config.UserSecret) + "&device_id=" + config.DeviceId)
if lib.IsOfflineError(err) {
return nil
}
if err != nil {
return err
}
var deletionRequests []*shared.DeletionRequest
err = json.Unmarshal(resp, &deletionRequests)
if err != nil {
return err
}
db, err := lib.OpenLocalSqliteDb()
if err != nil {
return err
}
for _, request := range deletionRequests {
for _, entry := range request.Messages.Ids {
res := db.Where("device_id = ? AND end_time = ?", entry.DeviceId, entry.Date).Delete(&data.HistoryEntry{})
if res.Error != nil {
return fmt.Errorf("DB error: %v", res.Error)
}
}
}
return nil
}
func retrieveAdditionalEntriesFromRemote() error {
db, err := lib.OpenLocalSqliteDb()
if err != nil {
return err
}
config, err := lib.GetConfig()
if err != nil {
return err
@ -152,13 +198,13 @@ func retrieveAdditionalEntriesFromRemote(db *gorm.DB) error {
}
lib.AddToDbIfNew(db, decEntry)
}
return nil
return processDeletionRequests()
}
func query(query string) {
db, err := lib.OpenLocalSqliteDb()
lib.CheckFatalError(err)
err = retrieveAdditionalEntriesFromRemote(db)
err = retrieveAdditionalEntriesFromRemote()
if err != nil {
if lib.IsOfflineError(err) {
fmt.Println("Warning: hishtory is offline so this may be missing recent results from your other machines!")
@ -284,7 +330,7 @@ func saveHistoryEntry() {
}
}
if len(dumpRequests) > 0 {
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote(db))
lib.CheckFatalError(retrieveAdditionalEntriesFromRemote())
entries, err := data.Search(db, "", 0)
lib.CheckFatalError(err)
var encEntries []*shared.EncHistoryEntry
@ -305,7 +351,7 @@ func saveHistoryEntry() {
func export(query string) {
db, err := lib.OpenLocalSqliteDb()
lib.CheckFatalError(err)
err = retrieveAdditionalEntriesFromRemote(db)
err = retrieveAdditionalEntriesFromRemote()
if err != nil {
if lib.IsOfflineError(err) {
fmt.Println("Warning: hishtory is offline so this may be missing recent results from your other machines!")
@ -319,3 +365,5 @@ func export(query string) {
fmt.Println(data[i].Command)
}
}
// TODO: Can we have a global db and config rather than this nonsense?

View File

@ -1,6 +1,9 @@
package shared
import (
"database/sql/driver"
"encoding/json"
"fmt"
"time"
)
@ -43,6 +46,39 @@ type UpdateInfo struct {
Version string `json:"version"`
}
type DeletionRequest struct {
// TODO: Add a ReadCount
UserId string `json:"user_id"`
DestinationDeviceId string `json:"destination_device_id"`
SendTime time.Time `json:"send_time"`
Messages MessageIdentifiers `json:"messages"`
}
type MessageIdentifiers struct {
Ids []MessageIdentifier `json:"message_ids"`
}
type MessageIdentifier struct {
DeviceId string `json:"device_id"`
Date time.Time `json:"date"`
}
func (m *MessageIdentifiers) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return fmt.Errorf("failed to unmarshal JSONB value: %v", value)
}
result := MessageIdentifiers{}
err := json.Unmarshal(bytes, &result)
*m = result
return err
}
func (m MessageIdentifiers) Value() (driver.Value, error) {
return json.Marshal(m)
}
const (
CONFIG_PATH = ".hishtory.config"
HISHTORY_PATH = ".hishtory"