diff --git a/backend/server/server_test.go b/backend/server/server_test.go index c9a0128..8efe178 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -31,11 +31,11 @@ func TestESubmitThenQuery(t *testing.T) { otherUser := data.UserId("otherkey") otherDev := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Submit a few entries for different devices entry := testutils.MakeFakeHistoryEntry("ls ~/") @@ -44,12 +44,12 @@ func TestESubmitThenQuery(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -78,7 +78,7 @@ func TestESubmitThenQuery(t *testing.T) { // Same for device id 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -106,7 +106,7 @@ func TestESubmitThenQuery(t *testing.T) { // Bootstrap handler should return 2 entries, one for each device w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil) - apiBootstrapHandler(context.Background(), w, searchReq) + apiBootstrapHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -132,17 +132,17 @@ func TestDumpRequestAndResponse(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Query for dump requests, there should be one for userId w := httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -162,7 +162,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And one for otherUser w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -182,7 +182,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none if we query for a user ID that doesn't exit w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -193,7 +193,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none for a missing user ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -212,11 +212,11 @@ func TestDumpRequestAndResponse(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody)) - apiSubmitDumpHandler(context.Background(), nil, submitReq) + apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) // Check that the dump request is no longer there for userId for either device ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -226,7 +226,7 @@ func TestDumpRequestAndResponse(t *testing.T) { } w = httptest.NewRecorder() // The other user - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -237,7 +237,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // But it is there for the other user w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -258,7 +258,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And finally, query to ensure that the dumped entries are in the DB w = httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -333,13 +333,13 @@ func TestDeletionRequests(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Add an entry for user1 entry1 := testutils.MakeFakeHistoryEntry("ls ~/") @@ -349,7 +349,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // And another entry for user1 entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -359,7 +359,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // And an entry for user2 that has the same timestamp as the previous entry entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -370,12 +370,12 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -414,13 +414,13 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal(delReq) testutils.Check(t, err) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - addDeletionRequestHandler(context.Background(), nil, req) + addDeletionRequestHandler(httptest.NewRecorder(), req) // Query again for device id 1 and get a single result time.Sleep(10 * time.Millisecond) w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -448,7 +448,7 @@ func TestDeletionRequests(t *testing.T) { // Query for user 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -476,7 +476,7 @@ func TestDeletionRequests(t *testing.T) { // Query for deletion requests w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - getDeletionRequestsHandler(context.Background(), w, searchReq) + getDeletionRequestsHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -506,7 +506,7 @@ func TestDeletionRequests(t *testing.T) { func TestHealthcheck(t *testing.T) { w := httptest.NewRecorder() - healthCheckHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/", nil)) + healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { t.Fatalf("expected 200 resp code for healthCheckHandler") } @@ -532,16 +532,16 @@ func TestLimitRegistrations(t *testing.T) { // Register three devices across two users deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // And this next one should fail since it is a new user defer func() { _ = recover() }() deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) t.Errorf("expected panic") } @@ -553,7 +553,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { userId := data.UserId("dkey") devId1 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) entry1 := testutils.MakeFakeHistoryEntry("ls ~/") entry1.DeviceId = devId1 encEntry, err := data.EncryptHistoryEntry("dkey", entry1) @@ -561,7 +561,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics testutils.Check(t, cleanDatabase(context.TODO()))