mirror of
https://github.com/ddworken/hishtory.git
synced 2025-02-02 11:39:24 +01:00
use single context and always return a status to the client
api handlers do not need an extra context. http.Request already has a context that is being ignored, so we leverage it and stop creating a new one. make the endpoints return http.StatusNoContent instead of just closing the connection from the client.
This commit is contained in:
parent
efa9ddd6df
commit
2b1ba7e3ba
@ -88,7 +88,7 @@ func updateUsageData(ctx context.Context, r *http.Request, userId, deviceId stri
|
||||
}
|
||||
}
|
||||
|
||||
func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
query := `
|
||||
SELECT
|
||||
MIN(devices.registration_date) as registration_date,
|
||||
@ -104,7 +104,7 @@ func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
GROUP BY devices.user_id
|
||||
ORDER BY registration_date
|
||||
`
|
||||
rows, err := GLOBAL_DB.WithContext(ctx).Raw(query).Rows()
|
||||
rows, err := GLOBAL_DB.WithContext(r.Context()).Raw(query).Rows()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -131,7 +131,8 @@ func usageStatsHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
tbl.Print()
|
||||
}
|
||||
|
||||
func statsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func statsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var numDevices int64 = 0
|
||||
checkGormResult(GLOBAL_DB.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices))
|
||||
type numEntriesProcessed struct {
|
||||
@ -153,15 +154,16 @@ func statsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Write([]byte(fmt.Sprintf("Num devices: %d\n", numDevices)))
|
||||
w.Write([]byte(fmt.Sprintf("Num history entries processed: %d\n", nep.Total)))
|
||||
w.Write([]byte(fmt.Sprintf("Num DB entries: %d\n", numDbEntries)))
|
||||
w.Write([]byte(fmt.Sprintf("Weekly active installs: %d\n", weeklyActiveInstalls)))
|
||||
w.Write([]byte(fmt.Sprintf("Weekly active queries: %d\n", weeklyQueryUsers)))
|
||||
w.Write([]byte(fmt.Sprintf("Last registration: %s\n", lastRegistration)))
|
||||
fmt.Fprintf(w, "Num devices: %d\n", numDevices)
|
||||
fmt.Fprintf(w, "Num history entries processed: %d\n", nep.Total)
|
||||
fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries)
|
||||
fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls)
|
||||
fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers)
|
||||
fmt.Fprintf(w, "Last registration: %s\n", lastRegistration)
|
||||
}
|
||||
|
||||
func apiSubmitHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -201,9 +203,12 @@ func apiSubmitHandler(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
if GLOBAL_STATSD != nil {
|
||||
GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func apiBootstrapHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
updateUsageData(ctx, r, userId, deviceId, 0, false)
|
||||
@ -218,7 +223,8 @@ func apiBootstrapHandler(ctx context.Context, w http.ResponseWriter, r *http.Req
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
func apiQueryHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
updateUsageData(ctx, r, userId, deviceId, 0, true)
|
||||
@ -276,7 +282,8 @@ func getRemoteAddr(r *http.Request) string {
|
||||
return addr[0]
|
||||
}
|
||||
|
||||
func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if getMaximumNumberOfAllowedUsers() < math.MaxInt {
|
||||
row := GLOBAL_DB.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row()
|
||||
var numDistinctUsers int64 = 0
|
||||
@ -302,14 +309,16 @@ func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ
|
||||
if GLOBAL_STATSD != nil {
|
||||
GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func apiGetPendingDumpRequestsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
var dumpRequests []*shared.DumpRequest
|
||||
// Filter out ones requested by the hishtory instance that sent this request
|
||||
checkGormResult(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
|
||||
respBody, err := json.Marshal(dumpRequests)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to JSON marshall the dump requests: %v", err))
|
||||
@ -317,7 +326,8 @@ func apiGetPendingDumpRequestsHandler(ctx context.Context, w http.ResponseWriter
|
||||
w.Write(respBody)
|
||||
}
|
||||
|
||||
func apiSubmitDumpHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
srcDeviceId := getRequiredQueryParam(r, "source_device_id")
|
||||
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
|
||||
@ -346,9 +356,11 @@ func apiSubmitDumpHandler(ctx context.Context, w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
checkGormResult(GLOBAL_DB.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
|
||||
updateUsageData(ctx, r, userId, srcDeviceId, len(entries), false)
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func apiBannerHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
commitHash := getRequiredQueryParam(r, "commit_hash")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
forcedBanner := r.URL.Query().Get("forced_banner")
|
||||
@ -360,7 +372,8 @@ func apiBannerHandler(ctx context.Context, w http.ResponseWriter, r *http.Reques
|
||||
w.Write([]byte(html.EscapeString(forcedBanner)))
|
||||
}
|
||||
|
||||
func getDeletionRequestsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
|
||||
@ -377,7 +390,8 @@ func getDeletionRequestsHandler(ctx context.Context, w http.ResponseWriter, r *h
|
||||
w.Write(respBody)
|
||||
}
|
||||
|
||||
func addDeletionRequestHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -409,9 +423,12 @@ func addDeletionRequestHandler(ctx context.Context, w http.ResponseWriter, r *ht
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func healthCheckHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if isProductionEnvironment() {
|
||||
// Check that we have a reasonable looking set of devices/entries in the DB
|
||||
rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
|
||||
@ -447,8 +464,7 @@ func healthCheckHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ
|
||||
panic(fmt.Sprintf("failed to ping DB: %v", err))
|
||||
}
|
||||
}
|
||||
ok := "OK"
|
||||
w.Write([]byte(ok))
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func applyDeletionRequestsToBackend(ctx context.Context, request shared.DeletionRequest) (int, error) {
|
||||
@ -461,22 +477,24 @@ func applyDeletionRequestsToBackend(ctx context.Context, request shared.Deletion
|
||||
return int(result.RowsAffected), nil
|
||||
}
|
||||
|
||||
func wipeDbEntriesHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
|
||||
panic("refusing to wipe the DB for prod")
|
||||
}
|
||||
if !isTestEnvironment() {
|
||||
panic("refusing to wipe the DB non-test environment")
|
||||
}
|
||||
checkGormResult(GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM enc_history_entries"))
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries"))
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func getNumConnectionsHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
sqlDb, err := GLOBAL_DB.DB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Write([]byte(fmt.Sprintf("%#v", sqlDb.Stats().OpenConnections)))
|
||||
fmt.Fprintf(w, "%#v", sqlDb.Stats().OpenConnections)
|
||||
}
|
||||
|
||||
func isTestEnvironment() bool {
|
||||
@ -588,11 +606,13 @@ func runBackgroundJobs(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func triggerCronHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
err := cron(ctx)
|
||||
func triggerCronHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err := cron(r.Context())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
@ -719,7 +739,7 @@ func buildUpdateInfo(version string) shared.UpdateInfo {
|
||||
}
|
||||
}
|
||||
|
||||
func apiDownloadHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func apiDownloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||
updateInfo := buildUpdateInfo(ReleaseVersion)
|
||||
resp, err := json.Marshal(updateInfo)
|
||||
if err != nil {
|
||||
@ -728,7 +748,7 @@ func apiDownloadHandler(ctx context.Context, w http.ResponseWriter, r *http.Requ
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
func slsaStatusHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// returns "OK" unless there is a current SLSA bug
|
||||
v := getHishtoryVersion(r)
|
||||
if !strings.Contains(v, "v0.") {
|
||||
@ -747,7 +767,7 @@ func slsaStatusHandler(ctx context.Context, w http.ResponseWriter, r *http.Reque
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func feedbackHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
func feedbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -758,11 +778,13 @@ func feedbackHandler(ctx context.Context, w http.ResponseWriter, r *http.Request
|
||||
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
|
||||
}
|
||||
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
|
||||
checkGormResult(GLOBAL_DB.WithContext(ctx).Create(feedback))
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(feedback))
|
||||
|
||||
if GLOBAL_STATSD != nil {
|
||||
GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type loggedResponseData struct {
|
||||
@ -789,7 +811,7 @@ func getFunctionName(temp interface{}) string {
|
||||
return strs[len(strs)-1]
|
||||
}
|
||||
|
||||
func withLogging(h func(context.Context, http.ResponseWriter, *http.Request)) http.Handler {
|
||||
func withLogging(h http.HandlerFunc) http.Handler {
|
||||
logFn := func(rw http.ResponseWriter, r *http.Request) {
|
||||
var responseData loggedResponseData
|
||||
lrw := loggingResponseWriter{
|
||||
@ -798,14 +820,14 @@ func withLogging(h func(context.Context, http.ResponseWriter, *http.Request)) ht
|
||||
}
|
||||
start := time.Now()
|
||||
span, ctx := tracer.StartSpanFromContext(
|
||||
context.Background(),
|
||||
r.Context(),
|
||||
getFunctionName(h),
|
||||
tracer.SpanType(ext.SpanTypeSQL),
|
||||
tracer.ServiceName("hishtory-api"),
|
||||
)
|
||||
defer span.Finish()
|
||||
|
||||
h(ctx, &lrw, r)
|
||||
h(&lrw, r.WithContext(ctx))
|
||||
|
||||
duration := time.Since(start)
|
||||
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
|
||||
@ -952,6 +974,7 @@ func main() {
|
||||
mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler))
|
||||
mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler))
|
||||
}
|
||||
|
||||
fmt.Println("Listening on localhost:8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", mux))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user