From 2b1ba7e3ba6cc7ce737504caccf3aab1a08c5e18 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Wed, 6 Sep 2023 11:37:14 -0400 Subject: [PATCH 1/3] use single context and always return a status to the client api handlers do not need an extra context. http.Request already has a context that is being ignored, so we leverage it and stop creating a new one. make the endpoints return http.StatusNoContent instead of just closing the connection from the client. --- backend/server/server.go | 93 +++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 35 deletions(-) diff --git a/backend/server/server.go b/backend/server/server.go index 19050cb..91494b0 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -88,7 +88,7 @@ func updateUsageData(ctx context.Context, r *http.Request, userId, deviceId stri } } -func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func usageStatsHandler(w http.ResponseWriter, r *http.Request) { query := ` SELECT MIN(devices.registration_date) as registration_date, @@ -104,7 +104,7 @@ func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque GROUP BY devices.user_id ORDER BY registration_date ` - rows, err := GLOBAL_DB.WithContext(ctx).Raw(query).Rows() + rows, err := GLOBAL_DB.WithContext(r.Context()).Raw(query).Rows() if err != nil { panic(err) } @@ -131,7 +131,8 @@ func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque tbl.Print() } -func statsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func statsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() var numDevices int64 = 0 checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices)) type numEntriesProcessed struct { @@ -153,15 +154,16 @@ func statsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { if err != nil { panic(err) } - w.Write([]byte(fmt.Sprintf("Num devices: %d\n", numDevices))) - w.Write([]byte(fmt.Sprintf("Num history entries processed: %d\n", nep.Total))) - w.Write([]byte(fmt.Sprintf("Num DB entries: %d\n", numDbEntries))) - w.Write([]byte(fmt.Sprintf("Weekly active installs: %d\n", weeklyActiveInstalls))) - w.Write([]byte(fmt.Sprintf("Weekly active queries: %d\n", weeklyQueryUsers))) - w.Write([]byte(fmt.Sprintf("Last registration: %s\n", lastRegistration))) + fmt.Fprintf(w, "Num devices: %d\n", numDevices) + fmt.Fprintf(w, "Num history entries processed: %d\n", nep.Total) + fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) + fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) + fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) + fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) } -func apiSubmitHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() data, err := io.ReadAll(r.Body) if err != nil { panic(err) @@ -201,9 +203,12 @@ func apiSubmitHandler(ctx context.Context, w http.ResponseWriter, r *http.Reques if GLOBAL_STATSD != nil { GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) } + + w.WriteHeader(http.StatusNoContent) } -func apiBootstrapHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") updateUsageData(ctx, r, userId, deviceId, 0, false) @@ -218,7 +223,8 @@ func apiBootstrapHandler(ctx context.Context, w http.ResponseWriter, r *http.Req w.Write(resp) } -func apiQueryHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiQueryHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") updateUsageData(ctx, r, userId, deviceId, 0, true) @@ -276,7 +282,8 @@ func getRemoteAddr(r *http.Request) string { return addr[0] } -func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() if getMaximumNumberOfAllowedUsers() < math.MaxInt { row := GLOBAL_DB.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row() var numDistinctUsers int64 = 0 @@ -302,14 +309,16 @@ func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ if GLOBAL_STATSD != nil { GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0) } + + w.WriteHeader(http.StatusNoContent) } -func apiGetPendingDumpRequestsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") var dumpRequests []*shared.DumpRequest // Filter out ones requested by the hishtory instance that sent this request - checkGormResult(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)) + checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)) respBody, err := json.Marshal(dumpRequests) if err != nil { panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err)) @@ -317,7 +326,8 @@ func apiGetPendingDumpRequestsHandler(ctx context.Context, w http.ResponseWriter w.Write(respBody) } -func apiSubmitDumpHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() userId := getRequiredQueryParam(r, "user_id") srcDeviceId := getRequiredQueryParam(r, "source_device_id") requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") @@ -346,9 +356,11 @@ func apiSubmitDumpHandler(ctx context.Context, w http.ResponseWriter, r *http.Re } checkGormResult(GLOBAL_DB.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId)) updateUsageData(ctx, r, userId, srcDeviceId, len(entries), false) + + w.WriteHeader(http.StatusNoContent) } -func apiBannerHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiBannerHandler(w http.ResponseWriter, r *http.Request) { commitHash := getRequiredQueryParam(r, "commit_hash") deviceId := getRequiredQueryParam(r, "device_id") forcedBanner := r.URL.Query().Get("forced_banner") @@ -360,7 +372,8 @@ func apiBannerHandler(ctx context.Context, w http.ResponseWriter, r *http.Reques w.Write([]byte(html.EscapeString(forcedBanner))) } -func getDeletionRequestsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") @@ -377,7 +390,8 @@ func getDeletionRequestsHandler(ctx context.Context, w http.ResponseWriter, r *h w.Write(respBody) } -func addDeletionRequestHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() data, err := io.ReadAll(r.Body) if err != nil { panic(err) @@ -409,9 +423,12 @@ func addDeletionRequestHandler(ctx context.Context, w http.ResponseWriter, r *ht panic(err) } fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted) + + w.WriteHeader(http.StatusNoContent) } -func healthCheckHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func healthCheckHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() if isProductionEnvironment() { // Check that we have a reasonable looking set of devices/entries in the DB rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() @@ -447,8 +464,7 @@ func healthCheckHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ panic(fmt.Sprintf("failed to ping DB: %v", err)) } } - ok := "OK" - w.Write([]byte(ok)) + w.Write([]byte("OK")) } func applyDeletionRequestsToBackend(ctx context.Context, request shared.DeletionRequest) (int, error) { @@ -461,22 +477,24 @@ func applyDeletionRequestsToBackend(ctx context.Context, request shared.Deletion return int(result.RowsAffected), nil } -func wipeDbEntriesHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { if r.Host == "api.hishtory.dev" || isProductionEnvironment() { panic("refusing to wipe the DB for prod") } if !isTestEnvironment() { panic("refusing to wipe the DB non-test environment") } - checkGormResult(GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM enc_history_entries")) + checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries")) + + w.WriteHeader(http.StatusNoContent) } -func getNumConnectionsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { sqlDb, err := GLOBAL_DB.DB() if err != nil { panic(err) } - w.Write([]byte(fmt.Sprintf("%#v", sqlDb.Stats().OpenConnections))) + fmt.Fprintf(w, "%#v", sqlDb.Stats().OpenConnections) } func isTestEnvironment() bool { @@ -588,11 +606,13 @@ func runBackgroundJobs(ctx context.Context) { } } -func triggerCronHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { - err := cron(ctx) +func triggerCronHandler(w http.ResponseWriter, r *http.Request) { + err := cron(r.Context()) if err != nil { panic(err) } + + w.WriteHeader(http.StatusNoContent) } type releaseInfo struct { @@ -719,7 +739,7 @@ func buildUpdateInfo(version string) shared.UpdateInfo { } } -func apiDownloadHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func apiDownloadHandler(w http.ResponseWriter, r *http.Request) { updateInfo := buildUpdateInfo(ReleaseVersion) resp, err := json.Marshal(updateInfo) if err != nil { @@ -728,7 +748,7 @@ func apiDownloadHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ w.Write(resp) } -func slsaStatusHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func slsaStatusHandler(w http.ResponseWriter, r *http.Request) { // returns "OK" unless there is a current SLSA bug v := getHishtoryVersion(r) if !strings.Contains(v, "v0.") { @@ -747,7 +767,7 @@ func slsaStatusHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque w.Write([]byte("OK")) } -func feedbackHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { +func feedbackHandler(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { panic(err) @@ -758,11 +778,13 @@ func feedbackHandler(ctx context.Context, w http.ResponseWriter, r *http.Request panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) } fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) - checkGormResult(GLOBAL_DB.WithContext(ctx).Create(feedback)) + checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(feedback)) if GLOBAL_STATSD != nil { GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0) } + + w.WriteHeader(http.StatusNoContent) } type loggedResponseData struct { @@ -789,7 +811,7 @@ func getFunctionName(temp interface{}) string { return strs[len(strs)-1] } -func withLogging(h func(context.Context, http.ResponseWriter, *http.Request)) http.Handler { +func withLogging(h http.HandlerFunc) http.Handler { logFn := func(rw http.ResponseWriter, r *http.Request) { var responseData loggedResponseData lrw := loggingResponseWriter{ @@ -798,14 +820,14 @@ func withLogging(h func(context.Context, http.ResponseWriter, *http.Request)) ht } start := time.Now() span, ctx := tracer.StartSpanFromContext( - context.Background(), + r.Context(), getFunctionName(h), tracer.SpanType(ext.SpanTypeSQL), tracer.ServiceName("hishtory-api"), ) defer span.Finish() - h(ctx, &lrw, r) + h(&lrw, r.WithContext(ctx)) duration := time.Since(start) fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) @@ -952,6 +974,7 @@ func main() { mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler)) mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler)) } + fmt.Println("Listening on localhost:8080") log.Fatal(http.ListenAndServe(":8080", mux)) } From 589b99e500944ad3bb795bc7b3042755b46ee14e Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Thu, 7 Sep 2023 08:34:21 -0400 Subject: [PATCH 2/3] do not use http.StatusNoContent --- backend/server/server.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/backend/server/server.go b/backend/server/server.go index 91494b0..1b0b462 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -204,7 +204,8 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) } - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { @@ -310,7 +311,8 @@ func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0) } - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { @@ -357,7 +359,8 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { checkGormResult(GLOBAL_DB.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId)) updateUsageData(ctx, r, userId, srcDeviceId, len(entries), false) - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } func apiBannerHandler(w http.ResponseWriter, r *http.Request) { @@ -424,7 +427,8 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { } fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted) - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } func healthCheckHandler(w http.ResponseWriter, r *http.Request) { @@ -486,7 +490,8 @@ func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { } checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries")) - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { @@ -612,7 +617,8 @@ func triggerCronHandler(w http.ResponseWriter, r *http.Request) { panic(err) } - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } type releaseInfo struct { @@ -784,7 +790,8 @@ func feedbackHandler(w http.ResponseWriter, r *http.Request) { GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0) } - w.WriteHeader(http.StatusNoContent) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) } type loggedResponseData struct { From e6d922709d4b215f8aa5c2f82f512ab15bc33d76 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Thu, 7 Sep 2023 09:43:04 -0400 Subject: [PATCH 3/3] fix tests --- backend/server/server_test.go | 78 +++++++++++++++++------------------ 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/backend/server/server_test.go b/backend/server/server_test.go index c9a0128..8efe178 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -31,11 +31,11 @@ func TestESubmitThenQuery(t *testing.T) { otherUser := data.UserId("otherkey") otherDev := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Submit a few entries for different devices entry := testutils.MakeFakeHistoryEntry("ls ~/") @@ -44,12 +44,12 @@ func TestESubmitThenQuery(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -78,7 +78,7 @@ func TestESubmitThenQuery(t *testing.T) { // Same for device id 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -106,7 +106,7 @@ func TestESubmitThenQuery(t *testing.T) { // Bootstrap handler should return 2 entries, one for each device w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil) - apiBootstrapHandler(context.Background(), w, searchReq) + apiBootstrapHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -132,17 +132,17 @@ func TestDumpRequestAndResponse(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Query for dump requests, there should be one for userId w := httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -162,7 +162,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And one for otherUser w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -182,7 +182,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none if we query for a user ID that doesn't exit w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -193,7 +193,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none for a missing user ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -212,11 +212,11 @@ func TestDumpRequestAndResponse(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody)) - apiSubmitDumpHandler(context.Background(), nil, submitReq) + apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) // Check that the dump request is no longer there for userId for either device ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -226,7 +226,7 @@ func TestDumpRequestAndResponse(t *testing.T) { } w = httptest.NewRecorder() // The other user - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -237,7 +237,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // But it is there for the other user w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -258,7 +258,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And finally, query to ensure that the dumped entries are in the DB w = httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -333,13 +333,13 @@ func TestDeletionRequests(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Add an entry for user1 entry1 := testutils.MakeFakeHistoryEntry("ls ~/") @@ -349,7 +349,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // And another entry for user1 entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -359,7 +359,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // And an entry for user2 that has the same timestamp as the previous entry entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -370,12 +370,12 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -414,13 +414,13 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal(delReq) testutils.Check(t, err) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - addDeletionRequestHandler(context.Background(), nil, req) + addDeletionRequestHandler(httptest.NewRecorder(), req) // Query again for device id 1 and get a single result time.Sleep(10 * time.Millisecond) w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -448,7 +448,7 @@ func TestDeletionRequests(t *testing.T) { // Query for user 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiQueryHandler(context.Background(), w, searchReq) + apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -476,7 +476,7 @@ func TestDeletionRequests(t *testing.T) { // Query for deletion requests w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - getDeletionRequestsHandler(context.Background(), w, searchReq) + getDeletionRequestsHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -506,7 +506,7 @@ func TestDeletionRequests(t *testing.T) { func TestHealthcheck(t *testing.T) { w := httptest.NewRecorder() - healthCheckHandler(context.Background(), w, httptest.NewRequest(http.MethodGet, "/", nil)) + healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { t.Fatalf("expected 200 resp code for healthCheckHandler") } @@ -532,16 +532,16 @@ func TestLimitRegistrations(t *testing.T) { // Register three devices across two users deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) // And this next one should fail since it is a new user defer func() { _ = recover() }() deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) t.Errorf("expected panic") } @@ -553,7 +553,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { userId := data.UserId("dkey") devId1 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(context.Background(), nil, deviceReq) + apiRegisterHandler(httptest.NewRecorder(), deviceReq) entry1 := testutils.MakeFakeHistoryEntry("ls ~/") entry1.DeviceId = devId1 encEntry, err := data.EncryptHistoryEntry("dkey", entry1) @@ -561,7 +561,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(context.Background(), nil, submitReq) + apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics testutils.Check(t, cleanDatabase(context.TODO()))