From 2b1ba7e3ba6cc7ce737504caccf3aab1a08c5e18 Mon Sep 17 00:00:00 2001
From: Sergio Moura <sergio@moura.ca>
Date: Wed, 6 Sep 2023 11:37:14 -0400
Subject: [PATCH] use single context and always return a status to the client
 api handlers do not need an extra context. http.Request already has a context
 that is being ignored, so we leverage it and stop creating a new one. make
 the endpoints return http.StatusNoContent instead of just closing the
 connection from the client.

---
 backend/server/server.go | 93 +++++++++++++++++++++++++---------------
 1 file changed, 58 insertions(+), 35 deletions(-)

diff --git a/backend/server/server.go b/backend/server/server.go
index 19050cb..91494b0 100644
--- a/backend/server/server.go
+++ b/backend/server/server.go
@@ -88,7 +88,7 @@ func updateUsageData(ctx context.Context, r *http.Request, userId, deviceId stri
 	}
 }
 
-func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
 	query := `
 	SELECT 
 		MIN(devices.registration_date) as registration_date, 
@@ -104,7 +104,7 @@ func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque
 	GROUP BY devices.user_id
 	ORDER BY registration_date
 	`
-	rows, err := GLOBAL_DB.WithContext(ctx).Raw(query).Rows()
+	rows, err := GLOBAL_DB.WithContext(r.Context()).Raw(query).Rows()
 	if err != nil {
 		panic(err)
 	}
@@ -131,7 +131,8 @@ func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque
 	tbl.Print()
 }
 
-func statsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func statsHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	var numDevices int64 = 0
 	checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices))
 	type numEntriesProcessed struct {
@@ -153,15 +154,16 @@ func statsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		panic(err)
 	}
-	w.Write([]byte(fmt.Sprintf("Num devices: %d\n", numDevices)))
-	w.Write([]byte(fmt.Sprintf("Num history entries processed: %d\n", nep.Total)))
-	w.Write([]byte(fmt.Sprintf("Num DB entries: %d\n", numDbEntries)))
-	w.Write([]byte(fmt.Sprintf("Weekly active installs: %d\n", weeklyActiveInstalls)))
-	w.Write([]byte(fmt.Sprintf("Weekly active queries: %d\n", weeklyQueryUsers)))
-	w.Write([]byte(fmt.Sprintf("Last registration: %s\n", lastRegistration)))
+	fmt.Fprintf(w, "Num devices: %d\n", numDevices)
+	fmt.Fprintf(w, "Num history entries processed: %d\n", nep.Total)
+	fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries)
+	fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls)
+	fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers)
+	fmt.Fprintf(w, "Last registration: %s\n", lastRegistration)
 }
 
-func apiSubmitHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	data, err := io.ReadAll(r.Body)
 	if err != nil {
 		panic(err)
@@ -201,9 +203,12 @@ func apiSubmitHandler(ctx context.Context, w http.ResponseWriter, r *http.Reques
 	if GLOBAL_STATSD != nil {
 		GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
 	}
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
-func apiBootstrapHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	userId := getRequiredQueryParam(r, "user_id")
 	deviceId := getRequiredQueryParam(r, "device_id")
 	updateUsageData(ctx, r, userId, deviceId, 0, false)
@@ -218,7 +223,8 @@ func apiBootstrapHandler(ctx context.Context, w http.ResponseWriter, r *http.Req
 	w.Write(resp)
 }
 
-func apiQueryHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	userId := getRequiredQueryParam(r, "user_id")
 	deviceId := getRequiredQueryParam(r, "device_id")
 	updateUsageData(ctx, r, userId, deviceId, 0, true)
@@ -276,7 +282,8 @@ func getRemoteAddr(r *http.Request) string {
 	return addr[0]
 }
 
-func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	if getMaximumNumberOfAllowedUsers() < math.MaxInt {
 		row := GLOBAL_DB.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row()
 		var numDistinctUsers int64 = 0
@@ -302,14 +309,16 @@ func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ
 	if GLOBAL_STATSD != nil {
 		GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0)
 	}
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
-func apiGetPendingDumpRequestsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
 	userId := getRequiredQueryParam(r, "user_id")
 	deviceId := getRequiredQueryParam(r, "device_id")
 	var dumpRequests []*shared.DumpRequest
 	// Filter out ones requested by the hishtory instance that sent this request
-	checkGormResult(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
+	checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
 	respBody, err := json.Marshal(dumpRequests)
 	if err != nil {
 		panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
@@ -317,7 +326,8 @@ func apiGetPendingDumpRequestsHandler(ctx context.Context, w http.ResponseWriter
 	w.Write(respBody)
 }
 
-func apiSubmitDumpHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	userId := getRequiredQueryParam(r, "user_id")
 	srcDeviceId := getRequiredQueryParam(r, "source_device_id")
 	requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
@@ -346,9 +356,11 @@ func apiSubmitDumpHandler(ctx context.Context, w http.ResponseWriter, r *http.Re
 	}
 	checkGormResult(GLOBAL_DB.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
 	updateUsageData(ctx, r, userId, srcDeviceId, len(entries), false)
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
-func apiBannerHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
 	commitHash := getRequiredQueryParam(r, "commit_hash")
 	deviceId := getRequiredQueryParam(r, "device_id")
 	forcedBanner := r.URL.Query().Get("forced_banner")
@@ -360,7 +372,8 @@ func apiBannerHandler(ctx context.Context, w http.ResponseWriter, r *http.Reques
 	w.Write([]byte(html.EscapeString(forcedBanner)))
 }
 
-func getDeletionRequestsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	userId := getRequiredQueryParam(r, "user_id")
 	deviceId := getRequiredQueryParam(r, "device_id")
 
@@ -377,7 +390,8 @@ func getDeletionRequestsHandler(ctx context.Context, w http.ResponseWriter, r *h
 	w.Write(respBody)
 }
 
-func addDeletionRequestHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	data, err := io.ReadAll(r.Body)
 	if err != nil {
 		panic(err)
@@ -409,9 +423,12 @@ func addDeletionRequestHandler(ctx context.Context, w http.ResponseWriter, r *ht
 		panic(err)
 	}
 	fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
-func healthCheckHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
 	if isProductionEnvironment() {
 		// Check that we have a reasonable looking set of devices/entries in the DB
 		rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
@@ -447,8 +464,7 @@ func healthCheckHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ
 			panic(fmt.Sprintf("failed to ping DB: %v", err))
 		}
 	}
-	ok := "OK"
-	w.Write([]byte(ok))
+	w.Write([]byte("OK"))
 }
 
 func applyDeletionRequestsToBackend(ctx context.Context, request shared.DeletionRequest) (int, error) {
@@ -461,22 +477,24 @@ func applyDeletionRequestsToBackend(ctx context.Context, request shared.Deletion
 	return int(result.RowsAffected), nil
 }
 
-func wipeDbEntriesHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
 	if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
 		panic("refusing to wipe the DB for prod")
 	}
 	if !isTestEnvironment() {
 		panic("refusing to wipe the DB non-test environment")
 	}
-	checkGormResult(GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM enc_history_entries"))
+	checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries"))
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
-func getNumConnectionsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) {
 	sqlDb, err := GLOBAL_DB.DB()
 	if err != nil {
 		panic(err)
 	}
-	w.Write([]byte(fmt.Sprintf("%#v", sqlDb.Stats().OpenConnections)))
+	fmt.Fprintf(w, "%#v", sqlDb.Stats().OpenConnections)
 }
 
 func isTestEnvironment() bool {
@@ -588,11 +606,13 @@ func runBackgroundJobs(ctx context.Context) {
 	}
 }
 
-func triggerCronHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
-	err := cron(ctx)
+func triggerCronHandler(w http.ResponseWriter, r *http.Request) {
+	err := cron(r.Context())
 	if err != nil {
 		panic(err)
 	}
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
 type releaseInfo struct {
@@ -719,7 +739,7 @@ func buildUpdateInfo(version string) shared.UpdateInfo {
 	}
 }
 
-func apiDownloadHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func apiDownloadHandler(w http.ResponseWriter, r *http.Request) {
 	updateInfo := buildUpdateInfo(ReleaseVersion)
 	resp, err := json.Marshal(updateInfo)
 	if err != nil {
@@ -728,7 +748,7 @@ func apiDownloadHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ
 	w.Write(resp)
 }
 
-func slsaStatusHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
 	// returns "OK" unless there is a current SLSA bug
 	v := getHishtoryVersion(r)
 	if !strings.Contains(v, "v0.") {
@@ -747,7 +767,7 @@ func slsaStatusHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque
 	w.Write([]byte("OK"))
 }
 
-func feedbackHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+func feedbackHandler(w http.ResponseWriter, r *http.Request) {
 	data, err := io.ReadAll(r.Body)
 	if err != nil {
 		panic(err)
@@ -758,11 +778,13 @@ func feedbackHandler(ctx context.Context, w http.ResponseWriter, r *http.Request
 		panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
 	}
 	fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
-	checkGormResult(GLOBAL_DB.WithContext(ctx).Create(feedback))
+	checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(feedback))
 
 	if GLOBAL_STATSD != nil {
 		GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0)
 	}
+
+	w.WriteHeader(http.StatusNoContent)
 }
 
 type loggedResponseData struct {
@@ -789,7 +811,7 @@ func getFunctionName(temp interface{}) string {
 	return strs[len(strs)-1]
 }
 
-func withLogging(h func(context.Context, http.ResponseWriter, *http.Request)) http.Handler {
+func withLogging(h http.HandlerFunc) http.Handler {
 	logFn := func(rw http.ResponseWriter, r *http.Request) {
 		var responseData loggedResponseData
 		lrw := loggingResponseWriter{
@@ -798,14 +820,14 @@ func withLogging(h func(context.Context, http.ResponseWriter, *http.Request)) ht
 		}
 		start := time.Now()
 		span, ctx := tracer.StartSpanFromContext(
-			context.Background(),
+			r.Context(),
 			getFunctionName(h),
 			tracer.SpanType(ext.SpanTypeSQL),
 			tracer.ServiceName("hishtory-api"),
 		)
 		defer span.Finish()
 
-		h(ctx, &lrw, r)
+		h(&lrw, r.WithContext(ctx))
 
 		duration := time.Since(start)
 		fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
@@ -952,6 +974,7 @@ func main() {
 		mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler))
 		mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler))
 	}
+
 	fmt.Println("Listening on localhost:8080")
 	log.Fatal(http.ListenAndServe(":8080", mux))
 }