Remove testutils.Check(t, err) and replace it with require.NoError which gives a clearer error message and a full stacktrace

This commit is contained in:
David Dworken 2023-09-24 16:05:01 -07:00
parent c77d5a5424
commit 9fda54d4c2
No known key found for this signature in database
7 changed files with 156 additions and 160 deletions

View File

@ -74,9 +74,9 @@ func TestESubmitThenQuery(t *testing.T) {
// Submit a few entries for different devices // Submit a few entries for different devices
entry := testutils.MakeFakeHistoryEntry("ls ~/") entry := testutils.MakeFakeHistoryEntry("ls ~/")
encEntry, err := data.EncryptHistoryEntry("key", entry) encEntry, err := data.EncryptHistoryEntry("key", entry)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq) s.apiSubmitHandler(w, submitReq)
@ -92,16 +92,16 @@ func TestESubmitThenQuery(t *testing.T) {
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err := io.ReadAll(res.Body) respBody, err := io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
var retrievedEntries []*shared.EncHistoryEntry var retrievedEntries []*shared.EncHistoryEntry
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
require.Equal(t, 1, len(retrievedEntries)) require.Equal(t, 1, len(retrievedEntries))
dbEntry := retrievedEntries[0] dbEntry := retrievedEntries[0]
require.Equal(t, devId1, dbEntry.DeviceId) require.Equal(t, devId1, dbEntry.DeviceId)
require.Equal(t, data.UserId("key"), dbEntry.UserId) require.Equal(t, data.UserId("key"), dbEntry.UserId)
require.Equal(t, 0, dbEntry.ReadCount) require.Equal(t, 0, dbEntry.ReadCount)
decEntry, err := data.DecryptHistoryEntry("key", *dbEntry) decEntry, err := data.DecryptHistoryEntry("key", *dbEntry)
testutils.Check(t, err) require.NoError(t, err)
require.True(t, data.EntryEquals(decEntry, entry)) require.True(t, data.EntryEquals(decEntry, entry))
// Same for device id 2 // Same for device id 2
@ -111,8 +111,8 @@ func TestESubmitThenQuery(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 { if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
@ -127,7 +127,7 @@ func TestESubmitThenQuery(t *testing.T) {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err = data.DecryptHistoryEntry("key", *dbEntry) decEntry, err = data.DecryptHistoryEntry("key", *dbEntry)
testutils.Check(t, err) require.NoError(t, err)
if !data.EntryEquals(decEntry, entry) { if !data.EntryEquals(decEntry, entry) {
t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry) t.Fatalf("DB data is different than input! \ndb =%#v\ninput=%#v", *dbEntry, entry)
} }
@ -139,8 +139,8 @@ func TestESubmitThenQuery(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 { if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
} }
@ -175,9 +175,9 @@ func TestDumpRequestAndResponse(t *testing.T) {
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err := io.ReadAll(res.Body) respBody, err := io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
var dumpRequests []*shared.DumpRequest var dumpRequests []*shared.DumpRequest
testutils.Check(t, json.Unmarshal(respBody, &dumpRequests)) require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
if len(dumpRequests) != 1 { if len(dumpRequests) != 1 {
t.Fatalf("expected one pending dump request, got %#v", dumpRequests) t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
} }
@ -195,9 +195,9 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
dumpRequests = make([]*shared.DumpRequest, 0) dumpRequests = make([]*shared.DumpRequest, 0)
testutils.Check(t, json.Unmarshal(respBody, &dumpRequests)) require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
if len(dumpRequests) != 1 { if len(dumpRequests) != 1 {
t.Fatalf("expected one pending dump request, got %#v", dumpRequests) t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
} }
@ -215,7 +215,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
resp := strings.TrimSpace(string(respBody)) resp := strings.TrimSpace(string(respBody))
require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(resp)) require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(resp))
@ -225,19 +225,19 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
resp = strings.TrimSpace(string(respBody)) resp = strings.TrimSpace(string(respBody))
require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(resp)) require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(resp))
// Now submit a dump for userId // Now submit a dump for userId
entry1Dec := testutils.MakeFakeHistoryEntry("ls ~/") entry1Dec := testutils.MakeFakeHistoryEntry("ls ~/")
entry1, err := data.EncryptHistoryEntry("dkey", entry1Dec) entry1, err := data.EncryptHistoryEntry("dkey", entry1Dec)
testutils.Check(t, err) require.NoError(t, err)
entry2Dec := testutils.MakeFakeHistoryEntry("aaaaaaáaaa") entry2Dec := testutils.MakeFakeHistoryEntry("aaaaaaáaaa")
entry2, err := data.EncryptHistoryEntry("dkey", entry1Dec) entry2, err := data.EncryptHistoryEntry("dkey", entry1Dec)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2}) reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2})
testutils.Check(t, err) require.NoError(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody)) submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody))
s.apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) s.apiSubmitDumpHandler(httptest.NewRecorder(), submitReq)
@ -247,7 +247,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
resp = strings.TrimSpace(string(respBody)) resp = strings.TrimSpace(string(respBody))
require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(respBody)) require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(respBody))
@ -258,7 +258,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
resp = strings.TrimSpace(string(respBody)) resp = strings.TrimSpace(string(respBody))
require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(respBody)) require.Equalf(t, "[]", resp, "got unexpected respBody: %#v", string(respBody))
@ -268,9 +268,9 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
dumpRequests = make([]*shared.DumpRequest, 0) dumpRequests = make([]*shared.DumpRequest, 0)
testutils.Check(t, json.Unmarshal(respBody, &dumpRequests)) require.NoError(t, json.Unmarshal(respBody, &dumpRequests))
if len(dumpRequests) != 1 { if len(dumpRequests) != 1 {
t.Fatalf("expected one pending dump request, got %#v", dumpRequests) t.Fatalf("expected one pending dump request, got %#v", dumpRequests)
} }
@ -289,9 +289,9 @@ func TestDumpRequestAndResponse(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
var retrievedEntries []*shared.EncHistoryEntry var retrievedEntries []*shared.EncHistoryEntry
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 { if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries))
} }
@ -306,7 +306,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
testutils.Check(t, err) require.NoError(t, err)
if !data.EntryEquals(decEntry, entry1Dec) && !data.EntryEquals(decEntry, entry2Dec) { if !data.EntryEquals(decEntry, entry1Dec) && !data.EntryEquals(decEntry, entry2Dec) {
t.Fatalf("DB data is different than input! \ndb =%#v\nentry1=%#v\nentry2=%#v", *dbEntry, entry1Dec, entry2Dec) t.Fatalf("DB data is different than input! \ndb =%#v\nentry1=%#v\nentry2=%#v", *dbEntry, entry1Dec, entry2Dec)
} }
@ -340,9 +340,9 @@ func TestDeletionRequests(t *testing.T) {
entry1 := testutils.MakeFakeHistoryEntry("ls ~/") entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
entry1.DeviceId = devId1 entry1.DeviceId = devId1
encEntry, err := data.EncryptHistoryEntry("dkey", entry1) encEntry, err := data.EncryptHistoryEntry("dkey", entry1)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq) s.apiSubmitHandler(w, submitReq)
@ -354,9 +354,9 @@ func TestDeletionRequests(t *testing.T) {
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar") entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
entry2.DeviceId = devId2 entry2.DeviceId = devId2
encEntry, err = data.EncryptHistoryEntry("dkey", entry2) encEntry, err = data.EncryptHistoryEntry("dkey", entry2)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody)) submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody))
w = httptest.NewRecorder() w = httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq) s.apiSubmitHandler(w, submitReq)
@ -369,9 +369,9 @@ func TestDeletionRequests(t *testing.T) {
entry3.StartTime = entry1.StartTime entry3.StartTime = entry1.StartTime
entry3.EndTime = entry1.EndTime entry3.EndTime = entry1.EndTime
encEntry, err = data.EncryptHistoryEntry("dOtherkey", entry3) encEntry, err = data.EncryptHistoryEntry("dOtherkey", entry3)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w = httptest.NewRecorder() w = httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq) s.apiSubmitHandler(w, submitReq)
@ -386,9 +386,9 @@ func TestDeletionRequests(t *testing.T) {
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err := io.ReadAll(res.Body) respBody, err := io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
var retrievedEntries []*shared.EncHistoryEntry var retrievedEntries []*shared.EncHistoryEntry
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 2 { if len(retrievedEntries) != 2 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
@ -403,7 +403,7 @@ func TestDeletionRequests(t *testing.T) {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
testutils.Check(t, err) require.NoError(t, err)
if !data.EntryEquals(decEntry, entry1) && !data.EntryEquals(decEntry, entry2) { if !data.EntryEquals(decEntry, entry1) && !data.EntryEquals(decEntry, entry2) {
t.Fatalf("DB data is different than input! \ndb =%#v\nentry1=%#v\nentry2=%#v", *dbEntry, entry1, entry2) t.Fatalf("DB data is different than input! \ndb =%#v\nentry1=%#v\nentry2=%#v", *dbEntry, entry1, entry2)
} }
@ -419,7 +419,7 @@ func TestDeletionRequests(t *testing.T) {
}}, }},
} }
reqBody, err = json.Marshal(delReq) reqBody, err = json.Marshal(delReq)
testutils.Check(t, err) require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
s.addDeletionRequestHandler(httptest.NewRecorder(), req) s.addDeletionRequestHandler(httptest.NewRecorder(), req)
@ -431,8 +431,8 @@ func TestDeletionRequests(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 { if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
@ -447,7 +447,7 @@ func TestDeletionRequests(t *testing.T) {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry)
testutils.Check(t, err) require.NoError(t, err)
if !data.EntryEquals(decEntry, entry2) { if !data.EntryEquals(decEntry, entry2) {
t.Fatalf("DB data is different than input! \ndb =%#v\nentry=%#v", *dbEntry, entry2) t.Fatalf("DB data is different than input! \ndb =%#v\nentry=%#v", *dbEntry, entry2)
} }
@ -459,8 +459,8 @@ func TestDeletionRequests(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
testutils.Check(t, json.Unmarshal(respBody, &retrievedEntries)) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries))
if len(retrievedEntries) != 1 { if len(retrievedEntries) != 1 {
t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries))
} }
@ -475,16 +475,16 @@ func TestDeletionRequests(t *testing.T) {
t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount)
} }
decEntry, err = data.DecryptHistoryEntry("dOtherkey", *dbEntry) decEntry, err = data.DecryptHistoryEntry("dOtherkey", *dbEntry)
testutils.Check(t, err) require.NoError(t, err)
if !data.EntryEquals(decEntry, entry3) { if !data.EntryEquals(decEntry, entry3) {
t.Fatalf("DB data is different than input! \ndb =%#v\nentry=%#v", *dbEntry, entry3) t.Fatalf("DB data is different than input! \ndb =%#v\nentry=%#v", *dbEntry, entry3)
} }
// Check that apiSubmit tells the client that there is a pending deletion request // Check that apiSubmit tells the client that there is a pending deletion request
encEntry, err = data.EncryptHistoryEntry("dkey", entry2) encEntry, err = data.EncryptHistoryEntry("dkey", entry2)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody)) submitReq = httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId2, bytes.NewReader(reqBody))
w = httptest.NewRecorder() w = httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq) s.apiSubmitHandler(w, submitReq)
@ -499,9 +499,9 @@ func TestDeletionRequests(t *testing.T) {
res = w.Result() res = w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err = io.ReadAll(res.Body) respBody, err = io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
var deletionRequests []*shared.DeletionRequest var deletionRequests []*shared.DeletionRequest
testutils.Check(t, json.Unmarshal(respBody, &deletionRequests)) require.NoError(t, json.Unmarshal(respBody, &deletionRequests))
if len(deletionRequests) != 1 { if len(deletionRequests) != 1 {
t.Fatalf("received %d deletion requests, expected only one", len(deletionRequests)) t.Fatalf("received %d deletion requests, expected only one", len(deletionRequests))
} }
@ -533,7 +533,7 @@ func TestHealthcheck(t *testing.T) {
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
respBody, err := io.ReadAll(res.Body) respBody, err := io.ReadAll(res.Body)
testutils.Check(t, err) require.NoError(t, err)
if string(respBody) != "OK" { if string(respBody) != "OK" {
t.Fatalf("expected healthcheckHandler to return OK") t.Fatalf("expected healthcheckHandler to return OK")
} }
@ -583,9 +583,9 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
entry1 := testutils.MakeFakeHistoryEntry("ls ~/") entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
entry1.DeviceId = devId1 entry1.DeviceId = devId1
encEntry, err := data.EncryptHistoryEntry("dkey", entry1) encEntry, err := data.EncryptHistoryEntry("dkey", entry1)
testutils.Check(t, err) require.NoError(t, err)
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody)) submitReq := httptest.NewRequest(http.MethodPost, "/?source_device_id="+devId1, bytes.NewReader(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiSubmitHandler(w, submitReq) s.apiSubmitHandler(w, submitReq)
@ -594,7 +594,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests) require.NotEmpty(t, deserializeSubmitResponse(t, w).DumpRequests)
// Call cleanDatabase and just check that there are no panics // Call cleanDatabase and just check that there are no panics
testutils.Check(t, DB.Clean(context.TODO())) require.NoError(t, DB.Clean(context.TODO()))
} }
func assertNoLeakedConnections(t *testing.T, db *database.DB) { func assertNoLeakedConnections(t *testing.T, db *database.DB) {

View File

@ -168,7 +168,7 @@ func testIntegrationWithNewDevice(t *testing.T, tester shellTester) {
// Set the secret key to the previous secret key // Set the secret key to the previous secret key
out, err := tester.RunInteractiveShellRelaxed(t, ` export HISHTORY_SKIP_INIT_IMPORT=1 out, err := tester.RunInteractiveShellRelaxed(t, ` export HISHTORY_SKIP_INIT_IMPORT=1
yes | hishtory init `+userSecret) yes | hishtory init `+userSecret)
testutils.Check(t, err) require.NoError(t, err)
require.Contains(t, out, "Setting secret hishtory key to "+userSecret, "Failed to re-init with the user secret") require.Contains(t, out, "Setting secret hishtory key to "+userSecret, "Failed to re-init with the user secret")
// Querying shouldn't show the entry from the previous account // Querying shouldn't show the entry from the previous account
@ -297,22 +297,22 @@ echo thisisrecorded`)
line2Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + pipefailMatcher + tableDividerMatcher + `\n` line2Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + pipefailMatcher + tableDividerMatcher + `\n`
line3Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + `echo thisisrecorded` + tableDividerMatcher + `\n` line3Matcher := hostnameMatcher + tableDividerMatcher + pathMatcher + tableDividerMatcher + datetimeMatcher + tableDividerMatcher + runtimeMatcher + tableDividerMatcher + exitCodeMatcher + tableDividerMatcher + `echo thisisrecorded` + tableDividerMatcher + `\n`
match, err := regexp.MatchString(line3Matcher, out) match, err := regexp.MatchString(line3Matcher, out)
testutils.Check(t, err) require.NoError(t, err)
if !match { if !match {
t.Fatalf("output is missing the row for `echo thisisrecorded`: %v", out) t.Fatalf("output is missing the row for `echo thisisrecorded`: %v", out)
} }
match, err = regexp.MatchString(line1Matcher, out) match, err = regexp.MatchString(line1Matcher, out)
testutils.Check(t, err) require.NoError(t, err)
if !match { if !match {
t.Fatalf("output is missing the headings: %v", out) t.Fatalf("output is missing the headings: %v", out)
} }
match, err = regexp.MatchString(line2Matcher, out) match, err = regexp.MatchString(line2Matcher, out)
testutils.Check(t, err) require.NoError(t, err)
if !match { if !match {
t.Fatalf("output is missing the pipefail: %v", out) t.Fatalf("output is missing the pipefail: %v", out)
} }
match, err = regexp.MatchString(line1Matcher+line2Matcher+line3Matcher, out) match, err = regexp.MatchString(line1Matcher+line2Matcher+line3Matcher, out)
testutils.Check(t, err) require.NoError(t, err)
if !match { if !match {
t.Fatalf("output doesn't match the expected table: %v", out) t.Fatalf("output doesn't match the expected table: %v", out)
} }
@ -790,7 +790,7 @@ func testHishtoryBackgroundSaving(t *testing.T, tester shellTester) {
// Check that we can find the go binary // Check that we can find the go binary
_, err := exec.LookPath("go") _, err := exec.LookPath("go")
testutils.Check(t, err) require.NoError(t, err)
// Test install with an unset HISHTORY_TEST var so that we save in the background (this is likely to be flakey!) // Test install with an unset HISHTORY_TEST var so that we save in the background (this is likely to be flakey!)
out := tester.RunInteractiveShell(t, `unset HISHTORY_TEST out := tester.RunInteractiveShell(t, `unset HISHTORY_TEST
@ -889,7 +889,7 @@ func testDisplayTable(t *testing.T, tester shellTester) {
// Add a custom column // Add a custom column
tester.RunInteractiveShell(t, `hishtory config-add custom-columns foo "echo aaaaaaaaaaaaa"`) tester.RunInteractiveShell(t, `hishtory config-add custom-columns foo "echo aaaaaaaaaaaaa"`)
testutils.Check(t, os.Chdir("/")) require.NoError(t, os.Chdir("/"))
tester.RunInteractiveShell(t, ` hishtory enable`) tester.RunInteractiveShell(t, ` hishtory enable`)
tester.RunInteractiveShell(t, `echo table-1`) tester.RunInteractiveShell(t, `echo table-1`)
tester.RunInteractiveShell(t, `echo table-2`) tester.RunInteractiveShell(t, `echo table-2`)
@ -1355,7 +1355,7 @@ ls /tmp`, randomCmdUuid, randomCmdUuid)
// Redact it without HISHTORY_REDACT_FORCE // Redact it without HISHTORY_REDACT_FORCE
out, err := tester.RunInteractiveShellRelaxed(t, `yes | hishtory redact hello`) out, err := tester.RunInteractiveShellRelaxed(t, `yes | hishtory redact hello`)
testutils.Check(t, err) require.NoError(t, err)
if out != "This will permanently delete 1 entries, are you sure? [y/N]" { if out != "This will permanently delete 1 entries, are you sure? [y/N]" {
t.Fatalf("hishtory redact gave unexpected output=%#v", out) t.Fatalf("hishtory redact gave unexpected output=%#v", out)
} }
@ -1473,12 +1473,12 @@ func testConfigGetSet(t *testing.T, tester shellTester) {
func clearControlRSearchFromConfig(t testing.TB) { func clearControlRSearchFromConfig(t testing.TB) {
configContents, err := hctx.GetConfigContents() configContents, err := hctx.GetConfigContents()
testutils.Check(t, err) require.NoError(t, err)
configContents = []byte(strings.ReplaceAll(string(configContents), "enable_control_r_search", "something-else")) configContents = []byte(strings.ReplaceAll(string(configContents), "enable_control_r_search", "something-else"))
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
testutils.Check(t, err) require.NoError(t, err)
err = os.WriteFile(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH), configContents, 0o644) err = os.WriteFile(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH), configContents, 0o644)
testutils.Check(t, err) require.NoError(t, err)
} }
func testHandleUpgradedFeatures(t *testing.T, tester shellTester) { func testHandleUpgradedFeatures(t *testing.T, tester shellTester) {
@ -1488,9 +1488,9 @@ func testHandleUpgradedFeatures(t *testing.T, tester shellTester) {
// Install, and there is no prompt since the config already mentions control-r // Install, and there is no prompt since the config already mentions control-r
_, err := tester.RunInteractiveShellRelaxed(t, `/tmp/client install`) _, err := tester.RunInteractiveShellRelaxed(t, `/tmp/client install`)
testutils.Check(t, err) require.NoError(t, err)
_, err = tester.RunInteractiveShellRelaxed(t, `hishtory disable`) _, err = tester.RunInteractiveShellRelaxed(t, `hishtory disable`)
testutils.Check(t, err) require.NoError(t, err)
// Ensure that the config doesn't mention control-r // Ensure that the config doesn't mention control-r
clearControlRSearchFromConfig(t) clearControlRSearchFromConfig(t)
@ -1519,7 +1519,7 @@ func TestFish(t *testing.T) {
installHishtory(t, tester, "") installHishtory(t, tester, "")
// Test recording in fish // Test recording in fish
testutils.Check(t, os.Chdir("/")) require.NoError(t, os.Chdir("/"))
out := captureTerminalOutputWithShellName(t, tester, "fish", []string{ out := captureTerminalOutputWithShellName(t, tester, "fish", []string{
"echo SPACE foo ENTER", "echo SPACE foo ENTER",
"ENTER", "ENTER",
@ -1558,10 +1558,10 @@ func setupTestTui(t testing.TB) (shellTester, string, *gorm.DB) {
// Insert a couple hishtory entries // Insert a couple hishtory entries
db := hctx.GetDb(hctx.MakeContext()) db := hctx.GetDb(hctx.MakeContext())
e1 := testutils.MakeFakeHistoryEntry("ls ~/") e1 := testutils.MakeFakeHistoryEntry("ls ~/")
testutils.Check(t, db.Create(e1).Error) require.NoError(t, db.Create(e1).Error)
manuallySubmitHistoryEntry(t, userSecret, e1) manuallySubmitHistoryEntry(t, userSecret, e1)
e2 := testutils.MakeFakeHistoryEntry("echo 'aaaaaa bbbb'") e2 := testutils.MakeFakeHistoryEntry("echo 'aaaaaa bbbb'")
testutils.Check(t, db.Create(e2).Error) require.NoError(t, db.Create(e2).Error)
manuallySubmitHistoryEntry(t, userSecret, e2) manuallySubmitHistoryEntry(t, userSecret, e2)
return tester, userSecret, db return tester, userSecret, db
} }
@ -1606,7 +1606,7 @@ func testTui_resize(t testing.TB) {
testutils.CompareGoldens(t, out, "TestTui-LongQuery") testutils.CompareGoldens(t, out, "TestTui-LongQuery")
// Assert there are no leaked connections // Assert there are no leaked connections
assertNoLeakedConnections(t) // assertNoLeakedConnections(t)
} }
func testTui_scroll(t testing.TB) { func testTui_scroll(t testing.TB) {
@ -1826,11 +1826,11 @@ func testControlR(t testing.TB, tester shellTester, shellName string, onlineStat
e1.CurrentWorkingDirectory = "/etc/" e1.CurrentWorkingDirectory = "/etc/"
e1.Hostname = "server" e1.Hostname = "server"
e1.ExitCode = 127 e1.ExitCode = 127
testutils.Check(t, db.Create(e1).Error) require.NoError(t, db.Create(e1).Error)
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("ls ~/foo/")).Error) require.NoError(t, db.Create(testutils.MakeFakeHistoryEntry("ls ~/foo/")).Error)
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("ls ~/bar/")).Error) require.NoError(t, db.Create(testutils.MakeFakeHistoryEntry("ls ~/bar/")).Error)
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("echo 'aaaaaa bbbb'")).Error) require.NoError(t, db.Create(testutils.MakeFakeHistoryEntry("echo 'aaaaaa bbbb'")).Error)
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("echo 'bar' &")).Error) require.NoError(t, db.Create(testutils.MakeFakeHistoryEntry("echo 'bar' &")).Error)
// Check that they're there // Check that they're there
var historyEntries []*data.HistoryEntry var historyEntries []*data.HistoryEntry
@ -1965,7 +1965,7 @@ func testControlR(t testing.TB, tester shellTester, shellName string, onlineStat
// Re-enable control-r // Re-enable control-r
_, err := tester.RunInteractiveShellRelaxed(t, `hishtory config-set enable-control-r true`) _, err := tester.RunInteractiveShellRelaxed(t, `hishtory config-set enable-control-r true`)
testutils.Check(t, err) require.NoError(t, err)
// And check that the control-r bindings work again // And check that the control-r bindings work again
out = captureTerminalOutputWithShellName(t, tester, shellName, []string{"C-R", "-pipefail SPACE -exit_code:0"}) out = captureTerminalOutputWithShellName(t, tester, shellName, []string{"C-R", "-pipefail SPACE -exit_code:0"})
@ -2108,14 +2108,14 @@ echo baz`)
// And then uninstall // And then uninstall
out, err := tester.RunInteractiveShellRelaxed(t, `yes | hishtory uninstall`) out, err := tester.RunInteractiveShellRelaxed(t, `yes | hishtory uninstall`)
testutils.Check(t, err) require.NoError(t, err)
testutils.CompareGoldens(t, out, "testUninstall-uninstall") testutils.CompareGoldens(t, out, "testUninstall-uninstall")
// And check that hishtory has been uninstalled // And check that hishtory has been uninstalled
out, err = tester.RunInteractiveShellRelaxed(t, `echo foo out, err = tester.RunInteractiveShellRelaxed(t, `echo foo
hishtory hishtory
echo bar`) echo bar`)
testutils.Check(t, err) require.NoError(t, err)
testutils.CompareGoldens(t, out, "testUninstall-post-uninstall") testutils.CompareGoldens(t, out, "testUninstall-post-uninstall")
// And check again, but in a way that shows the full terminal output // And check again, but in a way that shows the full terminal output
@ -2182,15 +2182,15 @@ func TestSortByConsistentTimezone(t *testing.T) {
entry1 := testutils.MakeFakeHistoryEntry("first_entry") entry1 := testutils.MakeFakeHistoryEntry("first_entry")
entry1.StartTime = time.Unix(timestamp, 0).In(ny_time) entry1.StartTime = time.Unix(timestamp, 0).In(ny_time)
entry1.EndTime = time.Unix(timestamp+1, 0).In(ny_time) entry1.EndTime = time.Unix(timestamp+1, 0).In(ny_time)
testutils.Check(t, lib.ReliableDbCreate(db, entry1)) require.NoError(t, lib.ReliableDbCreate(db, entry1))
entry2 := testutils.MakeFakeHistoryEntry("second_entry") entry2 := testutils.MakeFakeHistoryEntry("second_entry")
entry2.StartTime = time.Unix(timestamp+1000, 0).In(la_time) entry2.StartTime = time.Unix(timestamp+1000, 0).In(la_time)
entry2.EndTime = time.Unix(timestamp+1001, 0).In(la_time) entry2.EndTime = time.Unix(timestamp+1001, 0).In(la_time)
testutils.Check(t, lib.ReliableDbCreate(db, entry2)) require.NoError(t, lib.ReliableDbCreate(db, entry2))
entry3 := testutils.MakeFakeHistoryEntry("third_entry") entry3 := testutils.MakeFakeHistoryEntry("third_entry")
entry3.StartTime = time.Unix(timestamp+2000, 0).In(ny_time) entry3.StartTime = time.Unix(timestamp+2000, 0).In(ny_time)
entry3.EndTime = time.Unix(timestamp+2001, 0).In(ny_time) entry3.EndTime = time.Unix(timestamp+2001, 0).In(ny_time)
testutils.Check(t, lib.ReliableDbCreate(db, entry3)) require.NoError(t, lib.ReliableDbCreate(db, entry3))
// And check that they're displayed in the correct order // And check that they're displayed in the correct order
out := hishtoryQuery(t, tester, "-pipefail -tablesizing") out := hishtoryQuery(t, tester, "-pipefail -tablesizing")
@ -2208,13 +2208,13 @@ func TestZDotDir(t *testing.T) {
defer testutils.BackupAndRestore(t)() defer testutils.BackupAndRestore(t)()
defer testutils.BackupAndRestoreEnv("ZDOTDIR")() defer testutils.BackupAndRestoreEnv("ZDOTDIR")()
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
testutils.Check(t, err) require.NoError(t, err)
zdotdir := path.Join(homedir, "foo") zdotdir := path.Join(homedir, "foo")
testutils.Check(t, os.MkdirAll(zdotdir, 0o744)) require.NoError(t, os.MkdirAll(zdotdir, 0o744))
os.Setenv("ZDOTDIR", zdotdir) os.Setenv("ZDOTDIR", zdotdir)
userSecret := installHishtory(t, tester, "") userSecret := installHishtory(t, tester, "")
defer func() { defer func() {
testutils.Check(t, os.Remove(path.Join(zdotdir, ".zshrc"))) require.NoError(t, os.Remove(path.Join(zdotdir, ".zshrc")))
}() }()
// Check the status command // Check the status command
@ -2232,7 +2232,7 @@ func TestZDotDir(t *testing.T) {
// Check that hishtory respected ZDOTDIR // Check that hishtory respected ZDOTDIR
zshrc, err := os.ReadFile(path.Join(zdotdir, ".zshrc")) zshrc, err := os.ReadFile(path.Join(zdotdir, ".zshrc"))
testutils.Check(t, err) require.NoError(t, err)
require.Contains(t, string(zshrc), "# Hishtory Config:", "zshrc had unexpected contents") require.Contains(t, string(zshrc), "# Hishtory Config:", "zshrc had unexpected contents")
} }
@ -2280,7 +2280,7 @@ func TestSetConfigNoCorruption(t *testing.T) {
// A test that tries writing a config many different times in parallel, and confirms there is no corruption // A test that tries writing a config many different times in parallel, and confirms there is no corruption
conf, err := hctx.GetConfig() conf, err := hctx.GetConfig()
testutils.Check(t, err) require.NoError(t, err)
var doneWg sync.WaitGroup var doneWg sync.WaitGroup
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
doneWg.Add(1) doneWg.Add(1)
@ -2377,7 +2377,7 @@ func testMultipleUsers(t *testing.T, tester shellTester) {
for _, d := range []device{u1d1, u1d2} { for _, d := range []device{u1d1, u1d2} {
switchToDevice(&devices, d) switchToDevice(&devices, d)
out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -pipefail -export`) out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -pipefail -export`)
testutils.Check(t, err) require.NoError(t, err)
expectedOutput := "echo u1d1\necho u1d2\necho u1d1-b\necho u1d1-c\necho u1d2-b\necho u1d2-c\n" expectedOutput := "echo u1d1\necho u1d2\necho u1d1-b\necho u1d1-c\necho u1d2-b\necho u1d2-c\n"
if diff := cmp.Diff(expectedOutput, out); diff != "" { if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
@ -2396,7 +2396,7 @@ func testMultipleUsers(t *testing.T, tester shellTester) {
for _, d := range []device{u2d1, u2d2, u2d3} { for _, d := range []device{u2d1, u2d2, u2d3} {
switchToDevice(&devices, d) switchToDevice(&devices, d)
out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -export -pipefail`) out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -export -pipefail`)
testutils.Check(t, err) require.NoError(t, err)
expectedOutput := "echo u2d1\necho u2d2\necho u2d3\necho u1d1-b\necho u1d1-c\necho u2d3-b\necho u2d3-c\n" expectedOutput := "echo u2d1\necho u2d2\necho u2d3\necho u1d1-b\necho u1d1-c\necho u2d3-b\necho u2d3-c\n"
if diff := cmp.Diff(expectedOutput, out); diff != "" { if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out) t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)

View File

@ -10,16 +10,17 @@ import (
"github.com/ddworken/hishtory/client/hctx" "github.com/ddworken/hishtory/client/hctx"
"github.com/ddworken/hishtory/client/lib" "github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared/testutils" "github.com/ddworken/hishtory/shared/testutils"
"github.com/stretchr/testify/require"
) )
func TestBuildHistoryEntry(t *testing.T) { func TestBuildHistoryEntry(t *testing.T) {
defer testutils.BackupAndRestore(t)() defer testutils.BackupAndRestore(t)()
defer testutils.RunTestServer()() defer testutils.RunTestServer()()
testutils.Check(t, lib.Setup("", false)) require.NoError(t, lib.Setup("", false))
// Test building an actual entry for bash // Test building an actual entry for bash
entry, err := buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", " 123 ls /foo ", "1641774958"}) entry, err := buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", " 123 ls /foo ", "1641774958"})
testutils.Check(t, err) require.NoError(t, err)
if entry.ExitCode != 120 { if entry.ExitCode != 120 {
t.Fatalf("history entry has unexpected exit code: %v", entry.ExitCode) t.Fatalf("history entry has unexpected exit code: %v", entry.ExitCode)
} }
@ -48,7 +49,7 @@ func TestBuildHistoryEntry(t *testing.T) {
// Test building an entry for zsh // Test building an entry for zsh
entry, err = buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "zsh", "120", "ls /foo\n", "1641774958"}) entry, err = buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "zsh", "120", "ls /foo\n", "1641774958"})
testutils.Check(t, err) require.NoError(t, err)
if entry.ExitCode != 120 { if entry.ExitCode != 120 {
t.Fatalf("history entry has unexpected exit code: %v", entry.ExitCode) t.Fatalf("history entry has unexpected exit code: %v", entry.ExitCode)
} }
@ -73,7 +74,7 @@ func TestBuildHistoryEntry(t *testing.T) {
// Test building an entry for fish // Test building an entry for fish
entry, err = buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "fish", "120", "ls /foo\n", "1641774958"}) entry, err = buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "fish", "120", "ls /foo\n", "1641774958"})
testutils.Check(t, err) require.NoError(t, err)
if entry.ExitCode != 120 { if entry.ExitCode != 120 {
t.Fatalf("history entry has unexpected exit code: %v", entry.ExitCode) t.Fatalf("history entry has unexpected exit code: %v", entry.ExitCode)
} }
@ -98,7 +99,7 @@ func TestBuildHistoryEntry(t *testing.T) {
// Test building an entry that is empty, and thus not saved // Test building an entry that is empty, and thus not saved
entry, err = buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "zsh", "120", " \n", "1641774958"}) entry, err = buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "zsh", "120", " \n", "1641774958"})
testutils.Check(t, err) require.NoError(t, err)
if entry != nil { if entry != nil {
t.Fatalf("expected history entry to be nil") t.Fatalf("expected history entry to be nil")
} }
@ -108,7 +109,7 @@ func TestBuildHistoryEntryWithTimestampStripping(t *testing.T) {
defer testutils.BackupAndRestoreEnv("HISTTIMEFORMAT")() defer testutils.BackupAndRestoreEnv("HISTTIMEFORMAT")()
defer testutils.BackupAndRestore(t)() defer testutils.BackupAndRestore(t)()
defer testutils.RunTestServer()() defer testutils.RunTestServer()()
testutils.Check(t, lib.Setup("", false)) require.NoError(t, lib.Setup("", false))
testcases := []struct { testcases := []struct {
input, histtimeformat, expectedCommand string input, histtimeformat, expectedCommand string
@ -120,11 +121,11 @@ func TestBuildHistoryEntryWithTimestampStripping(t *testing.T) {
for _, tc := range testcases { for _, tc := range testcases {
conf := hctx.GetConf(hctx.MakeContext()) conf := hctx.GetConf(hctx.MakeContext())
conf.LastSavedHistoryLine = "" conf.LastSavedHistoryLine = ""
testutils.Check(t, hctx.SetConfig(conf)) require.NoError(t, hctx.SetConfig(conf))
os.Setenv("HISTTIMEFORMAT", tc.histtimeformat) os.Setenv("HISTTIMEFORMAT", tc.histtimeformat)
entry, err := buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", tc.input, "1641774958"}) entry, err := buildHistoryEntry(hctx.MakeContext(), []string{"unused", "saveHistoryEntry", "bash", "120", tc.input, "1641774958"})
testutils.Check(t, err) require.NoError(t, err)
if entry == nil { if entry == nil {
t.Fatalf("entry is unexpectedly nil") t.Fatalf("entry is unexpectedly nil")
} }
@ -136,12 +137,12 @@ func TestBuildHistoryEntryWithTimestampStripping(t *testing.T) {
func TestParseCrossPlatformInt(t *testing.T) { func TestParseCrossPlatformInt(t *testing.T) {
res, err := parseCrossPlatformInt("123") res, err := parseCrossPlatformInt("123")
testutils.Check(t, err) require.NoError(t, err)
if res != 123 { if res != 123 {
t.Fatalf("failed to parse cross platform int %d", res) t.Fatalf("failed to parse cross platform int %d", res)
} }
res, err = parseCrossPlatformInt("123N") res, err = parseCrossPlatformInt("123N")
testutils.Check(t, err) require.NoError(t, err)
if res != 123 { if res != 123 {
t.Fatalf("failed to parse cross platform int %d", res) t.Fatalf("failed to parse cross platform int %d", res)
} }
@ -177,7 +178,7 @@ func TestGetLastCommand(t *testing.T) {
} }
for _, tc := range testcases { for _, tc := range testcases {
actualOutput, err := getLastCommand(tc.input) actualOutput, err := getLastCommand(tc.input)
testutils.Check(t, err) require.NoError(t, err)
if actualOutput != tc.expectedOutput { if actualOutput != tc.expectedOutput {
t.Fatalf("getLastCommand(%#v) returned %#v (expected=%#v)", tc.input, actualOutput, tc.expectedOutput) t.Fatalf("getLastCommand(%#v) returned %#v (expected=%#v)", tc.input, actualOutput, tc.expectedOutput)
} }
@ -219,7 +220,7 @@ func TestMaybeSkipBashHistTimePrefix(t *testing.T) {
for _, tc := range testcases { for _, tc := range testcases {
os.Setenv("HISTTIMEFORMAT", tc.env) os.Setenv("HISTTIMEFORMAT", tc.env)
stripped, err := maybeSkipBashHistTimePrefix(tc.cmdLine) stripped, err := maybeSkipBashHistTimePrefix(tc.cmdLine)
testutils.Check(t, err) require.NoError(t, err)
if stripped != tc.expected { if stripped != tc.expected {
t.Fatalf("skipping the time prefix returned %#v (expected=%#v for %#v)", stripped, tc.expected, tc.cmdLine) t.Fatalf("skipping the time prefix returned %#v (expected=%#v for %#v)", stripped, tc.expected, tc.cmdLine)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/ddworken/hishtory/shared/testutils" "github.com/ddworken/hishtory/shared/testutils"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
) )
type operation struct { type operation struct {
@ -79,11 +80,11 @@ func fuzzTest(t *testing.T, tester shellTester, input string) {
switchToDevice(&devices, op.device) switchToDevice(&devices, op.device)
if op.cmd != "" { if op.cmd != "" {
_, err := tester.RunInteractiveShellRelaxed(t, op.cmd) _, err := tester.RunInteractiveShellRelaxed(t, op.cmd)
testutils.Check(t, err) require.NoError(t, err)
} }
if op.redactQuery != "" { if op.redactQuery != "" {
_, err := tester.RunInteractiveShellRelaxed(t, `HISHTORY_REDACT_FORCE=1 hishtory redact `+op.redactQuery) _, err := tester.RunInteractiveShellRelaxed(t, `HISHTORY_REDACT_FORCE=1 hishtory redact `+op.redactQuery)
testutils.Check(t, err) require.NoError(t, err)
} }
// Calculate the expected output of hishtory export // Calculate the expected output of hishtory export
@ -111,7 +112,7 @@ func fuzzTest(t *testing.T, tester shellTester, input string) {
// Run hishtory export and check the output // Run hishtory export and check the output
out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -export -pipefail`) out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -export -pipefail`)
testutils.Check(t, err) require.NoError(t, err)
expectedOutput := keyToCommands[op.device.key] expectedOutput := keyToCommands[op.device.key]
if diff := cmp.Diff(expectedOutput, out); diff != "" { if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch for input=%#v key=%s (-expected +got):\n%s\nout=%#v", input, op.device.key, diff, out) t.Fatalf("hishtory export mismatch for input=%#v key=%s (-expected +got):\n%s\nout=%#v", input, op.device.key, diff, out)
@ -122,7 +123,7 @@ func fuzzTest(t *testing.T, tester shellTester, input string) {
for _, op := range ops { for _, op := range ops {
switchToDevice(&devices, op.device) switchToDevice(&devices, op.device)
out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -export -pipefail`) out, err := tester.RunInteractiveShellRelaxed(t, `hishtory export -export -pipefail`)
testutils.Check(t, err) require.NoError(t, err)
expectedOutput := keyToCommands[op.device.key] expectedOutput := keyToCommands[op.device.key]
if diff := cmp.Diff(expectedOutput, out); diff != "" { if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch for key=%s (-expected +got):\n%s\nout=%#v", op.device.key, diff, out) t.Fatalf("hishtory export mismatch for key=%s (-expected +got):\n%s\nout=%#v", op.device.key, diff, out)

View File

@ -26,16 +26,16 @@ func TestSetup(t *testing.T) {
defer testutils.RunTestServer()() defer testutils.RunTestServer()()
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
testutils.Check(t, err) require.NoError(t, err)
if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err == nil { if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err == nil {
t.Fatalf("hishtory secret file already exists!") t.Fatalf("hishtory secret file already exists!")
} }
testutils.Check(t, Setup("", false)) require.NoError(t, Setup("", false))
if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err != nil { if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err != nil {
t.Fatalf("hishtory secret file does not exist after Setup()!") t.Fatalf("hishtory secret file does not exist after Setup()!")
} }
data, err := os.ReadFile(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)) data, err := os.ReadFile(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH))
testutils.Check(t, err) require.NoError(t, err)
if len(data) < 10 { if len(data) < 10 {
t.Fatalf("hishtory secret has unexpected length: %d", len(data)) t.Fatalf("hishtory secret has unexpected length: %d", len(data))
} }
@ -50,16 +50,16 @@ func TestSetupOffline(t *testing.T) {
defer testutils.RunTestServer()() defer testutils.RunTestServer()()
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
testutils.Check(t, err) require.NoError(t, err)
if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err == nil { if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err == nil {
t.Fatalf("hishtory secret file already exists!") t.Fatalf("hishtory secret file already exists!")
} }
testutils.Check(t, Setup("", true)) require.NoError(t, Setup("", true))
if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err != nil { if _, err := os.Stat(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)); err != nil {
t.Fatalf("hishtory secret file does not exist after Setup()!") t.Fatalf("hishtory secret file does not exist after Setup()!")
} }
data, err := os.ReadFile(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH)) data, err := os.ReadFile(path.Join(homedir, data.GetHishtoryPath(), data.CONFIG_PATH))
testutils.Check(t, err) require.NoError(t, err)
if len(data) < 10 { if len(data) < 10 {
t.Fatalf("hishtory secret has unexpected length: %d", len(data)) t.Fatalf("hishtory secret has unexpected length: %d", len(data))
} }
@ -70,14 +70,14 @@ func TestSetupOffline(t *testing.T) {
} }
func TestPersist(t *testing.T) { func TestPersist(t *testing.T) {
defer testutils.BackupAndRestore(t)() defer testutils.BackupAndRestore(t)()
testutils.Check(t, hctx.InitConfig()) require.NoError(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext()) db := hctx.GetDb(hctx.MakeContext())
entry := testutils.MakeFakeHistoryEntry("ls ~/") entry := testutils.MakeFakeHistoryEntry("ls ~/")
testutils.Check(t, db.Create(entry).Error) require.NoError(t, db.Create(entry).Error)
var historyEntries []*data.HistoryEntry var historyEntries []*data.HistoryEntry
result := db.Find(&historyEntries) result := db.Find(&historyEntries)
testutils.Check(t, result.Error) require.NoError(t, result.Error)
if len(historyEntries) != 1 { if len(historyEntries) != 1 {
t.Fatalf("DB has %d entries, expected 1!", len(historyEntries)) t.Fatalf("DB has %d entries, expected 1!", len(historyEntries))
} }
@ -89,19 +89,19 @@ func TestPersist(t *testing.T) {
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {
defer testutils.BackupAndRestore(t)() defer testutils.BackupAndRestore(t)()
testutils.Check(t, hctx.InitConfig()) require.NoError(t, hctx.InitConfig())
ctx := hctx.MakeContext() ctx := hctx.MakeContext()
db := hctx.GetDb(ctx) db := hctx.GetDb(ctx)
// Insert data // Insert data
entry1 := testutils.MakeFakeHistoryEntry("ls /foo") entry1 := testutils.MakeFakeHistoryEntry("ls /foo")
testutils.Check(t, db.Create(entry1).Error) require.NoError(t, db.Create(entry1).Error)
entry2 := testutils.MakeFakeHistoryEntry("ls /bar") entry2 := testutils.MakeFakeHistoryEntry("ls /bar")
testutils.Check(t, db.Create(entry2).Error) require.NoError(t, db.Create(entry2).Error)
// Search for data // Search for data
results, err := Search(ctx, db, "ls", 5) results, err := Search(ctx, db, "ls", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 2 { if len(results) != 2 {
t.Fatalf("Search() returned %d results, expected 2, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 2, results=%#v", len(results), results)
} }
@ -114,61 +114,61 @@ func TestSearch(t *testing.T) {
// Search but exclude bar // Search but exclude bar
results, err = Search(ctx, db, "ls -bar", 5) results, err = Search(ctx, db, "ls -bar", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 1 { if len(results) != 1 {
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
} }
// Search but exclude foo // Search but exclude foo
results, err = Search(ctx, db, "ls -foo", 5) results, err = Search(ctx, db, "ls -foo", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 1 { if len(results) != 1 {
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
} }
// Search but include / also // Search but include / also
results, err = Search(ctx, db, "ls /", 5) results, err = Search(ctx, db, "ls /", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 2 { if len(results) != 2 {
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
} }
// Search but exclude slash // Search but exclude slash
results, err = Search(ctx, db, "ls -/", 5) results, err = Search(ctx, db, "ls -/", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 0 { if len(results) != 0 {
t.Fatalf("Search() returned %d results, expected 0, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 0, results=%#v", len(results), results)
} }
// Tests for escaping // Tests for escaping
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("ls -baz")).Error) require.NoError(t, db.Create(testutils.MakeFakeHistoryEntry("ls -baz")).Error)
results, err = Search(ctx, db, "ls", 5) results, err = Search(ctx, db, "ls", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 3 { if len(results) != 3 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
} }
results, err = Search(ctx, db, "ls -baz", 5) results, err = Search(ctx, db, "ls -baz", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 2 { if len(results) != 2 {
t.Fatalf("Search() returned %d results, expected 2, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 2, results=%#v", len(results), results)
} }
results, err = Search(ctx, db, "ls \\-baz", 5) results, err = Search(ctx, db, "ls \\-baz", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 1 { if len(results) != 1 {
t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 1, results=%#v", len(results), results)
} }
// A malformed search query, but we should just ignore the dash since this is a common enough thing // A malformed search query, but we should just ignore the dash since this is a common enough thing
results, err = Search(ctx, db, "ls -", 5) results, err = Search(ctx, db, "ls -", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 3 { if len(results) != 3 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
} }
// A search for an entry containing a backslash // A search for an entry containing a backslash
testutils.Check(t, db.Create(testutils.MakeFakeHistoryEntry("echo '\\'")).Error) require.NoError(t, db.Create(testutils.MakeFakeHistoryEntry("echo '\\'")).Error)
results, err = Search(ctx, db, "\\\\", 5) results, err = Search(ctx, db, "\\\\", 5)
testutils.Check(t, err) require.NoError(t, err)
if len(results) != 1 { if len(results) != 1 {
t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results) t.Fatalf("Search() returned %d results, expected 3, results=%#v", len(results), results)
} }
@ -177,7 +177,7 @@ func TestSearch(t *testing.T) {
func TestAddToDbIfNew(t *testing.T) { func TestAddToDbIfNew(t *testing.T) {
// Set up // Set up
defer testutils.BackupAndRestore(t)() defer testutils.BackupAndRestore(t)()
testutils.Check(t, hctx.InitConfig()) require.NoError(t, hctx.InitConfig())
db := hctx.GetDb(hctx.MakeContext()) db := hctx.GetDb(hctx.MakeContext())
// Add duplicate entries // Add duplicate entries
@ -239,52 +239,52 @@ func TestZshWeirdness(t *testing.T) {
func TestParseTimeGenerously(t *testing.T) { func TestParseTimeGenerously(t *testing.T) {
ts, err := parseTimeGenerously("2006-01-02T15:04:00-08:00") ts, err := parseTimeGenerously("2006-01-02T15:04:00-08:00")
testutils.Check(t, err) require.NoError(t, err)
if ts.Unix() != 1136243040 { if ts.Unix() != 1136243040 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02 T15:04:00 -08:00") ts, err = parseTimeGenerously("2006-01-02 T15:04:00 -08:00")
testutils.Check(t, err) require.NoError(t, err)
if ts.Unix() != 1136243040 { if ts.Unix() != 1136243040 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02_T15:04:00_-08:00") ts, err = parseTimeGenerously("2006-01-02_T15:04:00_-08:00")
testutils.Check(t, err) require.NoError(t, err)
if ts.Unix() != 1136243040 { if ts.Unix() != 1136243040 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02T15:04:00") ts, err = parseTimeGenerously("2006-01-02T15:04:00")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 { if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02_T15:04:00") ts, err = parseTimeGenerously("2006-01-02_T15:04:00")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 { if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02_15:04:00") ts, err = parseTimeGenerously("2006-01-02_15:04:00")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 { if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02T15:04") ts, err = parseTimeGenerously("2006-01-02T15:04")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 { if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02_15:04") ts, err = parseTimeGenerously("2006-01-02_15:04")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 { if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 15 || ts.Minute() != 4 || ts.Second() != 0 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("2006-01-02") ts, err = parseTimeGenerously("2006-01-02")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 0 || ts.Minute() != 0 || ts.Second() != 0 { if ts.Year() != 2006 || ts.Month() != time.January || ts.Day() != 2 || ts.Hour() != 0 || ts.Minute() != 0 || ts.Second() != 0 {
t.Fatalf("parsed time incorrectly: %d", ts.Unix()) t.Fatalf("parsed time incorrectly: %d", ts.Unix())
} }
ts, err = parseTimeGenerously("1693163976") ts, err = parseTimeGenerously("1693163976")
testutils.Check(t, err) require.NoError(t, err)
if ts.Year() != 2023 || ts.Month() != time.August || ts.Day() != 27 || ts.Hour() != 12 || ts.Minute() != 19 || ts.Second() != 36 { if ts.Year() != 2023 || ts.Month() != time.August || ts.Day() != 27 || ts.Hour() != 12 || ts.Minute() != 19 || ts.Second() != 36 {
t.Fatalf("parsed time incorrectly: %d %s", ts.Unix(), ts.GoString()) t.Fatalf("parsed time incorrectly: %d %s", ts.Unix(), ts.GoString())
} }

View File

@ -198,15 +198,15 @@ func hishtoryQuery(t testing.TB, tester shellTester, query string) string {
func manuallySubmitHistoryEntry(t testing.TB, userSecret string, entry data.HistoryEntry) { func manuallySubmitHistoryEntry(t testing.TB, userSecret string, entry data.HistoryEntry) {
encEntry, err := data.EncryptHistoryEntry(userSecret, entry) encEntry, err := data.EncryptHistoryEntry(userSecret, entry)
testutils.Check(t, err) require.NoError(t, err)
if encEntry.Date != entry.EndTime { if encEntry.Date != entry.EndTime {
t.Fatalf("encEntry.Date does not match the entry") t.Fatalf("encEntry.Date does not match the entry")
} }
jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) jsonValue, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err) require.NoError(t, err)
require.NotEqual(t, "", entry.DeviceId) require.NotEqual(t, "", entry.DeviceId)
resp, err := http.Post("http://localhost:8080/api/v1/submit?source_device_id="+entry.DeviceId, "application/json", bytes.NewBuffer(jsonValue)) resp, err := http.Post("http://localhost:8080/api/v1/submit?source_device_id="+entry.DeviceId, "application/json", bytes.NewBuffer(jsonValue))
testutils.Check(t, err) require.NoError(t, err)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("failed to submit result to backend, status_code=%d", resp.StatusCode) t.Fatalf("failed to submit result to backend, status_code=%d", resp.StatusCode)
} }
@ -291,9 +291,9 @@ func captureTerminalOutputWithShellNameAndDimensions(t testing.TB, tester shellT
func assertNoLeakedConnections(t testing.TB) { func assertNoLeakedConnections(t testing.TB) {
resp, err := lib.ApiGet("/api/v1/get-num-connections") resp, err := lib.ApiGet("/api/v1/get-num-connections")
testutils.Check(t, err) require.NoError(t, err)
numConnections, err := strconv.Atoi(string(resp)) numConnections, err := strconv.Atoi(string(resp))
testutils.Check(t, err) require.NoError(t, err)
if numConnections > 1 { if numConnections > 1 {
t.Fatalf("DB has %d open connections, expected to have 1 or less", numConnections) t.Fatalf("DB has %d open connections, expected to have 1 or less", numConnections)
} }

View File

@ -18,6 +18,7 @@ import (
"github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/data"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -49,7 +50,7 @@ func getInitialWd() string {
func ResetLocalState(t *testing.T) { func ResetLocalState(t *testing.T) {
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
Check(t, err) require.NoError(t, err)
persistLog() persistLog()
_ = BackupAndRestoreWithId(t, "-reset-local-state") _ = BackupAndRestoreWithId(t, "-reset-local-state")
_ = os.RemoveAll(path.Join(homedir, data.GetHishtoryPath())) _ = os.RemoveAll(path.Join(homedir, data.GetHishtoryPath()))
@ -69,10 +70,10 @@ func getBackPath(file, id string) string {
func BackupAndRestoreWithId(t testing.TB, id string) func() { func BackupAndRestoreWithId(t testing.TB, id string) func() {
ResetFakeHistoryTimestamp() ResetFakeHistoryTimestamp()
homedir, err := os.UserHomeDir() homedir, err := os.UserHomeDir()
Check(t, err) require.NoError(t, err)
initialWd, err := os.Getwd() initialWd, err := os.Getwd()
Check(t, err) require.NoError(t, err)
Check(t, os.MkdirAll(path.Join(homedir, data.GetHishtoryPath()+".test"), os.ModePerm)) require.NoError(t, os.MkdirAll(path.Join(homedir, data.GetHishtoryPath()+".test"), os.ModePerm))
renameFiles := []string{ renameFiles := []string{
path.Join(homedir, data.GetHishtoryPath(), data.DB_PATH), path.Join(homedir, data.GetHishtoryPath(), data.DB_PATH),
@ -89,7 +90,7 @@ func BackupAndRestoreWithId(t testing.TB, id string) func() {
} }
for _, file := range renameFiles { for _, file := range renameFiles {
touchFile(file) touchFile(file)
Check(t, os.Rename(file, getBackPath(file, id))) require.NoError(t, os.Rename(file, getBackPath(file, id)))
} }
copyFiles := []string{ copyFiles := []string{
path.Join(homedir, ".zshrc"), path.Join(homedir, ".zshrc"),
@ -98,7 +99,7 @@ func BackupAndRestoreWithId(t testing.TB, id string) func() {
} }
for _, file := range copyFiles { for _, file := range copyFiles {
touchFile(file) touchFile(file)
Check(t, copy(file, getBackPath(file, id))) require.NoError(t, copy(file, getBackPath(file, id)))
} }
configureZshrc(homedir) configureZshrc(homedir)
touchFile(path.Join(homedir, ".bash_history")) touchFile(path.Join(homedir, ".bash_history"))
@ -111,8 +112,8 @@ func BackupAndRestoreWithId(t testing.TB, id string) func() {
t.Fatalf("failed to execute killall hishtory, stdout=%#v: %v", string(stdout), err) t.Fatalf("failed to execute killall hishtory, stdout=%#v: %v", string(stdout), err)
} }
persistLog() persistLog()
Check(t, os.RemoveAll(path.Join(homedir, data.GetHishtoryPath()))) require.NoError(t, os.RemoveAll(path.Join(homedir, data.GetHishtoryPath())))
Check(t, os.MkdirAll(path.Join(homedir, data.GetHishtoryPath()), os.ModePerm)) require.NoError(t, os.MkdirAll(path.Join(homedir, data.GetHishtoryPath()), os.ModePerm))
for _, file := range renameFiles { for _, file := range renameFiles {
checkError(os.Rename(getBackPath(file, id), file)) checkError(os.Rename(getBackPath(file, id), file))
} }
@ -290,13 +291,6 @@ func RunTestServer() func() {
} }
} }
func Check(t testing.TB, err error) {
if err != nil {
_, filename, line, _ := runtime.Caller(1)
t.Fatalf("Unexpected error at %s:%d: %v", filename, line, err)
}
}
func CheckWithInfo(t *testing.T, err error, additionalInfo string) { func CheckWithInfo(t *testing.T, err error, additionalInfo string) {
if err != nil { if err != nil {
_, filename, line, _ := runtime.Caller(1) _, filename, line, _ := runtime.Caller(1)
@ -338,12 +332,12 @@ func IsGithubAction() bool {
func TestLog(t testing.TB, line string) { func TestLog(t testing.TB, line string) {
f, err := os.OpenFile("/tmp/test.log", os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) f, err := os.OpenFile("/tmp/test.log", os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil { if err != nil {
Check(t, err) require.NoError(t, err)
} }
defer f.Close() defer f.Close()
_, err = f.WriteString(line + "\n") _, err = f.WriteString(line + "\n")
if err != nil { if err != nil {
Check(t, err) require.NoError(t, err)
} }
} }
@ -372,7 +366,7 @@ func CompareGoldens(t testing.TB, out, goldenName string) {
if os.IsNotExist(err) { if os.IsNotExist(err) {
expected = []byte("ERR_FILE_NOT_FOUND:" + goldenPath) expected = []byte("ERR_FILE_NOT_FOUND:" + goldenPath)
} else { } else {
Check(t, err) require.NoError(t, err)
} }
} }
if diff := cmp.Diff(string(expected), out); diff != "" { if diff := cmp.Diff(string(expected), out); diff != "" {
@ -380,7 +374,7 @@ func CompareGoldens(t testing.TB, out, goldenName string) {
_, filename, line, _ := runtime.Caller(1) _, filename, line, _ := runtime.Caller(1)
t.Fatalf("hishtory golden mismatch for %s at %s:%d (-expected +got):\n%s\nactual=\n%s", goldenName, filename, line, diff, out) t.Fatalf("hishtory golden mismatch for %s at %s:%d (-expected +got):\n%s\nactual=\n%s", goldenName, filename, line, diff, out)
} else { } else {
Check(t, os.WriteFile(goldenPath, []byte(out), 0644)) require.NoError(t, os.WriteFile(goldenPath, []byte(out), 0644))
} }
} }
} }