diff --git a/backend/server/internal/server/server_test.go b/backend/server/internal/server/server_test.go index 80005f2..0186e34 100644 --- a/backend/server/internal/server/server_test.go +++ b/backend/server/internal/server/server_test.go @@ -107,19 +107,11 @@ func TestESubmitThenQuery(t *testing.T) { respBody, err = io.ReadAll(res.Body) require.NoError(t, err) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 1 { - t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) - } + require.Len(t, retrievedEntries, 1) dbEntry := retrievedEntries[0] - if dbEntry.DeviceId != devId2 { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.UserId != data.UserId("key") { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.ReadCount != 0 { - t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) - } + require.Equal(t, dbEntry.DeviceId, devId2) + require.Equal(t, dbEntry.UserId, data.UserId("key")) + require.Equal(t, 0, dbEntry.ReadCount) decEntry, err := data.DecryptHistoryEntry("key", *dbEntry) require.NoError(t, err) require.Equal(t, decEntry, entry) @@ -133,9 +125,7 @@ func TestESubmitThenQuery(t *testing.T) { respBody, err = io.ReadAll(res.Body) require.NoError(t, err) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 2 { - t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries)) - } + require.Len(t, retrievedEntries, 2) // Assert that we aren't leaking connections assertNoLeakedConnections(t, DB) @@ -170,16 +160,10 @@ func TestDumpRequestAndResponse(t *testing.T) { require.NoError(t, err) var dumpRequests []*shared.DumpRequest require.NoError(t, json.Unmarshal(respBody, &dumpRequests)) - if len(dumpRequests) != 1 { - t.Fatalf("expected one pending dump request, got %#v", dumpRequests) - } + require.Len(t, dumpRequests, 1) dumpRequest := dumpRequests[0] - if dumpRequest.RequestingDeviceId != devId2 { - t.Fatalf("unexpected device ID") - } - if dumpRequest.UserId != userId { - t.Fatalf("unexpected user ID") - } + require.Equal(t, devId2, dumpRequest.RequestingDeviceId) + require.Equal(t, userId, dumpRequest.UserId) // And one for otherUser w = httptest.NewRecorder() @@ -190,16 +174,10 @@ func TestDumpRequestAndResponse(t *testing.T) { require.NoError(t, err) dumpRequests = make([]*shared.DumpRequest, 0) require.NoError(t, json.Unmarshal(respBody, &dumpRequests)) - if len(dumpRequests) != 1 { - t.Fatalf("expected one pending dump request, got %#v", dumpRequests) - } + require.Len(t, dumpRequest, 1) dumpRequest = dumpRequests[0] - if dumpRequest.RequestingDeviceId != otherDev2 { - t.Fatalf("unexpected device ID") - } - if dumpRequest.UserId != otherUser { - t.Fatalf("unexpected user ID") - } + require.Equal(t, otherDev2, dumpRequest.RequestingDeviceId) + require.Equal(t, otherUser, dumpRequest.UserId) // And none if we query for a user ID that doesn't exit w = httptest.NewRecorder() @@ -263,16 +241,10 @@ func TestDumpRequestAndResponse(t *testing.T) { require.NoError(t, err) dumpRequests = make([]*shared.DumpRequest, 0) require.NoError(t, json.Unmarshal(respBody, &dumpRequests)) - if len(dumpRequests) != 1 { - t.Fatalf("expected one pending dump request, got %#v", dumpRequests) - } + require.Len(t, dumpRequest, 1) dumpRequest = dumpRequests[0] - if dumpRequest.RequestingDeviceId != otherDev2 { - t.Fatalf("unexpected device ID") - } - if dumpRequest.UserId != otherUser { - t.Fatalf("unexpected user ID") - } + require.Equal(t, otherDev2, dumpRequest.RequestingDeviceId) + require.Equal(t, otherUser, dumpRequest.UserId) // And finally, query to ensure that the dumped entries are in the DB w = httptest.NewRecorder() @@ -284,19 +256,11 @@ func TestDumpRequestAndResponse(t *testing.T) { require.NoError(t, err) var retrievedEntries []*shared.EncHistoryEntry require.NoError(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 2 { - t.Fatalf("Expected to retrieve 2 entries, found %d", len(retrievedEntries)) - } + require.Len(t, retrievedEntries, 2) for _, dbEntry := range retrievedEntries { - if dbEntry.DeviceId != devId2 { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.UserId != userId { - t.Fatalf("Response contains an incorrect user ID: %#v", *dbEntry) - } - if dbEntry.ReadCount != 0 { - t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) - } + require.Equal(t, devId2, dbEntry.DeviceId) + require.Equal(t, userId, dbEntry.UserId) + require.Equal(t, 0, dbEntry.ReadCount) decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) require.NoError(t, err) require.True(t, assert.ObjectsAreEqual(decEntry, entry1Dec) || assert.ObjectsAreEqual(decEntry, entry2Dec)) @@ -376,19 +340,11 @@ func TestDeletionRequests(t *testing.T) { require.NoError(t, err) var retrievedEntries []*shared.EncHistoryEntry require.NoError(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 1 { - t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) - } + require.Len(t, retrievedEntries, 1) for _, dbEntry := range retrievedEntries { - if dbEntry.DeviceId != devId1 { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.UserId != data.UserId("dkey") { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.ReadCount != 0 { - t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) - } + require.Equal(t, devId1, dbEntry.DeviceId) + require.Equal(t, data.UserId("dkey"), dbEntry.UserId) + require.Equal(t, 0, dbEntry.ReadCount) decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) require.NoError(t, err) require.True(t, assert.ObjectsAreEqual(decEntry, entry1) || assert.ObjectsAreEqual(decEntry, entry2)) @@ -418,19 +374,11 @@ func TestDeletionRequests(t *testing.T) { respBody, err = io.ReadAll(res.Body) require.NoError(t, err) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 1 { - t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) - } + require.Len(t, retrievedEntries, 1) dbEntry := retrievedEntries[0] - if dbEntry.DeviceId != devId1 { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.UserId != data.UserId("dkey") { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.ReadCount != 1 { - t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) - } + require.Equal(t, devId1, dbEntry.DeviceId) + require.Equal(t, data.UserId("dkey"), dbEntry.UserId) + require.Equal(t, 1, dbEntry.ReadCount) decEntry, err := data.DecryptHistoryEntry("dkey", *dbEntry) require.NoError(t, err) require.Equal(t, decEntry, entry2) @@ -444,19 +392,11 @@ func TestDeletionRequests(t *testing.T) { respBody, err = io.ReadAll(res.Body) require.NoError(t, err) require.NoError(t, json.Unmarshal(respBody, &retrievedEntries)) - if len(retrievedEntries) != 1 { - t.Fatalf("Expected to retrieve 1 entry, found %d", len(retrievedEntries)) - } + require.Len(t, retrievedEntries, 1) dbEntry = retrievedEntries[0] - if dbEntry.DeviceId != otherDev1 { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.UserId != data.UserId("dOtherkey") { - t.Fatalf("Response contains an incorrect device ID: %#v", *dbEntry) - } - if dbEntry.ReadCount != 0 { - t.Fatalf("db.ReadCount should have been 1, was %v", dbEntry.ReadCount) - } + require.Equal(t, otherDev1, dbEntry.DeviceId) + require.Equal(t, data.UserId("dOtherkey"), dbEntry.UserId) + require.Equal(t, 0, dbEntry.ReadCount) decEntry, err = data.DecryptHistoryEntry("dOtherkey", *dbEntry) require.NoError(t, err) require.Equal(t, decEntry, entry3) @@ -482,9 +422,7 @@ func TestDeletionRequests(t *testing.T) { require.NoError(t, err) var deletionRequests []*shared.DeletionRequest require.NoError(t, json.Unmarshal(respBody, &deletionRequests)) - if len(deletionRequests) != 1 { - t.Fatalf("received %d deletion requests, expected only one", len(deletionRequests)) - } + require.Len(t, deletionRequests, 1) deletionRequest := deletionRequests[0] expected := shared.DeletionRequest{ UserId: data.UserId("dkey"), @@ -507,16 +445,12 @@ func TestHealthcheck(t *testing.T) { s := NewServer(DB, TrackUsageData(true)) w := httptest.NewRecorder() s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) - if w.Code != 200 { - t.Fatalf("expected 200 resp code for healthCheckHandler") - } + require.Equal(t, 200, w.Code) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) require.NoError(t, err) - if string(respBody) != "OK" { - t.Fatalf("expected healthcheckHandler to return OK") - } + require.Equal(t, "OK", string(respBody)) // Assert that we aren't leaking connections assertNoLeakedConnections(t, DB)